Skip to main content

text_transcripts/
contracts.rs

1use std::collections::BTreeMap;
2
3use serde::{Deserialize, Serialize};
4use text_core::{AsTextSegmentContract, TextSegmentContract, TextSourceRef, TimestampContract};
5use video_analysis_core::{Timebase, Timestamp};
6
7use crate::{TranscriptSegment, TranscriptWord, TranscriptionError, TranscriptionResult};
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct TranscriptCharContract {
12    #[serde(rename = "char")]
13    pub character: String,
14    #[serde(
15        default,
16        rename = "start",
17        alias = "start_seconds",
18        alias = "startSeconds"
19    )]
20    pub start_seconds: Option<f64>,
21    #[serde(default, rename = "end", alias = "end_seconds", alias = "endSeconds")]
22    pub end_seconds: Option<f64>,
23    #[serde(default, rename = "score", alias = "confidence")]
24    pub confidence: Option<f32>,
25    #[serde(default)]
26    pub attributes: BTreeMap<String, String>,
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30#[serde(rename_all = "camelCase")]
31pub struct TranscriptWordContract {
32    pub text: String,
33    #[serde(default)]
34    pub start_seconds: Option<f64>,
35    #[serde(default)]
36    pub end_seconds: Option<f64>,
37    #[serde(default)]
38    pub confidence: Option<f32>,
39    #[serde(default)]
40    pub speaker: Option<String>,
41    #[serde(default)]
42    pub attributes: BTreeMap<String, String>,
43}
44
45impl From<TranscriptWord> for TranscriptWordContract {
46    fn from(value: TranscriptWord) -> Self {
47        Self {
48            text: value.text,
49            start_seconds: value.start_seconds,
50            end_seconds: value.end_seconds,
51            confidence: sanitize_confidence(value.confidence),
52            speaker: None,
53            attributes: BTreeMap::new(),
54        }
55    }
56}
57
58impl From<TranscriptWordContract> for TranscriptWord {
59    fn from(value: TranscriptWordContract) -> Self {
60        Self {
61            text: value.text,
62            start_seconds: value.start_seconds,
63            end_seconds: value.end_seconds,
64            confidence: sanitize_confidence(value.confidence),
65        }
66    }
67}
68
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub struct TranscriptSegmentContract {
72    pub index: u64,
73    #[serde(default)]
74    pub start_seconds: Option<f64>,
75    #[serde(default)]
76    pub end_seconds: Option<f64>,
77    pub text: String,
78    #[serde(default)]
79    pub language: Option<String>,
80    #[serde(default)]
81    pub speaker: Option<String>,
82    #[serde(default)]
83    pub confidence: Option<f32>,
84    pub is_final: bool,
85    #[serde(default)]
86    pub words: Vec<TranscriptWordContract>,
87    #[serde(default)]
88    pub chars: Vec<TranscriptCharContract>,
89    #[serde(default)]
90    pub attributes: BTreeMap<String, String>,
91}
92
93impl TranscriptSegmentContract {
94    pub fn new(index: u64, text: impl Into<String>) -> Self {
95        Self {
96            index,
97            start_seconds: None,
98            end_seconds: None,
99            text: text.into(),
100            language: None,
101            speaker: None,
102            confidence: None,
103            is_final: true,
104            words: Vec::new(),
105            chars: Vec::new(),
106            attributes: BTreeMap::new(),
107        }
108    }
109
110    pub fn validate(&self) -> crate::Result<()> {
111        validate_seconds_range(self.start_seconds, self.end_seconds)?;
112        if self
113            .confidence
114            .is_some_and(|confidence| !confidence.is_finite())
115        {
116            return Err(TranscriptionError::InvalidTranscript(
117                "transcript segment confidence must be finite".to_string(),
118            ));
119        }
120        for word in &self.words {
121            validate_seconds_range(word.start_seconds, word.end_seconds)?;
122            if word
123                .confidence
124                .is_some_and(|confidence| !confidence.is_finite())
125            {
126                return Err(TranscriptionError::InvalidTranscript(
127                    "transcript word confidence must be finite".to_string(),
128                ));
129            }
130        }
131        for character in &self.chars {
132            validate_seconds_range(character.start_seconds, character.end_seconds)?;
133            if character
134                .confidence
135                .is_some_and(|confidence| !confidence.is_finite())
136            {
137                return Err(TranscriptionError::InvalidTranscript(
138                    "transcript char confidence must be finite".to_string(),
139                ));
140            }
141        }
142        Ok(())
143    }
144
145    pub fn validated(mut self) -> crate::Result<Self> {
146        self.confidence = sanitize_confidence(self.confidence);
147        for word in &mut self.words {
148            word.confidence = sanitize_confidence(word.confidence);
149        }
150        for character in &mut self.chars {
151            character.confidence = sanitize_confidence(character.confidence);
152        }
153        self.validate()?;
154        Ok(self)
155    }
156
157    pub fn normalized(mut self) -> Self {
158        self.text = self.text.trim().to_string();
159        self.confidence = sanitize_confidence(self.confidence);
160        self.words = self
161            .words
162            .into_iter()
163            .filter_map(|mut word| {
164                word.text = word.text.trim().to_string();
165                word.confidence = sanitize_confidence(word.confidence);
166                word.speaker = word
167                    .speaker
168                    .map(|speaker| speaker.trim().to_string())
169                    .filter(|speaker| !speaker.is_empty());
170                (!word.text.is_empty()).then_some(word)
171            })
172            .collect();
173        self.chars = self
174            .chars
175            .into_iter()
176            .filter_map(|mut character| {
177                character.confidence = sanitize_confidence(character.confidence);
178                (!character.character.is_empty()).then_some(character)
179            })
180            .collect();
181        self
182    }
183
184    pub fn duration_seconds(&self) -> Option<f64> {
185        Some((self.end_seconds? - self.start_seconds?).max(0.0))
186    }
187
188    pub fn midpoint_seconds(&self) -> Option<f64> {
189        Some((self.start_seconds? + self.end_seconds?) * 0.5)
190    }
191}
192
193impl AsTextSegmentContract for TranscriptSegmentContract {
194    fn as_text_segment_contract(&self) -> TextSegmentContract {
195        let mut attributes = self.attributes.clone();
196        insert_optional(&mut attributes, "speaker", self.speaker.as_deref());
197        insert_optional_display(&mut attributes, "confidence", self.confidence);
198
199        TextSegmentContract {
200            stream_id: None,
201            segment_index: self.index,
202            text: self.text.clone(),
203            language: self.language.clone(),
204            timestamp: self.start_seconds.map(seconds_to_timestamp_contract),
205            duration_seconds: self.duration_seconds(),
206            is_final: self.is_final,
207            attributes,
208            source: Some(TextSourceRef {
209                source_id: None,
210                source_kind: Some("transcript_segment".to_string()),
211                uri: None,
212                media_timestamp: self.start_seconds.map(seconds_to_timestamp_contract),
213                duration_seconds: self.duration_seconds(),
214            }),
215            provenance: Vec::new(),
216            annotations: Vec::new(),
217        }
218    }
219}
220
221impl From<TranscriptSegment> for TranscriptSegmentContract {
222    fn from(value: TranscriptSegment) -> Self {
223        Self {
224            index: value.index,
225            start_seconds: value.start_seconds,
226            end_seconds: value.end_seconds,
227            text: value.text,
228            language: value.language,
229            speaker: value.speaker,
230            confidence: sanitize_confidence(value.confidence),
231            is_final: value.is_final,
232            words: Vec::new(),
233            chars: Vec::new(),
234            attributes: BTreeMap::new(),
235        }
236    }
237}
238
239impl From<&TranscriptSegment> for TranscriptSegmentContract {
240    fn from(value: &TranscriptSegment) -> Self {
241        value.clone().into()
242    }
243}
244
245impl From<TranscriptSegmentContract> for TranscriptSegment {
246    fn from(value: TranscriptSegmentContract) -> Self {
247        Self {
248            index: value.index,
249            start_seconds: value.start_seconds,
250            end_seconds: value.end_seconds,
251            text: value.text,
252            language: value.language,
253            speaker: value.speaker,
254            confidence: sanitize_confidence(value.confidence),
255            is_final: value.is_final,
256        }
257    }
258}
259
260impl From<TranscriptSegmentContract> for TextSegmentContract {
261    fn from(value: TranscriptSegmentContract) -> Self {
262        value.as_text_segment_contract()
263    }
264}
265
266impl From<&TranscriptSegmentContract> for TextSegmentContract {
267    fn from(value: &TranscriptSegmentContract) -> Self {
268        value.as_text_segment_contract()
269    }
270}
271
272pub fn text_segment_contract_with_source(
273    segment: &TranscriptSegmentContract,
274    stream_id: impl Into<String>,
275    source_kind: impl Into<String>,
276    uri: impl Into<String>,
277) -> TextSegmentContract {
278    let stream_id = stream_id.into();
279    let source_kind = source_kind.into();
280    let uri = uri.into();
281    let mut contract = segment.as_text_segment_contract();
282    contract.stream_id = Some(stream_id.clone());
283    contract.source = Some(TextSourceRef {
284        source_id: Some(stream_id),
285        source_kind: Some(source_kind),
286        uri: Some(uri),
287        media_timestamp: contract.timestamp,
288        duration_seconds: contract.duration_seconds,
289    });
290    contract
291}
292
293#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
294#[serde(rename_all = "camelCase")]
295pub struct TranscriptionContract {
296    #[serde(default)]
297    pub text: Option<String>,
298    #[serde(default)]
299    pub language: Option<String>,
300    #[serde(default)]
301    pub segments: Vec<TranscriptSegmentContract>,
302    #[serde(default)]
303    pub source: Option<String>,
304    #[serde(default)]
305    pub attributes: BTreeMap<String, String>,
306}
307
308impl TranscriptionContract {
309    pub fn new(segments: Vec<TranscriptSegmentContract>) -> Self {
310        Self {
311            text: None,
312            language: None,
313            segments,
314            source: None,
315            attributes: BTreeMap::new(),
316        }
317    }
318
319    pub fn from_segments(
320        source: Option<String>,
321        language: Option<String>,
322        segments: Vec<TranscriptSegmentContract>,
323    ) -> crate::Result<Self> {
324        let segments = segments
325            .into_iter()
326            .map(|mut segment| {
327                if segment.language.is_none() {
328                    segment.language = language.clone();
329                }
330                segment
331            })
332            .collect();
333        Self {
334            text: None,
335            language,
336            segments,
337            source,
338            attributes: BTreeMap::new(),
339        }
340        .normalized()
341    }
342
343    pub fn validate(&self) -> crate::Result<()> {
344        for segment in &self.segments {
345            segment.validate()?;
346        }
347        Ok(())
348    }
349
350    pub fn validate_strict(&self) -> crate::Result<()> {
351        self.validate()?;
352        let mut last_start_seconds = None;
353        for segment in &self.segments {
354            if segment.text.trim().is_empty() {
355                return Err(TranscriptionError::InvalidTranscript(
356                    "transcript segment text must not be empty".to_string(),
357                ));
358            }
359            if let (Some(previous), Some(current)) = (last_start_seconds, segment.start_seconds) {
360                if current < previous {
361                    return Err(TranscriptionError::InvalidTranscript(
362                        "transcript segment start_seconds must not move backward".to_string(),
363                    ));
364                }
365            }
366            if segment.start_seconds.is_some() {
367                last_start_seconds = segment.start_seconds;
368            }
369            for word in &segment.words {
370                validate_word_inside_segment(segment, word)?;
371            }
372            for character in &segment.chars {
373                validate_char_inside_segment(segment, character)?;
374            }
375        }
376        Ok(())
377    }
378
379    pub fn joined_text(&self) -> String {
380        self.segments
381            .iter()
382            .map(|segment| segment.text.trim())
383            .filter(|text| !text.is_empty())
384            .collect::<Vec<_>>()
385            .join(" ")
386    }
387
388    pub fn text_or_joined(&self) -> String {
389        self.text
390            .as_deref()
391            .map(str::trim)
392            .filter(|text| !text.is_empty())
393            .map(str::to_string)
394            .unwrap_or_else(|| self.joined_text())
395    }
396
397    pub fn normalized(mut self) -> crate::Result<Self> {
398        self.text = self
399            .text
400            .map(|text| text.trim().to_string())
401            .filter(|text| !text.is_empty());
402        self.segments = self
403            .segments
404            .into_iter()
405            .map(TranscriptSegmentContract::normalized)
406            .collect();
407        if self.text.is_none() {
408            let joined = self.joined_text();
409            if !joined.is_empty() {
410                self.text = Some(joined);
411            }
412        }
413        self.validate()?;
414        Ok(self)
415    }
416}
417
418impl From<TranscriptionResult> for TranscriptionContract {
419    fn from(value: TranscriptionResult) -> Self {
420        Self {
421            text: value.text,
422            language: value.language,
423            segments: value.segments.into_iter().map(Into::into).collect(),
424            source: value.source,
425            attributes: BTreeMap::new(),
426        }
427    }
428}
429
430impl From<TranscriptionContract> for TranscriptionResult {
431    fn from(value: TranscriptionContract) -> Self {
432        Self {
433            text: value.text,
434            language: value.language,
435            segments: value.segments.into_iter().map(Into::into).collect(),
436            source: value.source,
437        }
438    }
439}
440
441fn seconds_to_timestamp_contract(seconds: f64) -> TimestampContract {
442    Timestamp::new((seconds * 1_000.0).round() as i64, Timebase::new(1, 1_000)).into()
443}
444
445fn sanitize_confidence(value: Option<f32>) -> Option<f32> {
446    value.and_then(|confidence| confidence.is_finite().then(|| confidence.clamp(0.0, 1.0)))
447}
448
449fn validate_seconds_range(start: Option<f64>, end: Option<f64>) -> crate::Result<()> {
450    if start.is_some_and(|value| !value.is_finite()) || end.is_some_and(|value| !value.is_finite())
451    {
452        return Err(TranscriptionError::InvalidTranscript(
453            "transcript timestamps must be finite".to_string(),
454        ));
455    }
456    if let (Some(start), Some(end)) = (start, end) {
457        if end < start {
458            return Err(TranscriptionError::InvalidTranscript(
459                "transcript segment end_seconds must be greater than or equal to start_seconds"
460                    .to_string(),
461            ));
462        }
463    }
464    Ok(())
465}
466
467fn validate_word_inside_segment(
468    segment: &TranscriptSegmentContract,
469    word: &TranscriptWordContract,
470) -> crate::Result<()> {
471    if let (Some(segment_start), Some(word_start)) = (segment.start_seconds, word.start_seconds) {
472        if word_start < segment_start {
473            return Err(TranscriptionError::InvalidTranscript(
474                "transcript word start_seconds must be within its segment".to_string(),
475            ));
476        }
477    }
478    if let (Some(segment_end), Some(word_end)) = (segment.end_seconds, word.end_seconds) {
479        if word_end > segment_end {
480            return Err(TranscriptionError::InvalidTranscript(
481                "transcript word end_seconds must be within its segment".to_string(),
482            ));
483        }
484    }
485    Ok(())
486}
487
488fn validate_char_inside_segment(
489    segment: &TranscriptSegmentContract,
490    character: &TranscriptCharContract,
491) -> crate::Result<()> {
492    if let (Some(segment_start), Some(char_start)) =
493        (segment.start_seconds, character.start_seconds)
494    {
495        if char_start < segment_start {
496            return Err(TranscriptionError::InvalidTranscript(
497                "transcript char start_seconds must be within its segment".to_string(),
498            ));
499        }
500    }
501    if let (Some(segment_end), Some(char_end)) = (segment.end_seconds, character.end_seconds) {
502        if char_end > segment_end {
503            return Err(TranscriptionError::InvalidTranscript(
504                "transcript char end_seconds must be within its segment".to_string(),
505            ));
506        }
507    }
508    Ok(())
509}
510
511fn insert_optional(metadata: &mut BTreeMap<String, String>, key: &str, value: Option<&str>) {
512    if let Some(value) = value {
513        metadata.insert(key.to_string(), value.to_string());
514    }
515}
516
517fn insert_optional_display<T: std::fmt::Display>(
518    metadata: &mut BTreeMap<String, String>,
519    key: &str,
520    value: Option<T>,
521) {
522    if let Some(value) = value {
523        metadata.insert(key.to_string(), value.to_string());
524    }
525}