Skip to main content

gemini_live/
session.rs

1//! Session layer — the primary interface for the Gemini Live API.
2//!
3//! [`Session`] manages the full connection lifecycle: WebSocket connect,
4//! `setup` handshake, automatic reconnection, and typed send/receive.
5//!
6//! # Architecture
7//!
8//! ```text
9//!                        ┌──────────────┐
10//!                        │   Session    │  ← cheap Clone (Arc)
11//!                        │  (handle)    │
12//!                        └──┬───────┬───┘
13//!                 cmd_tx    │       │  event_rx (broadcast)
14//!                           ▼       ▼
15//!             ┌─────────────────────────────────┐
16//!             │          Runner task            │  ← tokio::spawn
17//!             │  ┌───────────┐  ┌────────────┐  │
18//!             │  │ Send Loop │  │ Recv Loop  │  │
19//!             │  │ (ws sink) │  │ (ws stream)│  │
20//!             │  └───────────┘  └────────────┘  │
21//!             │  reconnect · GoAway · resume    │
22//!             └─────────────────────────────────┘
23//! ```
24//!
25//! The runner is a single `tokio::spawn`'d task that uses `tokio::select!`
26//! to multiplex user commands and WebSocket frames.  Reconnection is
27//! transparent — messages buffer in the mpsc channel during downtime.
28
29use std::sync::atomic::{AtomicU8, Ordering};
30use std::sync::{Arc, Mutex};
31use std::time::Duration;
32
33use base64::Engine;
34use futures_util::Stream;
35use tokio::sync::{broadcast, mpsc};
36
37use crate::audio::INPUT_AUDIO_MIME;
38use crate::codec;
39use crate::error::SessionError;
40use crate::transport::{Connection, RawFrame, TransportConfig};
41use crate::types::*;
42
43/// Timeout for the `setup` → `setupComplete` handshake.
44const SETUP_TIMEOUT: Duration = Duration::from_secs(30);
45const EVENT_CHANNEL_CAPACITY: usize = 256;
46const COMMAND_CHANNEL_CAPACITY: usize = 64;
47
48// ── Public config types ──────────────────────────────────────────────────────
49
50/// Complete session configuration combining transport, protocol, and
51/// reconnection settings.
52pub struct SessionConfig {
53    pub transport: TransportConfig,
54    pub setup: SetupConfig,
55    pub reconnect: ReconnectPolicy,
56}
57
58/// Reconnection behaviour after an unexpected disconnect or `goAway`.
59///
60/// Backoff formula: `base_backoff × 2^(attempt − 1)`, capped at `max_backoff`.
61pub struct ReconnectPolicy {
62    /// Enable automatic reconnection.  Default: `true`.
63    pub enabled: bool,
64    /// Initial backoff delay.  Default: 500 ms.
65    pub base_backoff: Duration,
66    /// Maximum backoff delay.  Default: 5 s.
67    pub max_backoff: Duration,
68    /// Maximum reconnection attempts.  `None` = unlimited.  Default: `Some(10)`.
69    pub max_attempts: Option<u32>,
70}
71
72impl Default for ReconnectPolicy {
73    fn default() -> Self {
74        Self {
75            enabled: true,
76            base_backoff: Duration::from_millis(500),
77            max_backoff: Duration::from_secs(5),
78            max_attempts: Some(10),
79        }
80    }
81}
82
83/// Observable session state.
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum SessionStatus {
86    Connecting = 0,
87    Connected = 1,
88    Reconnecting = 2,
89    Closed = 3,
90}
91
92// ── Session handle ───────────────────────────────────────────────────────────
93
94/// An active session with the Gemini Live API.
95///
96/// Cheaply [`Clone`]able — all clones share the same underlying connection
97/// and runner task.  Each clone has its own event cursor, so events are
98/// never "stolen" between consumers.
99///
100/// Created via [`Session::connect`].
101pub struct Session {
102    cmd_tx: mpsc::Sender<Command>,
103    event_tx: broadcast::Sender<ServerEvent>,
104    event_rx: broadcast::Receiver<ServerEvent>,
105    state: Arc<SharedState>,
106}
107
108impl Clone for Session {
109    fn clone(&self) -> Self {
110        Self {
111            cmd_tx: self.cmd_tx.clone(),
112            event_tx: self.event_tx.clone(),
113            event_rx: self.event_tx.subscribe(),
114            state: self.state.clone(),
115        }
116    }
117}
118
119impl Session {
120    /// Connect to the Gemini Live API and complete the `setup` handshake.
121    ///
122    /// On success the session is immediately usable — `setupComplete` has
123    /// already been received.  A background runner task is spawned to
124    /// manage the connection and handle reconnection.
125    pub async fn connect(config: SessionConfig) -> Result<Self, SessionError> {
126        let (cmd_tx, cmd_rx) = mpsc::channel(COMMAND_CHANNEL_CAPACITY);
127        let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
128        let state = Arc::new(SharedState::new());
129
130        // 1. Establish WebSocket connection
131        state.set_status(SessionStatus::Connecting);
132        let mut conn = Connection::connect(&config.transport)
133            .await
134            .map_err(|e| SessionError::SetupFailed(e.to_string()))?;
135
136        // 2. Send setup and await setupComplete
137        do_handshake(&mut conn, &config.setup, None).await?;
138        state.set_status(SessionStatus::Connected);
139        tracing::info!("session established");
140
141        // 3. Spawn the background runner
142        let runner = Runner {
143            cmd_rx,
144            event_tx: event_tx.clone(),
145            conn,
146            config,
147            state: Arc::clone(&state),
148        };
149        tokio::spawn(runner.run());
150
151        let event_rx = event_tx.subscribe();
152        Ok(Self {
153            cmd_tx,
154            event_tx,
155            event_rx,
156            state,
157        })
158    }
159
160    /// Current session status.
161    pub fn status(&self) -> SessionStatus {
162        self.state.status()
163    }
164
165    // ── Send convenience methods ─────────────────────────────────────
166
167    /// Stream audio.  Accepts raw i16 little-endian PCM bytes — base64
168    /// encoding and `realtimeInput` wrapping are handled internally.
169    ///
170    /// **Performance note:** allocates a new `String` for base64 on every
171    /// call (`roadmap.md` P-1).  For zero-allocation streaming, use
172    /// [`AudioEncoder`](crate::audio::AudioEncoder) with [`send_raw`](Self::send_raw).
173    pub async fn send_audio(&self, pcm_i16_le: &[u8]) -> Result<(), SessionError> {
174        let b64 = base64::engine::general_purpose::STANDARD.encode(pcm_i16_le);
175        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
176            audio: Some(Blob {
177                data: b64,
178                mime_type: INPUT_AUDIO_MIME.into(),
179            }),
180            video: None,
181            text: None,
182            activity_start: None,
183            activity_end: None,
184            audio_stream_end: None,
185        }))
186        .await
187    }
188
189    /// Stream a video frame.  Accepts encoded JPEG/PNG bytes.
190    pub async fn send_video(&self, data: &[u8], mime: &str) -> Result<(), SessionError> {
191        let b64 = base64::engine::general_purpose::STANDARD.encode(data);
192        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
193            video: Some(Blob {
194                data: b64,
195                mime_type: mime.into(),
196            }),
197            audio: None,
198            text: None,
199            activity_start: None,
200            activity_end: None,
201            audio_stream_end: None,
202        }))
203        .await
204    }
205
206    /// Send text via the real-time input channel.
207    pub async fn send_text(&self, text: &str) -> Result<(), SessionError> {
208        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
209            text: Some(text.into()),
210            audio: None,
211            video: None,
212            activity_start: None,
213            activity_end: None,
214            audio_stream_end: None,
215        }))
216        .await
217    }
218
219    /// Send conversation history or incremental content.
220    pub async fn send_client_content(&self, content: ClientContent) -> Result<(), SessionError> {
221        self.send_raw(ClientMessage::ClientContent(content)).await
222    }
223
224    /// Manual VAD: signal that user activity (speech) has started.
225    pub async fn activity_start(&self) -> Result<(), SessionError> {
226        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
227            activity_start: Some(EmptyObject {}),
228            audio: None,
229            video: None,
230            text: None,
231            activity_end: None,
232            audio_stream_end: None,
233        }))
234        .await
235    }
236
237    /// Manual VAD: signal that user activity has ended.
238    pub async fn activity_end(&self) -> Result<(), SessionError> {
239        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
240            activity_end: Some(EmptyObject {}),
241            audio: None,
242            video: None,
243            text: None,
244            activity_start: None,
245            audio_stream_end: None,
246        }))
247        .await
248    }
249
250    /// Notify the server that the audio stream has ended (auto VAD mode).
251    pub async fn audio_stream_end(&self) -> Result<(), SessionError> {
252        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
253            audio_stream_end: Some(true),
254            audio: None,
255            video: None,
256            text: None,
257            activity_start: None,
258            activity_end: None,
259        }))
260        .await
261    }
262
263    /// Reply to one or more server-initiated function calls.
264    pub async fn send_tool_response(
265        &self,
266        responses: Vec<FunctionResponse>,
267    ) -> Result<(), SessionError> {
268        self.send_raw(ClientMessage::ToolResponse(ToolResponseMessage {
269            function_responses: responses,
270        }))
271        .await
272    }
273
274    /// Send an arbitrary [`ClientMessage`] (escape hatch for future types).
275    pub async fn send_raw(&self, msg: ClientMessage) -> Result<(), SessionError> {
276        self.cmd_tx
277            .send(Command::Send(Box::new(msg)))
278            .await
279            .map_err(|_| SessionError::Closed)
280    }
281
282    // ── Receive ──────────────────────────────────────────────────────
283
284    /// Create a new event [`Stream`].
285    ///
286    /// Each call produces an **independent** subscription — multiple streams
287    /// can coexist without stealing events.
288    pub fn events(&self) -> impl Stream<Item = ServerEvent> {
289        let rx = self.event_tx.subscribe();
290        futures_util::stream::unfold(rx, |mut rx| async move {
291            loop {
292                match rx.recv().await {
293                    Ok(event) => return Some((event, rx)),
294                    Err(broadcast::error::RecvError::Lagged(n)) => {
295                        tracing::warn!(n, "event stream lagged, some events were missed");
296                        continue;
297                    }
298                    Err(broadcast::error::RecvError::Closed) => return None,
299                }
300            }
301        })
302    }
303
304    /// Wait for the next event on this handle's cursor.
305    ///
306    /// Returns `None` when the session is permanently closed.
307    pub async fn next_event(&mut self) -> Option<ServerEvent> {
308        loop {
309            match self.event_rx.recv().await {
310                Ok(event) => return Some(event),
311                Err(broadcast::error::RecvError::Lagged(n)) => {
312                    tracing::warn!(n, "event consumer lagged, some events were missed");
313                    continue;
314                }
315                Err(broadcast::error::RecvError::Closed) => return None,
316            }
317        }
318    }
319
320    // ── Lifecycle ────────────────────────────────────────────────────
321
322    /// Gracefully close the session.
323    ///
324    /// Sends a WebSocket close frame and shuts down the background runner.
325    /// Other clones of this session will observe `SessionStatus::Closed`.
326    pub async fn close(self) -> Result<(), SessionError> {
327        let _ = self.cmd_tx.send(Command::Close).await;
328        Ok(())
329    }
330}
331
332// ── Internals ──────────────────────────────��─────────────────────────────────
333
334enum Command {
335    Send(Box<ClientMessage>),
336    Close,
337}
338
339// ── Shared state (survives reconnects) ─────────────────────────────���─────────
340
341struct SharedState {
342    resume_handle: Mutex<Option<String>>,
343    status: AtomicU8,
344}
345
346impl SharedState {
347    fn new() -> Self {
348        Self {
349            resume_handle: Mutex::new(None),
350            status: AtomicU8::new(SessionStatus::Connecting as u8),
351        }
352    }
353
354    fn status(&self) -> SessionStatus {
355        match self.status.load(Ordering::Relaxed) {
356            0 => SessionStatus::Connecting,
357            1 => SessionStatus::Connected,
358            2 => SessionStatus::Reconnecting,
359            _ => SessionStatus::Closed,
360        }
361    }
362
363    fn set_status(&self, s: SessionStatus) {
364        self.status.store(s as u8, Ordering::Relaxed);
365    }
366
367    fn resume_handle(&self) -> Option<String> {
368        self.resume_handle.lock().unwrap().clone()
369    }
370
371    fn set_resume_handle(&self, handle: Option<String>) {
372        *self.resume_handle.lock().unwrap() = handle;
373    }
374}
375
376// ── Runner task ──────────────────────────────────────────────────────────────
377
378enum DisconnectReason {
379    GoAway,
380    ConnectionLost,
381    UserClose,
382    SendersDropped,
383}
384
385struct Runner {
386    cmd_rx: mpsc::Receiver<Command>,
387    event_tx: broadcast::Sender<ServerEvent>,
388    conn: Connection,
389    config: SessionConfig,
390    state: Arc<SharedState>,
391}
392
393impl Runner {
394    async fn run(mut self) {
395        loop {
396            let reason = self.run_connected().await;
397
398            match reason {
399                DisconnectReason::UserClose | DisconnectReason::SendersDropped => {
400                    self.state.set_status(SessionStatus::Closed);
401                    tracing::info!("session closed");
402                    break;
403                }
404                DisconnectReason::GoAway | DisconnectReason::ConnectionLost => {
405                    if !self.config.reconnect.enabled {
406                        self.state.set_status(SessionStatus::Closed);
407                        let _ = self.event_tx.send(ServerEvent::Closed {
408                            reason: "disconnected (reconnect disabled)".into(),
409                        });
410                        break;
411                    }
412
413                    self.state.set_status(SessionStatus::Reconnecting);
414                    tracing::info!("attempting reconnection");
415
416                    match self.reconnect().await {
417                        Ok(conn) => {
418                            self.conn = conn;
419                            self.state.set_status(SessionStatus::Connected);
420                            tracing::info!("reconnected successfully");
421                        }
422                        Err(e) => {
423                            self.state.set_status(SessionStatus::Closed);
424                            let _ = self.event_tx.send(ServerEvent::Error(ApiError {
425                                message: e.to_string(),
426                            }));
427                            break;
428                        }
429                    }
430                }
431            }
432        }
433    }
434
435    /// Drive the connection: forward commands to the WebSocket, broadcast
436    /// received frames as events.  Returns the reason for disconnection.
437    async fn run_connected(&mut self) -> DisconnectReason {
438        loop {
439            tokio::select! {
440                cmd = self.cmd_rx.recv() => {
441                    match cmd {
442                        Some(Command::Send(msg)) => { let msg = *msg;
443                            match codec::encode(&msg) {
444                                Ok(json) => {
445                                    if let Err(e) = self.conn.send_text(&json).await {
446                                        tracing::warn!(error = %e, "send failed");
447                                        return DisconnectReason::ConnectionLost;
448                                    }
449                                }
450                                Err(e) => {
451                                    tracing::warn!(error = %e, "message encode failed, dropping");
452                                }
453                            }
454                        }
455                        Some(Command::Close) => {
456                            let _ = self.conn.send_close().await;
457                            return DisconnectReason::UserClose;
458                        }
459                        None => {
460                            let _ = self.conn.send_close().await;
461                            return DisconnectReason::SendersDropped;
462                        }
463                    }
464                }
465                frame = self.conn.recv() => {
466                    match frame {
467                        Ok(RawFrame::Text(text)) => {
468                            if let Some(reason) = self.try_decode_and_process(&text) {
469                                return reason;
470                            }
471                        }
472                        Ok(RawFrame::Binary(data)) => {
473                            // Gemini Live API may send JSON as binary frames.
474                            if let Ok(text) = std::str::from_utf8(&data)
475                                && let Some(reason) = self.try_decode_and_process(text)
476                            {
477                                return reason;
478                            }
479                        }
480                        Ok(RawFrame::Close(reason)) => {
481                            let _ = self.event_tx.send(ServerEvent::Closed {
482                                reason: reason.unwrap_or_default(),
483                            });
484                            return DisconnectReason::ConnectionLost;
485                        }
486                        Err(e) => {
487                            tracing::warn!(error = %e, "recv error");
488                            return DisconnectReason::ConnectionLost;
489                        }
490                    }
491                }
492            }
493        }
494    }
495
496    /// Decode a server message, track session state, and broadcast events.
497    /// Returns `true` if the message contained a `goAway`.
498    /// Try to decode a JSON string and process it. Returns `Some(reason)` if
499    /// the connection loop should exit.
500    fn try_decode_and_process(&self, text: &str) -> Option<DisconnectReason> {
501        match codec::decode(text) {
502            Ok(msg) => {
503                if self.process_message(msg) {
504                    Some(DisconnectReason::GoAway)
505                } else {
506                    None
507                }
508            }
509            Err(e) => {
510                tracing::warn!(error = %e, "failed to decode server message");
511                None
512            }
513        }
514    }
515
516    fn process_message(&self, msg: ServerMessage) -> bool {
517        // Track the latest resume handle for reconnection.
518        if let Some(ref sr) = msg.session_resumption_update
519            && let Some(ref handle) = sr.new_handle
520        {
521            self.state.set_resume_handle(Some(handle.clone()));
522        }
523
524        let is_go_away = msg.go_away.is_some();
525
526        for event in codec::into_events(msg) {
527            let _ = self.event_tx.send(event);
528        }
529
530        is_go_away
531    }
532
533    /// Attempt reconnection with exponential backoff.
534    async fn reconnect(&mut self) -> Result<Connection, SessionError> {
535        let policy = &self.config.reconnect;
536        let mut attempt = 0u32;
537
538        loop {
539            attempt += 1;
540            if policy.max_attempts.is_some_and(|max| attempt > max) {
541                return Err(SessionError::ReconnectExhausted {
542                    attempts: attempt - 1,
543                });
544            }
545
546            let backoff = compute_backoff(policy, attempt);
547            tracing::debug!(attempt, ?backoff, "reconnect backoff");
548            tokio::time::sleep(backoff).await;
549
550            let mut conn = match Connection::connect(&self.config.transport).await {
551                Ok(c) => c,
552                Err(e) => {
553                    tracing::warn!(attempt, error = %e, "reconnect connect failed");
554                    continue;
555                }
556            };
557
558            let resume_handle = self.state.resume_handle();
559            match do_handshake(&mut conn, &self.config.setup, resume_handle).await {
560                Ok(()) => return Ok(conn),
561                Err(e) => {
562                    tracing::warn!(attempt, error = %e, "reconnect handshake failed");
563                    continue;
564                }
565            }
566        }
567    }
568}
569
570// ── Handshake ────────────────────────────────────────────────────────────────
571
572/// Send `setup` and wait for `setupComplete`.
573///
574/// If `resume_handle` is `Some`, it is injected into the setup's
575/// `sessionResumption` config so the server can resume state.
576async fn do_handshake(
577    conn: &mut Connection,
578    setup: &SetupConfig,
579    resume_handle: Option<String>,
580) -> Result<(), SessionError> {
581    let mut setup = setup.clone();
582    if let Some(handle) = resume_handle {
583        let sr = setup
584            .session_resumption
585            .get_or_insert_with(SessionResumptionConfig::default);
586        sr.handle = Some(handle);
587    }
588
589    let json = codec::encode(&ClientMessage::Setup(setup))?;
590    tracing::debug!(setup_json = %json, "sending setup message");
591    conn.send_text(&json)
592        .await
593        .map_err(|e| SessionError::SetupFailed(e.to_string()))?;
594
595    tokio::time::timeout(SETUP_TIMEOUT, wait_setup_complete(conn))
596        .await
597        .map_err(|_| SessionError::SetupTimeout(SETUP_TIMEOUT))?
598}
599
600async fn wait_setup_complete(conn: &mut Connection) -> Result<(), SessionError> {
601    loop {
602        match conn.recv().await {
603            Ok(RawFrame::Text(text)) => {
604                tracing::debug!(raw = %text, "received text during setup");
605                match try_parse_setup_response(&text)? {
606                    SetupResult::Complete => return Ok(()),
607                    SetupResult::Continue => {}
608                }
609            }
610            Ok(RawFrame::Binary(data)) => {
611                // Gemini Live API may send JSON as binary frames.
612                if let Ok(text) = std::str::from_utf8(&data) {
613                    tracing::debug!(raw = %text, "received binary (UTF-8) during setup");
614                    match try_parse_setup_response(text)? {
615                        SetupResult::Complete => return Ok(()),
616                        SetupResult::Continue => {}
617                    }
618                }
619            }
620            Ok(RawFrame::Close(reason)) => {
621                return Err(SessionError::SetupFailed(format!(
622                    "closed during setup: {}",
623                    reason.unwrap_or_default()
624                )));
625            }
626            Err(e) => return Err(SessionError::SetupFailed(e.to_string())),
627        }
628    }
629}
630
631enum SetupResult {
632    Complete,
633    Continue,
634}
635
636fn try_parse_setup_response(text: &str) -> Result<SetupResult, SessionError> {
637    let msg = codec::decode(text).map_err(|e| SessionError::SetupFailed(e.to_string()))?;
638    if msg.setup_complete.is_some() {
639        return Ok(SetupResult::Complete);
640    }
641    if let Some(err) = msg.error {
642        return Err(SessionError::Api(err.message));
643    }
644    Ok(SetupResult::Continue)
645}
646
647// ── Backoff ──────────────────────────────────────────────────────────────────
648
649/// Exponential backoff: `base × 2^(attempt − 1)`, capped at `max`.
650fn compute_backoff(policy: &ReconnectPolicy, attempt: u32) -> Duration {
651    let exp = attempt.saturating_sub(1).min(10);
652    let factor = 2u64.saturating_pow(exp);
653    let ms = policy.base_backoff.as_millis() as u64 * factor;
654    Duration::from_millis(ms.min(policy.max_backoff.as_millis() as u64))
655}
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660
661    #[test]
662    fn backoff_exponential_with_cap() {
663        let policy = ReconnectPolicy {
664            base_backoff: Duration::from_millis(500),
665            max_backoff: Duration::from_secs(5),
666            ..Default::default()
667        };
668        assert_eq!(compute_backoff(&policy, 1), Duration::from_millis(500));
669        assert_eq!(compute_backoff(&policy, 2), Duration::from_millis(1000));
670        assert_eq!(compute_backoff(&policy, 3), Duration::from_millis(2000));
671        assert_eq!(compute_backoff(&policy, 4), Duration::from_millis(4000));
672        assert_eq!(compute_backoff(&policy, 5), Duration::from_secs(5)); // capped
673        assert_eq!(compute_backoff(&policy, 100), Duration::from_secs(5));
674    }
675
676    #[test]
677    fn status_round_trip() {
678        let state = SharedState::new();
679        assert_eq!(state.status(), SessionStatus::Connecting);
680
681        state.set_status(SessionStatus::Connected);
682        assert_eq!(state.status(), SessionStatus::Connected);
683
684        state.set_status(SessionStatus::Reconnecting);
685        assert_eq!(state.status(), SessionStatus::Reconnecting);
686
687        state.set_status(SessionStatus::Closed);
688        assert_eq!(state.status(), SessionStatus::Closed);
689    }
690
691    #[test]
692    fn resume_handle_tracking() {
693        let state = SharedState::new();
694        assert!(state.resume_handle().is_none());
695
696        state.set_resume_handle(Some("h1".into()));
697        assert_eq!(state.resume_handle().as_deref(), Some("h1"));
698
699        state.set_resume_handle(Some("h2".into()));
700        assert_eq!(state.resume_handle().as_deref(), Some("h2"));
701
702        state.set_resume_handle(None);
703        assert!(state.resume_handle().is_none());
704    }
705
706    #[test]
707    fn default_reconnect_policy() {
708        let p = ReconnectPolicy::default();
709        assert!(p.enabled);
710        assert_eq!(p.base_backoff, Duration::from_millis(500));
711        assert_eq!(p.max_backoff, Duration::from_secs(5));
712        assert_eq!(p.max_attempts, Some(10));
713    }
714}