use crate::enrollment::TranscriptRole;
use crate::{Listener, Result, VoiceConfig, VoiceError, VoiceEvent};
use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Weak};
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::{oneshot, Mutex};
use tokio::task::JoinHandle;
use tracing::info;
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
pub trait VoiceEventSink: Send + Sync {
fn send(&self, session_id: &str, event_json: String);
}
pub struct VoiceSession {
session_id: String,
listener: Arc<Mutex<Box<dyn Listener>>>,
drain_handle: Mutex<Option<JoinHandle<()>>>,
stop_tx: Mutex<Option<oneshot::Sender<()>>>,
last_active_at: Arc<AtomicU64>,
}
impl VoiceSession {
pub fn new(session_id: impl Into<String>, listener: Box<dyn Listener>) -> Self {
Self {
session_id: session_id.into(),
listener: Arc::new(Mutex::new(listener)),
drain_handle: Mutex::new(None),
stop_tx: Mutex::new(None),
last_active_at: Arc::new(AtomicU64::new(now_secs())),
}
}
pub fn last_active_at(&self) -> u64 {
self.last_active_at.load(Ordering::Relaxed)
}
fn shareable_last_active(&self) -> Arc<AtomicU64> {
self.last_active_at.clone()
}
pub fn session_id(&self) -> &str {
&self.session_id
}
pub async fn start(&self, config: VoiceConfig, sink: Arc<dyn VoiceEventSink>) -> Result<()> {
let mut rx = {
let mut listener = self.listener.lock().await;
listener.start(config).await?
};
let session_id = self.session_id.clone();
let last_active = self.shareable_last_active();
let (stop_tx, mut stop_rx) = oneshot::channel::<()>();
*self.stop_tx.lock().await = Some(stop_tx);
let handle = tokio::spawn(async move {
loop {
tokio::select! {
biased;
_ = &mut stop_rx => break,
maybe_evt = rx.recv() => {
match maybe_evt {
Some(evt) => {
last_active.store(now_secs(), Ordering::Relaxed);
sink.send(&session_id, voice_event_to_json(&evt));
}
None => break, }
}
}
}
while let Ok(evt) = rx.try_recv() {
last_active.store(now_secs(), Ordering::Relaxed);
sink.send(&session_id, voice_event_to_json(&evt));
}
sink.send(&session_id, r#"{"type":"done"}"#.to_string());
});
*self.drain_handle.lock().await = Some(handle);
Ok(())
}
pub async fn stop(&self) -> Result<()> {
if let Some(tx) = self.stop_tx.lock().await.take() {
let _ = tx.send(());
}
let stop_result = {
let mut listener = self.listener.lock().await;
listener.stop().await
};
match stop_result {
Ok(()) | Err(VoiceError::NotRunning) => {}
Err(e) => return Err(e),
}
if let Some(h) = self.drain_handle.lock().await.take() {
let _ = h.await;
}
Ok(())
}
}
fn voice_event_to_json(evt: &VoiceEvent) -> String {
match evt {
VoiceEvent::SpeechStart => r#"{"type":"speech_start"}"#.to_string(),
VoiceEvent::SpeechEnd => r#"{"type":"speech_end"}"#.to_string(),
VoiceEvent::Transcript {
text,
duration_ms,
role,
} => serde_json::json!({
"type": "transcript",
"text": text,
"duration_ms": duration_ms,
"role": role_to_str(role),
})
.to_string(),
VoiceEvent::Partial { text, duration_ms } => serde_json::json!({
"type": "partial",
"text": text,
"duration_ms": duration_ms,
})
.to_string(),
VoiceEvent::AudioChunk {
samples,
sample_rate,
} => serde_json::json!({
"type": "audio_chunk",
"sample_rate": sample_rate,
"frame_count": samples.len(),
})
.to_string(),
VoiceEvent::BargeIn => r#"{"type":"barge_in"}"#.to_string(),
VoiceEvent::EnrollmentCaptured { label, save_path } => serde_json::json!({
"type": "enrollment_captured",
"label": label,
"save_path": save_path.display().to_string(),
})
.to_string(),
VoiceEvent::EnrollmentFailed { reason } => serde_json::json!({
"type": "enrollment_failed",
"reason": reason,
})
.to_string(),
}
}
fn role_to_str(role: &TranscriptRole) -> String {
match role {
TranscriptRole::EnrolledUser => "enrolled_user".to_string(),
TranscriptRole::OtherSpeaker { local_id } => format!("other:{}", local_id),
TranscriptRole::Unknown => "unknown".to_string(),
}
}
pub struct VoiceSessionRegistry {
sessions: DashMap<String, VoiceSession>,
sweep_handle: std::sync::Mutex<Option<JoinHandle<()>>>,
idle_timeout_secs: u64,
sweep_interval_secs: u64,
}
const DEFAULT_IDLE_TIMEOUT_SECS: u64 = 300;
const DEFAULT_SWEEP_INTERVAL_SECS: u64 = 60;
impl Default for VoiceSessionRegistry {
fn default() -> Self {
Self {
sessions: DashMap::new(),
sweep_handle: std::sync::Mutex::new(None),
idle_timeout_secs: DEFAULT_IDLE_TIMEOUT_SECS,
sweep_interval_secs: DEFAULT_SWEEP_INTERVAL_SECS,
}
}
}
impl VoiceSessionRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn with_idle_timeout_secs(mut self, secs: u64) -> Self {
self.idle_timeout_secs = secs;
self
}
pub fn with_sweep_interval_secs(mut self, secs: u64) -> Self {
self.sweep_interval_secs = secs;
self
}
pub fn start_sweeper(self: &Arc<Self>) {
let mut guard = match self.sweep_handle.lock() {
Ok(g) => g,
Err(p) => p.into_inner(),
};
if guard.is_some() {
return;
}
let rt = match tokio::runtime::Handle::try_current() {
Ok(h) => h,
Err(_) => {
tracing::warn!(
"VoiceSessionRegistry::start_sweeper called outside a tokio runtime; \
sweeper not started — idle sessions will not be reaped automatically"
);
return;
}
};
let weak = Arc::downgrade(self);
let interval_secs = self.sweep_interval_secs;
let handle = rt.spawn(async move {
run_sweeper(weak, interval_secs).await;
});
*guard = Some(handle);
}
pub fn insert(&self, session: VoiceSession) -> Result<()> {
let id = session.session_id.clone();
if self.sessions.contains_key(&id) {
return Err(VoiceError::AlreadyRunning);
}
self.sessions.insert(id, session);
Ok(())
}
pub async fn stop(&self, session_id: &str) -> Result<()> {
let session = self
.sessions
.remove(session_id)
.ok_or(VoiceError::NotRunning)?
.1;
session.stop().await
}
pub fn contains(&self, session_id: &str) -> bool {
self.sessions.contains_key(session_id)
}
pub fn list(&self) -> Vec<String> {
self.sessions.iter().map(|e| e.key().clone()).collect()
}
pub fn with<F, R>(&self, session_id: &str, f: F) -> Result<R>
where
F: FnOnce(&VoiceSession) -> R,
{
let entry = self
.sessions
.get(session_id)
.ok_or(VoiceError::NotRunning)?;
Ok(f(entry.value()))
}
pub async fn reap_idle(&self) -> Vec<String> {
let cutoff = now_secs().saturating_sub(self.idle_timeout_secs);
let stale: Vec<String> = self
.sessions
.iter()
.filter(|e| e.value().last_active_at() < cutoff)
.map(|e| e.key().clone())
.collect();
for id in &stale {
if let Err(e) = self.stop(id).await {
info!(session_id = %id, error = %e, "voice session reap: stop returned error");
} else {
info!(session_id = %id, "voice session reaped (idle)");
}
}
stale
}
}
async fn run_sweeper(weak: Weak<VoiceSessionRegistry>, interval_secs: u64) {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
interval.tick().await;
loop {
interval.tick().await;
match weak.upgrade() {
Some(reg) => {
let _ = reg.reap_idle().await;
}
None => break, }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Listener;
use async_trait::async_trait;
use std::path::PathBuf;
use std::sync::Mutex;
use tokio::sync::mpsc;
struct ScriptedListener {
script: Vec<VoiceEvent>,
running: bool,
}
impl ScriptedListener {
fn new(script: Vec<VoiceEvent>) -> Self {
Self {
script,
running: false,
}
}
}
#[async_trait]
impl Listener for ScriptedListener {
async fn start(&mut self, _config: VoiceConfig) -> Result<mpsc::Receiver<VoiceEvent>> {
let (tx, rx) = mpsc::channel(16);
self.running = true;
let script = std::mem::take(&mut self.script);
tokio::spawn(async move {
for evt in script {
if tx.send(evt).await.is_err() {
break;
}
}
});
Ok(rx)
}
async fn stop(&mut self) -> Result<()> {
self.running = false;
Ok(())
}
fn is_running(&self) -> bool {
self.running
}
}
#[derive(Default)]
struct CollectingSink {
events: Mutex<Vec<(String, String)>>,
}
impl VoiceEventSink for CollectingSink {
fn send(&self, session_id: &str, event_json: String) {
self.events
.lock()
.unwrap()
.push((session_id.to_string(), event_json));
}
}
#[tokio::test]
async fn drains_listener_events_to_sink_as_json() {
let listener = ScriptedListener::new(vec![
VoiceEvent::SpeechStart,
VoiceEvent::Transcript {
text: "hello world".into(),
duration_ms: 1100,
role: TranscriptRole::EnrolledUser,
},
VoiceEvent::SpeechEnd,
]);
let sink: Arc<CollectingSink> = Arc::new(CollectingSink::default());
let session = VoiceSession::new("s1", Box::new(listener));
session
.start(VoiceConfig::default(), sink.clone())
.await
.unwrap();
for _ in 0..50 {
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
if sink
.events
.lock()
.unwrap()
.last()
.map(|(_, j)| j.contains(r#""done""#))
.unwrap_or(false)
{
break;
}
}
let events = sink.events.lock().unwrap().clone();
let types: Vec<&str> = events
.iter()
.map(|(_, j)| {
if j.contains(r#""speech_start""#) {
"speech_start"
} else if j.contains(r#""transcript""#) {
"transcript"
} else if j.contains(r#""speech_end""#) {
"speech_end"
} else if j.contains(r#""done""#) {
"done"
} else {
"other"
}
})
.collect();
assert_eq!(
types,
vec!["speech_start", "transcript", "speech_end", "done"]
);
for (sid, _) in &events {
assert_eq!(sid, "s1");
}
let transcript_json = events
.iter()
.find(|(_, j)| j.contains(r#""transcript""#))
.map(|(_, j)| j.clone())
.unwrap();
assert!(transcript_json.contains(r#""text":"hello world""#));
assert!(transcript_json.contains(r#""role":"enrolled_user""#));
assert!(transcript_json.contains(r#""duration_ms":1100"#));
}
#[tokio::test]
async fn registry_rejects_duplicate_session_ids() {
let registry = VoiceSessionRegistry::new();
let l1 = Box::new(ScriptedListener::new(vec![]));
let l2 = Box::new(ScriptedListener::new(vec![]));
registry.insert(VoiceSession::new("dup", l1)).unwrap();
let err = registry
.insert(VoiceSession::new("dup", l2))
.expect_err("duplicate id should error");
matches!(err, VoiceError::AlreadyRunning);
assert_eq!(registry.list(), vec!["dup".to_string()]);
}
#[tokio::test]
async fn reap_idle_reaps_stale_keeps_fresh() {
let registry = Arc::new(
VoiceSessionRegistry::new()
.with_idle_timeout_secs(60)
.with_sweep_interval_secs(60),
);
let sink: Arc<CollectingSink> = Arc::new(CollectingSink::default());
let stale = VoiceSession::new("stale", Box::new(ScriptedListener::new(vec![])));
stale
.start(VoiceConfig::default(), sink.clone())
.await
.unwrap();
stale
.last_active_at
.store(now_secs().saturating_sub(3600), Ordering::Relaxed);
let fresh = VoiceSession::new("fresh", Box::new(ScriptedListener::new(vec![])));
fresh
.start(VoiceConfig::default(), sink.clone())
.await
.unwrap();
registry.insert(stale).unwrap();
registry.insert(fresh).unwrap();
let reaped = registry.reap_idle().await;
assert_eq!(reaped, vec!["stale".to_string()]);
assert!(!registry.contains("stale"));
assert!(registry.contains("fresh"));
}
#[tokio::test]
async fn reap_idle_returns_empty_when_no_sessions_are_stale() {
let registry = Arc::new(VoiceSessionRegistry::new().with_idle_timeout_secs(60));
let sink: Arc<CollectingSink> = Arc::new(CollectingSink::default());
let session = VoiceSession::new("s1", Box::new(ScriptedListener::new(vec![])));
session.start(VoiceConfig::default(), sink).await.unwrap();
registry.insert(session).unwrap();
assert!(registry.reap_idle().await.is_empty());
assert!(registry.contains("s1"));
}
#[tokio::test]
async fn sweeper_exits_when_registry_dropped() {
let registry = Arc::new(
VoiceSessionRegistry::new()
.with_sweep_interval_secs(1)
.with_idle_timeout_secs(1),
);
registry.start_sweeper();
let handle = registry
.sweep_handle
.lock()
.unwrap()
.take()
.expect("sweeper should be spawned");
drop(registry);
tokio::time::timeout(std::time::Duration::from_secs(5), handle)
.await
.expect("sweeper task did not exit within 5s after registry drop")
.expect("sweeper task panicked");
}
#[tokio::test]
async fn start_sweeper_is_idempotent() {
let registry = Arc::new(VoiceSessionRegistry::new());
registry.start_sweeper();
registry.start_sweeper();
assert!(registry.sweep_handle.lock().unwrap().is_some());
}
#[tokio::test]
async fn registry_stop_removes_session() {
let registry = VoiceSessionRegistry::new();
let listener = Box::new(ScriptedListener::new(vec![VoiceEvent::SpeechStart]));
let sink: Arc<CollectingSink> = Arc::new(CollectingSink::default());
let session = VoiceSession::new("s1", listener);
session
.start(VoiceConfig::default(), sink.clone())
.await
.unwrap();
registry.insert(session).unwrap();
assert!(registry.contains("s1"));
registry.stop("s1").await.unwrap();
assert!(!registry.contains("s1"));
}
#[test]
fn role_to_str_covers_all_variants() {
assert_eq!(role_to_str(&TranscriptRole::EnrolledUser), "enrolled_user");
assert_eq!(role_to_str(&TranscriptRole::Unknown), "unknown");
assert_eq!(
role_to_str(&TranscriptRole::OtherSpeaker {
local_id: "alice".into()
}),
"other:alice"
);
}
#[test]
fn partial_event_serializes_with_text_and_duration() {
let json = voice_event_to_json(&VoiceEvent::Partial {
text: "hello wor".to_string(),
duration_ms: 750,
});
assert!(json.contains(r#""type":"partial""#));
assert!(json.contains(r#""text":"hello wor""#));
assert!(json.contains(r#""duration_ms":750"#));
}
#[test]
fn audio_chunk_emits_meta_only() {
let json = voice_event_to_json(&VoiceEvent::AudioChunk {
samples: vec![0i16; 480],
sample_rate: 16000,
});
assert!(json.contains(r#""sample_rate":16000"#));
assert!(json.contains(r#""frame_count":480"#));
assert!(!json.contains("samples"));
}
#[test]
fn enrollment_event_carries_path() {
let json = voice_event_to_json(&VoiceEvent::EnrollmentCaptured {
label: "alice".into(),
save_path: PathBuf::from("/tmp/alice.toml"),
});
assert!(json.contains(r#""label":"alice""#));
assert!(json.contains("/tmp/alice.toml"));
}
}