use std::{
collections::BTreeMap,
fs::File,
io::{BufWriter, Write},
path::Path,
};
use derive_more::{IsVariant, TryUnwrap, Unwrap};
use serde::{Deserialize, Serialize};
use crate::error::{Error, FileIoPayload, FileOp, Result};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Word {
start: f64,
end: f64,
word: String,
#[serde(flatten)]
extra: BTreeMap<String, serde_json::Value>,
}
impl Word {
pub fn new(
start: f64,
end: f64,
word: impl Into<String>,
extra: BTreeMap<String, serde_json::Value>,
) -> Self {
Self {
start,
end,
word: word.into(),
extra,
}
}
#[inline(always)]
pub fn start(&self) -> f64 {
self.start
}
#[inline(always)]
pub fn end(&self) -> f64 {
self.end
}
#[inline(always)]
pub fn word(&self) -> &str {
&self.word
}
#[inline(always)]
pub fn extra(&self) -> &BTreeMap<String, serde_json::Value> {
&self.extra
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Segment {
start: f64,
end: f64,
text: String,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
words: Vec<Word>,
#[serde(skip_serializing_if = "String::is_empty", default)]
speaker_id: String,
}
impl Segment {
#[inline(always)]
pub fn new(
start: f64,
end: f64,
text: impl Into<String>,
words: Vec<Word>,
speaker_id: impl Into<String>,
) -> Self {
Self {
start,
end,
text: text.into(),
words,
speaker_id: speaker_id.into(),
}
}
#[inline(always)]
pub fn start(&self) -> f64 {
self.start
}
#[inline(always)]
pub fn end(&self) -> f64 {
self.end
}
#[inline(always)]
pub fn text(&self) -> &str {
&self.text
}
#[inline(always)]
pub fn words_slice(&self) -> &[Word] {
&self.words
}
#[inline(always)]
pub fn speaker_id(&self) -> &str {
&self.speaker_id
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SentenceToken {
text: String,
start: f64,
end: f64,
duration: f64,
}
impl SentenceToken {
pub fn new(text: impl Into<String>, start: f64, end: f64, duration: f64) -> Self {
Self {
text: text.into(),
start,
end,
duration,
}
}
#[inline(always)]
pub fn text(&self) -> &str {
&self.text
}
#[inline(always)]
pub fn start(&self) -> f64 {
self.start
}
#[inline(always)]
pub fn end(&self) -> f64 {
self.end
}
#[inline(always)]
pub fn duration(&self) -> f64 {
self.duration
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Sentence {
text: String,
start: f64,
end: f64,
duration: f64,
tokens: Vec<SentenceToken>,
#[serde(skip_serializing_if = "String::is_empty", default)]
speaker_id: String,
}
impl Sentence {
#[inline(always)]
pub fn new(
text: impl Into<String>,
start: f64,
end: f64,
duration: f64,
tokens: Vec<SentenceToken>,
speaker_id: impl Into<String>,
) -> Self {
Self {
text: text.into(),
start,
end,
duration,
tokens,
speaker_id: speaker_id.into(),
}
}
#[inline(always)]
pub fn text(&self) -> &str {
&self.text
}
#[inline(always)]
pub fn start(&self) -> f64 {
self.start
}
#[inline(always)]
pub fn end(&self) -> f64 {
self.end
}
#[inline(always)]
pub fn duration(&self) -> f64 {
self.duration
}
#[inline(always)]
pub fn tokens(&self) -> &[SentenceToken] {
&self.tokens
}
#[inline(always)]
pub fn speaker_id(&self) -> &str {
&self.speaker_id
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SegmentsPayload {
text: String,
segments: Vec<Segment>,
}
impl SegmentsPayload {
pub fn new(text: impl Into<String>, segments: Vec<Segment>) -> Self {
Self {
text: text.into(),
segments,
}
}
#[inline(always)]
pub fn text(&self) -> &str {
&self.text
}
#[inline(always)]
pub fn segments(&self) -> &[Segment] {
&self.segments
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SentencesPayload {
text: String,
sentences: Vec<Sentence>,
}
impl SentencesPayload {
pub fn new(text: impl Into<String>, sentences: Vec<Sentence>) -> Self {
Self {
text: text.into(),
sentences,
}
}
#[inline(always)]
pub fn text(&self) -> &str {
&self.text
}
#[inline(always)]
pub fn sentences(&self) -> &[Sentence] {
&self.sentences
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, IsVariant, Unwrap, TryUnwrap)]
#[unwrap(ref, ref_mut)]
#[serde(untagged)]
pub enum Transcript {
Segments(SegmentsPayload),
Sentences(SentencesPayload),
}
impl Transcript {
pub fn text(&self) -> &str {
match self {
Transcript::Segments(p) => p.text(),
Transcript::Sentences(p) => p.text(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Cue<'a> {
start: f64,
end: f64,
text: &'a str,
}
impl<'a> Cue<'a> {
#[inline(always)]
pub const fn new(start: f64, end: f64, text: &'a str) -> Self {
Self { start, end, text }
}
#[inline(always)]
pub fn start(&self) -> f64 {
self.start
}
#[inline(always)]
pub fn end(&self) -> f64 {
self.end
}
#[inline(always)]
pub fn text(&self) -> &str {
self.text
}
}
pub fn get_cues(t: &Transcript) -> Vec<Cue<'_>> {
match t {
Transcript::Sentences(p) => p
.sentences()
.iter()
.map(|s| Cue::new(s.start(), s.end(), s.text()))
.collect(),
Transcript::Segments(p) => {
let mut cues = Vec::with_capacity(p.segments().len());
for s in p.segments() {
cues.push(Cue::new(s.start(), s.end(), s.text()));
for w in s.words_slice() {
cues.push(Cue::new(w.start(), w.end(), w.word()));
}
}
cues
}
}
}
pub fn format_timestamp(seconds: f64) -> String {
let hours = (seconds / 3600.0).floor();
let minutes = (seconds / 60.0).floor() % 60.0;
let rem = seconds - (seconds / 60.0).floor() * 60.0;
let raw = format!("{:02}:{:02}:{:06.3}", hours as i64, minutes as i64, rem);
raw.replace('.', ",")
}
pub fn format_vtt_timestamp(seconds: f64) -> String {
format_timestamp(seconds).replace(',', ".")
}
pub fn save_as_txt(transcript: &Transcript, path: &Path) -> Result<()> {
if path == Path::new("-") {
let stdout = std::io::stdout();
let mut w = stdout.lock();
return save_as_txt_stdout(transcript, &mut w);
}
let final_path = with_extension(path, "txt");
let f = File::create(&final_path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_txt",
FileOp::Create,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
let mut w = BufWriter::new(f);
save_as_txt_to_writer(transcript, &mut w).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_txt",
FileOp::Write,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
w.flush().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_txt",
FileOp::Flush,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
Ok(())
}
fn save_as_txt_to_writer<W: Write>(transcript: &Transcript, w: &mut W) -> std::io::Result<()> {
w.write_all(transcript.text().as_bytes())
}
fn save_as_txt_stdout<W: Write>(transcript: &Transcript, w: &mut W) -> Result<()> {
save_as_txt_to_writer(transcript, w).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_txt: write to stdout failed",
FileOp::Write,
::std::path::PathBuf::from("<stdout>"),
e,
))
})?;
w.flush().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_txt: stdout flush failed",
FileOp::Flush,
::std::path::PathBuf::from("<stdout>"),
e,
))
})?;
Ok(())
}
pub fn save_as_srt(transcript: &Transcript, path: &Path) -> Result<()> {
if path == Path::new("-") {
let stdout = std::io::stdout();
let mut w = stdout.lock();
return save_as_srt_stdout(transcript, &mut w);
}
let final_path = with_extension(path, "srt");
let f = File::create(&final_path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_srt",
FileOp::Create,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
let mut w = BufWriter::new(f);
save_as_srt_to_writer(transcript, &mut w).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_srt",
FileOp::Write,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
w.flush().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_srt",
FileOp::Flush,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
Ok(())
}
fn save_as_srt_to_writer<W: Write>(transcript: &Transcript, w: &mut W) -> std::io::Result<()> {
for (i, cue) in get_cues(transcript).iter().enumerate() {
let idx = i + 1;
let block = format!(
"{}\n{} --> {}\n{}\n\n",
idx,
format_timestamp(cue.start()),
format_timestamp(cue.end()),
cue.text(),
);
w.write_all(block.as_bytes())?;
}
Ok(())
}
fn save_as_srt_stdout<W: Write>(transcript: &Transcript, w: &mut W) -> Result<()> {
save_as_srt_to_writer(transcript, w).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_srt: write to stdout failed",
FileOp::Write,
::std::path::PathBuf::from("<stdout>"),
e,
))
})?;
w.flush().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_srt: stdout flush failed",
FileOp::Flush,
::std::path::PathBuf::from("<stdout>"),
e,
))
})?;
Ok(())
}
pub fn save_as_vtt(transcript: &Transcript, path: &Path) -> Result<()> {
if path == Path::new("-") {
let stdout = std::io::stdout();
let mut w = stdout.lock();
return save_as_vtt_stdout(transcript, &mut w);
}
let final_path = with_extension(path, "vtt");
let f = File::create(&final_path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_vtt",
FileOp::Create,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
let mut w = BufWriter::new(f);
save_as_vtt_to_writer(transcript, &mut w).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_vtt",
FileOp::Write,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
w.flush().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_vtt",
FileOp::Flush,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
Ok(())
}
fn save_as_vtt_to_writer<W: Write>(transcript: &Transcript, w: &mut W) -> std::io::Result<()> {
w.write_all(b"WEBVTT\n\n")?;
for (i, cue) in get_cues(transcript).iter().enumerate() {
let idx = i + 1;
let block = format!(
"{}\n{} --> {}\n{}\n\n",
idx,
format_vtt_timestamp(cue.start()),
format_vtt_timestamp(cue.end()),
cue.text(),
);
w.write_all(block.as_bytes())?;
}
Ok(())
}
fn save_as_vtt_stdout<W: Write>(transcript: &Transcript, w: &mut W) -> Result<()> {
save_as_vtt_to_writer(transcript, w).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_vtt: write to stdout failed",
FileOp::Write,
::std::path::PathBuf::from("<stdout>"),
e,
))
})?;
w.flush().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_vtt: stdout flush failed",
FileOp::Flush,
::std::path::PathBuf::from("<stdout>"),
e,
))
})?;
Ok(())
}
pub fn save_as_json(transcript: &Transcript, path: &Path) -> Result<()> {
if path == Path::new("-") {
let stdout = std::io::stdout();
let mut w = stdout.lock();
return save_as_json_stdout(transcript, &mut w);
}
let final_path = with_extension(path, "json");
let f = File::create(&final_path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_json",
FileOp::Create,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
let mut w = BufWriter::new(f);
save_as_json_to_writer(transcript, &mut w).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_json: serialize",
FileOp::Write,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
w.flush().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_json",
FileOp::Flush,
::std::path::PathBuf::from(&final_path),
e,
))
})?;
Ok(())
}
fn save_as_json_to_writer<W: Write>(transcript: &Transcript, w: &mut W) -> std::io::Result<()> {
let value = transcript_to_python_shape(transcript);
serde_json::to_writer_pretty(w, &value).map_err(std::io::Error::other)
}
fn save_as_json_stdout<W: Write>(transcript: &Transcript, w: &mut W) -> Result<()> {
save_as_json_to_writer(transcript, w).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_json: serialize to stdout failed",
FileOp::Write,
::std::path::PathBuf::from("<stdout>"),
e,
))
})?;
w.flush().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"save_as_json: stdout flush failed",
FileOp::Flush,
::std::path::PathBuf::from("<stdout>"),
e,
))
})?;
Ok(())
}
fn transcript_to_python_shape(t: &Transcript) -> serde_json::Value {
use serde_json::{Map, Value, json};
match t {
Transcript::Sentences(p) => {
let mut sents_arr: Vec<Value> = Vec::with_capacity(p.sentences().len());
for s in p.sentences() {
let mut obj = Map::new();
obj.insert("text".into(), Value::String(s.text().to_owned()));
obj.insert("start".into(), json!(s.start()));
obj.insert("end".into(), json!(s.end()));
obj.insert("duration".into(), json!(s.duration()));
let tok_arr: Vec<Value> = s
.tokens()
.iter()
.map(|tk| {
let mut tobj = Map::new();
tobj.insert("text".into(), Value::String(tk.text().to_owned()));
tobj.insert("start".into(), json!(tk.start()));
tobj.insert("end".into(), json!(tk.end()));
tobj.insert("duration".into(), json!(tk.duration()));
Value::Object(tobj)
})
.collect();
obj.insert("tokens".into(), Value::Array(tok_arr));
if !s.speaker_id().is_empty() {
obj.insert(
"speaker_id".into(),
Value::String(s.speaker_id().to_owned()),
);
}
sents_arr.push(Value::Object(obj));
}
let mut root = Map::new();
root.insert("text".into(), Value::String(p.text().to_owned()));
root.insert("sentences".into(), Value::Array(sents_arr));
Value::Object(root)
}
Transcript::Segments(p) => {
let mut segs_arr: Vec<Value> = Vec::with_capacity(p.segments().len());
for s in p.segments() {
let mut obj = Map::new();
obj.insert("text".into(), Value::String(s.text().to_owned()));
obj.insert("start".into(), json!(s.start()));
obj.insert("end".into(), json!(s.end()));
obj.insert("duration".into(), json!(s.end() - s.start()));
if !s.words_slice().is_empty() {
let words_arr: Vec<Value> = s
.words_slice()
.iter()
.map(|w| {
let mut wobj = Map::new();
wobj.insert("start".into(), json!(w.start()));
wobj.insert("end".into(), json!(w.end()));
wobj.insert("word".into(), Value::String(w.word().to_owned()));
for (k, v) in w.extra() {
wobj.insert(k.clone(), v.clone());
}
Value::Object(wobj)
})
.collect();
obj.insert("words".into(), Value::Array(words_arr));
}
if !s.speaker_id().is_empty() {
obj.insert(
"speaker_id".into(),
Value::String(s.speaker_id().to_owned()),
);
}
segs_arr.push(Value::Object(obj));
}
let mut root = Map::new();
root.insert("text".into(), Value::String(p.text().to_owned()));
root.insert("segments".into(), Value::Array(segs_arr));
Value::Object(root)
}
}
}
fn with_extension(path: &Path, ext: &str) -> std::path::PathBuf {
let mut s = path.as_os_str().to_owned();
s.push(".");
s.push(ext);
std::path::PathBuf::from(s)
}
#[cfg(test)]
mod tests;