use std::time::Duration;
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::tungstenite::{self, Message};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async_with_config};
use crate::error::{ConnectError, RecvError, SendError};
const DEFAULT_HOST: &str = "wss://generativelanguage.googleapis.com";
const API_KEY_PATH: &str =
"/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent";
const EPHEMERAL_TOKEN_PATH: &str =
"/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContentConstrained";
#[derive(Debug, Clone)]
pub enum Auth {
ApiKey(String),
EphemeralToken(String),
}
#[derive(Debug, Clone)]
pub struct TransportConfig {
pub auth: Auth,
pub endpoint_override: Option<String>,
pub write_buffer_size: usize,
pub max_frame_size: usize,
pub connect_timeout: Duration,
}
impl Default for TransportConfig {
fn default() -> Self {
Self {
auth: Auth::ApiKey(String::new()),
endpoint_override: None,
write_buffer_size: 64 * 1024,
max_frame_size: 16 * 1024 * 1024,
connect_timeout: Duration::from_secs(10),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum RawFrame {
Text(String),
Binary(Vec<u8>),
Close(Option<String>),
}
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub struct Connection {
sink: SplitSink<WsStream, Message>,
stream: SplitStream<WsStream>,
}
impl Connection {
pub async fn connect(config: &TransportConfig) -> Result<Self, ConnectError> {
let _ = rustls::crypto::ring::default_provider().install_default();
let url = build_url(config);
let mut ws_config = WebSocketConfig::default();
ws_config.write_buffer_size = config.write_buffer_size;
ws_config.max_write_buffer_size = config.write_buffer_size * 2;
ws_config.max_frame_size = Some(config.max_frame_size);
ws_config.max_message_size = Some(config.max_frame_size);
let connect_fut = connect_async_with_config(url, Some(ws_config), false);
let (ws_stream, _response) = tokio::time::timeout(config.connect_timeout, connect_fut)
.await
.map_err(|_| ConnectError::Timeout(config.connect_timeout))?
.map_err(classify_connect_error)?;
let (sink, stream) = ws_stream.split();
tracing::debug!("WebSocket connection established");
Ok(Self { sink, stream })
}
pub async fn send_text(&mut self, json: &str) -> Result<(), SendError> {
self.sink
.send(Message::text(json))
.await
.map_err(classify_send_error)
}
pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), SendError> {
self.sink
.send(Message::binary(data.to_vec()))
.await
.map_err(classify_send_error)
}
pub async fn recv(&mut self) -> Result<RawFrame, RecvError> {
loop {
match self.stream.next().await {
Some(Ok(msg)) => {
tracing::trace!(msg_type = ?std::mem::discriminant(&msg), "raw ws frame received");
match msg {
Message::Text(text) => return Ok(RawFrame::Text(text.to_string())),
Message::Binary(data) => return Ok(RawFrame::Binary(data.to_vec())),
Message::Close(frame) => {
let reason = frame.map(|f| f.reason.to_string());
return Ok(RawFrame::Close(reason));
}
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
}
}
Some(Err(e)) => return Err(RecvError::Ws(e)),
None => return Err(RecvError::Closed),
}
}
}
pub(crate) async fn send_close(&mut self) -> Result<(), SendError> {
self.sink
.send(Message::Close(None))
.await
.map_err(classify_send_error)
}
pub async fn close(mut self) -> Result<(), SendError> {
self.send_close().await
}
}
fn build_url(config: &TransportConfig) -> String {
if let Some(url) = &config.endpoint_override {
return url.clone();
}
match &config.auth {
Auth::ApiKey(key) => format!("{DEFAULT_HOST}{API_KEY_PATH}?key={key}"),
Auth::EphemeralToken(token) => {
format!("{DEFAULT_HOST}{EPHEMERAL_TOKEN_PATH}?access_token={token}")
}
}
}
fn classify_connect_error(e: tungstenite::Error) -> ConnectError {
match e {
tungstenite::Error::Http(response) => ConnectError::Rejected {
status: response.status().as_u16(),
},
other => ConnectError::Ws(other),
}
}
fn classify_send_error(e: tungstenite::Error) -> SendError {
match e {
tungstenite::Error::ConnectionClosed | tungstenite::Error::AlreadyClosed => {
SendError::Closed
}
other => SendError::Ws(other),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn url_api_key() {
let config = TransportConfig {
auth: Auth::ApiKey("test-key-123".into()),
..Default::default()
};
let url = build_url(&config);
assert!(url.starts_with("wss://generativelanguage.googleapis.com"));
assert!(url.contains("BidiGenerateContent?key=test-key-123"));
assert!(!url.contains("v1alpha"));
}
#[test]
fn url_ephemeral_token() {
let config = TransportConfig {
auth: Auth::EphemeralToken("tok-abc".into()),
..Default::default()
};
let url = build_url(&config);
assert!(url.contains("v1alpha"));
assert!(url.contains("BidiGenerateContentConstrained?access_token=tok-abc"));
}
#[test]
fn url_endpoint_override() {
let config = TransportConfig {
auth: Auth::ApiKey("ignored".into()),
endpoint_override: Some("wss://custom.example.com/ws".into()),
..Default::default()
};
let url = build_url(&config);
assert_eq!(url, "wss://custom.example.com/ws");
}
#[test]
fn default_config_values() {
let config = TransportConfig::default();
assert_eq!(config.write_buffer_size, 64 * 1024);
assert_eq!(config.max_frame_size, 16 * 1024 * 1024);
assert_eq!(config.connect_timeout, Duration::from_secs(10));
assert!(config.endpoint_override.is_none());
}
}