use crate::errors::{Result, TypecastError};
use base64::{engine::general_purpose::STANDARD as B64, Engine};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct AlignmentSegmentWord {
pub text: String,
pub start: f64,
pub end: f64,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct AlignmentSegmentCharacter {
pub text: String,
pub start: f64,
pub end: f64,
}
#[derive(Debug, Clone, Serialize)]
pub struct TTSRequestWithTimestamps {
pub voice_id: String,
pub text: String,
pub model: crate::models::TTSModel,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u32>,
}
impl TTSRequestWithTimestamps {
pub fn new(
voice_id: impl Into<String>,
text: impl Into<String>,
model: crate::models::TTSModel,
) -> Self {
Self {
voice_id: voice_id.into(),
text: text.into(),
model,
language: None,
prompt: None,
output: None,
seed: None,
}
}
pub fn language(mut self, language: impl Into<String>) -> Self {
self.language = Some(language.into());
self
}
pub fn prompt(mut self, prompt: serde_json::Value) -> Self {
self.prompt = Some(prompt);
self
}
pub fn output(mut self, output: serde_json::Value) -> Self {
self.output = Some(output);
self
}
pub fn seed(mut self, seed: u32) -> Self {
self.seed = Some(seed);
self
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TTSWithTimestampsResponse {
pub audio: String,
pub audio_format: String,
pub audio_duration: f64,
pub words: Option<Vec<AlignmentSegmentWord>>,
pub characters: Option<Vec<AlignmentSegmentCharacter>>,
}
impl TTSWithTimestampsResponse {
pub fn audio_bytes(&self) -> Result<Vec<u8>> {
B64.decode(&self.audio)
.map_err(|e| TypecastError::DecodeError(e.to_string()))
}
pub fn save_audio<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let bytes = self.audio_bytes()?;
fs::write(path, bytes).map_err(|e| TypecastError::IoError(e.to_string()))
}
pub fn to_srt(&self) -> Result<String> {
format_captions(self, true)
}
pub fn to_vtt(&self) -> Result<String> {
format_captions(self, false)
}
}
const MAX_CAPTION_SECONDS: f64 = 7.0;
const MAX_CAPTION_CHARS: usize = 42;
const SENTENCE_TERMINATORS: &[&str] = &[".", "?", "!", "\u{3002}", "\u{ff1f}", "\u{ff01}"];
struct Segment {
text: String,
start: f64,
end: f64,
}
struct Cue {
text: String,
start: f64,
end: f64,
}
fn pick_segments(
resp: &TTSWithTimestampsResponse,
) -> Result<(Vec<Segment>, bool)> {
let word_segs = |words: &[crate::timestamps::AlignmentSegmentWord]| -> Vec<Segment> {
words
.iter()
.map(|w| Segment {
text: w.text.clone(),
start: w.start,
end: w.end,
})
.collect()
};
let char_segs = |chars: &[crate::timestamps::AlignmentSegmentCharacter]| -> Vec<Segment> {
chars
.iter()
.map(|c| Segment {
text: c.text.clone(),
start: c.start,
end: c.end,
})
.collect()
};
let multi_words = resp.words.as_deref().filter(|w| w.len() >= 2);
let chars = resp.characters.as_deref().filter(|c| !c.is_empty());
let single_word = resp.words.as_deref().filter(|w| w.len() == 1);
if let Some(words) = multi_words {
Ok((word_segs(words), true))
} else if let Some(c) = chars {
Ok((char_segs(c), false))
} else if let Some(words) = single_word {
Ok((word_segs(words), true))
} else {
Err(TypecastError::CaptioningError(
"no alignment segments to caption from".into(),
))
}
}
fn join_parts(parts: &[String], word_mode: bool) -> String {
let sep = if word_mode { " " } else { "" };
parts.join(sep).trim().to_string()
}
fn ends_in_sentence(text: &str) -> bool {
let trimmed = text.trim_end();
SENTENCE_TERMINATORS.iter().any(|t| trimmed.ends_with(t))
}
fn group_into_cues(segs: &[Segment], word_mode: bool) -> Vec<Cue> {
let mut cues: Vec<Cue> = Vec::new();
let mut parts: Vec<String> = Vec::new();
let mut cur_start: f64 = 0.0;
let mut last_end: f64 = 0.0;
fn emit(cues: &mut Vec<Cue>, text: String, start: f64, end: f64) {
if !text.is_empty() {
cues.push(Cue { text, start, end });
}
}
for seg in segs {
if !parts.is_empty() {
let mut tentative = parts.clone();
tentative.push(seg.text.clone());
let would_be = join_parts(&tentative, word_mode);
let too_long_secs = (seg.end - cur_start) > MAX_CAPTION_SECONDS;
let too_long_chars = would_be.chars().count() > MAX_CAPTION_CHARS;
if too_long_secs || too_long_chars {
emit(&mut cues, join_parts(&parts, word_mode), cur_start, last_end);
parts.clear();
}
}
if parts.is_empty() {
cur_start = seg.start;
}
parts.push(seg.text.clone());
last_end = seg.end;
if ends_in_sentence(&seg.text) {
emit(&mut cues, join_parts(&parts, word_mode), cur_start, seg.end);
parts.clear();
}
}
if !parts.is_empty() {
emit(&mut cues, join_parts(&parts, word_mode), cur_start, last_end);
}
cues
}
fn format_srt_time(seconds: f64) -> String {
let total_ms = (seconds * 1000.0).round() as i64;
let ms = total_ms % 1000;
let total_sec = total_ms / 1000;
let ss = total_sec % 60;
let total_min = total_sec / 60;
let mm = total_min % 60;
let hh = total_min / 60;
format!("{:02}:{:02}:{:02},{:03}", hh, mm, ss, ms)
}
fn format_vtt_time(seconds: f64) -> String {
format_srt_time(seconds).replace(',', ".")
}
fn format_captions(resp: &TTSWithTimestampsResponse, srt: bool) -> Result<String> {
let (segs, word_mode) = pick_segments(resp)?;
let cues = group_into_cues(&segs, word_mode);
if cues.is_empty() {
return Err(TypecastError::CaptioningError(
"no alignment segments to caption from".into(),
));
}
let mut out = String::new();
if !srt {
out.push_str("WEBVTT\n\n");
}
for (i, cue) in cues.iter().enumerate() {
if srt {
out.push_str(&format!("{}\n", i + 1));
}
let (s, e) = if srt {
(format_srt_time(cue.start), format_srt_time(cue.end))
} else {
(format_vtt_time(cue.start), format_vtt_time(cue.end))
};
out.push_str(&format!("{} --> {}\n", s, e));
out.push_str(&cue.text);
out.push_str("\n\n");
}
Ok(out)
}