1use std::sync::atomic::{AtomicU8, Ordering};
30use std::sync::{Arc, Mutex};
31use std::time::Duration;
32
33use base64::Engine;
34use futures_util::Stream;
35use tokio::sync::{broadcast, mpsc};
36
37use crate::audio::INPUT_AUDIO_MIME;
38use crate::codec;
39use crate::error::SessionError;
40use crate::transport::{Connection, RawFrame, TransportConfig};
41use crate::types::*;
42
43const SETUP_TIMEOUT: Duration = Duration::from_secs(30);
45const EVENT_CHANNEL_CAPACITY: usize = 256;
46const COMMAND_CHANNEL_CAPACITY: usize = 64;
47
48pub struct SessionConfig {
53 pub transport: TransportConfig,
54 pub setup: SetupConfig,
55 pub reconnect: ReconnectPolicy,
56}
57
58pub struct ReconnectPolicy {
62 pub enabled: bool,
64 pub base_backoff: Duration,
66 pub max_backoff: Duration,
68 pub max_attempts: Option<u32>,
70}
71
72impl Default for ReconnectPolicy {
73 fn default() -> Self {
74 Self {
75 enabled: true,
76 base_backoff: Duration::from_millis(500),
77 max_backoff: Duration::from_secs(5),
78 max_attempts: Some(10),
79 }
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum SessionStatus {
86 Connecting = 0,
87 Connected = 1,
88 Reconnecting = 2,
89 Closed = 3,
90}
91
92pub struct Session {
102 cmd_tx: mpsc::Sender<Command>,
103 event_tx: broadcast::Sender<ServerEvent>,
104 event_rx: broadcast::Receiver<ServerEvent>,
105 state: Arc<SharedState>,
106}
107
108impl Clone for Session {
109 fn clone(&self) -> Self {
110 Self {
111 cmd_tx: self.cmd_tx.clone(),
112 event_tx: self.event_tx.clone(),
113 event_rx: self.event_tx.subscribe(),
114 state: self.state.clone(),
115 }
116 }
117}
118
119impl Session {
120 pub async fn connect(config: SessionConfig) -> Result<Self, SessionError> {
126 let (cmd_tx, cmd_rx) = mpsc::channel(COMMAND_CHANNEL_CAPACITY);
127 let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
128 let state = Arc::new(SharedState::new());
129
130 state.set_status(SessionStatus::Connecting);
132 let mut conn = Connection::connect(&config.transport)
133 .await
134 .map_err(|e| SessionError::SetupFailed(e.to_string()))?;
135
136 do_handshake(&mut conn, &config.setup, None).await?;
138 state.set_status(SessionStatus::Connected);
139 tracing::info!("session established");
140
141 let runner = Runner {
143 cmd_rx,
144 event_tx: event_tx.clone(),
145 conn,
146 config,
147 state: Arc::clone(&state),
148 };
149 tokio::spawn(runner.run());
150
151 let event_rx = event_tx.subscribe();
152 Ok(Self {
153 cmd_tx,
154 event_tx,
155 event_rx,
156 state,
157 })
158 }
159
160 pub fn status(&self) -> SessionStatus {
162 self.state.status()
163 }
164
165 pub async fn send_audio(&self, pcm_i16_le: &[u8]) -> Result<(), SessionError> {
174 let b64 = base64::engine::general_purpose::STANDARD.encode(pcm_i16_le);
175 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
176 audio: Some(Blob {
177 data: b64,
178 mime_type: INPUT_AUDIO_MIME.into(),
179 }),
180 video: None,
181 text: None,
182 activity_start: None,
183 activity_end: None,
184 audio_stream_end: None,
185 }))
186 .await
187 }
188
189 pub async fn send_video(&self, data: &[u8], mime: &str) -> Result<(), SessionError> {
191 let b64 = base64::engine::general_purpose::STANDARD.encode(data);
192 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
193 video: Some(Blob {
194 data: b64,
195 mime_type: mime.into(),
196 }),
197 audio: None,
198 text: None,
199 activity_start: None,
200 activity_end: None,
201 audio_stream_end: None,
202 }))
203 .await
204 }
205
206 pub async fn send_text(&self, text: &str) -> Result<(), SessionError> {
208 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
209 text: Some(text.into()),
210 audio: None,
211 video: None,
212 activity_start: None,
213 activity_end: None,
214 audio_stream_end: None,
215 }))
216 .await
217 }
218
219 pub async fn send_client_content(&self, content: ClientContent) -> Result<(), SessionError> {
221 self.send_raw(ClientMessage::ClientContent(content)).await
222 }
223
224 pub async fn activity_start(&self) -> Result<(), SessionError> {
226 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
227 activity_start: Some(EmptyObject {}),
228 audio: None,
229 video: None,
230 text: None,
231 activity_end: None,
232 audio_stream_end: None,
233 }))
234 .await
235 }
236
237 pub async fn activity_end(&self) -> Result<(), SessionError> {
239 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
240 activity_end: Some(EmptyObject {}),
241 audio: None,
242 video: None,
243 text: None,
244 activity_start: None,
245 audio_stream_end: None,
246 }))
247 .await
248 }
249
250 pub async fn audio_stream_end(&self) -> Result<(), SessionError> {
252 self.send_raw(ClientMessage::RealtimeInput(RealtimeInput {
253 audio_stream_end: Some(true),
254 audio: None,
255 video: None,
256 text: None,
257 activity_start: None,
258 activity_end: None,
259 }))
260 .await
261 }
262
263 pub async fn send_tool_response(
265 &self,
266 responses: Vec<FunctionResponse>,
267 ) -> Result<(), SessionError> {
268 self.send_raw(ClientMessage::ToolResponse(ToolResponseMessage {
269 function_responses: responses,
270 }))
271 .await
272 }
273
274 pub async fn send_raw(&self, msg: ClientMessage) -> Result<(), SessionError> {
276 self.cmd_tx
277 .send(Command::Send(Box::new(msg)))
278 .await
279 .map_err(|_| SessionError::Closed)
280 }
281
282 pub fn events(&self) -> impl Stream<Item = ServerEvent> {
289 let rx = self.event_tx.subscribe();
290 futures_util::stream::unfold(rx, |mut rx| async move {
291 loop {
292 match rx.recv().await {
293 Ok(event) => return Some((event, rx)),
294 Err(broadcast::error::RecvError::Lagged(n)) => {
295 tracing::warn!(n, "event stream lagged, some events were missed");
296 continue;
297 }
298 Err(broadcast::error::RecvError::Closed) => return None,
299 }
300 }
301 })
302 }
303
304 pub async fn next_event(&mut self) -> Option<ServerEvent> {
308 loop {
309 match self.event_rx.recv().await {
310 Ok(event) => return Some(event),
311 Err(broadcast::error::RecvError::Lagged(n)) => {
312 tracing::warn!(n, "event consumer lagged, some events were missed");
313 continue;
314 }
315 Err(broadcast::error::RecvError::Closed) => return None,
316 }
317 }
318 }
319
320 pub async fn close(self) -> Result<(), SessionError> {
327 let _ = self.cmd_tx.send(Command::Close).await;
328 Ok(())
329 }
330}
331
332enum Command {
335 Send(Box<ClientMessage>),
336 Close,
337}
338
339struct SharedState {
342 resume_handle: Mutex<Option<String>>,
343 status: AtomicU8,
344}
345
346impl SharedState {
347 fn new() -> Self {
348 Self {
349 resume_handle: Mutex::new(None),
350 status: AtomicU8::new(SessionStatus::Connecting as u8),
351 }
352 }
353
354 fn status(&self) -> SessionStatus {
355 match self.status.load(Ordering::Relaxed) {
356 0 => SessionStatus::Connecting,
357 1 => SessionStatus::Connected,
358 2 => SessionStatus::Reconnecting,
359 _ => SessionStatus::Closed,
360 }
361 }
362
363 fn set_status(&self, s: SessionStatus) {
364 self.status.store(s as u8, Ordering::Relaxed);
365 }
366
367 fn resume_handle(&self) -> Option<String> {
368 self.resume_handle.lock().unwrap().clone()
369 }
370
371 fn set_resume_handle(&self, handle: Option<String>) {
372 *self.resume_handle.lock().unwrap() = handle;
373 }
374}
375
376enum DisconnectReason {
379 GoAway,
380 ConnectionLost,
381 UserClose,
382 SendersDropped,
383}
384
385struct Runner {
386 cmd_rx: mpsc::Receiver<Command>,
387 event_tx: broadcast::Sender<ServerEvent>,
388 conn: Connection,
389 config: SessionConfig,
390 state: Arc<SharedState>,
391}
392
393impl Runner {
394 async fn run(mut self) {
395 loop {
396 let reason = self.run_connected().await;
397
398 match reason {
399 DisconnectReason::UserClose | DisconnectReason::SendersDropped => {
400 self.state.set_status(SessionStatus::Closed);
401 tracing::info!("session closed");
402 break;
403 }
404 DisconnectReason::GoAway | DisconnectReason::ConnectionLost => {
405 if !self.config.reconnect.enabled {
406 self.state.set_status(SessionStatus::Closed);
407 let _ = self.event_tx.send(ServerEvent::Closed {
408 reason: "disconnected (reconnect disabled)".into(),
409 });
410 break;
411 }
412
413 self.state.set_status(SessionStatus::Reconnecting);
414 tracing::info!("attempting reconnection");
415
416 match self.reconnect().await {
417 Ok(conn) => {
418 self.conn = conn;
419 self.state.set_status(SessionStatus::Connected);
420 tracing::info!("reconnected successfully");
421 }
422 Err(e) => {
423 self.state.set_status(SessionStatus::Closed);
424 let _ = self.event_tx.send(ServerEvent::Error(ApiError {
425 message: e.to_string(),
426 }));
427 break;
428 }
429 }
430 }
431 }
432 }
433 }
434
435 async fn run_connected(&mut self) -> DisconnectReason {
438 loop {
439 tokio::select! {
440 cmd = self.cmd_rx.recv() => {
441 match cmd {
442 Some(Command::Send(msg)) => { let msg = *msg;
443 match codec::encode(&msg) {
444 Ok(json) => {
445 if let Err(e) = self.conn.send_text(&json).await {
446 tracing::warn!(error = %e, "send failed");
447 return DisconnectReason::ConnectionLost;
448 }
449 }
450 Err(e) => {
451 tracing::warn!(error = %e, "message encode failed, dropping");
452 }
453 }
454 }
455 Some(Command::Close) => {
456 let _ = self.conn.send_close().await;
457 return DisconnectReason::UserClose;
458 }
459 None => {
460 let _ = self.conn.send_close().await;
461 return DisconnectReason::SendersDropped;
462 }
463 }
464 }
465 frame = self.conn.recv() => {
466 match frame {
467 Ok(RawFrame::Text(text)) => {
468 if let Some(reason) = self.try_decode_and_process(&text) {
469 return reason;
470 }
471 }
472 Ok(RawFrame::Binary(data)) => {
473 if let Ok(text) = std::str::from_utf8(&data)
475 && let Some(reason) = self.try_decode_and_process(text)
476 {
477 return reason;
478 }
479 }
480 Ok(RawFrame::Close(reason)) => {
481 let _ = self.event_tx.send(ServerEvent::Closed {
482 reason: reason.unwrap_or_default(),
483 });
484 return DisconnectReason::ConnectionLost;
485 }
486 Err(e) => {
487 tracing::warn!(error = %e, "recv error");
488 return DisconnectReason::ConnectionLost;
489 }
490 }
491 }
492 }
493 }
494 }
495
496 fn try_decode_and_process(&self, text: &str) -> Option<DisconnectReason> {
501 match codec::decode(text) {
502 Ok(msg) => {
503 if self.process_message(msg) {
504 Some(DisconnectReason::GoAway)
505 } else {
506 None
507 }
508 }
509 Err(e) => {
510 tracing::warn!(error = %e, "failed to decode server message");
511 None
512 }
513 }
514 }
515
516 fn process_message(&self, msg: ServerMessage) -> bool {
517 if let Some(ref sr) = msg.session_resumption_update
519 && let Some(ref handle) = sr.new_handle
520 {
521 self.state.set_resume_handle(Some(handle.clone()));
522 }
523
524 let is_go_away = msg.go_away.is_some();
525
526 for event in codec::into_events(msg) {
527 let _ = self.event_tx.send(event);
528 }
529
530 is_go_away
531 }
532
533 async fn reconnect(&mut self) -> Result<Connection, SessionError> {
535 let policy = &self.config.reconnect;
536 let mut attempt = 0u32;
537
538 loop {
539 attempt += 1;
540 if policy.max_attempts.is_some_and(|max| attempt > max) {
541 return Err(SessionError::ReconnectExhausted {
542 attempts: attempt - 1,
543 });
544 }
545
546 let backoff = compute_backoff(policy, attempt);
547 tracing::debug!(attempt, ?backoff, "reconnect backoff");
548 tokio::time::sleep(backoff).await;
549
550 let mut conn = match Connection::connect(&self.config.transport).await {
551 Ok(c) => c,
552 Err(e) => {
553 tracing::warn!(attempt, error = %e, "reconnect connect failed");
554 continue;
555 }
556 };
557
558 let resume_handle = self.state.resume_handle();
559 match do_handshake(&mut conn, &self.config.setup, resume_handle).await {
560 Ok(()) => return Ok(conn),
561 Err(e) => {
562 tracing::warn!(attempt, error = %e, "reconnect handshake failed");
563 continue;
564 }
565 }
566 }
567 }
568}
569
570async fn do_handshake(
577 conn: &mut Connection,
578 setup: &SetupConfig,
579 resume_handle: Option<String>,
580) -> Result<(), SessionError> {
581 let mut setup = setup.clone();
582 if let Some(handle) = resume_handle {
583 let sr = setup
584 .session_resumption
585 .get_or_insert_with(SessionResumptionConfig::default);
586 sr.handle = Some(handle);
587 }
588
589 let json = codec::encode(&ClientMessage::Setup(setup))?;
590 tracing::debug!(setup_json = %json, "sending setup message");
591 conn.send_text(&json)
592 .await
593 .map_err(|e| SessionError::SetupFailed(e.to_string()))?;
594
595 tokio::time::timeout(SETUP_TIMEOUT, wait_setup_complete(conn))
596 .await
597 .map_err(|_| SessionError::SetupTimeout(SETUP_TIMEOUT))?
598}
599
600async fn wait_setup_complete(conn: &mut Connection) -> Result<(), SessionError> {
601 loop {
602 match conn.recv().await {
603 Ok(RawFrame::Text(text)) => {
604 tracing::debug!(raw = %text, "received text during setup");
605 match try_parse_setup_response(&text)? {
606 SetupResult::Complete => return Ok(()),
607 SetupResult::Continue => {}
608 }
609 }
610 Ok(RawFrame::Binary(data)) => {
611 if let Ok(text) = std::str::from_utf8(&data) {
613 tracing::debug!(raw = %text, "received binary (UTF-8) during setup");
614 match try_parse_setup_response(text)? {
615 SetupResult::Complete => return Ok(()),
616 SetupResult::Continue => {}
617 }
618 }
619 }
620 Ok(RawFrame::Close(reason)) => {
621 return Err(SessionError::SetupFailed(format!(
622 "closed during setup: {}",
623 reason.unwrap_or_default()
624 )));
625 }
626 Err(e) => return Err(SessionError::SetupFailed(e.to_string())),
627 }
628 }
629}
630
631enum SetupResult {
632 Complete,
633 Continue,
634}
635
636fn try_parse_setup_response(text: &str) -> Result<SetupResult, SessionError> {
637 let msg = codec::decode(text).map_err(|e| SessionError::SetupFailed(e.to_string()))?;
638 if msg.setup_complete.is_some() {
639 return Ok(SetupResult::Complete);
640 }
641 if let Some(err) = msg.error {
642 return Err(SessionError::Api(err.message));
643 }
644 Ok(SetupResult::Continue)
645}
646
647fn compute_backoff(policy: &ReconnectPolicy, attempt: u32) -> Duration {
651 let exp = attempt.saturating_sub(1).min(10);
652 let factor = 2u64.saturating_pow(exp);
653 let ms = policy.base_backoff.as_millis() as u64 * factor;
654 Duration::from_millis(ms.min(policy.max_backoff.as_millis() as u64))
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660
661 #[test]
662 fn backoff_exponential_with_cap() {
663 let policy = ReconnectPolicy {
664 base_backoff: Duration::from_millis(500),
665 max_backoff: Duration::from_secs(5),
666 ..Default::default()
667 };
668 assert_eq!(compute_backoff(&policy, 1), Duration::from_millis(500));
669 assert_eq!(compute_backoff(&policy, 2), Duration::from_millis(1000));
670 assert_eq!(compute_backoff(&policy, 3), Duration::from_millis(2000));
671 assert_eq!(compute_backoff(&policy, 4), Duration::from_millis(4000));
672 assert_eq!(compute_backoff(&policy, 5), Duration::from_secs(5)); assert_eq!(compute_backoff(&policy, 100), Duration::from_secs(5));
674 }
675
676 #[test]
677 fn status_round_trip() {
678 let state = SharedState::new();
679 assert_eq!(state.status(), SessionStatus::Connecting);
680
681 state.set_status(SessionStatus::Connected);
682 assert_eq!(state.status(), SessionStatus::Connected);
683
684 state.set_status(SessionStatus::Reconnecting);
685 assert_eq!(state.status(), SessionStatus::Reconnecting);
686
687 state.set_status(SessionStatus::Closed);
688 assert_eq!(state.status(), SessionStatus::Closed);
689 }
690
691 #[test]
692 fn resume_handle_tracking() {
693 let state = SharedState::new();
694 assert!(state.resume_handle().is_none());
695
696 state.set_resume_handle(Some("h1".into()));
697 assert_eq!(state.resume_handle().as_deref(), Some("h1"));
698
699 state.set_resume_handle(Some("h2".into()));
700 assert_eq!(state.resume_handle().as_deref(), Some("h2"));
701
702 state.set_resume_handle(None);
703 assert!(state.resume_handle().is_none());
704 }
705
706 #[test]
707 fn default_reconnect_policy() {
708 let p = ReconnectPolicy::default();
709 assert!(p.enabled);
710 assert_eq!(p.base_backoff, Duration::from_millis(500));
711 assert_eq!(p.max_backoff, Duration::from_secs(5));
712 assert_eq!(p.max_attempts, Some(10));
713 }
714}