Skip to main content

piper_plus_g2p/
custom_dict.rs

1//! カスタム辞書モジュール
2//!
3//! 技術用語や固有名詞の読みを管理し、テキスト前処理を行う。
4//! Python `custom_dict.py` と同一ロジックの Rust 移植。
5//!
6//! ## JSON 辞書フォーマット
7//!
8//! **v1.0** (単純形式):
9//! ```json
10//! { "version": "1.0", "entries": { "API": "エーピーアイ" } }
11//! ```
12//!
13//! **v2.0** (詳細形式):
14//! ```json
15//! { "version": "2.0", "entries": { "API": { "pronunciation": "エーピーアイ", "priority": 5 } } }
16//! ```
17
18use std::collections::HashMap;
19use std::path::Path;
20use std::sync::Mutex;
21
22use regex::Regex;
23use serde::Deserialize;
24
25use crate::error::G2pError;
26
27/// 辞書ファイルの最大サイズ (10 MB) — Python 側の制限と一致
28const MAX_DICT_SIZE: u64 = 10 * 1024 * 1024;
29
30// ---------------------------------------------------------------------------
31// Types
32// ---------------------------------------------------------------------------
33
34/// 辞書エントリ
35#[derive(Debug, Clone)]
36pub struct DictEntry {
37    pub pronunciation: String,
38    pub priority: i32,
39}
40
41/// JSON v2.0 のエントリ表現 (デシリアライズ用)
42#[derive(Debug, Deserialize)]
43#[serde(untagged)]
44enum RawEntry {
45    /// v1.0 互換: 文字列のみ
46    Simple(String),
47    /// v2.0: pronunciation + optional priority
48    Detailed(DetailedEntry),
49}
50
51#[derive(Debug, Deserialize)]
52struct DetailedEntry {
53    pronunciation: String,
54    #[serde(default = "default_priority")]
55    priority: i32,
56}
57
58fn default_priority() -> i32 {
59    5
60}
61
62/// JSON 辞書ファイルのトップレベル構造
63#[derive(Debug, Deserialize)]
64struct DictFile {
65    /// バージョン文字列 (将来の拡張用に保持)
66    #[serde(default = "default_version")]
67    #[allow(dead_code)]
68    version: String,
69    #[serde(default)]
70    entries: HashMap<String, RawEntry>,
71}
72
73fn default_version() -> String {
74    "1.0".to_string()
75}
76
77// ---------------------------------------------------------------------------
78// CustomDictionary
79// ---------------------------------------------------------------------------
80
81/// カスタム辞書
82///
83/// 技術用語・固有名詞の読みを保持し、テキスト中の該当箇所を置換する。
84///
85/// - 大文字小文字混在の単語 (例: "GitHub") は case-sensitive マップに格納
86/// - 全大文字/全小文字の単語は lowercase 正規化して case-insensitive マップに格納
87/// - 日本語 (非 ASCII) 文字を含む単語は単純部分文字列マッチ
88/// - ASCII のみの単語は単語境界正規表現でマッチ
89pub struct CustomDictionary {
90    /// Case-insensitive エントリ (キーは lowercase 正規化済み)
91    entries: HashMap<String, DictEntry>,
92    /// Case-sensitive エントリ (混在ケースの単語)
93    case_sensitive_entries: HashMap<String, DictEntry>,
94    /// コンパイル済み正規表現キャッシュ (interior mutability で &self から挿入可能)
95    pattern_cache: Mutex<HashMap<String, Regex>>,
96}
97
98impl CustomDictionary {
99    /// 空の辞書を作成
100    pub fn new() -> Self {
101        Self {
102            entries: HashMap::new(),
103            case_sensitive_entries: HashMap::new(),
104            pattern_cache: Mutex::new(HashMap::new()),
105        }
106    }
107
108    /// JSON 辞書ファイルを読み込む (v1.0 / v2.0 対応)
109    pub fn load_dictionary(&mut self, path: &Path) -> Result<(), G2pError> {
110        // ファイルサイズチェック (DoS 防止)
111        let metadata = std::fs::metadata(path).map_err(|_| G2pError::DictionaryLoad {
112            path: path.display().to_string(),
113        })?;
114        if metadata.len() > MAX_DICT_SIZE {
115            return Err(G2pError::DictionaryLoad {
116                path: format!(
117                    "{}: file too large ({} bytes, max {})",
118                    path.display(),
119                    metadata.len(),
120                    MAX_DICT_SIZE,
121                ),
122            });
123        }
124
125        let content = std::fs::read_to_string(path).map_err(|_| G2pError::DictionaryLoad {
126            path: path.display().to_string(),
127        })?;
128
129        let dict_file: DictFile =
130            serde_json::from_str(&content).map_err(|e| G2pError::DictionaryLoad {
131                path: format!("{}: {}", path.display(), e),
132            })?;
133
134        for (word, raw_entry) in dict_file.entries {
135            // v2.0: コメント行スキップ
136            if word.starts_with("//") {
137                continue;
138            }
139
140            let entry = match raw_entry {
141                RawEntry::Simple(pronunciation) => DictEntry {
142                    pronunciation,
143                    priority: default_priority(),
144                },
145                RawEntry::Detailed(d) => DictEntry {
146                    pronunciation: d.pronunciation,
147                    priority: d.priority,
148                },
149            };
150
151            self.add_entry(&word, entry);
152        }
153
154        Ok(())
155    }
156
157    /// テキストに辞書を適用して単語を置換
158    ///
159    /// 1. Case-sensitive エントリを長い順に処理
160    /// 2. Case-insensitive エントリを長い順に処理
161    pub fn apply_to_text(&self, text: &str) -> String {
162        let mut result = text.to_string();
163
164        // Case-sensitive エントリ (長い順)
165        let mut cs_entries: Vec<_> = self.case_sensitive_entries.iter().collect();
166        cs_entries.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
167
168        for (word, entry) in &cs_entries {
169            let pattern = self.get_word_pattern(word, true);
170            result = pattern
171                .replace_all(&result, entry.pronunciation.as_str())
172                .to_string();
173        }
174
175        // Case-insensitive エントリ (長い順)
176        let mut ci_entries: Vec<_> = self.entries.iter().collect();
177        ci_entries.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
178
179        for (word, entry) in &ci_entries {
180            let pattern = self.get_word_pattern(word, false);
181            result = pattern
182                .replace_all(&result, entry.pronunciation.as_str())
183                .to_string();
184        }
185
186        result
187    }
188
189    /// 単語と読みを追加
190    ///
191    /// 既存エントリより優先度が低い場合は追加しない。
192    /// パターンキャッシュはクリアされる。
193    pub fn add_word(&mut self, word: &str, pronunciation: &str, priority: i32) {
194        let entry = DictEntry {
195            pronunciation: pronunciation.to_string(),
196            priority,
197        };
198        self.add_entry(word, entry);
199        self.pattern_cache.lock().unwrap().clear();
200    }
201
202    /// 単語の読みを取得
203    ///
204    /// Case-sensitive マップを先に検索し、見つからなければ case-insensitive マップを検索。
205    pub fn get_pronunciation(&self, word: &str) -> Option<&str> {
206        // Case-sensitive を先にチェック
207        if let Some(entry) = self.case_sensitive_entries.get(word) {
208            return Some(&entry.pronunciation);
209        }
210
211        // Case-insensitive (lowercase 正規化)
212        let normalized = word.to_lowercase();
213        self.entries
214            .get(&normalized)
215            .map(|e| e.pronunciation.as_str())
216    }
217
218    // -----------------------------------------------------------------------
219    // Internal helpers
220    // -----------------------------------------------------------------------
221
222    /// エントリを適切なマップに追加
223    fn add_entry(&mut self, word: &str, entry: DictEntry) {
224        let lower = word.to_lowercase();
225        let upper = word.to_uppercase();
226
227        if word != lower && word != upper {
228            // 大文字小文字混在 → case-sensitive マップ
229            self.case_sensitive_entries.insert(word.to_string(), entry);
230        } else {
231            // 全大文字 or 全小文字 → lowercase 正規化して case-insensitive マップ
232            let normalized = lower;
233
234            if let Some(existing) = self.entries.get(&normalized)
235                && entry.priority <= existing.priority
236            {
237                return; // 既存の方が優先度が高い (または同じ)
238            }
239
240            self.entries.insert(normalized, entry);
241        }
242    }
243
244    /// 単語の正規表現パターンを取得 (キャッシュ利用)
245    fn get_word_pattern(&self, word: &str, case_sensitive: bool) -> Regex {
246        let cache_key = format!("{}_{}", word, case_sensitive);
247
248        let mut cache = self.pattern_cache.lock().unwrap();
249        if let Some(cached) = cache.get(&cache_key) {
250            return cached.clone();
251        }
252
253        let escaped = regex::escape(word);
254
255        // 非 ASCII 文字を含むかチェック (日本語等)
256        let has_non_ascii = word.chars().any(|c| c as u32 > 127);
257
258        let pattern_str = if has_non_ascii {
259            // 日本語を含む場合: 単純部分文字列マッチ
260            escaped
261        } else {
262            // ASCII のみ: ASCII ワード境界で区切る
263            // (?-u:\b) は ASCII のみの \b — 日本語文字の隣でも正しく動作する
264            format!(r"(?-u:\b){}(?-u:\b)", escaped)
265        };
266
267        let pattern = if case_sensitive {
268            Regex::new(&pattern_str)
269        } else {
270            Regex::new(&format!("(?i){}", pattern_str))
271        };
272
273        let pat = pattern.expect("failed to compile regex pattern");
274        cache.insert(cache_key, pat.clone());
275        pat
276    }
277}
278
279impl Default for CustomDictionary {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285// ---------------------------------------------------------------------------
286// Tests
287// ---------------------------------------------------------------------------
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use std::io::Write;
293    use std::sync::atomic::{AtomicU32, Ordering};
294
295    static COUNTER: AtomicU32 = AtomicU32::new(0);
296
297    /// テスト用に一時 JSON ファイルを作成するヘルパー
298    fn write_temp_json(content: &str) -> std::path::PathBuf {
299        let id = COUNTER.fetch_add(1, Ordering::SeqCst);
300        let path = std::env::temp_dir().join(format!(
301            "piper_test_dict_{}_{}.json",
302            std::process::id(),
303            id
304        ));
305        let mut f = std::fs::File::create(&path).unwrap();
306        f.write_all(content.as_bytes()).unwrap();
307        f.flush().unwrap();
308        path
309    }
310
311    // ----- v1.0 / v2.0 ロード -----
312
313    #[test]
314    fn test_load_v1_dictionary() {
315        let json = r#"{
316            "version": "1.0",
317            "entries": {
318                "API": "エーピーアイ",
319                "CPU": "シーピーユー"
320            }
321        }"#;
322        let f = write_temp_json(json);
323
324        let mut dict = CustomDictionary::new();
325        dict.load_dictionary(&f).unwrap();
326
327        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
328        assert_eq!(dict.get_pronunciation("cpu"), Some("シーピーユー"));
329    }
330
331    #[test]
332    fn test_load_v2_dictionary() {
333        let json = r#"{
334            "version": "2.0",
335            "entries": {
336                "API": {"pronunciation": "エーピーアイ", "priority": 8},
337                "GPU": {"pronunciation": "ジーピーユー"}
338            }
339        }"#;
340        let f = write_temp_json(json);
341
342        let mut dict = CustomDictionary::new();
343        dict.load_dictionary(&f).unwrap();
344
345        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
346        assert_eq!(dict.get_pronunciation("gpu"), Some("ジーピーユー"));
347    }
348
349    #[test]
350    fn test_v2_comment_lines_skipped() {
351        let json = r#"{
352            "version": "2.0",
353            "entries": {
354                "// this is a comment": {"pronunciation": "ignored", "priority": 1},
355                "API": {"pronunciation": "エーピーアイ", "priority": 5}
356            }
357        }"#;
358        let f = write_temp_json(json);
359
360        let mut dict = CustomDictionary::new();
361        dict.load_dictionary(&f).unwrap();
362
363        // コメント行は登録されない
364        assert_eq!(dict.get_pronunciation("// this is a comment"), None);
365        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
366    }
367
368    #[test]
369    fn test_load_nonexistent_file() {
370        let mut dict = CustomDictionary::new();
371        let result = dict.load_dictionary(Path::new("/no/such/file.json"));
372        assert!(result.is_err());
373    }
374
375    #[test]
376    fn test_load_file_too_large() {
377        // MAX_DICT_SIZE (10 MB) を超えるファイルは拒否される
378        let id = COUNTER.fetch_add(1, Ordering::SeqCst);
379        let path = std::env::temp_dir().join(format!(
380            "piper_test_dict_large_{}_{}.json",
381            std::process::id(),
382            id,
383        ));
384        // 10 MB + 1 byte のダミーファイルを作成
385        let size = (super::MAX_DICT_SIZE + 1) as usize;
386        let data = vec![b' '; size];
387        std::fs::write(&path, &data).unwrap();
388
389        let mut dict = CustomDictionary::new();
390        let result = dict.load_dictionary(&path);
391        assert!(result.is_err());
392
393        let err_msg = format!("{}", result.unwrap_err());
394        assert!(
395            err_msg.contains("file too large"),
396            "error should mention 'file too large': {}",
397            err_msg
398        );
399
400        // テスト後クリーンアップ
401        let _ = std::fs::remove_file(&path);
402    }
403
404    // ----- Case sensitivity -----
405
406    #[test]
407    fn test_case_sensitivity() {
408        let mut dict = CustomDictionary::new();
409
410        // 混在ケース → case-sensitive マップ
411        dict.add_word("GitHub", "ギットハブ", 5);
412        // 全大文字 → case-insensitive マップ (lowercase 正規化)
413        dict.add_word("API", "エーピーアイ", 5);
414
415        // case-sensitive: 完全一致のみ
416        assert_eq!(dict.get_pronunciation("GitHub"), Some("ギットハブ"));
417        // "github" (全小文字) は case-sensitive マップにないので None
418        // ただし case-insensitive マップにも登録されていないので None
419        assert_eq!(dict.get_pronunciation("github"), None);
420
421        // case-insensitive: どのケースでも取得可能
422        assert_eq!(dict.get_pronunciation("API"), Some("エーピーアイ"));
423        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
424        assert_eq!(dict.get_pronunciation("Api"), Some("エーピーアイ"));
425    }
426
427    // ----- Priority -----
428
429    #[test]
430    fn test_priority_ordering() {
431        let mut dict = CustomDictionary::new();
432
433        dict.add_word("API", "エーピーアイ低", 3);
434        dict.add_word("API", "エーピーアイ高", 7);
435        // 優先度 7 > 3 なので上書きされる
436        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
437
438        // 同じ優先度では上書きされない
439        dict.add_word("API", "エーピーアイ同", 7);
440        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
441
442        // 低い優先度では上書きされない
443        dict.add_word("API", "エーピーアイ低2", 2);
444        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
445    }
446
447    // ----- Japanese word matching -----
448
449    #[test]
450    fn test_japanese_word_matching() {
451        let mut dict = CustomDictionary::new();
452        dict.add_word("東京都", "トウキョウト", 5);
453
454        let result = dict.apply_to_text("私は東京都に住んでいます");
455        assert_eq!(result, "私はトウキョウトに住んでいます");
456    }
457
458    #[test]
459    fn test_japanese_substring_no_boundary() {
460        let mut dict = CustomDictionary::new();
461        dict.add_word("京都", "キョウト", 5);
462        dict.add_word("東京都", "トウキョウト", 5);
463
464        // 長い方が先にマッチ → 「東京都」が置換される
465        let result = dict.apply_to_text("東京都と京都");
466        assert_eq!(result, "トウキョウトとキョウト");
467    }
468
469    // ----- English word boundary matching -----
470
471    #[test]
472    fn test_english_word_boundary() {
473        let mut dict = CustomDictionary::new();
474        dict.add_word("API", "エーピーアイ", 5);
475
476        // 単語境界あり → マッチ
477        assert_eq!(dict.apply_to_text("Use API here"), "Use エーピーアイ here");
478
479        // 英数字に隣接 → マッチしない
480        assert_eq!(dict.apply_to_text("UseAPIhere"), "UseAPIhere");
481
482        // 記号に隣接 → マッチ
483        assert_eq!(dict.apply_to_text("(API)"), "(エーピーアイ)");
484    }
485
486    #[test]
487    fn test_english_case_insensitive_matching() {
488        let mut dict = CustomDictionary::new();
489        dict.add_word("CPU", "シーピーユー", 5);
490
491        // case-insensitive: 大文字小文字問わずマッチ
492        assert_eq!(dict.apply_to_text("my cpu"), "my シーピーユー");
493        assert_eq!(dict.apply_to_text("my CPU"), "my シーピーユー");
494    }
495
496    // ----- apply_to_text with mixed text -----
497
498    #[test]
499    fn test_apply_mixed_ja_en_text() {
500        let mut dict = CustomDictionary::new();
501        dict.add_word("GitHub", "ギットハブ", 5);
502        dict.add_word("API", "エーピーアイ", 5);
503        dict.add_word("東京", "トウキョウ", 5);
504
505        let input = "東京のGitHubでAPI開発";
506        let result = dict.apply_to_text(input);
507        assert_eq!(result, "トウキョウのギットハブでエーピーアイ開発");
508    }
509
510    #[test]
511    fn test_apply_case_sensitive_before_insensitive() {
512        let mut dict = CustomDictionary::new();
513        // "iOS" は混在ケース → case-sensitive
514        dict.add_word("iOS", "アイオーエス", 5);
515        // "android" は全小文字 → case-insensitive
516        dict.add_word("android", "アンドロイド", 5);
517
518        let result = dict.apply_to_text("iOS and Android");
519        assert_eq!(result, "アイオーエス and アンドロイド");
520
521        // "ios" (全小文字) は case-sensitive マップの "iOS" にマッチしない
522        // case-insensitive マップにも無いのでそのまま
523        let result2 = dict.apply_to_text("ios test");
524        assert_eq!(result2, "ios test");
525    }
526
527    // ----- Longest match first -----
528
529    #[test]
530    fn test_longest_match_first() {
531        let mut dict = CustomDictionary::new();
532        dict.add_word("DB", "ディービー", 5);
533        dict.add_word("DBMS", "ディービーエムエス", 5);
534
535        // "DBMS" が先にマッチし、残った部分に "DB" はマッチしない
536        let result = dict.apply_to_text("DBMS and DB");
537        assert_eq!(result, "ディービーエムエス and ディービー");
538    }
539
540    // ----- Default constructor -----
541
542    #[test]
543    fn test_default_empty() {
544        let dict = CustomDictionary::default();
545        assert_eq!(dict.get_pronunciation("anything"), None);
546    }
547
548    // ----- Multiple dictionaries -----
549
550    #[test]
551    fn test_load_multiple_dictionaries() {
552        let json1 = r#"{
553            "version": "2.0",
554            "entries": {
555                "API": {"pronunciation": "エーピーアイ", "priority": 3}
556            }
557        }"#;
558        let json2 = r#"{
559            "version": "2.0",
560            "entries": {
561                "API": {"pronunciation": "エーピーアイ改", "priority": 8},
562                "GPU": {"pronunciation": "ジーピーユー", "priority": 5}
563            }
564        }"#;
565        let f1 = write_temp_json(json1);
566        let f2 = write_temp_json(json2);
567
568        let mut dict = CustomDictionary::new();
569        dict.load_dictionary(&f1).unwrap();
570        dict.load_dictionary(&f2).unwrap();
571
572        // 2番目のファイルの方が優先度が高い → 上書き
573        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ改"));
574        assert_eq!(dict.get_pronunciation("gpu"), Some("ジーピーユー"));
575    }
576}