use std::io::Write;
use crate::Result;
use crate::segment_encoder::SegmentEncoder;
use crate::segments::Segment;
pub struct VttEncoder<W: Write> {
w: W,
started: bool,
closed: bool,
}
impl<W: Write> VttEncoder<W> {
pub fn new(w: W) -> Self {
Self {
w,
started: false,
closed: false,
}
}
fn start_if_needed(&mut self) -> Result<()> {
if !self.started {
self.w.write_all(b"WEBVTT\n\n")?;
self.started = true;
}
Ok(())
}
}
impl<W: Write> SegmentEncoder for VttEncoder<W> {
fn write_segment(&mut self, seg: &Segment) -> Result<()> {
if self.closed {
return Err(crate::Error::invalid_input(
"cannot write segment: encoder is already closed",
));
}
self.start_if_needed()?;
let start = format_timestamp_vtt(seg.start_seconds);
let end = format_timestamp_vtt(seg.end_seconds);
writeln!(&mut self.w, "{start} --> {end}")?;
writeln!(&mut self.w, "{}", seg.text)?;
writeln!(&mut self.w)?;
self.w.flush()?;
Ok(())
}
fn close(&mut self) -> Result<()> {
if self.closed {
return Ok(());
}
self.w.flush()?;
self.closed = true;
Ok(())
}
}
fn format_timestamp_vtt(seconds: f32) -> String {
let total_ms = (seconds * 1000.0).round() as u64;
let ms = total_ms % 1000;
let total_s = total_ms / 1000;
let s = total_s % 60;
let total_m = total_s / 60;
let m = total_m % 60;
let h = total_m / 60;
format!("{h:02}:{m:02}:{s:02}.{ms:03}")
}
#[cfg(test)]
mod tests {
use super::*;
fn seg(start: f32, end: f32, text: &str) -> Segment {
Segment {
start_seconds: start,
end_seconds: end,
text: text.to_string(),
tokens: Vec::new(),
language_code: "en".to_string(),
next_speaker_turn: false,
}
}
#[test]
fn vtt_close_without_segments_emits_nothing() -> anyhow::Result<()> {
let mut out = Vec::new();
let mut enc = VttEncoder::new(&mut out);
enc.close()?;
assert_eq!(std::str::from_utf8(&out)?, "");
Ok(())
}
#[test]
fn vtt_writes_header_once_and_formats_cues() -> anyhow::Result<()> {
let mut out = Vec::new();
let mut enc = VttEncoder::new(&mut out);
enc.write_segment(&seg(0.0, 1.2345, "hello"))?;
enc.write_segment(&seg(61.2, 62.0, "world"))?;
enc.close()?;
let s = std::str::from_utf8(&out)?;
assert!(s.starts_with("WEBVTT\n\n"));
assert!(s.contains("00:00:00.000 --> 00:00:01.235\nhello\n\n"));
assert!(s.contains("00:01:01.200 --> 00:01:02.000\nworld\n\n"));
assert_eq!(s.matches("WEBVTT\n\n").count(), 1);
Ok(())
}
#[test]
fn vtt_format_timestamp_rounds_to_nearest_millisecond() {
assert_eq!(format_timestamp_vtt(0.0004), "00:00:00.000");
assert_eq!(format_timestamp_vtt(0.0005), "00:00:00.001");
assert_eq!(format_timestamp_vtt(1.9995), "00:00:02.000");
}
#[test]
fn vtt_write_after_close_errors() -> anyhow::Result<()> {
let mut out = Vec::new();
let mut enc = VttEncoder::new(&mut out);
enc.close()?;
let err = enc.write_segment(&seg(0.0, 1.0, "nope")).unwrap_err();
assert!(err.to_string().contains("already closed"));
Ok(())
}
}