1use std::future::Future;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::time::Duration;
30
31use futures_util::stream::{SplitSink, SplitStream};
32use futures_util::{SinkExt, StreamExt};
33use tokio::net::TcpStream;
34use tokio_tungstenite::tungstenite::client::IntoClientRequest;
35use tokio_tungstenite::tungstenite::http::{HeaderValue, Request, header::AUTHORIZATION};
36use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
37use tokio_tungstenite::tungstenite::{self, Message};
38use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
39use url::Url;
40
41use crate::error::{BearerTokenError, ConnectError, RecvError, SendError};
42
43#[cfg(feature = "vertex-auth")]
44mod vertex_auth;
45
46#[cfg(feature = "vertex-auth")]
47pub use vertex_auth::VertexAiApplicationDefaultCredentials;
48
49const GEMINI_API_HOST: &str = "wss://generativelanguage.googleapis.com";
50const GEMINI_API_KEY_PATH: &str =
51 "/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent";
52const GEMINI_EPHEMERAL_TOKEN_PATH: &str =
53 "/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContentConstrained";
54const VERTEX_AI_PATH: &str = "/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent";
55
56#[derive(Debug, Clone, Default, PartialEq, Eq)]
64pub enum Endpoint {
65 #[default]
67 GeminiApi,
68 VertexAi { location: String },
73 Custom(String),
75}
76
77type BearerTokenFuture<'a> =
80 Pin<Box<dyn Future<Output = Result<String, BearerTokenError>> + Send + 'a>>;
81
82trait DynBearerTokenProvider: Send + Sync {
83 fn name(&self) -> &'static str;
84 fn bearer_token(&self) -> BearerTokenFuture<'_>;
85}
86
87struct FnBearerTokenProvider<F> {
88 name: &'static str,
89 func: F,
90}
91
92impl<F, Fut> DynBearerTokenProvider for FnBearerTokenProvider<F>
93where
94 F: Fn() -> Fut + Send + Sync + 'static,
95 Fut: Future<Output = Result<String, BearerTokenError>> + Send + 'static,
96{
97 fn name(&self) -> &'static str {
98 self.name
99 }
100
101 fn bearer_token(&self) -> BearerTokenFuture<'_> {
102 Box::pin((self.func)())
103 }
104}
105
106#[derive(Clone)]
112pub struct BearerTokenProvider {
113 inner: Arc<dyn DynBearerTokenProvider>,
114}
115
116impl BearerTokenProvider {
117 fn new<P>(provider: P) -> Self
118 where
119 P: DynBearerTokenProvider + 'static,
120 {
121 Self {
122 inner: Arc::new(provider),
123 }
124 }
125
126 pub fn from_fn<F, Fut>(name: &'static str, func: F) -> Self
128 where
129 F: Fn() -> Fut + Send + Sync + 'static,
130 Fut: Future<Output = Result<String, BearerTokenError>> + Send + 'static,
131 {
132 Self::new(FnBearerTokenProvider { name, func })
133 }
134
135 pub async fn bearer_token(&self) -> Result<String, BearerTokenError> {
137 self.inner.bearer_token().await
138 }
139
140 #[cfg(feature = "vertex-auth")]
141 pub fn vertex_ai_application_default() -> Result<Self, BearerTokenError> {
144 Ok(VertexAiApplicationDefaultCredentials::new()?.into_bearer_token_provider())
145 }
146}
147
148impl std::fmt::Debug for BearerTokenProvider {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("BearerTokenProvider")
151 .field("kind", &self.inner.name())
152 .finish()
153 }
154}
155
156#[derive(Clone)]
160pub enum Auth {
161 None,
166 ApiKey(String),
168 EphemeralToken(String),
170 BearerToken(String),
174 BearerTokenProvider(BearerTokenProvider),
179}
180
181impl std::fmt::Debug for Auth {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 match self {
184 Self::None => f.debug_tuple("None").finish(),
185 Self::ApiKey(_) => f.debug_tuple("ApiKey").field(&"<redacted>").finish(),
186 Self::EphemeralToken(_) => f
187 .debug_tuple("EphemeralToken")
188 .field(&"<redacted>")
189 .finish(),
190 Self::BearerToken(_) => f.debug_tuple("BearerToken").field(&"<redacted>").finish(),
191 Self::BearerTokenProvider(provider) => f
192 .debug_tuple("BearerTokenProvider")
193 .field(provider)
194 .finish(),
195 }
196 }
197}
198
199impl Auth {
200 #[cfg(feature = "vertex-auth")]
201 pub fn vertex_ai_application_default() -> Result<Self, BearerTokenError> {
204 Ok(Self::BearerTokenProvider(
205 BearerTokenProvider::vertex_ai_application_default()?,
206 ))
207 }
208}
209
210#[derive(Debug, Clone)]
218pub struct TransportConfig {
219 pub endpoint: Endpoint,
221 pub auth: Auth,
223 pub write_buffer_size: usize,
225 pub max_frame_size: usize,
227 pub connect_timeout: Duration,
229}
230
231impl Default for TransportConfig {
232 fn default() -> Self {
233 Self {
234 endpoint: Endpoint::GeminiApi,
235 auth: Auth::None,
236 write_buffer_size: 1024 * 1024,
237 max_frame_size: 16 * 1024 * 1024,
238 connect_timeout: Duration::from_secs(10),
239 }
240 }
241}
242
243#[derive(Debug, Clone, PartialEq)]
247pub enum RawFrame {
248 Text(String),
250 Binary(Vec<u8>),
252 Close(Option<String>),
254}
255
256type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
259
260pub struct Connection {
266 sink: SplitSink<WsStream, Message>,
267 stream: SplitStream<WsStream>,
268}
269
270impl Connection {
271 pub async fn connect(config: &TransportConfig) -> Result<Self, ConnectError> {
273 install_rustls_crypto_provider();
274
275 let request = build_request(config).await?;
276 let mut ws_config = WebSocketConfig::default();
277 ws_config.write_buffer_size = config.write_buffer_size;
278 ws_config.max_write_buffer_size = config.write_buffer_size * 2;
279 ws_config.max_frame_size = Some(config.max_frame_size);
280 ws_config.max_message_size = Some(config.max_frame_size);
281
282 let connect_fut = connect_async_with_config(request, Some(ws_config), false);
283
284 let (ws_stream, _response) = tokio::time::timeout(config.connect_timeout, connect_fut)
285 .await
286 .map_err(|_| ConnectError::Timeout(config.connect_timeout))?
287 .map_err(classify_connect_error)?;
288
289 let (sink, stream) = ws_stream.split();
290 tracing::debug!("WebSocket connection established");
291 Ok(Self { sink, stream })
292 }
293
294 pub async fn send_text(&mut self, json: &str) -> Result<(), SendError> {
296 self.sink
297 .send(Message::text(json))
298 .await
299 .map_err(classify_send_error)
300 }
301
302 pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), SendError> {
304 self.sink
305 .send(Message::binary(data.to_vec()))
306 .await
307 .map_err(classify_send_error)
308 }
309
310 pub async fn recv(&mut self) -> Result<RawFrame, RecvError> {
312 loop {
313 match self.stream.next().await {
314 Some(Ok(msg)) => {
315 tracing::trace!(msg_type = ?std::mem::discriminant(&msg), "raw ws frame received");
316 match msg {
317 Message::Text(text) => return Ok(RawFrame::Text(text.to_string())),
318 Message::Binary(data) => return Ok(RawFrame::Binary(data.to_vec())),
319 Message::Close(frame) => {
320 let reason = frame.map(|f| f.reason.to_string());
321 return Ok(RawFrame::Close(reason));
322 }
323 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
325 }
326 }
327 Some(Err(e)) => return Err(RecvError::Ws(e)),
328 None => return Err(RecvError::Closed),
329 }
330 }
331 }
332
333 pub(crate) async fn send_close(&mut self) -> Result<(), SendError> {
335 self.sink
336 .send(Message::Close(None))
337 .await
338 .map_err(classify_send_error)
339 }
340
341 pub async fn close(mut self) -> Result<(), SendError> {
343 self.send_close().await
344 }
345}
346
347pub(crate) fn install_rustls_crypto_provider() {
350 let _ = rustls::crypto::ring::default_provider().install_default();
353}
354
355async fn build_request(config: &TransportConfig) -> Result<Request<()>, ConnectError> {
356 validate_transport_config(config)?;
357
358 let url = build_url(config)?;
359 let mut request = url
360 .as_str()
361 .into_client_request()
362 .map_err(|e| ConnectError::Config(format!("invalid websocket request: {e}")))?;
363
364 if let Some(header) = build_bearer_header(&config.auth).await? {
365 request.headers_mut().insert(AUTHORIZATION, header);
366 }
367
368 Ok(request)
369}
370
371fn validate_transport_config(config: &TransportConfig) -> Result<(), ConnectError> {
372 match (&config.endpoint, &config.auth) {
373 (Endpoint::GeminiApi, Auth::ApiKey(_) | Auth::EphemeralToken(_)) => Ok(()),
374 (Endpoint::GeminiApi, Auth::None) => Err(ConnectError::Config(
375 "Endpoint::GeminiApi requires Auth::ApiKey or Auth::EphemeralToken".into(),
376 )),
377 (Endpoint::GeminiApi, Auth::BearerToken(_) | Auth::BearerTokenProvider(_)) => Err(
378 ConnectError::Config(
379 "Endpoint::GeminiApi does not use bearer auth; use Auth::ApiKey or Auth::EphemeralToken".into(),
380 ),
381 ),
382 (
383 Endpoint::VertexAi { location },
384 Auth::BearerToken(_) | Auth::BearerTokenProvider(_),
385 ) => {
386 if location.trim().is_empty() {
387 return Err(ConnectError::Config(
388 "Endpoint::VertexAi location must not be empty".into(),
389 ));
390 }
391 Ok(())
392 }
393 (Endpoint::VertexAi { .. }, _) => Err(ConnectError::Config(
394 "Endpoint::VertexAi requires Auth::BearerToken or Auth::BearerTokenProvider"
395 .into(),
396 )),
397 (Endpoint::Custom(url), _) => {
398 Url::parse(url).map_err(|e| ConnectError::Config(format!("invalid custom endpoint URL: {e}")))?;
399 Ok(())
400 }
401 }
402}
403
404fn build_url(config: &TransportConfig) -> Result<Url, ConnectError> {
405 let mut url = match &config.endpoint {
406 Endpoint::GeminiApi => {
407 Url::parse(&format!(
408 "{}{}",
409 GEMINI_API_HOST,
410 gemini_path_for_auth(&config.auth)
411 ))
412 }
413 .map_err(|e| ConnectError::Config(format!("invalid Gemini API endpoint URL: {e}")))?,
414 Endpoint::VertexAi { location } => Url::parse(&format!(
415 "wss://{location}-aiplatform.googleapis.com{VERTEX_AI_PATH}"
416 ))
417 .map_err(|e| ConnectError::Config(format!("invalid Vertex AI endpoint URL: {e}")))?,
418 Endpoint::Custom(url) => Url::parse(url)
419 .map_err(|e| ConnectError::Config(format!("invalid custom endpoint URL: {e}")))?,
420 };
421
422 match &config.auth {
423 Auth::ApiKey(key) => {
424 url.query_pairs_mut().append_pair("key", key);
425 }
426 Auth::EphemeralToken(token) => {
427 url.query_pairs_mut().append_pair("access_token", token);
428 }
429 Auth::None | Auth::BearerToken(_) | Auth::BearerTokenProvider(_) => {}
430 }
431
432 Ok(url)
433}
434
435fn gemini_path_for_auth(auth: &Auth) -> &'static str {
436 match auth {
437 Auth::EphemeralToken(_) => GEMINI_EPHEMERAL_TOKEN_PATH,
438 Auth::None | Auth::ApiKey(_) | Auth::BearerToken(_) | Auth::BearerTokenProvider(_) => {
439 GEMINI_API_KEY_PATH
440 }
441 }
442}
443
444async fn build_bearer_header(auth: &Auth) -> Result<Option<HeaderValue>, ConnectError> {
445 match auth {
446 Auth::BearerToken(token) => HeaderValue::from_str(&format!("Bearer {token}"))
447 .map(Some)
448 .map_err(|e| ConnectError::Config(format!("invalid bearer token header: {e}"))),
449 Auth::BearerTokenProvider(provider) => {
450 let token = provider.bearer_token().await.map_err(ConnectError::Auth)?;
451 HeaderValue::from_str(&format!("Bearer {token}"))
452 .map(Some)
453 .map_err(|e| {
454 ConnectError::Auth(BearerTokenError::with_source(
455 "token provider returned an invalid bearer token",
456 e,
457 ))
458 })
459 }
460 Auth::None | Auth::ApiKey(_) | Auth::EphemeralToken(_) => Ok(None),
461 }
462}
463
464fn classify_connect_error(e: tungstenite::Error) -> ConnectError {
465 match e {
466 tungstenite::Error::Http(response) => ConnectError::Rejected {
467 status: response.status().as_u16(),
468 },
469 other => ConnectError::Ws(other),
470 }
471}
472
473fn classify_send_error(e: tungstenite::Error) -> SendError {
474 match e {
475 tungstenite::Error::ConnectionClosed | tungstenite::Error::AlreadyClosed => {
476 SendError::Closed
477 }
478 other => SendError::Ws(other),
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use std::sync::Arc;
485 use std::sync::atomic::{AtomicUsize, Ordering};
486
487 use super::*;
488
489 #[tokio::test]
490 async fn request_gemini_api_key_uses_query_auth() {
491 let config = TransportConfig {
492 endpoint: Endpoint::GeminiApi,
493 auth: Auth::ApiKey("test-key-123".into()),
494 ..Default::default()
495 };
496 let request = build_request(&config).await.expect("request");
497 let uri = request.uri().to_string();
498
499 assert!(uri.starts_with("wss://generativelanguage.googleapis.com"));
500 assert!(uri.contains("BidiGenerateContent?key=test-key-123"));
501 assert!(!uri.contains("v1alpha"));
502 assert!(request.headers().get(AUTHORIZATION).is_none());
503 }
504
505 #[tokio::test]
506 async fn request_gemini_ephemeral_token_uses_constrained_path() {
507 let config = TransportConfig {
508 endpoint: Endpoint::GeminiApi,
509 auth: Auth::EphemeralToken("tok-abc".into()),
510 ..Default::default()
511 };
512 let request = build_request(&config).await.expect("request");
513 let uri = request.uri().to_string();
514
515 assert!(uri.contains("v1alpha"));
516 assert!(uri.contains("BidiGenerateContentConstrained?access_token=tok-abc"));
517 assert!(request.headers().get(AUTHORIZATION).is_none());
518 }
519
520 #[tokio::test]
521 async fn request_vertex_ai_uses_bearer_header() {
522 let config = TransportConfig {
523 endpoint: Endpoint::VertexAi {
524 location: "us-central1".into(),
525 },
526 auth: Auth::BearerToken("vertex-token".into()),
527 ..Default::default()
528 };
529 let request = build_request(&config).await.expect("request");
530
531 assert_eq!(
532 request.uri(),
533 "wss://us-central1-aiplatform.googleapis.com/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
534 );
535 assert_eq!(
536 request
537 .headers()
538 .get(AUTHORIZATION)
539 .expect("authorization header"),
540 "Bearer vertex-token"
541 );
542 }
543
544 #[tokio::test]
545 async fn request_custom_endpoint_can_skip_auth() {
546 let config = TransportConfig {
547 endpoint: Endpoint::Custom("wss://custom.example.com/ws".into()),
548 auth: Auth::None,
549 ..Default::default()
550 };
551 let request = build_request(&config).await.expect("request");
552
553 assert_eq!(request.uri(), "wss://custom.example.com/ws");
554 assert!(request.headers().get(AUTHORIZATION).is_none());
555 }
556
557 #[tokio::test]
558 async fn request_vertex_ai_provider_fetches_token_per_connect() {
559 let calls = Arc::new(AtomicUsize::new(0));
560 let provider = BearerTokenProvider::from_fn("test-sequence", {
561 let calls = Arc::clone(&calls);
562 move || {
563 let calls = Arc::clone(&calls);
564 async move {
565 let next = calls.fetch_add(1, Ordering::Relaxed) + 1;
566 Ok(format!("token-{next}"))
567 }
568 }
569 });
570
571 let config = TransportConfig {
572 endpoint: Endpoint::VertexAi {
573 location: "us-central1".into(),
574 },
575 auth: Auth::BearerTokenProvider(provider),
576 ..Default::default()
577 };
578
579 let first = build_request(&config).await.expect("first request");
580 let second = build_request(&config).await.expect("second request");
581
582 assert_eq!(
583 first.headers().get(AUTHORIZATION).expect("first auth"),
584 "Bearer token-1"
585 );
586 assert_eq!(
587 second.headers().get(AUTHORIZATION).expect("second auth"),
588 "Bearer token-2"
589 );
590 assert_eq!(calls.load(Ordering::Relaxed), 2);
591 }
592
593 #[tokio::test]
594 async fn request_vertex_ai_provider_error_bubbles() {
595 let config = TransportConfig {
596 endpoint: Endpoint::VertexAi {
597 location: "us-central1".into(),
598 },
599 auth: Auth::BearerTokenProvider(BearerTokenProvider::from_fn(
600 "always-fails",
601 || async { Err(BearerTokenError::new("boom")) },
602 )),
603 ..Default::default()
604 };
605
606 let err = build_request(&config)
607 .await
608 .expect_err("provider failure should bubble");
609
610 assert!(matches!(err, ConnectError::Auth(source) if source.to_string() == "boom"));
611 }
612
613 #[tokio::test]
614 async fn invalid_vertex_auth_is_rejected_before_connect() {
615 let config = TransportConfig {
616 endpoint: Endpoint::VertexAi {
617 location: "us-central1".into(),
618 },
619 auth: Auth::ApiKey("not-vertex".into()),
620 ..Default::default()
621 };
622 let err = build_request(&config).await.expect_err("config error");
623
624 assert!(
625 matches!(err, ConnectError::Config(message) if message == "Endpoint::VertexAi requires Auth::BearerToken or Auth::BearerTokenProvider")
626 );
627 }
628
629 #[tokio::test]
630 async fn invalid_gemini_bearer_auth_is_rejected_before_connect() {
631 let config = TransportConfig {
632 endpoint: Endpoint::GeminiApi,
633 auth: Auth::BearerTokenProvider(BearerTokenProvider::from_fn("wrong", || async {
634 Ok("wrong".into())
635 })),
636 ..Default::default()
637 };
638 let err = build_request(&config).await.expect_err("config error");
639
640 assert!(
641 matches!(err, ConnectError::Config(message) if message.contains("does not use bearer auth"))
642 );
643 }
644
645 #[test]
646 fn default_config_values() {
647 let config = TransportConfig::default();
648
649 assert_eq!(config.endpoint, Endpoint::GeminiApi);
650 assert!(matches!(config.auth, Auth::None));
651 assert_eq!(config.write_buffer_size, 1024 * 1024);
652 assert_eq!(config.max_frame_size, 16 * 1024 * 1024);
653 assert_eq!(config.connect_timeout, Duration::from_secs(10));
654 }
655}