1use std::path::Path;
15use std::time::Duration;
16
17use anyhow::{bail, Context, Result};
18use serde::{Deserialize, Serialize};
19
20pub type EventId = ulid::Ulid;
27
28pub type SpeakerId = String;
31
32pub type AudioChunk = Vec<i16>;
38
39pub trait AudioInput: Send {
45 fn next_chunk(&mut self) -> Option<AudioChunk>;
49}
50
51pub trait EventStream: Iterator<Item = Result<TranscriptEvent>> + Send {}
57
58impl<T> EventStream for T where T: Iterator<Item = Result<TranscriptEvent>> + Send {}
59
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub struct Word {
65 pub text: String,
67 #[serde(with = "duration_secs")]
69 pub start: Duration,
70 #[serde(with = "duration_secs")]
72 pub end: Duration,
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub confidence: Option<f32>,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
80#[serde(rename_all = "snake_case")]
81pub enum EndpointKind {
82 SilenceGap,
84 UtteranceEnd,
86 StreamEnd,
88}
89
90#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95#[serde(tag = "type", rename_all = "snake_case")]
96pub enum TranscriptEvent {
97 Partial {
100 text: String,
102 #[serde(with = "duration_secs")]
104 start: Duration,
105 #[serde(with = "duration_secs")]
107 end: Duration,
108 #[serde(skip_serializing_if = "Option::is_none")]
110 words: Option<Vec<Word>>,
111 #[serde(skip_serializing_if = "Option::is_none")]
113 speaker: Option<SpeakerId>,
114 },
115 Final {
120 event_id: EventId,
122 text: String,
124 #[serde(with = "duration_secs")]
126 start: Duration,
127 #[serde(with = "duration_secs")]
129 end: Duration,
130 confidence: f32,
132 #[serde(skip_serializing_if = "Option::is_none")]
134 words: Option<Vec<Word>>,
135 #[serde(skip_serializing_if = "Option::is_none")]
137 speaker: Option<SpeakerId>,
138 revisable: bool,
141 },
142 Endpoint {
144 #[serde(with = "duration_secs")]
146 at: Duration,
147 kind: EndpointKind,
149 },
150}
151
152pub trait Transcriber: Send + Sync {
158 fn transcribe(&self, audio: Box<dyn AudioInput>) -> Result<Box<dyn EventStream>>;
160}
161
162#[derive(Debug)]
172pub struct VecAudioInput {
173 samples: Vec<i16>,
174 cursor: usize,
175 chunk_samples: usize,
176}
177
178impl VecAudioInput {
179 pub fn from_wav_path(path: impl AsRef<Path>, chunk_samples: usize) -> Result<Self> {
183 let path = path.as_ref();
184 let mut reader = hound::WavReader::open(path)
185 .with_context(|| format!("Failed to open WAV at {}", path.display()))?;
186 let spec = reader.spec();
187 if spec.sample_rate != 16_000 {
188 bail!(
189 "WAV at {} must be 16000 Hz (got {}). Resample before constructing VecAudioInput.",
190 path.display(),
191 spec.sample_rate
192 );
193 }
194 if spec.channels != 1 {
195 bail!(
196 "WAV at {} must be mono (got {} channels). Mix down before constructing VecAudioInput.",
197 path.display(),
198 spec.channels
199 );
200 }
201 if spec.bits_per_sample != 16 || spec.sample_format != hound::SampleFormat::Int {
202 bail!(
203 "WAV at {} must be 16-bit signed PCM (got {}-bit {:?})",
204 path.display(),
205 spec.bits_per_sample,
206 spec.sample_format
207 );
208 }
209 let samples: Vec<i16> = reader
210 .samples::<i16>()
211 .collect::<Result<Vec<_>, _>>()
212 .with_context(|| format!("Failed to decode i16 PCM samples from {}", path.display()))?;
213 Ok(Self::from_samples(samples, chunk_samples))
214 }
215
216 pub fn from_samples(samples: Vec<i16>, chunk_samples: usize) -> Self {
219 Self {
220 samples,
221 cursor: 0,
222 chunk_samples: chunk_samples.max(1),
223 }
224 }
225}
226
227impl AudioInput for VecAudioInput {
228 fn next_chunk(&mut self) -> Option<AudioChunk> {
229 if self.cursor >= self.samples.len() {
230 return None;
231 }
232 let end = (self.cursor + self.chunk_samples).min(self.samples.len());
233 let chunk = self.samples[self.cursor..end].to_vec();
234 self.cursor = end;
235 Some(chunk)
236 }
237}
238
239mod duration_secs {
242 use serde::{Deserialize, Deserializer, Serializer};
243 use std::time::Duration;
244
245 pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
246 s.serialize_f64(d.as_secs_f64())
247 }
248
249 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
250 let secs = f64::deserialize(d)?;
251 Ok(Duration::from_secs_f64(secs.max(0.0)))
252 }
253}
254
255#[cfg(test)]
256#[allow(clippy::unwrap_used, clippy::expect_used)]
257mod tests {
258 use super::*;
259 use tempfile::TempDir;
260
261 fn write_fixture_wav(
262 dir: &TempDir,
263 name: &str,
264 sample_rate: u32,
265 channels: u16,
266 bits: u16,
267 samples: &[i16],
268 ) -> std::path::PathBuf {
269 let path = dir.path().join(name);
270 let spec = hound::WavSpec {
271 channels,
272 sample_rate,
273 bits_per_sample: bits,
274 sample_format: hound::SampleFormat::Int,
275 };
276 let mut writer = hound::WavWriter::create(&path, spec).unwrap();
277 for s in samples {
278 writer.write_sample(*s).unwrap();
279 }
280 writer.finalize().unwrap();
281 path
282 }
283
284 #[test]
285 fn vec_audio_input_from_samples_chunks_correctly() {
286 let mut input = VecAudioInput::from_samples(vec![1, 2, 3, 4, 5], 2);
287 assert_eq!(input.next_chunk(), Some(vec![1, 2]));
288 assert_eq!(input.next_chunk(), Some(vec![3, 4]));
289 assert_eq!(input.next_chunk(), Some(vec![5]));
290 assert_eq!(input.next_chunk(), None);
291 }
292
293 #[test]
294 fn vec_audio_input_zero_chunk_size_clamps_to_one() {
295 let mut input = VecAudioInput::from_samples(vec![10, 20], 0);
296 assert_eq!(input.next_chunk(), Some(vec![10]));
297 assert_eq!(input.next_chunk(), Some(vec![20]));
298 assert_eq!(input.next_chunk(), None);
299 }
300
301 #[test]
302 fn vec_audio_input_empty_yields_none() {
303 let mut input = VecAudioInput::from_samples(vec![], 16);
304 assert!(input.next_chunk().is_none());
305 }
306
307 #[test]
308 fn vec_audio_input_reads_16k_mono_i16_wav() {
309 let tmp = TempDir::new().unwrap();
310 let path = write_fixture_wav(&tmp, "ok.wav", 16_000, 1, 16, &[100, 200, 300, 400]);
311 let mut input = VecAudioInput::from_wav_path(&path, 2).unwrap();
312 assert_eq!(input.next_chunk(), Some(vec![100, 200]));
313 assert_eq!(input.next_chunk(), Some(vec![300, 400]));
314 assert!(input.next_chunk().is_none());
315 }
316
317 #[test]
318 fn vec_audio_input_rejects_wrong_sample_rate() {
319 let tmp = TempDir::new().unwrap();
320 let path = write_fixture_wav(&tmp, "44k.wav", 44_100, 1, 16, &[0, 0]);
321 let err = VecAudioInput::from_wav_path(&path, 16).unwrap_err();
322 assert!(err.to_string().contains("16000 Hz"), "got: {err}");
323 }
324
325 #[test]
326 fn vec_audio_input_rejects_stereo() {
327 let tmp = TempDir::new().unwrap();
328 let path = write_fixture_wav(&tmp, "stereo.wav", 16_000, 2, 16, &[0, 0, 0, 0]);
329 let err = VecAudioInput::from_wav_path(&path, 16).unwrap_err();
330 assert!(err.to_string().contains("mono"), "got: {err}");
331 }
332
333 #[test]
334 fn vec_audio_input_rejects_wrong_bit_depth() {
335 let tmp = TempDir::new().unwrap();
336 let path = dir_with_wav_f32(&tmp);
337 let err = VecAudioInput::from_wav_path(&path, 16).unwrap_err();
338 assert!(err.to_string().contains("16-bit"), "got: {err}");
339 }
340
341 fn dir_with_wav_f32(dir: &TempDir) -> std::path::PathBuf {
342 let path = dir.path().join("f32.wav");
343 let spec = hound::WavSpec {
344 channels: 1,
345 sample_rate: 16_000,
346 bits_per_sample: 32,
347 sample_format: hound::SampleFormat::Float,
348 };
349 let mut writer = hound::WavWriter::create(&path, spec).unwrap();
350 writer.write_sample(0.0_f32).unwrap();
351 writer.finalize().unwrap();
352 path
353 }
354
355 #[test]
356 fn vec_audio_input_missing_file_errors() {
357 let err = VecAudioInput::from_wav_path("/nope/does/not/exist.wav", 16).unwrap_err();
358 assert!(err.to_string().contains("Failed to open WAV"), "got: {err}");
359 }
360
361 #[test]
362 fn event_stream_blanket_impl_compiles() {
363 fn accepts(_s: Box<dyn EventStream>) {}
366 let events: Vec<Result<TranscriptEvent>> = vec![Ok(TranscriptEvent::Endpoint {
367 at: Duration::from_secs(1),
368 kind: EndpointKind::StreamEnd,
369 })];
370 accepts(Box::new(events.into_iter()));
371 }
372
373 #[test]
374 fn transcript_event_serde_round_trips() {
375 let event = TranscriptEvent::Final {
376 event_id: ulid::Ulid::from_parts(0, 1),
377 text: "hello".to_string(),
378 start: Duration::from_millis(0),
379 end: Duration::from_millis(500),
380 confidence: 0.97,
381 words: None,
382 speaker: None,
383 revisable: false,
384 };
385 let json = serde_json::to_string(&event).unwrap();
386 let back: TranscriptEvent = serde_json::from_str(&json).unwrap();
387 assert_eq!(event, back);
388 }
389
390 #[test]
391 fn duration_serialises_as_seconds() {
392 let event = TranscriptEvent::Endpoint {
393 at: Duration::from_millis(1500),
394 kind: EndpointKind::StreamEnd,
395 };
396 let json = serde_json::to_string(&event).unwrap();
397 assert!(
398 json.contains("\"at\":1.5"),
399 "duration should serialise as f64 seconds, got: {json}"
400 );
401 }
402
403 #[test]
404 fn duration_deserialise_rejects_non_numeric_seconds() {
405 let bad_json = r#"{"type":"endpoint","at":"not a number","kind":"stream_end"}"#;
409 let result: Result<TranscriptEvent, _> = serde_json::from_str(bad_json);
410 assert!(result.is_err(), "expected deserialization to fail");
411 }
412
413 #[test]
414 fn vec_audio_input_propagates_decode_failure() {
415 let tmp = TempDir::new().unwrap();
419 let path = write_fixture_wav(&tmp, "truncated.wav", 16_000, 1, 16, &[1, 2, 3, 4]);
420 let len = std::fs::metadata(&path).unwrap().len();
421 std::fs::OpenOptions::new()
422 .write(true)
423 .open(&path)
424 .unwrap()
425 .set_len(len - 1)
426 .unwrap();
427 let err = VecAudioInput::from_wav_path(&path, 16).unwrap_err();
428 assert!(
429 err.to_string().contains("Failed to decode i16 PCM samples"),
430 "got: {err}"
431 );
432 }
433}