use dashmap::DashMap;
use liblevenshtein::phonetic::{zompist_rules_char, OnlinePhoneticTransducerChar, RewriteRuleChar};
use std::sync::Arc;
use super::SubwordEmbedding;
pub const DEFAULT_PHONETIC_WEIGHT: f64 = 0.3;
pub const DEFAULT_PHONETIC_FUEL: usize = 1000;
#[derive(Debug)]
pub struct PhoneticEmbedding {
orthographic: Arc<SubwordEmbedding>,
rules: Vec<RewriteRuleChar>,
phonetic_weight: f64,
normalization_cache: DashMap<String, String>,
max_cache_size: usize,
}
impl PhoneticEmbedding {
pub fn new(orthographic: SubwordEmbedding) -> Self {
Self {
orthographic: Arc::new(orthographic),
rules: zompist_rules_char(),
phonetic_weight: DEFAULT_PHONETIC_WEIGHT,
normalization_cache: DashMap::new(),
max_cache_size: 100_000,
}
}
pub fn from_arc(orthographic: Arc<SubwordEmbedding>) -> Self {
Self {
orthographic,
rules: zompist_rules_char(),
phonetic_weight: DEFAULT_PHONETIC_WEIGHT,
normalization_cache: DashMap::new(),
max_cache_size: 100_000,
}
}
pub fn with_rules(mut self, rules: Vec<RewriteRuleChar>) -> Self {
self.rules = rules;
self.normalization_cache.clear();
self
}
pub fn with_phonetic_weight(mut self, weight: f64) -> Self {
assert!(
(0.0..=1.0).contains(&weight),
"Phonetic weight must be in [0.0, 1.0], got {}",
weight
);
self.phonetic_weight = weight;
self
}
pub fn with_cache_size(mut self, size: usize) -> Self {
self.max_cache_size = size;
self
}
#[inline]
pub fn phonetic_weight(&self) -> f64 {
self.phonetic_weight
}
#[inline]
pub fn orthographic(&self) -> &SubwordEmbedding {
&self.orthographic
}
#[inline]
pub fn rules(&self) -> &[RewriteRuleChar] {
&self.rules
}
#[inline]
pub fn dim(&self) -> usize {
self.orthographic.dim()
}
#[inline]
pub fn vocab_size(&self) -> usize {
self.orthographic.vocab_size()
}
#[inline]
pub fn contains(&self, word: &str) -> bool {
self.orthographic.contains(word)
}
pub fn normalize(&self, word: &str) -> String {
if let Some(cached) = self.normalization_cache.get(word) {
return cached.clone();
}
let mut transducer = OnlinePhoneticTransducerChar::new(self.rules.clone());
let mut result = String::with_capacity(word.len());
for c in word.chars() {
for normalized_char in transducer.feed(c) {
result.push(normalized_char);
}
}
for c in transducer.finish() {
result.push(c);
}
if self.normalization_cache.len() < self.max_cache_size {
self.normalization_cache
.insert(word.to_string(), result.clone());
}
result
}
pub fn similarity(&self, word1: &str, word2: &str) -> f64 {
if word1 == word2 {
return 1.0;
}
let ortho_sim = self.orthographic.similarity(word1, word2) as f64;
if self.phonetic_weight == 0.0 {
return ortho_sim;
}
let norm1 = self.normalize(word1);
let norm2 = self.normalize(word2);
let phone_sim = if norm1 == norm2 {
1.0
} else {
self.orthographic.similarity(&norm1, &norm2) as f64
};
(1.0 - self.phonetic_weight) * ortho_sim + self.phonetic_weight * phone_sim
}
pub fn phonetic_similarity(&self, word1: &str, word2: &str) -> f64 {
let norm1 = self.normalize(word1);
let norm2 = self.normalize(word2);
if norm1 == norm2 {
1.0
} else {
self.orthographic.similarity(&norm1, &norm2) as f64
}
}
pub fn most_similar(&self, word: &str, k: usize) -> Vec<(String, f64)> {
let candidates = self.orthographic.most_similar(word, k * 2);
let mut scored: Vec<(String, f64)> = candidates
.into_iter()
.map(|(w, _)| {
let sim = self.similarity(word, &w);
(w, sim)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
scored
}
pub fn most_similar_phonetically(&self, word: &str, k: usize) -> Vec<(String, f64)> {
let normalized_query = self.normalize(word);
let candidates = self.orthographic.most_similar(&normalized_query, k * 3);
let mut scored: Vec<(String, f64)> = candidates
.into_iter()
.map(|(w, _)| {
let sim = self.phonetic_similarity(word, &w);
(w, sim)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
scored
}
pub fn clear_cache(&self) {
self.normalization_cache.clear();
}
pub fn cache_size(&self) -> usize {
self.normalization_cache.len()
}
}
impl Clone for PhoneticEmbedding {
fn clone(&self) -> Self {
Self {
orthographic: Arc::clone(&self.orthographic),
rules: self.rules.clone(),
phonetic_weight: self.phonetic_weight,
normalization_cache: DashMap::new(), max_cache_size: self.max_cache_size,
}
}
}
unsafe impl Send for PhoneticEmbedding {}
unsafe impl Sync for PhoneticEmbedding {}
#[cfg(test)]
mod tests {
use super::*;
use crate::embedding::SubwordEmbedding;
fn create_test_orthographic() -> SubwordEmbedding {
let vocab = vec![
"phone".to_string(),
"fone".to_string(),
"enough".to_string(),
"enuf".to_string(),
"knight".to_string(),
"night".to_string(),
"know".to_string(),
"no".to_string(),
];
let mut model = SubwordEmbedding::new(vocab, 50, 10000);
let embeddings = model.word_embeddings_mut();
embeddings[[0, 0]] = 1.0;
embeddings[[0, 1]] = 0.5;
embeddings[[1, 0]] = 1.0;
embeddings[[1, 1]] = 0.5;
embeddings[[2, 2]] = 1.0;
embeddings[[2, 3]] = 0.5;
embeddings[[3, 2]] = 1.0;
embeddings[[3, 3]] = 0.5;
embeddings[[4, 4]] = 1.0;
embeddings[[4, 5]] = 0.5;
embeddings[[5, 4]] = 1.0;
embeddings[[5, 5]] = 0.5;
embeddings[[6, 6]] = 1.0;
embeddings[[6, 7]] = 0.5;
embeddings[[7, 6]] = 1.0;
embeddings[[7, 7]] = 0.5;
model
}
#[test]
fn test_phonetic_embedding_creation() {
let ortho = create_test_orthographic();
let phonetic = PhoneticEmbedding::new(ortho);
assert_eq!(phonetic.dim(), 50);
assert_eq!(phonetic.vocab_size(), 8);
assert!(!phonetic.rules().is_empty());
}
#[test]
fn test_phonetic_weight() {
let ortho = create_test_orthographic();
let phonetic = PhoneticEmbedding::new(ortho).with_phonetic_weight(0.5);
assert_eq!(phonetic.phonetic_weight(), 0.5);
}
#[test]
#[should_panic(expected = "Phonetic weight must be in [0.0, 1.0]")]
fn test_invalid_phonetic_weight() {
let ortho = create_test_orthographic();
let _ = PhoneticEmbedding::new(ortho).with_phonetic_weight(1.5);
}
#[test]
fn test_normalize() {
let ortho = create_test_orthographic();
let phonetic = PhoneticEmbedding::new(ortho);
let norm = phonetic.normalize("phone");
assert_eq!(norm, phonetic.normalize("phone"));
}
#[test]
fn test_normalization_cache() {
let ortho = create_test_orthographic();
let phonetic = PhoneticEmbedding::new(ortho);
let _ = phonetic.normalize("phone");
assert_eq!(phonetic.cache_size(), 1);
let _ = phonetic.normalize("phone");
assert_eq!(phonetic.cache_size(), 1);
let _ = phonetic.normalize("enough");
assert_eq!(phonetic.cache_size(), 2);
phonetic.clear_cache();
assert_eq!(phonetic.cache_size(), 0);
}
#[test]
fn test_self_similarity() {
let ortho = create_test_orthographic();
let phonetic = PhoneticEmbedding::new(ortho);
assert_eq!(phonetic.similarity("phone", "phone"), 1.0);
assert_eq!(phonetic.similarity("enough", "enough"), 1.0);
}
#[test]
fn test_phonetic_similarity_identical_normalized() {
let ortho = create_test_orthographic();
let phonetic = PhoneticEmbedding::new(ortho);
let norm1 = phonetic.normalize("phone");
let norm2 = phonetic.normalize("phone");
assert_eq!(norm1, norm2);
let phone_sim = phonetic.phonetic_similarity("phone", "phone");
assert_eq!(phone_sim, 1.0);
}
#[test]
fn test_pure_orthographic_mode() {
let ortho = create_test_orthographic();
let phonetic = PhoneticEmbedding::new(ortho.clone()).with_phonetic_weight(0.0);
let ortho_sim = ortho.similarity("phone", "fone");
let combined_sim = phonetic.similarity("phone", "fone");
assert!((ortho_sim as f64 - combined_sim).abs() < 1e-6);
}
#[test]
fn test_clone() {
let ortho = create_test_orthographic();
let phonetic = PhoneticEmbedding::new(ortho).with_phonetic_weight(0.5);
let _ = phonetic.normalize("phone");
let cloned = phonetic.clone();
assert_eq!(cloned.phonetic_weight(), 0.5);
assert_eq!(cloned.cache_size(), 0); }
#[test]
fn test_most_similar() {
let ortho = create_test_orthographic();
let phonetic = PhoneticEmbedding::new(ortho);
let similar = phonetic.most_similar("phone", 3);
assert!(!similar.is_empty());
assert!(similar.len() <= 3);
}
}