1#![doc = include_str!("../README.md")]
2
3pub mod surface;
4use std::collections::BTreeMap;
5use std::fmt;
6use std::fs;
7use std::path::{Path, PathBuf};
8use std::process::{Child, Command, ExitStatus, Output, Stdio};
9use std::time::Duration;
10
11use audio_analysis_core::OwnedAudioWaveformBatch;
12use serde::Deserialize;
13use serde_json::Value;
14use text_core::{tokenize, tokenize_words, TextProcessingOptions, TokenKind};
15
16mod whisper_cpp;
17
18use thiserror::Error;
19use video_analysis_core::{AnalysisEvent, OwnedTextSegment, TextAnalyzer, TextSegment, Timestamp};
20use video_analysis_ingest::{
21 MediaSourceInfo, SourceMode, TextFormat as IngestTextFormat, TextSegmentSource, TextStreamInfo,
22};
23pub mod contracts;
24pub use contracts::{
25 text_segment_contract_with_source, TranscriptCharContract, TranscriptSegmentContract,
26 TranscriptWordContract, TranscriptionContract,
27};
28pub use whisper_cpp::{
30 transcription_catalog as whisper_cpp_catalog, whisper_cpp_system_info,
31 ModelStore as WhisperCppModelStore, WhisperCppCatalog, WhisperCppConfig, WhisperCppError,
32 WhisperCppModel, WhisperCppModelStatus, WhisperCppPhase, WhisperCppProgressEvent,
33 WhisperCppSegment, WhisperCppTranscriber as NativeWhisperCppTranscriber,
34};
35
36#[derive(Debug, Error)]
37pub enum TranscriptionError {
39 #[error("I/O error: {0}")]
40 Io(#[from] std::io::Error),
42 #[error("invalid transcript JSON: {0}")]
43 Json(#[from] serde_json::Error),
45 #[error("invalid transcript: {0}")]
46 InvalidTranscript(String),
48 #[error("transcriber command `{0}` failed")]
49 CommandFailed(String),
51 #[error("transcriber command `{command}` timed out after {seconds} seconds")]
52 CommandTimeout {
54 command: String,
56 seconds: u64,
58 },
59 #[error("{0}")]
60 Detect(#[from] video_analysis_core::DetectError),
62 #[error("{0}")]
63 WhisperCpp(#[from] WhisperCppError),
65}
66
67pub type Result<T> = std::result::Result<T, TranscriptionError>;
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum TranscriptFormat {
73 Plain,
75 Lines,
77 WhisperJson,
79 Srt,
81 WebVtt,
83}
84
85impl TranscriptFormat {
86 pub fn from_extension(extension: &str) -> Option<Self> {
88 match extension
89 .trim()
90 .trim_start_matches('.')
91 .to_ascii_lowercase()
92 .as_str()
93 {
94 "txt" | "text" => Some(Self::Plain),
95 "lines" => Some(Self::Lines),
96 "json" | "whisper" | "whisperjson" | "whisper-json" => Some(Self::WhisperJson),
97 "srt" => Some(Self::Srt),
98 "vtt" | "webvtt" | "web-vtt" => Some(Self::WebVtt),
99 _ => None,
100 }
101 }
102}
103
104#[derive(Debug, Clone, PartialEq)]
105pub struct TranscriptSegment {
107 pub index: u64,
109 pub start_seconds: Option<f64>,
111 pub end_seconds: Option<f64>,
113 pub text: String,
115 pub language: Option<String>,
117 pub speaker: Option<String>,
119 pub confidence: Option<f32>,
121 pub is_final: bool,
123}
124
125impl TranscriptSegment {
126 pub fn metadata(&self) -> BTreeMap<String, String> {
128 let mut metadata = BTreeMap::new();
129 insert_optional(&mut metadata, "language", self.language.as_deref());
130 insert_optional(&mut metadata, "speaker", self.speaker.as_deref());
131 insert_optional_number(&mut metadata, "start_seconds", self.start_seconds);
132 insert_optional_number(&mut metadata, "end_seconds", self.end_seconds);
133 insert_optional_display(&mut metadata, "confidence", self.confidence);
134 metadata
135 }
136
137 pub fn metadata_with_source(&self, source: impl Into<String>) -> BTreeMap<String, String> {
139 let mut metadata = self.metadata();
140 let source = source.into();
141 if !source.is_empty() {
142 metadata.insert("source".to_string(), source);
143 }
144 metadata
145 }
146}
147
148#[derive(Debug, Clone, PartialEq)]
149pub struct TranscriptionResult {
151 pub text: Option<String>,
153 pub language: Option<String>,
155 pub segments: Vec<TranscriptSegment>,
157 pub source: Option<String>,
159}
160
161#[derive(Debug, Clone, PartialEq)]
162pub struct TranscriptWord {
164 pub text: String,
166 pub start_seconds: Option<f64>,
168 pub end_seconds: Option<f64>,
170 pub confidence: Option<f32>,
172}
173
174pub trait Transcriber {
176 fn transcribe(&mut self, input: &Path) -> Result<TranscriptionResult>;
178}
179
180#[derive(Debug, Clone)]
181pub struct SubtitleNormalizationOptions {
183 pub strip_markup: bool,
185 pub decode_basic_entities: bool,
187 pub collapse_whitespace: bool,
189}
190
191impl Default for SubtitleNormalizationOptions {
192 fn default() -> Self {
193 Self {
194 strip_markup: true,
195 decode_basic_entities: true,
196 collapse_whitespace: true,
197 }
198 }
199}
200
201pub fn normalize_subtitle_text(text: &str, options: SubtitleNormalizationOptions) -> String {
203 let mut normalized = text.to_string();
204 if options.strip_markup {
205 normalized = strip_subtitle_markup(&normalized);
206 }
207 if options.decode_basic_entities {
208 normalized = decode_basic_entities(&normalized);
209 }
210 if options.collapse_whitespace {
211 normalized = collapse_whitespace(&normalized);
212 }
213 normalized
214}
215
216#[derive(Debug, Clone)]
217pub struct CommandTranscriberOptions {
219 pub command: PathBuf,
221 pub args: Vec<String>,
223 pub format: TranscriptFormat,
225 pub timeout_seconds: Option<u64>,
227}
228
229#[derive(Debug, Default, Clone)]
230pub struct TranscriptHeuristicAnalyzer;
232
233impl TextAnalyzer for TranscriptHeuristicAnalyzer {
234 fn name(&self) -> &str {
235 "transcript_heuristics"
236 }
237
238 fn process_segment(
239 &mut self,
240 segment: &TextSegment<'_>,
241 ) -> video_analysis_core::Result<Vec<AnalysisEvent>> {
242 let mut events = Vec::new();
243 let text = segment.text.trim();
244 if text.ends_with(['?', '؟', '?']) {
245 events.push(event_at(self.name(), "speech:question", segment.timestamp));
246 }
247 if has_token_kind(text, TokenKind::Url) {
248 events.push(event_at(self.name(), "speech:url", segment.timestamp));
249 }
250 if has_token_kind(text, TokenKind::Number) {
251 events.push(event_at(self.name(), "speech:number", segment.timestamp));
252 }
253 if tokenize_words(text).len() >= 30 {
254 events.push(event_at(
255 self.name(),
256 "speech:long_segment",
257 segment.timestamp,
258 ));
259 }
260 Ok(events)
261 }
262}
263
264#[derive(Debug, Clone)]
265pub struct CommandTranscriber {
267 command: PathBuf,
268 args: Vec<String>,
269 format: TranscriptFormat,
270 timeout_seconds: Option<u64>,
271}
272
273impl CommandTranscriber {
274 pub fn new(command: impl Into<PathBuf>, format: TranscriptFormat) -> Self {
276 Self {
277 command: command.into(),
278 args: Vec::new(),
279 format,
280 timeout_seconds: None,
281 }
282 }
283
284 pub fn from_options(options: CommandTranscriberOptions) -> Self {
286 Self {
287 command: options.command,
288 args: options.args,
289 format: options.format,
290 timeout_seconds: options.timeout_seconds,
291 }
292 }
293
294 pub fn args(mut self, args: impl IntoIterator<Item = String>) -> Self {
296 self.args.extend(args);
297 self
298 }
299
300 pub fn timeout_seconds(mut self, timeout_seconds: Option<u64>) -> Self {
302 self.timeout_seconds = timeout_seconds;
303 self
304 }
305}
306
307impl Transcriber for CommandTranscriber {
308 fn transcribe(&mut self, input: &Path) -> Result<TranscriptionResult> {
309 let child = Command::new(&self.command)
310 .args(&self.args)
311 .arg(input)
312 .stdin(Stdio::null())
313 .stdout(Stdio::piped())
314 .stderr(Stdio::piped())
315 .spawn()?;
316 let output = wait_with_optional_timeout(child, &self.command, self.timeout_seconds)?;
317 if !output.status.success() {
318 return Err(TranscriptionError::CommandFailed(
319 self.command.display().to_string(),
320 ));
321 }
322 parse_transcript_bytes(&output.stdout, self.format)
323 }
324}
325
326#[derive(Debug, Clone)]
327pub struct WhisperCliTranscriberOptions {
329 pub command: PathBuf,
331 pub args: Vec<String>,
333 pub output_dir: Option<PathBuf>,
335 pub timeout_seconds: Option<u64>,
337}
338
339#[derive(Debug, Clone)]
340pub struct WhisperCliTranscriber {
342 command: PathBuf,
343 args: Vec<String>,
344 output_dir: Option<PathBuf>,
345 timeout_seconds: Option<u64>,
346}
347
348impl WhisperCliTranscriber {
349 pub fn new(command: impl Into<PathBuf>) -> Self {
351 Self {
352 command: command.into(),
353 args: Vec::new(),
354 output_dir: None,
355 timeout_seconds: None,
356 }
357 }
358
359 pub fn from_options(options: WhisperCliTranscriberOptions) -> Self {
361 Self {
362 command: options.command,
363 args: options.args,
364 output_dir: options.output_dir,
365 timeout_seconds: options.timeout_seconds,
366 }
367 }
368
369 pub fn args(mut self, args: impl IntoIterator<Item = String>) -> Self {
371 self.args.extend(args);
372 self
373 }
374
375 pub fn output_dir(mut self, output_dir: impl Into<PathBuf>) -> Self {
377 self.output_dir = Some(output_dir.into());
378 self
379 }
380
381 pub fn timeout_seconds(mut self, timeout_seconds: Option<u64>) -> Self {
383 self.timeout_seconds = timeout_seconds;
384 self
385 }
386}
387
388impl Transcriber for WhisperCliTranscriber {
389 fn transcribe(&mut self, input: &Path) -> Result<TranscriptionResult> {
390 let output_dir = self.output_dir.clone().unwrap_or_else(|| {
391 input
392 .parent()
393 .unwrap_or_else(|| Path::new("."))
394 .join("transcript")
395 });
396 fs::create_dir_all(&output_dir)?;
397
398 let child = Command::new(&self.command)
399 .arg(input)
400 .args(&self.args)
401 .arg("--output_format")
402 .arg("json")
403 .arg("--output_dir")
404 .arg(&output_dir)
405 .stdin(Stdio::null())
406 .spawn()?;
407 let status = wait_status_with_optional_timeout(child, &self.command, self.timeout_seconds)?;
408 if !status.success() {
409 return Err(TranscriptionError::CommandFailed(
410 self.command.display().to_string(),
411 ));
412 }
413
414 let transcript_path = find_transcript_json(&output_dir).ok_or_else(|| {
415 TranscriptionError::InvalidTranscript(
416 "transcriber completed but no JSON transcript was found".to_string(),
417 )
418 })?;
419 let bytes = fs::read(&transcript_path)?;
420 let mut result = parse_whisper_json(&bytes)?;
421 result.source = Some(transcript_path.to_string_lossy().into_owned());
422 Ok(result)
423 }
424}
425
426pub struct WhisperCppTranscriber {
428 inner: NativeWhisperCppTranscriber,
429}
430
431impl WhisperCppTranscriber {
432 pub fn new(config: WhisperCppConfig) -> Self {
434 Self {
435 inner: NativeWhisperCppTranscriber::new(config),
436 }
437 }
438
439 pub fn with_model_store(mut self, store: WhisperCppModelStore) -> Self {
441 self.inner = self.inner.with_model_store(store);
442 self
443 }
444
445 pub fn on_progress<F>(mut self, callback: F) -> Self
447 where
448 F: FnMut(WhisperCppProgressEvent) + 'static,
449 {
450 self.inner = self.inner.on_progress(callback);
451 self
452 }
453
454 pub fn transcribe_with_progress(
456 &mut self,
457 input: &Path,
458 progress: &mut dyn FnMut(WhisperCppProgressEvent),
459 ) -> Result<TranscriptionResult> {
460 let transcript = self.inner.transcribe_file_with_progress(input, progress)?;
461 Ok(whisper_cpp_result_to_transcription_result(transcript))
462 }
463}
464
465impl Transcriber for WhisperCppTranscriber {
466 fn transcribe(&mut self, input: &Path) -> Result<TranscriptionResult> {
467 let transcript = self.inner.transcribe_file(input)?;
468 Ok(whisper_cpp_result_to_transcription_result(transcript))
469 }
470}
471
472fn whisper_cpp_result_to_transcription_result(
473 transcript: whisper_cpp::WhisperCppTranscription,
474) -> TranscriptionResult {
475 TranscriptionResult {
476 text: transcript.text,
477 language: transcript.language.clone(),
478 segments: transcript
479 .segments
480 .into_iter()
481 .map(|segment| TranscriptSegment {
482 index: segment.index,
483 start_seconds: segment.start_seconds,
484 end_seconds: segment.end_seconds,
485 text: segment.text,
486 language: transcript.language.clone(),
487 speaker: None,
488 confidence: segment.confidence,
489 is_final: true,
490 })
491 .collect(),
492 source: transcript.source,
493 }
494}
495
496pub struct TranscriptSegmentSource {
498 source_info: MediaSourceInfo,
499 segments: Vec<TranscriptSegment>,
500 next_index: usize,
501}
502
503impl TranscriptSegmentSource {
504 pub fn recorded(input: impl Into<String>, segments: Vec<TranscriptSegment>) -> Self {
506 Self::new(SourceMode::Recorded, input, segments)
507 }
508
509 pub fn live(input: impl Into<String>, segments: Vec<TranscriptSegment>) -> Self {
511 Self::new(SourceMode::Live, input, segments)
512 }
513
514 fn new(mode: SourceMode, input: impl Into<String>, segments: Vec<TranscriptSegment>) -> Self {
515 let language = segments.iter().find_map(|segment| segment.language.clone());
516 let source_info = MediaSourceInfo {
517 input: input.into(),
518 mode,
519 video: None,
520 audio: Vec::new(),
521 text: vec![TextStreamInfo {
522 format: IngestTextFormat::Transcript,
523 language,
524 }],
525 };
526 Self {
527 source_info,
528 segments,
529 next_index: 0,
530 }
531 }
532}
533
534impl TextSegmentSource for TranscriptSegmentSource {
535 fn source_info(&self) -> &MediaSourceInfo {
536 &self.source_info
537 }
538
539 fn next_text_segment(&mut self) -> video_analysis_core::Result<Option<OwnedTextSegment>> {
540 let Some(segment) = self.segments.get(self.next_index) else {
541 return Ok(None);
542 };
543 self.next_index += 1;
544 Ok(Some(segment_to_owned_text_segment(segment)))
545 }
546}
547
548#[derive(Debug, Deserialize)]
549struct WhisperOutput {
550 text: Option<String>,
551 language: Option<String>,
552 #[serde(default)]
553 segments: Vec<WhisperSegment>,
554}
555
556#[derive(Debug, Deserialize)]
557struct WhisperSegment {
558 id: Option<u64>,
559 start: Option<f64>,
560 end: Option<f64>,
561 text: String,
562 #[serde(default)]
563 avg_logprob: Option<f32>,
564 #[serde(default)]
565 no_speech_prob: Option<f32>,
566}
567
568pub fn parse_whisper_json(bytes: &[u8]) -> Result<TranscriptionResult> {
570 let parsed: WhisperOutput = serde_json::from_slice(bytes)?;
571 let segments = parsed
572 .segments
573 .into_iter()
574 .enumerate()
575 .map(|(index, segment)| TranscriptSegment {
576 index: segment.id.unwrap_or(index as u64),
577 start_seconds: segment.start,
578 end_seconds: segment.end,
579 text: segment.text.trim().to_string(),
580 language: parsed.language.clone(),
581 speaker: None,
582 confidence: whisper_confidence(segment.avg_logprob, segment.no_speech_prob),
583 is_final: true,
584 })
585 .collect::<Vec<_>>();
586 Ok(TranscriptionResult {
587 text: parsed.text.map(|text| text.trim().to_string()),
588 language: parsed.language,
589 segments,
590 source: None,
591 })
592}
593
594pub fn normalize_transcription_contract(
596 contract: TranscriptionContract,
597) -> Result<TranscriptionContract> {
598 contract.normalized()
599}
600
601pub fn normalize_imported_segments(
603 source: Option<String>,
604 language: Option<String>,
605 segments: Vec<TranscriptSegmentContract>,
606) -> Result<TranscriptionContract> {
607 TranscriptionContract::from_segments(source, language, segments)
608}
609
610pub fn parse_whisperx_json(bytes: &[u8]) -> Result<TranscriptionContract> {
612 let value: Value = serde_json::from_slice(bytes)?;
613 let object = value.as_object().ok_or_else(|| {
614 TranscriptionError::InvalidTranscript("WhisperX JSON must be an object".to_string())
615 })?;
616 let language = object
617 .get("language")
618 .and_then(Value::as_str)
619 .map(str::to_string);
620 let text = object
621 .get("text")
622 .and_then(Value::as_str)
623 .map(|text| text.trim().to_string())
624 .filter(|text| !text.is_empty());
625 let source = object
626 .get("source")
627 .and_then(Value::as_str)
628 .map(str::to_string);
629 let mut attributes = unknown_attributes(
630 object,
631 &["language", "text", "source", "segments", "word_segments"],
632 );
633 if let Some(count) = object.get("segment_count").and_then(Value::as_u64) {
634 attributes.insert("segment_count".to_string(), count.to_string());
635 }
636
637 let mut segments = object
638 .get("segments")
639 .and_then(Value::as_array)
640 .ok_or_else(|| {
641 TranscriptionError::InvalidTranscript(
642 "WhisperX JSON must include a segments array".to_string(),
643 )
644 })?
645 .iter()
646 .enumerate()
647 .map(|(index, segment)| whisperx_segment(segment, index as u64, language.clone()))
648 .collect::<Result<Vec<_>>>()?;
649
650 if segments.iter().all(|segment| segment.words.is_empty()) {
651 if let Some(words) = object.get("word_segments").and_then(Value::as_array) {
652 attach_flat_whisperx_words(&mut segments, words)?;
653 }
654 }
655
656 let mut contract = TranscriptionContract {
657 text,
658 language,
659 segments,
660 source,
661 attributes,
662 }
663 .normalized()?;
664 for segment in &mut contract.segments {
665 if segment.speaker.is_none() {
666 segment.speaker = infer_segment_speaker(&segment.words);
667 }
668 }
669 contract.validate_strict()?;
670 Ok(contract)
671}
672
673pub fn parse_srt(text: &str) -> Result<TranscriptionResult> {
675 parse_subtitle_blocks(text, TranscriptFormat::Srt)
676}
677
678pub fn parse_webvtt(text: &str) -> Result<TranscriptionResult> {
680 parse_subtitle_blocks(text, TranscriptFormat::WebVtt)
681}
682
683pub fn parse_plain_lines(text: &str) -> TranscriptionResult {
685 let segments = text
686 .lines()
687 .enumerate()
688 .filter_map(|(index, line)| {
689 let line = line.trim();
690 (!line.is_empty()).then(|| TranscriptSegment {
691 index: index as u64,
692 start_seconds: None,
693 end_seconds: None,
694 text: line.to_string(),
695 language: None,
696 speaker: None,
697 confidence: None,
698 is_final: true,
699 })
700 })
701 .collect::<Vec<_>>();
702 TranscriptionResult {
703 text: Some(
704 segments
705 .iter()
706 .map(|segment| segment.text.as_str())
707 .collect::<Vec<_>>()
708 .join("\n"),
709 ),
710 language: None,
711 segments,
712 source: None,
713 }
714}
715
716pub fn format_srt(segments: &[TranscriptSegment]) -> String {
718 let mut output = String::new();
719 for (index, segment) in segments.iter().enumerate() {
720 let start = segment.start_seconds.unwrap_or(0.0);
721 let end = segment
722 .end_seconds
723 .unwrap_or_else(|| (start + 2.0).max(start));
724 output.push_str(&(index + 1).to_string());
725 output.push('\n');
726 output.push_str(&format_srt_timestamp(start));
727 output.push_str(" --> ");
728 output.push_str(&format_srt_timestamp(end.max(start)));
729 output.push('\n');
730 output.push_str(segment.text.trim());
731 output.push_str("\n\n");
732 }
733 output
734}
735
736pub fn format_webvtt(segments: &[TranscriptSegment]) -> String {
738 let mut output = String::from("WEBVTT\n\n");
739 for (index, segment) in segments.iter().enumerate() {
740 if index > 0 {
741 output.push('\n');
742 }
743 let start = segment.start_seconds.unwrap_or(0.0);
744 let end = segment
745 .end_seconds
746 .unwrap_or_else(|| (start + 2.0).max(start));
747 output.push_str(&format_webvtt_timestamp(start));
748 output.push_str(" --> ");
749 output.push_str(&format_webvtt_timestamp(end.max(start)));
750 output.push('\n');
751 output.push_str(segment.text.trim());
752 output.push('\n');
753 }
754 output
755}
756
757pub fn write_srt(path: impl AsRef<Path>, segments: &[TranscriptSegment]) -> Result<()> {
759 let path = path.as_ref();
760 if let Some(parent) = path.parent() {
761 fs::create_dir_all(parent)?;
762 }
763 fs::write(path, format_srt(segments))?;
764 Ok(())
765}
766
767pub fn transcribe_waveform_batch<T: Transcriber>(
769 transcriber: &mut T,
770 batch: &OwnedAudioWaveformBatch,
771 wav_path: &Path,
772) -> Result<TranscriptionResult> {
773 write_waveform_batch_as_wav(wav_path, batch)?;
774 transcriber.transcribe(wav_path)
775}
776
777fn write_waveform_batch_as_wav(
778 path: impl AsRef<Path>,
779 batch: &OwnedAudioWaveformBatch,
780) -> Result<()> {
781 let view = batch.as_view()?;
782 if view.batch_size() != 1 {
783 return Err(video_analysis_core::DetectError::InvalidArgument(
784 "waveform WAV export requires a batch size of 1".to_string(),
785 )
786 .into());
787 }
788 let path = path.as_ref();
789 if let Some(parent) = path.parent() {
790 fs::create_dir_all(parent)?;
791 }
792 let spec = hound::WavSpec {
793 channels: view.channel_count() as u16,
794 sample_rate: view.sample_rate,
795 bits_per_sample: 32,
796 sample_format: hound::SampleFormat::Float,
797 };
798 let mut writer = hound::WavWriter::create(path, spec).map_err(|err| {
799 video_analysis_core::DetectError::Source(format!(
800 "failed to create WAV `{}`: {err}",
801 path.display()
802 ))
803 })?;
804 for time_index in 0..view.time_steps() {
805 for channel_index in 0..view.channel_count() {
806 let sample = view.waveform(0, channel_index)?[time_index];
807 writer.write_sample(sample).map_err(|err| {
808 video_analysis_core::DetectError::Source(format!(
809 "failed to write WAV sample `{}`: {err}",
810 path.display()
811 ))
812 })?;
813 }
814 }
815 writer.finalize().map_err(|err| {
816 video_analysis_core::DetectError::Source(format!(
817 "failed to finalize WAV `{}`: {err}",
818 path.display()
819 ))
820 })?;
821 Ok(())
822}
823
824pub fn format_srt_timestamp(seconds: f64) -> String {
826 let total_millis = (seconds.max(0.0) * 1_000.0).round() as u64;
827 let millis = total_millis % 1_000;
828 let total_seconds = total_millis / 1_000;
829 let secs = total_seconds % 60;
830 let total_minutes = total_seconds / 60;
831 let minutes = total_minutes % 60;
832 let hours = total_minutes / 60;
833 format!("{hours:02}:{minutes:02}:{secs:02},{millis:03}")
834}
835
836fn format_webvtt_timestamp(seconds: f64) -> String {
837 format_srt_timestamp(seconds).replace(',', ".")
838}
839
840pub fn segment_to_owned_text_segment(segment: &TranscriptSegment) -> OwnedTextSegment {
842 text_core::TextSegmentContract::from(TranscriptSegmentContract::from(segment))
843 .to_owned_text_segment()
844}
845
846pub fn parse_transcript_file(path: impl AsRef<Path>) -> Result<TranscriptionResult> {
848 let path = path.as_ref();
849 let extension = path
850 .extension()
851 .and_then(|value| value.to_str())
852 .ok_or_else(|| {
853 TranscriptionError::InvalidTranscript("transcript file missing extension".to_string())
854 })?;
855 let format = TranscriptFormat::from_extension(extension).ok_or_else(|| {
856 TranscriptionError::InvalidTranscript(format!(
857 "unsupported transcript file extension `{extension}`"
858 ))
859 })?;
860 let bytes = fs::read(path)?;
861 let mut parsed = parse_transcript_bytes(&bytes, format)?;
862 parsed.source = Some(path.to_string_lossy().into_owned());
863 Ok(parsed)
864}
865
866pub fn parse_normalized_transcript_file(
868 path: impl AsRef<Path>,
869 options: SubtitleNormalizationOptions,
870) -> Result<TranscriptionContract> {
871 let parsed = parse_transcript_file(path)?;
872 let mut contract = TranscriptionContract::from(parsed);
873 contract.segments = contract
874 .segments
875 .into_iter()
876 .filter_map(|mut segment| {
877 segment.text = normalize_subtitle_text(&segment.text, options.clone());
878 (!segment.text.is_empty()).then_some(segment)
879 })
880 .collect();
881 contract.text = contract
882 .text
883 .as_deref()
884 .map(|text| normalize_subtitle_text(text, options))
885 .filter(|text| !text.is_empty());
886 contract.normalized()
887}
888
889fn parse_transcript_bytes(bytes: &[u8], format: TranscriptFormat) -> Result<TranscriptionResult> {
890 match format {
891 TranscriptFormat::Plain | TranscriptFormat::Lines => {
892 Ok(parse_plain_lines(&String::from_utf8_lossy(bytes)))
893 }
894 TranscriptFormat::WhisperJson => parse_whisper_json(bytes),
895 TranscriptFormat::Srt => parse_srt(&String::from_utf8_lossy(bytes)),
896 TranscriptFormat::WebVtt => parse_webvtt(&String::from_utf8_lossy(bytes)),
897 }
898}
899
900fn whisperx_segment(
901 value: &Value,
902 fallback_index: u64,
903 language: Option<String>,
904) -> Result<TranscriptSegmentContract> {
905 let object = value.as_object().ok_or_else(|| {
906 TranscriptionError::InvalidTranscript("WhisperX segment must be an object".to_string())
907 })?;
908 let mut segment = TranscriptSegmentContract::new(
909 object
910 .get("id")
911 .or_else(|| object.get("index"))
912 .and_then(Value::as_u64)
913 .unwrap_or(fallback_index),
914 object
915 .get("text")
916 .and_then(Value::as_str)
917 .unwrap_or_default(),
918 );
919 segment.start_seconds = number_field(object, &["start", "start_seconds", "startSeconds"]);
920 segment.end_seconds = number_field(object, &["end", "end_seconds", "endSeconds"]);
921 segment.language = object
922 .get("language")
923 .and_then(Value::as_str)
924 .map(str::to_string)
925 .or(language);
926 segment.speaker = object
927 .get("speaker")
928 .or_else(|| object.get("speaker_label"))
929 .or_else(|| object.get("speakerLabel"))
930 .and_then(Value::as_str)
931 .map(str::to_string);
932 segment.confidence = confidence_field(
933 object,
934 &["confidence", "score", "avg_logprob", "no_speech_prob"],
935 );
936 segment.words = object
937 .get("words")
938 .or_else(|| object.get("word_segments"))
939 .and_then(Value::as_array)
940 .map(|words| {
941 words
942 .iter()
943 .map(whisperx_word)
944 .collect::<Result<Vec<TranscriptWordContract>>>()
945 })
946 .transpose()?
947 .unwrap_or_default();
948 segment.chars = object
949 .get("chars")
950 .or_else(|| object.get("characters"))
951 .and_then(Value::as_array)
952 .map(|chars| {
953 chars
954 .iter()
955 .map(whisperx_char)
956 .collect::<Result<Vec<TranscriptCharContract>>>()
957 })
958 .transpose()?
959 .unwrap_or_default();
960 if segment.speaker.is_none() {
961 segment.speaker = infer_segment_speaker(&segment.words);
962 }
963 segment.attributes = unknown_attributes(
964 object,
965 &[
966 "id",
967 "index",
968 "start",
969 "start_seconds",
970 "startSeconds",
971 "end",
972 "end_seconds",
973 "endSeconds",
974 "text",
975 "language",
976 "speaker",
977 "speaker_label",
978 "speakerLabel",
979 "confidence",
980 "score",
981 "avg_logprob",
982 "no_speech_prob",
983 "words",
984 "word_segments",
985 "chars",
986 "characters",
987 ],
988 );
989 Ok(segment)
990}
991
992fn whisperx_word(value: &Value) -> Result<TranscriptWordContract> {
993 let object = value.as_object().ok_or_else(|| {
994 TranscriptionError::InvalidTranscript("WhisperX word must be an object".to_string())
995 })?;
996 Ok(TranscriptWordContract {
997 text: object
998 .get("word")
999 .or_else(|| object.get("text"))
1000 .and_then(Value::as_str)
1001 .unwrap_or_default()
1002 .to_string(),
1003 start_seconds: number_field(object, &["start", "start_seconds", "startSeconds"]),
1004 end_seconds: number_field(object, &["end", "end_seconds", "endSeconds"]),
1005 confidence: confidence_field(object, &["confidence", "score", "probability"]),
1006 speaker: object
1007 .get("speaker")
1008 .or_else(|| object.get("speaker_label"))
1009 .or_else(|| object.get("speakerLabel"))
1010 .and_then(Value::as_str)
1011 .map(str::to_string),
1012 attributes: unknown_attributes(
1013 object,
1014 &[
1015 "word",
1016 "text",
1017 "start",
1018 "start_seconds",
1019 "startSeconds",
1020 "end",
1021 "end_seconds",
1022 "endSeconds",
1023 "confidence",
1024 "score",
1025 "probability",
1026 "speaker",
1027 "speaker_label",
1028 "speakerLabel",
1029 ],
1030 ),
1031 })
1032}
1033
1034fn whisperx_char(value: &Value) -> Result<TranscriptCharContract> {
1035 let object = value.as_object().ok_or_else(|| {
1036 TranscriptionError::InvalidTranscript("WhisperX char must be an object".to_string())
1037 })?;
1038 Ok(TranscriptCharContract {
1039 character: object
1040 .get("char")
1041 .or_else(|| object.get("character"))
1042 .or_else(|| object.get("text"))
1043 .and_then(Value::as_str)
1044 .unwrap_or_default()
1045 .to_string(),
1046 start_seconds: number_field(object, &["start", "start_seconds", "startSeconds"]),
1047 end_seconds: number_field(object, &["end", "end_seconds", "endSeconds"]),
1048 confidence: confidence_field(object, &["confidence", "score", "probability"]),
1049 attributes: unknown_attributes(
1050 object,
1051 &[
1052 "char",
1053 "character",
1054 "text",
1055 "start",
1056 "start_seconds",
1057 "startSeconds",
1058 "end",
1059 "end_seconds",
1060 "endSeconds",
1061 "confidence",
1062 "score",
1063 "probability",
1064 ],
1065 ),
1066 })
1067}
1068
1069fn attach_flat_whisperx_words(
1070 segments: &mut [TranscriptSegmentContract],
1071 words: &[Value],
1072) -> Result<()> {
1073 for value in words {
1074 let word = whisperx_word(value)?;
1075 let midpoint = match (word.start_seconds, word.end_seconds) {
1076 (Some(start), Some(end)) => Some((start + end) * 0.5),
1077 (Some(start), None) => Some(start),
1078 (None, Some(end)) => Some(end),
1079 (None, None) => None,
1080 };
1081 let Some(segment) = segments.iter_mut().find(|segment| {
1082 midpoint
1083 .zip(segment.start_seconds.zip(segment.end_seconds))
1084 .map(|(midpoint, (start, end))| midpoint >= start && midpoint <= end)
1085 .unwrap_or(false)
1086 }) else {
1087 continue;
1088 };
1089 segment.words.push(word);
1090 }
1091 for segment in segments {
1092 if segment.speaker.is_none() {
1093 segment.speaker = infer_segment_speaker(&segment.words);
1094 }
1095 }
1096 Ok(())
1097}
1098
1099fn infer_segment_speaker(words: &[TranscriptWordContract]) -> Option<String> {
1100 let mut scores: BTreeMap<&str, (f64, usize)> = BTreeMap::new();
1101 for word in words {
1102 let Some(speaker) = word
1103 .speaker
1104 .as_deref()
1105 .filter(|speaker| !speaker.is_empty())
1106 else {
1107 continue;
1108 };
1109 let duration = word
1110 .start_seconds
1111 .zip(word.end_seconds)
1112 .map(|(start, end)| (end - start).max(0.0))
1113 .filter(|duration| duration.is_finite() && *duration > 0.0)
1114 .unwrap_or(1.0);
1115 let entry = scores.entry(speaker).or_insert((0.0, 0));
1116 entry.0 += duration;
1117 entry.1 += 1;
1118 }
1119 scores
1120 .into_iter()
1121 .max_by(|left, right| {
1122 left.1
1123 .0
1124 .total_cmp(&right.1 .0)
1125 .then(left.1 .1.cmp(&right.1 .1))
1126 .then_with(|| right.0.cmp(left.0))
1127 })
1128 .map(|(speaker, _)| speaker.to_string())
1129}
1130
1131fn unknown_attributes(
1132 object: &serde_json::Map<String, Value>,
1133 known_fields: &[&str],
1134) -> BTreeMap<String, String> {
1135 object
1136 .iter()
1137 .filter(|(key, _)| !known_fields.contains(&key.as_str()))
1138 .map(|(key, value)| (key.clone(), json_attribute(value)))
1139 .collect()
1140}
1141
1142fn json_attribute(value: &Value) -> String {
1143 match value {
1144 Value::String(value) => value.clone(),
1145 Value::Number(value) => value.to_string(),
1146 Value::Bool(value) => value.to_string(),
1147 Value::Null => "null".to_string(),
1148 other => other.to_string(),
1149 }
1150}
1151
1152fn number_field(object: &serde_json::Map<String, Value>, names: &[&str]) -> Option<f64> {
1153 names
1154 .iter()
1155 .find_map(|name| object.get(*name).and_then(Value::as_f64))
1156 .filter(|value| value.is_finite())
1157}
1158
1159fn confidence_field(object: &serde_json::Map<String, Value>, names: &[&str]) -> Option<f32> {
1160 for name in names {
1161 let Some(value) = object.get(*name).and_then(Value::as_f64) else {
1162 continue;
1163 };
1164 if !value.is_finite() {
1165 continue;
1166 }
1167 return match *name {
1168 "avg_logprob" => Some(value.exp().clamp(0.0, 1.0) as f32),
1169 "no_speech_prob" => Some((1.0 - value).clamp(0.0, 1.0) as f32),
1170 _ => Some(value.clamp(0.0, 1.0) as f32),
1171 };
1172 }
1173 None
1174}
1175
1176fn wait_with_optional_timeout(
1177 mut child: Child,
1178 command: &Path,
1179 timeout_seconds: Option<u64>,
1180) -> Result<Output> {
1181 if let Some(seconds) = timeout_seconds {
1182 let deadline = std::time::Instant::now() + Duration::from_secs(seconds);
1183 loop {
1184 if child.try_wait()?.is_some() {
1185 return Ok(child.wait_with_output()?);
1186 }
1187 if std::time::Instant::now() >= deadline {
1188 let _ = child.kill();
1189 let _ = child.wait();
1190 return Err(TranscriptionError::CommandTimeout {
1191 command: command.display().to_string(),
1192 seconds,
1193 });
1194 }
1195 std::thread::sleep(Duration::from_millis(25));
1196 }
1197 }
1198 Ok(child.wait_with_output()?)
1199}
1200
1201fn wait_status_with_optional_timeout(
1202 mut child: Child,
1203 command: &Path,
1204 timeout_seconds: Option<u64>,
1205) -> Result<ExitStatus> {
1206 if let Some(seconds) = timeout_seconds {
1207 let deadline = std::time::Instant::now() + Duration::from_secs(seconds);
1208 loop {
1209 if let Some(status) = child.try_wait()? {
1210 return Ok(status);
1211 }
1212 if std::time::Instant::now() >= deadline {
1213 let _ = child.kill();
1214 let _ = child.wait();
1215 return Err(TranscriptionError::CommandTimeout {
1216 command: command.display().to_string(),
1217 seconds,
1218 });
1219 }
1220 std::thread::sleep(Duration::from_millis(25));
1221 }
1222 }
1223 Ok(child.wait()?)
1224}
1225
1226fn strip_subtitle_markup(text: &str) -> String {
1227 let mut output = String::with_capacity(text.len());
1228 let mut chars = text.chars().peekable();
1229 while let Some(ch) = chars.next() {
1230 if ch != '<' {
1231 output.push(ch);
1232 continue;
1233 }
1234
1235 let mut tag = String::new();
1236 let mut closed = false;
1237 for tag_ch in chars.by_ref() {
1238 if tag_ch == '>' {
1239 closed = true;
1240 break;
1241 }
1242 tag.push(tag_ch);
1243 }
1244
1245 if !closed {
1246 output.push('<');
1247 output.push_str(&tag);
1248 break;
1249 }
1250
1251 let tag = tag.trim();
1252 if tag.is_empty() {
1253 continue;
1254 }
1255
1256 if is_subtitle_timestamp(tag) {
1257 output.push(' ');
1258 } else if is_subtitle_tag(tag) {
1259 continue;
1260 } else {
1261 output.push('<');
1262 output.push_str(tag);
1263 output.push('>');
1264 }
1265 }
1266 output
1267}
1268
1269fn is_subtitle_tag(tag: &str) -> bool {
1270 let tag = tag.trim_start_matches('/');
1271 let name = tag
1272 .split(|ch: char| ch.is_whitespace() || ch == '.')
1273 .next()
1274 .unwrap_or_default();
1275 matches!(name, "b" | "c" | "i" | "lang" | "rt" | "ruby" | "u" | "v")
1276}
1277
1278fn is_subtitle_timestamp(value: &str) -> bool {
1279 let value = value.replace(',', ".");
1280 let parts = value.split(':').collect::<Vec<_>>();
1281 let [hours, minutes, seconds] = parts.as_slice() else {
1282 return false;
1283 };
1284 is_two_digits(hours)
1285 && is_two_digits(minutes)
1286 && seconds.len() == 6
1287 && seconds.as_bytes().get(2) == Some(&b'.')
1288 && seconds[..2].bytes().all(|byte| byte.is_ascii_digit())
1289 && seconds[3..].bytes().all(|byte| byte.is_ascii_digit())
1290}
1291
1292fn is_two_digits(value: &str) -> bool {
1293 value.len() == 2 && value.bytes().all(|byte| byte.is_ascii_digit())
1294}
1295
1296fn decode_basic_entities(text: &str) -> String {
1297 text.replace("&", "&")
1298 .replace("<", "<")
1299 .replace(">", ">")
1300 .replace(""", "\"")
1301 .replace("'", "'")
1302 .replace("'", "'")
1303 .replace(" ", " ")
1304}
1305
1306fn collapse_whitespace(text: &str) -> String {
1307 text.split_whitespace().collect::<Vec<_>>().join(" ")
1308}
1309
1310fn parse_subtitle_blocks(text: &str, format: TranscriptFormat) -> Result<TranscriptionResult> {
1311 let normalized = text.replace("\r\n", "\n").replace('\r', "\n");
1312 let mut segments = Vec::new();
1313
1314 for block in normalized.split("\n\n") {
1315 let lines = block
1316 .lines()
1317 .map(str::trim)
1318 .filter(|line| !line.is_empty())
1319 .collect::<Vec<_>>();
1320 if lines.is_empty()
1321 || (format == TranscriptFormat::WebVtt && lines[0].starts_with("WEBVTT"))
1322 {
1323 continue;
1324 }
1325
1326 let time_line_index = lines
1327 .iter()
1328 .position(|line| line.contains("-->"))
1329 .ok_or_else(|| {
1330 TranscriptionError::InvalidTranscript(
1331 "subtitle block missing timestamp".to_string(),
1332 )
1333 })?;
1334 let (start_seconds, end_seconds) = parse_timestamp_range(lines[time_line_index])?;
1335 let text_lines = &lines[time_line_index + 1..];
1336 if text_lines.is_empty() {
1337 continue;
1338 }
1339 let index = if time_line_index > 0 {
1340 lines[0].parse::<u64>().unwrap_or(segments.len() as u64)
1341 } else {
1342 segments.len() as u64
1343 };
1344 segments.push(TranscriptSegment {
1345 index,
1346 start_seconds: Some(start_seconds),
1347 end_seconds: Some(end_seconds),
1348 text: text_lines.join(" "),
1349 language: None,
1350 speaker: None,
1351 confidence: None,
1352 is_final: true,
1353 });
1354 }
1355
1356 Ok(TranscriptionResult {
1357 text: Some(
1358 segments
1359 .iter()
1360 .map(|segment| segment.text.as_str())
1361 .collect::<Vec<_>>()
1362 .join("\n"),
1363 ),
1364 language: None,
1365 segments,
1366 source: None,
1367 })
1368}
1369
1370fn parse_timestamp_range(line: &str) -> Result<(f64, f64)> {
1371 let Some((start, end_with_settings)) = line.split_once("-->") else {
1372 return Err(TranscriptionError::InvalidTranscript(
1373 "timestamp range missing -->".to_string(),
1374 ));
1375 };
1376 let end = end_with_settings
1377 .split_whitespace()
1378 .next()
1379 .unwrap_or(end_with_settings);
1380 Ok((parse_timestamp(start.trim())?, parse_timestamp(end.trim())?))
1381}
1382
1383fn parse_timestamp(value: &str) -> Result<f64> {
1384 let value = value.replace(',', ".");
1385 let pieces = value.split(':').collect::<Vec<_>>();
1386 let seconds = match pieces.as_slice() {
1387 [minutes, seconds] => {
1388 Some(parse_timestamp_component(minutes)? * 60.0 + parse_timestamp_component(seconds)?)
1389 }
1390 [hours, minutes, seconds] => Some(
1391 parse_timestamp_component(hours)? * 3600.0
1392 + parse_timestamp_component(minutes)? * 60.0
1393 + parse_timestamp_component(seconds)?,
1394 ),
1395 _ => None,
1396 };
1397 seconds.ok_or_else(|| {
1398 TranscriptionError::InvalidTranscript(format!("invalid timestamp `{value}`"))
1399 })
1400}
1401
1402fn parse_timestamp_component(value: &str) -> Result<f64> {
1403 value.parse::<f64>().map_err(|_| {
1404 TranscriptionError::InvalidTranscript(format!("invalid timestamp component `{value}`"))
1405 })
1406}
1407
1408fn whisper_confidence(avg_logprob: Option<f32>, no_speech_prob: Option<f32>) -> Option<f32> {
1409 avg_logprob.or(no_speech_prob.map(|probability| 1.0 - probability))
1410}
1411
1412fn insert_optional(metadata: &mut BTreeMap<String, String>, key: &str, value: Option<&str>) {
1413 if let Some(value) = value {
1414 metadata.insert(key.to_string(), value.to_string());
1415 }
1416}
1417
1418fn insert_optional_number(metadata: &mut BTreeMap<String, String>, key: &str, value: Option<f64>) {
1419 insert_optional_display(metadata, key, value);
1420}
1421
1422fn insert_optional_display<T: fmt::Display>(
1423 metadata: &mut BTreeMap<String, String>,
1424 key: &str,
1425 value: Option<T>,
1426) {
1427 if let Some(value) = value {
1428 metadata.insert(key.to_string(), value.to_string());
1429 }
1430}
1431
1432fn has_token_kind(text: &str, kind: TokenKind) -> bool {
1433 tokenize(text, &TextProcessingOptions::default())
1434 .into_iter()
1435 .any(|token| token.kind == kind)
1436}
1437
1438fn event_at(analyzer: &str, label: &str, timestamp: Option<Timestamp>) -> AnalysisEvent {
1439 let event = AnalysisEvent::new(analyzer, label);
1440 if let Some(timestamp) = timestamp {
1441 event.at_timestamp(timestamp)
1442 } else {
1443 event
1444 }
1445}
1446
1447fn find_transcript_json(output_dir: &Path) -> Option<PathBuf> {
1448 let mut candidates = fs::read_dir(output_dir)
1449 .ok()?
1450 .filter_map(|entry| entry.ok().map(|entry| entry.path()))
1451 .filter(|path| path.extension().and_then(|value| value.to_str()) == Some("json"))
1452 .collect::<Vec<_>>();
1453 candidates.sort_by(|left, right| {
1454 let left_modified = fs::metadata(left)
1455 .and_then(|metadata| metadata.modified())
1456 .ok();
1457 let right_modified = fs::metadata(right)
1458 .and_then(|metadata| metadata.modified())
1459 .ok();
1460 left_modified
1461 .cmp(&right_modified)
1462 .then_with(|| left.cmp(right))
1463 });
1464 candidates.pop()
1465}
1466
1467#[cfg(test)]
1468mod tests {
1469 use super::*;
1470 use tempfile::tempdir;
1471 use video_analysis_core::{AudioBuffer, OwnedAudioFrame, Timebase, Timestamp};
1472 use video_analysis_ingest::TextSegmentSource;
1473
1474 #[test]
1475 fn parses_whisper_json() {
1476 let parsed = parse_whisper_json(
1477 br#"{"text":"hello world","language":"en","segments":[{"id":7,"start":0.0,"end":1.5,"text":" hello"}]}"#,
1478 )
1479 .unwrap();
1480
1481 assert_eq!(parsed.text.as_deref(), Some("hello world"));
1482 assert_eq!(parsed.language.as_deref(), Some("en"));
1483 assert_eq!(parsed.segments[0].index, 7);
1484 assert_eq!(parsed.segments[0].start_seconds, Some(0.0));
1485 }
1486
1487 #[test]
1488 fn parses_whisperx_segment_chars_and_preserves_unknown_fields() {
1489 let parsed = parse_whisperx_json(
1490 br#"{
1491 "text": "hi",
1492 "language": "en",
1493 "segments": [{
1494 "id": 0,
1495 "start": 1.0,
1496 "end": 2.0,
1497 "text": "hi",
1498 "chars": [
1499 {"char": "h", "start": 1.1, "end": 1.2, "score": 0.9, "extra": "kept"},
1500 {"char": "i", "start": 1.2, "end": 1.3, "score": 0.8}
1501 ]
1502 }]
1503 }"#,
1504 )
1505 .unwrap();
1506
1507 let chars = &parsed.segments[0].chars;
1508 assert_eq!(chars.len(), 2);
1509 assert_eq!(chars[0].character, "h");
1510 assert_eq!(chars[0].start_seconds, Some(1.1));
1511 assert_eq!(chars[0].end_seconds, Some(1.2));
1512 assert_eq!(chars[0].confidence, Some(0.9));
1513 assert_eq!(
1514 chars[0].attributes.get("extra").map(String::as_str),
1515 Some("kept")
1516 );
1517 }
1518
1519 #[test]
1520 fn normalization_keeps_chars_and_strict_validation_checks_bounds() {
1521 let mut segment = TranscriptSegmentContract::new(0, " hello ");
1522 segment.start_seconds = Some(0.0);
1523 segment.end_seconds = Some(1.0);
1524 segment.chars.push(TranscriptCharContract {
1525 character: "h".to_string(),
1526 start_seconds: Some(0.1),
1527 end_seconds: Some(0.2),
1528 confidence: Some(0.5),
1529 attributes: BTreeMap::new(),
1530 });
1531
1532 let normalized = TranscriptionContract::new(vec![segment])
1533 .normalized()
1534 .unwrap();
1535 assert_eq!(normalized.segments[0].text, "hello");
1536 assert_eq!(normalized.segments[0].chars.len(), 1);
1537 normalized.validate_strict().unwrap();
1538
1539 let mut invalid = normalized.clone();
1540 invalid.segments[0].chars[0].end_seconds = Some(1.1);
1541 let error = invalid.validate_strict().unwrap_err().to_string();
1542 assert!(error.contains("transcript char end_seconds"));
1543 }
1544
1545 #[test]
1546 fn parses_srt() {
1547 let parsed = parse_srt("1\n00:00:01,000 --> 00:00:02,500\nHello\n\n").unwrap();
1548 assert_eq!(parsed.segments.len(), 1);
1549 assert_eq!(parsed.segments[0].start_seconds, Some(1.0));
1550 assert_eq!(parsed.segments[0].end_seconds, Some(2.5));
1551 }
1552
1553 #[test]
1554 fn formats_srt() {
1555 let text = format_srt(&[
1556 TranscriptSegment {
1557 index: 0,
1558 start_seconds: Some(1.25),
1559 end_seconds: Some(3.5),
1560 text: "Hello".to_string(),
1561 language: None,
1562 speaker: None,
1563 confidence: None,
1564 is_final: true,
1565 },
1566 TranscriptSegment {
1567 index: 1,
1568 start_seconds: Some(63.0),
1569 end_seconds: Some(65.125),
1570 text: "World".to_string(),
1571 language: None,
1572 speaker: None,
1573 confidence: None,
1574 is_final: true,
1575 },
1576 ]);
1577
1578 assert_eq!(
1579 text,
1580 "1\n00:00:01,250 --> 00:00:03,500\nHello\n\n2\n00:01:03,000 --> 00:01:05,125\nWorld\n\n"
1581 );
1582 }
1583
1584 #[test]
1585 fn formats_webvtt() {
1586 let text = format_webvtt(&[TranscriptSegment {
1587 index: 0,
1588 start_seconds: Some(1.0),
1589 end_seconds: Some(0.5),
1590 text: "Hello.".to_string(),
1591 language: None,
1592 speaker: None,
1593 confidence: None,
1594 is_final: true,
1595 }]);
1596
1597 assert_eq!(text, "WEBVTT\n\n00:00:01.000 --> 00:00:01.000\nHello.\n");
1598 }
1599
1600 #[test]
1601 fn parses_webvtt() {
1602 let parsed = parse_webvtt("WEBVTT\n\ncue\n00:00:03.000 --> 00:00:04.250\nHi\n").unwrap();
1603 assert_eq!(parsed.segments.len(), 1);
1604 assert_eq!(parsed.segments[0].text, "Hi");
1605 assert_eq!(parsed.segments[0].start_seconds, Some(3.0));
1606 }
1607
1608 #[test]
1609 fn parses_plain_lines() {
1610 let parsed = parse_plain_lines("one\n\ntwo\n");
1611 assert_eq!(parsed.segments.len(), 2);
1612 assert_eq!(parsed.text.as_deref(), Some("one\ntwo"));
1613 }
1614
1615 #[test]
1616 fn converts_segment_timestamp() {
1617 let segment = TranscriptSegment {
1618 index: 2,
1619 start_seconds: Some(1.25),
1620 end_seconds: Some(2.0),
1621 text: "hello".to_string(),
1622 language: Some("en".to_string()),
1623 speaker: None,
1624 confidence: None,
1625 is_final: true,
1626 };
1627 let owned = segment_to_owned_text_segment(&segment);
1628 assert_eq!(owned.segment_index, 2);
1629 assert_eq!(owned.timestamp.unwrap().seconds(), 1.25);
1630 assert_eq!(owned.language.as_deref(), Some("en"));
1631 }
1632
1633 #[test]
1634 fn transcript_segment_metadata_preserves_optional_fields() {
1635 let segment = TranscriptSegment {
1636 index: 2,
1637 start_seconds: Some(1.25),
1638 end_seconds: Some(2.0),
1639 text: "hello".to_string(),
1640 language: Some("en".to_string()),
1641 speaker: Some("speaker-1".to_string()),
1642 confidence: Some(0.75),
1643 is_final: true,
1644 };
1645
1646 let metadata = segment.metadata_with_source("fixture.srt");
1647
1648 assert_eq!(metadata["language"], "en");
1649 assert_eq!(metadata["speaker"], "speaker-1");
1650 assert_eq!(metadata["start_seconds"], "1.25");
1651 assert_eq!(metadata["end_seconds"], "2");
1652 assert_eq!(metadata["confidence"], "0.75");
1653 assert_eq!(metadata["source"], "fixture.srt");
1654 }
1655
1656 #[test]
1657 fn transcript_segment_source_iterates() {
1658 let mut source = TranscriptSegmentSource::recorded(
1659 "test",
1660 vec![TranscriptSegment {
1661 index: 0,
1662 start_seconds: None,
1663 end_seconds: None,
1664 text: "hello".to_string(),
1665 language: None,
1666 speaker: None,
1667 confidence: None,
1668 is_final: true,
1669 }],
1670 );
1671 assert_eq!(source.next_text_segment().unwrap().unwrap().text, "hello");
1672 assert!(source.next_text_segment().unwrap().is_none());
1673 }
1674
1675 #[test]
1676 fn transcript_heuristic_analyzer_emits_speech_events() {
1677 let segment = TextSegment {
1678 segment_index: 0,
1679 timestamp: None,
1680 text: "Visit https://example.com at 3?",
1681 language: None,
1682 is_final: true,
1683 };
1684 let mut analyzer = TranscriptHeuristicAnalyzer;
1685
1686 let labels = analyzer
1687 .process_segment(&segment)
1688 .unwrap()
1689 .into_iter()
1690 .map(|event| event.label)
1691 .collect::<Vec<_>>();
1692
1693 assert!(labels.iter().any(|label| label == "speech:question"));
1694 assert!(labels.iter().any(|label| label == "speech:url"));
1695 assert!(labels.iter().any(|label| label == "speech:number"));
1696 }
1697
1698 #[test]
1699 fn command_transcriber_reports_failure() {
1700 let mut transcriber = CommandTranscriber::new("false", TranscriptFormat::Plain);
1701 let err = transcriber.transcribe(Path::new("missing")).unwrap_err();
1702 assert!(matches!(err, TranscriptionError::CommandFailed(_)));
1703 }
1704
1705 #[test]
1706 fn transcribes_waveform_batches_via_existing_command_transcriber() {
1707 let dir = tempdir().unwrap();
1708 let script_path = dir.path().join("transcriber.sh");
1709 fs::write(&script_path, "#!/bin/sh\nprintf 'hello from batch\\n'\n").unwrap();
1710 #[cfg(unix)]
1711 {
1712 use std::os::unix::fs::PermissionsExt;
1713
1714 let mut permissions = fs::metadata(&script_path).unwrap().permissions();
1715 permissions.set_mode(0o755);
1716 fs::set_permissions(&script_path, permissions).unwrap();
1717 }
1718
1719 let mut transcriber = CommandTranscriber::new(&script_path, TranscriptFormat::Plain);
1720 let frame = OwnedAudioFrame::new(
1721 Timestamp::new(0, Timebase::new(1, 16_000)),
1722 16_000,
1723 1,
1724 AudioBuffer::F32(vec![0.0, 0.25, -0.25, 0.5]),
1725 )
1726 .unwrap();
1727 let batch = OwnedAudioWaveformBatch::from_audio_frames(&[frame]).unwrap();
1728 let wav_path = dir.path().join("input.wav");
1729
1730 let result = transcribe_waveform_batch(&mut transcriber, &batch, &wav_path).unwrap();
1731 assert_eq!(result.text.as_deref(), Some("hello from batch"));
1732 assert!(wav_path.is_file());
1733 }
1734
1735 #[test]
1736 fn srt_webvtt_plain_round_trip() {
1737 let srt = "1\n00:00:00,000 --> 00:00:01,000\nHello world\n";
1738 let parsed = parse_srt(srt).unwrap();
1739 let formatted = format_srt(&parsed.segments);
1740 assert!(formatted.contains("Hello world"));
1741
1742 let webvtt =
1743 parse_webvtt("WEBVTT\n\n00:00:00.000 --> 00:00:01.000\nHello world\n").unwrap();
1744 let plain = parse_plain_lines("Hello world\n");
1745
1746 assert_eq!(parsed.segments[0].text, webvtt.segments[0].text);
1747 assert_eq!(plain.text.as_deref(), Some("Hello world"));
1748 }
1749
1750 #[test]
1751 fn normalizes_subtitle_markup_entities_and_whitespace() {
1752 let normalized = normalize_subtitle_text(
1753 "<v Speaker>Hello <c.yellow>Rust</c> & friends <00:00:01.000>\nnow",
1754 SubtitleNormalizationOptions::default(),
1755 );
1756
1757 assert_eq!(normalized, "Hello Rust & friends now");
1758 }
1759
1760 #[test]
1761 fn infers_transcript_format_from_extension() {
1762 assert_eq!(
1763 TranscriptFormat::from_extension(".vtt"),
1764 Some(TranscriptFormat::WebVtt)
1765 );
1766 assert_eq!(
1767 TranscriptFormat::from_extension("SRT"),
1768 Some(TranscriptFormat::Srt)
1769 );
1770 assert_eq!(
1771 TranscriptFormat::from_extension("json"),
1772 Some(TranscriptFormat::WhisperJson)
1773 );
1774 assert_eq!(TranscriptFormat::from_extension("csv"), None);
1775 }
1776
1777 #[test]
1778 fn parses_and_normalizes_transcript_file() {
1779 let dir = tempfile::tempdir().unwrap();
1780 let path = dir.path().join("sample.vtt");
1781 fs::write(
1782 &path,
1783 "WEBVTT\n\n00:00:00.000 --> 00:00:01.000\n<c>Hello file</c>\n",
1784 )
1785 .unwrap();
1786
1787 let parsed =
1788 parse_normalized_transcript_file(&path, SubtitleNormalizationOptions::default())
1789 .unwrap();
1790
1791 assert_eq!(parsed.source.as_deref(), Some(path.to_str().unwrap()));
1792 assert_eq!(parsed.text.as_deref(), Some("Hello file"));
1793 assert_eq!(parsed.segments[0].text, "Hello file");
1794 }
1795
1796 #[test]
1797 fn builds_text_segment_contract_with_source() {
1798 let mut segment = TranscriptSegmentContract::new(7, "hello source");
1799 segment.start_seconds = Some(1.25);
1800 segment.end_seconds = Some(2.5);
1801 segment.language = Some("en".to_string());
1802
1803 let contract = text_segment_contract_with_source(
1804 &segment,
1805 "stream-1",
1806 "caption_manual",
1807 "https://example.test/video",
1808 );
1809
1810 assert_eq!(contract.stream_id.as_deref(), Some("stream-1"));
1811 assert_eq!(contract.segment_index, 7);
1812 assert_eq!(contract.language.as_deref(), Some("en"));
1813 assert_eq!(contract.duration_seconds, Some(1.25));
1814 let source = contract.source.unwrap();
1815 assert_eq!(source.source_id.as_deref(), Some("stream-1"));
1816 assert_eq!(source.source_kind.as_deref(), Some("caption_manual"));
1817 assert_eq!(source.uri.as_deref(), Some("https://example.test/video"));
1818 assert_eq!(source.duration_seconds, Some(1.25));
1819 }
1820
1821 #[test]
1822 fn preserves_unknown_markup_conservatively() {
1823 let normalized = normalize_subtitle_text(
1824 "look <custom value>here</custom>",
1825 SubtitleNormalizationOptions::default(),
1826 );
1827
1828 assert_eq!(normalized, "look <custom value>here</custom>");
1829 }
1830
1831 #[test]
1832 fn decodes_basic_entities() {
1833 let normalized = normalize_subtitle_text(
1834 "& < > " ' ' ",
1835 SubtitleNormalizationOptions {
1836 strip_markup: false,
1837 decode_basic_entities: true,
1838 collapse_whitespace: true,
1839 },
1840 );
1841
1842 assert_eq!(normalized, "& < > \" ' '");
1843 }
1844
1845 #[test]
1846 fn parse_and_normalize_round_trip_for_subtitles() {
1847 let srt = parse_srt("1\n00:00:00,000 --> 00:00:01,000\n<c>Hello SRT</c>\n").unwrap();
1848 let vtt =
1849 parse_webvtt("WEBVTT\n\n00:00:00.000 --> 00:00:01.000\n<v Speaker>Hello VTT\n")
1850 .unwrap();
1851
1852 let options = SubtitleNormalizationOptions::default();
1853 assert_eq!(
1854 normalize_subtitle_text(&srt.segments[0].text, options.clone()),
1855 "Hello SRT"
1856 );
1857 assert_eq!(
1858 normalize_subtitle_text(&vtt.segments[0].text, options),
1859 "Hello VTT"
1860 );
1861 }
1862}