use crate::decoder::TimedToken;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TimestampMode {
#[default]
Tokens,
Words,
Sentences,
}
pub fn process_timestamps(tokens: &[TimedToken], mode: TimestampMode) -> Vec<TimedToken> {
match mode {
TimestampMode::Tokens => tokens.to_vec(),
TimestampMode::Words => group_by_words(tokens),
TimestampMode::Sentences => group_by_sentences(tokens),
}
}
pub(crate) fn group_by_words(tokens: &[TimedToken]) -> Vec<TimedToken> {
if tokens.is_empty() {
return Vec::new();
}
let mut words = Vec::new();
let mut current_word_text = String::new();
let mut current_word_start = 0.0;
let mut last_word_lower = String::new();
for (i, token) in tokens.iter().enumerate() {
if token.text.trim().is_empty() {
if !current_word_text.is_empty() {
let word_lower = current_word_text.to_lowercase();
if word_lower != last_word_lower {
words.push(TimedToken {
text: current_word_text.clone(),
start: current_word_start,
end: if i > 0 { tokens[i - 1].end } else { token.end },
});
last_word_lower = word_lower;
}
current_word_text.clear();
}
continue;
}
let is_pure_punctuation =
!token.text.is_empty() && token.text.chars().all(|c| c.is_ascii_punctuation());
let token_without_marker = token.text.trim_start_matches('▁').trim_start_matches(' ');
let is_contraction = token_without_marker.starts_with('\'');
let is_hyphenation = token_without_marker.starts_with('-');
let starts_word =
(token.text.starts_with('▁') || token.text.starts_with(' ') || is_pure_punctuation)
&& !is_contraction
&& !is_hyphenation
|| i == 0;
if starts_word && !current_word_text.is_empty() {
let word_lower = current_word_text.to_lowercase();
if word_lower != last_word_lower {
words.push(TimedToken {
text: current_word_text.clone(),
start: current_word_start,
end: tokens[i - 1].end,
});
last_word_lower = word_lower;
}
current_word_text.clear();
}
if current_word_text.is_empty() {
current_word_start = token.start;
}
let token_text = token.text.trim_start_matches('▁').trim_start_matches(' ');
current_word_text.push_str(token_text);
}
if !current_word_text.is_empty() {
let word_lower = current_word_text.to_lowercase();
if word_lower != last_word_lower {
words.push(TimedToken {
text: current_word_text,
start: current_word_start,
end: tokens.last().unwrap().end,
});
}
}
words
}
fn group_by_sentences(tokens: &[TimedToken]) -> Vec<TimedToken> {
let words = group_by_words(tokens);
if words.is_empty() {
return Vec::new();
}
let mut sentences = Vec::new();
let mut current_sentence = Vec::new();
for word in words {
current_sentence.push(word.clone());
let ends_sentence =
word.text.contains('.') || word.text.contains('?') || word.text.contains('!');
if ends_sentence {
let sentence_text = format_sentence(¤t_sentence);
let start = current_sentence.first().unwrap().start;
let end = current_sentence.last().unwrap().end;
if !sentence_text.is_empty() {
sentences.push(TimedToken {
text: sentence_text,
start,
end,
});
}
current_sentence.clear();
}
}
if !current_sentence.is_empty() {
let sentence_text = format_sentence(¤t_sentence);
let start = current_sentence.first().unwrap().start;
let end = current_sentence.last().unwrap().end;
if !sentence_text.is_empty() {
sentences.push(TimedToken {
text: sentence_text,
start,
end,
});
}
}
sentences
}
fn format_sentence(words: &[TimedToken]) -> String {
let result: Vec<&str> = words.iter().map(|w| w.text.as_str()).collect();
let mut output = String::new();
for (i, word) in result.iter().enumerate() {
let is_standalone_punct = word.len() == 1
&& word
.chars()
.all(|c| matches!(c, '.' | ',' | '!' | '?' | ';' | ':' | ')'));
if i > 0 && !is_standalone_punct {
output.push(' ');
}
output.push_str(word);
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_word_grouping() {
let tokens = vec![
TimedToken {
text: "▁Hello".to_string(),
start: 0.0,
end: 0.5,
},
TimedToken {
text: "▁world".to_string(),
start: 0.5,
end: 1.0,
},
];
let words = group_by_words(&tokens);
assert_eq!(words.len(), 2);
assert_eq!(words[0].text, "Hello");
assert_eq!(words[1].text, "world");
}
#[test]
fn test_word_grouping_with_hyphenated_word() {
let tokens = vec![
TimedToken {
text: "▁twenty".to_string(),
start: 0.0,
end: 0.3,
},
TimedToken {
text: "-two".to_string(),
start: 0.3,
end: 0.6,
},
TimedToken {
text: "▁apples".to_string(),
start: 0.6,
end: 1.0,
},
];
let words = group_by_words(&tokens);
assert_eq!(words.len(), 2);
assert_eq!(words[0].text, "twenty-two");
assert_eq!(words[1].text, "apples");
assert_eq!(words[0].start, 0.0);
assert_eq!(words[0].end, 0.6);
assert_eq!(words[1].start, 0.6);
assert_eq!(words[1].end, 1.0);
}
#[test]
fn test_sentence_grouping() {
let tokens = vec![
TimedToken {
text: "▁Hello".to_string(),
start: 0.0,
end: 0.5,
},
TimedToken {
text: "▁world".to_string(),
start: 0.5,
end: 1.0,
},
TimedToken {
text: ".".to_string(),
start: 1.0,
end: 1.1,
},
];
let sentences = group_by_sentences(&tokens);
assert_eq!(sentences.len(), 1);
assert_eq!(sentences[0].text, "Hello world.");
assert_eq!(sentences[0].start, 0.0);
assert_eq!(sentences[0].end, 1.1);
}
#[test]
fn test_repetition_preservation() {
let words = vec![
TimedToken {
text: "uh".to_string(),
start: 0.0,
end: 0.5,
},
TimedToken {
text: "uh".to_string(),
start: 0.5,
end: 1.0,
},
TimedToken {
text: "hello".to_string(),
start: 1.0,
end: 1.5,
},
];
let result = format_sentence(&words);
assert_eq!(result, "uh uh hello");
}
#[test]
fn test_space_token_separates_words_from_digits() {
let tokens = vec![
TimedToken {
text: " like".to_string(),
start: 0.0,
end: 0.5,
},
TimedToken {
text: " ".to_string(), start: 0.5,
end: 0.5,
},
TimedToken {
text: "1".to_string(),
start: 0.5,
end: 0.6,
},
TimedToken {
text: "0".to_string(),
start: 0.6,
end: 0.7,
},
TimedToken {
text: "0".to_string(),
start: 0.7,
end: 0.8,
},
];
let words = group_by_words(&tokens);
assert_eq!(words.len(), 2);
assert_eq!(words[0].text, "like");
assert_eq!(words[1].text, "100");
let sentence = format_sentence(&words);
assert_eq!(sentence, "like 100");
}
}