Skip to main content

gemini_live/
transport.rs

1//! WebSocket transport layer.
2//!
3//! Handles raw connection establishment, frame I/O, and TLS (via `rustls`).
4//! This is the lowest layer — it knows nothing about JSON or the Gemini
5//! protocol.  The [`session`](crate::session) layer wraps [`Connection`] to
6//! add protocol-level concerns.
7//!
8//! # Endpoints
9//!
10//! | Auth method     | Endpoint                                                                                                                   |
11//! |-----------------|----------------------------------------------------------------------------------------------------------------------------|
12//! | API key         | `wss://generativelanguage.googleapis.com/ws/…v1beta.GenerativeService.BidiGenerateContent?key={KEY}`                        |
13//! | Ephemeral token | `wss://generativelanguage.googleapis.com/ws/…v1alpha.GenerativeService.BidiGenerateContentConstrained?access_token={TOKEN}` |
14//!
15//! Both can be overridden via [`TransportConfig::endpoint_override`] for
16//! testing or Vertex AI endpoints.
17
18use std::time::Duration;
19
20use futures_util::stream::{SplitSink, SplitStream};
21use futures_util::{SinkExt, StreamExt};
22use tokio::net::TcpStream;
23use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
24use tokio_tungstenite::tungstenite::{self, Message};
25use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
26
27use crate::error::{ConnectError, RecvError, SendError};
28
29const DEFAULT_HOST: &str = "wss://generativelanguage.googleapis.com";
30const API_KEY_PATH: &str =
31    "/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent";
32const EPHEMERAL_TOKEN_PATH: &str =
33    "/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContentConstrained";
34
35// ── Auth ─────────────────────────────────────────────────────────────────────
36
37/// Authentication method for the Gemini Live API.
38#[derive(Debug, Clone)]
39pub enum Auth {
40    /// Standard long-lived API key (sent as `?key=` query param).
41    ApiKey(String),
42    /// Short-lived token obtained via the ephemeral token endpoint (v1alpha).
43    EphemeralToken(String),
44}
45
46// ── TransportConfig ──────────────────────────────────────────────────────────
47
48/// Transport layer settings.
49///
50/// All fields have sensible defaults (see [`Default`] impl).  In most cases
51/// only [`auth`](Self::auth) needs to be set explicitly.
52#[derive(Debug, Clone)]
53pub struct TransportConfig {
54    pub auth: Auth,
55    /// Override the default endpoint (for testing or Vertex AI).
56    pub endpoint_override: Option<String>,
57    /// WebSocket write buffer size in bytes.  Default: 64 KB.
58    pub write_buffer_size: usize,
59    /// Maximum WebSocket frame size in bytes.  Default: 16 MB.
60    pub max_frame_size: usize,
61    /// Connection timeout.  Default: 10 s.
62    pub connect_timeout: Duration,
63}
64
65impl Default for TransportConfig {
66    fn default() -> Self {
67        Self {
68            auth: Auth::ApiKey(String::new()),
69            endpoint_override: None,
70            write_buffer_size: 64 * 1024,
71            max_frame_size: 16 * 1024 * 1024,
72            connect_timeout: Duration::from_secs(10),
73        }
74    }
75}
76
77// ── RawFrame ─────────────────────────────────────────────────────────────────
78
79/// A raw WebSocket frame received from the server.
80#[derive(Debug, Clone, PartialEq)]
81pub enum RawFrame {
82    /// UTF-8 text frame (JSON on the Gemini Live protocol).
83    Text(String),
84    /// Binary frame.
85    Binary(Vec<u8>),
86    /// Close frame with an optional reason string.
87    Close(Option<String>),
88}
89
90// ── Connection ───────────────────────────────────────────────────────────────
91
92type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
93
94/// Low-level WebSocket connection handle.
95///
96/// Wraps a split `tokio-tungstenite` stream (sink + source) and provides
97/// simple send/recv methods for raw frames.  This type is **not** meant to
98/// be used directly by application code — the session layer manages it.
99pub struct Connection {
100    sink: SplitSink<WsStream, Message>,
101    stream: SplitStream<WsStream>,
102}
103
104impl Connection {
105    /// Establish a WebSocket connection (does **not** send the `setup` message).
106    pub async fn connect(config: &TransportConfig) -> Result<Self, ConnectError> {
107        // Ensure a rustls CryptoProvider is installed (idempotent).
108        let _ = rustls::crypto::ring::default_provider().install_default();
109
110        let url = build_url(config);
111        let mut ws_config = WebSocketConfig::default();
112        ws_config.write_buffer_size = config.write_buffer_size;
113        ws_config.max_write_buffer_size = config.write_buffer_size * 2;
114        ws_config.max_frame_size = Some(config.max_frame_size);
115        ws_config.max_message_size = Some(config.max_frame_size);
116
117        let connect_fut = connect_async_with_config(url, Some(ws_config), false);
118
119        let (ws_stream, _response) = tokio::time::timeout(config.connect_timeout, connect_fut)
120            .await
121            .map_err(|_| ConnectError::Timeout(config.connect_timeout))?
122            .map_err(classify_connect_error)?;
123
124        let (sink, stream) = ws_stream.split();
125        tracing::debug!("WebSocket connection established");
126        Ok(Self { sink, stream })
127    }
128
129    /// Send a text frame (typically a serialised JSON message).
130    pub async fn send_text(&mut self, json: &str) -> Result<(), SendError> {
131        self.sink
132            .send(Message::text(json))
133            .await
134            .map_err(classify_send_error)
135    }
136
137    /// Send a binary frame.
138    pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), SendError> {
139        self.sink
140            .send(Message::binary(data.to_vec()))
141            .await
142            .map_err(classify_send_error)
143    }
144
145    /// Receive the next meaningful frame, skipping ping/pong control frames.
146    pub async fn recv(&mut self) -> Result<RawFrame, RecvError> {
147        loop {
148            match self.stream.next().await {
149                Some(Ok(msg)) => {
150                    tracing::trace!(msg_type = ?std::mem::discriminant(&msg), "raw ws frame received");
151                    match msg {
152                        Message::Text(text) => return Ok(RawFrame::Text(text.to_string())),
153                        Message::Binary(data) => return Ok(RawFrame::Binary(data.to_vec())),
154                        Message::Close(frame) => {
155                            let reason = frame.map(|f| f.reason.to_string());
156                            return Ok(RawFrame::Close(reason));
157                        }
158                        // Ping/Pong are handled at the tungstenite protocol level.
159                        Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
160                    }
161                }
162                Some(Err(e)) => return Err(RecvError::Ws(e)),
163                None => return Err(RecvError::Closed),
164            }
165        }
166    }
167
168    /// Send a close frame without consuming the connection.
169    pub(crate) async fn send_close(&mut self) -> Result<(), SendError> {
170        self.sink
171            .send(Message::Close(None))
172            .await
173            .map_err(classify_send_error)
174    }
175
176    /// Gracefully close the connection by sending a close frame.
177    pub async fn close(mut self) -> Result<(), SendError> {
178        self.send_close().await
179    }
180}
181
182// ── helpers ──────────────────────────────────────────────────────────────────
183
184fn build_url(config: &TransportConfig) -> String {
185    if let Some(url) = &config.endpoint_override {
186        return url.clone();
187    }
188    match &config.auth {
189        Auth::ApiKey(key) => format!("{DEFAULT_HOST}{API_KEY_PATH}?key={key}"),
190        Auth::EphemeralToken(token) => {
191            format!("{DEFAULT_HOST}{EPHEMERAL_TOKEN_PATH}?access_token={token}")
192        }
193    }
194}
195
196fn classify_connect_error(e: tungstenite::Error) -> ConnectError {
197    match e {
198        tungstenite::Error::Http(response) => ConnectError::Rejected {
199            status: response.status().as_u16(),
200        },
201        other => ConnectError::Ws(other),
202    }
203}
204
205fn classify_send_error(e: tungstenite::Error) -> SendError {
206    match e {
207        tungstenite::Error::ConnectionClosed | tungstenite::Error::AlreadyClosed => {
208            SendError::Closed
209        }
210        other => SendError::Ws(other),
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn url_api_key() {
220        let config = TransportConfig {
221            auth: Auth::ApiKey("test-key-123".into()),
222            ..Default::default()
223        };
224        let url = build_url(&config);
225        assert!(url.starts_with("wss://generativelanguage.googleapis.com"));
226        assert!(url.contains("BidiGenerateContent?key=test-key-123"));
227        assert!(!url.contains("v1alpha"));
228    }
229
230    #[test]
231    fn url_ephemeral_token() {
232        let config = TransportConfig {
233            auth: Auth::EphemeralToken("tok-abc".into()),
234            ..Default::default()
235        };
236        let url = build_url(&config);
237        assert!(url.contains("v1alpha"));
238        assert!(url.contains("BidiGenerateContentConstrained?access_token=tok-abc"));
239    }
240
241    #[test]
242    fn url_endpoint_override() {
243        let config = TransportConfig {
244            auth: Auth::ApiKey("ignored".into()),
245            endpoint_override: Some("wss://custom.example.com/ws".into()),
246            ..Default::default()
247        };
248        let url = build_url(&config);
249        assert_eq!(url, "wss://custom.example.com/ws");
250    }
251
252    #[test]
253    fn default_config_values() {
254        let config = TransportConfig::default();
255        assert_eq!(config.write_buffer_size, 64 * 1024);
256        assert_eq!(config.max_frame_size, 16 * 1024 * 1024);
257        assert_eq!(config.connect_timeout, Duration::from_secs(10));
258        assert!(config.endpoint_override.is_none());
259    }
260}