use std::collections::VecDeque;
use base64::Engine as _;
use futures::{SinkExt, StreamExt};
use serde_json::{Value, json};
use tokio_tungstenite::tungstenite::Message;
use tracing::{debug, warn};
use crate::core::Model as _;
use crate::error::{Error, ProviderError, Result};
use crate::genai_types::{Content, FunctionCall, FunctionResponse, Tool, UsageMetadata};
use crate::providers::gemini::Gemini;
type WsStream =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
#[derive(Debug, Clone)]
pub struct LiveConfig {
pub response_modalities: Vec<String>,
pub system_instruction: Option<Content>,
pub tools: Vec<Tool>,
pub voice: Option<String>,
pub input_audio_transcription: bool,
pub output_audio_transcription: bool,
}
impl Default for LiveConfig {
fn default() -> Self {
Self {
response_modalities: vec!["TEXT".into()],
system_instruction: None,
tools: vec![],
voice: None,
input_audio_transcription: false,
output_audio_transcription: false,
}
}
}
#[derive(Debug, Clone)]
pub enum LiveEvent {
Text(String),
Audio {
data: Vec<u8>,
mime_type: String,
},
InputTranscription(String),
OutputTranscription(String),
ToolCall(Vec<FunctionCall>),
ToolCallCancellation(Vec<String>),
Interrupted,
GenerationComplete,
TurnComplete,
GoAway {
time_left: Option<String>,
},
UsageMetadata(UsageMetadata),
}
#[derive(Debug)]
pub struct LiveSession {
ws: WsStream,
pending: VecDeque<LiveEvent>,
}
impl Gemini {
pub async fn connect_live(&self, cfg: LiveConfig) -> Result<LiveSession> {
let gcfg = self.config();
if gcfg.api_key.is_empty() {
return Err(Error::Provider(ProviderError::Auth(
"Gemini api_key is empty; set $GOOGLE_API_KEY".into(),
)));
}
let base = gcfg.base_url.trim_end_matches('/');
let ws_base = if let Some(rest) = base.strip_prefix("https://") {
format!("wss://{rest}")
} else if let Some(rest) = base.strip_prefix("http://") {
format!("ws://{rest}")
} else {
return Err(Error::config(format!("unsupported base_url: {base}")));
};
let url = format!(
"{ws_base}/ws/google.ai.generativelanguage.{}.GenerativeService.BidiGenerateContent?key={}",
gcfg.api_version, gcfg.api_key,
);
let (ws, _) = tokio_tungstenite::connect_async(&url)
.await
.map_err(|e| ProviderError::Transport(format!("live connect: {e}")))?;
let mut session = LiveSession {
ws,
pending: VecDeque::new(),
};
let mut generation_config = json!({ "responseModalities": cfg.response_modalities });
if let Some(voice) = &cfg.voice {
generation_config["speechConfig"] =
json!({ "voiceConfig": { "prebuiltVoiceConfig": { "voiceName": voice } } });
}
let mut setup = json!({
"model": format!("models/{}", self.name()),
"generationConfig": generation_config,
});
if let Some(sys) = &cfg.system_instruction {
setup["systemInstruction"] = serde_json::to_value(sys)?;
}
if !cfg.tools.is_empty() {
setup["tools"] = serde_json::to_value(&cfg.tools)?;
}
if cfg.input_audio_transcription {
setup["inputAudioTranscription"] = json!({});
}
if cfg.output_audio_transcription {
setup["outputAudioTranscription"] = json!({});
}
session.send_json(&json!({ "setup": setup })).await?;
match session.next_message().await? {
Some(v) if v.get("setupComplete").is_some() => Ok(session),
Some(v) => Err(Error::Provider(ProviderError::Stream(format!(
"expected setupComplete, got: {v}"
)))),
None => Err(Error::Provider(ProviderError::Stream(
"connection closed before setupComplete".into(),
))),
}
}
}
impl LiveSession {
async fn send_json(&mut self, v: &Value) -> Result<()> {
self.ws
.send(Message::Text(v.to_string().into()))
.await
.map_err(|e| Error::Provider(ProviderError::Transport(format!("live send: {e}"))))
}
pub async fn send_text(&mut self, text: &str, turn_complete: bool) -> Result<()> {
self.send_json(&json!({
"clientContent": {
"turns": [{ "role": "user", "parts": [{ "text": text }] }],
"turnComplete": turn_complete,
}
}))
.await
}
pub async fn send_audio(&mut self, pcm: &[u8], mime_type: &str) -> Result<()> {
let data = base64::engine::general_purpose::STANDARD.encode(pcm);
self.send_json(&json!({
"realtimeInput": { "audio": { "data": data, "mimeType": mime_type } }
}))
.await
}
pub async fn send_audio_stream_end(&mut self) -> Result<()> {
self.send_json(&json!({ "realtimeInput": { "audioStreamEnd": true } }))
.await
}
pub async fn send_tool_response(&mut self, responses: Vec<FunctionResponse>) -> Result<()> {
self.send_json(&json!({
"toolResponse": { "functionResponses": serde_json::to_value(&responses)? }
}))
.await
}
pub async fn recv(&mut self) -> Result<Option<LiveEvent>> {
loop {
if let Some(ev) = self.pending.pop_front() {
return Ok(Some(ev));
}
match self.next_message().await? {
Some(v) => self.ingest(&v),
None => return Ok(None),
}
}
}
pub async fn close(mut self) -> Result<()> {
self.ws
.close(None)
.await
.map_err(|e| Error::Provider(ProviderError::Transport(format!("live close: {e}"))))
}
async fn next_message(&mut self) -> Result<Option<Value>> {
loop {
let Some(msg) = self.ws.next().await else {
return Ok(None);
};
let msg =
msg.map_err(|e| Error::Provider(ProviderError::Transport(format!("live: {e}"))))?;
let bytes = match msg {
Message::Text(t) => t.as_bytes().to_vec(),
Message::Binary(b) => b.to_vec(),
Message::Close(_) => return Ok(None),
_ => continue,
};
let v: Value = serde_json::from_slice(&bytes)
.map_err(|e| ProviderError::Decode(format!("live message: {e}")))?;
return Ok(Some(v));
}
}
fn ingest(&mut self, v: &Value) {
if let Some(sc) = v.get("serverContent") {
if sc.get("interrupted").and_then(Value::as_bool) == Some(true) {
self.pending.push_back(LiveEvent::Interrupted);
}
if let Some(t) = sc
.get("inputTranscription")
.and_then(|t| t.get("text"))
.and_then(Value::as_str)
{
self.pending
.push_back(LiveEvent::InputTranscription(t.to_string()));
}
if let Some(t) = sc
.get("outputTranscription")
.and_then(|t| t.get("text"))
.and_then(Value::as_str)
{
self.pending
.push_back(LiveEvent::OutputTranscription(t.to_string()));
}
if let Some(parts) = sc
.get("modelTurn")
.and_then(|mt| mt.get("parts"))
.and_then(Value::as_array)
{
for p in parts {
if let Some(t) = p.get("text").and_then(Value::as_str) {
self.pending.push_back(LiveEvent::Text(t.to_string()));
}
if let Some(inline) = p.get("inlineData") {
let mime_type = inline
.get("mimeType")
.and_then(Value::as_str)
.unwrap_or("audio/pcm")
.to_string();
let data = inline
.get("data")
.and_then(Value::as_str)
.and_then(|d| base64::engine::general_purpose::STANDARD.decode(d).ok())
.unwrap_or_default();
self.pending.push_back(LiveEvent::Audio { data, mime_type });
}
}
}
if sc.get("generationComplete").and_then(Value::as_bool) == Some(true) {
self.pending.push_back(LiveEvent::GenerationComplete);
}
if sc.get("turnComplete").and_then(Value::as_bool) == Some(true) {
self.pending.push_back(LiveEvent::TurnComplete);
}
} else if let Some(tc) = v.get("toolCall") {
let calls: Vec<FunctionCall> = tc
.get("functionCalls")
.map(|fc| serde_json::from_value(fc.clone()).unwrap_or_default())
.unwrap_or_default();
if !calls.is_empty() {
self.pending.push_back(LiveEvent::ToolCall(calls));
}
} else if let Some(tcc) = v.get("toolCallCancellation") {
let ids: Vec<String> = tcc
.get("ids")
.map(|ids| serde_json::from_value(ids.clone()).unwrap_or_default())
.unwrap_or_default();
self.pending.push_back(LiveEvent::ToolCallCancellation(ids));
} else if let Some(ga) = v.get("goAway") {
self.pending.push_back(LiveEvent::GoAway {
time_left: ga
.get("timeLeft")
.and_then(Value::as_str)
.map(str::to_string),
});
} else if let Some(um) = v.get("usageMetadata") {
match serde_json::from_value::<UsageMetadata>(um.clone()) {
Ok(usage) => self.pending.push_back(LiveEvent::UsageMetadata(usage)),
Err(e) => warn!("live usageMetadata decode failed: {e}"),
}
} else {
debug!("ignoring unknown live message: {v}");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::gemini::GeminiConfig;
use tokio::net::TcpListener;
async fn spawn_mock_live_server() -> std::net::SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(stream).await.unwrap();
let setup = ws.next().await.unwrap().unwrap();
let setup: Value = serde_json::from_slice(&setup.into_data()).unwrap();
assert!(
setup["setup"]["model"]
.as_str()
.unwrap()
.starts_with("models/")
);
assert_eq!(
setup["setup"]["generationConfig"]["responseModalities"][0],
"TEXT"
);
ws.send(Message::Text(
json!({"setupComplete": {}}).to_string().into(),
))
.await
.unwrap();
let turn = ws.next().await.unwrap().unwrap();
let turn: Value = serde_json::from_slice(&turn.into_data()).unwrap();
assert_eq!(
turn["clientContent"]["turns"][0]["parts"][0]["text"],
"hello"
);
ws.send(Message::Text(
json!({"serverContent": {"modelTurn": {"parts": [{"text": "hi "}]}}})
.to_string()
.into(),
))
.await
.unwrap();
ws.send(Message::Binary(
json!({"serverContent": {
"modelTurn": {"parts": [{"text": "there"}]},
"turnComplete": true
}})
.to_string()
.into_bytes()
.into(),
))
.await
.unwrap();
let _ = ws.close(None).await;
});
addr
}
#[tokio::test]
async fn live_handshake_text_roundtrip() {
let addr = spawn_mock_live_server().await;
let g = Gemini::new(
"gemini-2.5-flash",
GeminiConfig {
base_url: format!("http://{addr}"),
api_key: "k".into(),
..GeminiConfig::default()
},
)
.unwrap();
let mut session = g.connect_live(LiveConfig::default()).await.unwrap();
session.send_text("hello", true).await.unwrap();
let mut text = String::new();
let mut turn_complete = false;
while let Some(ev) = session.recv().await.unwrap() {
match ev {
LiveEvent::Text(t) => text.push_str(&t),
LiveEvent::TurnComplete => {
turn_complete = true;
break;
}
other => panic!("unexpected event: {other:?}"),
}
}
assert_eq!(text, "hi there");
assert!(turn_complete);
}
#[tokio::test]
async fn refuses_empty_api_key() {
let g = Gemini::new(
"gemini-2.5-flash",
GeminiConfig {
base_url: "http://127.0.0.1:1".into(),
..GeminiConfig::default()
},
)
.unwrap();
let err = g.connect_live(LiveConfig::default()).await.unwrap_err();
assert!(err.to_string().contains("api_key"));
}
}