Skip to main content

omni_dev/voice/
transcriber.rs

1//! Transcriber trait and event types per issues #799 and #801.
2//!
3//! `Transcriber` is the contract every speech-to-text backend implements,
4//! whether batch (this issue) or streaming (#806). It takes an
5//! [`AudioInput`] producing 16 kHz mono signed-PCM chunks and returns an
6//! [`EventStream`] of [`TranscriptEvent`]s.
7//!
8//! The separation from [`crate::voice::AudioSource`] (#800) is deliberate:
9//! `AudioSource` is the *hardware-capture* seam (variable rate, variable
10//! channels, `f32`, intentionally `!Send` on macOS per ADR-0031);
11//! `AudioInput` is the *post-mixdown-and-resample* seam (16 kHz mono i16,
12//! `Send`) that ASR engines consume natively. See ADR-0032 for the rationale.
13
14use std::path::Path;
15use std::time::Duration;
16
17use anyhow::{bail, Context, Result};
18use serde::{Deserialize, Serialize};
19
20/// Monotonically-unique identifier for a `Final` event, used by downstream
21/// consumers (commit-message generation, history merging) to deduplicate
22/// across overlapping streaming windows.
23///
24/// ULID rather than UUIDv4 because we want timestamp ordering when finals
25/// arrive out-of-order from a streaming backend (#806). Per #799.
26pub type EventId = ulid::Ulid;
27
28/// Diarisation tag attached to a segment when speaker labelling is on
29/// (#805). Always `None` for the batch backend in #801.
30pub type SpeakerId = String;
31
32/// 16 kHz mono signed 16-bit PCM samples, in capture order.
33///
34/// Chunk size is up to the [`AudioInput`] implementation; a `Transcriber`
35/// drains every chunk before running inference. Empty chunks are permitted
36/// and treated as "more is coming".
37pub type AudioChunk = Vec<i16>;
38
39/// Source of 16 kHz mono signed-PCM audio for transcription.
40///
41/// Distinct from [`crate::voice::AudioSource`] (which is `!Send`, f32, and
42/// variable-rate) — see the module docs and ADR-0032 for why the seam
43/// splits here.
44pub trait AudioInput: Send {
45    /// Returns the next chunk of samples, or `None` when the input is
46    /// exhausted. Implementations may yield chunks of any size; consumers
47    /// must not rely on a particular chunk boundary.
48    fn next_chunk(&mut self) -> Option<AudioChunk>;
49}
50
51/// Stream of transcription events. A blanket impl is provided for any
52/// iterator producing `Result<TranscriptEvent>` that is also `Send`.
53///
54/// Sync `Iterator` shape for the batch backend in #801; the async `Stream`
55/// variant lands alongside streaming work in #806.
56pub trait EventStream: Iterator<Item = Result<TranscriptEvent>> + Send {}
57
58impl<T> EventStream for T where T: Iterator<Item = Result<TranscriptEvent>> + Send {}
59
60/// First-class word-level alignment, optionally returned by backends that
61/// expose it. The batch backend in #801 always emits `None`; word-level
62/// alignment is a backend opt-in, not a guarantee.
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub struct Word {
65    /// The word's text, as it appeared in the source language.
66    pub text: String,
67    /// Start of the word, in stream-relative seconds.
68    #[serde(with = "duration_secs")]
69    pub start: Duration,
70    /// End of the word, in stream-relative seconds.
71    #[serde(with = "duration_secs")]
72    pub end: Duration,
73    /// Per-word confidence in `[0.0, 1.0]`, when the backend provides it.
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub confidence: Option<f32>,
76}
77
78/// What ended a speech region.
79#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
80#[serde(rename_all = "snake_case")]
81pub enum EndpointKind {
82    /// A silence gap exceeded the endpointer's threshold.
83    SilenceGap,
84    /// The speaker explicitly stopped (e.g. push-to-talk release).
85    UtteranceEnd,
86    /// The input source signalled end-of-stream.
87    StreamEnd,
88}
89
90/// One event emitted by a [`Transcriber`].
91///
92/// `Partial` carries no `event_id` because partials supersede each other —
93/// only `Final` is durable enough to deduplicate against. Per #799.
94#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95#[serde(tag = "type", rename_all = "snake_case")]
96pub enum TranscriptEvent {
97    /// Hypothesis text that may still change. Streaming backends emit
98    /// these; the batch backend in #801 never does.
99    Partial {
100        /// The current best-guess text for this region.
101        text: String,
102        /// Start of the region, in stream-relative seconds.
103        #[serde(with = "duration_secs")]
104        start: Duration,
105        /// End of the region, in stream-relative seconds.
106        #[serde(with = "duration_secs")]
107        end: Duration,
108        /// Word-level alignment, when the backend provides it.
109        #[serde(skip_serializing_if = "Option::is_none")]
110        words: Option<Vec<Word>>,
111        /// Diarisation tag, when speaker labelling is on (#805).
112        #[serde(skip_serializing_if = "Option::is_none")]
113        speaker: Option<SpeakerId>,
114    },
115    /// Committed text for a region. `revisable` is `false` for batch
116    /// backends and for streaming backends that have endpointed the
117    /// region; `true` only when a streaming backend may still revise
118    /// the text in a later pass.
119    Final {
120        /// Unique identifier for deduplication across overlapping windows.
121        event_id: EventId,
122        /// The committed transcript text.
123        text: String,
124        /// Start of the region, in stream-relative seconds.
125        #[serde(with = "duration_secs")]
126        start: Duration,
127        /// End of the region, in stream-relative seconds.
128        #[serde(with = "duration_secs")]
129        end: Duration,
130        /// Segment-level confidence in `[0.0, 1.0]`.
131        confidence: f32,
132        /// Word-level alignment, when the backend provides it.
133        #[serde(skip_serializing_if = "Option::is_none")]
134        words: Option<Vec<Word>>,
135        /// Diarisation tag, when speaker labelling is on (#805).
136        #[serde(skip_serializing_if = "Option::is_none")]
137        speaker: Option<SpeakerId>,
138        /// Whether this Final may still be revised by a later pass. Batch
139        /// backends always set this to `false`.
140        revisable: bool,
141    },
142    /// Marks the end of a speech region or the stream itself.
143    Endpoint {
144        /// Time of the endpoint, in stream-relative seconds.
145        #[serde(with = "duration_secs")]
146        at: Duration,
147        /// What kind of endpoint this is.
148        kind: EndpointKind,
149    },
150}
151
152/// Speech-to-text backend.
153///
154/// `Send + Sync` so a single transcriber can be shared across worker
155/// threads (e.g., one model, many concurrent inputs). Backends that hold
156/// non-thread-safe handles internally wrap them in `Mutex`.
157pub trait Transcriber: Send + Sync {
158    /// Consumes an audio input and returns the resulting event stream.
159    fn transcribe(&self, audio: Box<dyn AudioInput>) -> Result<Box<dyn EventStream>>;
160}
161
162/// In-memory [`AudioInput`] adapter — reads a 16 kHz mono 16-bit PCM WAV
163/// from disk (or accepts an in-memory `Vec<i16>`) and yields it in fixed-
164/// size chunks.
165///
166/// Refuses WAVs that are not 16 kHz mono 16-bit signed PCM: the contract
167/// of [`AudioInput`] is that samples are already at the rate the
168/// transcriber expects. Resampling, channel mixdown, and bit-depth
169/// conversion happen *before* a `VecAudioInput` is constructed (in the
170/// streaming pipeline, that's downstream of [`crate::voice::AudioSource`]).
171#[derive(Debug)]
172pub struct VecAudioInput {
173    samples: Vec<i16>,
174    cursor: usize,
175    chunk_samples: usize,
176}
177
178impl VecAudioInput {
179    /// Loads a 16 kHz mono i16 PCM WAV from `path` and chunks it into
180    /// pieces of `chunk_samples` samples each (last chunk may be shorter).
181    /// `chunk_samples` is clamped to at least 1.
182    pub fn from_wav_path(path: impl AsRef<Path>, chunk_samples: usize) -> Result<Self> {
183        let path = path.as_ref();
184        let mut reader = hound::WavReader::open(path)
185            .with_context(|| format!("Failed to open WAV at {}", path.display()))?;
186        let spec = reader.spec();
187        if spec.sample_rate != 16_000 {
188            bail!(
189                "WAV at {} must be 16000 Hz (got {}). Resample before constructing VecAudioInput.",
190                path.display(),
191                spec.sample_rate
192            );
193        }
194        if spec.channels != 1 {
195            bail!(
196                "WAV at {} must be mono (got {} channels). Mix down before constructing VecAudioInput.",
197                path.display(),
198                spec.channels
199            );
200        }
201        if spec.bits_per_sample != 16 || spec.sample_format != hound::SampleFormat::Int {
202            bail!(
203                "WAV at {} must be 16-bit signed PCM (got {}-bit {:?})",
204                path.display(),
205                spec.bits_per_sample,
206                spec.sample_format
207            );
208        }
209        let samples: Vec<i16> = reader
210            .samples::<i16>()
211            .collect::<Result<Vec<_>, _>>()
212            .with_context(|| format!("Failed to decode i16 PCM samples from {}", path.display()))?;
213        Ok(Self::from_samples(samples, chunk_samples))
214    }
215
216    /// Builds an input from an in-memory `Vec<i16>` (already 16 kHz mono).
217    /// Useful for synthesised test signals.
218    pub fn from_samples(samples: Vec<i16>, chunk_samples: usize) -> Self {
219        Self {
220            samples,
221            cursor: 0,
222            chunk_samples: chunk_samples.max(1),
223        }
224    }
225}
226
227impl AudioInput for VecAudioInput {
228    fn next_chunk(&mut self) -> Option<AudioChunk> {
229        if self.cursor >= self.samples.len() {
230            return None;
231        }
232        let end = (self.cursor + self.chunk_samples).min(self.samples.len());
233        let chunk = self.samples[self.cursor..end].to_vec();
234        self.cursor = end;
235        Some(chunk)
236    }
237}
238
239/// Serde helper: serialises a `Duration` as a floating-point number of
240/// seconds, so JSONL snapshots are human-readable and diff-friendly.
241mod duration_secs {
242    use serde::{Deserialize, Deserializer, Serializer};
243    use std::time::Duration;
244
245    pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
246        s.serialize_f64(d.as_secs_f64())
247    }
248
249    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
250        let secs = f64::deserialize(d)?;
251        Ok(Duration::from_secs_f64(secs.max(0.0)))
252    }
253}
254
255#[cfg(test)]
256#[allow(clippy::unwrap_used, clippy::expect_used)]
257mod tests {
258    use super::*;
259    use tempfile::TempDir;
260
261    fn write_fixture_wav(
262        dir: &TempDir,
263        name: &str,
264        sample_rate: u32,
265        channels: u16,
266        bits: u16,
267        samples: &[i16],
268    ) -> std::path::PathBuf {
269        let path = dir.path().join(name);
270        let spec = hound::WavSpec {
271            channels,
272            sample_rate,
273            bits_per_sample: bits,
274            sample_format: hound::SampleFormat::Int,
275        };
276        let mut writer = hound::WavWriter::create(&path, spec).unwrap();
277        for s in samples {
278            writer.write_sample(*s).unwrap();
279        }
280        writer.finalize().unwrap();
281        path
282    }
283
284    #[test]
285    fn vec_audio_input_from_samples_chunks_correctly() {
286        let mut input = VecAudioInput::from_samples(vec![1, 2, 3, 4, 5], 2);
287        assert_eq!(input.next_chunk(), Some(vec![1, 2]));
288        assert_eq!(input.next_chunk(), Some(vec![3, 4]));
289        assert_eq!(input.next_chunk(), Some(vec![5]));
290        assert_eq!(input.next_chunk(), None);
291    }
292
293    #[test]
294    fn vec_audio_input_zero_chunk_size_clamps_to_one() {
295        let mut input = VecAudioInput::from_samples(vec![10, 20], 0);
296        assert_eq!(input.next_chunk(), Some(vec![10]));
297        assert_eq!(input.next_chunk(), Some(vec![20]));
298        assert_eq!(input.next_chunk(), None);
299    }
300
301    #[test]
302    fn vec_audio_input_empty_yields_none() {
303        let mut input = VecAudioInput::from_samples(vec![], 16);
304        assert!(input.next_chunk().is_none());
305    }
306
307    #[test]
308    fn vec_audio_input_reads_16k_mono_i16_wav() {
309        let tmp = TempDir::new().unwrap();
310        let path = write_fixture_wav(&tmp, "ok.wav", 16_000, 1, 16, &[100, 200, 300, 400]);
311        let mut input = VecAudioInput::from_wav_path(&path, 2).unwrap();
312        assert_eq!(input.next_chunk(), Some(vec![100, 200]));
313        assert_eq!(input.next_chunk(), Some(vec![300, 400]));
314        assert!(input.next_chunk().is_none());
315    }
316
317    #[test]
318    fn vec_audio_input_rejects_wrong_sample_rate() {
319        let tmp = TempDir::new().unwrap();
320        let path = write_fixture_wav(&tmp, "44k.wav", 44_100, 1, 16, &[0, 0]);
321        let err = VecAudioInput::from_wav_path(&path, 16).unwrap_err();
322        assert!(err.to_string().contains("16000 Hz"), "got: {err}");
323    }
324
325    #[test]
326    fn vec_audio_input_rejects_stereo() {
327        let tmp = TempDir::new().unwrap();
328        let path = write_fixture_wav(&tmp, "stereo.wav", 16_000, 2, 16, &[0, 0, 0, 0]);
329        let err = VecAudioInput::from_wav_path(&path, 16).unwrap_err();
330        assert!(err.to_string().contains("mono"), "got: {err}");
331    }
332
333    #[test]
334    fn vec_audio_input_rejects_wrong_bit_depth() {
335        let tmp = TempDir::new().unwrap();
336        let path = dir_with_wav_f32(&tmp);
337        let err = VecAudioInput::from_wav_path(&path, 16).unwrap_err();
338        assert!(err.to_string().contains("16-bit"), "got: {err}");
339    }
340
341    fn dir_with_wav_f32(dir: &TempDir) -> std::path::PathBuf {
342        let path = dir.path().join("f32.wav");
343        let spec = hound::WavSpec {
344            channels: 1,
345            sample_rate: 16_000,
346            bits_per_sample: 32,
347            sample_format: hound::SampleFormat::Float,
348        };
349        let mut writer = hound::WavWriter::create(&path, spec).unwrap();
350        writer.write_sample(0.0_f32).unwrap();
351        writer.finalize().unwrap();
352        path
353    }
354
355    #[test]
356    fn vec_audio_input_missing_file_errors() {
357        let err = VecAudioInput::from_wav_path("/nope/does/not/exist.wav", 16).unwrap_err();
358        assert!(err.to_string().contains("Failed to open WAV"), "got: {err}");
359    }
360
361    #[test]
362    fn event_stream_blanket_impl_compiles() {
363        // Just ensure `Vec<Result<TranscriptEvent>>::into_iter()` satisfies
364        // `EventStream` so backends can build their streams trivially.
365        fn accepts(_s: Box<dyn EventStream>) {}
366        let events: Vec<Result<TranscriptEvent>> = vec![Ok(TranscriptEvent::Endpoint {
367            at: Duration::from_secs(1),
368            kind: EndpointKind::StreamEnd,
369        })];
370        accepts(Box::new(events.into_iter()));
371    }
372
373    #[test]
374    fn transcript_event_serde_round_trips() {
375        let event = TranscriptEvent::Final {
376            event_id: ulid::Ulid::from_parts(0, 1),
377            text: "hello".to_string(),
378            start: Duration::from_millis(0),
379            end: Duration::from_millis(500),
380            confidence: 0.97,
381            words: None,
382            speaker: None,
383            revisable: false,
384        };
385        let json = serde_json::to_string(&event).unwrap();
386        let back: TranscriptEvent = serde_json::from_str(&json).unwrap();
387        assert_eq!(event, back);
388    }
389
390    #[test]
391    fn duration_serialises_as_seconds() {
392        let event = TranscriptEvent::Endpoint {
393            at: Duration::from_millis(1500),
394            kind: EndpointKind::StreamEnd,
395        };
396        let json = serde_json::to_string(&event).unwrap();
397        assert!(
398            json.contains("\"at\":1.5"),
399            "duration should serialise as f64 seconds, got: {json}"
400        );
401    }
402
403    #[test]
404    fn duration_deserialise_rejects_non_numeric_seconds() {
405        // The `duration_secs` helper's deserialize path returns an error
406        // when the JSON value isn't a number — pin that behaviour so
407        // future changes to the serde shape don't silently swallow it.
408        let bad_json = r#"{"type":"endpoint","at":"not a number","kind":"stream_end"}"#;
409        let result: Result<TranscriptEvent, _> = serde_json::from_str(bad_json);
410        assert!(result.is_err(), "expected deserialization to fail");
411    }
412
413    #[test]
414    fn vec_audio_input_propagates_decode_failure() {
415        // Truncate a valid WAV mid-sample so hound's i16 iterator errors
416        // on the last read. Exercises the `.with_context(…)` arm in
417        // VecAudioInput::from_wav_path that wraps decode failures.
418        let tmp = TempDir::new().unwrap();
419        let path = write_fixture_wav(&tmp, "truncated.wav", 16_000, 1, 16, &[1, 2, 3, 4]);
420        let len = std::fs::metadata(&path).unwrap().len();
421        std::fs::OpenOptions::new()
422            .write(true)
423            .open(&path)
424            .unwrap()
425            .set_len(len - 1)
426            .unwrap();
427        let err = VecAudioInput::from_wav_path(&path, 16).unwrap_err();
428        assert!(
429            err.to_string().contains("Failed to decode i16 PCM samples"),
430            "got: {err}"
431        );
432    }
433}