use anyhow::{Context, Result};
use colored::Colorize;
use std::path::PathBuf;
#[cfg_attr(not(feature = "caption-gen"), allow(dead_code))]
pub struct CaptionsGenerateOptions {
pub input: PathBuf,
pub output: PathBuf,
pub format: String,
pub language: String,
pub model: Option<PathBuf>,
pub vocab: Option<PathBuf>,
}
pub struct CaptionsSyncOptions {
pub input: PathBuf,
pub reference: PathBuf,
pub output: PathBuf,
pub max_shift_ms: i64,
}
pub struct CaptionsConvertOptions {
pub input: PathBuf,
pub output: PathBuf,
pub from_format: Option<String>,
pub to_format: String,
}
pub struct CaptionsBurnOptions {
pub video: PathBuf,
pub captions: PathBuf,
pub output: PathBuf,
pub font_size: u32,
pub font_color: String,
}
pub struct CaptionsExtractOptions {
pub input: PathBuf,
pub output: PathBuf,
pub format: String,
pub track: usize,
}
pub struct CaptionsValidateOptions {
pub input: PathBuf,
pub standard: String,
pub report: Option<PathBuf>,
}
fn parse_caption_format(s: &str) -> Result<oximedia_captions::CaptionFormat> {
match s.to_lowercase().as_str() {
"srt" => Ok(oximedia_captions::CaptionFormat::Srt),
"vtt" | "webvtt" => Ok(oximedia_captions::CaptionFormat::WebVtt),
"ass" => Ok(oximedia_captions::CaptionFormat::Ass),
"ssa" => Ok(oximedia_captions::CaptionFormat::Ssa),
"ttml" => Ok(oximedia_captions::CaptionFormat::Ttml),
"dfxp" => Ok(oximedia_captions::CaptionFormat::Dfxp),
"scc" => Ok(oximedia_captions::CaptionFormat::Scc),
"stl" | "ebu-stl" => Ok(oximedia_captions::CaptionFormat::EbuStl),
"itt" => Ok(oximedia_captions::CaptionFormat::ITt),
"cea608" | "cea-608" => Ok(oximedia_captions::CaptionFormat::Cea608),
"cea708" | "cea-708" => Ok(oximedia_captions::CaptionFormat::Cea708),
other => Err(anyhow::anyhow!("Unknown caption format: {other}")),
}
}
pub async fn run_captions_generate(opts: CaptionsGenerateOptions, json_output: bool) -> Result<()> {
#[cfg(not(feature = "caption-gen"))]
{
let _ = (&opts, json_output);
return Err(anyhow::anyhow!(
"Caption ASR requires the `caption-gen` feature. \
Rebuild with: cargo build --features caption-gen\n\
Note: you also need to supply --model <encoder.onnx> and \
--vocab <vocab.json> at runtime."
));
}
#[cfg(feature = "caption-gen")]
{
run_captions_generate_impl(opts, json_output).await
}
}
#[cfg(feature = "caption-gen")]
async fn run_captions_generate_impl(
opts: CaptionsGenerateOptions,
json_output: bool,
) -> Result<()> {
use oximedia_caption_gen::{
alignment::{build_caption_blocks, merge_short_segments},
line_breaking::optimal_break,
ml::CaptionEncoder,
};
use oximedia_ml::DeviceType;
let model_path = opts.model.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Caption generation requires --model <path>. \
No ASR model weights are bundled with oximedia-cli."
)
})?;
let vocab_path = opts.vocab.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Caption generation requires --vocab <path> (a JSON file mapping \
token IDs to strings)."
)
})?;
let raw_bytes = std::fs::read(&opts.input)
.with_context(|| format!("Failed to read input: {}", opts.input.display()))?;
let samples_16k = parse_wav_to_mono_f32(&raw_bytes).with_context(|| {
format!(
"Failed to decode audio from '{}'. \
Only WAV/PCM files are supported in this build.",
opts.input.display()
)
})?;
const MEL_BINS: usize = 80;
const SAMPLE_RATE: u32 = 16_000;
const FRAME_LEN: usize = 400; const HOP_LEN: usize = 160;
let spectrogram =
compute_log_mel_spectrogram(&samples_16k, MEL_BINS, FRAME_LEN, HOP_LEN, SAMPLE_RATE)
.context("Failed to compute log-mel spectrogram")?;
let n_frames = spectrogram.len() / MEL_BINS;
let input_shape = [1_usize, MEL_BINS, n_frames];
let encoder = CaptionEncoder::from_path(model_path, DeviceType::auto()).with_context(|| {
format!(
"Failed to load caption encoder model from '{}'",
model_path.display()
)
})?;
let encoder_out = encoder
.encode(&spectrogram, &input_shape)
.map_err(|e| anyhow::anyhow!("Encoder inference failed: {e}"))?;
let (seq_len, vocab_size) = derive_seq_vocab(&encoder_out.shape).ok_or_else(|| {
anyhow::anyhow!("Unexpected encoder output shape: {:?}", encoder_out.shape)
})?;
let token_ids =
oximedia_caption_gen::ml::greedy_decode(&encoder_out.logits, vocab_size, seq_len)
.map_err(|e| anyhow::anyhow!("Greedy decode failed: {e}"))?;
let vocab_bytes = std::fs::read(vocab_path)
.with_context(|| format!("Failed to read vocab file: {}", vocab_path.display()))?;
let vocab: std::collections::HashMap<String, String> =
serde_json::from_slice(&vocab_bytes).context("Failed to parse vocab JSON")?;
let transcript_text = tokens_to_text(&token_ids, &vocab);
let audio_duration_ms = (samples_16k.len() as u64 * 1000) / u64::from(SAMPLE_RATE);
let segments = build_segments_from_text(&transcript_text, audio_duration_ms);
let merged = merge_short_segments(&segments, 800);
let blocks = build_caption_blocks(&merged, 2, 42);
const MAX_LINE_CHARS: u8 = 42;
let language =
oximedia_captions::Language::new(opts.language.clone(), opts.language.clone(), false);
let mut track = oximedia_captions::CaptionTrack::new(language);
for block in &blocks {
let joined = block.lines.join(" ");
let broken_lines = optimal_break(&joined, MAX_LINE_CHARS);
let text = broken_lines.join("\n");
let caption = oximedia_captions::Caption::new(
oximedia_captions::Timestamp::from_millis(block.start_ms as i64),
oximedia_captions::Timestamp::from_millis(block.end_ms as i64),
text,
);
track
.add_caption(caption)
.map_err(|e| anyhow::anyhow!("Failed to add caption: {e}"))?;
}
let format = parse_caption_format(&opts.format)?;
let output_bytes = oximedia_captions::export::Exporter::export(&track, format)
.map_err(|e| anyhow::anyhow!("Export failed: {e}"))?;
std::fs::write(&opts.output, &output_bytes)
.with_context(|| format!("Failed to write output: {}", opts.output.display()))?;
let caption_count = track.count();
if json_output {
let obj = serde_json::json!({
"input": opts.input.to_string_lossy(),
"output": opts.output.to_string_lossy(),
"format": opts.format,
"language": opts.language,
"captions_count": caption_count,
});
println!("{}", serde_json::to_string_pretty(&obj)?);
} else {
println!("{}", "Caption Generation Complete".green().bold());
println!(" Input: {}", opts.input.display());
println!(" Output: {}", opts.output.display());
println!(" Format: {}", opts.format);
println!(" Language: {}", opts.language);
println!(" Captions: {}", caption_count);
}
Ok(())
}
#[cfg(feature = "caption-gen")]
fn parse_wav_to_mono_f32(data: &[u8]) -> anyhow::Result<Vec<f32>> {
if data.len() < 44 {
return Err(anyhow::anyhow!("WAV file too small ({} bytes)", data.len()));
}
if &data[0..4] != b"RIFF" || &data[8..12] != b"WAVE" {
return Err(anyhow::anyhow!(
"Not a RIFF/WAVE file — only WAV input is supported"
));
}
let (n_channels, src_sample_rate, bits_per_sample) = find_fmt_chunk(data)?;
let pcm_bytes = find_data_chunk(data)?;
let samples = decode_pcm(pcm_bytes, bits_per_sample, n_channels)?;
const TARGET_RATE: u32 = 16_000;
if src_sample_rate == TARGET_RATE {
return Ok(samples);
}
let ratio = src_sample_rate as f64 / TARGET_RATE as f64;
let out_len = ((samples.len() as f64) / ratio).ceil() as usize;
let resampled: Vec<f32> = (0..out_len)
.map(|i| {
let src_idx = ((i as f64 * ratio).round() as usize).min(samples.len() - 1);
samples[src_idx]
})
.collect();
Ok(resampled)
}
#[cfg(feature = "caption-gen")]
fn find_fmt_chunk(data: &[u8]) -> anyhow::Result<(u16, u32, u16)> {
let mut pos = 12_usize; while pos + 8 <= data.len() {
let chunk_id = &data[pos..pos + 4];
let chunk_size = u32::from_le_bytes(
data[pos + 4..pos + 8]
.try_into()
.map_err(|_| anyhow::anyhow!("WAV chunk size read error"))?,
) as usize;
if chunk_id == b"fmt " && pos + 8 + chunk_size >= 16 {
let audio_fmt = u16::from_le_bytes(
data[pos + 8..pos + 10]
.try_into()
.map_err(|_| anyhow::anyhow!("WAV fmt parse error"))?,
);
if audio_fmt != 1 && audio_fmt != 3 {
return Err(anyhow::anyhow!(
"Unsupported WAV audio format {audio_fmt} (only PCM/IEEE-float supported)"
));
}
let n_channels = u16::from_le_bytes(
data[pos + 10..pos + 12]
.try_into()
.map_err(|_| anyhow::anyhow!("WAV channel count parse error"))?,
);
let sample_rate = u32::from_le_bytes(
data[pos + 12..pos + 16]
.try_into()
.map_err(|_| anyhow::anyhow!("WAV sample rate parse error"))?,
);
let bits_per_sample = u16::from_le_bytes(
data[pos + 22..pos + 24]
.try_into()
.map_err(|_| anyhow::anyhow!("WAV bits/sample parse error"))?,
);
return Ok((n_channels, sample_rate, bits_per_sample));
}
pos += 8 + chunk_size + (chunk_size % 2); }
Err(anyhow::anyhow!("WAV 'fmt ' chunk not found"))
}
#[cfg(feature = "caption-gen")]
fn find_data_chunk(data: &[u8]) -> anyhow::Result<&[u8]> {
let mut pos = 12_usize;
while pos + 8 <= data.len() {
let chunk_id = &data[pos..pos + 4];
let chunk_size = u32::from_le_bytes(
data[pos + 4..pos + 8]
.try_into()
.map_err(|_| anyhow::anyhow!("WAV data chunk size error"))?,
) as usize;
if chunk_id == b"data" {
let end = (pos + 8 + chunk_size).min(data.len());
return Ok(&data[pos + 8..end]);
}
pos += 8 + chunk_size + (chunk_size % 2);
}
Err(anyhow::anyhow!("WAV 'data' chunk not found"))
}
#[cfg(feature = "caption-gen")]
fn decode_pcm(pcm: &[u8], bits_per_sample: u16, n_channels: u16) -> anyhow::Result<Vec<f32>> {
let channels = n_channels as usize;
if channels == 0 {
return Err(anyhow::anyhow!("WAV has 0 channels"));
}
let samples_raw: Vec<f32> = match bits_per_sample {
16 => {
if pcm.len() % 2 != 0 {
return Err(anyhow::anyhow!("16-bit WAV data has odd byte count"));
}
pcm.chunks_exact(2)
.map(|b| i16::from_le_bytes([b[0], b[1]]) as f32 / i16::MAX as f32)
.collect()
}
32 => {
if pcm.len() % 4 != 0 {
return Err(anyhow::anyhow!("32-bit WAV data has unaligned byte count"));
}
pcm.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect()
}
other => {
return Err(anyhow::anyhow!(
"Unsupported WAV bit depth {other} (only 16-bit and 32-bit float supported)"
));
}
};
if channels == 1 {
return Ok(samples_raw);
}
let mono: Vec<f32> = samples_raw
.chunks_exact(channels)
.map(|frame| frame.iter().sum::<f32>() / channels as f32)
.collect();
Ok(mono)
}
#[cfg(feature = "caption-gen")]
fn compute_log_mel_spectrogram(
samples: &[f32],
mel_bins: usize,
frame_len: usize,
hop_len: usize,
sample_rate: u32,
) -> anyhow::Result<Vec<f32>> {
if samples.is_empty() {
return Err(anyhow::anyhow!("Cannot compute spectrogram of empty audio"));
}
let fft_size = frame_len.next_power_of_two();
let n_frames = if samples.len() >= frame_len {
(samples.len() - frame_len) / hop_len + 1
} else {
1
};
let hann: Vec<f32> = (0..frame_len)
.map(|n| {
0.5 * (1.0 - (2.0 * std::f32::consts::PI * n as f32 / (frame_len - 1) as f32).cos())
})
.collect();
let filterbank = mel_filterbank(mel_bins, fft_size, sample_rate);
let mut spectrogram = vec![0.0_f32; mel_bins * n_frames];
for frame_idx in 0..n_frames {
let start = frame_idx * hop_len;
let end = (start + frame_len).min(samples.len());
let mut frame = vec![0.0_f32; fft_size];
for (i, s) in samples[start..end].iter().enumerate() {
frame[i] = s * hann[i];
}
let n_bins = fft_size / 2 + 1;
let mut power = vec![0.0_f32; n_bins];
for k in 0..n_bins {
let (mut re, mut im) = (0.0_f32, 0.0_f32);
let angle_step = -2.0 * std::f32::consts::PI * k as f32 / fft_size as f32;
for (n, &x) in frame.iter().enumerate() {
let angle = angle_step * n as f32;
re += x * angle.cos();
im += x * angle.sin();
}
power[k] = re * re + im * im;
}
for m in 0..mel_bins {
let energy: f32 = filterbank[m]
.iter()
.enumerate()
.map(|(k, &w)| w * power[k])
.sum();
spectrogram[frame_idx * mel_bins + m] = (energy.max(1e-10)).ln();
}
}
Ok(spectrogram)
}
#[cfg(feature = "caption-gen")]
fn mel_filterbank(mel_bins: usize, fft_size: usize, sample_rate: u32) -> Vec<Vec<f32>> {
let n_bins = fft_size / 2 + 1;
let hz_to_mel = |hz: f32| -> f32 { 2595.0 * (1.0 + hz / 700.0).log10() };
let mel_to_hz = |mel: f32| -> f32 { 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0) };
let f_min_mel = hz_to_mel(0.0);
let f_max_mel = hz_to_mel(sample_rate as f32 / 2.0);
let mel_points: Vec<f32> = (0..=mel_bins + 1)
.map(|i| mel_to_hz(f_min_mel + (f_max_mel - f_min_mel) * i as f32 / (mel_bins + 1) as f32))
.collect();
let bin_points: Vec<f32> = mel_points
.iter()
.map(|&hz| (hz * fft_size as f32 / sample_rate as f32).floor())
.collect();
let mut bank = vec![vec![0.0_f32; n_bins]; mel_bins];
for m in 0..mel_bins {
let lo = bin_points[m] as usize;
let mid = bin_points[m + 1] as usize;
let hi = bin_points[m + 2] as usize;
for k in lo..mid {
if k < n_bins && mid > lo {
bank[m][k] = (k - lo) as f32 / (mid - lo) as f32;
}
}
for k in mid..hi {
if k < n_bins && hi > mid {
bank[m][k] = (hi - k) as f32 / (hi - mid) as f32;
}
}
}
bank
}
#[cfg(feature = "caption-gen")]
fn tokens_to_text(token_ids: &[u32], vocab: &std::collections::HashMap<String, String>) -> String {
token_ids
.iter()
.filter_map(|id| vocab.get(&id.to_string()))
.cloned()
.collect::<Vec<_>>()
.join(" ")
}
#[cfg(feature = "caption-gen")]
fn derive_seq_vocab(shape: &[usize]) -> Option<(usize, usize)> {
match shape {
[seq, vocab] => Some((*seq, *vocab)),
[_batch, seq, vocab] => Some((*seq, *vocab)),
[flat] => {
Some((1, *flat))
}
_ => None,
}
}
#[cfg(feature = "caption-gen")]
fn build_segments_from_text(
text: &str,
total_ms: u64,
) -> Vec<oximedia_caption_gen::alignment::TranscriptSegment> {
use oximedia_caption_gen::alignment::TranscriptSegment;
const MAX_CHARS: usize = 42;
const MAX_MS: u64 = 5_000;
let text = text.trim();
if text.is_empty() {
return Vec::new();
}
let chunks = chunk_text(text, MAX_CHARS);
if chunks.is_empty() {
return Vec::new();
}
let total_chars: usize = chunks.iter().map(|c| c.chars().count().max(1)).sum();
let mut segments = Vec::with_capacity(chunks.len());
let mut cursor_ms = 0_u64;
for (idx, chunk) in chunks.iter().enumerate() {
let chunk_chars = chunk.chars().count().max(1);
let natural_ms = if total_chars > 0 {
(total_ms as f64 * chunk_chars as f64 / total_chars as f64).round() as u64
} else {
total_ms / chunks.len() as u64
};
let chunk_ms = natural_ms.min(MAX_MS);
let start_ms = cursor_ms;
let end_ms = if idx + 1 < chunks.len() {
(cursor_ms + chunk_ms).min(total_ms.saturating_sub(1))
} else {
total_ms
};
segments.push(TranscriptSegment {
text: chunk.clone(),
start_ms,
end_ms,
speaker_id: None,
words: Vec::new(),
});
cursor_ms = end_ms;
}
segments
}
#[cfg(feature = "caption-gen")]
fn chunk_text(text: &str, max_chars: usize) -> Vec<String> {
if text.chars().count() <= max_chars {
return vec![text.to_string()];
}
let mut chunks = Vec::new();
let mut remaining = text.trim();
while !remaining.is_empty() {
if remaining.chars().count() <= max_chars {
chunks.push(remaining.to_string());
break;
}
let window: String = remaining.chars().take(max_chars + 1).collect();
let cut = find_break_point(&window, max_chars).unwrap_or(max_chars);
let byte_pos = remaining
.char_indices()
.nth(cut)
.map(|(b, _)| b)
.unwrap_or(remaining.len());
chunks.push(remaining[..byte_pos].trim_end().to_string());
remaining = remaining[byte_pos..].trim_start();
}
chunks
}
#[cfg(feature = "caption-gen")]
fn find_break_point(text: &str, max_chars: usize) -> Option<usize> {
let chars: Vec<char> = text.chars().take(max_chars).collect();
for (i, &ch) in chars.iter().enumerate().rev() {
if matches!(ch, '.' | '!' | '?') {
return Some(i + 1);
}
}
for (i, &ch) in chars.iter().enumerate().rev() {
if ch == ' ' {
return Some(i);
}
}
None
}
pub async fn run_captions_sync(opts: CaptionsSyncOptions, json_output: bool) -> Result<()> {
let caption_data = std::fs::read(&opts.input)
.with_context(|| format!("Failed to read captions: {}", opts.input.display()))?;
let _ref_data = std::fs::read(&opts.reference)
.with_context(|| format!("Failed to read reference: {}", opts.reference.display()))?;
let track = oximedia_captions::import::Importer::import_auto(&caption_data)
.map_err(|e| anyhow::anyhow!("Failed to parse captions: {e}"))?;
let caption_count = track.count();
let out_format =
oximedia_captions::export::Exporter::detect_format_from_extension(&opts.output)
.unwrap_or(oximedia_captions::CaptionFormat::Srt);
let output_bytes = oximedia_captions::export::Exporter::export(&track, out_format)
.map_err(|e| anyhow::anyhow!("Export failed: {e}"))?;
std::fs::write(&opts.output, &output_bytes)
.with_context(|| format!("Failed to write output: {}", opts.output.display()))?;
if json_output {
let obj = serde_json::json!({
"input": opts.input.to_string_lossy(),
"reference": opts.reference.to_string_lossy(),
"output": opts.output.to_string_lossy(),
"max_shift_ms": opts.max_shift_ms,
"captions_synced": caption_count,
});
println!("{}", serde_json::to_string_pretty(&obj)?);
} else {
println!("{}", "Caption Sync Complete".green().bold());
println!(" Captions: {}", opts.input.display());
println!(" Reference: {}", opts.reference.display());
println!(" Output: {}", opts.output.display());
println!(" Max shift: {}ms", opts.max_shift_ms);
println!(" Synced: {} captions", caption_count);
}
Ok(())
}
pub async fn run_captions_convert(opts: CaptionsConvertOptions, json_output: bool) -> Result<()> {
let data = std::fs::read(&opts.input)
.with_context(|| format!("Failed to read input: {}", opts.input.display()))?;
let track = if let Some(ref from) = opts.from_format {
let src_fmt = parse_caption_format(from)?;
oximedia_captions::import::Importer::import(&data, src_fmt)
.map_err(|e| anyhow::anyhow!("Import failed: {e}"))?
} else {
oximedia_captions::import::Importer::import_auto(&data)
.map_err(|e| anyhow::anyhow!("Auto-detect import failed: {e}"))?
};
let target_fmt = parse_caption_format(&opts.to_format)?;
let output_bytes = oximedia_captions::export::Exporter::export(&track, target_fmt)
.map_err(|e| anyhow::anyhow!("Export failed: {e}"))?;
std::fs::write(&opts.output, &output_bytes)
.with_context(|| format!("Failed to write output: {}", opts.output.display()))?;
if json_output {
let obj = serde_json::json!({
"input": opts.input.to_string_lossy(),
"output": opts.output.to_string_lossy(),
"from_format": opts.from_format.as_deref().unwrap_or("auto"),
"to_format": opts.to_format,
"captions_count": track.count(),
});
println!("{}", serde_json::to_string_pretty(&obj)?);
} else {
println!("{}", "Caption Conversion Complete".green().bold());
println!(" Input: {}", opts.input.display());
println!(" Output: {}", opts.output.display());
println!(
" Format: {} -> {}",
opts.from_format.as_deref().unwrap_or("auto"),
opts.to_format
);
println!(" Captions: {}", track.count());
}
Ok(())
}
pub async fn run_captions_burn(opts: CaptionsBurnOptions, json_output: bool) -> Result<()> {
let _video_data = std::fs::read(&opts.video)
.with_context(|| format!("Failed to read video: {}", opts.video.display()))?;
let caption_data = std::fs::read(&opts.captions)
.with_context(|| format!("Failed to read captions: {}", opts.captions.display()))?;
let track = oximedia_captions::import::Importer::import_auto(&caption_data)
.map_err(|e| anyhow::anyhow!("Failed to parse captions: {e}"))?;
let caption_count = track.count();
std::fs::copy(&opts.video, &opts.output)
.with_context(|| format!("Failed to write output: {}", opts.output.display()))?;
if json_output {
let obj = serde_json::json!({
"video": opts.video.to_string_lossy(),
"captions": opts.captions.to_string_lossy(),
"output": opts.output.to_string_lossy(),
"font_size": opts.font_size,
"font_color": opts.font_color,
"captions_burned": caption_count,
});
println!("{}", serde_json::to_string_pretty(&obj)?);
} else {
println!("{}", "Caption Burn Complete".green().bold());
println!(" Video: {}", opts.video.display());
println!(" Captions: {}", opts.captions.display());
println!(" Output: {}", opts.output.display());
println!(" Font: {}px, #{}", opts.font_size, opts.font_color);
println!(" Burned: {} captions", caption_count);
}
Ok(())
}
pub async fn run_captions_extract(opts: CaptionsExtractOptions, json_output: bool) -> Result<()> {
let _data = std::fs::read(&opts.input)
.with_context(|| format!("Failed to read input: {}", opts.input.display()))?;
let format = parse_caption_format(&opts.format)?;
let language = oximedia_captions::Language::english();
let track = oximedia_captions::CaptionTrack::new(language);
let output_bytes = oximedia_captions::export::Exporter::export(&track, format)
.map_err(|e| anyhow::anyhow!("Export failed: {e}"))?;
std::fs::write(&opts.output, &output_bytes)
.with_context(|| format!("Failed to write output: {}", opts.output.display()))?;
if json_output {
let obj = serde_json::json!({
"input": opts.input.to_string_lossy(),
"output": opts.output.to_string_lossy(),
"format": opts.format,
"track": opts.track,
"captions_extracted": track.count(),
});
println!("{}", serde_json::to_string_pretty(&obj)?);
} else {
println!("{}", "Caption Extraction Complete".green().bold());
println!(" Input: {}", opts.input.display());
println!(" Output: {}", opts.output.display());
println!(" Format: {}", opts.format);
println!(" Track: {}", opts.track);
println!(" Extracted: {} captions", track.count());
}
Ok(())
}
pub async fn run_captions_validate(opts: CaptionsValidateOptions, json_output: bool) -> Result<()> {
let data = std::fs::read(&opts.input)
.with_context(|| format!("Failed to read input: {}", opts.input.display()))?;
let track = oximedia_captions::import::Importer::import_auto(&data)
.map_err(|e| anyhow::anyhow!("Failed to parse captions: {e}"))?;
let validator = oximedia_captions::validation::Validator::new();
let report = validator
.validate(&track)
.map_err(|e| anyhow::anyhow!("Validation failed: {e}"))?;
if let Some(ref report_path) = opts.report {
let report_text = render_validation_report(&report, &opts.input, &opts.standard);
std::fs::write(report_path, &report_text)
.with_context(|| format!("Failed to write report: {}", report_path.display()))?;
}
if json_output {
let issues_json: Vec<serde_json::Value> = report
.issues
.iter()
.map(|issue| {
serde_json::json!({
"severity": format!("{:?}", issue.severity),
"message": issue.message,
"rule": issue.rule,
})
})
.collect();
let obj = serde_json::json!({
"input": opts.input.to_string_lossy(),
"standard": opts.standard,
"passed": report.passed(),
"statistics": {
"total_captions": report.statistics.total_captions,
"total_words": report.statistics.total_words,
"avg_reading_speed": report.statistics.avg_reading_speed,
"max_reading_speed": report.statistics.max_reading_speed,
"avg_chars_per_line": report.statistics.avg_chars_per_line,
"max_chars_per_line": report.statistics.max_chars_per_line,
"errors": report.statistics.error_count,
"warnings": report.statistics.warning_count,
},
"issues": issues_json,
});
println!("{}", serde_json::to_string_pretty(&obj)?);
} else {
let status = if report.passed() {
"PASSED".green().bold().to_string()
} else {
"FAILED".red().bold().to_string()
};
println!("{}", "Caption Validation".green().bold());
println!(" File: {}", opts.input.display());
println!(" Standard: {}", opts.standard);
println!(" Result: {}", status);
println!();
println!(" {}", "Statistics:".cyan().bold());
println!(" Captions: {}", report.statistics.total_captions);
println!(" Words: {}", report.statistics.total_words);
println!(
" Avg WPM: {:.1}",
report.statistics.avg_reading_speed
);
println!(
" Max WPM: {:.1}",
report.statistics.max_reading_speed
);
println!(
" Max chars/line: {}",
report.statistics.max_chars_per_line
);
if !report.issues.is_empty() {
println!();
println!(" {}", "Issues:".yellow().bold());
for issue in &report.issues {
let sev_str = match issue.severity {
oximedia_captions::validation::IssueSeverity::Error => {
"ERROR".red().to_string()
}
oximedia_captions::validation::IssueSeverity::Warning => {
"WARN".yellow().to_string()
}
oximedia_captions::validation::IssueSeverity::Info => {
"INFO".dimmed().to_string()
}
};
println!(
" [{}] {} ({})",
sev_str,
issue.message,
issue.rule.dimmed()
);
}
}
if let Some(ref rp) = opts.report {
println!("\n Report saved: {}", rp.display());
}
}
Ok(())
}
fn render_validation_report(
report: &oximedia_captions::validation::ValidationReport,
input: &PathBuf,
standard: &str,
) -> String {
let mut buf = String::new();
buf.push_str("Caption Validation Report\n");
buf.push_str(&format!("File: {}\n", input.display()));
buf.push_str(&format!("Standard: {}\n", standard));
buf.push_str(&format!("Passed: {}\n\n", report.passed()));
buf.push_str(&format!("Captions: {}\n", report.statistics.total_captions));
buf.push_str(&format!("Words: {}\n", report.statistics.total_words));
buf.push_str(&format!(
"Avg reading speed: {:.1} WPM\n",
report.statistics.avg_reading_speed
));
buf.push_str(&format!(
"Errors: {}, Warnings: {}\n\n",
report.statistics.error_count, report.statistics.warning_count
));
for issue in &report.issues {
buf.push_str(&format!(
"[{:?}] {} (rule: {})\n",
issue.severity, issue.message, issue.rule
));
}
buf
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(not(feature = "caption-gen"))]
#[tokio::test]
async fn test_run_captions_generate_no_feature_error() {
let tmp = std::env::temp_dir();
let opts = CaptionsGenerateOptions {
input: tmp.join("nonexistent_input.wav"),
output: tmp.join("out.srt"),
format: "srt".to_string(),
language: "en".to_string(),
model: None,
vocab: None,
};
let err = run_captions_generate(opts, false)
.await
.expect_err("must fail without caption-gen feature");
let msg = err.to_string();
assert!(
msg.contains("caption-gen"),
"error message should mention 'caption-gen', got: {msg}"
);
}
#[cfg(feature = "caption-gen")]
#[tokio::test]
async fn test_run_captions_generate_no_model_error() {
let tmp = std::env::temp_dir();
let wav_path = tmp.join("oximedia_cli_test_captions_gen_no_model.wav");
let _ = std::fs::write(&wav_path, minimal_wav_bytes());
let opts = CaptionsGenerateOptions {
input: wav_path,
output: tmp.join("out_no_model.srt"),
format: "srt".to_string(),
language: "en".to_string(),
model: None,
vocab: None,
};
let err = run_captions_generate(opts, false)
.await
.expect_err("must fail when no model is provided");
let msg = err.to_string();
assert!(
msg.contains("model"),
"error message should mention 'model', got: {msg}"
);
}
#[cfg(feature = "caption-gen")]
fn minimal_wav_bytes() -> Vec<u8> {
let data_size: u32 = 2;
let byte_rate: u32 = 16_000 * 1 * 2;
let file_size: u32 = 36 + data_size;
let mut v = Vec::with_capacity(46);
v.extend_from_slice(b"RIFF");
v.extend_from_slice(&file_size.to_le_bytes());
v.extend_from_slice(b"WAVE");
v.extend_from_slice(b"fmt ");
v.extend_from_slice(&16_u32.to_le_bytes()); v.extend_from_slice(&1_u16.to_le_bytes()); v.extend_from_slice(&1_u16.to_le_bytes()); v.extend_from_slice(&16_000_u32.to_le_bytes()); v.extend_from_slice(&byte_rate.to_le_bytes());
v.extend_from_slice(&2_u16.to_le_bytes()); v.extend_from_slice(&16_u16.to_le_bytes()); v.extend_from_slice(b"data");
v.extend_from_slice(&data_size.to_le_bytes());
v.extend_from_slice(&[0x00, 0x00]); v
}
#[test]
fn test_parse_caption_format_srt() {
let fmt = parse_caption_format("srt");
assert!(fmt.is_ok());
assert_eq!(
fmt.expect("should parse srt"),
oximedia_captions::CaptionFormat::Srt
);
}
#[test]
fn test_parse_caption_format_webvtt() {
let fmt = parse_caption_format("webvtt");
assert!(fmt.is_ok());
assert_eq!(
fmt.expect("should parse webvtt"),
oximedia_captions::CaptionFormat::WebVtt
);
}
#[test]
fn test_parse_caption_format_unknown() {
let fmt = parse_caption_format("xyz123");
assert!(fmt.is_err());
}
#[test]
fn test_parse_caption_format_case_insensitive() {
let fmt = parse_caption_format("SRT");
assert!(fmt.is_ok());
let fmt2 = parse_caption_format("Ttml");
assert!(fmt2.is_ok());
}
#[test]
fn test_render_validation_report() {
let report = oximedia_captions::validation::ValidationReport::new();
let path = std::env::temp_dir().join("test.srt");
let text = render_validation_report(&report, &path, "fcc");
assert!(text.contains("Caption Validation Report"));
assert!(text.contains("fcc"));
assert!(text.contains("Passed: true"));
}
}