1use std::collections::HashMap;
19use std::path::Path;
20use std::sync::Mutex;
21
22use regex::Regex;
23use serde::Deserialize;
24
25use crate::error::G2pError;
26
27const MAX_DICT_SIZE: u64 = 10 * 1024 * 1024;
29
30#[derive(Debug, Clone)]
36pub struct DictEntry {
37 pub pronunciation: String,
38 pub priority: i32,
39}
40
41#[derive(Debug, Deserialize)]
43#[serde(untagged)]
44enum RawEntry {
45 Simple(String),
47 Detailed(DetailedEntry),
49}
50
51#[derive(Debug, Deserialize)]
52struct DetailedEntry {
53 pronunciation: String,
54 #[serde(default = "default_priority")]
55 priority: i32,
56}
57
58fn default_priority() -> i32 {
59 5
60}
61
62#[derive(Debug, Deserialize)]
64struct DictFile {
65 #[serde(default = "default_version")]
67 #[allow(dead_code)]
68 version: String,
69 #[serde(default)]
70 entries: HashMap<String, RawEntry>,
71}
72
73fn default_version() -> String {
74 "1.0".to_string()
75}
76
77pub struct CustomDictionary {
90 entries: HashMap<String, DictEntry>,
92 case_sensitive_entries: HashMap<String, DictEntry>,
94 pattern_cache: Mutex<HashMap<String, Regex>>,
96}
97
98impl CustomDictionary {
99 pub fn new() -> Self {
101 Self {
102 entries: HashMap::new(),
103 case_sensitive_entries: HashMap::new(),
104 pattern_cache: Mutex::new(HashMap::new()),
105 }
106 }
107
108 pub fn load_dictionary(&mut self, path: &Path) -> Result<(), G2pError> {
110 let metadata = std::fs::metadata(path).map_err(|_| G2pError::DictionaryLoad {
112 path: path.display().to_string(),
113 })?;
114 if metadata.len() > MAX_DICT_SIZE {
115 return Err(G2pError::DictionaryLoad {
116 path: format!(
117 "{}: file too large ({} bytes, max {})",
118 path.display(),
119 metadata.len(),
120 MAX_DICT_SIZE,
121 ),
122 });
123 }
124
125 let content = std::fs::read_to_string(path).map_err(|_| G2pError::DictionaryLoad {
126 path: path.display().to_string(),
127 })?;
128
129 let dict_file: DictFile =
130 serde_json::from_str(&content).map_err(|e| G2pError::DictionaryLoad {
131 path: format!("{}: {}", path.display(), e),
132 })?;
133
134 for (word, raw_entry) in dict_file.entries {
135 if word.starts_with("//") {
137 continue;
138 }
139
140 let entry = match raw_entry {
141 RawEntry::Simple(pronunciation) => DictEntry {
142 pronunciation,
143 priority: default_priority(),
144 },
145 RawEntry::Detailed(d) => DictEntry {
146 pronunciation: d.pronunciation,
147 priority: d.priority,
148 },
149 };
150
151 self.add_entry(&word, entry);
152 }
153
154 Ok(())
155 }
156
157 pub fn apply_to_text(&self, text: &str) -> String {
162 let mut result = text.to_string();
163
164 let mut cs_entries: Vec<_> = self.case_sensitive_entries.iter().collect();
166 cs_entries.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
167
168 for (word, entry) in &cs_entries {
169 let pattern = self.get_word_pattern(word, true);
170 result = pattern
171 .replace_all(&result, entry.pronunciation.as_str())
172 .to_string();
173 }
174
175 let mut ci_entries: Vec<_> = self.entries.iter().collect();
177 ci_entries.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
178
179 for (word, entry) in &ci_entries {
180 let pattern = self.get_word_pattern(word, false);
181 result = pattern
182 .replace_all(&result, entry.pronunciation.as_str())
183 .to_string();
184 }
185
186 result
187 }
188
189 pub fn add_word(&mut self, word: &str, pronunciation: &str, priority: i32) {
194 let entry = DictEntry {
195 pronunciation: pronunciation.to_string(),
196 priority,
197 };
198 self.add_entry(word, entry);
199 self.pattern_cache.lock().unwrap().clear();
200 }
201
202 pub fn get_pronunciation(&self, word: &str) -> Option<&str> {
206 if let Some(entry) = self.case_sensitive_entries.get(word) {
208 return Some(&entry.pronunciation);
209 }
210
211 let normalized = word.to_lowercase();
213 self.entries
214 .get(&normalized)
215 .map(|e| e.pronunciation.as_str())
216 }
217
218 fn add_entry(&mut self, word: &str, entry: DictEntry) {
224 let lower = word.to_lowercase();
225 let upper = word.to_uppercase();
226
227 if word != lower && word != upper {
228 self.case_sensitive_entries.insert(word.to_string(), entry);
230 } else {
231 let normalized = lower;
233
234 if let Some(existing) = self.entries.get(&normalized)
235 && entry.priority <= existing.priority
236 {
237 return; }
239
240 self.entries.insert(normalized, entry);
241 }
242 }
243
244 fn get_word_pattern(&self, word: &str, case_sensitive: bool) -> Regex {
246 let cache_key = format!("{}_{}", word, case_sensitive);
247
248 let mut cache = self.pattern_cache.lock().unwrap();
249 if let Some(cached) = cache.get(&cache_key) {
250 return cached.clone();
251 }
252
253 let escaped = regex::escape(word);
254
255 let has_non_ascii = word.chars().any(|c| c as u32 > 127);
257
258 let pattern_str = if has_non_ascii {
259 escaped
261 } else {
262 format!(r"(?-u:\b){}(?-u:\b)", escaped)
265 };
266
267 let pattern = if case_sensitive {
268 Regex::new(&pattern_str)
269 } else {
270 Regex::new(&format!("(?i){}", pattern_str))
271 };
272
273 let pat = pattern.expect("failed to compile regex pattern");
274 cache.insert(cache_key, pat.clone());
275 pat
276 }
277}
278
279impl Default for CustomDictionary {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285#[cfg(test)]
290mod tests {
291 use super::*;
292 use std::io::Write;
293 use std::sync::atomic::{AtomicU32, Ordering};
294
295 static COUNTER: AtomicU32 = AtomicU32::new(0);
296
297 fn write_temp_json(content: &str) -> std::path::PathBuf {
299 let id = COUNTER.fetch_add(1, Ordering::SeqCst);
300 let path = std::env::temp_dir().join(format!(
301 "piper_test_dict_{}_{}.json",
302 std::process::id(),
303 id
304 ));
305 let mut f = std::fs::File::create(&path).unwrap();
306 f.write_all(content.as_bytes()).unwrap();
307 f.flush().unwrap();
308 path
309 }
310
311 #[test]
314 fn test_load_v1_dictionary() {
315 let json = r#"{
316 "version": "1.0",
317 "entries": {
318 "API": "エーピーアイ",
319 "CPU": "シーピーユー"
320 }
321 }"#;
322 let f = write_temp_json(json);
323
324 let mut dict = CustomDictionary::new();
325 dict.load_dictionary(&f).unwrap();
326
327 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
328 assert_eq!(dict.get_pronunciation("cpu"), Some("シーピーユー"));
329 }
330
331 #[test]
332 fn test_load_v2_dictionary() {
333 let json = r#"{
334 "version": "2.0",
335 "entries": {
336 "API": {"pronunciation": "エーピーアイ", "priority": 8},
337 "GPU": {"pronunciation": "ジーピーユー"}
338 }
339 }"#;
340 let f = write_temp_json(json);
341
342 let mut dict = CustomDictionary::new();
343 dict.load_dictionary(&f).unwrap();
344
345 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
346 assert_eq!(dict.get_pronunciation("gpu"), Some("ジーピーユー"));
347 }
348
349 #[test]
350 fn test_v2_comment_lines_skipped() {
351 let json = r#"{
352 "version": "2.0",
353 "entries": {
354 "// this is a comment": {"pronunciation": "ignored", "priority": 1},
355 "API": {"pronunciation": "エーピーアイ", "priority": 5}
356 }
357 }"#;
358 let f = write_temp_json(json);
359
360 let mut dict = CustomDictionary::new();
361 dict.load_dictionary(&f).unwrap();
362
363 assert_eq!(dict.get_pronunciation("// this is a comment"), None);
365 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
366 }
367
368 #[test]
369 fn test_load_nonexistent_file() {
370 let mut dict = CustomDictionary::new();
371 let result = dict.load_dictionary(Path::new("/no/such/file.json"));
372 assert!(result.is_err());
373 }
374
375 #[test]
376 fn test_load_file_too_large() {
377 let id = COUNTER.fetch_add(1, Ordering::SeqCst);
379 let path = std::env::temp_dir().join(format!(
380 "piper_test_dict_large_{}_{}.json",
381 std::process::id(),
382 id,
383 ));
384 let size = (super::MAX_DICT_SIZE + 1) as usize;
386 let data = vec![b' '; size];
387 std::fs::write(&path, &data).unwrap();
388
389 let mut dict = CustomDictionary::new();
390 let result = dict.load_dictionary(&path);
391 assert!(result.is_err());
392
393 let err_msg = format!("{}", result.unwrap_err());
394 assert!(
395 err_msg.contains("file too large"),
396 "error should mention 'file too large': {}",
397 err_msg
398 );
399
400 let _ = std::fs::remove_file(&path);
402 }
403
404 #[test]
407 fn test_case_sensitivity() {
408 let mut dict = CustomDictionary::new();
409
410 dict.add_word("GitHub", "ギットハブ", 5);
412 dict.add_word("API", "エーピーアイ", 5);
414
415 assert_eq!(dict.get_pronunciation("GitHub"), Some("ギットハブ"));
417 assert_eq!(dict.get_pronunciation("github"), None);
420
421 assert_eq!(dict.get_pronunciation("API"), Some("エーピーアイ"));
423 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
424 assert_eq!(dict.get_pronunciation("Api"), Some("エーピーアイ"));
425 }
426
427 #[test]
430 fn test_priority_ordering() {
431 let mut dict = CustomDictionary::new();
432
433 dict.add_word("API", "エーピーアイ低", 3);
434 dict.add_word("API", "エーピーアイ高", 7);
435 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
437
438 dict.add_word("API", "エーピーアイ同", 7);
440 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
441
442 dict.add_word("API", "エーピーアイ低2", 2);
444 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
445 }
446
447 #[test]
450 fn test_japanese_word_matching() {
451 let mut dict = CustomDictionary::new();
452 dict.add_word("東京都", "トウキョウト", 5);
453
454 let result = dict.apply_to_text("私は東京都に住んでいます");
455 assert_eq!(result, "私はトウキョウトに住んでいます");
456 }
457
458 #[test]
459 fn test_japanese_substring_no_boundary() {
460 let mut dict = CustomDictionary::new();
461 dict.add_word("京都", "キョウト", 5);
462 dict.add_word("東京都", "トウキョウト", 5);
463
464 let result = dict.apply_to_text("東京都と京都");
466 assert_eq!(result, "トウキョウトとキョウト");
467 }
468
469 #[test]
472 fn test_english_word_boundary() {
473 let mut dict = CustomDictionary::new();
474 dict.add_word("API", "エーピーアイ", 5);
475
476 assert_eq!(dict.apply_to_text("Use API here"), "Use エーピーアイ here");
478
479 assert_eq!(dict.apply_to_text("UseAPIhere"), "UseAPIhere");
481
482 assert_eq!(dict.apply_to_text("(API)"), "(エーピーアイ)");
484 }
485
486 #[test]
487 fn test_english_case_insensitive_matching() {
488 let mut dict = CustomDictionary::new();
489 dict.add_word("CPU", "シーピーユー", 5);
490
491 assert_eq!(dict.apply_to_text("my cpu"), "my シーピーユー");
493 assert_eq!(dict.apply_to_text("my CPU"), "my シーピーユー");
494 }
495
496 #[test]
499 fn test_apply_mixed_ja_en_text() {
500 let mut dict = CustomDictionary::new();
501 dict.add_word("GitHub", "ギットハブ", 5);
502 dict.add_word("API", "エーピーアイ", 5);
503 dict.add_word("東京", "トウキョウ", 5);
504
505 let input = "東京のGitHubでAPI開発";
506 let result = dict.apply_to_text(input);
507 assert_eq!(result, "トウキョウのギットハブでエーピーアイ開発");
508 }
509
510 #[test]
511 fn test_apply_case_sensitive_before_insensitive() {
512 let mut dict = CustomDictionary::new();
513 dict.add_word("iOS", "アイオーエス", 5);
515 dict.add_word("android", "アンドロイド", 5);
517
518 let result = dict.apply_to_text("iOS and Android");
519 assert_eq!(result, "アイオーエス and アンドロイド");
520
521 let result2 = dict.apply_to_text("ios test");
524 assert_eq!(result2, "ios test");
525 }
526
527 #[test]
530 fn test_longest_match_first() {
531 let mut dict = CustomDictionary::new();
532 dict.add_word("DB", "ディービー", 5);
533 dict.add_word("DBMS", "ディービーエムエス", 5);
534
535 let result = dict.apply_to_text("DBMS and DB");
537 assert_eq!(result, "ディービーエムエス and ディービー");
538 }
539
540 #[test]
543 fn test_default_empty() {
544 let dict = CustomDictionary::default();
545 assert_eq!(dict.get_pronunciation("anything"), None);
546 }
547
548 #[test]
551 fn test_load_multiple_dictionaries() {
552 let json1 = r#"{
553 "version": "2.0",
554 "entries": {
555 "API": {"pronunciation": "エーピーアイ", "priority": 3}
556 }
557 }"#;
558 let json2 = r#"{
559 "version": "2.0",
560 "entries": {
561 "API": {"pronunciation": "エーピーアイ改", "priority": 8},
562 "GPU": {"pronunciation": "ジーピーユー", "priority": 5}
563 }
564 }"#;
565 let f1 = write_temp_json(json1);
566 let f2 = write_temp_json(json2);
567
568 let mut dict = CustomDictionary::new();
569 dict.load_dictionary(&f1).unwrap();
570 dict.load_dictionary(&f2).unwrap();
571
572 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ改"));
574 assert_eq!(dict.get_pronunciation("gpu"), Some("ジーピーユー"));
575 }
576}