1use std::sync::atomic::{AtomicU8, Ordering};
32use std::sync::{Arc, Mutex};
33use std::time::Duration;
34
35use base64::Engine;
36use futures_util::Stream;
37use tokio::sync::{broadcast, mpsc};
38
39use crate::codec;
40use crate::error::SessionError;
41use crate::transport::{Connection, RawFrame, TransportConfig};
42use crate::types::*;
43
44const SETUP_TIMEOUT: Duration = Duration::from_secs(30);
46const EVENT_CHANNEL_CAPACITY: usize = 256;
47const COMMAND_CHANNEL_CAPACITY: usize = 64;
48
49#[derive(Debug, Clone)]
58pub struct SessionConfig {
59 pub transport: TransportConfig,
60 pub setup: SetupConfig,
61 pub reconnect: ReconnectPolicy,
62}
63
64#[derive(Debug, Clone)]
71pub struct ReconnectPolicy {
72 pub enabled: bool,
74 pub base_backoff: Duration,
76 pub max_backoff: Duration,
78 pub max_attempts: Option<u32>,
80}
81
82impl Default for ReconnectPolicy {
83 fn default() -> Self {
84 Self {
85 enabled: true,
86 base_backoff: Duration::from_millis(500),
87 max_backoff: Duration::from_secs(5),
88 max_attempts: Some(10),
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum SessionStatus {
96 Connecting = 0,
97 Connected = 1,
98 Reconnecting = 2,
99 Closed = 3,
100}
101
102#[derive(Debug, Clone)]
108pub enum SessionObservation {
109 Event(ServerEvent),
110 Lagged { count: u64 },
111}
112
113pub struct Session {
123 cmd_tx: mpsc::Sender<Command>,
124 event_tx: broadcast::Sender<ServerEvent>,
125 event_rx: broadcast::Receiver<ServerEvent>,
126 state: Arc<SharedState>,
127}
128
129impl Clone for Session {
130 fn clone(&self) -> Self {
131 Self {
132 cmd_tx: self.cmd_tx.clone(),
133 event_tx: self.event_tx.clone(),
134 event_rx: self.event_tx.subscribe(),
135 state: self.state.clone(),
136 }
137 }
138}
139
140impl Session {
141 pub async fn connect(config: SessionConfig) -> Result<Self, SessionError> {
147 let (cmd_tx, cmd_rx) = mpsc::channel(COMMAND_CHANNEL_CAPACITY);
148 let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
149 let state = Arc::new(SharedState::new());
150
151 state.set_status(SessionStatus::Connecting);
153 let mut conn = Connection::connect(&config.transport)
154 .await
155 .map_err(|e| SessionError::SetupFailed(format_error_chain(&e)))?;
156
157 do_handshake(&mut conn, &config.setup, None).await?;
159 state.set_status(SessionStatus::Connected);
160 tracing::info!("session established");
161
162 let runner = Runner {
164 cmd_rx,
165 event_tx: event_tx.clone(),
166 conn,
167 config,
168 state: Arc::clone(&state),
169 };
170 tokio::spawn(runner.run());
171
172 let event_rx = event_tx.subscribe();
173 Ok(Self {
174 cmd_tx,
175 event_tx,
176 event_rx,
177 state,
178 })
179 }
180
181 pub fn status(&self) -> SessionStatus {
183 self.state.status()
184 }
185
186 pub async fn send_audio(&self, pcm_i16_le: &[u8]) -> Result<(), SessionError> {
195 self.send_audio_at_rate(pcm_i16_le, crate::audio::INPUT_SAMPLE_RATE)
196 .await
197 }
198
199 pub async fn send_audio_at_rate(
209 &self,
210 pcm_i16_le: &[u8],
211 sample_rate: u32,
212 ) -> Result<(), SessionError> {
213 let b64 = base64::engine::general_purpose::STANDARD.encode(pcm_i16_le);
214 let mime = format!("audio/pcm;rate={sample_rate}");
215 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
216 audio: Some(Blob {
217 data: b64,
218 mime_type: mime,
219 }),
220 ..Default::default()
221 }))
222 .await
223 }
224
225 pub async fn send_video(&self, data: &[u8], mime: &str) -> Result<(), SessionError> {
227 let b64 = base64::engine::general_purpose::STANDARD.encode(data);
228 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
229 video: Some(Blob {
230 data: b64,
231 mime_type: mime.into(),
232 }),
233 ..Default::default()
234 }))
235 .await
236 }
237
238 pub async fn send_text(&self, text: &str) -> Result<(), SessionError> {
240 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
241 text: Some(text.into()),
242 ..Default::default()
243 }))
244 .await
245 }
246
247 pub async fn send_client_content(&self, content: ClientContent) -> Result<(), SessionError> {
249 self.send_raw(ClientMessage::ClientContent(content)).await
250 }
251
252 pub async fn activity_start(&self) -> Result<(), SessionError> {
254 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
255 activity_start: Some(EmptyObject {}),
256 ..Default::default()
257 }))
258 .await
259 }
260
261 pub async fn activity_end(&self) -> Result<(), SessionError> {
263 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
264 activity_end: Some(EmptyObject {}),
265 ..Default::default()
266 }))
267 .await
268 }
269
270 pub async fn audio_stream_end(&self) -> Result<(), SessionError> {
272 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
273 audio_stream_end: Some(true),
274 ..Default::default()
275 }))
276 .await
277 }
278
279 pub async fn send_tool_response(
281 &self,
282 responses: Vec<FunctionResponse>,
283 ) -> Result<(), SessionError> {
284 self.send_raw(ClientMessage::ToolResponse(ToolResponseMessage {
285 function_responses: responses,
286 }))
287 .await
288 }
289
290 pub async fn send_raw(&self, msg: ClientMessage) -> Result<(), SessionError> {
292 self.cmd_tx
293 .send(Command::Send(Box::new(msg)))
294 .await
295 .map_err(|_| SessionError::Closed)
296 }
297
298 pub fn events(&self) -> impl Stream<Item = ServerEvent> {
305 let rx = self.event_tx.subscribe();
306 futures_util::stream::unfold(rx, |mut rx| async move {
307 loop {
308 match rx.recv().await {
309 Ok(event) => return Some((event, rx)),
310 Err(broadcast::error::RecvError::Lagged(n)) => {
311 tracing::warn!(n, "event stream lagged, some events were missed");
312 continue;
313 }
314 Err(broadcast::error::RecvError::Closed) => return None,
315 }
316 }
317 })
318 }
319
320 pub async fn next_event(&mut self) -> Option<ServerEvent> {
324 loop {
325 match self.next_observed_event().await? {
326 SessionObservation::Event(event) => return Some(event),
327 SessionObservation::Lagged { count: n } => {
328 tracing::warn!(n, "event consumer lagged, some events were missed");
329 }
330 }
331 }
332 }
333
334 pub async fn next_observed_event(&mut self) -> Option<SessionObservation> {
340 match self.event_rx.recv().await {
341 Ok(event) => Some(SessionObservation::Event(event)),
342 Err(broadcast::error::RecvError::Lagged(n)) => {
343 Some(SessionObservation::Lagged { count: n })
344 }
345 Err(broadcast::error::RecvError::Closed) => None,
346 }
347 }
348
349 pub async fn close(self) -> Result<(), SessionError> {
359 let _ = self.cmd_tx.send(Command::Close).await;
360 Ok(())
361 }
362}
363
364enum Command {
367 Send(Box<ClientMessage>),
368 Close,
369}
370
371struct SharedState {
374 resume_handle: Mutex<Option<String>>,
375 status: AtomicU8,
376}
377
378impl SharedState {
379 fn new() -> Self {
380 Self {
381 resume_handle: Mutex::new(None),
382 status: AtomicU8::new(SessionStatus::Connecting as u8),
383 }
384 }
385
386 fn status(&self) -> SessionStatus {
387 match self.status.load(Ordering::Relaxed) {
388 0 => SessionStatus::Connecting,
389 1 => SessionStatus::Connected,
390 2 => SessionStatus::Reconnecting,
391 _ => SessionStatus::Closed,
392 }
393 }
394
395 fn set_status(&self, s: SessionStatus) {
396 self.status.store(s as u8, Ordering::Relaxed);
397 }
398
399 fn resume_handle(&self) -> Option<String> {
400 self.resume_handle.lock().unwrap().clone()
401 }
402
403 fn set_resume_handle(&self, handle: Option<String>) {
404 *self.resume_handle.lock().unwrap() = handle;
405 }
406}
407
408enum DisconnectReason {
411 GoAway,
412 ConnectionLost,
413 UserClose,
414 SendersDropped,
415}
416
417struct Runner {
418 cmd_rx: mpsc::Receiver<Command>,
419 event_tx: broadcast::Sender<ServerEvent>,
420 conn: Connection,
421 config: SessionConfig,
422 state: Arc<SharedState>,
423}
424
425impl Runner {
426 async fn run(mut self) {
427 loop {
428 let reason = self.run_connected().await;
429
430 match reason {
431 DisconnectReason::UserClose | DisconnectReason::SendersDropped => {
432 self.state.set_status(SessionStatus::Closed);
433 tracing::info!("session closed");
434 break;
435 }
436 DisconnectReason::GoAway | DisconnectReason::ConnectionLost => {
437 if !self.config.reconnect.enabled {
438 self.state.set_status(SessionStatus::Closed);
439 let _ = self.event_tx.send(ServerEvent::Closed {
440 reason: "disconnected (reconnect disabled)".into(),
441 });
442 break;
443 }
444
445 self.state.set_status(SessionStatus::Reconnecting);
446 tracing::info!("attempting reconnection");
447
448 match self.reconnect().await {
449 Ok(conn) => {
450 self.conn = conn;
451 self.state.set_status(SessionStatus::Connected);
452 tracing::info!("reconnected successfully");
453 }
454 Err(e) => {
455 self.state.set_status(SessionStatus::Closed);
456 let _ = self.event_tx.send(ServerEvent::Error(ApiError {
457 message: e.to_string(),
458 }));
459 break;
460 }
461 }
462 }
463 }
464 }
465 }
466
467 async fn run_connected(&mut self) -> DisconnectReason {
470 loop {
471 tokio::select! {
472 cmd = self.cmd_rx.recv() => {
473 match cmd {
474 Some(Command::Send(msg)) => { let msg = *msg;
475 match codec::encode(&msg) {
476 Ok(json) => {
477 if let Err(e) = self.conn.send_text(&json).await {
478 tracing::warn!(error = %e, "send failed");
479 return DisconnectReason::ConnectionLost;
480 }
481 }
482 Err(e) => {
483 tracing::warn!(error = %e, "message encode failed, dropping");
484 }
485 }
486 }
487 Some(Command::Close) => {
488 let _ = self.conn.send_close().await;
489 return DisconnectReason::UserClose;
490 }
491 None => {
492 let _ = self.conn.send_close().await;
493 return DisconnectReason::SendersDropped;
494 }
495 }
496 }
497 frame = self.conn.recv() => {
498 match frame {
499 Ok(RawFrame::Text(text)) => {
500 if let Some(reason) = self.try_decode_and_process(&text) {
501 return reason;
502 }
503 }
504 Ok(RawFrame::Binary(data)) => {
505 if let Ok(text) = std::str::from_utf8(&data)
507 && let Some(reason) = self.try_decode_and_process(text)
508 {
509 return reason;
510 }
511 }
512 Ok(RawFrame::Close(reason)) => {
513 let _ = self.event_tx.send(ServerEvent::Closed {
514 reason: reason.unwrap_or_default(),
515 });
516 return DisconnectReason::ConnectionLost;
517 }
518 Err(e) => {
519 tracing::warn!(error = %e, "recv error");
520 return DisconnectReason::ConnectionLost;
521 }
522 }
523 }
524 }
525 }
526 }
527
528 fn try_decode_and_process(&self, text: &str) -> Option<DisconnectReason> {
533 match codec::decode(text) {
534 Ok(msg) => {
535 if self.process_message(msg) {
536 Some(DisconnectReason::GoAway)
537 } else {
538 None
539 }
540 }
541 Err(e) => {
542 tracing::warn!(error = %e, "failed to decode server message");
543 None
544 }
545 }
546 }
547
548 fn process_message(&self, msg: ServerMessage) -> bool {
549 if let Some(ref sr) = msg.session_resumption_update
551 && let Some(ref handle) = sr.new_handle
552 {
553 self.state.set_resume_handle(Some(handle.clone()));
554 }
555
556 let is_go_away = msg.go_away.is_some();
557
558 for event in codec::into_events(msg) {
559 let _ = self.event_tx.send(event);
560 }
561
562 is_go_away
563 }
564
565 async fn reconnect(&mut self) -> Result<Connection, SessionError> {
567 let policy = &self.config.reconnect;
568 let mut attempt = 0u32;
569
570 loop {
571 attempt += 1;
572 if policy.max_attempts.is_some_and(|max| attempt > max) {
573 return Err(SessionError::ReconnectExhausted {
574 attempts: attempt - 1,
575 });
576 }
577
578 let backoff = compute_backoff(policy, attempt);
579 tracing::debug!(attempt, ?backoff, "reconnect backoff");
580 tokio::time::sleep(backoff).await;
581
582 let mut conn = match Connection::connect(&self.config.transport).await {
583 Ok(c) => c,
584 Err(e) => {
585 tracing::warn!(attempt, error = %e, "reconnect connect failed");
586 continue;
587 }
588 };
589
590 let resume_handle = self.state.resume_handle();
591 match do_handshake(&mut conn, &self.config.setup, resume_handle).await {
592 Ok(()) => return Ok(conn),
593 Err(e) => {
594 tracing::warn!(attempt, error = %e, "reconnect handshake failed");
595 continue;
596 }
597 }
598 }
599 }
600}
601
602async fn do_handshake(
609 conn: &mut Connection,
610 setup: &SetupConfig,
611 resume_handle: Option<String>,
612) -> Result<(), SessionError> {
613 let setup = setup_for_handshake(setup, resume_handle);
614
615 let json = codec::encode(&ClientMessage::Setup(setup))?;
616 tracing::debug!(setup_json = %json, "sending setup message");
617 conn.send_text(&json)
618 .await
619 .map_err(|e| SessionError::SetupFailed(format_error_chain(&e)))?;
620
621 tokio::time::timeout(SETUP_TIMEOUT, wait_setup_complete(conn))
622 .await
623 .map_err(|_| SessionError::SetupTimeout(SETUP_TIMEOUT))?
624}
625
626fn setup_for_handshake(setup: &SetupConfig, resume_handle: Option<String>) -> SetupConfig {
627 let mut setup = setup.clone();
628 if let Some(handle) = resume_handle {
629 let sr = setup
630 .session_resumption
631 .get_or_insert_with(SessionResumptionConfig::default);
632 sr.handle = Some(handle);
633 setup.history_config = None;
637 }
638 setup
639}
640
641async fn wait_setup_complete(conn: &mut Connection) -> Result<(), SessionError> {
642 loop {
643 match conn.recv().await {
644 Ok(RawFrame::Text(text)) => {
645 tracing::debug!(raw = %text, "received text during setup");
646 match try_parse_setup_response(&text)? {
647 SetupResult::Complete => return Ok(()),
648 SetupResult::Continue => {}
649 }
650 }
651 Ok(RawFrame::Binary(data)) => {
652 if let Ok(text) = std::str::from_utf8(&data) {
654 tracing::debug!(raw = %text, "received binary (UTF-8) during setup");
655 match try_parse_setup_response(text)? {
656 SetupResult::Complete => return Ok(()),
657 SetupResult::Continue => {}
658 }
659 }
660 }
661 Ok(RawFrame::Close(reason)) => {
662 return Err(SessionError::SetupFailed(format!(
663 "closed during setup: {}",
664 reason.unwrap_or_default()
665 )));
666 }
667 Err(e) => return Err(SessionError::SetupFailed(format_error_chain(&e))),
668 }
669 }
670}
671
672enum SetupResult {
673 Complete,
674 Continue,
675}
676
677fn try_parse_setup_response(text: &str) -> Result<SetupResult, SessionError> {
678 let msg = codec::decode(text).map_err(|e| SessionError::SetupFailed(format_error_chain(&e)))?;
679 if msg.setup_complete.is_some() {
680 return Ok(SetupResult::Complete);
681 }
682 if let Some(err) = msg.error {
683 return Err(SessionError::Api(err.message));
684 }
685 Ok(SetupResult::Continue)
686}
687
688fn format_error_chain(error: &dyn std::error::Error) -> String {
689 let mut message = error.to_string();
690 let mut current = error.source();
691 while let Some(source) = current {
692 let source_text = source.to_string();
693 if !source_text.is_empty() && !message.ends_with(&source_text) {
694 message.push_str(": ");
695 message.push_str(&source_text);
696 }
697 current = source.source();
698 }
699 message
700}
701
702fn compute_backoff(policy: &ReconnectPolicy, attempt: u32) -> Duration {
706 let exp = attempt.saturating_sub(1).min(10);
707 let factor = 2u64.saturating_pow(exp);
708 let ms = policy.base_backoff.as_millis() as u64 * factor;
709 Duration::from_millis(ms.min(policy.max_backoff.as_millis() as u64))
710}
711
712#[cfg(test)]
713mod tests {
714 use crate::error::{BearerTokenError, ConnectError};
715 use crate::types::HistoryConfig;
716
717 use super::*;
718
719 #[test]
720 fn backoff_exponential_with_cap() {
721 let policy = ReconnectPolicy {
722 base_backoff: Duration::from_millis(500),
723 max_backoff: Duration::from_secs(5),
724 ..Default::default()
725 };
726 assert_eq!(compute_backoff(&policy, 1), Duration::from_millis(500));
727 assert_eq!(compute_backoff(&policy, 2), Duration::from_millis(1000));
728 assert_eq!(compute_backoff(&policy, 3), Duration::from_millis(2000));
729 assert_eq!(compute_backoff(&policy, 4), Duration::from_millis(4000));
730 assert_eq!(compute_backoff(&policy, 5), Duration::from_secs(5)); assert_eq!(compute_backoff(&policy, 100), Duration::from_secs(5));
732 }
733
734 #[test]
735 fn status_round_trip() {
736 let state = SharedState::new();
737 assert_eq!(state.status(), SessionStatus::Connecting);
738
739 state.set_status(SessionStatus::Connected);
740 assert_eq!(state.status(), SessionStatus::Connected);
741
742 state.set_status(SessionStatus::Reconnecting);
743 assert_eq!(state.status(), SessionStatus::Reconnecting);
744
745 state.set_status(SessionStatus::Closed);
746 assert_eq!(state.status(), SessionStatus::Closed);
747 }
748
749 #[test]
750 fn resume_handle_tracking() {
751 let state = SharedState::new();
752 assert!(state.resume_handle().is_none());
753
754 state.set_resume_handle(Some("h1".into()));
755 assert_eq!(state.resume_handle().as_deref(), Some("h1"));
756
757 state.set_resume_handle(Some("h2".into()));
758 assert_eq!(state.resume_handle().as_deref(), Some("h2"));
759
760 state.set_resume_handle(None);
761 assert!(state.resume_handle().is_none());
762 }
763
764 #[test]
765 fn default_reconnect_policy() {
766 let p = ReconnectPolicy::default();
767 assert!(p.enabled);
768 assert_eq!(p.base_backoff, Duration::from_millis(500));
769 assert_eq!(p.max_backoff, Duration::from_secs(5));
770 assert_eq!(p.max_attempts, Some(10));
771 }
772
773 #[test]
774 fn handshake_setup_strips_initial_history_when_resuming() {
775 let setup = SetupConfig {
776 model: "models/test".into(),
777 history_config: Some(HistoryConfig {
778 initial_history_in_client_content: Some(true),
779 }),
780 ..Default::default()
781 };
782
783 let resumed = setup_for_handshake(&setup, Some("resume-1".into()));
784 assert_eq!(
785 resumed
786 .session_resumption
787 .as_ref()
788 .and_then(|config| config.handle.as_deref()),
789 Some("resume-1")
790 );
791 assert!(resumed.history_config.is_none());
792
793 let fresh = setup_for_handshake(&setup, None);
794 assert_eq!(fresh.history_config, setup.history_config);
795 }
796
797 #[test]
798 fn format_error_chain_includes_sources() {
799 let err = ConnectError::Auth(BearerTokenError::with_source(
800 "failed to refresh Google Cloud access token from Application Default Credentials",
801 std::io::Error::other("invalid_grant: Account has been deleted"),
802 ));
803
804 assert_eq!(
805 format_error_chain(&err),
806 "failed to obtain bearer token: failed to refresh Google Cloud access token from Application Default Credentials: invalid_grant: Account has been deleted"
807 );
808 }
809}