use alloc::vec::Vec;
use serde::{Deserialize, Serialize};
use svara::phoneme::Phoneme;
use super::PronunciationDict;
use super::entry::DictEntry;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct G2PResult {
phonemes: Vec<Phoneme>,
confidence: f32,
}
impl G2PResult {
#[must_use]
pub fn new(phonemes: Vec<Phoneme>, confidence: f32) -> Self {
Self {
phonemes,
confidence,
}
}
#[must_use]
pub fn phonemes(&self) -> &[Phoneme] {
&self.phonemes
}
#[must_use]
pub fn confidence(&self) -> f32 {
self.confidence
}
#[must_use]
pub fn into_phonemes(self) -> Vec<Phoneme> {
self.phonemes
}
}
pub trait G2PModel: Send + Sync {
fn predict(&self, word: &str) -> Option<G2PResult>;
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum LookupSource {
UserOverlay,
BaseDictionary,
G2PModel {
confidence: f32,
},
}
pub struct FallbackDict<M: G2PModel> {
dict: PronunciationDict,
model: M,
}
impl<M: G2PModel> FallbackDict<M> {
#[must_use]
pub fn new(dict: PronunciationDict, model: M) -> Self {
Self { dict, model }
}
#[must_use]
pub fn dict(&self) -> &PronunciationDict {
&self.dict
}
pub fn dict_mut(&mut self) -> &mut PronunciationDict {
&mut self.dict
}
#[must_use]
pub fn model(&self) -> &M {
&self.model
}
#[must_use]
pub fn lookup(&self, word: &str) -> Option<Vec<Phoneme>> {
if let Some(phonemes) = self.dict.lookup(word) {
return Some(phonemes.to_vec());
}
self.model
.predict(&word.to_lowercase())
.map(|r| r.into_phonemes())
}
#[must_use]
pub fn lookup_with_source(&self, word: &str) -> Option<(Vec<Phoneme>, LookupSource)> {
let key = alloc::string::ToString::to_string(&word.to_lowercase());
if let Some(entry) = self.dict.user_entries().get(&key) {
return Some((entry.primary_phonemes().to_vec(), LookupSource::UserOverlay));
}
if let Some(entry) = self.dict.entries().get(&key) {
return Some((
entry.primary_phonemes().to_vec(),
LookupSource::BaseDictionary,
));
}
let result = self.model.predict(&key)?;
let confidence = result.confidence();
Some((
result.into_phonemes(),
LookupSource::G2PModel { confidence },
))
}
#[must_use]
pub fn lookup_entry(&self, word: &str) -> Option<&DictEntry> {
self.dict.lookup_entry(word)
}
pub fn promote_prediction(&mut self, word: &str) -> bool {
if self.dict.lookup_entry(word).is_some() {
return false;
}
let key = alloc::string::ToString::to_string(&word.to_lowercase());
if let Some(result) = self.model.predict(&key) {
self.dict.insert_user(word, result.phonemes());
true
} else {
false
}
}
pub fn promote_if_confident(&mut self, word: &str, threshold: f32) -> bool {
if self.dict.lookup_entry(word).is_some() {
return false;
}
let key = alloc::string::ToString::to_string(&word.to_lowercase());
if let Some(result) = self.model.predict(&key)
&& result.confidence() >= threshold
{
self.dict.insert_user(word, result.phonemes());
return true;
}
false
}
#[must_use]
pub fn into_parts(self) -> (PronunciationDict, M) {
(self.dict, self.model)
}
}
impl<M: G2PModel + core::fmt::Debug> core::fmt::Debug for FallbackDict<M> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("FallbackDict")
.field("dict", &self.dict)
.field("model", &self.model)
.finish()
}
}
impl<M: G2PModel + Clone> Clone for FallbackDict<M> {
fn clone(&self) -> Self {
Self {
dict: self.dict.clone(),
model: self.model.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FstModel {
model_path: alloc::string::String,
notation: FstNotation,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum FstNotation {
Arpabet,
Ipa,
}
impl FstModel {
#[must_use]
pub fn new(model_path: &str, notation: FstNotation) -> Self {
Self {
model_path: alloc::string::ToString::to_string(model_path),
notation,
}
}
#[must_use]
pub fn model_path(&self) -> &str {
&self.model_path
}
#[must_use]
pub fn notation(&self) -> FstNotation {
self.notation
}
}
impl G2PModel for FstModel {
fn predict(&self, _word: &str) -> Option<G2PResult> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct MockG2PModel {
confidence: f32,
}
impl MockG2PModel {
fn new(confidence: f32) -> Self {
Self { confidence }
}
}
impl G2PModel for MockG2PModel {
fn predict(&self, _word: &str) -> Option<G2PResult> {
Some(G2PResult::new(
alloc::vec![Phoneme::VowelSchwa],
self.confidence,
))
}
}
#[derive(Debug, Clone)]
struct NoopModel;
impl G2PModel for NoopModel {
fn predict(&self, _word: &str) -> Option<G2PResult> {
None
}
}
#[test]
fn test_g2p_result_new() {
let result = G2PResult::new(alloc::vec![Phoneme::PlosiveK], 0.85);
assert_eq!(result.phonemes(), &[Phoneme::PlosiveK]);
assert_eq!(result.confidence(), 0.85);
}
#[test]
fn test_g2p_result_into_phonemes() {
let result = G2PResult::new(alloc::vec![Phoneme::PlosiveK, Phoneme::VowelAsh], 0.9);
let phonemes = result.into_phonemes();
assert_eq!(phonemes, alloc::vec![Phoneme::PlosiveK, Phoneme::VowelAsh]);
}
#[test]
fn test_g2p_result_serde_roundtrip() {
let result = G2PResult::new(alloc::vec![Phoneme::PlosiveK, Phoneme::VowelAsh], 0.75);
let json = serde_json::to_string(&result).unwrap();
let result2: G2PResult = serde_json::from_str(&json).unwrap();
assert_eq!(result, result2);
}
#[test]
fn test_lookup_source_serde_roundtrip() {
let sources = [
LookupSource::UserOverlay,
LookupSource::BaseDictionary,
LookupSource::G2PModel { confidence: 0.8 },
];
for source in &sources {
let json = serde_json::to_string(source).unwrap();
let source2: LookupSource = serde_json::from_str(&json).unwrap();
assert_eq!(source, &source2);
}
}
#[test]
fn test_fallback_dict_lookup_from_dict() {
let dict = PronunciationDict::english_minimal();
let fallback = FallbackDict::new(dict, MockG2PModel::new(0.5));
let phonemes = fallback.lookup("hello");
assert!(phonemes.is_some());
}
#[test]
fn test_fallback_dict_lookup_from_g2p() {
let dict = PronunciationDict::english_minimal();
let fallback = FallbackDict::new(dict, MockG2PModel::new(0.5));
let phonemes = fallback.lookup("xyzzy");
assert!(phonemes.is_some());
assert_eq!(phonemes.unwrap(), alloc::vec![Phoneme::VowelSchwa]);
}
#[test]
fn test_fallback_dict_lookup_none_when_no_model_result() {
let dict = PronunciationDict::english_minimal();
let fallback = FallbackDict::new(dict, NoopModel);
assert!(fallback.lookup("xyzzy").is_none());
}
#[test]
fn test_lookup_with_source_user_overlay() {
let mut dict = PronunciationDict::english_minimal();
dict.insert_user("custom", &[Phoneme::VowelA]);
let fallback = FallbackDict::new(dict, MockG2PModel::new(0.5));
let (_, source) = fallback.lookup_with_source("custom").unwrap();
assert_eq!(source, LookupSource::UserOverlay);
}
#[test]
fn test_lookup_with_source_base_dict() {
let dict = PronunciationDict::english_minimal();
let fallback = FallbackDict::new(dict, MockG2PModel::new(0.5));
let (_, source) = fallback.lookup_with_source("hello").unwrap();
assert_eq!(source, LookupSource::BaseDictionary);
}
#[test]
fn test_lookup_with_source_g2p() {
let dict = PronunciationDict::english_minimal();
let fallback = FallbackDict::new(dict, MockG2PModel::new(0.75));
let (phonemes, source) = fallback.lookup_with_source("xyzzy").unwrap();
assert_eq!(phonemes, alloc::vec![Phoneme::VowelSchwa]);
assert_eq!(source, LookupSource::G2PModel { confidence: 0.75 });
}
#[test]
fn test_promote_prediction() {
let dict = PronunciationDict::english_minimal();
let mut fallback = FallbackDict::new(dict, MockG2PModel::new(0.8));
assert!(fallback.promote_prediction("newword"));
assert!(!fallback.promote_prediction("newword"));
assert_eq!(fallback.dict().user_len(), 1);
}
#[test]
fn test_promote_prediction_skips_existing() {
let dict = PronunciationDict::english_minimal();
let mut fallback = FallbackDict::new(dict, MockG2PModel::new(0.8));
assert!(!fallback.promote_prediction("hello"));
assert_eq!(fallback.dict().user_len(), 0);
}
#[test]
fn test_promote_if_confident_above_threshold() {
let dict = PronunciationDict::english_minimal();
let mut fallback = FallbackDict::new(dict, MockG2PModel::new(0.8));
assert!(fallback.promote_if_confident("newword", 0.7));
assert_eq!(fallback.dict().user_len(), 1);
}
#[test]
fn test_promote_if_confident_below_threshold() {
let dict = PronunciationDict::english_minimal();
let mut fallback = FallbackDict::new(dict, MockG2PModel::new(0.5));
assert!(!fallback.promote_if_confident("newword", 0.7));
assert_eq!(fallback.dict().user_len(), 0);
}
#[test]
fn test_into_parts() {
let dict = PronunciationDict::english_minimal();
let fallback = FallbackDict::new(dict.clone(), MockG2PModel::new(0.5));
let (recovered_dict, _model) = fallback.into_parts();
assert_eq!(recovered_dict.len(), dict.len());
}
#[test]
fn test_fallback_dict_lookup_entry() {
let dict = PronunciationDict::english_minimal();
let fallback = FallbackDict::new(dict, MockG2PModel::new(0.5));
assert!(fallback.lookup_entry("hello").is_some());
assert!(fallback.lookup_entry("xyzzy").is_none());
}
#[test]
fn test_user_overlay_takes_precedence_over_g2p() {
let mut dict = PronunciationDict::new();
dict.insert_user("test", &[Phoneme::PlosiveT]);
let fallback = FallbackDict::new(dict, MockG2PModel::new(0.9));
let (phonemes, source) = fallback.lookup_with_source("test").unwrap();
assert_eq!(phonemes, alloc::vec![Phoneme::PlosiveT]);
assert_eq!(source, LookupSource::UserOverlay);
}
#[test]
fn test_fst_model_serde_roundtrip() {
let model = FstModel::new("/path/to/model.fst", FstNotation::Arpabet);
let json = serde_json::to_string(&model).unwrap();
let model2: FstModel = serde_json::from_str(&json).unwrap();
assert_eq!(model.model_path(), model2.model_path());
assert_eq!(model.notation(), model2.notation());
}
#[test]
fn test_fst_notation_serde_roundtrip() {
for notation in [FstNotation::Arpabet, FstNotation::Ipa] {
let json = serde_json::to_string(¬ation).unwrap();
let notation2: FstNotation = serde_json::from_str(&json).unwrap();
assert_eq!(notation, notation2);
}
}
}