Skip to main content

text_transcripts/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod surface;
4use std::collections::BTreeMap;
5use std::fmt;
6use std::fs;
7use std::path::{Path, PathBuf};
8use std::process::{Child, Command, ExitStatus, Output, Stdio};
9use std::time::Duration;
10
11use audio_analysis_core::OwnedAudioWaveformBatch;
12use serde::Deserialize;
13use serde_json::Value;
14use text_core::{tokenize, tokenize_words, TextProcessingOptions, TokenKind};
15
16mod whisper_cpp;
17
18use thiserror::Error;
19use video_analysis_core::{AnalysisEvent, OwnedTextSegment, TextAnalyzer, TextSegment, Timestamp};
20use video_analysis_ingest::{
21    MediaSourceInfo, SourceMode, TextFormat as IngestTextFormat, TextSegmentSource, TextStreamInfo,
22};
23pub mod contracts;
24pub use contracts::{
25    text_segment_contract_with_source, TranscriptCharContract, TranscriptSegmentContract,
26    TranscriptWordContract, TranscriptionContract,
27};
28/// Re-exports the text transcript native whisper.cpp API.
29pub use whisper_cpp::{
30    transcription_catalog as whisper_cpp_catalog, whisper_cpp_system_info,
31    ModelStore as WhisperCppModelStore, WhisperCppCatalog, WhisperCppConfig, WhisperCppError,
32    WhisperCppModel, WhisperCppModelStatus, WhisperCppPhase, WhisperCppProgressEvent,
33    WhisperCppSegment, WhisperCppTranscriber as NativeWhisperCppTranscriber,
34};
35
36#[derive(Debug, Error)]
37/// Variants describing transcription error.
38pub enum TranscriptionError {
39    #[error("I/O error: {0}")]
40    /// The I/O variant.
41    Io(#[from] std::io::Error),
42    #[error("invalid transcript JSON: {0}")]
43    /// The JSON variant.
44    Json(#[from] serde_json::Error),
45    #[error("invalid transcript: {0}")]
46    /// The invalid transcript variant.
47    InvalidTranscript(String),
48    #[error("transcriber command `{0}` failed")]
49    /// The command failed variant.
50    CommandFailed(String),
51    #[error("transcriber command `{command}` timed out after {seconds} seconds")]
52    /// The command timeout variant.
53    CommandTimeout {
54        /// Command that timed out.
55        command: String,
56        /// Timeout in seconds.
57        seconds: u64,
58    },
59    #[error("{0}")]
60    /// The detect variant.
61    Detect(#[from] video_analysis_core::DetectError),
62    #[error("{0}")]
63    /// The whisper cpp variant.
64    WhisperCpp(#[from] WhisperCppError),
65}
66
67/// Type alias for result.
68pub type Result<T> = std::result::Result<T, TranscriptionError>;
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71/// Variants describing transcript format.
72pub enum TranscriptFormat {
73    /// The plain variant.
74    Plain,
75    /// The lines variant.
76    Lines,
77    /// The whisper JSON variant.
78    WhisperJson,
79    /// The srt variant.
80    Srt,
81    /// The web vtt variant.
82    WebVtt,
83}
84
85impl TranscriptFormat {
86    /// Infers a transcript format from a file extension.
87    pub fn from_extension(extension: &str) -> Option<Self> {
88        match extension
89            .trim()
90            .trim_start_matches('.')
91            .to_ascii_lowercase()
92            .as_str()
93        {
94            "txt" | "text" => Some(Self::Plain),
95            "lines" => Some(Self::Lines),
96            "json" | "whisper" | "whisperjson" | "whisper-json" => Some(Self::WhisperJson),
97            "srt" => Some(Self::Srt),
98            "vtt" | "webvtt" | "web-vtt" => Some(Self::WebVtt),
99            _ => None,
100        }
101    }
102}
103
104#[derive(Debug, Clone, PartialEq)]
105/// Data type for transcript segment.
106pub struct TranscriptSegment {
107    /// The index value.
108    pub index: u64,
109    /// The start seconds value.
110    pub start_seconds: Option<f64>,
111    /// The end seconds value.
112    pub end_seconds: Option<f64>,
113    /// Text content for this value.
114    pub text: String,
115    /// Language tag for this value.
116    pub language: Option<String>,
117    /// The speaker value.
118    pub speaker: Option<String>,
119    /// Confidence score for this value.
120    pub confidence: Option<f32>,
121    /// The is final value.
122    pub is_final: bool,
123}
124
125impl TranscriptSegment {
126    /// Returns metadata.
127    pub fn metadata(&self) -> BTreeMap<String, String> {
128        let mut metadata = BTreeMap::new();
129        insert_optional(&mut metadata, "language", self.language.as_deref());
130        insert_optional(&mut metadata, "speaker", self.speaker.as_deref());
131        insert_optional_number(&mut metadata, "start_seconds", self.start_seconds);
132        insert_optional_number(&mut metadata, "end_seconds", self.end_seconds);
133        insert_optional_display(&mut metadata, "confidence", self.confidence);
134        metadata
135    }
136
137    /// Returns metadata with source.
138    pub fn metadata_with_source(&self, source: impl Into<String>) -> BTreeMap<String, String> {
139        let mut metadata = self.metadata();
140        let source = source.into();
141        if !source.is_empty() {
142            metadata.insert("source".to_string(), source);
143        }
144        metadata
145    }
146}
147
148#[derive(Debug, Clone, PartialEq)]
149/// Data type for transcription result.
150pub struct TranscriptionResult {
151    /// Text content for this value.
152    pub text: Option<String>,
153    /// Language tag for this value.
154    pub language: Option<String>,
155    /// The segments value.
156    pub segments: Vec<TranscriptSegment>,
157    /// The source value.
158    pub source: Option<String>,
159}
160
161#[derive(Debug, Clone, PartialEq)]
162/// Optional word-level transcript timing.
163pub struct TranscriptWord {
164    /// Word text.
165    pub text: String,
166    /// Optional start time in seconds.
167    pub start_seconds: Option<f64>,
168    /// Optional end time in seconds.
169    pub end_seconds: Option<f64>,
170    /// Optional confidence.
171    pub confidence: Option<f32>,
172}
173
174/// Trait for transcriber implementations.
175pub trait Transcriber {
176    /// Returns transcribe.
177    fn transcribe(&mut self, input: &Path) -> Result<TranscriptionResult>;
178}
179
180#[derive(Debug, Clone)]
181/// Options for subtitle text normalization.
182pub struct SubtitleNormalizationOptions {
183    /// Strip WebVTT/SRT markup used for subtitle rendering.
184    pub strip_markup: bool,
185    /// Decode common HTML entities used in captions.
186    pub decode_basic_entities: bool,
187    /// Collapse all whitespace runs into one ASCII space.
188    pub collapse_whitespace: bool,
189}
190
191impl Default for SubtitleNormalizationOptions {
192    fn default() -> Self {
193        Self {
194            strip_markup: true,
195            decode_basic_entities: true,
196            collapse_whitespace: true,
197        }
198    }
199}
200
201/// Normalizes subtitle cue text without parsing timing blocks.
202pub fn normalize_subtitle_text(text: &str, options: SubtitleNormalizationOptions) -> String {
203    let mut normalized = text.to_string();
204    if options.strip_markup {
205        normalized = strip_subtitle_markup(&normalized);
206    }
207    if options.decode_basic_entities {
208        normalized = decode_basic_entities(&normalized);
209    }
210    if options.collapse_whitespace {
211        normalized = collapse_whitespace(&normalized);
212    }
213    normalized
214}
215
216#[derive(Debug, Clone)]
217/// Options for command transcriber construction.
218pub struct CommandTranscriberOptions {
219    /// Command path.
220    pub command: PathBuf,
221    /// Extra command arguments.
222    pub args: Vec<String>,
223    /// Expected output format.
224    pub format: TranscriptFormat,
225    /// Optional timeout in seconds.
226    pub timeout_seconds: Option<u64>,
227}
228
229#[derive(Debug, Default, Clone)]
230/// Transcript-specific deterministic analyzer.
231pub struct TranscriptHeuristicAnalyzer;
232
233impl TextAnalyzer for TranscriptHeuristicAnalyzer {
234    fn name(&self) -> &str {
235        "transcript_heuristics"
236    }
237
238    fn process_segment(
239        &mut self,
240        segment: &TextSegment<'_>,
241    ) -> video_analysis_core::Result<Vec<AnalysisEvent>> {
242        let mut events = Vec::new();
243        let text = segment.text.trim();
244        if text.ends_with(['?', '؟', '?']) {
245            events.push(event_at(self.name(), "speech:question", segment.timestamp));
246        }
247        if has_token_kind(text, TokenKind::Url) {
248            events.push(event_at(self.name(), "speech:url", segment.timestamp));
249        }
250        if has_token_kind(text, TokenKind::Number) {
251            events.push(event_at(self.name(), "speech:number", segment.timestamp));
252        }
253        if tokenize_words(text).len() >= 30 {
254            events.push(event_at(
255                self.name(),
256                "speech:long_segment",
257                segment.timestamp,
258            ));
259        }
260        Ok(events)
261    }
262}
263
264#[derive(Debug, Clone)]
265/// Data type for command transcriber.
266pub struct CommandTranscriber {
267    command: PathBuf,
268    args: Vec<String>,
269    format: TranscriptFormat,
270    timeout_seconds: Option<u64>,
271}
272
273impl CommandTranscriber {
274    /// Creates a new value.
275    pub fn new(command: impl Into<PathBuf>, format: TranscriptFormat) -> Self {
276        Self {
277            command: command.into(),
278            args: Vec::new(),
279            format,
280            timeout_seconds: None,
281        }
282    }
283
284    /// Creates from options.
285    pub fn from_options(options: CommandTranscriberOptions) -> Self {
286        Self {
287            command: options.command,
288            args: options.args,
289            format: options.format,
290            timeout_seconds: options.timeout_seconds,
291        }
292    }
293
294    /// Returns args.
295    pub fn args(mut self, args: impl IntoIterator<Item = String>) -> Self {
296        self.args.extend(args);
297        self
298    }
299
300    /// Returns this value with timeout.
301    pub fn timeout_seconds(mut self, timeout_seconds: Option<u64>) -> Self {
302        self.timeout_seconds = timeout_seconds;
303        self
304    }
305}
306
307impl Transcriber for CommandTranscriber {
308    fn transcribe(&mut self, input: &Path) -> Result<TranscriptionResult> {
309        let child = Command::new(&self.command)
310            .args(&self.args)
311            .arg(input)
312            .stdin(Stdio::null())
313            .stdout(Stdio::piped())
314            .stderr(Stdio::piped())
315            .spawn()?;
316        let output = wait_with_optional_timeout(child, &self.command, self.timeout_seconds)?;
317        if !output.status.success() {
318            return Err(TranscriptionError::CommandFailed(
319                self.command.display().to_string(),
320            ));
321        }
322        parse_transcript_bytes(&output.stdout, self.format)
323    }
324}
325
326#[derive(Debug, Clone)]
327/// Options for whisper CLI transcriber construction.
328pub struct WhisperCliTranscriberOptions {
329    /// Command path.
330    pub command: PathBuf,
331    /// Extra command arguments.
332    pub args: Vec<String>,
333    /// Optional output directory.
334    pub output_dir: Option<PathBuf>,
335    /// Optional timeout in seconds.
336    pub timeout_seconds: Option<u64>,
337}
338
339#[derive(Debug, Clone)]
340/// Data type for whisper cli transcriber.
341pub struct WhisperCliTranscriber {
342    command: PathBuf,
343    args: Vec<String>,
344    output_dir: Option<PathBuf>,
345    timeout_seconds: Option<u64>,
346}
347
348impl WhisperCliTranscriber {
349    /// Creates a new value.
350    pub fn new(command: impl Into<PathBuf>) -> Self {
351        Self {
352            command: command.into(),
353            args: Vec::new(),
354            output_dir: None,
355            timeout_seconds: None,
356        }
357    }
358
359    /// Creates from options.
360    pub fn from_options(options: WhisperCliTranscriberOptions) -> Self {
361        Self {
362            command: options.command,
363            args: options.args,
364            output_dir: options.output_dir,
365            timeout_seconds: options.timeout_seconds,
366        }
367    }
368
369    /// Returns args.
370    pub fn args(mut self, args: impl IntoIterator<Item = String>) -> Self {
371        self.args.extend(args);
372        self
373    }
374
375    /// Returns output dir.
376    pub fn output_dir(mut self, output_dir: impl Into<PathBuf>) -> Self {
377        self.output_dir = Some(output_dir.into());
378        self
379    }
380
381    /// Returns this value with timeout.
382    pub fn timeout_seconds(mut self, timeout_seconds: Option<u64>) -> Self {
383        self.timeout_seconds = timeout_seconds;
384        self
385    }
386}
387
388impl Transcriber for WhisperCliTranscriber {
389    fn transcribe(&mut self, input: &Path) -> Result<TranscriptionResult> {
390        let output_dir = self.output_dir.clone().unwrap_or_else(|| {
391            input
392                .parent()
393                .unwrap_or_else(|| Path::new("."))
394                .join("transcript")
395        });
396        fs::create_dir_all(&output_dir)?;
397
398        let child = Command::new(&self.command)
399            .arg(input)
400            .args(&self.args)
401            .arg("--output_format")
402            .arg("json")
403            .arg("--output_dir")
404            .arg(&output_dir)
405            .stdin(Stdio::null())
406            .spawn()?;
407        let status = wait_status_with_optional_timeout(child, &self.command, self.timeout_seconds)?;
408        if !status.success() {
409            return Err(TranscriptionError::CommandFailed(
410                self.command.display().to_string(),
411            ));
412        }
413
414        let transcript_path = find_transcript_json(&output_dir).ok_or_else(|| {
415            TranscriptionError::InvalidTranscript(
416                "transcriber completed but no JSON transcript was found".to_string(),
417            )
418        })?;
419        let bytes = fs::read(&transcript_path)?;
420        let mut result = parse_whisper_json(&bytes)?;
421        result.source = Some(transcript_path.to_string_lossy().into_owned());
422        Ok(result)
423    }
424}
425
426/// Data type for whisper cpp transcriber.
427pub struct WhisperCppTranscriber {
428    inner: NativeWhisperCppTranscriber,
429}
430
431impl WhisperCppTranscriber {
432    /// Creates a new value.
433    pub fn new(config: WhisperCppConfig) -> Self {
434        Self {
435            inner: NativeWhisperCppTranscriber::new(config),
436        }
437    }
438
439    /// Returns this value with model store.
440    pub fn with_model_store(mut self, store: WhisperCppModelStore) -> Self {
441        self.inner = self.inner.with_model_store(store);
442        self
443    }
444
445    /// Returns on progress.
446    pub fn on_progress<F>(mut self, callback: F) -> Self
447    where
448        F: FnMut(WhisperCppProgressEvent) + 'static,
449    {
450        self.inner = self.inner.on_progress(callback);
451        self
452    }
453
454    /// Returns transcribe with progress.
455    pub fn transcribe_with_progress(
456        &mut self,
457        input: &Path,
458        progress: &mut dyn FnMut(WhisperCppProgressEvent),
459    ) -> Result<TranscriptionResult> {
460        let transcript = self.inner.transcribe_file_with_progress(input, progress)?;
461        Ok(whisper_cpp_result_to_transcription_result(transcript))
462    }
463}
464
465impl Transcriber for WhisperCppTranscriber {
466    fn transcribe(&mut self, input: &Path) -> Result<TranscriptionResult> {
467        let transcript = self.inner.transcribe_file(input)?;
468        Ok(whisper_cpp_result_to_transcription_result(transcript))
469    }
470}
471
472fn whisper_cpp_result_to_transcription_result(
473    transcript: whisper_cpp::WhisperCppTranscription,
474) -> TranscriptionResult {
475    TranscriptionResult {
476        text: transcript.text,
477        language: transcript.language.clone(),
478        segments: transcript
479            .segments
480            .into_iter()
481            .map(|segment| TranscriptSegment {
482                index: segment.index,
483                start_seconds: segment.start_seconds,
484                end_seconds: segment.end_seconds,
485                text: segment.text,
486                language: transcript.language.clone(),
487                speaker: None,
488                confidence: segment.confidence,
489                is_final: true,
490            })
491            .collect(),
492        source: transcript.source,
493    }
494}
495
496/// Data type for transcript segment source.
497pub struct TranscriptSegmentSource {
498    source_info: MediaSourceInfo,
499    segments: Vec<TranscriptSegment>,
500    next_index: usize,
501}
502
503impl TranscriptSegmentSource {
504    /// Returns recorded.
505    pub fn recorded(input: impl Into<String>, segments: Vec<TranscriptSegment>) -> Self {
506        Self::new(SourceMode::Recorded, input, segments)
507    }
508
509    /// Returns live.
510    pub fn live(input: impl Into<String>, segments: Vec<TranscriptSegment>) -> Self {
511        Self::new(SourceMode::Live, input, segments)
512    }
513
514    fn new(mode: SourceMode, input: impl Into<String>, segments: Vec<TranscriptSegment>) -> Self {
515        let language = segments.iter().find_map(|segment| segment.language.clone());
516        let source_info = MediaSourceInfo {
517            input: input.into(),
518            mode,
519            video: None,
520            audio: Vec::new(),
521            text: vec![TextStreamInfo {
522                format: IngestTextFormat::Transcript,
523                language,
524            }],
525        };
526        Self {
527            source_info,
528            segments,
529            next_index: 0,
530        }
531    }
532}
533
534impl TextSegmentSource for TranscriptSegmentSource {
535    fn source_info(&self) -> &MediaSourceInfo {
536        &self.source_info
537    }
538
539    fn next_text_segment(&mut self) -> video_analysis_core::Result<Option<OwnedTextSegment>> {
540        let Some(segment) = self.segments.get(self.next_index) else {
541            return Ok(None);
542        };
543        self.next_index += 1;
544        Ok(Some(segment_to_owned_text_segment(segment)))
545    }
546}
547
548#[derive(Debug, Deserialize)]
549struct WhisperOutput {
550    text: Option<String>,
551    language: Option<String>,
552    #[serde(default)]
553    segments: Vec<WhisperSegment>,
554}
555
556#[derive(Debug, Deserialize)]
557struct WhisperSegment {
558    id: Option<u64>,
559    start: Option<f64>,
560    end: Option<f64>,
561    text: String,
562    #[serde(default)]
563    avg_logprob: Option<f32>,
564    #[serde(default)]
565    no_speech_prob: Option<f32>,
566}
567
568/// Parses parse whisper JSON.
569pub fn parse_whisper_json(bytes: &[u8]) -> Result<TranscriptionResult> {
570    let parsed: WhisperOutput = serde_json::from_slice(bytes)?;
571    let segments = parsed
572        .segments
573        .into_iter()
574        .enumerate()
575        .map(|(index, segment)| TranscriptSegment {
576            index: segment.id.unwrap_or(index as u64),
577            start_seconds: segment.start,
578            end_seconds: segment.end,
579            text: segment.text.trim().to_string(),
580            language: parsed.language.clone(),
581            speaker: None,
582            confidence: whisper_confidence(segment.avg_logprob, segment.no_speech_prob),
583            is_final: true,
584        })
585        .collect::<Vec<_>>();
586    Ok(TranscriptionResult {
587        text: parsed.text.map(|text| text.trim().to_string()),
588        language: parsed.language,
589        segments,
590        source: None,
591    })
592}
593
594/// Normalizes an existing transcription contract.
595pub fn normalize_transcription_contract(
596    contract: TranscriptionContract,
597) -> Result<TranscriptionContract> {
598    contract.normalized()
599}
600
601/// Builds and normalizes a transcription contract from imported transcript segments.
602pub fn normalize_imported_segments(
603    source: Option<String>,
604    language: Option<String>,
605    segments: Vec<TranscriptSegmentContract>,
606) -> Result<TranscriptionContract> {
607    TranscriptionContract::from_segments(source, language, segments)
608}
609
610/// Parses WhisperX JSON into the shared transcription contract.
611pub fn parse_whisperx_json(bytes: &[u8]) -> Result<TranscriptionContract> {
612    let value: Value = serde_json::from_slice(bytes)?;
613    let object = value.as_object().ok_or_else(|| {
614        TranscriptionError::InvalidTranscript("WhisperX JSON must be an object".to_string())
615    })?;
616    let language = object
617        .get("language")
618        .and_then(Value::as_str)
619        .map(str::to_string);
620    let text = object
621        .get("text")
622        .and_then(Value::as_str)
623        .map(|text| text.trim().to_string())
624        .filter(|text| !text.is_empty());
625    let source = object
626        .get("source")
627        .and_then(Value::as_str)
628        .map(str::to_string);
629    let mut attributes = unknown_attributes(
630        object,
631        &["language", "text", "source", "segments", "word_segments"],
632    );
633    if let Some(count) = object.get("segment_count").and_then(Value::as_u64) {
634        attributes.insert("segment_count".to_string(), count.to_string());
635    }
636
637    let mut segments = object
638        .get("segments")
639        .and_then(Value::as_array)
640        .ok_or_else(|| {
641            TranscriptionError::InvalidTranscript(
642                "WhisperX JSON must include a segments array".to_string(),
643            )
644        })?
645        .iter()
646        .enumerate()
647        .map(|(index, segment)| whisperx_segment(segment, index as u64, language.clone()))
648        .collect::<Result<Vec<_>>>()?;
649
650    if segments.iter().all(|segment| segment.words.is_empty()) {
651        if let Some(words) = object.get("word_segments").and_then(Value::as_array) {
652            attach_flat_whisperx_words(&mut segments, words)?;
653        }
654    }
655
656    let mut contract = TranscriptionContract {
657        text,
658        language,
659        segments,
660        source,
661        attributes,
662    }
663    .normalized()?;
664    for segment in &mut contract.segments {
665        if segment.speaker.is_none() {
666            segment.speaker = infer_segment_speaker(&segment.words);
667        }
668    }
669    contract.validate_strict()?;
670    Ok(contract)
671}
672
673/// Parses parse srt.
674pub fn parse_srt(text: &str) -> Result<TranscriptionResult> {
675    parse_subtitle_blocks(text, TranscriptFormat::Srt)
676}
677
678/// Parses parse webvtt.
679pub fn parse_webvtt(text: &str) -> Result<TranscriptionResult> {
680    parse_subtitle_blocks(text, TranscriptFormat::WebVtt)
681}
682
683/// Parses parse plain lines.
684pub fn parse_plain_lines(text: &str) -> TranscriptionResult {
685    let segments = text
686        .lines()
687        .enumerate()
688        .filter_map(|(index, line)| {
689            let line = line.trim();
690            (!line.is_empty()).then(|| TranscriptSegment {
691                index: index as u64,
692                start_seconds: None,
693                end_seconds: None,
694                text: line.to_string(),
695                language: None,
696                speaker: None,
697                confidence: None,
698                is_final: true,
699            })
700        })
701        .collect::<Vec<_>>();
702    TranscriptionResult {
703        text: Some(
704            segments
705                .iter()
706                .map(|segment| segment.text.as_str())
707                .collect::<Vec<_>>()
708                .join("\n"),
709        ),
710        language: None,
711        segments,
712        source: None,
713    }
714}
715
716/// Returns format srt.
717pub fn format_srt(segments: &[TranscriptSegment]) -> String {
718    let mut output = String::new();
719    for (index, segment) in segments.iter().enumerate() {
720        let start = segment.start_seconds.unwrap_or(0.0);
721        let end = segment
722            .end_seconds
723            .unwrap_or_else(|| (start + 2.0).max(start));
724        output.push_str(&(index + 1).to_string());
725        output.push('\n');
726        output.push_str(&format_srt_timestamp(start));
727        output.push_str(" --> ");
728        output.push_str(&format_srt_timestamp(end.max(start)));
729        output.push('\n');
730        output.push_str(segment.text.trim());
731        output.push_str("\n\n");
732    }
733    output
734}
735
736/// Returns format webvtt.
737pub fn format_webvtt(segments: &[TranscriptSegment]) -> String {
738    let mut output = String::from("WEBVTT\n\n");
739    for (index, segment) in segments.iter().enumerate() {
740        if index > 0 {
741            output.push('\n');
742        }
743        let start = segment.start_seconds.unwrap_or(0.0);
744        let end = segment
745            .end_seconds
746            .unwrap_or_else(|| (start + 2.0).max(start));
747        output.push_str(&format_webvtt_timestamp(start));
748        output.push_str(" --> ");
749        output.push_str(&format_webvtt_timestamp(end.max(start)));
750        output.push('\n');
751        output.push_str(segment.text.trim());
752        output.push('\n');
753    }
754    output
755}
756
757/// Writes srt.
758pub fn write_srt(path: impl AsRef<Path>, segments: &[TranscriptSegment]) -> Result<()> {
759    let path = path.as_ref();
760    if let Some(parent) = path.parent() {
761        fs::create_dir_all(parent)?;
762    }
763    fs::write(path, format_srt(segments))?;
764    Ok(())
765}
766
767/// Returns transcribe waveform batch.
768pub fn transcribe_waveform_batch<T: Transcriber>(
769    transcriber: &mut T,
770    batch: &OwnedAudioWaveformBatch,
771    wav_path: &Path,
772) -> Result<TranscriptionResult> {
773    write_waveform_batch_as_wav(wav_path, batch)?;
774    transcriber.transcribe(wav_path)
775}
776
777fn write_waveform_batch_as_wav(
778    path: impl AsRef<Path>,
779    batch: &OwnedAudioWaveformBatch,
780) -> Result<()> {
781    let view = batch.as_view()?;
782    if view.batch_size() != 1 {
783        return Err(video_analysis_core::DetectError::InvalidArgument(
784            "waveform WAV export requires a batch size of 1".to_string(),
785        )
786        .into());
787    }
788    let path = path.as_ref();
789    if let Some(parent) = path.parent() {
790        fs::create_dir_all(parent)?;
791    }
792    let spec = hound::WavSpec {
793        channels: view.channel_count() as u16,
794        sample_rate: view.sample_rate,
795        bits_per_sample: 32,
796        sample_format: hound::SampleFormat::Float,
797    };
798    let mut writer = hound::WavWriter::create(path, spec).map_err(|err| {
799        video_analysis_core::DetectError::Source(format!(
800            "failed to create WAV `{}`: {err}",
801            path.display()
802        ))
803    })?;
804    for time_index in 0..view.time_steps() {
805        for channel_index in 0..view.channel_count() {
806            let sample = view.waveform(0, channel_index)?[time_index];
807            writer.write_sample(sample).map_err(|err| {
808                video_analysis_core::DetectError::Source(format!(
809                    "failed to write WAV sample `{}`: {err}",
810                    path.display()
811                ))
812            })?;
813        }
814    }
815    writer.finalize().map_err(|err| {
816        video_analysis_core::DetectError::Source(format!(
817            "failed to finalize WAV `{}`: {err}",
818            path.display()
819        ))
820    })?;
821    Ok(())
822}
823
824/// Returns format srt timestamp.
825pub fn format_srt_timestamp(seconds: f64) -> String {
826    let total_millis = (seconds.max(0.0) * 1_000.0).round() as u64;
827    let millis = total_millis % 1_000;
828    let total_seconds = total_millis / 1_000;
829    let secs = total_seconds % 60;
830    let total_minutes = total_seconds / 60;
831    let minutes = total_minutes % 60;
832    let hours = total_minutes / 60;
833    format!("{hours:02}:{minutes:02}:{secs:02},{millis:03}")
834}
835
836fn format_webvtt_timestamp(seconds: f64) -> String {
837    format_srt_timestamp(seconds).replace(',', ".")
838}
839
840/// Returns segment to owned text segment.
841pub fn segment_to_owned_text_segment(segment: &TranscriptSegment) -> OwnedTextSegment {
842    text_core::TextSegmentContract::from(TranscriptSegmentContract::from(segment))
843        .to_owned_text_segment()
844}
845
846/// Parses a transcript file by inferring the format from its extension.
847pub fn parse_transcript_file(path: impl AsRef<Path>) -> Result<TranscriptionResult> {
848    let path = path.as_ref();
849    let extension = path
850        .extension()
851        .and_then(|value| value.to_str())
852        .ok_or_else(|| {
853            TranscriptionError::InvalidTranscript("transcript file missing extension".to_string())
854        })?;
855    let format = TranscriptFormat::from_extension(extension).ok_or_else(|| {
856        TranscriptionError::InvalidTranscript(format!(
857            "unsupported transcript file extension `{extension}`"
858        ))
859    })?;
860    let bytes = fs::read(path)?;
861    let mut parsed = parse_transcript_bytes(&bytes, format)?;
862    parsed.source = Some(path.to_string_lossy().into_owned());
863    Ok(parsed)
864}
865
866/// Parses and normalizes a transcript file into the stable transcript contract.
867pub fn parse_normalized_transcript_file(
868    path: impl AsRef<Path>,
869    options: SubtitleNormalizationOptions,
870) -> Result<TranscriptionContract> {
871    let parsed = parse_transcript_file(path)?;
872    let mut contract = TranscriptionContract::from(parsed);
873    contract.segments = contract
874        .segments
875        .into_iter()
876        .filter_map(|mut segment| {
877            segment.text = normalize_subtitle_text(&segment.text, options.clone());
878            (!segment.text.is_empty()).then_some(segment)
879        })
880        .collect();
881    contract.text = contract
882        .text
883        .as_deref()
884        .map(|text| normalize_subtitle_text(text, options))
885        .filter(|text| !text.is_empty());
886    contract.normalized()
887}
888
889fn parse_transcript_bytes(bytes: &[u8], format: TranscriptFormat) -> Result<TranscriptionResult> {
890    match format {
891        TranscriptFormat::Plain | TranscriptFormat::Lines => {
892            Ok(parse_plain_lines(&String::from_utf8_lossy(bytes)))
893        }
894        TranscriptFormat::WhisperJson => parse_whisper_json(bytes),
895        TranscriptFormat::Srt => parse_srt(&String::from_utf8_lossy(bytes)),
896        TranscriptFormat::WebVtt => parse_webvtt(&String::from_utf8_lossy(bytes)),
897    }
898}
899
900fn whisperx_segment(
901    value: &Value,
902    fallback_index: u64,
903    language: Option<String>,
904) -> Result<TranscriptSegmentContract> {
905    let object = value.as_object().ok_or_else(|| {
906        TranscriptionError::InvalidTranscript("WhisperX segment must be an object".to_string())
907    })?;
908    let mut segment = TranscriptSegmentContract::new(
909        object
910            .get("id")
911            .or_else(|| object.get("index"))
912            .and_then(Value::as_u64)
913            .unwrap_or(fallback_index),
914        object
915            .get("text")
916            .and_then(Value::as_str)
917            .unwrap_or_default(),
918    );
919    segment.start_seconds = number_field(object, &["start", "start_seconds", "startSeconds"]);
920    segment.end_seconds = number_field(object, &["end", "end_seconds", "endSeconds"]);
921    segment.language = object
922        .get("language")
923        .and_then(Value::as_str)
924        .map(str::to_string)
925        .or(language);
926    segment.speaker = object
927        .get("speaker")
928        .or_else(|| object.get("speaker_label"))
929        .or_else(|| object.get("speakerLabel"))
930        .and_then(Value::as_str)
931        .map(str::to_string);
932    segment.confidence = confidence_field(
933        object,
934        &["confidence", "score", "avg_logprob", "no_speech_prob"],
935    );
936    segment.words = object
937        .get("words")
938        .or_else(|| object.get("word_segments"))
939        .and_then(Value::as_array)
940        .map(|words| {
941            words
942                .iter()
943                .map(whisperx_word)
944                .collect::<Result<Vec<TranscriptWordContract>>>()
945        })
946        .transpose()?
947        .unwrap_or_default();
948    segment.chars = object
949        .get("chars")
950        .or_else(|| object.get("characters"))
951        .and_then(Value::as_array)
952        .map(|chars| {
953            chars
954                .iter()
955                .map(whisperx_char)
956                .collect::<Result<Vec<TranscriptCharContract>>>()
957        })
958        .transpose()?
959        .unwrap_or_default();
960    if segment.speaker.is_none() {
961        segment.speaker = infer_segment_speaker(&segment.words);
962    }
963    segment.attributes = unknown_attributes(
964        object,
965        &[
966            "id",
967            "index",
968            "start",
969            "start_seconds",
970            "startSeconds",
971            "end",
972            "end_seconds",
973            "endSeconds",
974            "text",
975            "language",
976            "speaker",
977            "speaker_label",
978            "speakerLabel",
979            "confidence",
980            "score",
981            "avg_logprob",
982            "no_speech_prob",
983            "words",
984            "word_segments",
985            "chars",
986            "characters",
987        ],
988    );
989    Ok(segment)
990}
991
992fn whisperx_word(value: &Value) -> Result<TranscriptWordContract> {
993    let object = value.as_object().ok_or_else(|| {
994        TranscriptionError::InvalidTranscript("WhisperX word must be an object".to_string())
995    })?;
996    Ok(TranscriptWordContract {
997        text: object
998            .get("word")
999            .or_else(|| object.get("text"))
1000            .and_then(Value::as_str)
1001            .unwrap_or_default()
1002            .to_string(),
1003        start_seconds: number_field(object, &["start", "start_seconds", "startSeconds"]),
1004        end_seconds: number_field(object, &["end", "end_seconds", "endSeconds"]),
1005        confidence: confidence_field(object, &["confidence", "score", "probability"]),
1006        speaker: object
1007            .get("speaker")
1008            .or_else(|| object.get("speaker_label"))
1009            .or_else(|| object.get("speakerLabel"))
1010            .and_then(Value::as_str)
1011            .map(str::to_string),
1012        attributes: unknown_attributes(
1013            object,
1014            &[
1015                "word",
1016                "text",
1017                "start",
1018                "start_seconds",
1019                "startSeconds",
1020                "end",
1021                "end_seconds",
1022                "endSeconds",
1023                "confidence",
1024                "score",
1025                "probability",
1026                "speaker",
1027                "speaker_label",
1028                "speakerLabel",
1029            ],
1030        ),
1031    })
1032}
1033
1034fn whisperx_char(value: &Value) -> Result<TranscriptCharContract> {
1035    let object = value.as_object().ok_or_else(|| {
1036        TranscriptionError::InvalidTranscript("WhisperX char must be an object".to_string())
1037    })?;
1038    Ok(TranscriptCharContract {
1039        character: object
1040            .get("char")
1041            .or_else(|| object.get("character"))
1042            .or_else(|| object.get("text"))
1043            .and_then(Value::as_str)
1044            .unwrap_or_default()
1045            .to_string(),
1046        start_seconds: number_field(object, &["start", "start_seconds", "startSeconds"]),
1047        end_seconds: number_field(object, &["end", "end_seconds", "endSeconds"]),
1048        confidence: confidence_field(object, &["confidence", "score", "probability"]),
1049        attributes: unknown_attributes(
1050            object,
1051            &[
1052                "char",
1053                "character",
1054                "text",
1055                "start",
1056                "start_seconds",
1057                "startSeconds",
1058                "end",
1059                "end_seconds",
1060                "endSeconds",
1061                "confidence",
1062                "score",
1063                "probability",
1064            ],
1065        ),
1066    })
1067}
1068
1069fn attach_flat_whisperx_words(
1070    segments: &mut [TranscriptSegmentContract],
1071    words: &[Value],
1072) -> Result<()> {
1073    for value in words {
1074        let word = whisperx_word(value)?;
1075        let midpoint = match (word.start_seconds, word.end_seconds) {
1076            (Some(start), Some(end)) => Some((start + end) * 0.5),
1077            (Some(start), None) => Some(start),
1078            (None, Some(end)) => Some(end),
1079            (None, None) => None,
1080        };
1081        let Some(segment) = segments.iter_mut().find(|segment| {
1082            midpoint
1083                .zip(segment.start_seconds.zip(segment.end_seconds))
1084                .map(|(midpoint, (start, end))| midpoint >= start && midpoint <= end)
1085                .unwrap_or(false)
1086        }) else {
1087            continue;
1088        };
1089        segment.words.push(word);
1090    }
1091    for segment in segments {
1092        if segment.speaker.is_none() {
1093            segment.speaker = infer_segment_speaker(&segment.words);
1094        }
1095    }
1096    Ok(())
1097}
1098
1099fn infer_segment_speaker(words: &[TranscriptWordContract]) -> Option<String> {
1100    let mut scores: BTreeMap<&str, (f64, usize)> = BTreeMap::new();
1101    for word in words {
1102        let Some(speaker) = word
1103            .speaker
1104            .as_deref()
1105            .filter(|speaker| !speaker.is_empty())
1106        else {
1107            continue;
1108        };
1109        let duration = word
1110            .start_seconds
1111            .zip(word.end_seconds)
1112            .map(|(start, end)| (end - start).max(0.0))
1113            .filter(|duration| duration.is_finite() && *duration > 0.0)
1114            .unwrap_or(1.0);
1115        let entry = scores.entry(speaker).or_insert((0.0, 0));
1116        entry.0 += duration;
1117        entry.1 += 1;
1118    }
1119    scores
1120        .into_iter()
1121        .max_by(|left, right| {
1122            left.1
1123                 .0
1124                .total_cmp(&right.1 .0)
1125                .then(left.1 .1.cmp(&right.1 .1))
1126                .then_with(|| right.0.cmp(left.0))
1127        })
1128        .map(|(speaker, _)| speaker.to_string())
1129}
1130
1131fn unknown_attributes(
1132    object: &serde_json::Map<String, Value>,
1133    known_fields: &[&str],
1134) -> BTreeMap<String, String> {
1135    object
1136        .iter()
1137        .filter(|(key, _)| !known_fields.contains(&key.as_str()))
1138        .map(|(key, value)| (key.clone(), json_attribute(value)))
1139        .collect()
1140}
1141
1142fn json_attribute(value: &Value) -> String {
1143    match value {
1144        Value::String(value) => value.clone(),
1145        Value::Number(value) => value.to_string(),
1146        Value::Bool(value) => value.to_string(),
1147        Value::Null => "null".to_string(),
1148        other => other.to_string(),
1149    }
1150}
1151
1152fn number_field(object: &serde_json::Map<String, Value>, names: &[&str]) -> Option<f64> {
1153    names
1154        .iter()
1155        .find_map(|name| object.get(*name).and_then(Value::as_f64))
1156        .filter(|value| value.is_finite())
1157}
1158
1159fn confidence_field(object: &serde_json::Map<String, Value>, names: &[&str]) -> Option<f32> {
1160    for name in names {
1161        let Some(value) = object.get(*name).and_then(Value::as_f64) else {
1162            continue;
1163        };
1164        if !value.is_finite() {
1165            continue;
1166        }
1167        return match *name {
1168            "avg_logprob" => Some(value.exp().clamp(0.0, 1.0) as f32),
1169            "no_speech_prob" => Some((1.0 - value).clamp(0.0, 1.0) as f32),
1170            _ => Some(value.clamp(0.0, 1.0) as f32),
1171        };
1172    }
1173    None
1174}
1175
1176fn wait_with_optional_timeout(
1177    mut child: Child,
1178    command: &Path,
1179    timeout_seconds: Option<u64>,
1180) -> Result<Output> {
1181    if let Some(seconds) = timeout_seconds {
1182        let deadline = std::time::Instant::now() + Duration::from_secs(seconds);
1183        loop {
1184            if child.try_wait()?.is_some() {
1185                return Ok(child.wait_with_output()?);
1186            }
1187            if std::time::Instant::now() >= deadline {
1188                let _ = child.kill();
1189                let _ = child.wait();
1190                return Err(TranscriptionError::CommandTimeout {
1191                    command: command.display().to_string(),
1192                    seconds,
1193                });
1194            }
1195            std::thread::sleep(Duration::from_millis(25));
1196        }
1197    }
1198    Ok(child.wait_with_output()?)
1199}
1200
1201fn wait_status_with_optional_timeout(
1202    mut child: Child,
1203    command: &Path,
1204    timeout_seconds: Option<u64>,
1205) -> Result<ExitStatus> {
1206    if let Some(seconds) = timeout_seconds {
1207        let deadline = std::time::Instant::now() + Duration::from_secs(seconds);
1208        loop {
1209            if let Some(status) = child.try_wait()? {
1210                return Ok(status);
1211            }
1212            if std::time::Instant::now() >= deadline {
1213                let _ = child.kill();
1214                let _ = child.wait();
1215                return Err(TranscriptionError::CommandTimeout {
1216                    command: command.display().to_string(),
1217                    seconds,
1218                });
1219            }
1220            std::thread::sleep(Duration::from_millis(25));
1221        }
1222    }
1223    Ok(child.wait()?)
1224}
1225
1226fn strip_subtitle_markup(text: &str) -> String {
1227    let mut output = String::with_capacity(text.len());
1228    let mut chars = text.chars().peekable();
1229    while let Some(ch) = chars.next() {
1230        if ch != '<' {
1231            output.push(ch);
1232            continue;
1233        }
1234
1235        let mut tag = String::new();
1236        let mut closed = false;
1237        for tag_ch in chars.by_ref() {
1238            if tag_ch == '>' {
1239                closed = true;
1240                break;
1241            }
1242            tag.push(tag_ch);
1243        }
1244
1245        if !closed {
1246            output.push('<');
1247            output.push_str(&tag);
1248            break;
1249        }
1250
1251        let tag = tag.trim();
1252        if tag.is_empty() {
1253            continue;
1254        }
1255
1256        if is_subtitle_timestamp(tag) {
1257            output.push(' ');
1258        } else if is_subtitle_tag(tag) {
1259            continue;
1260        } else {
1261            output.push('<');
1262            output.push_str(tag);
1263            output.push('>');
1264        }
1265    }
1266    output
1267}
1268
1269fn is_subtitle_tag(tag: &str) -> bool {
1270    let tag = tag.trim_start_matches('/');
1271    let name = tag
1272        .split(|ch: char| ch.is_whitespace() || ch == '.')
1273        .next()
1274        .unwrap_or_default();
1275    matches!(name, "b" | "c" | "i" | "lang" | "rt" | "ruby" | "u" | "v")
1276}
1277
1278fn is_subtitle_timestamp(value: &str) -> bool {
1279    let value = value.replace(',', ".");
1280    let parts = value.split(':').collect::<Vec<_>>();
1281    let [hours, minutes, seconds] = parts.as_slice() else {
1282        return false;
1283    };
1284    is_two_digits(hours)
1285        && is_two_digits(minutes)
1286        && seconds.len() == 6
1287        && seconds.as_bytes().get(2) == Some(&b'.')
1288        && seconds[..2].bytes().all(|byte| byte.is_ascii_digit())
1289        && seconds[3..].bytes().all(|byte| byte.is_ascii_digit())
1290}
1291
1292fn is_two_digits(value: &str) -> bool {
1293    value.len() == 2 && value.bytes().all(|byte| byte.is_ascii_digit())
1294}
1295
1296fn decode_basic_entities(text: &str) -> String {
1297    text.replace("&amp;", "&")
1298        .replace("&lt;", "<")
1299        .replace("&gt;", ">")
1300        .replace("&quot;", "\"")
1301        .replace("&#39;", "'")
1302        .replace("&apos;", "'")
1303        .replace("&nbsp;", " ")
1304}
1305
1306fn collapse_whitespace(text: &str) -> String {
1307    text.split_whitespace().collect::<Vec<_>>().join(" ")
1308}
1309
1310fn parse_subtitle_blocks(text: &str, format: TranscriptFormat) -> Result<TranscriptionResult> {
1311    let normalized = text.replace("\r\n", "\n").replace('\r', "\n");
1312    let mut segments = Vec::new();
1313
1314    for block in normalized.split("\n\n") {
1315        let lines = block
1316            .lines()
1317            .map(str::trim)
1318            .filter(|line| !line.is_empty())
1319            .collect::<Vec<_>>();
1320        if lines.is_empty()
1321            || (format == TranscriptFormat::WebVtt && lines[0].starts_with("WEBVTT"))
1322        {
1323            continue;
1324        }
1325
1326        let time_line_index = lines
1327            .iter()
1328            .position(|line| line.contains("-->"))
1329            .ok_or_else(|| {
1330                TranscriptionError::InvalidTranscript(
1331                    "subtitle block missing timestamp".to_string(),
1332                )
1333            })?;
1334        let (start_seconds, end_seconds) = parse_timestamp_range(lines[time_line_index])?;
1335        let text_lines = &lines[time_line_index + 1..];
1336        if text_lines.is_empty() {
1337            continue;
1338        }
1339        let index = if time_line_index > 0 {
1340            lines[0].parse::<u64>().unwrap_or(segments.len() as u64)
1341        } else {
1342            segments.len() as u64
1343        };
1344        segments.push(TranscriptSegment {
1345            index,
1346            start_seconds: Some(start_seconds),
1347            end_seconds: Some(end_seconds),
1348            text: text_lines.join(" "),
1349            language: None,
1350            speaker: None,
1351            confidence: None,
1352            is_final: true,
1353        });
1354    }
1355
1356    Ok(TranscriptionResult {
1357        text: Some(
1358            segments
1359                .iter()
1360                .map(|segment| segment.text.as_str())
1361                .collect::<Vec<_>>()
1362                .join("\n"),
1363        ),
1364        language: None,
1365        segments,
1366        source: None,
1367    })
1368}
1369
1370fn parse_timestamp_range(line: &str) -> Result<(f64, f64)> {
1371    let Some((start, end_with_settings)) = line.split_once("-->") else {
1372        return Err(TranscriptionError::InvalidTranscript(
1373            "timestamp range missing -->".to_string(),
1374        ));
1375    };
1376    let end = end_with_settings
1377        .split_whitespace()
1378        .next()
1379        .unwrap_or(end_with_settings);
1380    Ok((parse_timestamp(start.trim())?, parse_timestamp(end.trim())?))
1381}
1382
1383fn parse_timestamp(value: &str) -> Result<f64> {
1384    let value = value.replace(',', ".");
1385    let pieces = value.split(':').collect::<Vec<_>>();
1386    let seconds = match pieces.as_slice() {
1387        [minutes, seconds] => {
1388            Some(parse_timestamp_component(minutes)? * 60.0 + parse_timestamp_component(seconds)?)
1389        }
1390        [hours, minutes, seconds] => Some(
1391            parse_timestamp_component(hours)? * 3600.0
1392                + parse_timestamp_component(minutes)? * 60.0
1393                + parse_timestamp_component(seconds)?,
1394        ),
1395        _ => None,
1396    };
1397    seconds.ok_or_else(|| {
1398        TranscriptionError::InvalidTranscript(format!("invalid timestamp `{value}`"))
1399    })
1400}
1401
1402fn parse_timestamp_component(value: &str) -> Result<f64> {
1403    value.parse::<f64>().map_err(|_| {
1404        TranscriptionError::InvalidTranscript(format!("invalid timestamp component `{value}`"))
1405    })
1406}
1407
1408fn whisper_confidence(avg_logprob: Option<f32>, no_speech_prob: Option<f32>) -> Option<f32> {
1409    avg_logprob.or(no_speech_prob.map(|probability| 1.0 - probability))
1410}
1411
1412fn insert_optional(metadata: &mut BTreeMap<String, String>, key: &str, value: Option<&str>) {
1413    if let Some(value) = value {
1414        metadata.insert(key.to_string(), value.to_string());
1415    }
1416}
1417
1418fn insert_optional_number(metadata: &mut BTreeMap<String, String>, key: &str, value: Option<f64>) {
1419    insert_optional_display(metadata, key, value);
1420}
1421
1422fn insert_optional_display<T: fmt::Display>(
1423    metadata: &mut BTreeMap<String, String>,
1424    key: &str,
1425    value: Option<T>,
1426) {
1427    if let Some(value) = value {
1428        metadata.insert(key.to_string(), value.to_string());
1429    }
1430}
1431
1432fn has_token_kind(text: &str, kind: TokenKind) -> bool {
1433    tokenize(text, &TextProcessingOptions::default())
1434        .into_iter()
1435        .any(|token| token.kind == kind)
1436}
1437
1438fn event_at(analyzer: &str, label: &str, timestamp: Option<Timestamp>) -> AnalysisEvent {
1439    let event = AnalysisEvent::new(analyzer, label);
1440    if let Some(timestamp) = timestamp {
1441        event.at_timestamp(timestamp)
1442    } else {
1443        event
1444    }
1445}
1446
1447fn find_transcript_json(output_dir: &Path) -> Option<PathBuf> {
1448    let mut candidates = fs::read_dir(output_dir)
1449        .ok()?
1450        .filter_map(|entry| entry.ok().map(|entry| entry.path()))
1451        .filter(|path| path.extension().and_then(|value| value.to_str()) == Some("json"))
1452        .collect::<Vec<_>>();
1453    candidates.sort_by(|left, right| {
1454        let left_modified = fs::metadata(left)
1455            .and_then(|metadata| metadata.modified())
1456            .ok();
1457        let right_modified = fs::metadata(right)
1458            .and_then(|metadata| metadata.modified())
1459            .ok();
1460        left_modified
1461            .cmp(&right_modified)
1462            .then_with(|| left.cmp(right))
1463    });
1464    candidates.pop()
1465}
1466
1467#[cfg(test)]
1468mod tests {
1469    use super::*;
1470    use tempfile::tempdir;
1471    use video_analysis_core::{AudioBuffer, OwnedAudioFrame, Timebase, Timestamp};
1472    use video_analysis_ingest::TextSegmentSource;
1473
1474    #[test]
1475    fn parses_whisper_json() {
1476        let parsed = parse_whisper_json(
1477            br#"{"text":"hello world","language":"en","segments":[{"id":7,"start":0.0,"end":1.5,"text":" hello"}]}"#,
1478        )
1479        .unwrap();
1480
1481        assert_eq!(parsed.text.as_deref(), Some("hello world"));
1482        assert_eq!(parsed.language.as_deref(), Some("en"));
1483        assert_eq!(parsed.segments[0].index, 7);
1484        assert_eq!(parsed.segments[0].start_seconds, Some(0.0));
1485    }
1486
1487    #[test]
1488    fn parses_whisperx_segment_chars_and_preserves_unknown_fields() {
1489        let parsed = parse_whisperx_json(
1490            br#"{
1491                "text": "hi",
1492                "language": "en",
1493                "segments": [{
1494                    "id": 0,
1495                    "start": 1.0,
1496                    "end": 2.0,
1497                    "text": "hi",
1498                    "chars": [
1499                        {"char": "h", "start": 1.1, "end": 1.2, "score": 0.9, "extra": "kept"},
1500                        {"char": "i", "start": 1.2, "end": 1.3, "score": 0.8}
1501                    ]
1502                }]
1503            }"#,
1504        )
1505        .unwrap();
1506
1507        let chars = &parsed.segments[0].chars;
1508        assert_eq!(chars.len(), 2);
1509        assert_eq!(chars[0].character, "h");
1510        assert_eq!(chars[0].start_seconds, Some(1.1));
1511        assert_eq!(chars[0].end_seconds, Some(1.2));
1512        assert_eq!(chars[0].confidence, Some(0.9));
1513        assert_eq!(
1514            chars[0].attributes.get("extra").map(String::as_str),
1515            Some("kept")
1516        );
1517    }
1518
1519    #[test]
1520    fn normalization_keeps_chars_and_strict_validation_checks_bounds() {
1521        let mut segment = TranscriptSegmentContract::new(0, " hello ");
1522        segment.start_seconds = Some(0.0);
1523        segment.end_seconds = Some(1.0);
1524        segment.chars.push(TranscriptCharContract {
1525            character: "h".to_string(),
1526            start_seconds: Some(0.1),
1527            end_seconds: Some(0.2),
1528            confidence: Some(0.5),
1529            attributes: BTreeMap::new(),
1530        });
1531
1532        let normalized = TranscriptionContract::new(vec![segment])
1533            .normalized()
1534            .unwrap();
1535        assert_eq!(normalized.segments[0].text, "hello");
1536        assert_eq!(normalized.segments[0].chars.len(), 1);
1537        normalized.validate_strict().unwrap();
1538
1539        let mut invalid = normalized.clone();
1540        invalid.segments[0].chars[0].end_seconds = Some(1.1);
1541        let error = invalid.validate_strict().unwrap_err().to_string();
1542        assert!(error.contains("transcript char end_seconds"));
1543    }
1544
1545    #[test]
1546    fn parses_srt() {
1547        let parsed = parse_srt("1\n00:00:01,000 --> 00:00:02,500\nHello\n\n").unwrap();
1548        assert_eq!(parsed.segments.len(), 1);
1549        assert_eq!(parsed.segments[0].start_seconds, Some(1.0));
1550        assert_eq!(parsed.segments[0].end_seconds, Some(2.5));
1551    }
1552
1553    #[test]
1554    fn formats_srt() {
1555        let text = format_srt(&[
1556            TranscriptSegment {
1557                index: 0,
1558                start_seconds: Some(1.25),
1559                end_seconds: Some(3.5),
1560                text: "Hello".to_string(),
1561                language: None,
1562                speaker: None,
1563                confidence: None,
1564                is_final: true,
1565            },
1566            TranscriptSegment {
1567                index: 1,
1568                start_seconds: Some(63.0),
1569                end_seconds: Some(65.125),
1570                text: "World".to_string(),
1571                language: None,
1572                speaker: None,
1573                confidence: None,
1574                is_final: true,
1575            },
1576        ]);
1577
1578        assert_eq!(
1579            text,
1580            "1\n00:00:01,250 --> 00:00:03,500\nHello\n\n2\n00:01:03,000 --> 00:01:05,125\nWorld\n\n"
1581        );
1582    }
1583
1584    #[test]
1585    fn formats_webvtt() {
1586        let text = format_webvtt(&[TranscriptSegment {
1587            index: 0,
1588            start_seconds: Some(1.0),
1589            end_seconds: Some(0.5),
1590            text: "Hello.".to_string(),
1591            language: None,
1592            speaker: None,
1593            confidence: None,
1594            is_final: true,
1595        }]);
1596
1597        assert_eq!(text, "WEBVTT\n\n00:00:01.000 --> 00:00:01.000\nHello.\n");
1598    }
1599
1600    #[test]
1601    fn parses_webvtt() {
1602        let parsed = parse_webvtt("WEBVTT\n\ncue\n00:00:03.000 --> 00:00:04.250\nHi\n").unwrap();
1603        assert_eq!(parsed.segments.len(), 1);
1604        assert_eq!(parsed.segments[0].text, "Hi");
1605        assert_eq!(parsed.segments[0].start_seconds, Some(3.0));
1606    }
1607
1608    #[test]
1609    fn parses_plain_lines() {
1610        let parsed = parse_plain_lines("one\n\ntwo\n");
1611        assert_eq!(parsed.segments.len(), 2);
1612        assert_eq!(parsed.text.as_deref(), Some("one\ntwo"));
1613    }
1614
1615    #[test]
1616    fn converts_segment_timestamp() {
1617        let segment = TranscriptSegment {
1618            index: 2,
1619            start_seconds: Some(1.25),
1620            end_seconds: Some(2.0),
1621            text: "hello".to_string(),
1622            language: Some("en".to_string()),
1623            speaker: None,
1624            confidence: None,
1625            is_final: true,
1626        };
1627        let owned = segment_to_owned_text_segment(&segment);
1628        assert_eq!(owned.segment_index, 2);
1629        assert_eq!(owned.timestamp.unwrap().seconds(), 1.25);
1630        assert_eq!(owned.language.as_deref(), Some("en"));
1631    }
1632
1633    #[test]
1634    fn transcript_segment_metadata_preserves_optional_fields() {
1635        let segment = TranscriptSegment {
1636            index: 2,
1637            start_seconds: Some(1.25),
1638            end_seconds: Some(2.0),
1639            text: "hello".to_string(),
1640            language: Some("en".to_string()),
1641            speaker: Some("speaker-1".to_string()),
1642            confidence: Some(0.75),
1643            is_final: true,
1644        };
1645
1646        let metadata = segment.metadata_with_source("fixture.srt");
1647
1648        assert_eq!(metadata["language"], "en");
1649        assert_eq!(metadata["speaker"], "speaker-1");
1650        assert_eq!(metadata["start_seconds"], "1.25");
1651        assert_eq!(metadata["end_seconds"], "2");
1652        assert_eq!(metadata["confidence"], "0.75");
1653        assert_eq!(metadata["source"], "fixture.srt");
1654    }
1655
1656    #[test]
1657    fn transcript_segment_source_iterates() {
1658        let mut source = TranscriptSegmentSource::recorded(
1659            "test",
1660            vec![TranscriptSegment {
1661                index: 0,
1662                start_seconds: None,
1663                end_seconds: None,
1664                text: "hello".to_string(),
1665                language: None,
1666                speaker: None,
1667                confidence: None,
1668                is_final: true,
1669            }],
1670        );
1671        assert_eq!(source.next_text_segment().unwrap().unwrap().text, "hello");
1672        assert!(source.next_text_segment().unwrap().is_none());
1673    }
1674
1675    #[test]
1676    fn transcript_heuristic_analyzer_emits_speech_events() {
1677        let segment = TextSegment {
1678            segment_index: 0,
1679            timestamp: None,
1680            text: "Visit https://example.com at 3?",
1681            language: None,
1682            is_final: true,
1683        };
1684        let mut analyzer = TranscriptHeuristicAnalyzer;
1685
1686        let labels = analyzer
1687            .process_segment(&segment)
1688            .unwrap()
1689            .into_iter()
1690            .map(|event| event.label)
1691            .collect::<Vec<_>>();
1692
1693        assert!(labels.iter().any(|label| label == "speech:question"));
1694        assert!(labels.iter().any(|label| label == "speech:url"));
1695        assert!(labels.iter().any(|label| label == "speech:number"));
1696    }
1697
1698    #[test]
1699    fn command_transcriber_reports_failure() {
1700        let mut transcriber = CommandTranscriber::new("false", TranscriptFormat::Plain);
1701        let err = transcriber.transcribe(Path::new("missing")).unwrap_err();
1702        assert!(matches!(err, TranscriptionError::CommandFailed(_)));
1703    }
1704
1705    #[test]
1706    fn transcribes_waveform_batches_via_existing_command_transcriber() {
1707        let dir = tempdir().unwrap();
1708        let script_path = dir.path().join("transcriber.sh");
1709        fs::write(&script_path, "#!/bin/sh\nprintf 'hello from batch\\n'\n").unwrap();
1710        #[cfg(unix)]
1711        {
1712            use std::os::unix::fs::PermissionsExt;
1713
1714            let mut permissions = fs::metadata(&script_path).unwrap().permissions();
1715            permissions.set_mode(0o755);
1716            fs::set_permissions(&script_path, permissions).unwrap();
1717        }
1718
1719        let mut transcriber = CommandTranscriber::new(&script_path, TranscriptFormat::Plain);
1720        let frame = OwnedAudioFrame::new(
1721            Timestamp::new(0, Timebase::new(1, 16_000)),
1722            16_000,
1723            1,
1724            AudioBuffer::F32(vec![0.0, 0.25, -0.25, 0.5]),
1725        )
1726        .unwrap();
1727        let batch = OwnedAudioWaveformBatch::from_audio_frames(&[frame]).unwrap();
1728        let wav_path = dir.path().join("input.wav");
1729
1730        let result = transcribe_waveform_batch(&mut transcriber, &batch, &wav_path).unwrap();
1731        assert_eq!(result.text.as_deref(), Some("hello from batch"));
1732        assert!(wav_path.is_file());
1733    }
1734
1735    #[test]
1736    fn srt_webvtt_plain_round_trip() {
1737        let srt = "1\n00:00:00,000 --> 00:00:01,000\nHello world\n";
1738        let parsed = parse_srt(srt).unwrap();
1739        let formatted = format_srt(&parsed.segments);
1740        assert!(formatted.contains("Hello world"));
1741
1742        let webvtt =
1743            parse_webvtt("WEBVTT\n\n00:00:00.000 --> 00:00:01.000\nHello world\n").unwrap();
1744        let plain = parse_plain_lines("Hello world\n");
1745
1746        assert_eq!(parsed.segments[0].text, webvtt.segments[0].text);
1747        assert_eq!(plain.text.as_deref(), Some("Hello world"));
1748    }
1749
1750    #[test]
1751    fn normalizes_subtitle_markup_entities_and_whitespace() {
1752        let normalized = normalize_subtitle_text(
1753            "<v Speaker>Hello&nbsp; <c.yellow>Rust</c> &amp; friends <00:00:01.000>\nnow",
1754            SubtitleNormalizationOptions::default(),
1755        );
1756
1757        assert_eq!(normalized, "Hello Rust & friends now");
1758    }
1759
1760    #[test]
1761    fn infers_transcript_format_from_extension() {
1762        assert_eq!(
1763            TranscriptFormat::from_extension(".vtt"),
1764            Some(TranscriptFormat::WebVtt)
1765        );
1766        assert_eq!(
1767            TranscriptFormat::from_extension("SRT"),
1768            Some(TranscriptFormat::Srt)
1769        );
1770        assert_eq!(
1771            TranscriptFormat::from_extension("json"),
1772            Some(TranscriptFormat::WhisperJson)
1773        );
1774        assert_eq!(TranscriptFormat::from_extension("csv"), None);
1775    }
1776
1777    #[test]
1778    fn parses_and_normalizes_transcript_file() {
1779        let dir = tempfile::tempdir().unwrap();
1780        let path = dir.path().join("sample.vtt");
1781        fs::write(
1782            &path,
1783            "WEBVTT\n\n00:00:00.000 --> 00:00:01.000\n<c>Hello&nbsp;file</c>\n",
1784        )
1785        .unwrap();
1786
1787        let parsed =
1788            parse_normalized_transcript_file(&path, SubtitleNormalizationOptions::default())
1789                .unwrap();
1790
1791        assert_eq!(parsed.source.as_deref(), Some(path.to_str().unwrap()));
1792        assert_eq!(parsed.text.as_deref(), Some("Hello file"));
1793        assert_eq!(parsed.segments[0].text, "Hello file");
1794    }
1795
1796    #[test]
1797    fn builds_text_segment_contract_with_source() {
1798        let mut segment = TranscriptSegmentContract::new(7, "hello source");
1799        segment.start_seconds = Some(1.25);
1800        segment.end_seconds = Some(2.5);
1801        segment.language = Some("en".to_string());
1802
1803        let contract = text_segment_contract_with_source(
1804            &segment,
1805            "stream-1",
1806            "caption_manual",
1807            "https://example.test/video",
1808        );
1809
1810        assert_eq!(contract.stream_id.as_deref(), Some("stream-1"));
1811        assert_eq!(contract.segment_index, 7);
1812        assert_eq!(contract.language.as_deref(), Some("en"));
1813        assert_eq!(contract.duration_seconds, Some(1.25));
1814        let source = contract.source.unwrap();
1815        assert_eq!(source.source_id.as_deref(), Some("stream-1"));
1816        assert_eq!(source.source_kind.as_deref(), Some("caption_manual"));
1817        assert_eq!(source.uri.as_deref(), Some("https://example.test/video"));
1818        assert_eq!(source.duration_seconds, Some(1.25));
1819    }
1820
1821    #[test]
1822    fn preserves_unknown_markup_conservatively() {
1823        let normalized = normalize_subtitle_text(
1824            "look <custom value>here</custom>",
1825            SubtitleNormalizationOptions::default(),
1826        );
1827
1828        assert_eq!(normalized, "look <custom value>here</custom>");
1829    }
1830
1831    #[test]
1832    fn decodes_basic_entities() {
1833        let normalized = normalize_subtitle_text(
1834            "&amp; &lt; &gt; &quot; &#39; &apos; &nbsp;",
1835            SubtitleNormalizationOptions {
1836                strip_markup: false,
1837                decode_basic_entities: true,
1838                collapse_whitespace: true,
1839            },
1840        );
1841
1842        assert_eq!(normalized, "& < > \" ' '");
1843    }
1844
1845    #[test]
1846    fn parse_and_normalize_round_trip_for_subtitles() {
1847        let srt = parse_srt("1\n00:00:00,000 --> 00:00:01,000\n<c>Hello&nbsp;SRT</c>\n").unwrap();
1848        let vtt =
1849            parse_webvtt("WEBVTT\n\n00:00:00.000 --> 00:00:01.000\n<v Speaker>Hello&nbsp;VTT\n")
1850                .unwrap();
1851
1852        let options = SubtitleNormalizationOptions::default();
1853        assert_eq!(
1854            normalize_subtitle_text(&srt.segments[0].text, options.clone()),
1855            "Hello SRT"
1856        );
1857        assert_eq!(
1858            normalize_subtitle_text(&vtt.segments[0].text, options),
1859            "Hello VTT"
1860        );
1861    }
1862}