use std::collections::VecDeque;
use std::sync::mpsc::{channel, Receiver, Sender};
use crate::{AsrError, AudioFrame, Channel, StreamingAsr, TranscriptEvent};
pub struct MockAsr {
tx: Sender<TranscriptEvent>,
script: VecDeque<TranscriptEvent>,
finished: bool,
audio_pushes: u64,
}
impl MockAsr {
pub fn new() -> (Self, Receiver<TranscriptEvent>) {
Self::with_script(default_script())
}
pub fn with_script(script: Vec<TranscriptEvent>) -> (Self, Receiver<TranscriptEvent>) {
let (tx, rx) = channel();
(
Self {
tx,
script: script.into(),
finished: false,
audio_pushes: 0,
},
rx,
)
}
pub fn audio_pushes(&self) -> u64 {
self.audio_pushes
}
}
impl StreamingAsr for MockAsr {
fn push_audio(&mut self, _frame: &AudioFrame, _channel: Channel) -> Result<(), AsrError> {
if self.finished {
return Err(AsrError::AlreadyFinished);
}
self.audio_pushes += 1;
if let Some(evt) = self.script.pop_front() {
let _ = self.tx.send(evt);
}
Ok(())
}
fn finish(&mut self) -> Result<(), AsrError> {
if self.finished {
return Err(AsrError::AlreadyFinished);
}
self.finished = true;
let _ = self.tx.send(TranscriptEvent::SpeechEnded {
channel: Channel::Local,
ts_ms: 0,
});
Ok(())
}
}
fn default_script() -> Vec<TranscriptEvent> {
vec![
TranscriptEvent::SpeechStarted {
channel: Channel::Local,
ts_ms: 0,
},
TranscriptEvent::Partial {
channel: Channel::Local,
ts_ms: 0,
text: "hello".into(),
},
TranscriptEvent::Partial {
channel: Channel::Local,
ts_ms: 0,
text: "hello there".into(),
},
TranscriptEvent::Final {
channel: Channel::Local,
ts_ms: 0,
end_ms: 1500,
text: "hello there".into(),
confidence: 1.0,
},
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_script_emits_partials_then_final() {
let (mut asr, rx) = MockAsr::new();
let samples = vec![0i16; 160];
let frame = AudioFrame::new(&samples, 16_000);
for _ in 0..4 {
asr.push_audio(&frame, Channel::Local).unwrap();
}
let events: Vec<_> = rx.try_iter().collect();
assert_eq!(events.len(), 4);
assert!(matches!(events[3], TranscriptEvent::Final { .. }));
}
#[test]
fn finish_twice_errors() {
let (mut asr, _rx) = MockAsr::new();
asr.finish().unwrap();
assert!(matches!(asr.finish(), Err(AsrError::AlreadyFinished)));
}
}