1use std::collections::HashMap;
29use std::sync::atomic::{AtomicU64, Ordering};
30use std::sync::{Arc, Mutex};
31
32use car_eventlog::{EventKind, EventLog};
33use car_inference::{GenerateRequest, InferenceEngine, StreamEvent};
34use serde_json::Value;
35use tokio::sync::{mpsc, oneshot};
36use tokio_util::sync::CancellationToken;
37
38fn next_turn_id() -> u64 {
40 static COUNTER: AtomicU64 = AtomicU64::new(1);
41 COUNTER.fetch_add(1, Ordering::Relaxed)
42}
43
44#[derive(Debug, Clone)]
47pub struct SidecarResult {
48 pub turn_id: u64,
52 pub text: String,
54 pub data: Option<serde_json::Value>,
57}
58
59#[derive(Debug, thiserror::Error)]
61pub enum VoiceTurnError {
62 #[error("inference failed: {0}")]
66 Inference(String),
67 #[error("turn cancelled (barge-in or supersession)")]
70 Cancelled,
71}
72
73#[derive(Clone)]
79pub struct VoiceTurnControl {
80 pub turn_id: u64,
83 cancel: CancellationToken,
84}
85
86impl VoiceTurnControl {
87 pub fn cancel(&self) {
89 self.cancel.cancel();
90 }
91
92 pub fn is_cancelled(&self) -> bool {
94 self.cancel.is_cancelled()
95 }
96}
97
98pub struct VoiceTurnHandle {
111 pub control: VoiceTurnControl,
112 pub fast: mpsc::Receiver<StreamEvent>,
113 pub sidecar: oneshot::Receiver<Result<SidecarResult, VoiceTurnError>>,
114}
115
116impl VoiceTurnHandle {
117 pub fn turn_id(&self) -> u64 {
119 self.control.turn_id
120 }
121
122 pub fn cancel(&self) {
124 self.control.cancel();
125 }
126}
127
128#[async_trait::async_trait]
142pub trait DirectDataFetcher: Send + Sync {
143 async fn try_fetch(&self, utterance: &str) -> Option<Result<String, String>>;
150}
151
152#[derive(Clone)]
156pub struct VoiceTelemetry {
157 log: Arc<Mutex<EventLog>>,
158}
159
160impl VoiceTelemetry {
161 pub fn new(log: Arc<Mutex<EventLog>>) -> Self {
163 Self { log }
164 }
165
166 pub fn emit(&self, kind: EventKind, turn_id: u64, extra: Vec<(&str, Value)>) {
170 let mut data: HashMap<String, Value> = HashMap::new();
171 data.insert("turn_id".to_string(), Value::from(turn_id));
172 for (k, v) in extra {
173 data.insert(k.to_string(), v);
174 }
175 if let Ok(mut guard) = self.log.lock() {
176 guard.append(kind, None, None, data);
177 }
178 }
179}
180
181pub fn dispatch_voice_turn(
194 engine: Arc<InferenceEngine>,
195 utterance: String,
196 fast_request: GenerateRequest,
197 sidecar_request: GenerateRequest,
198) -> VoiceTurnHandle {
199 dispatch_voice_turn_with_telemetry(engine, utterance, fast_request, sidecar_request, None)
200}
201
202pub fn dispatch_voice_turn_with_telemetry(
207 engine: Arc<InferenceEngine>,
208 _utterance: String,
209 fast_request: GenerateRequest,
210 sidecar_request: GenerateRequest,
211 telemetry: Option<VoiceTelemetry>,
212) -> VoiceTurnHandle {
213 let turn_id = next_turn_id();
214 let cancel = CancellationToken::new();
215 let (fast_tx, fast_rx) = mpsc::channel::<StreamEvent>(64);
216 let (sidecar_tx, sidecar_rx) = oneshot::channel();
217
218 if let Some(t) = telemetry.as_ref() {
219 t.emit(EventKind::VoiceFastTurnStarted, turn_id, vec![]);
220 }
221
222 spawn_fast_task(
223 engine.clone(),
224 fast_request,
225 fast_tx,
226 cancel.clone(),
227 turn_id,
228 telemetry.clone(),
229 );
230 spawn_sidecar_task(
231 engine,
232 sidecar_request,
233 sidecar_tx,
234 cancel.clone(),
235 turn_id,
236 telemetry,
237 );
238
239 VoiceTurnHandle {
240 control: VoiceTurnControl { turn_id, cancel },
241 fast: fast_rx,
242 sidecar: sidecar_rx,
243 }
244}
245
246pub fn dispatch_voice_turn_sidecar_only(
254 engine: Arc<InferenceEngine>,
255 utterance: String,
256 sidecar_request: GenerateRequest,
257) -> VoiceTurnHandle {
258 dispatch_voice_turn_sidecar_only_with_telemetry(engine, utterance, sidecar_request, None)
259}
260
261pub fn dispatch_voice_turn_sidecar_only_with_telemetry(
264 engine: Arc<InferenceEngine>,
265 utterance: String,
266 sidecar_request: GenerateRequest,
267 telemetry: Option<VoiceTelemetry>,
268) -> VoiceTurnHandle {
269 dispatch_voice_turn_sidecar_only_with_classifier(
270 engine,
271 utterance,
272 sidecar_request,
273 None,
274 telemetry,
275 )
276}
277
278pub fn dispatch_voice_turn_sidecar_only_with_classifier(
284 engine: Arc<InferenceEngine>,
285 utterance: String,
286 sidecar_request: GenerateRequest,
287 fetcher: Option<Arc<dyn DirectDataFetcher>>,
288 telemetry: Option<VoiceTelemetry>,
289) -> VoiceTurnHandle {
290 let turn_id = next_turn_id();
291 let cancel = CancellationToken::new();
292 let (fast_tx, fast_rx) = mpsc::channel::<StreamEvent>(1);
295 drop(fast_tx);
296 let (sidecar_tx, sidecar_rx) = oneshot::channel();
297
298 spawn_sidecar_task_classified(
299 engine,
300 utterance,
301 sidecar_request,
302 sidecar_tx,
303 cancel.clone(),
304 turn_id,
305 fetcher,
306 telemetry,
307 );
308
309 VoiceTurnHandle {
310 control: VoiceTurnControl { turn_id, cancel },
311 fast: fast_rx,
312 sidecar: sidecar_rx,
313 }
314}
315
316fn spawn_fast_task(
317 engine: Arc<InferenceEngine>,
318 request: GenerateRequest,
319 out: mpsc::Sender<StreamEvent>,
320 cancel: CancellationToken,
321 turn_id: u64,
322 telemetry: Option<VoiceTelemetry>,
323) {
324 tokio::spawn(async move {
325 let cancelled_during = tokio::select! {
326 biased;
327 _ = cancel.cancelled() => {
328 tracing::debug!(turn_id, "fast task cancelled before inference start");
329 true
330 }
331 res = engine.generate_tracked_stream(request) => {
332 match res {
333 Ok(mut rx) => {
334 relay_fast_stream(&mut rx, &out, &cancel, turn_id).await;
335 cancel.is_cancelled()
336 }
337 Err(e) => {
338 tracing::error!(turn_id, error=%e, "fast turn inference failed");
339 false
340 }
341 }
342 }
343 };
344 if let Some(t) = telemetry {
345 if cancelled_during {
346 t.emit(EventKind::VoiceTurnCancelled, turn_id, vec![("track", "fast".into())]);
347 } else {
348 t.emit(EventKind::VoiceFastTurnEnded, turn_id, vec![]);
349 }
350 }
351 });
352}
353
354async fn relay_fast_stream(
355 rx: &mut mpsc::Receiver<StreamEvent>,
356 out: &mpsc::Sender<StreamEvent>,
357 cancel: &CancellationToken,
358 turn_id: u64,
359) {
360 loop {
361 tokio::select! {
362 biased;
363 _ = cancel.cancelled() => {
364 tracing::debug!(turn_id, "fast stream cancelled mid-relay");
365 break;
366 }
367 evt = rx.recv() => match evt {
368 Some(e) => {
369 if out.send(e).await.is_err() {
370 break;
372 }
373 }
374 None => break,
375 }
376 }
377 }
378}
379
380fn spawn_sidecar_task_classified(
381 engine: Arc<InferenceEngine>,
382 utterance: String,
383 request: GenerateRequest,
384 sender: oneshot::Sender<Result<SidecarResult, VoiceTurnError>>,
385 cancel: CancellationToken,
386 turn_id: u64,
387 fetcher: Option<Arc<dyn DirectDataFetcher>>,
388 telemetry: Option<VoiceTelemetry>,
389) {
390 tokio::spawn(async move {
391 if let Some(f) = fetcher.as_ref() {
393 let fetch_outcome = tokio::select! {
394 biased;
395 _ = cancel.cancelled() => None,
396 outcome = f.try_fetch(&utterance) => outcome,
397 };
398 match fetch_outcome {
399 Some(Ok(text)) => {
400 let result = Ok(SidecarResult {
401 turn_id,
402 text: text.clone(),
403 data: None,
404 });
405 if let Some(t) = telemetry {
406 t.emit(
407 EventKind::VoiceSidecarResolved,
408 turn_id,
409 vec![
410 ("text_len", Value::from(text.len())),
411 ("source", "direct_fetch".into()),
412 ],
413 );
414 }
415 let _ = sender.send(result);
416 return;
417 }
418 Some(Err(e)) => {
419 tracing::debug!(turn_id, error=%e, "DirectDataFetcher errored; falling through to LLM");
420 }
421 None => { }
422 }
423 if cancel.is_cancelled() {
426 let _ = sender.send(Err(VoiceTurnError::Cancelled));
427 if let Some(t) = telemetry {
428 t.emit(
429 EventKind::VoiceTurnCancelled,
430 turn_id,
431 vec![("track", "sidecar".into())],
432 );
433 }
434 return;
435 }
436 }
437 run_llm_sidecar(engine, request, sender, cancel, turn_id, telemetry).await;
438 });
439}
440
441async fn run_llm_sidecar(
442 engine: Arc<InferenceEngine>,
443 request: GenerateRequest,
444 sender: oneshot::Sender<Result<SidecarResult, VoiceTurnError>>,
445 cancel: CancellationToken,
446 turn_id: u64,
447 telemetry: Option<VoiceTelemetry>,
448) {
449 let result = tokio::select! {
450 biased;
451 _ = cancel.cancelled() => Err(VoiceTurnError::Cancelled),
452 res = engine.generate(request) => {
453 res.map(|text| SidecarResult { turn_id, text, data: None })
454 .map_err(|e| VoiceTurnError::Inference(e.to_string()))
455 }
456 };
457 if let Some(t) = telemetry {
458 match &result {
459 Ok(r) => t.emit(
460 EventKind::VoiceSidecarResolved,
461 turn_id,
462 vec![("text_len", Value::from(r.text.len()))],
463 ),
464 Err(VoiceTurnError::Cancelled) => {
465 t.emit(EventKind::VoiceTurnCancelled, turn_id, vec![("track", "sidecar".into())]);
466 }
467 Err(VoiceTurnError::Inference(e)) => {
468 t.emit(
469 EventKind::VoiceSidecarFailed,
470 turn_id,
471 vec![("error", Value::from(e.clone()))],
472 );
473 }
474 }
475 }
476 let _ = sender.send(result);
477}
478
479fn spawn_sidecar_task(
480 engine: Arc<InferenceEngine>,
481 request: GenerateRequest,
482 sender: oneshot::Sender<Result<SidecarResult, VoiceTurnError>>,
483 cancel: CancellationToken,
484 turn_id: u64,
485 telemetry: Option<VoiceTelemetry>,
486) {
487 tokio::spawn(run_llm_sidecar(
488 engine, request, sender, cancel, turn_id, telemetry,
489 ));
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn turn_ids_are_monotonic_and_unique() {
498 let a = next_turn_id();
499 let b = next_turn_id();
500 let c = next_turn_id();
501 assert!(b > a);
502 assert!(c > b);
503 }
504
505 #[test]
506 fn control_cancel_is_observable() {
507 let control = VoiceTurnControl {
508 turn_id: 42,
509 cancel: CancellationToken::new(),
510 };
511 assert!(!control.is_cancelled());
512 let clone = control.clone();
513 clone.cancel();
514 assert!(control.is_cancelled());
515 }
516
517 #[test]
518 fn handle_turn_id_delegates_to_control() {
519 let (_tx, fast_rx) = mpsc::channel::<StreamEvent>(1);
520 let (_stx, sidecar_rx) = oneshot::channel();
521 let handle = VoiceTurnHandle {
522 control: VoiceTurnControl {
523 turn_id: 7,
524 cancel: CancellationToken::new(),
525 },
526 fast: fast_rx,
527 sidecar: sidecar_rx,
528 };
529 assert_eq!(handle.turn_id(), 7);
530 assert!(!handle.control.is_cancelled());
531 handle.cancel();
532 assert!(handle.control.is_cancelled());
533 }
534
535 #[tokio::test]
536 async fn closed_fast_channel_recv_is_none() {
537 let (fast_tx, mut fast_rx) = mpsc::channel::<StreamEvent>(1);
541 drop(fast_tx);
542 assert!(fast_rx.recv().await.is_none());
543 }
544
545 #[tokio::test]
546 async fn cancellation_propagates_to_relay_fast_stream() {
547 let (in_tx, mut in_rx) = mpsc::channel::<StreamEvent>(8);
551 let (out_tx, mut out_rx) = mpsc::channel::<StreamEvent>(8);
552 let cancel = CancellationToken::new();
553
554 let producer = tokio::spawn(async move {
557 for i in 0..100u32 {
558 if in_tx.send(StreamEvent::TextDelta(format!("d{i}"))).await.is_err() {
559 break;
560 }
561 }
562 });
563
564 let cancel_clone = cancel.clone();
565 let relay = tokio::spawn(async move {
566 relay_fast_stream(&mut in_rx, &out_tx, &cancel_clone, 1).await;
567 });
568
569 let first = out_rx.recv().await.expect("first event");
571 match first {
572 StreamEvent::TextDelta(_) => {}
573 other => panic!("unexpected event: {other:?}"),
574 }
575 cancel.cancel();
576
577 tokio::time::timeout(std::time::Duration::from_secs(1), relay)
579 .await
580 .expect("relay did not exit after cancel")
581 .expect("relay panicked");
582
583 producer.abort();
584 }
585
586 #[tokio::test]
587 async fn direct_fetcher_hit_skips_llm_and_resolves_sidecar() {
588 struct Hit;
589 #[async_trait::async_trait]
590 impl DirectDataFetcher for Hit {
591 async fn try_fetch(&self, _u: &str) -> Option<Result<String, String>> {
592 Some(Ok("3 emails: Bob, Alice, Carol".to_string()))
593 }
594 }
595 let cancel = CancellationToken::new();
596 let (tx, rx) = oneshot::channel();
597 let log = Arc::new(Mutex::new(EventLog::new()));
598 let telemetry = VoiceTelemetry::new(log.clone());
599 let dummy_engine = Arc::new(car_inference::InferenceEngine::new(
603 car_inference::InferenceConfig::default(),
604 ));
605 spawn_sidecar_task_classified(
606 dummy_engine,
607 "any new email today".to_string(),
608 GenerateRequest::default(),
609 tx,
610 cancel,
611 99,
612 Some(Arc::new(Hit)),
613 Some(telemetry),
614 );
615 let r = rx.await.expect("oneshot delivered").expect("ok");
616 assert_eq!(r.turn_id, 99);
617 assert_eq!(r.text, "3 emails: Bob, Alice, Carol");
618 let g = log.lock().unwrap();
620 let evt = g.events().last().expect("event emitted");
621 assert_eq!(evt.kind, EventKind::VoiceSidecarResolved);
622 assert_eq!(evt.data.get("source"), Some(&Value::from("direct_fetch")));
623 }
624
625 #[tokio::test]
626 async fn direct_fetcher_miss_falls_through_but_we_observe_no_short_circuit() {
627 struct Miss;
634 #[async_trait::async_trait]
635 impl DirectDataFetcher for Miss {
636 async fn try_fetch(&self, _u: &str) -> Option<Result<String, String>> {
637 None
638 }
639 }
640 let cancel = CancellationToken::new();
641 let (tx, rx) = oneshot::channel();
642 let dummy_engine = Arc::new(car_inference::InferenceEngine::new(
643 car_inference::InferenceConfig::default(),
644 ));
645 spawn_sidecar_task_classified(
646 dummy_engine,
647 "what's the weather".to_string(),
648 GenerateRequest::default(),
649 tx,
650 cancel.clone(),
651 100,
652 Some(Arc::new(Miss)),
653 None,
654 );
655 cancel.cancel();
657 match rx.await.expect("oneshot delivered") {
658 Err(VoiceTurnError::Cancelled) => {}
659 other => panic!("expected Cancelled after fetcher miss + cancel, got {other:?}"),
660 }
661 }
662
663 #[test]
664 fn telemetry_emit_appends_to_eventlog() {
665 let log = Arc::new(Mutex::new(EventLog::new()));
666 let telemetry = VoiceTelemetry::new(log.clone());
667 telemetry.emit(EventKind::VoiceFastTurnStarted, 7, vec![]);
668 telemetry.emit(
669 EventKind::VoiceSidecarResolved,
670 7,
671 vec![("text_len", Value::from(42usize))],
672 );
673 let g = log.lock().unwrap();
674 let events = g.events();
675 assert_eq!(events.len(), 2);
676 assert_eq!(events[0].kind, EventKind::VoiceFastTurnStarted);
677 assert_eq!(events[0].data.get("turn_id"), Some(&Value::from(7u64)));
678 assert_eq!(events[1].kind, EventKind::VoiceSidecarResolved);
679 assert_eq!(events[1].data.get("text_len"), Some(&Value::from(42usize)));
680 }
681
682 #[tokio::test]
683 async fn dropped_out_channel_stops_relay_without_cancel() {
684 let (in_tx, mut in_rx) = mpsc::channel::<StreamEvent>(8);
685 let (out_tx, out_rx) = mpsc::channel::<StreamEvent>(8);
686 let cancel = CancellationToken::new();
687
688 drop(out_rx);
691
692 let cancel_clone = cancel.clone();
693 let relay = tokio::spawn(async move {
694 relay_fast_stream(&mut in_rx, &out_tx, &cancel_clone, 1).await;
695 });
696
697 in_tx.send(StreamEvent::TextDelta("x".into())).await.unwrap();
699
700 tokio::time::timeout(std::time::Duration::from_secs(1), relay)
701 .await
702 .expect("relay did not exit after out_rx drop")
703 .expect("relay panicked");
704 }
705}