#![forbid(unsafe_code)]
pub mod codec_id {
pub const UTF8: &str = "S_TEXT/UTF8";
pub const WEBVTT: &str = "S_TEXT/WEBVTT";
pub const ASS: &str = "S_TEXT/ASS";
pub const SSA: &str = "S_TEXT/SSA";
pub const HDMV_PGS: &str = "S_HDMV/PGS";
pub const VOBSUB: &str = "S_VOBSUB";
}
#[derive(Debug, Clone, PartialEq)]
pub struct SubtitleCue {
pub start_ms: i64,
pub duration_ms: Option<u64>,
pub text: String,
pub id: Option<String>,
}
impl SubtitleCue {
#[must_use]
pub fn new(start_ms: i64, duration_ms: u64, text: impl Into<String>) -> Self {
Self {
start_ms,
duration_ms: Some(duration_ms),
text: text.into(),
id: None,
}
}
#[must_use]
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
}
#[derive(Debug, Clone, Default)]
pub struct WebVttDocument {
pub header: Option<String>,
pub cues: Vec<SubtitleCue>,
}
impl WebVttDocument {
#[must_use]
pub fn parse(input: &str) -> Self {
let mut doc = Self::default();
let trimmed = input.trim_start_matches('\u{FEFF}');
if !trimmed.starts_with("WEBVTT") {
return doc;
}
let mut lines = trimmed.lines().peekable();
lines.next();
let mut header_lines: Vec<&str> = Vec::new();
while let Some(line) = lines.peek().copied() {
if line.trim().is_empty() {
break;
}
if line.contains("-->") {
break;
}
header_lines.push(line);
lines.next();
}
let header_text = header_lines.join("\n").trim().to_owned();
if !header_text.is_empty() {
doc.header = Some(header_text);
}
while let Some(cue) = parse_vtt_cue(&mut lines) {
doc.cues.push(cue);
}
doc
}
#[must_use]
pub fn to_string(&self) -> String {
let mut out = String::from("WEBVTT\n");
if let Some(ref hdr) = self.header {
out.push('\n');
out.push_str(hdr);
out.push('\n');
}
for cue in &self.cues {
out.push('\n');
if let Some(ref id) = cue.id {
out.push_str(id);
out.push('\n');
}
let dur = cue.duration_ms.unwrap_or(0);
let end_ms = cue.start_ms as u64 + dur;
out.push_str(&format_vtt_time(cue.start_ms as u64));
out.push_str(" --> ");
out.push_str(&format_vtt_time(end_ms));
out.push('\n');
out.push_str(&cue.text);
out.push('\n');
}
out
}
}
#[derive(Debug, Clone, Default)]
pub struct AssDocument {
pub script_info: Vec<(String, String)>,
pub styles: Vec<AssStyle>,
pub events: Vec<SubtitleCue>,
pub styles_raw: String,
}
#[derive(Debug, Clone)]
pub struct AssStyle {
pub name: String,
pub fontname: String,
pub fontsize: u32,
pub primary_colour: String,
pub bold: bool,
pub italic: bool,
}
impl AssDocument {
#[must_use]
pub fn parse(input: &str) -> Self {
let mut doc = Self::default();
let mut section = "";
for line in input.lines() {
let line = line.trim();
if line.starts_with('[') && line.ends_with(']') {
section = line;
continue;
}
match section {
"[Script Info]" => {
if let Some((k, v)) = line.split_once(':') {
doc.script_info
.push((k.trim().to_owned(), v.trim().to_owned()));
}
}
"[V4+ Styles]" | "[V4 Styles]" => {
doc.styles_raw.push_str(line);
doc.styles_raw.push('\n');
if let Some(stripped) = line.strip_prefix("Style: ") {
let parts: Vec<&str> = stripped.splitn(24, ',').collect();
if parts.len() >= 4 {
doc.styles.push(AssStyle {
name: parts[0].to_owned(),
fontname: parts[1].to_owned(),
fontsize: parts[2].parse().unwrap_or(20),
primary_colour: parts[3].to_owned(),
bold: parts.get(7).map_or(false, |&s| s == "-1"),
italic: parts.get(8).map_or(false, |&s| s == "-1"),
});
}
}
}
"[Events]" => {
if let Some(stripped) = line.strip_prefix("Dialogue: ") {
if let Some(cue) = parse_ass_dialogue(stripped) {
doc.events.push(cue);
}
}
}
_ => {}
}
}
doc
}
#[must_use]
pub fn codec_private(&self) -> String {
let mut out = String::from("[Script Info]\n");
for (k, v) in &self.script_info {
out.push_str(k);
out.push_str(": ");
out.push_str(v);
out.push('\n');
}
out.push_str("\n[V4+ Styles]\n");
out.push_str(&self.styles_raw);
out
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SubtitleFormat {
WebVtt,
Ass,
Utf8,
}
impl SubtitleFormat {
#[must_use]
pub fn codec_id(self) -> &'static str {
match self {
Self::WebVtt => codec_id::WEBVTT,
Self::Ass => codec_id::ASS,
Self::Utf8 => codec_id::UTF8,
}
}
}
#[derive(Debug, Clone)]
pub struct SubtitlePacket {
pub timestamp_ms: i64,
pub duration_ms: Option<u64>,
pub payload: Vec<u8>,
pub is_keyframe: bool,
}
#[derive(Debug, Clone)]
pub struct SubtitleEncoder {
format: SubtitleFormat,
}
impl SubtitleEncoder {
#[must_use]
pub fn new(format: SubtitleFormat) -> Self {
Self { format }
}
#[must_use]
pub fn encode_cue(&self, cue: &SubtitleCue) -> SubtitlePacket {
let payload = match self.format {
SubtitleFormat::WebVtt => {
cue.text.as_bytes().to_vec()
}
SubtitleFormat::Ass => {
cue.text.as_bytes().to_vec()
}
SubtitleFormat::Utf8 => {
cue.text.as_bytes().to_vec()
}
};
SubtitlePacket {
timestamp_ms: cue.start_ms,
duration_ms: cue.duration_ms,
payload,
is_keyframe: true,
}
}
#[must_use]
pub fn encode_all(&self, cues: &[SubtitleCue]) -> Vec<SubtitlePacket> {
cues.iter().map(|c| self.encode_cue(c)).collect()
}
#[must_use]
pub fn codec_id(&self) -> &'static str {
self.format.codec_id()
}
}
fn parse_vtt_cue<'a>(
lines: &mut std::iter::Peekable<impl Iterator<Item = &'a str>>,
) -> Option<SubtitleCue> {
while let Some(&line) = lines.peek() {
if line.trim().is_empty() {
lines.next();
} else {
break;
}
}
let mut cue_id: Option<String> = None;
if let Some(&line) = lines.peek() {
if !line.contains("-->") && !line.is_empty() {
cue_id = Some(line.to_owned());
lines.next();
}
}
let timing_line = lines.next()?;
let (start_ms, duration_ms) = parse_vtt_timing(timing_line)?;
let mut text_lines: Vec<&str> = Vec::new();
while let Some(&line) = lines.peek() {
if line.trim().is_empty() {
break;
}
text_lines.push(line);
lines.next();
}
if text_lines.is_empty() {
return None;
}
Some(SubtitleCue {
start_ms,
duration_ms: Some(duration_ms),
text: text_lines.join("\n"),
id: cue_id,
})
}
fn parse_vtt_timing(line: &str) -> Option<(i64, u64)> {
let parts: Vec<&str> = line.split("-->").collect();
if parts.len() < 2 {
return None;
}
let start = parse_vtt_timestamp(parts[0].trim())?;
let end = parse_vtt_timestamp(parts[1].split_whitespace().next().unwrap_or(""))?;
Some((start as i64, end.saturating_sub(start)))
}
fn parse_vtt_timestamp(s: &str) -> Option<u64> {
let parts: Vec<&str> = s.splitn(3, ':').collect();
let (h, m, sec_ms) = match parts.len() {
2 => (0u64, parts[0].parse::<u64>().ok()?, parts[1]),
3 => (
parts[0].parse::<u64>().ok()?,
parts[1].parse::<u64>().ok()?,
parts[2],
),
_ => return None,
};
let (sec_str, ms_str) = sec_ms.split_once('.')?;
let secs: u64 = sec_str.parse().ok()?;
let ms: u64 = ms_str.parse().ok()?;
Some(h * 3_600_000 + m * 60_000 + secs * 1000 + ms)
}
fn format_vtt_time(ms: u64) -> String {
let h = ms / 3_600_000;
let m = (ms % 3_600_000) / 60_000;
let s = (ms % 60_000) / 1_000;
let millis = ms % 1_000;
format!("{h:02}:{m:02}:{s:02}.{millis:03}")
}
fn parse_ass_dialogue(line: &str) -> Option<SubtitleCue> {
let parts: Vec<&str> = line.splitn(10, ',').collect();
if parts.len() < 10 {
return None;
}
let start = parse_ass_time(parts[1].trim())?;
let end = parse_ass_time(parts[2].trim())?;
let text = parts[9].to_owned();
Some(SubtitleCue {
start_ms: start as i64,
duration_ms: Some(end.saturating_sub(start)),
text,
id: None,
})
}
fn parse_ass_time(s: &str) -> Option<u64> {
let parts: Vec<&str> = s.splitn(3, ':').collect();
if parts.len() != 3 {
return None;
}
let h: u64 = parts[0].parse().ok()?;
let m: u64 = parts[1].parse().ok()?;
let (sec_str, cs_str) = parts[2].split_once('.')?;
let secs: u64 = sec_str.parse().ok()?;
let cs: u64 = cs_str.parse().ok()?;
Some(h * 3_600_000 + m * 60_000 + secs * 1_000 + cs * 10)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_webvtt_parse_simple() {
let vtt = "WEBVTT\n\n1\n00:00:01.000 --> 00:00:03.000\nHello world\n";
let doc = WebVttDocument::parse(vtt);
assert_eq!(doc.cues.len(), 1);
assert_eq!(doc.cues[0].start_ms, 1000);
assert_eq!(doc.cues[0].duration_ms, Some(2000));
assert_eq!(doc.cues[0].text, "Hello world");
assert_eq!(doc.cues[0].id, Some("1".to_owned()));
}
#[test]
fn test_webvtt_parse_no_id() {
let vtt = "WEBVTT\n\n00:00:02.500 --> 00:00:05.000\nSubtitle line\n";
let doc = WebVttDocument::parse(vtt);
assert_eq!(doc.cues.len(), 1);
assert_eq!(doc.cues[0].start_ms, 2500);
assert!(doc.cues[0].id.is_none());
}
#[test]
fn test_webvtt_roundtrip() {
let original = "WEBVTT\n\n1\n00:00:01.000 --> 00:00:03.000\nHello\n";
let doc = WebVttDocument::parse(original);
let serialized = doc.to_string();
let reparsed = WebVttDocument::parse(&serialized);
assert_eq!(reparsed.cues.len(), 1);
assert_eq!(reparsed.cues[0].text, "Hello");
}
#[test]
fn test_ass_parse_dialogue() {
let ass = "[Script Info]\nTitle: Test\n\n[V4+ Styles]\n\n[Events]\nDialogue: 0,0:00:01.00,0:00:03.00,Default,,0,0,0,,Hello ASS\n";
let doc = AssDocument::parse(ass);
assert_eq!(doc.events.len(), 1);
assert_eq!(doc.events[0].start_ms, 1000);
assert_eq!(doc.events[0].duration_ms, Some(2000));
assert!(doc.events[0].text.contains("Hello ASS"));
}
#[test]
fn test_subtitle_encoder_webvtt() {
let encoder = SubtitleEncoder::new(SubtitleFormat::WebVtt);
let cue = SubtitleCue::new(1000, 2000, "Test subtitle");
let pkt = encoder.encode_cue(&cue);
assert_eq!(pkt.timestamp_ms, 1000);
assert_eq!(pkt.duration_ms, Some(2000));
assert_eq!(pkt.payload, b"Test subtitle");
assert!(pkt.is_keyframe);
}
#[test]
fn test_subtitle_encoder_ass() {
let encoder = SubtitleEncoder::new(SubtitleFormat::Ass);
assert_eq!(encoder.codec_id(), "S_TEXT/ASS");
}
#[test]
fn test_subtitle_encoder_utf8() {
let encoder = SubtitleEncoder::new(SubtitleFormat::Utf8);
assert_eq!(encoder.codec_id(), "S_TEXT/UTF8");
}
#[test]
fn test_encode_all() {
let encoder = SubtitleEncoder::new(SubtitleFormat::WebVtt);
let cues = vec![
SubtitleCue::new(0, 1000, "First"),
SubtitleCue::new(2000, 1500, "Second"),
];
let packets = encoder.encode_all(&cues);
assert_eq!(packets.len(), 2);
assert_eq!(packets[0].timestamp_ms, 0);
assert_eq!(packets[1].timestamp_ms, 2000);
}
#[test]
fn test_subtitle_cue_with_id() {
let cue = SubtitleCue::new(500, 1000, "Line").with_id("42");
assert_eq!(cue.id, Some("42".to_owned()));
}
#[test]
fn test_format_vtt_time() {
assert_eq!(format_vtt_time(61_500), "00:01:01.500");
assert_eq!(format_vtt_time(3_661_001), "01:01:01.001");
}
#[test]
fn test_parse_vtt_timestamp_mm_ss() {
let ms = parse_vtt_timestamp("01:02.500");
assert_eq!(ms, Some(62_500));
}
#[test]
fn test_parse_ass_time() {
let ms = parse_ass_time("0:01:30.00");
assert_eq!(ms, Some(90_000));
}
#[test]
fn test_webvtt_invalid() {
let bad = "NOT WEBVTT\nsome garbage";
let doc = WebVttDocument::parse(bad);
assert!(doc.cues.is_empty());
}
#[test]
fn test_subtitle_format_codec_ids() {
assert_eq!(SubtitleFormat::WebVtt.codec_id(), "S_TEXT/WEBVTT");
assert_eq!(SubtitleFormat::Ass.codec_id(), "S_TEXT/ASS");
assert_eq!(SubtitleFormat::Utf8.codec_id(), "S_TEXT/UTF8");
}
}