1use std::time::Duration;
19
20use futures_util::stream::{SplitSink, SplitStream};
21use futures_util::{SinkExt, StreamExt};
22use tokio::net::TcpStream;
23use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
24use tokio_tungstenite::tungstenite::{self, Message};
25use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
26
27use crate::error::{ConnectError, RecvError, SendError};
28
29const DEFAULT_HOST: &str = "wss://generativelanguage.googleapis.com";
30const API_KEY_PATH: &str =
31 "/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent";
32const EPHEMERAL_TOKEN_PATH: &str =
33 "/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContentConstrained";
34
35#[derive(Debug, Clone)]
39pub enum Auth {
40 ApiKey(String),
42 EphemeralToken(String),
44}
45
46#[derive(Debug, Clone)]
53pub struct TransportConfig {
54 pub auth: Auth,
55 pub endpoint_override: Option<String>,
57 pub write_buffer_size: usize,
59 pub max_frame_size: usize,
61 pub connect_timeout: Duration,
63}
64
65impl Default for TransportConfig {
66 fn default() -> Self {
67 Self {
68 auth: Auth::ApiKey(String::new()),
69 endpoint_override: None,
70 write_buffer_size: 64 * 1024,
71 max_frame_size: 16 * 1024 * 1024,
72 connect_timeout: Duration::from_secs(10),
73 }
74 }
75}
76
77#[derive(Debug, Clone, PartialEq)]
81pub enum RawFrame {
82 Text(String),
84 Binary(Vec<u8>),
86 Close(Option<String>),
88}
89
90type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
93
94pub struct Connection {
100 sink: SplitSink<WsStream, Message>,
101 stream: SplitStream<WsStream>,
102}
103
104impl Connection {
105 pub async fn connect(config: &TransportConfig) -> Result<Self, ConnectError> {
107 let _ = rustls::crypto::ring::default_provider().install_default();
109
110 let url = build_url(config);
111 let mut ws_config = WebSocketConfig::default();
112 ws_config.write_buffer_size = config.write_buffer_size;
113 ws_config.max_write_buffer_size = config.write_buffer_size * 2;
114 ws_config.max_frame_size = Some(config.max_frame_size);
115 ws_config.max_message_size = Some(config.max_frame_size);
116
117 let connect_fut = connect_async_with_config(url, Some(ws_config), false);
118
119 let (ws_stream, _response) = tokio::time::timeout(config.connect_timeout, connect_fut)
120 .await
121 .map_err(|_| ConnectError::Timeout(config.connect_timeout))?
122 .map_err(classify_connect_error)?;
123
124 let (sink, stream) = ws_stream.split();
125 tracing::debug!("WebSocket connection established");
126 Ok(Self { sink, stream })
127 }
128
129 pub async fn send_text(&mut self, json: &str) -> Result<(), SendError> {
131 self.sink
132 .send(Message::text(json))
133 .await
134 .map_err(classify_send_error)
135 }
136
137 pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), SendError> {
139 self.sink
140 .send(Message::binary(data.to_vec()))
141 .await
142 .map_err(classify_send_error)
143 }
144
145 pub async fn recv(&mut self) -> Result<RawFrame, RecvError> {
147 loop {
148 match self.stream.next().await {
149 Some(Ok(msg)) => {
150 tracing::trace!(msg_type = ?std::mem::discriminant(&msg), "raw ws frame received");
151 match msg {
152 Message::Text(text) => return Ok(RawFrame::Text(text.to_string())),
153 Message::Binary(data) => return Ok(RawFrame::Binary(data.to_vec())),
154 Message::Close(frame) => {
155 let reason = frame.map(|f| f.reason.to_string());
156 return Ok(RawFrame::Close(reason));
157 }
158 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
160 }
161 }
162 Some(Err(e)) => return Err(RecvError::Ws(e)),
163 None => return Err(RecvError::Closed),
164 }
165 }
166 }
167
168 pub(crate) async fn send_close(&mut self) -> Result<(), SendError> {
170 self.sink
171 .send(Message::Close(None))
172 .await
173 .map_err(classify_send_error)
174 }
175
176 pub async fn close(mut self) -> Result<(), SendError> {
178 self.send_close().await
179 }
180}
181
182fn build_url(config: &TransportConfig) -> String {
185 if let Some(url) = &config.endpoint_override {
186 return url.clone();
187 }
188 match &config.auth {
189 Auth::ApiKey(key) => format!("{DEFAULT_HOST}{API_KEY_PATH}?key={key}"),
190 Auth::EphemeralToken(token) => {
191 format!("{DEFAULT_HOST}{EPHEMERAL_TOKEN_PATH}?access_token={token}")
192 }
193 }
194}
195
196fn classify_connect_error(e: tungstenite::Error) -> ConnectError {
197 match e {
198 tungstenite::Error::Http(response) => ConnectError::Rejected {
199 status: response.status().as_u16(),
200 },
201 other => ConnectError::Ws(other),
202 }
203}
204
205fn classify_send_error(e: tungstenite::Error) -> SendError {
206 match e {
207 tungstenite::Error::ConnectionClosed | tungstenite::Error::AlreadyClosed => {
208 SendError::Closed
209 }
210 other => SendError::Ws(other),
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn url_api_key() {
220 let config = TransportConfig {
221 auth: Auth::ApiKey("test-key-123".into()),
222 ..Default::default()
223 };
224 let url = build_url(&config);
225 assert!(url.starts_with("wss://generativelanguage.googleapis.com"));
226 assert!(url.contains("BidiGenerateContent?key=test-key-123"));
227 assert!(!url.contains("v1alpha"));
228 }
229
230 #[test]
231 fn url_ephemeral_token() {
232 let config = TransportConfig {
233 auth: Auth::EphemeralToken("tok-abc".into()),
234 ..Default::default()
235 };
236 let url = build_url(&config);
237 assert!(url.contains("v1alpha"));
238 assert!(url.contains("BidiGenerateContentConstrained?access_token=tok-abc"));
239 }
240
241 #[test]
242 fn url_endpoint_override() {
243 let config = TransportConfig {
244 auth: Auth::ApiKey("ignored".into()),
245 endpoint_override: Some("wss://custom.example.com/ws".into()),
246 ..Default::default()
247 };
248 let url = build_url(&config);
249 assert_eq!(url, "wss://custom.example.com/ws");
250 }
251
252 #[test]
253 fn default_config_values() {
254 let config = TransportConfig::default();
255 assert_eq!(config.write_buffer_size, 64 * 1024);
256 assert_eq!(config.max_frame_size, 16 * 1024 * 1024);
257 assert_eq!(config.connect_timeout, Duration::from_secs(10));
258 assert!(config.endpoint_override.is_none());
259 }
260}