Skip to main content

piper_plus/phonemize/
custom_dict.rs

1//! カスタム辞書モジュール
2//!
3//! 技術用語や固有名詞の読みを管理し、テキスト前処理を行う。
4//! Python `custom_dict.py` と同一ロジックの Rust 移植。
5//!
6//! ## JSON 辞書フォーマット
7//!
8//! **v1.0** (単純形式):
9//! ```json
10//! { "version": "1.0", "entries": { "API": "エーピーアイ" } }
11//! ```
12//!
13//! **v2.0** (詳細形式):
14//! ```json
15//! { "version": "2.0", "entries": { "API": { "pronunciation": "エーピーアイ", "priority": 5 } } }
16//! ```
17
18use std::collections::HashMap;
19use std::path::Path;
20use std::sync::Mutex;
21
22use regex::Regex;
23use serde::Deserialize;
24
25use crate::error::PiperError;
26
27// ---------------------------------------------------------------------------
28// Types
29// ---------------------------------------------------------------------------
30
31/// 辞書エントリ
32#[derive(Debug, Clone)]
33pub struct DictEntry {
34    pub pronunciation: String,
35    pub priority: i32,
36}
37
38/// JSON v2.0 のエントリ表現 (デシリアライズ用)
39#[derive(Debug, Deserialize)]
40#[serde(untagged)]
41enum RawEntry {
42    /// v1.0 互換: 文字列のみ
43    Simple(String),
44    /// v2.0: pronunciation + optional priority
45    Detailed(DetailedEntry),
46}
47
48#[derive(Debug, Deserialize)]
49struct DetailedEntry {
50    pronunciation: String,
51    #[serde(default = "default_priority")]
52    priority: i32,
53}
54
55fn default_priority() -> i32 {
56    5
57}
58
59/// JSON 辞書ファイルのトップレベル構造
60#[derive(Debug, Deserialize)]
61struct DictFile {
62    /// バージョン文字列 (将来の拡張用に保持)
63    #[serde(default = "default_version")]
64    #[allow(dead_code)]
65    version: String,
66    #[serde(default)]
67    entries: HashMap<String, RawEntry>,
68}
69
70fn default_version() -> String {
71    "1.0".to_string()
72}
73
74// ---------------------------------------------------------------------------
75// CustomDictionary
76// ---------------------------------------------------------------------------
77
78/// カスタム辞書
79///
80/// 技術用語・固有名詞の読みを保持し、テキスト中の該当箇所を置換する。
81///
82/// - 大文字小文字混在の単語 (例: "GitHub") は case-sensitive マップに格納
83/// - 全大文字/全小文字の単語は lowercase 正規化して case-insensitive マップに格納
84/// - 日本語 (非 ASCII) 文字を含む単語は単純部分文字列マッチ
85/// - ASCII のみの単語は単語境界正規表現でマッチ
86pub struct CustomDictionary {
87    /// Case-insensitive エントリ (キーは lowercase 正規化済み)
88    entries: HashMap<String, DictEntry>,
89    /// Case-sensitive エントリ (混在ケースの単語)
90    case_sensitive_entries: HashMap<String, DictEntry>,
91    /// コンパイル済み正規表現キャッシュ (interior mutability で &self から挿入可能)
92    pattern_cache: Mutex<HashMap<String, Regex>>,
93}
94
95impl CustomDictionary {
96    /// 空の辞書を作成
97    pub fn new() -> Self {
98        Self {
99            entries: HashMap::new(),
100            case_sensitive_entries: HashMap::new(),
101            pattern_cache: Mutex::new(HashMap::new()),
102        }
103    }
104
105    /// JSON 辞書ファイルを読み込む (v1.0 / v2.0 対応)
106    pub fn load_dictionary(&mut self, path: &Path) -> Result<(), PiperError> {
107        let content = std::fs::read_to_string(path).map_err(|_| PiperError::DictionaryLoad {
108            path: path.display().to_string(),
109        })?;
110
111        let dict_file: DictFile =
112            serde_json::from_str(&content).map_err(|e| PiperError::DictionaryLoad {
113                path: format!("{}: {}", path.display(), e),
114            })?;
115
116        for (word, raw_entry) in dict_file.entries {
117            // v2.0: コメント行スキップ
118            if word.starts_with("//") {
119                continue;
120            }
121
122            let entry = match raw_entry {
123                RawEntry::Simple(pronunciation) => DictEntry {
124                    pronunciation,
125                    priority: default_priority(),
126                },
127                RawEntry::Detailed(d) => DictEntry {
128                    pronunciation: d.pronunciation,
129                    priority: d.priority,
130                },
131            };
132
133            self.add_entry(&word, entry);
134        }
135
136        Ok(())
137    }
138
139    /// テキストに辞書を適用して単語を置換
140    ///
141    /// 1. Case-sensitive エントリを長い順に処理
142    /// 2. Case-insensitive エントリを長い順に処理
143    pub fn apply_to_text(&self, text: &str) -> String {
144        let mut result = text.to_string();
145
146        // Case-sensitive エントリ (長い順)
147        let mut cs_entries: Vec<_> = self.case_sensitive_entries.iter().collect();
148        cs_entries.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
149
150        for (word, entry) in &cs_entries {
151            let pattern = self.get_word_pattern(word, true);
152            result = pattern
153                .replace_all(&result, entry.pronunciation.as_str())
154                .to_string();
155        }
156
157        // Case-insensitive エントリ (長い順)
158        let mut ci_entries: Vec<_> = self.entries.iter().collect();
159        ci_entries.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
160
161        for (word, entry) in &ci_entries {
162            let pattern = self.get_word_pattern(word, false);
163            result = pattern
164                .replace_all(&result, entry.pronunciation.as_str())
165                .to_string();
166        }
167
168        result
169    }
170
171    /// 単語と読みを追加
172    ///
173    /// 既存エントリより優先度が低い場合は追加しない。
174    /// パターンキャッシュはクリアされる。
175    pub fn add_word(&mut self, word: &str, pronunciation: &str, priority: i32) {
176        let entry = DictEntry {
177            pronunciation: pronunciation.to_string(),
178            priority,
179        };
180        self.add_entry(word, entry);
181        self.pattern_cache.lock().unwrap().clear();
182    }
183
184    /// 単語の読みを取得
185    ///
186    /// Case-sensitive マップを先に検索し、見つからなければ case-insensitive マップを検索。
187    pub fn get_pronunciation(&self, word: &str) -> Option<&str> {
188        // Case-sensitive を先にチェック
189        if let Some(entry) = self.case_sensitive_entries.get(word) {
190            return Some(&entry.pronunciation);
191        }
192
193        // Case-insensitive (lowercase 正規化)
194        let normalized = word.to_lowercase();
195        self.entries
196            .get(&normalized)
197            .map(|e| e.pronunciation.as_str())
198    }
199
200    // -----------------------------------------------------------------------
201    // Internal helpers
202    // -----------------------------------------------------------------------
203
204    /// エントリを適切なマップに追加
205    fn add_entry(&mut self, word: &str, entry: DictEntry) {
206        let lower = word.to_lowercase();
207        let upper = word.to_uppercase();
208
209        if word != lower && word != upper {
210            // 大文字小文字混在 → case-sensitive マップ
211            self.case_sensitive_entries.insert(word.to_string(), entry);
212        } else {
213            // 全大文字 or 全小文字 → lowercase 正規化して case-insensitive マップ
214            let normalized = lower;
215
216            if let Some(existing) = self.entries.get(&normalized)
217                && entry.priority <= existing.priority
218            {
219                return; // 既存の方が優先度が高い (または同じ)
220            }
221
222            self.entries.insert(normalized, entry);
223        }
224    }
225
226    /// 単語の正規表現パターンを取得 (キャッシュ利用)
227    fn get_word_pattern(&self, word: &str, case_sensitive: bool) -> Regex {
228        let cache_key = format!("{}_{}", word, case_sensitive);
229
230        let mut cache = self.pattern_cache.lock().unwrap();
231        if let Some(cached) = cache.get(&cache_key) {
232            return cached.clone();
233        }
234
235        let escaped = regex::escape(word);
236
237        // 非 ASCII 文字を含むかチェック (日本語等)
238        let has_non_ascii = word.chars().any(|c| c as u32 > 127);
239
240        let pattern_str = if has_non_ascii {
241            // 日本語を含む場合: 単純部分文字列マッチ
242            escaped
243        } else {
244            // ASCII のみ: ASCII ワード境界で区切る
245            // (?-u:\b) は ASCII のみの \b — 日本語文字の隣でも正しく動作する
246            format!(r"(?-u:\b){}(?-u:\b)", escaped)
247        };
248
249        let pattern = if case_sensitive {
250            Regex::new(&pattern_str)
251        } else {
252            Regex::new(&format!("(?i){}", pattern_str))
253        };
254
255        let pat = pattern.expect("failed to compile regex pattern");
256        cache.insert(cache_key, pat.clone());
257        pat
258    }
259}
260
261impl Default for CustomDictionary {
262    fn default() -> Self {
263        Self::new()
264    }
265}
266
267// ---------------------------------------------------------------------------
268// Tests
269// ---------------------------------------------------------------------------
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use std::io::Write;
275    use std::sync::atomic::{AtomicU32, Ordering};
276
277    static COUNTER: AtomicU32 = AtomicU32::new(0);
278
279    /// テスト用に一時 JSON ファイルを作成するヘルパー
280    fn write_temp_json(content: &str) -> std::path::PathBuf {
281        let id = COUNTER.fetch_add(1, Ordering::SeqCst);
282        let path = std::env::temp_dir().join(format!(
283            "piper_test_dict_{}_{}.json",
284            std::process::id(),
285            id
286        ));
287        let mut f = std::fs::File::create(&path).unwrap();
288        f.write_all(content.as_bytes()).unwrap();
289        f.flush().unwrap();
290        path
291    }
292
293    // ----- v1.0 / v2.0 ロード -----
294
295    #[test]
296    fn test_load_v1_dictionary() {
297        let json = r#"{
298            "version": "1.0",
299            "entries": {
300                "API": "エーピーアイ",
301                "CPU": "シーピーユー"
302            }
303        }"#;
304        let f = write_temp_json(json);
305
306        let mut dict = CustomDictionary::new();
307        dict.load_dictionary(&f).unwrap();
308
309        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
310        assert_eq!(dict.get_pronunciation("cpu"), Some("シーピーユー"));
311    }
312
313    #[test]
314    fn test_load_v2_dictionary() {
315        let json = r#"{
316            "version": "2.0",
317            "entries": {
318                "API": {"pronunciation": "エーピーアイ", "priority": 8},
319                "GPU": {"pronunciation": "ジーピーユー"}
320            }
321        }"#;
322        let f = write_temp_json(json);
323
324        let mut dict = CustomDictionary::new();
325        dict.load_dictionary(&f).unwrap();
326
327        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
328        assert_eq!(dict.get_pronunciation("gpu"), Some("ジーピーユー"));
329    }
330
331    #[test]
332    fn test_v2_comment_lines_skipped() {
333        let json = r#"{
334            "version": "2.0",
335            "entries": {
336                "// this is a comment": {"pronunciation": "ignored", "priority": 1},
337                "API": {"pronunciation": "エーピーアイ", "priority": 5}
338            }
339        }"#;
340        let f = write_temp_json(json);
341
342        let mut dict = CustomDictionary::new();
343        dict.load_dictionary(&f).unwrap();
344
345        // コメント行は登録されない
346        assert_eq!(dict.get_pronunciation("// this is a comment"), None);
347        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
348    }
349
350    #[test]
351    fn test_load_nonexistent_file() {
352        let mut dict = CustomDictionary::new();
353        let result = dict.load_dictionary(Path::new("/no/such/file.json"));
354        assert!(result.is_err());
355    }
356
357    // ----- Case sensitivity -----
358
359    #[test]
360    fn test_case_sensitivity() {
361        let mut dict = CustomDictionary::new();
362
363        // 混在ケース → case-sensitive マップ
364        dict.add_word("GitHub", "ギットハブ", 5);
365        // 全大文字 → case-insensitive マップ (lowercase 正規化)
366        dict.add_word("API", "エーピーアイ", 5);
367
368        // case-sensitive: 完全一致のみ
369        assert_eq!(dict.get_pronunciation("GitHub"), Some("ギットハブ"));
370        // "github" (全小文字) は case-sensitive マップにないので None
371        // ただし case-insensitive マップにも登録されていないので None
372        assert_eq!(dict.get_pronunciation("github"), None);
373
374        // case-insensitive: どのケースでも取得可能
375        assert_eq!(dict.get_pronunciation("API"), Some("エーピーアイ"));
376        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ"));
377        assert_eq!(dict.get_pronunciation("Api"), Some("エーピーアイ"));
378    }
379
380    // ----- Priority -----
381
382    #[test]
383    fn test_priority_ordering() {
384        let mut dict = CustomDictionary::new();
385
386        dict.add_word("API", "エーピーアイ低", 3);
387        dict.add_word("API", "エーピーアイ高", 7);
388        // 優先度 7 > 3 なので上書きされる
389        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
390
391        // 同じ優先度では上書きされない
392        dict.add_word("API", "エーピーアイ同", 7);
393        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
394
395        // 低い優先度では上書きされない
396        dict.add_word("API", "エーピーアイ低2", 2);
397        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ高"));
398    }
399
400    // ----- Japanese word matching -----
401
402    #[test]
403    fn test_japanese_word_matching() {
404        let mut dict = CustomDictionary::new();
405        dict.add_word("東京都", "トウキョウト", 5);
406
407        let result = dict.apply_to_text("私は東京都に住んでいます");
408        assert_eq!(result, "私はトウキョウトに住んでいます");
409    }
410
411    #[test]
412    fn test_japanese_substring_no_boundary() {
413        let mut dict = CustomDictionary::new();
414        dict.add_word("京都", "キョウト", 5);
415        dict.add_word("東京都", "トウキョウト", 5);
416
417        // 長い方が先にマッチ → 「東京都」が置換される
418        let result = dict.apply_to_text("東京都と京都");
419        assert_eq!(result, "トウキョウトとキョウト");
420    }
421
422    // ----- English word boundary matching -----
423
424    #[test]
425    fn test_english_word_boundary() {
426        let mut dict = CustomDictionary::new();
427        dict.add_word("API", "エーピーアイ", 5);
428
429        // 単語境界あり → マッチ
430        assert_eq!(dict.apply_to_text("Use API here"), "Use エーピーアイ here");
431
432        // 英数字に隣接 → マッチしない
433        assert_eq!(dict.apply_to_text("UseAPIhere"), "UseAPIhere");
434
435        // 記号に隣接 → マッチ
436        assert_eq!(dict.apply_to_text("(API)"), "(エーピーアイ)");
437    }
438
439    #[test]
440    fn test_english_case_insensitive_matching() {
441        let mut dict = CustomDictionary::new();
442        dict.add_word("CPU", "シーピーユー", 5);
443
444        // case-insensitive: 大文字小文字問わずマッチ
445        assert_eq!(dict.apply_to_text("my cpu"), "my シーピーユー");
446        assert_eq!(dict.apply_to_text("my CPU"), "my シーピーユー");
447    }
448
449    // ----- apply_to_text with mixed text -----
450
451    #[test]
452    fn test_apply_mixed_ja_en_text() {
453        let mut dict = CustomDictionary::new();
454        dict.add_word("GitHub", "ギットハブ", 5);
455        dict.add_word("API", "エーピーアイ", 5);
456        dict.add_word("東京", "トウキョウ", 5);
457
458        let input = "東京のGitHubでAPI開発";
459        let result = dict.apply_to_text(input);
460        assert_eq!(result, "トウキョウのギットハブでエーピーアイ開発");
461    }
462
463    #[test]
464    fn test_apply_case_sensitive_before_insensitive() {
465        let mut dict = CustomDictionary::new();
466        // "iOS" は混在ケース → case-sensitive
467        dict.add_word("iOS", "アイオーエス", 5);
468        // "android" は全小文字 → case-insensitive
469        dict.add_word("android", "アンドロイド", 5);
470
471        let result = dict.apply_to_text("iOS and Android");
472        assert_eq!(result, "アイオーエス and アンドロイド");
473
474        // "ios" (全小文字) は case-sensitive マップの "iOS" にマッチしない
475        // case-insensitive マップにも無いのでそのまま
476        let result2 = dict.apply_to_text("ios test");
477        assert_eq!(result2, "ios test");
478    }
479
480    // ----- Longest match first -----
481
482    #[test]
483    fn test_longest_match_first() {
484        let mut dict = CustomDictionary::new();
485        dict.add_word("DB", "ディービー", 5);
486        dict.add_word("DBMS", "ディービーエムエス", 5);
487
488        // "DBMS" が先にマッチし、残った部分に "DB" はマッチしない
489        let result = dict.apply_to_text("DBMS and DB");
490        assert_eq!(result, "ディービーエムエス and ディービー");
491    }
492
493    // ----- Default constructor -----
494
495    #[test]
496    fn test_default_empty() {
497        let dict = CustomDictionary::default();
498        assert_eq!(dict.get_pronunciation("anything"), None);
499    }
500
501    // ----- Multiple dictionaries -----
502
503    #[test]
504    fn test_load_multiple_dictionaries() {
505        let json1 = r#"{
506            "version": "2.0",
507            "entries": {
508                "API": {"pronunciation": "エーピーアイ", "priority": 3}
509            }
510        }"#;
511        let json2 = r#"{
512            "version": "2.0",
513            "entries": {
514                "API": {"pronunciation": "エーピーアイ改", "priority": 8},
515                "GPU": {"pronunciation": "ジーピーユー", "priority": 5}
516            }
517        }"#;
518        let f1 = write_temp_json(json1);
519        let f2 = write_temp_json(json2);
520
521        let mut dict = CustomDictionary::new();
522        dict.load_dictionary(&f1).unwrap();
523        dict.load_dictionary(&f2).unwrap();
524
525        // 2番目のファイルの方が優先度が高い → 上書き
526        assert_eq!(dict.get_pronunciation("api"), Some("エーピーアイ改"));
527        assert_eq!(dict.get_pronunciation("gpu"), Some("ジーピーユー"));
528    }
529}