Skip to main content

gemini_live/
session.rs

1//! Session layer — the primary interface for the Gemini Live API.
2//!
3//! [`Session`] manages the full connection lifecycle: WebSocket connect,
4//! `setup` handshake, automatic reconnection, and typed send/receive.
5//! Because Live API WebSocket sessions are finite-lived, this layer also owns
6//! `goAway` handling and session-resumption handoff.
7//!
8//! # Architecture
9//!
10//! ```text
11//!                        ┌──────────────┐
12//!                        │   Session    │  ← cheap Clone (Arc)
13//!                        │  (handle)    │
14//!                        └──┬───────┬───┘
15//!                 cmd_tx    │       │  event_rx (broadcast)
16//!                           ▼       ▼
17//!             ┌─────────────────────────────────┐
18//!             │          Runner task            │  ← tokio::spawn
19//!             │  ┌───────────┐  ┌────────────┐  │
20//!             │  │ Send Loop │  │ Recv Loop  │  │
21//!             │  │ (ws sink) │  │ (ws stream)│  │
22//!             │  └───────────┘  └────────────┘  │
23//!             │  reconnect · GoAway · resume    │
24//!             └─────────────────────────────────┘
25//! ```
26//!
27//! The runner is a single `tokio::spawn`'d task that uses `tokio::select!`
28//! to multiplex user commands and WebSocket frames.  Reconnection is
29//! transparent — messages buffer in the mpsc channel during downtime.
30
31use std::sync::atomic::{AtomicU8, Ordering};
32use std::sync::{Arc, Mutex};
33use std::time::Duration;
34
35use base64::Engine;
36use futures_util::Stream;
37use tokio::sync::{broadcast, mpsc};
38
39use crate::codec;
40use crate::error::SessionError;
41use crate::transport::{Connection, RawFrame, TransportConfig};
42use crate::types::*;
43
44/// Timeout for the `setup` → `setupComplete` handshake.
45const SETUP_TIMEOUT: Duration = Duration::from_secs(30);
46const EVENT_CHANNEL_CAPACITY: usize = 256;
47const COMMAND_CHANNEL_CAPACITY: usize = 64;
48
49// ── Public config types ──────────────────────────────────────────────────────
50
51/// Complete session configuration combining transport, protocol, and
52/// reconnection settings.
53///
54/// `setup` is sent on the initial connect and again on every reconnect. When a
55/// fresh resume handle exists, the session layer injects it into
56/// `setup.session_resumption.handle` before sending.
57#[derive(Debug, Clone)]
58pub struct SessionConfig {
59    pub transport: TransportConfig,
60    pub setup: SetupConfig,
61    pub reconnect: ReconnectPolicy,
62}
63
64/// Reconnection behaviour after an unexpected disconnect or `goAway`.
65///
66/// Backoff formula: `base_backoff × 2^(attempt − 1)`, capped at `max_backoff`.
67///
68/// Outbound messages continue to queue in the command channel while reconnect
69/// attempts are in progress.
70#[derive(Debug, Clone)]
71pub struct ReconnectPolicy {
72    /// Enable automatic reconnection.  Default: `true`.
73    pub enabled: bool,
74    /// Initial backoff delay.  Default: 500 ms.
75    pub base_backoff: Duration,
76    /// Maximum backoff delay.  Default: 5 s.
77    pub max_backoff: Duration,
78    /// Maximum reconnection attempts.  `None` = unlimited.  Default: `Some(10)`.
79    pub max_attempts: Option<u32>,
80}
81
82impl Default for ReconnectPolicy {
83    fn default() -> Self {
84        Self {
85            enabled: true,
86            base_backoff: Duration::from_millis(500),
87            max_backoff: Duration::from_secs(5),
88            max_attempts: Some(10),
89        }
90    }
91}
92
93/// Observable session state.
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum SessionStatus {
96    Connecting = 0,
97    Connected = 1,
98    Reconnecting = 2,
99    Closed = 3,
100}
101
102/// Observable items from the session receive cursor.
103///
104/// Most callers only care about [`ServerEvent`] values and can continue using
105/// [`Session::next_event`]. Callers that need visibility into dropped events
106/// should use [`Session::next_observed_event`].
107#[derive(Debug, Clone)]
108pub enum SessionObservation {
109    Event(ServerEvent),
110    Lagged { count: u64 },
111}
112
113// ── Session handle ───────────────────────────────────────────────────────────
114
115/// An active session with the Gemini Live API.
116///
117/// Cheaply [`Clone`]able — all clones share the same underlying connection
118/// and runner task.  Each clone has its own event cursor, so events are
119/// never "stolen" between consumers.
120///
121/// Created via [`Session::connect`].
122pub struct Session {
123    cmd_tx: mpsc::Sender<Command>,
124    event_tx: broadcast::Sender<ServerEvent>,
125    event_rx: broadcast::Receiver<ServerEvent>,
126    state: Arc<SharedState>,
127}
128
129impl Clone for Session {
130    fn clone(&self) -> Self {
131        Self {
132            cmd_tx: self.cmd_tx.clone(),
133            event_tx: self.event_tx.clone(),
134            event_rx: self.event_tx.subscribe(),
135            state: self.state.clone(),
136        }
137    }
138}
139
140impl Session {
141    /// Connect to the Gemini Live API and complete the `setup` handshake.
142    ///
143    /// On success the session is immediately usable — `setupComplete` has
144    /// already been received.  A background runner task is spawned to
145    /// manage the connection and handle reconnection.
146    pub async fn connect(config: SessionConfig) -> Result<Self, SessionError> {
147        let (cmd_tx, cmd_rx) = mpsc::channel(COMMAND_CHANNEL_CAPACITY);
148        let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
149        let state = Arc::new(SharedState::new());
150
151        // 1. Establish WebSocket connection
152        state.set_status(SessionStatus::Connecting);
153        let mut conn = Connection::connect(&config.transport)
154            .await
155            .map_err(|e| SessionError::SetupFailed(format_error_chain(&e)))?;
156
157        // 2. Send setup and await setupComplete
158        do_handshake(&mut conn, &config.setup, None).await?;
159        state.set_status(SessionStatus::Connected);
160        tracing::info!("session established");
161
162        // 3. Spawn the background runner
163        let runner = Runner {
164            cmd_rx,
165            event_tx: event_tx.clone(),
166            conn,
167            config,
168            state: Arc::clone(&state),
169        };
170        tokio::spawn(runner.run());
171
172        let event_rx = event_tx.subscribe();
173        Ok(Self {
174            cmd_tx,
175            event_tx,
176            event_rx,
177            state,
178        })
179    }
180
181    /// Current session status.
182    pub fn status(&self) -> SessionStatus {
183        self.state.status()
184    }
185
186    // ── Send convenience methods ─────────────────────────────────────
187
188    /// Stream audio.  Accepts raw i16 little-endian PCM bytes — base64
189    /// encoding and `realtimeInput` wrapping are handled internally.
190    ///
191    /// **Performance note:** allocates a new `String` for base64 on every
192    /// call (`roadmap.md` P-1).  For zero-allocation streaming, use
193    /// [`AudioEncoder`](crate::audio::AudioEncoder) with [`send_raw`](Self::send_raw).
194    pub async fn send_audio(&self, pcm_i16_le: &[u8]) -> Result<(), SessionError> {
195        self.send_audio_at_rate(pcm_i16_le, crate::audio::INPUT_SAMPLE_RATE)
196            .await
197    }
198
199    /// Stream audio at a specific sample rate.
200    ///
201    /// Like [`send_audio`](Self::send_audio) but lets you specify the sample
202    /// rate (e.g. for mic capture at the device's native rate).  The server
203    /// resamples as needed.
204    ///
205    /// **Performance note:** allocates a new `String` for base64 on every
206    /// call (`roadmap.md` P-1).  For zero-allocation streaming, use
207    /// [`AudioEncoder`](crate::audio::AudioEncoder) with [`send_raw`](Self::send_raw).
208    pub async fn send_audio_at_rate(
209        &self,
210        pcm_i16_le: &[u8],
211        sample_rate: u32,
212    ) -> Result<(), SessionError> {
213        let b64 = base64::engine::general_purpose::STANDARD.encode(pcm_i16_le);
214        let mime = format!("audio/pcm;rate={sample_rate}");
215        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
216            audio: Some(Blob {
217                data: b64,
218                mime_type: mime,
219            }),
220            ..Default::default()
221        }))
222        .await
223    }
224
225    /// Stream a video frame.  Accepts encoded JPEG/PNG bytes.
226    pub async fn send_video(&self, data: &[u8], mime: &str) -> Result<(), SessionError> {
227        let b64 = base64::engine::general_purpose::STANDARD.encode(data);
228        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
229            video: Some(Blob {
230                data: b64,
231                mime_type: mime.into(),
232            }),
233            ..Default::default()
234        }))
235        .await
236    }
237
238    /// Send text via the real-time input channel.
239    pub async fn send_text(&self, text: &str) -> Result<(), SessionError> {
240        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
241            text: Some(text.into()),
242            ..Default::default()
243        }))
244        .await
245    }
246
247    /// Send conversation history or incremental content.
248    pub async fn send_client_content(&self, content: ClientContent) -> Result<(), SessionError> {
249        self.send_raw(ClientMessage::ClientContent(content)).await
250    }
251
252    /// Manual VAD: signal that user activity (speech) has started.
253    pub async fn activity_start(&self) -> Result<(), SessionError> {
254        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
255            activity_start: Some(EmptyObject {}),
256            ..Default::default()
257        }))
258        .await
259    }
260
261    /// Manual VAD: signal that user activity has ended.
262    pub async fn activity_end(&self) -> Result<(), SessionError> {
263        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
264            activity_end: Some(EmptyObject {}),
265            ..Default::default()
266        }))
267        .await
268    }
269
270    /// Notify the server that the audio stream has ended (auto VAD mode).
271    pub async fn audio_stream_end(&self) -> Result<(), SessionError> {
272        self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
273            audio_stream_end: Some(true),
274            ..Default::default()
275        }))
276        .await
277    }
278
279    /// Reply to one or more server-initiated function calls.
280    pub async fn send_tool_response(
281        &self,
282        responses: Vec<FunctionResponse>,
283    ) -> Result<(), SessionError> {
284        self.send_raw(ClientMessage::ToolResponse(ToolResponseMessage {
285            function_responses: responses,
286        }))
287        .await
288    }
289
290    /// Send an arbitrary [`ClientMessage`] (escape hatch for future types).
291    pub async fn send_raw(&self, msg: ClientMessage) -> Result<(), SessionError> {
292        self.cmd_tx
293            .send(Command::Send(Box::new(msg)))
294            .await
295            .map_err(|_| SessionError::Closed)
296    }
297
298    // ── Receive ──────────────────────────────────────────────────────
299
300    /// Create a new event [`Stream`].
301    ///
302    /// Each call produces an **independent** subscription — multiple streams
303    /// can coexist without stealing events.
304    pub fn events(&self) -> impl Stream<Item = ServerEvent> {
305        let rx = self.event_tx.subscribe();
306        futures_util::stream::unfold(rx, |mut rx| async move {
307            loop {
308                match rx.recv().await {
309                    Ok(event) => return Some((event, rx)),
310                    Err(broadcast::error::RecvError::Lagged(n)) => {
311                        tracing::warn!(n, "event stream lagged, some events were missed");
312                        continue;
313                    }
314                    Err(broadcast::error::RecvError::Closed) => return None,
315                }
316            }
317        })
318    }
319
320    /// Wait for the next event on this handle's cursor.
321    ///
322    /// Returns `None` when the session is permanently closed.
323    pub async fn next_event(&mut self) -> Option<ServerEvent> {
324        loop {
325            match self.next_observed_event().await? {
326                SessionObservation::Event(event) => return Some(event),
327                SessionObservation::Lagged { count: n } => {
328                    tracing::warn!(n, "event consumer lagged, some events were missed");
329                }
330            }
331        }
332    }
333
334    /// Wait for the next observable item on this handle's cursor.
335    ///
336    /// Unlike [`Session::next_event`], this does not hide lagged broadcast
337    /// notifications. Callers can therefore surface lost-event state directly
338    /// in their own runtime or UI layer.
339    pub async fn next_observed_event(&mut self) -> Option<SessionObservation> {
340        match self.event_rx.recv().await {
341            Ok(event) => Some(SessionObservation::Event(event)),
342            Err(broadcast::error::RecvError::Lagged(n)) => {
343                Some(SessionObservation::Lagged { count: n })
344            }
345            Err(broadcast::error::RecvError::Closed) => None,
346        }
347    }
348
349    // ── Lifecycle ────────────────────────────────────────────────────
350
351    /// Gracefully close the session.
352    ///
353    /// Sends a WebSocket close frame and shuts down the background runner.
354    /// Other clones of this session will observe `SessionStatus::Closed`.
355    ///
356    /// This only enqueues the close request; it does not await runner-task
357    /// completion yet.
358    pub async fn close(self) -> Result<(), SessionError> {
359        let _ = self.cmd_tx.send(Command::Close).await;
360        Ok(())
361    }
362}
363
364// ── Internals ──────────────────────────────��─────────────────────────────────
365
366enum Command {
367    Send(Box<ClientMessage>),
368    Close,
369}
370
371// ── Shared state (survives reconnects) ─────────────────────────────���─────────
372
373struct SharedState {
374    resume_handle: Mutex<Option<String>>,
375    status: AtomicU8,
376}
377
378impl SharedState {
379    fn new() -> Self {
380        Self {
381            resume_handle: Mutex::new(None),
382            status: AtomicU8::new(SessionStatus::Connecting as u8),
383        }
384    }
385
386    fn status(&self) -> SessionStatus {
387        match self.status.load(Ordering::Relaxed) {
388            0 => SessionStatus::Connecting,
389            1 => SessionStatus::Connected,
390            2 => SessionStatus::Reconnecting,
391            _ => SessionStatus::Closed,
392        }
393    }
394
395    fn set_status(&self, s: SessionStatus) {
396        self.status.store(s as u8, Ordering::Relaxed);
397    }
398
399    fn resume_handle(&self) -> Option<String> {
400        self.resume_handle.lock().unwrap().clone()
401    }
402
403    fn set_resume_handle(&self, handle: Option<String>) {
404        *self.resume_handle.lock().unwrap() = handle;
405    }
406}
407
408// ── Runner task ──────────────────────────────────────────────────────────────
409
410enum DisconnectReason {
411    GoAway,
412    ConnectionLost,
413    UserClose,
414    SendersDropped,
415}
416
417struct Runner {
418    cmd_rx: mpsc::Receiver<Command>,
419    event_tx: broadcast::Sender<ServerEvent>,
420    conn: Connection,
421    config: SessionConfig,
422    state: Arc<SharedState>,
423}
424
425impl Runner {
426    async fn run(mut self) {
427        loop {
428            let reason = self.run_connected().await;
429
430            match reason {
431                DisconnectReason::UserClose | DisconnectReason::SendersDropped => {
432                    self.state.set_status(SessionStatus::Closed);
433                    tracing::info!("session closed");
434                    break;
435                }
436                DisconnectReason::GoAway | DisconnectReason::ConnectionLost => {
437                    if !self.config.reconnect.enabled {
438                        self.state.set_status(SessionStatus::Closed);
439                        let _ = self.event_tx.send(ServerEvent::Closed {
440                            reason: "disconnected (reconnect disabled)".into(),
441                        });
442                        break;
443                    }
444
445                    self.state.set_status(SessionStatus::Reconnecting);
446                    tracing::info!("attempting reconnection");
447
448                    match self.reconnect().await {
449                        Ok(conn) => {
450                            self.conn = conn;
451                            self.state.set_status(SessionStatus::Connected);
452                            tracing::info!("reconnected successfully");
453                        }
454                        Err(e) => {
455                            self.state.set_status(SessionStatus::Closed);
456                            let _ = self.event_tx.send(ServerEvent::Error(ApiError {
457                                message: e.to_string(),
458                            }));
459                            break;
460                        }
461                    }
462                }
463            }
464        }
465    }
466
467    /// Drive the connection: forward commands to the WebSocket, broadcast
468    /// received frames as events.  Returns the reason for disconnection.
469    async fn run_connected(&mut self) -> DisconnectReason {
470        loop {
471            tokio::select! {
472                cmd = self.cmd_rx.recv() => {
473                    match cmd {
474                        Some(Command::Send(msg)) => { let msg = *msg;
475                            match codec::encode(&msg) {
476                                Ok(json) => {
477                                    if let Err(e) = self.conn.send_text(&json).await {
478                                        tracing::warn!(error = %e, "send failed");
479                                        return DisconnectReason::ConnectionLost;
480                                    }
481                                }
482                                Err(e) => {
483                                    tracing::warn!(error = %e, "message encode failed, dropping");
484                                }
485                            }
486                        }
487                        Some(Command::Close) => {
488                            let _ = self.conn.send_close().await;
489                            return DisconnectReason::UserClose;
490                        }
491                        None => {
492                            let _ = self.conn.send_close().await;
493                            return DisconnectReason::SendersDropped;
494                        }
495                    }
496                }
497                frame = self.conn.recv() => {
498                    match frame {
499                        Ok(RawFrame::Text(text)) => {
500                            if let Some(reason) = self.try_decode_and_process(&text) {
501                                return reason;
502                            }
503                        }
504                        Ok(RawFrame::Binary(data)) => {
505                            // Gemini Live API may send JSON as binary frames.
506                            if let Ok(text) = std::str::from_utf8(&data)
507                                && let Some(reason) = self.try_decode_and_process(text)
508                            {
509                                return reason;
510                            }
511                        }
512                        Ok(RawFrame::Close(reason)) => {
513                            let _ = self.event_tx.send(ServerEvent::Closed {
514                                reason: reason.unwrap_or_default(),
515                            });
516                            return DisconnectReason::ConnectionLost;
517                        }
518                        Err(e) => {
519                            tracing::warn!(error = %e, "recv error");
520                            return DisconnectReason::ConnectionLost;
521                        }
522                    }
523                }
524            }
525        }
526    }
527
528    /// Decode a server message, track session state, and broadcast events.
529    /// Returns `true` if the message contained a `goAway`.
530    /// Try to decode a JSON string and process it. Returns `Some(reason)` if
531    /// the connection loop should exit.
532    fn try_decode_and_process(&self, text: &str) -> Option<DisconnectReason> {
533        match codec::decode(text) {
534            Ok(msg) => {
535                if self.process_message(msg) {
536                    Some(DisconnectReason::GoAway)
537                } else {
538                    None
539                }
540            }
541            Err(e) => {
542                tracing::warn!(error = %e, "failed to decode server message");
543                None
544            }
545        }
546    }
547
548    fn process_message(&self, msg: ServerMessage) -> bool {
549        // Track the latest resume handle for reconnection.
550        if let Some(ref sr) = msg.session_resumption_update
551            && let Some(ref handle) = sr.new_handle
552        {
553            self.state.set_resume_handle(Some(handle.clone()));
554        }
555
556        let is_go_away = msg.go_away.is_some();
557
558        for event in codec::into_events(msg) {
559            let _ = self.event_tx.send(event);
560        }
561
562        is_go_away
563    }
564
565    /// Attempt reconnection with exponential backoff.
566    async fn reconnect(&mut self) -> Result<Connection, SessionError> {
567        let policy = &self.config.reconnect;
568        let mut attempt = 0u32;
569
570        loop {
571            attempt += 1;
572            if policy.max_attempts.is_some_and(|max| attempt > max) {
573                return Err(SessionError::ReconnectExhausted {
574                    attempts: attempt - 1,
575                });
576            }
577
578            let backoff = compute_backoff(policy, attempt);
579            tracing::debug!(attempt, ?backoff, "reconnect backoff");
580            tokio::time::sleep(backoff).await;
581
582            let mut conn = match Connection::connect(&self.config.transport).await {
583                Ok(c) => c,
584                Err(e) => {
585                    tracing::warn!(attempt, error = %e, "reconnect connect failed");
586                    continue;
587                }
588            };
589
590            let resume_handle = self.state.resume_handle();
591            match do_handshake(&mut conn, &self.config.setup, resume_handle).await {
592                Ok(()) => return Ok(conn),
593                Err(e) => {
594                    tracing::warn!(attempt, error = %e, "reconnect handshake failed");
595                    continue;
596                }
597            }
598        }
599    }
600}
601
602// ── Handshake ────────────────────────────────────────────────────────────────
603
604/// Send `setup` and wait for `setupComplete`.
605///
606/// If `resume_handle` is `Some`, it is injected into the setup's
607/// `sessionResumption` config so the server can resume state.
608async fn do_handshake(
609    conn: &mut Connection,
610    setup: &SetupConfig,
611    resume_handle: Option<String>,
612) -> Result<(), SessionError> {
613    let setup = setup_for_handshake(setup, resume_handle);
614
615    let json = codec::encode(&ClientMessage::Setup(setup))?;
616    tracing::debug!(setup_json = %json, "sending setup message");
617    conn.send_text(&json)
618        .await
619        .map_err(|e| SessionError::SetupFailed(format_error_chain(&e)))?;
620
621    tokio::time::timeout(SETUP_TIMEOUT, wait_setup_complete(conn))
622        .await
623        .map_err(|_| SessionError::SetupTimeout(SETUP_TIMEOUT))?
624}
625
626fn setup_for_handshake(setup: &SetupConfig, resume_handle: Option<String>) -> SetupConfig {
627    let mut setup = setup.clone();
628    if let Some(handle) = resume_handle {
629        let sr = setup
630            .session_resumption
631            .get_or_insert_with(SessionResumptionConfig::default);
632        sr.handle = Some(handle);
633        // Initial-history mode is only valid on a fresh session before the
634        // first realtime input. Resumed sessions continue existing history and
635        // must not wait for new bootstrap `clientContent`.
636        setup.history_config = None;
637    }
638    setup
639}
640
641async fn wait_setup_complete(conn: &mut Connection) -> Result<(), SessionError> {
642    loop {
643        match conn.recv().await {
644            Ok(RawFrame::Text(text)) => {
645                tracing::debug!(raw = %text, "received text during setup");
646                match try_parse_setup_response(&text)? {
647                    SetupResult::Complete => return Ok(()),
648                    SetupResult::Continue => {}
649                }
650            }
651            Ok(RawFrame::Binary(data)) => {
652                // Gemini Live API may send JSON as binary frames.
653                if let Ok(text) = std::str::from_utf8(&data) {
654                    tracing::debug!(raw = %text, "received binary (UTF-8) during setup");
655                    match try_parse_setup_response(text)? {
656                        SetupResult::Complete => return Ok(()),
657                        SetupResult::Continue => {}
658                    }
659                }
660            }
661            Ok(RawFrame::Close(reason)) => {
662                return Err(SessionError::SetupFailed(format!(
663                    "closed during setup: {}",
664                    reason.unwrap_or_default()
665                )));
666            }
667            Err(e) => return Err(SessionError::SetupFailed(format_error_chain(&e))),
668        }
669    }
670}
671
672enum SetupResult {
673    Complete,
674    Continue,
675}
676
677fn try_parse_setup_response(text: &str) -> Result<SetupResult, SessionError> {
678    let msg = codec::decode(text).map_err(|e| SessionError::SetupFailed(format_error_chain(&e)))?;
679    if msg.setup_complete.is_some() {
680        return Ok(SetupResult::Complete);
681    }
682    if let Some(err) = msg.error {
683        return Err(SessionError::Api(err.message));
684    }
685    Ok(SetupResult::Continue)
686}
687
688fn format_error_chain(error: &dyn std::error::Error) -> String {
689    let mut message = error.to_string();
690    let mut current = error.source();
691    while let Some(source) = current {
692        let source_text = source.to_string();
693        if !source_text.is_empty() && !message.ends_with(&source_text) {
694            message.push_str(": ");
695            message.push_str(&source_text);
696        }
697        current = source.source();
698    }
699    message
700}
701
702// ── Backoff ──────────────────────────────────────────────────────────────────
703
704/// Exponential backoff: `base × 2^(attempt − 1)`, capped at `max`.
705fn compute_backoff(policy: &ReconnectPolicy, attempt: u32) -> Duration {
706    let exp = attempt.saturating_sub(1).min(10);
707    let factor = 2u64.saturating_pow(exp);
708    let ms = policy.base_backoff.as_millis() as u64 * factor;
709    Duration::from_millis(ms.min(policy.max_backoff.as_millis() as u64))
710}
711
712#[cfg(test)]
713mod tests {
714    use crate::error::{BearerTokenError, ConnectError};
715    use crate::types::HistoryConfig;
716
717    use super::*;
718
719    #[test]
720    fn backoff_exponential_with_cap() {
721        let policy = ReconnectPolicy {
722            base_backoff: Duration::from_millis(500),
723            max_backoff: Duration::from_secs(5),
724            ..Default::default()
725        };
726        assert_eq!(compute_backoff(&policy, 1), Duration::from_millis(500));
727        assert_eq!(compute_backoff(&policy, 2), Duration::from_millis(1000));
728        assert_eq!(compute_backoff(&policy, 3), Duration::from_millis(2000));
729        assert_eq!(compute_backoff(&policy, 4), Duration::from_millis(4000));
730        assert_eq!(compute_backoff(&policy, 5), Duration::from_secs(5)); // capped
731        assert_eq!(compute_backoff(&policy, 100), Duration::from_secs(5));
732    }
733
734    #[test]
735    fn status_round_trip() {
736        let state = SharedState::new();
737        assert_eq!(state.status(), SessionStatus::Connecting);
738
739        state.set_status(SessionStatus::Connected);
740        assert_eq!(state.status(), SessionStatus::Connected);
741
742        state.set_status(SessionStatus::Reconnecting);
743        assert_eq!(state.status(), SessionStatus::Reconnecting);
744
745        state.set_status(SessionStatus::Closed);
746        assert_eq!(state.status(), SessionStatus::Closed);
747    }
748
749    #[test]
750    fn resume_handle_tracking() {
751        let state = SharedState::new();
752        assert!(state.resume_handle().is_none());
753
754        state.set_resume_handle(Some("h1".into()));
755        assert_eq!(state.resume_handle().as_deref(), Some("h1"));
756
757        state.set_resume_handle(Some("h2".into()));
758        assert_eq!(state.resume_handle().as_deref(), Some("h2"));
759
760        state.set_resume_handle(None);
761        assert!(state.resume_handle().is_none());
762    }
763
764    #[test]
765    fn default_reconnect_policy() {
766        let p = ReconnectPolicy::default();
767        assert!(p.enabled);
768        assert_eq!(p.base_backoff, Duration::from_millis(500));
769        assert_eq!(p.max_backoff, Duration::from_secs(5));
770        assert_eq!(p.max_attempts, Some(10));
771    }
772
773    #[test]
774    fn handshake_setup_strips_initial_history_when_resuming() {
775        let setup = SetupConfig {
776            model: "models/test".into(),
777            history_config: Some(HistoryConfig {
778                initial_history_in_client_content: Some(true),
779            }),
780            ..Default::default()
781        };
782
783        let resumed = setup_for_handshake(&setup, Some("resume-1".into()));
784        assert_eq!(
785            resumed
786                .session_resumption
787                .as_ref()
788                .and_then(|config| config.handle.as_deref()),
789            Some("resume-1")
790        );
791        assert!(resumed.history_config.is_none());
792
793        let fresh = setup_for_handshake(&setup, None);
794        assert_eq!(fresh.history_config, setup.history_config);
795    }
796
797    #[test]
798    fn format_error_chain_includes_sources() {
799        let err = ConnectError::Auth(BearerTokenError::with_source(
800            "failed to refresh Google Cloud access token from Application Default Credentials",
801            std::io::Error::other("invalid_grant: Account has been deleted"),
802        ));
803
804        assert_eq!(
805            format_error_chain(&err),
806            "failed to obtain bearer token: failed to refresh Google Cloud access token from Application Default Credentials: invalid_grant: Account has been deleted"
807        );
808    }
809}