Skip to main content

vox_rtc_server/
session.rs

1use crate::error::{Result, VoxRtcError};
2use crate::socket::RawSocketChannel;
3use crate::types::*;
4use serde_json::Value;
5use std::ops::ControlFlow;
6use tokio::sync::broadcast::error::RecvError;
7use tokio::task::JoinHandle;
8use tokio::time::{Duration, timeout};
9
10#[derive(Clone)]
11pub struct VoxRtcControlSession {
12    channel: RawSocketChannel,
13    session_id: String,
14    channel_name: String,
15    join_timeout: Duration,
16}
17
18pub struct Listener {
19    handle: JoinHandle<()>,
20}
21
22impl Drop for Listener {
23    fn drop(&mut self) {
24        self.handle.abort();
25    }
26}
27
28impl VoxRtcControlSession {
29    pub(crate) fn new(
30        channel: RawSocketChannel,
31        session_id: String,
32        join_timeout: Duration,
33    ) -> Self {
34        let channel_name = format!("/rtc/{session_id}");
35        Self {
36            channel,
37            session_id,
38            channel_name,
39            join_timeout,
40        }
41    }
42
43    pub fn session_id(&self) -> &str {
44        &self.session_id
45    }
46
47    pub fn channel_name(&self) -> &str {
48        &self.channel_name
49    }
50
51    pub async fn join(&self) -> Result<()> {
52        let mut states = self.channel.subscribe_state();
53        self.channel.join().await?;
54        let channel_name = self.channel.name().to_owned();
55        timeout(self.join_timeout, async move {
56            loop {
57                let state = *states.borrow_and_update();
58                match state {
59                    ChannelState::Joined => return Ok(()),
60                    ChannelState::Closed | ChannelState::Declined => {
61                        return Err(VoxRtcError::JoinFailed {
62                            channel: channel_name,
63                            state: format!("{state:?}"),
64                        });
65                    }
66                    _ => {}
67                }
68                if states.changed().await.is_err() {
69                    return Err(VoxRtcError::Disconnected);
70                }
71            }
72        })
73        .await
74        .map_err(|_| VoxRtcError::JoinTimeout(self.channel_name.clone()))?
75    }
76
77    pub async fn close(&self) -> Result<()> {
78        self.channel.leave().await
79    }
80
81    pub fn on_event<F>(&self, handler: F) -> Listener
82    where
83        F: Fn(WireEvent) + Send + Sync + 'static,
84    {
85        let mut messages = self.channel.subscribe_messages();
86        let session_id = self.session_id.clone();
87        let channel_name = self.channel_name.clone();
88        Listener {
89            handle: tokio::spawn(async move {
90                loop {
91                    match next_message(messages.recv().await) {
92                        ControlFlow::Break(()) => break,
93                        ControlFlow::Continue(None) => continue,
94                        ControlFlow::Continue(Some((event, payload))) => handler(WireEvent {
95                            r#type: event,
96                            data: payload,
97                            session_id: session_id.clone(),
98                            channel_name: channel_name.clone(),
99                        }),
100                    }
101                }
102            }),
103        }
104    }
105
106    pub fn on<F>(&self, event_name: impl Into<String>, handler: F) -> Listener
107    where
108        F: Fn(EventData) + Send + Sync + 'static,
109    {
110        let event_name = event_name.into();
111        let mut messages = self.channel.subscribe_messages();
112        Listener {
113            handle: tokio::spawn(async move {
114                loop {
115                    match next_message(messages.recv().await) {
116                        ControlFlow::Break(()) => break,
117                        ControlFlow::Continue(None) => continue,
118                        ControlFlow::Continue(Some((event, payload))) => {
119                            if event == event_name {
120                                handler(payload);
121                            }
122                        }
123                    }
124                }
125            }),
126        }
127    }
128
129    pub fn on_session_attached<F>(&self, handler: F) -> Listener
130    where
131        F: Fn(SessionAttachedEvent) + Send + Sync + 'static,
132    {
133        let session_id = self.session_id.clone();
134        let channel_name = self.channel_name.clone();
135        self.on(EVENT_RTC_SESSION_ATTACHED, move |payload| {
136            handler(SessionAttachedEvent {
137                session_id: base_session_id(&payload, &session_id),
138                channel_name: channel_name.clone(),
139                data: payload,
140            })
141        })
142    }
143
144    pub fn on_session_created<F>(&self, handler: F) -> Listener
145    where
146        F: Fn(SessionCreatedEvent) + Send + Sync + 'static,
147    {
148        let session_id = self.session_id.clone();
149        let channel_name = self.channel_name.clone();
150        self.on(EVENT_SESSION_CREATED, move |payload| {
151            let session = payload.get("session").and_then(Value::as_object).cloned();
152            handler(SessionCreatedEvent {
153                session_id: base_session_id(&payload, &session_id),
154                channel_name: channel_name.clone(),
155                data: payload,
156                session,
157            });
158        })
159    }
160
161    pub fn on_transcript<F>(&self, handler: F) -> Listener
162    where
163        F: Fn(TranscriptEvent) + Send + Sync + 'static,
164    {
165        let session_id = self.session_id.clone();
166        let channel_name = self.channel_name.clone();
167        self.on(EVENT_TRANSCRIPT_COMPLETED, move |payload| {
168            handler(TranscriptEvent {
169                session_id: base_session_id(&payload, &session_id),
170                channel_name: channel_name.clone(),
171                transcript: required_string(&payload, "transcript", ""),
172                language: optional_string(&payload, "language"),
173                start_ms: optional_number(&payload, "start_ms"),
174                end_ms: optional_number(&payload, "end_ms"),
175                eou_probability: optional_number(&payload, "eou_probability"),
176                topics: optional_string_vec(&payload, "topics"),
177                data: payload,
178            });
179        })
180    }
181
182    pub fn on_turn_state_changed<F>(&self, handler: F) -> Listener
183    where
184        F: Fn(TurnStateEvent) + Send + Sync + 'static,
185    {
186        let session_id = self.session_id.clone();
187        let channel_name = self.channel_name.clone();
188        self.on(EVENT_TURN_STATE_CHANGED, move |payload| {
189            handler(TurnStateEvent {
190                session_id: base_session_id(&payload, &session_id),
191                channel_name: channel_name.clone(),
192                state: required_string(&payload, "state", "unknown"),
193                previous_state: optional_string(&payload, "previous_state"),
194                data: payload,
195            });
196        })
197    }
198
199    pub fn on_speech_started<F>(&self, handler: F) -> Listener
200    where
201        F: Fn(SpeechStartedEvent) + Send + Sync + 'static,
202    {
203        let session_id = self.session_id.clone();
204        let channel_name = self.channel_name.clone();
205        self.on(EVENT_SPEECH_STARTED, move |payload| {
206            handler(SpeechStartedEvent {
207                session_id: base_session_id(&payload, &session_id),
208                channel_name: channel_name.clone(),
209                timestamp_ms: optional_number(&payload, "timestamp_ms"),
210                data: payload,
211            });
212        })
213    }
214
215    pub fn on_speech_stopped<F>(&self, handler: F) -> Listener
216    where
217        F: Fn(SpeechStoppedEvent) + Send + Sync + 'static,
218    {
219        let session_id = self.session_id.clone();
220        let channel_name = self.channel_name.clone();
221        self.on(EVENT_SPEECH_STOPPED, move |payload| {
222            handler(SpeechStoppedEvent {
223                session_id: base_session_id(&payload, &session_id),
224                channel_name: channel_name.clone(),
225                timestamp_ms: optional_number(&payload, "timestamp_ms"),
226                data: payload,
227            });
228        })
229    }
230
231    pub fn on_transcript_delta<F>(&self, handler: F) -> Listener
232    where
233        F: Fn(TranscriptDeltaEvent) + Send + Sync + 'static,
234    {
235        let session_id = self.session_id.clone();
236        let channel_name = self.channel_name.clone();
237        self.on(EVENT_TRANSCRIPT_DELTA, move |payload| {
238            handler(TranscriptDeltaEvent {
239                session_id: base_session_id(&payload, &session_id),
240                channel_name: channel_name.clone(),
241                delta: required_string(&payload, "delta", ""),
242                start_ms: optional_number(&payload, "start_ms"),
243                end_ms: optional_number(&payload, "end_ms"),
244                data: payload,
245            });
246        })
247    }
248
249    pub fn on_turn_eou_predicted<F>(&self, handler: F) -> Listener
250    where
251        F: Fn(TurnEouPredictedEvent) + Send + Sync + 'static,
252    {
253        let session_id = self.session_id.clone();
254        let channel_name = self.channel_name.clone();
255        self.on(EVENT_TURN_EOU_PREDICTED, move |payload| {
256            handler(TurnEouPredictedEvent {
257                session_id: base_session_id(&payload, &session_id),
258                channel_name: channel_name.clone(),
259                probability: optional_number(&payload, "probability"),
260                threshold: optional_number(&payload, "threshold"),
261                delay_ms: optional_number(&payload, "delay_ms"),
262                start_ms: optional_number(&payload, "start_ms"),
263                end_ms: optional_number(&payload, "end_ms"),
264                decision: optional_string(&payload, "decision"),
265                action: optional_string(&payload, "action"),
266                turn_detector: optional_string(&payload, "turn_detector"),
267                data: payload,
268            });
269        })
270    }
271
272    pub fn on_response_created<F>(&self, handler: F) -> Listener
273    where
274        F: Fn(ResponseEvent) + Send + Sync + 'static,
275    {
276        self.on_response_event(EVENT_RESPONSE_CREATED, handler)
277    }
278
279    pub fn on_response_committed<F>(&self, handler: F) -> Listener
280    where
281        F: Fn(ResponseEvent) + Send + Sync + 'static,
282    {
283        self.on_response_event(EVENT_RESPONSE_COMMITTED, handler)
284    }
285
286    pub fn on_response_done<F>(&self, handler: F) -> Listener
287    where
288        F: Fn(ResponseEvent) + Send + Sync + 'static,
289    {
290        self.on_response_event(EVENT_RESPONSE_DONE, handler)
291    }
292
293    pub fn on_response_cancelled<F>(&self, handler: F) -> Listener
294    where
295        F: Fn(ResponseEvent) + Send + Sync + 'static,
296    {
297        self.on_response_event(EVENT_RESPONSE_CANCELLED, handler)
298    }
299
300    pub fn on_response_audio_clear<F>(&self, handler: F) -> Listener
301    where
302        F: Fn(ResponseEvent) + Send + Sync + 'static,
303    {
304        self.on_response_event(EVENT_RESPONSE_AUDIO_CLEAR, handler)
305    }
306
307    fn on_response_event<F>(&self, event_name: &'static str, handler: F) -> Listener
308    where
309        F: Fn(ResponseEvent) + Send + Sync + 'static,
310    {
311        let session_id = self.session_id.clone();
312        let channel_name = self.channel_name.clone();
313        self.on(event_name, move |payload| {
314            handler(response_event(payload, &session_id, &channel_name));
315        })
316    }
317
318    pub fn on_interruption_detected<F>(&self, handler: F) -> Listener
319    where
320        F: Fn(InterruptionEvent) + Send + Sync + 'static,
321    {
322        self.on_interruption_event(EVENT_INTERRUPTION_DETECTED, handler)
323    }
324
325    pub fn on_interruption_false_positive<F>(&self, handler: F) -> Listener
326    where
327        F: Fn(InterruptionEvent) + Send + Sync + 'static,
328    {
329        self.on_interruption_event(EVENT_INTERRUPTION_FALSE_POSITIVE, handler)
330    }
331
332    fn on_interruption_event<F>(&self, event_name: &'static str, handler: F) -> Listener
333    where
334        F: Fn(InterruptionEvent) + Send + Sync + 'static,
335    {
336        let session_id = self.session_id.clone();
337        let channel_name = self.channel_name.clone();
338        self.on(event_name, move |payload| {
339            handler(InterruptionEvent {
340                response: response_event(payload.clone(), &session_id, &channel_name),
341                vad_active_ms: optional_number(&payload, "vad_active_ms"),
342                partial_transcript: optional_string(&payload, "partial_transcript"),
343            });
344        })
345    }
346
347    pub fn on_browser_event<F>(&self, handler: F) -> Listener
348    where
349        F: Fn(BrowserEvent) + Send + Sync + 'static,
350    {
351        let session_id = self.session_id.clone();
352        let channel_name = self.channel_name.clone();
353        self.on(EVENT_BROWSER_EVENT, move |payload| {
354            handler(BrowserEvent {
355                session_id: base_session_id(&payload, &session_id),
356                channel_name: channel_name.clone(),
357                event: required_string(&payload, "event", ""),
358                payload: payload.get("payload").cloned().unwrap_or(Value::Null),
359                data: payload,
360            });
361        })
362    }
363
364    pub fn on_close<F>(&self, handler: F) -> Listener
365    where
366        F: Fn(CloseEvent) + Send + Sync + 'static,
367    {
368        let session_id = self.session_id.clone();
369        let channel_name = self.channel_name.clone();
370        self.on(EVENT_RTC_CLIENT_DISCONNECTED, move |payload| {
371            handler(CloseEvent {
372                session_id: base_session_id(&payload, &session_id),
373                channel_name: channel_name.clone(),
374                reason: required_string(&payload, "reason", "unknown"),
375                connection_state: optional_string(&payload, "connection_state"),
376                ice_connection_state: optional_string(&payload, "ice_connection_state"),
377                data_channel_state: optional_string(&payload, "data_channel_state"),
378                data: payload,
379            });
380        })
381    }
382
383    pub fn on_error<F>(&self, handler: F) -> Listener
384    where
385        F: Fn(ErrorEvent) + Send + Sync + 'static,
386    {
387        let session_id = self.session_id.clone();
388        let channel_name = self.channel_name.clone();
389        self.on(EVENT_ERROR, move |payload| {
390            handler(ErrorEvent {
391                session_id: base_session_id(&payload, &session_id),
392                channel_name: channel_name.clone(),
393                message: optional_string(&payload, "message"),
394                code: optional_string(&payload, "code"),
395                data: payload,
396            });
397        })
398    }
399
400    pub async fn send_control(&self, event: &str, payload: EventData) -> Result<()> {
401        self.channel.send_message(event, payload).await
402    }
403
404    pub async fn configure(&self, config: SessionConfig) -> Result<()> {
405        let mut session = config.extra;
406        insert_opt(&mut session, "stt_model", config.stt_model);
407        insert_opt(&mut session, "tts_model", config.tts_model);
408        insert_opt(&mut session, "voice", config.voice);
409        insert_opt(&mut session, "turn_profile", config.turn_profile);
410        insert_opt(&mut session, "vad_backend", config.vad_backend);
411        insert_opt(&mut session, "turn_detector", config.turn_detector);
412
413        let mut payload = EventData::new();
414        payload.insert("session".to_owned(), Value::Object(session));
415        self.send_control("session.update", payload).await
416    }
417
418    pub async fn start_response(&self, options: Option<ResponseOptions>) -> Result<()> {
419        self.send_control("response.start", response_options_payload(options))
420            .await
421    }
422
423    pub async fn append_response_text(
424        &self,
425        delta: impl Into<String>,
426        options: Option<ResponseOptions>,
427    ) -> Result<()> {
428        let mut payload = response_options_payload(options);
429        payload.insert("delta".to_owned(), Value::String(delta.into()));
430        self.send_control("response.delta", payload).await
431    }
432
433    pub async fn commit_response(&self) -> Result<()> {
434        self.send_control("response.commit", EventData::new()).await
435    }
436
437    pub async fn cancel_response(&self) -> Result<()> {
438        self.send_control("response.cancel", EventData::new()).await
439    }
440
441    pub async fn replace_response_text(
442        &self,
443        text: impl Into<String>,
444        options: Option<ResponseOptions>,
445    ) -> Result<()> {
446        let mut payload = response_options_payload(options);
447        payload.insert("text".to_owned(), Value::String(text.into()));
448        self.send_control("response.replace_text", payload).await
449    }
450
451    pub async fn send_text_response(
452        &self,
453        text: impl Into<String>,
454        options: Option<ResponseOptions>,
455        cancel_first: bool,
456    ) -> Result<()> {
457        let text = text.into();
458        if cancel_first {
459            return self.replace_response_text(text, options).await;
460        }
461        self.start_response(options.clone()).await?;
462        self.append_response_text(text, options).await?;
463        self.commit_response().await
464    }
465
466    pub async fn send_client_event(&self, envelope: ClientEventEnvelope) -> Result<()> {
467        let mut payload = EventData::new();
468        payload.insert("event".to_owned(), Value::String(envelope.event));
469        payload.insert("payload".to_owned(), envelope.payload);
470        self.send_control(EVENT_CLIENT_EVENT, payload).await
471    }
472}
473
474fn next_message(
475    result: std::result::Result<(String, EventData), RecvError>,
476) -> ControlFlow<(), Option<(String, EventData)>> {
477    match result {
478        Ok(message) => ControlFlow::Continue(Some(message)),
479        Err(RecvError::Lagged(_)) => ControlFlow::Continue(None),
480        Err(RecvError::Closed) => ControlFlow::Break(()),
481    }
482}
483
484fn insert_opt(session: &mut EventData, key: &str, value: Option<String>) {
485    if let Some(value) = value {
486        session.insert(key.to_owned(), Value::String(value));
487    }
488}
489
490fn response_options_payload(options: Option<ResponseOptions>) -> EventData {
491    let mut payload = EventData::new();
492    if let Some(options) = options
493        && let Some(allow) = options.allow_interruptions
494    {
495        payload.insert("allow_interruptions".to_owned(), Value::Bool(allow));
496    }
497    payload
498}
499
500fn base_session_id(payload: &EventData, fallback: &str) -> String {
501    required_string(payload, "session_id", fallback)
502}
503
504fn response_event(payload: EventData, session_id: &str, channel_name: &str) -> ResponseEvent {
505    ResponseEvent {
506        session_id: base_session_id(&payload, session_id),
507        channel_name: channel_name.to_owned(),
508        response_id: optional_string(&payload, "response_id"),
509        data: payload,
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use crate::socket::test_channel;
517    use serde_json::json;
518    use tokio::sync::broadcast;
519    use tokio::sync::mpsc;
520
521    async fn session() -> (
522        VoxRtcControlSession,
523        broadcast::Sender<(String, EventData)>,
524    ) {
525        let (channel, sender) = test_channel().await;
526        let session =
527            VoxRtcControlSession::new(channel, "sess-1".to_owned(), Duration::from_secs(1));
528        (session, sender)
529    }
530
531    fn payload(value: Value) -> EventData {
532        value.as_object().cloned().expect("object payload")
533    }
534
535    async fn recv<T>(rx: &mut mpsc::UnboundedReceiver<T>) -> T {
536        timeout(Duration::from_secs(1), rx.recv())
537            .await
538            .expect("handler fired within timeout")
539            .expect("handler produced an event")
540    }
541
542    #[test]
543    fn next_message_classifies_lag_close_and_ok() {
544        assert!(matches!(
545            next_message(Ok(("e".to_owned(), EventData::new()))),
546            ControlFlow::Continue(Some(_))
547        ));
548        assert!(matches!(
549            next_message(Err(RecvError::Lagged(7))),
550            ControlFlow::Continue(None)
551        ));
552        assert!(matches!(
553            next_message(Err(RecvError::Closed)),
554            ControlFlow::Break(())
555        ));
556    }
557
558    #[tokio::test]
559    async fn on_speech_started_fires_with_timestamp() {
560        let (session, sender) = session().await;
561        let (tx, mut rx) = mpsc::unbounded_channel();
562        let _listener = session.on_speech_started(move |event| {
563            tx.send(event).unwrap();
564        });
565        sender
566            .send((
567                EVENT_SPEECH_STARTED.to_owned(),
568                payload(json!({ "session_id": "sess-1", "timestamp_ms": 1234 })),
569            ))
570            .unwrap();
571        let event = recv(&mut rx).await;
572        assert_eq!(event.session_id, "sess-1");
573        assert_eq!(event.channel_name, "/rtc/sess-1");
574        assert_eq!(event.timestamp_ms, Some(1234.0));
575    }
576
577    #[tokio::test]
578    async fn on_speech_stopped_fires_with_timestamp() {
579        let (session, sender) = session().await;
580        let (tx, mut rx) = mpsc::unbounded_channel();
581        let _listener = session.on_speech_stopped(move |event| {
582            tx.send(event).unwrap();
583        });
584        sender
585            .send((
586                EVENT_SPEECH_STOPPED.to_owned(),
587                payload(json!({ "timestamp_ms": 5678 })),
588            ))
589            .unwrap();
590        let event = recv(&mut rx).await;
591        assert_eq!(event.timestamp_ms, Some(5678.0));
592    }
593
594    #[tokio::test]
595    async fn on_transcript_delta_fires_with_fields() {
596        let (session, sender) = session().await;
597        let (tx, mut rx) = mpsc::unbounded_channel();
598        let _listener = session.on_transcript_delta(move |event| {
599            tx.send(event).unwrap();
600        });
601        sender
602            .send((
603                EVENT_TRANSCRIPT_DELTA.to_owned(),
604                payload(json!({ "delta": "hel", "start_ms": 10, "end_ms": 20 })),
605            ))
606            .unwrap();
607        let event = recv(&mut rx).await;
608        assert_eq!(event.delta, "hel");
609        assert_eq!(event.start_ms, Some(10.0));
610        assert_eq!(event.end_ms, Some(20.0));
611    }
612
613    #[tokio::test]
614    async fn on_turn_eou_predicted_fires_with_fields() {
615        let (session, sender) = session().await;
616        let (tx, mut rx) = mpsc::unbounded_channel();
617        let _listener = session.on_turn_eou_predicted(move |event| {
618            tx.send(event).unwrap();
619        });
620        sender
621            .send((
622                EVENT_TURN_EOU_PREDICTED.to_owned(),
623                payload(json!({
624                    "probability": 0.82,
625                    "threshold": 0.5,
626                    "delay_ms": 120,
627                    "start_ms": 0,
628                    "end_ms": 300,
629                    "decision": "end",
630                    "action": "commit",
631                    "turn_detector": "smart"
632                })),
633            ))
634            .unwrap();
635        let event = recv(&mut rx).await;
636        assert_eq!(event.probability, Some(0.82));
637        assert_eq!(event.threshold, Some(0.5));
638        assert_eq!(event.delay_ms, Some(120.0));
639        assert_eq!(event.start_ms, Some(0.0));
640        assert_eq!(event.end_ms, Some(300.0));
641        assert_eq!(event.decision.as_deref(), Some("end"));
642        assert_eq!(event.action.as_deref(), Some("commit"));
643        assert_eq!(event.turn_detector.as_deref(), Some("smart"));
644    }
645
646    #[tokio::test]
647    async fn handler_survives_a_lagged_broadcast() {
648        let (session, sender) = session().await;
649        let (tx, mut rx) = mpsc::unbounded_channel();
650        let _listener = session.on_speech_started(move |event| {
651            tx.send(event.timestamp_ms).unwrap();
652        });
653
654        for index in 0..2100u32 {
655            let _ = sender.send((
656                EVENT_SPEECH_STARTED.to_owned(),
657                payload(json!({ "timestamp_ms": index })),
658            ));
659        }
660        let _ = sender.send((
661            EVENT_SPEECH_STARTED.to_owned(),
662            payload(json!({ "timestamp_ms": 9999 })),
663        ));
664
665        let mut saw_final = false;
666        while let Ok(Some(value)) = timeout(Duration::from_secs(1), rx.recv()).await {
667            if value == Some(9999.0) {
668                saw_final = true;
669                break;
670            }
671        }
672        assert!(
673            saw_final,
674            "loop must keep delivering events after a broadcast lag"
675        );
676    }
677}