1use std::collections::HashMap;
19use std::path::Path;
20use std::sync::Mutex;
21
22use regex::Regex;
23use serde::Deserialize;
24
25use crate::error::PiperError;
26
27#[derive(Debug, Clone)]
33pub struct DictEntry {
34 pub pronunciation: String,
35 pub priority: i32,
36}
37
38#[derive(Debug, Deserialize)]
40#[serde(untagged)]
41enum RawEntry {
42 Simple(String),
44 Detailed(DetailedEntry),
46}
47
48#[derive(Debug, Deserialize)]
49struct DetailedEntry {
50 pronunciation: String,
51 #[serde(default = "default_priority")]
52 priority: i32,
53}
54
55fn default_priority() -> i32 {
56 5
57}
58
59#[derive(Debug, Deserialize)]
61struct DictFile {
62 #[serde(default = "default_version")]
64 #[allow(dead_code)]
65 version: String,
66 #[serde(default)]
67 entries: HashMap<String, RawEntry>,
68}
69
70fn default_version() -> String {
71 "1.0".to_string()
72}
73
74pub struct CustomDictionary {
87 entries: HashMap<String, DictEntry>,
89 case_sensitive_entries: HashMap<String, DictEntry>,
91 pattern_cache: Mutex<HashMap<String, Regex>>,
93}
94
95impl CustomDictionary {
96 pub fn new() -> Self {
98 Self {
99 entries: HashMap::new(),
100 case_sensitive_entries: HashMap::new(),
101 pattern_cache: Mutex::new(HashMap::new()),
102 }
103 }
104
105 pub fn load_dictionary(&mut self, path: &Path) -> Result<(), PiperError> {
107 let content = std::fs::read_to_string(path).map_err(|_| PiperError::DictionaryLoad {
108 path: path.display().to_string(),
109 })?;
110
111 let dict_file: DictFile =
112 serde_json::from_str(&content).map_err(|e| PiperError::DictionaryLoad {
113 path: format!("{}: {}", path.display(), e),
114 })?;
115
116 for (word, raw_entry) in dict_file.entries {
117 if word.starts_with("//") {
119 continue;
120 }
121
122 let entry = match raw_entry {
123 RawEntry::Simple(pronunciation) => DictEntry {
124 pronunciation,
125 priority: default_priority(),
126 },
127 RawEntry::Detailed(d) => DictEntry {
128 pronunciation: d.pronunciation,
129 priority: d.priority,
130 },
131 };
132
133 self.add_entry(&word, entry);
134 }
135
136 Ok(())
137 }
138
139 pub fn apply_to_text(&self, text: &str) -> String {
144 let mut result = text.to_string();
145
146 let mut cs_entries: Vec<_> = self.case_sensitive_entries.iter().collect();
148 cs_entries.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
149
150 for (word, entry) in &cs_entries {
151 let pattern = self.get_word_pattern(word, true);
152 result = pattern
153 .replace_all(&result, entry.pronunciation.as_str())
154 .to_string();
155 }
156
157 let mut ci_entries: Vec<_> = self.entries.iter().collect();
159 ci_entries.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
160
161 for (word, entry) in &ci_entries {
162 let pattern = self.get_word_pattern(word, false);
163 result = pattern
164 .replace_all(&result, entry.pronunciation.as_str())
165 .to_string();
166 }
167
168 result
169 }
170
171 pub fn add_word(&mut self, word: &str, pronunciation: &str, priority: i32) {
176 let entry = DictEntry {
177 pronunciation: pronunciation.to_string(),
178 priority,
179 };
180 self.add_entry(word, entry);
181 self.pattern_cache.lock().unwrap().clear();
182 }
183
184 pub fn get_pronunciation(&self, word: &str) -> Option<&str> {
188 if let Some(entry) = self.case_sensitive_entries.get(word) {
190 return Some(&entry.pronunciation);
191 }
192
193 let normalized = word.to_lowercase();
195 self.entries
196 .get(&normalized)
197 .map(|e| e.pronunciation.as_str())
198 }
199
200 fn add_entry(&mut self, word: &str, entry: DictEntry) {
206 let lower = word.to_lowercase();
207 let upper = word.to_uppercase();
208
209 if word != lower && word != upper {
210 self.case_sensitive_entries.insert(word.to_string(), entry);
212 } else {
213 let normalized = lower;
215
216 if let Some(existing) = self.entries.get(&normalized)
217 && entry.priority <= existing.priority
218 {
219 return; }
221
222 self.entries.insert(normalized, entry);
223 }
224 }
225
226 fn get_word_pattern(&self, word: &str, case_sensitive: bool) -> Regex {
228 let cache_key = format!("{}_{}", word, case_sensitive);
229
230 let mut cache = self.pattern_cache.lock().unwrap();
231 if let Some(cached) = cache.get(&cache_key) {
232 return cached.clone();
233 }
234
235 let escaped = regex::escape(word);
236
237 let has_non_ascii = word.chars().any(|c| c as u32 > 127);
239
240 let pattern_str = if has_non_ascii {
241 escaped
243 } else {
244 format!(r"(?-u:\b){}(?-u:\b)", escaped)
247 };
248
249 let pattern = if case_sensitive {
250 Regex::new(&pattern_str)
251 } else {
252 Regex::new(&format!("(?i){}", pattern_str))
253 };
254
255 let pat = pattern.expect("failed to compile regex pattern");
256 cache.insert(cache_key, pat.clone());
257 pat
258 }
259}
260
261impl Default for CustomDictionary {
262 fn default() -> Self {
263 Self::new()
264 }
265}
266
267#[cfg(test)]
272mod tests {
273 use super::*;
274 use std::io::Write;
275 use std::sync::atomic::{AtomicU32, Ordering};
276
277 static COUNTER: AtomicU32 = AtomicU32::new(0);
278
279 fn write_temp_json(content: &str) -> std::path::PathBuf {
281 let id = COUNTER.fetch_add(1, Ordering::SeqCst);
282 let path = std::env::temp_dir().join(format!(
283 "piper_test_dict_{}_{}.json",
284 std::process::id(),
285 id
286 ));
287 let mut f = std::fs::File::create(&path).unwrap();
288 f.write_all(content.as_bytes()).unwrap();
289 f.flush().unwrap();
290 path
291 }
292
293 #[test]
296 fn test_load_v1_dictionary() {
297 let json = r#"{
298 "version": "1.0",
299 "entries": {
300 "API": "エーピーアイ",
301 "CPU": "シーピーユー"
302 }
303 }"#;
304 let f = write_temp_json(json);
305
306 let mut dict = CustomDictionary::new();
307 dict.load_dictionary(&f).unwrap();
308
309 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
310 assert_eq!(dict.get_pronunciation("cpu"), Some("シーピーユー"));
311 }
312
313 #[test]
314 fn test_load_v2_dictionary() {
315 let json = r#"{
316 "version": "2.0",
317 "entries": {
318 "API": {"pronunciation": "エーピーアイ", "priority": 8},
319 "GPU": {"pronunciation": "ジーピーユー"}
320 }
321 }"#;
322 let f = write_temp_json(json);
323
324 let mut dict = CustomDictionary::new();
325 dict.load_dictionary(&f).unwrap();
326
327 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
328 assert_eq!(dict.get_pronunciation("gpu"), Some("ジーピーユー"));
329 }
330
331 #[test]
332 fn test_v2_comment_lines_skipped() {
333 let json = r#"{
334 "version": "2.0",
335 "entries": {
336 "// this is a comment": {"pronunciation": "ignored", "priority": 1},
337 "API": {"pronunciation": "エーピーアイ", "priority": 5}
338 }
339 }"#;
340 let f = write_temp_json(json);
341
342 let mut dict = CustomDictionary::new();
343 dict.load_dictionary(&f).unwrap();
344
345 assert_eq!(dict.get_pronunciation("// this is a comment"), None);
347 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
348 }
349
350 #[test]
351 fn test_load_nonexistent_file() {
352 let mut dict = CustomDictionary::new();
353 let result = dict.load_dictionary(Path::new("/no/such/file.json"));
354 assert!(result.is_err());
355 }
356
357 #[test]
360 fn test_case_sensitivity() {
361 let mut dict = CustomDictionary::new();
362
363 dict.add_word("GitHub", "ギットハブ", 5);
365 dict.add_word("API", "エーピーアイ", 5);
367
368 assert_eq!(dict.get_pronunciation("GitHub"), Some("ギットハブ"));
370 assert_eq!(dict.get_pronunciation("github"), None);
373
374 assert_eq!(dict.get_pronunciation("API"), Some("エーピーアイ"));
376 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
377 assert_eq!(dict.get_pronunciation("Api"), Some("エーピーアイ"));
378 }
379
380 #[test]
383 fn test_priority_ordering() {
384 let mut dict = CustomDictionary::new();
385
386 dict.add_word("API", "エーピーアイ低", 3);
387 dict.add_word("API", "エーピーアイ高", 7);
388 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
390
391 dict.add_word("API", "エーピーアイ同", 7);
393 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
394
395 dict.add_word("API", "エーピーアイ低2", 2);
397 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
398 }
399
400 #[test]
403 fn test_japanese_word_matching() {
404 let mut dict = CustomDictionary::new();
405 dict.add_word("東京都", "トウキョウト", 5);
406
407 let result = dict.apply_to_text("私は東京都に住んでいます");
408 assert_eq!(result, "私はトウキョウトに住んでいます");
409 }
410
411 #[test]
412 fn test_japanese_substring_no_boundary() {
413 let mut dict = CustomDictionary::new();
414 dict.add_word("京都", "キョウト", 5);
415 dict.add_word("東京都", "トウキョウト", 5);
416
417 let result = dict.apply_to_text("東京都と京都");
419 assert_eq!(result, "トウキョウトとキョウト");
420 }
421
422 #[test]
425 fn test_english_word_boundary() {
426 let mut dict = CustomDictionary::new();
427 dict.add_word("API", "エーピーアイ", 5);
428
429 assert_eq!(dict.apply_to_text("Use API here"), "Use エーピーアイ here");
431
432 assert_eq!(dict.apply_to_text("UseAPIhere"), "UseAPIhere");
434
435 assert_eq!(dict.apply_to_text("(API)"), "(エーピーアイ)");
437 }
438
439 #[test]
440 fn test_english_case_insensitive_matching() {
441 let mut dict = CustomDictionary::new();
442 dict.add_word("CPU", "シーピーユー", 5);
443
444 assert_eq!(dict.apply_to_text("my cpu"), "my シーピーユー");
446 assert_eq!(dict.apply_to_text("my CPU"), "my シーピーユー");
447 }
448
449 #[test]
452 fn test_apply_mixed_ja_en_text() {
453 let mut dict = CustomDictionary::new();
454 dict.add_word("GitHub", "ギットハブ", 5);
455 dict.add_word("API", "エーピーアイ", 5);
456 dict.add_word("東京", "トウキョウ", 5);
457
458 let input = "東京のGitHubでAPI開発";
459 let result = dict.apply_to_text(input);
460 assert_eq!(result, "トウキョウのギットハブでエーピーアイ開発");
461 }
462
463 #[test]
464 fn test_apply_case_sensitive_before_insensitive() {
465 let mut dict = CustomDictionary::new();
466 dict.add_word("iOS", "アイオーエス", 5);
468 dict.add_word("android", "アンドロイド", 5);
470
471 let result = dict.apply_to_text("iOS and Android");
472 assert_eq!(result, "アイオーエス and アンドロイド");
473
474 let result2 = dict.apply_to_text("ios test");
477 assert_eq!(result2, "ios test");
478 }
479
480 #[test]
483 fn test_longest_match_first() {
484 let mut dict = CustomDictionary::new();
485 dict.add_word("DB", "ディービー", 5);
486 dict.add_word("DBMS", "ディービーエムエス", 5);
487
488 let result = dict.apply_to_text("DBMS and DB");
490 assert_eq!(result, "ディービーエムエス and ディービー");
491 }
492
493 #[test]
496 fn test_default_empty() {
497 let dict = CustomDictionary::default();
498 assert_eq!(dict.get_pronunciation("anything"), None);
499 }
500
501 #[test]
504 fn test_load_multiple_dictionaries() {
505 let json1 = r#"{
506 "version": "2.0",
507 "entries": {
508 "API": {"pronunciation": "エーピーアイ", "priority": 3}
509 }
510 }"#;
511 let json2 = r#"{
512 "version": "2.0",
513 "entries": {
514 "API": {"pronunciation": "エーピーアイ改", "priority": 8},
515 "GPU": {"pronunciation": "ジーピーユー", "priority": 5}
516 }
517 }"#;
518 let f1 = write_temp_json(json1);
519 let f2 = write_temp_json(json2);
520
521 let mut dict = CustomDictionary::new();
522 dict.load_dictionary(&f1).unwrap();
523 dict.load_dictionary(&f2).unwrap();
524
525 assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ改"));
527 assert_eq!(dict.get_pronunciation("gpu"), Some("ジーピーユー"));
528 }
529}