use crate::Result;
use crate::error::Error;
use crate::ingest::SourceProvenance;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MediaKind {
Audio,
Video,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TranscriptionEngine {
Whisper,
Mock,
External,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MediaIngestRequest {
pub path: PathBuf,
pub media_kind: MediaKind,
pub provenance: SourceProvenance,
pub language_hints: Vec<String>,
pub metadata: HashMap<String, String>,
}
impl MediaIngestRequest {
pub fn new(
path: impl Into<PathBuf>,
media_kind: MediaKind,
provenance: SourceProvenance,
) -> Self {
Self {
path: path.into(),
media_kind,
provenance,
language_hints: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn with_language_hint(mut self, language: impl Into<String>) -> Self {
self.language_hints.push(language.into());
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct TranscriptChunkPolicy {
pub max_segment_duration_ms: u64,
pub max_chars_per_chunk: usize,
}
impl Default for TranscriptChunkPolicy {
fn default() -> Self {
Self {
max_segment_duration_ms: 60_000,
max_chars_per_chunk: 2_000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptionRequest {
pub media: MediaIngestRequest,
pub chunk_policy: TranscriptChunkPolicy,
pub diarization: bool,
}
impl TranscriptionRequest {
pub fn new(media: MediaIngestRequest) -> Self {
Self {
media,
chunk_policy: TranscriptChunkPolicy::default(),
diarization: false,
}
}
pub fn with_diarization(mut self) -> Self {
self.diarization = true;
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TranscriptSegment {
pub index: usize,
pub start_ms: u64,
pub end_ms: u64,
pub text: String,
pub speaker: Option<String>,
pub confidence: Option<f32>,
}
impl TranscriptSegment {
pub fn new(index: usize, start_ms: u64, end_ms: u64, text: impl Into<String>) -> Self {
Self {
index,
start_ms,
end_ms,
text: text.into(),
speaker: None,
confidence: None,
}
}
pub fn duration_ms(&self) -> u64 {
self.end_ms.saturating_sub(self.start_ms)
}
pub fn has_valid_time_range(&self) -> bool {
self.end_ms > self.start_ms
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptDocument {
pub engine: TranscriptionEngine,
pub media_kind: MediaKind,
pub language: Option<String>,
pub segments: Vec<TranscriptSegment>,
pub full_text: String,
pub generated_at: DateTime<Utc>,
pub provenance: SourceProvenance,
}
impl TranscriptDocument {
pub fn new(
engine: TranscriptionEngine,
media_kind: MediaKind,
provenance: SourceProvenance,
) -> Self {
Self {
engine,
media_kind,
language: None,
segments: Vec::new(),
full_text: String::new(),
generated_at: Utc::now(),
provenance,
}
}
pub fn effective_text(&self) -> String {
let trimmed = self.full_text.trim();
if !trimmed.is_empty() {
return trimmed.to_string();
}
self.segments
.iter()
.map(|segment| segment.text.trim())
.filter(|text| !text.is_empty())
.collect::<Vec<_>>()
.join("\n")
}
pub fn validate(&self) -> Result<()> {
let mut last_end = 0u64;
for (position, segment) in self.segments.iter().enumerate() {
if !segment.has_valid_time_range() {
return Err(Error::ingest(format!(
"invalid transcript segment at position {position}: end_ms ({}) must be greater than start_ms ({})",
segment.end_ms, segment.start_ms
)));
}
if position > 0 && segment.start_ms < last_end {
return Err(Error::ingest(format!(
"overlapping transcript segment at position {position}: start_ms ({}) < previous end_ms ({last_end})",
segment.start_ms
)));
}
last_end = segment.end_ms;
}
Ok(())
}
}
#[async_trait]
pub trait TranscriptionBackend: Send + Sync {
fn engine(&self) -> TranscriptionEngine;
async fn transcribe(&self, request: &TranscriptionRequest) -> Result<TranscriptDocument>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ingest::{SourceKind, SourceProvenance};
#[test]
fn transcript_effective_text_falls_back_to_segments() {
let provenance = SourceProvenance::new(SourceKind::Audio, "file:///meeting.m4a");
let mut doc =
TranscriptDocument::new(TranscriptionEngine::Mock, MediaKind::Audio, provenance);
doc.segments
.push(TranscriptSegment::new(0, 0, 1000, "hello world"));
doc.segments
.push(TranscriptSegment::new(1, 1000, 2000, "second segment"));
assert_eq!(doc.effective_text(), "hello world\nsecond segment");
}
#[test]
fn transcript_validation_rejects_overlap() {
let provenance = SourceProvenance::new(SourceKind::Video, "file:///clip.mp4");
let mut doc =
TranscriptDocument::new(TranscriptionEngine::Mock, MediaKind::Video, provenance);
doc.segments.push(TranscriptSegment::new(0, 0, 1500, "a"));
doc.segments
.push(TranscriptSegment::new(1, 1000, 2000, "b"));
let err = doc.validate().unwrap_err();
assert!(err.to_string().contains("overlapping transcript segment"));
}
#[test]
fn transcript_validation_accepts_monotonic_segments() {
let provenance = SourceProvenance::new(SourceKind::Video, "file:///clip.mp4");
let mut doc =
TranscriptDocument::new(TranscriptionEngine::Mock, MediaKind::Video, provenance);
doc.segments.push(TranscriptSegment::new(0, 0, 1500, "a"));
doc.segments
.push(TranscriptSegment::new(1, 1500, 2200, "b"));
assert!(doc.validate().is_ok());
}
}