use crate::{
audio::tts::g2p::types::{PhonemeUnit, Phonemizer},
error::{Error, OutOfRangePayload, Result},
};
pub struct NeuralPhonemizer<F> {
convert: F,
language: String,
}
impl<F> NeuralPhonemizer<F>
where
F: Fn(&str, &str) -> Result<String>,
{
pub fn new(convert: F, language: impl Into<String>) -> Self {
Self {
convert,
language: language.into(),
}
}
#[inline(always)]
pub fn language(&self) -> &str {
&self.language
}
}
impl<F> Phonemizer for NeuralPhonemizer<F>
where
F: Fn(&str, &str) -> Result<String>,
{
fn phonemize(&self, grapheme: &str) -> Result<Vec<PhonemeUnit>> {
let raw = (self.convert)(grapheme, &self.language)?;
let ipa = raw.trim();
if ipa.is_empty() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"neural G2P output",
"must be non-empty for input token",
grapheme,
)));
}
let units: Vec<PhonemeUnit> = ipa
.chars()
.filter(|c| !c.is_whitespace())
.map(|c| PhonemeUnit::new(c.to_string()))
.collect();
Ok(units)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn happy_path_splits_into_phoneme_units() {
let backend = |_w: &str, _l: &str| Ok("h ə l oʊ".to_string());
let p = NeuralPhonemizer::new(backend, "eng-us");
let units = p.phonemize("hello").unwrap();
assert_eq!(
units,
vec![
PhonemeUnit::new("h"),
PhonemeUnit::new("ə"),
PhonemeUnit::new("l"),
PhonemeUnit::new("o"),
PhonemeUnit::new("ʊ"),
]
);
}
#[test]
fn empty_backend_output_errors_with_token() {
let backend = |_w: &str, _l: &str| Ok(" ".to_string());
let p = NeuralPhonemizer::new(backend, "eng-us");
let err = p.phonemize("ghost").unwrap_err();
let msg = err.to_string();
assert!(msg.contains("ghost"), "expected token in {msg:?}");
assert!(msg.contains("empty"), "expected 'empty' in {msg:?}");
}
#[test]
fn backend_error_propagates() {
let backend = |_w: &str, _l: &str| -> Result<String> {
Err(Error::InvariantViolation(
crate::error::InvariantViolationPayload::new(
"neural_phonemizer test mock",
"simulated model failure",
),
))
};
let p = NeuralPhonemizer::new(backend, "eng-us");
let err = p.phonemize("test").unwrap_err();
assert!(err.to_string().contains("model failure"));
}
#[test]
fn language_is_threaded_to_backend() {
let p = NeuralPhonemizer::new(
|word: &str, lang: &str| Ok(format!("<{lang}>:{word}")),
"es",
);
let units = p.phonemize("hola").unwrap();
let joined: String = units.iter().map(|u| u.symbol()).collect();
assert_eq!(joined, "<es>:hola");
assert_eq!(p.language(), "es");
}
}