use std::collections::HashMap;
use std::path::Path;
use std::sync::Mutex;
use regex::Regex;
use serde::Deserialize;
use crate::error::G2pError;
const MAX_DICT_SIZE: u64 = 10 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct DictEntry {
pub pronunciation: String,
pub priority: i32,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum RawEntry {
Simple(String),
Detailed(DetailedEntry),
}
#[derive(Debug, Deserialize)]
struct DetailedEntry {
pronunciation: String,
#[serde(default = "default_priority")]
priority: i32,
}
fn default_priority() -> i32 {
5
}
#[derive(Debug, Deserialize)]
struct DictFile {
#[serde(default = "default_version")]
#[allow(dead_code)]
version: String,
#[serde(default)]
entries: HashMap<String, RawEntry>,
}
fn default_version() -> String {
"1.0".to_string()
}
pub struct CustomDictionary {
entries: HashMap<String, DictEntry>,
case_sensitive_entries: HashMap<String, DictEntry>,
pattern_cache: Mutex<HashMap<String, Regex>>,
}
impl CustomDictionary {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
case_sensitive_entries: HashMap::new(),
pattern_cache: Mutex::new(HashMap::new()),
}
}
pub fn load_dictionary(&mut self, path: &Path) -> Result<(), G2pError> {
let metadata = std::fs::metadata(path).map_err(|_| G2pError::DictionaryLoad {
path: path.display().to_string(),
})?;
if metadata.len() > MAX_DICT_SIZE {
return Err(G2pError::DictionaryLoad {
path: format!(
"{}: file too large ({} bytes, max {})",
path.display(),
metadata.len(),
MAX_DICT_SIZE,
),
});
}
let content = std::fs::read_to_string(path).map_err(|_| G2pError::DictionaryLoad {
path: path.display().to_string(),
})?;
let dict_file: DictFile =
serde_json::from_str(&content).map_err(|e| G2pError::DictionaryLoad {
path: format!("{}: {}", path.display(), e),
})?;
for (word, raw_entry) in dict_file.entries {
if word.starts_with("//") {
continue;
}
let entry = match raw_entry {
RawEntry::Simple(pronunciation) => DictEntry {
pronunciation,
priority: default_priority(),
},
RawEntry::Detailed(d) => DictEntry {
pronunciation: d.pronunciation,
priority: d.priority,
},
};
self.add_entry(&word, entry);
}
Ok(())
}
pub fn apply_to_text(&self, text: &str) -> String {
let mut result = text.to_string();
let mut cs_entries: Vec<_> = self.case_sensitive_entries.iter().collect();
cs_entries.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
for (word, entry) in &cs_entries {
let pattern = self.get_word_pattern(word, true);
result = pattern
.replace_all(&result, entry.pronunciation.as_str())
.to_string();
}
let mut ci_entries: Vec<_> = self.entries.iter().collect();
ci_entries.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
for (word, entry) in &ci_entries {
let pattern = self.get_word_pattern(word, false);
result = pattern
.replace_all(&result, entry.pronunciation.as_str())
.to_string();
}
result
}
pub fn add_word(&mut self, word: &str, pronunciation: &str, priority: i32) {
let entry = DictEntry {
pronunciation: pronunciation.to_string(),
priority,
};
self.add_entry(word, entry);
self.pattern_cache.lock().unwrap().clear();
}
pub fn get_pronunciation(&self, word: &str) -> Option<&str> {
if let Some(entry) = self.case_sensitive_entries.get(word) {
return Some(&entry.pronunciation);
}
let normalized = word.to_lowercase();
self.entries
.get(&normalized)
.map(|e| e.pronunciation.as_str())
}
fn add_entry(&mut self, word: &str, entry: DictEntry) {
let lower = word.to_lowercase();
let upper = word.to_uppercase();
if word != lower && word != upper {
self.case_sensitive_entries.insert(word.to_string(), entry);
} else {
let normalized = lower;
if let Some(existing) = self.entries.get(&normalized)
&& entry.priority <= existing.priority
{
return; }
self.entries.insert(normalized, entry);
}
}
fn get_word_pattern(&self, word: &str, case_sensitive: bool) -> Regex {
let cache_key = format!("{}_{}", word, case_sensitive);
let mut cache = self.pattern_cache.lock().unwrap();
if let Some(cached) = cache.get(&cache_key) {
return cached.clone();
}
let escaped = regex::escape(word);
let has_non_ascii = word.chars().any(|c| c as u32 > 127);
let pattern_str = if has_non_ascii {
escaped
} else {
format!(r"(?-u:\b){}(?-u:\b)", escaped)
};
let pattern = if case_sensitive {
Regex::new(&pattern_str)
} else {
Regex::new(&format!("(?i){}", pattern_str))
};
let pat = pattern.expect("failed to compile regex pattern");
cache.insert(cache_key, pat.clone());
pat
}
}
impl Default for CustomDictionary {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::sync::atomic::{AtomicU32, Ordering};
static COUNTER: AtomicU32 = AtomicU32::new(0);
fn write_temp_json(content: &str) -> std::path::PathBuf {
let id = COUNTER.fetch_add(1, Ordering::SeqCst);
let path = std::env::temp_dir().join(format!(
"piper_test_dict_{}_{}.json",
std::process::id(),
id
));
let mut f = std::fs::File::create(&path).unwrap();
f.write_all(content.as_bytes()).unwrap();
f.flush().unwrap();
path
}
#[test]
fn test_load_v1_dictionary() {
let json = r#"{
"version": "1.0",
"entries": {
"API": "エーピーアイ",
"CPU": "シーピーユー"
}
}"#;
let f = write_temp_json(json);
let mut dict = CustomDictionary::new();
dict.load_dictionary(&f).unwrap();
assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
assert_eq!(dict.get_pronunciation("cpu"), Some("シーピーユー"));
}
#[test]
fn test_load_v2_dictionary() {
let json = r#"{
"version": "2.0",
"entries": {
"API": {"pronunciation": "エーピーアイ", "priority": 8},
"GPU": {"pronunciation": "ジーピーユー"}
}
}"#;
let f = write_temp_json(json);
let mut dict = CustomDictionary::new();
dict.load_dictionary(&f).unwrap();
assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
assert_eq!(dict.get_pronunciation("gpu"), Some("ジーピーユー"));
}
#[test]
fn test_v2_comment_lines_skipped() {
let json = r#"{
"version": "2.0",
"entries": {
"// this is a comment": {"pronunciation": "ignored", "priority": 1},
"API": {"pronunciation": "エーピーアイ", "priority": 5}
}
}"#;
let f = write_temp_json(json);
let mut dict = CustomDictionary::new();
dict.load_dictionary(&f).unwrap();
assert_eq!(dict.get_pronunciation("// this is a comment"), None);
assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
}
#[test]
fn test_load_nonexistent_file() {
let mut dict = CustomDictionary::new();
let result = dict.load_dictionary(Path::new("/no/such/file.json"));
assert!(result.is_err());
}
#[test]
fn test_load_file_too_large() {
let id = COUNTER.fetch_add(1, Ordering::SeqCst);
let path = std::env::temp_dir().join(format!(
"piper_test_dict_large_{}_{}.json",
std::process::id(),
id,
));
let size = (super::MAX_DICT_SIZE + 1) as usize;
let data = vec![b' '; size];
std::fs::write(&path, &data).unwrap();
let mut dict = CustomDictionary::new();
let result = dict.load_dictionary(&path);
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(
err_msg.contains("file too large"),
"error should mention 'file too large': {}",
err_msg
);
let _ = std::fs::remove_file(&path);
}
#[test]
fn test_case_sensitivity() {
let mut dict = CustomDictionary::new();
dict.add_word("GitHub", "ギットハブ", 5);
dict.add_word("API", "エーピーアイ", 5);
assert_eq!(dict.get_pronunciation("GitHub"), Some("ギットハブ"));
assert_eq!(dict.get_pronunciation("github"), None);
assert_eq!(dict.get_pronunciation("API"), Some("エーピーアイ"));
assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
assert_eq!(dict.get_pronunciation("Api"), Some("エーピーアイ"));
}
#[test]
fn test_priority_ordering() {
let mut dict = CustomDictionary::new();
dict.add_word("API", "エーピーアイ低", 3);
dict.add_word("API", "エーピーアイ高", 7);
assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
dict.add_word("API", "エーピーアイ同", 7);
assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
dict.add_word("API", "エーピーアイ低2", 2);
assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
}
#[test]
fn test_japanese_word_matching() {
let mut dict = CustomDictionary::new();
dict.add_word("東京都", "トウキョウト", 5);
let result = dict.apply_to_text("私は東京都に住んでいます");
assert_eq!(result, "私はトウキョウトに住んでいます");
}
#[test]
fn test_japanese_substring_no_boundary() {
let mut dict = CustomDictionary::new();
dict.add_word("京都", "キョウト", 5);
dict.add_word("東京都", "トウキョウト", 5);
let result = dict.apply_to_text("東京都と京都");
assert_eq!(result, "トウキョウトとキョウト");
}
#[test]
fn test_english_word_boundary() {
let mut dict = CustomDictionary::new();
dict.add_word("API", "エーピーアイ", 5);
assert_eq!(dict.apply_to_text("Use API here"), "Use エーピーアイ here");
assert_eq!(dict.apply_to_text("UseAPIhere"), "UseAPIhere");
assert_eq!(dict.apply_to_text("(API)"), "(エーピーアイ)");
}
#[test]
fn test_english_case_insensitive_matching() {
let mut dict = CustomDictionary::new();
dict.add_word("CPU", "シーピーユー", 5);
assert_eq!(dict.apply_to_text("my cpu"), "my シーピーユー");
assert_eq!(dict.apply_to_text("my CPU"), "my シーピーユー");
}
#[test]
fn test_apply_mixed_ja_en_text() {
let mut dict = CustomDictionary::new();
dict.add_word("GitHub", "ギットハブ", 5);
dict.add_word("API", "エーピーアイ", 5);
dict.add_word("東京", "トウキョウ", 5);
let input = "東京のGitHubでAPI開発";
let result = dict.apply_to_text(input);
assert_eq!(result, "トウキョウのギットハブでエーピーアイ開発");
}
#[test]
fn test_apply_case_sensitive_before_insensitive() {
let mut dict = CustomDictionary::new();
dict.add_word("iOS", "アイオーエス", 5);
dict.add_word("android", "アンドロイド", 5);
let result = dict.apply_to_text("iOS and Android");
assert_eq!(result, "アイオーエス and アンドロイド");
let result2 = dict.apply_to_text("ios test");
assert_eq!(result2, "ios test");
}
#[test]
fn test_longest_match_first() {
let mut dict = CustomDictionary::new();
dict.add_word("DB", "ディービー", 5);
dict.add_word("DBMS", "ディービーエムエス", 5);
let result = dict.apply_to_text("DBMS and DB");
assert_eq!(result, "ディービーエムエス and ディービー");
}
#[test]
fn test_default_empty() {
let dict = CustomDictionary::default();
assert_eq!(dict.get_pronunciation("anything"), None);
}
#[test]
fn test_load_multiple_dictionaries() {
let json1 = r#"{
"version": "2.0",
"entries": {
"API": {"pronunciation": "エーピーアイ", "priority": 3}
}
}"#;
let json2 = r#"{
"version": "2.0",
"entries": {
"API": {"pronunciation": "エーピーアイ改", "priority": 8},
"GPU": {"pronunciation": "ジーピーユー", "priority": 5}
}
}"#;
let f1 = write_temp_json(json1);
let f2 = write_temp_json(json2);
let mut dict = CustomDictionary::new();
dict.load_dictionary(&f1).unwrap();
dict.load_dictionary(&f2).unwrap();
assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ改"));
assert_eq!(dict.get_pronunciation("gpu"), Some("ジーピーユー"));
}
}