use once_cell::sync::Lazy;
use regex::Regex;
pub const NUM_SPEECH_TOKENS: usize = 1024;
static RE_SPEECH_TOKEN: Lazy<Regex> =
Lazy::new(|| Regex::new(r"<\|speech_(\d+)\|>").unwrap());
pub fn ids_to_token_str(ids: &[i32]) -> String {
let mut s = String::with_capacity(ids.len() * 15);
for &id in ids {
s.push_str(&format!("<|speech_{id}|>"));
}
s
}
pub fn extract_ids(s: &str) -> Vec<i32> {
RE_SPEECH_TOKEN
.captures_iter(s)
.filter_map(|cap| cap[1].parse::<i32>().ok())
.collect()
}
pub const STOP_TOKEN: &str = "<|SPEECH_GENERATION_END|>";
pub fn build_prompt(ref_phones: &str, input_phones: &str, ref_codes: &[i32]) -> String {
let codes_str = ids_to_token_str(ref_codes);
format!(
"user: Convert the text to speech:\
<|TEXT_PROMPT_START|>{ref_phones} {input_phones}<|TEXT_PROMPT_END|>\n\
assistant:<|SPEECH_GENERATION_START|>{codes_str}"
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ids_to_token_str() {
assert_eq!(ids_to_token_str(&[0, 5, 42]), "<|speech_0|><|speech_5|><|speech_42|>");
assert_eq!(ids_to_token_str(&[]), "");
}
#[test]
fn test_extract_ids() {
let s = "<|speech_0|><|speech_5|><|speech_42|><|SPEECH_GENERATION_END|>";
assert_eq!(extract_ids(s), vec![0, 5, 42]);
}
#[test]
fn test_extract_ids_empty() {
assert_eq!(extract_ids("no tokens here"), Vec::<i32>::new());
}
#[test]
fn test_extract_ids_ignores_non_speech() {
let s = "<|SPEECH_GENERATION_START|><|speech_10|><|SPEECH_GENERATION_END|>";
assert_eq!(extract_ids(s), vec![10]);
}
#[test]
fn test_build_prompt_contains_key_parts() {
let prompt = build_prompt("hɛloʊ", "wɜːld", &[0, 1, 2]);
assert!(prompt.contains("<|TEXT_PROMPT_START|>hɛloʊ wɜːld<|TEXT_PROMPT_END|>"));
assert!(prompt.contains("<|SPEECH_GENERATION_START|>"));
assert!(prompt.contains("<|speech_0|><|speech_1|><|speech_2|>"));
}
}