1use crate::error::{Result, VoxRtcError};
2use crate::socket::RawSocketChannel;
3use crate::types::*;
4use serde_json::Value;
5use std::ops::ControlFlow;
6use tokio::sync::broadcast::error::RecvError;
7use tokio::task::JoinHandle;
8use tokio::time::{Duration, timeout};
9
10#[derive(Clone)]
11pub struct VoxRtcControlSession {
12 channel: RawSocketChannel,
13 session_id: String,
14 channel_name: String,
15 join_timeout: Duration,
16}
17
18pub struct Listener {
19 handle: JoinHandle<()>,
20}
21
22impl Drop for Listener {
23 fn drop(&mut self) {
24 self.handle.abort();
25 }
26}
27
28impl VoxRtcControlSession {
29 pub(crate) fn new(
30 channel: RawSocketChannel,
31 session_id: String,
32 join_timeout: Duration,
33 ) -> Self {
34 let channel_name = format!("/rtc/{session_id}");
35 Self {
36 channel,
37 session_id,
38 channel_name,
39 join_timeout,
40 }
41 }
42
43 pub fn session_id(&self) -> &str {
44 &self.session_id
45 }
46
47 pub fn channel_name(&self) -> &str {
48 &self.channel_name
49 }
50
51 pub async fn join(&self) -> Result<()> {
52 let mut states = self.channel.subscribe_state();
53 self.channel.join().await?;
54 let channel_name = self.channel.name().to_owned();
55 timeout(self.join_timeout, async move {
56 loop {
57 let state = *states.borrow_and_update();
58 match state {
59 ChannelState::Joined => return Ok(()),
60 ChannelState::Closed | ChannelState::Declined => {
61 return Err(VoxRtcError::JoinFailed {
62 channel: channel_name,
63 state: format!("{state:?}"),
64 });
65 }
66 _ => {}
67 }
68 if states.changed().await.is_err() {
69 return Err(VoxRtcError::Disconnected);
70 }
71 }
72 })
73 .await
74 .map_err(|_| VoxRtcError::JoinTimeout(self.channel_name.clone()))?
75 }
76
77 pub async fn close(&self) -> Result<()> {
78 self.channel.leave().await
79 }
80
81 pub fn on_event<F>(&self, handler: F) -> Listener
82 where
83 F: Fn(WireEvent) + Send + Sync + 'static,
84 {
85 let mut messages = self.channel.subscribe_messages();
86 let session_id = self.session_id.clone();
87 let channel_name = self.channel_name.clone();
88 Listener {
89 handle: tokio::spawn(async move {
90 loop {
91 match next_message(messages.recv().await) {
92 ControlFlow::Break(()) => break,
93 ControlFlow::Continue(None) => continue,
94 ControlFlow::Continue(Some((event, payload))) => handler(WireEvent {
95 r#type: event,
96 data: payload,
97 session_id: session_id.clone(),
98 channel_name: channel_name.clone(),
99 }),
100 }
101 }
102 }),
103 }
104 }
105
106 pub fn on<F>(&self, event_name: impl Into<String>, handler: F) -> Listener
107 where
108 F: Fn(EventData) + Send + Sync + 'static,
109 {
110 let event_name = event_name.into();
111 let mut messages = self.channel.subscribe_messages();
112 Listener {
113 handle: tokio::spawn(async move {
114 loop {
115 match next_message(messages.recv().await) {
116 ControlFlow::Break(()) => break,
117 ControlFlow::Continue(None) => continue,
118 ControlFlow::Continue(Some((event, payload))) => {
119 if event == event_name {
120 handler(payload);
121 }
122 }
123 }
124 }
125 }),
126 }
127 }
128
129 pub fn on_session_attached<F>(&self, handler: F) -> Listener
130 where
131 F: Fn(SessionAttachedEvent) + Send + Sync + 'static,
132 {
133 let session_id = self.session_id.clone();
134 let channel_name = self.channel_name.clone();
135 self.on(EVENT_RTC_SESSION_ATTACHED, move |payload| {
136 handler(SessionAttachedEvent {
137 session_id: base_session_id(&payload, &session_id),
138 channel_name: channel_name.clone(),
139 data: payload,
140 })
141 })
142 }
143
144 pub fn on_session_created<F>(&self, handler: F) -> Listener
145 where
146 F: Fn(SessionCreatedEvent) + Send + Sync + 'static,
147 {
148 let session_id = self.session_id.clone();
149 let channel_name = self.channel_name.clone();
150 self.on(EVENT_SESSION_CREATED, move |payload| {
151 let session = payload.get("session").and_then(Value::as_object).cloned();
152 handler(SessionCreatedEvent {
153 session_id: base_session_id(&payload, &session_id),
154 channel_name: channel_name.clone(),
155 data: payload,
156 session,
157 });
158 })
159 }
160
161 pub fn on_transcript<F>(&self, handler: F) -> Listener
162 where
163 F: Fn(TranscriptEvent) + Send + Sync + 'static,
164 {
165 let session_id = self.session_id.clone();
166 let channel_name = self.channel_name.clone();
167 self.on(EVENT_TRANSCRIPT_COMPLETED, move |payload| {
168 handler(TranscriptEvent {
169 session_id: base_session_id(&payload, &session_id),
170 channel_name: channel_name.clone(),
171 transcript: required_string(&payload, "transcript", ""),
172 language: optional_string(&payload, "language"),
173 start_ms: optional_number(&payload, "start_ms"),
174 end_ms: optional_number(&payload, "end_ms"),
175 eou_probability: optional_number(&payload, "eou_probability"),
176 topics: optional_string_vec(&payload, "topics"),
177 data: payload,
178 });
179 })
180 }
181
182 pub fn on_turn_state_changed<F>(&self, handler: F) -> Listener
183 where
184 F: Fn(TurnStateEvent) + Send + Sync + 'static,
185 {
186 let session_id = self.session_id.clone();
187 let channel_name = self.channel_name.clone();
188 self.on(EVENT_TURN_STATE_CHANGED, move |payload| {
189 handler(TurnStateEvent {
190 session_id: base_session_id(&payload, &session_id),
191 channel_name: channel_name.clone(),
192 state: required_string(&payload, "state", "unknown"),
193 previous_state: optional_string(&payload, "previous_state"),
194 data: payload,
195 });
196 })
197 }
198
199 pub fn on_speech_started<F>(&self, handler: F) -> Listener
200 where
201 F: Fn(SpeechStartedEvent) + Send + Sync + 'static,
202 {
203 let session_id = self.session_id.clone();
204 let channel_name = self.channel_name.clone();
205 self.on(EVENT_SPEECH_STARTED, move |payload| {
206 handler(SpeechStartedEvent {
207 session_id: base_session_id(&payload, &session_id),
208 channel_name: channel_name.clone(),
209 timestamp_ms: optional_number(&payload, "timestamp_ms"),
210 data: payload,
211 });
212 })
213 }
214
215 pub fn on_speech_stopped<F>(&self, handler: F) -> Listener
216 where
217 F: Fn(SpeechStoppedEvent) + Send + Sync + 'static,
218 {
219 let session_id = self.session_id.clone();
220 let channel_name = self.channel_name.clone();
221 self.on(EVENT_SPEECH_STOPPED, move |payload| {
222 handler(SpeechStoppedEvent {
223 session_id: base_session_id(&payload, &session_id),
224 channel_name: channel_name.clone(),
225 timestamp_ms: optional_number(&payload, "timestamp_ms"),
226 data: payload,
227 });
228 })
229 }
230
231 pub fn on_transcript_delta<F>(&self, handler: F) -> Listener
232 where
233 F: Fn(TranscriptDeltaEvent) + Send + Sync + 'static,
234 {
235 let session_id = self.session_id.clone();
236 let channel_name = self.channel_name.clone();
237 self.on(EVENT_TRANSCRIPT_DELTA, move |payload| {
238 handler(TranscriptDeltaEvent {
239 session_id: base_session_id(&payload, &session_id),
240 channel_name: channel_name.clone(),
241 delta: required_string(&payload, "delta", ""),
242 start_ms: optional_number(&payload, "start_ms"),
243 end_ms: optional_number(&payload, "end_ms"),
244 data: payload,
245 });
246 })
247 }
248
249 pub fn on_turn_eou_predicted<F>(&self, handler: F) -> Listener
250 where
251 F: Fn(TurnEouPredictedEvent) + Send + Sync + 'static,
252 {
253 let session_id = self.session_id.clone();
254 let channel_name = self.channel_name.clone();
255 self.on(EVENT_TURN_EOU_PREDICTED, move |payload| {
256 handler(TurnEouPredictedEvent {
257 session_id: base_session_id(&payload, &session_id),
258 channel_name: channel_name.clone(),
259 probability: optional_number(&payload, "probability"),
260 threshold: optional_number(&payload, "threshold"),
261 delay_ms: optional_number(&payload, "delay_ms"),
262 start_ms: optional_number(&payload, "start_ms"),
263 end_ms: optional_number(&payload, "end_ms"),
264 decision: optional_string(&payload, "decision"),
265 action: optional_string(&payload, "action"),
266 turn_detector: optional_string(&payload, "turn_detector"),
267 data: payload,
268 });
269 })
270 }
271
272 pub fn on_response_created<F>(&self, handler: F) -> Listener
273 where
274 F: Fn(ResponseEvent) + Send + Sync + 'static,
275 {
276 self.on_response_event(EVENT_RESPONSE_CREATED, handler)
277 }
278
279 pub fn on_response_committed<F>(&self, handler: F) -> Listener
280 where
281 F: Fn(ResponseEvent) + Send + Sync + 'static,
282 {
283 self.on_response_event(EVENT_RESPONSE_COMMITTED, handler)
284 }
285
286 pub fn on_response_done<F>(&self, handler: F) -> Listener
287 where
288 F: Fn(ResponseEvent) + Send + Sync + 'static,
289 {
290 self.on_response_event(EVENT_RESPONSE_DONE, handler)
291 }
292
293 pub fn on_response_cancelled<F>(&self, handler: F) -> Listener
294 where
295 F: Fn(ResponseEvent) + Send + Sync + 'static,
296 {
297 self.on_response_event(EVENT_RESPONSE_CANCELLED, handler)
298 }
299
300 pub fn on_response_audio_clear<F>(&self, handler: F) -> Listener
301 where
302 F: Fn(ResponseEvent) + Send + Sync + 'static,
303 {
304 self.on_response_event(EVENT_RESPONSE_AUDIO_CLEAR, handler)
305 }
306
307 fn on_response_event<F>(&self, event_name: &'static str, handler: F) -> Listener
308 where
309 F: Fn(ResponseEvent) + Send + Sync + 'static,
310 {
311 let session_id = self.session_id.clone();
312 let channel_name = self.channel_name.clone();
313 self.on(event_name, move |payload| {
314 handler(response_event(payload, &session_id, &channel_name));
315 })
316 }
317
318 pub fn on_interruption_detected<F>(&self, handler: F) -> Listener
319 where
320 F: Fn(InterruptionEvent) + Send + Sync + 'static,
321 {
322 self.on_interruption_event(EVENT_INTERRUPTION_DETECTED, handler)
323 }
324
325 pub fn on_interruption_false_positive<F>(&self, handler: F) -> Listener
326 where
327 F: Fn(InterruptionEvent) + Send + Sync + 'static,
328 {
329 self.on_interruption_event(EVENT_INTERRUPTION_FALSE_POSITIVE, handler)
330 }
331
332 fn on_interruption_event<F>(&self, event_name: &'static str, handler: F) -> Listener
333 where
334 F: Fn(InterruptionEvent) + Send + Sync + 'static,
335 {
336 let session_id = self.session_id.clone();
337 let channel_name = self.channel_name.clone();
338 self.on(event_name, move |payload| {
339 handler(InterruptionEvent {
340 response: response_event(payload.clone(), &session_id, &channel_name),
341 vad_active_ms: optional_number(&payload, "vad_active_ms"),
342 partial_transcript: optional_string(&payload, "partial_transcript"),
343 });
344 })
345 }
346
347 pub fn on_browser_event<F>(&self, handler: F) -> Listener
348 where
349 F: Fn(BrowserEvent) + Send + Sync + 'static,
350 {
351 let session_id = self.session_id.clone();
352 let channel_name = self.channel_name.clone();
353 self.on(EVENT_BROWSER_EVENT, move |payload| {
354 handler(BrowserEvent {
355 session_id: base_session_id(&payload, &session_id),
356 channel_name: channel_name.clone(),
357 event: required_string(&payload, "event", ""),
358 payload: payload.get("payload").cloned().unwrap_or(Value::Null),
359 data: payload,
360 });
361 })
362 }
363
364 pub fn on_close<F>(&self, handler: F) -> Listener
365 where
366 F: Fn(CloseEvent) + Send + Sync + 'static,
367 {
368 let session_id = self.session_id.clone();
369 let channel_name = self.channel_name.clone();
370 self.on(EVENT_RTC_CLIENT_DISCONNECTED, move |payload| {
371 handler(CloseEvent {
372 session_id: base_session_id(&payload, &session_id),
373 channel_name: channel_name.clone(),
374 reason: required_string(&payload, "reason", "unknown"),
375 connection_state: optional_string(&payload, "connection_state"),
376 ice_connection_state: optional_string(&payload, "ice_connection_state"),
377 data_channel_state: optional_string(&payload, "data_channel_state"),
378 data: payload,
379 });
380 })
381 }
382
383 pub fn on_error<F>(&self, handler: F) -> Listener
384 where
385 F: Fn(ErrorEvent) + Send + Sync + 'static,
386 {
387 let session_id = self.session_id.clone();
388 let channel_name = self.channel_name.clone();
389 self.on(EVENT_ERROR, move |payload| {
390 handler(ErrorEvent {
391 session_id: base_session_id(&payload, &session_id),
392 channel_name: channel_name.clone(),
393 message: optional_string(&payload, "message"),
394 code: optional_string(&payload, "code"),
395 data: payload,
396 });
397 })
398 }
399
400 pub async fn send_control(&self, event: &str, payload: EventData) -> Result<()> {
401 self.channel.send_message(event, payload).await
402 }
403
404 pub async fn configure(&self, config: SessionConfig) -> Result<()> {
405 let mut session = config.extra;
406 insert_opt(&mut session, "stt_model", config.stt_model);
407 insert_opt(&mut session, "tts_model", config.tts_model);
408 insert_opt(&mut session, "voice", config.voice);
409 insert_opt(&mut session, "turn_profile", config.turn_profile);
410 insert_opt(&mut session, "vad_backend", config.vad_backend);
411 insert_opt(&mut session, "turn_detector", config.turn_detector);
412
413 let mut payload = EventData::new();
414 payload.insert("session".to_owned(), Value::Object(session));
415 self.send_control("session.update", payload).await
416 }
417
418 pub async fn start_response(&self, options: Option<ResponseOptions>) -> Result<()> {
419 self.send_control("response.start", response_options_payload(options))
420 .await
421 }
422
423 pub async fn append_response_text(
424 &self,
425 delta: impl Into<String>,
426 options: Option<ResponseOptions>,
427 ) -> Result<()> {
428 let mut payload = response_options_payload(options);
429 payload.insert("delta".to_owned(), Value::String(delta.into()));
430 self.send_control("response.delta", payload).await
431 }
432
433 pub async fn commit_response(&self) -> Result<()> {
434 self.send_control("response.commit", EventData::new()).await
435 }
436
437 pub async fn cancel_response(&self) -> Result<()> {
438 self.send_control("response.cancel", EventData::new()).await
439 }
440
441 pub async fn replace_response_text(
442 &self,
443 text: impl Into<String>,
444 options: Option<ResponseOptions>,
445 ) -> Result<()> {
446 let mut payload = response_options_payload(options);
447 payload.insert("text".to_owned(), Value::String(text.into()));
448 self.send_control("response.replace_text", payload).await
449 }
450
451 pub async fn send_text_response(
452 &self,
453 text: impl Into<String>,
454 options: Option<ResponseOptions>,
455 cancel_first: bool,
456 ) -> Result<()> {
457 let text = text.into();
458 if cancel_first {
459 return self.replace_response_text(text, options).await;
460 }
461 self.start_response(options.clone()).await?;
462 self.append_response_text(text, options).await?;
463 self.commit_response().await
464 }
465
466 pub async fn send_client_event(&self, envelope: ClientEventEnvelope) -> Result<()> {
467 let mut payload = EventData::new();
468 payload.insert("event".to_owned(), Value::String(envelope.event));
469 payload.insert("payload".to_owned(), envelope.payload);
470 self.send_control(EVENT_CLIENT_EVENT, payload).await
471 }
472}
473
474fn next_message(
475 result: std::result::Result<(String, EventData), RecvError>,
476) -> ControlFlow<(), Option<(String, EventData)>> {
477 match result {
478 Ok(message) => ControlFlow::Continue(Some(message)),
479 Err(RecvError::Lagged(_)) => ControlFlow::Continue(None),
480 Err(RecvError::Closed) => ControlFlow::Break(()),
481 }
482}
483
484fn insert_opt(session: &mut EventData, key: &str, value: Option<String>) {
485 if let Some(value) = value {
486 session.insert(key.to_owned(), Value::String(value));
487 }
488}
489
490fn response_options_payload(options: Option<ResponseOptions>) -> EventData {
491 let mut payload = EventData::new();
492 if let Some(options) = options
493 && let Some(allow) = options.allow_interruptions
494 {
495 payload.insert("allow_interruptions".to_owned(), Value::Bool(allow));
496 }
497 payload
498}
499
500fn base_session_id(payload: &EventData, fallback: &str) -> String {
501 required_string(payload, "session_id", fallback)
502}
503
504fn response_event(payload: EventData, session_id: &str, channel_name: &str) -> ResponseEvent {
505 ResponseEvent {
506 session_id: base_session_id(&payload, session_id),
507 channel_name: channel_name.to_owned(),
508 response_id: optional_string(&payload, "response_id"),
509 data: payload,
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516 use crate::socket::test_channel;
517 use serde_json::json;
518 use tokio::sync::broadcast;
519 use tokio::sync::mpsc;
520
521 async fn session() -> (
522 VoxRtcControlSession,
523 broadcast::Sender<(String, EventData)>,
524 ) {
525 let (channel, sender) = test_channel().await;
526 let session =
527 VoxRtcControlSession::new(channel, "sess-1".to_owned(), Duration::from_secs(1));
528 (session, sender)
529 }
530
531 fn payload(value: Value) -> EventData {
532 value.as_object().cloned().expect("object payload")
533 }
534
535 async fn recv<T>(rx: &mut mpsc::UnboundedReceiver<T>) -> T {
536 timeout(Duration::from_secs(1), rx.recv())
537 .await
538 .expect("handler fired within timeout")
539 .expect("handler produced an event")
540 }
541
542 #[test]
543 fn next_message_classifies_lag_close_and_ok() {
544 assert!(matches!(
545 next_message(Ok(("e".to_owned(), EventData::new()))),
546 ControlFlow::Continue(Some(_))
547 ));
548 assert!(matches!(
549 next_message(Err(RecvError::Lagged(7))),
550 ControlFlow::Continue(None)
551 ));
552 assert!(matches!(
553 next_message(Err(RecvError::Closed)),
554 ControlFlow::Break(())
555 ));
556 }
557
558 #[tokio::test]
559 async fn on_speech_started_fires_with_timestamp() {
560 let (session, sender) = session().await;
561 let (tx, mut rx) = mpsc::unbounded_channel();
562 let _listener = session.on_speech_started(move |event| {
563 tx.send(event).unwrap();
564 });
565 sender
566 .send((
567 EVENT_SPEECH_STARTED.to_owned(),
568 payload(json!({ "session_id": "sess-1", "timestamp_ms": 1234 })),
569 ))
570 .unwrap();
571 let event = recv(&mut rx).await;
572 assert_eq!(event.session_id, "sess-1");
573 assert_eq!(event.channel_name, "/rtc/sess-1");
574 assert_eq!(event.timestamp_ms, Some(1234.0));
575 }
576
577 #[tokio::test]
578 async fn on_speech_stopped_fires_with_timestamp() {
579 let (session, sender) = session().await;
580 let (tx, mut rx) = mpsc::unbounded_channel();
581 let _listener = session.on_speech_stopped(move |event| {
582 tx.send(event).unwrap();
583 });
584 sender
585 .send((
586 EVENT_SPEECH_STOPPED.to_owned(),
587 payload(json!({ "timestamp_ms": 5678 })),
588 ))
589 .unwrap();
590 let event = recv(&mut rx).await;
591 assert_eq!(event.timestamp_ms, Some(5678.0));
592 }
593
594 #[tokio::test]
595 async fn on_transcript_delta_fires_with_fields() {
596 let (session, sender) = session().await;
597 let (tx, mut rx) = mpsc::unbounded_channel();
598 let _listener = session.on_transcript_delta(move |event| {
599 tx.send(event).unwrap();
600 });
601 sender
602 .send((
603 EVENT_TRANSCRIPT_DELTA.to_owned(),
604 payload(json!({ "delta": "hel", "start_ms": 10, "end_ms": 20 })),
605 ))
606 .unwrap();
607 let event = recv(&mut rx).await;
608 assert_eq!(event.delta, "hel");
609 assert_eq!(event.start_ms, Some(10.0));
610 assert_eq!(event.end_ms, Some(20.0));
611 }
612
613 #[tokio::test]
614 async fn on_turn_eou_predicted_fires_with_fields() {
615 let (session, sender) = session().await;
616 let (tx, mut rx) = mpsc::unbounded_channel();
617 let _listener = session.on_turn_eou_predicted(move |event| {
618 tx.send(event).unwrap();
619 });
620 sender
621 .send((
622 EVENT_TURN_EOU_PREDICTED.to_owned(),
623 payload(json!({
624 "probability": 0.82,
625 "threshold": 0.5,
626 "delay_ms": 120,
627 "start_ms": 0,
628 "end_ms": 300,
629 "decision": "end",
630 "action": "commit",
631 "turn_detector": "smart"
632 })),
633 ))
634 .unwrap();
635 let event = recv(&mut rx).await;
636 assert_eq!(event.probability, Some(0.82));
637 assert_eq!(event.threshold, Some(0.5));
638 assert_eq!(event.delay_ms, Some(120.0));
639 assert_eq!(event.start_ms, Some(0.0));
640 assert_eq!(event.end_ms, Some(300.0));
641 assert_eq!(event.decision.as_deref(), Some("end"));
642 assert_eq!(event.action.as_deref(), Some("commit"));
643 assert_eq!(event.turn_detector.as_deref(), Some("smart"));
644 }
645
646 #[tokio::test]
647 async fn handler_survives_a_lagged_broadcast() {
648 let (session, sender) = session().await;
649 let (tx, mut rx) = mpsc::unbounded_channel();
650 let _listener = session.on_speech_started(move |event| {
651 tx.send(event.timestamp_ms).unwrap();
652 });
653
654 for index in 0..2100u32 {
655 let _ = sender.send((
656 EVENT_SPEECH_STARTED.to_owned(),
657 payload(json!({ "timestamp_ms": index })),
658 ));
659 }
660 let _ = sender.send((
661 EVENT_SPEECH_STARTED.to_owned(),
662 payload(json!({ "timestamp_ms": 9999 })),
663 ));
664
665 let mut saw_final = false;
666 while let Ok(Some(value)) = timeout(Duration::from_secs(1), rx.recv()).await {
667 if value == Some(9999.0) {
668 saw_final = true;
669 break;
670 }
671 }
672 assert!(
673 saw_final,
674 "loop must keep delivering events after a broadcast lag"
675 );
676 }
677}