Skip to main content

memory_indexer/tokenizer/
dictionary.rs

1use std::{
2    collections::{HashMap, HashSet},
3    time::{SystemTime, UNIX_EPOCH},
4};
5
6use super::{SegmentScript, TextNormalizer, TokenWithScript, script_runs, tokenize_char_ngrams};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DictionaryLanguage {
11    Japanese,
12    Hangul,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ScriptDictionary {
17    pub version: Option<String>,
18    pub entries: HashSet<String>,
19}
20
21impl ScriptDictionary {
22    pub fn is_empty(&self) -> bool {
23        self.entries.is_empty()
24    }
25}
26
27#[derive(Debug, Clone, Default, Serialize, Deserialize)]
28pub struct DictionaryConfig {
29    pub japanese: Option<ScriptDictionary>,
30    pub hangul: Option<ScriptDictionary>,
31}
32
33#[derive(Debug, Clone, Default)]
34pub struct DictionarySegmenter {
35    pub config: DictionaryConfig,
36}
37
38impl DictionarySegmenter {
39    pub fn new(config: DictionaryConfig) -> Self {
40        Self { config }
41    }
42
43    pub fn export(&self) -> Vec<DictionaryExport> {
44        let mut exports = Vec::new();
45        if let Some(dict) = &self.config.japanese
46            && let Some(export) = export_dictionary(DictionaryLanguage::Japanese, dict)
47        {
48            exports.push(export);
49        }
50        if let Some(dict) = &self.config.hangul
51            && let Some(export) = export_dictionary(DictionaryLanguage::Hangul, dict)
52        {
53            exports.push(export);
54        }
55        exports
56    }
57
58    pub fn segment(
59        &self,
60        segment: &str,
61        base_start: usize,
62        script: SegmentScript,
63        normalizer: &dyn TextNormalizer,
64        out: &mut Vec<TokenWithScript>,
65        seen: &mut HashSet<(String, usize, usize)>,
66    ) -> bool {
67        let Some(dictionary) = self.dictionary_for_script(script) else {
68            return false;
69        };
70        if dictionary.is_empty() {
71            return false;
72        }
73
74        let mut char_offsets: Vec<usize> = segment.char_indices().map(|(i, _)| i).collect();
75        char_offsets.push(segment.len());
76        let char_len = char_offsets.len().saturating_sub(1);
77        if char_len == 0 {
78            return true;
79        }
80
81        let mut covered = vec![false; char_len];
82        let mut matched_any = false;
83        let mut idx = 0;
84
85        while idx < char_len {
86            let mut matched_range: Option<(usize, usize)> = None;
87            for end in (idx + 1..=char_len).rev() {
88                let start_byte = char_offsets[idx];
89                let end_byte = char_offsets[end];
90                let candidate = &segment[start_byte..end_byte];
91                if dictionary.entries.contains(candidate) {
92                    matched_range = Some((idx, end));
93                    break;
94                }
95            }
96
97            if let Some((start_idx, end_idx)) = matched_range {
98                matched_any = true;
99                let start_byte = char_offsets[start_idx];
100                let end_byte = char_offsets[end_idx];
101                normalizer.normalize(
102                    &segment[start_byte..end_byte],
103                    base_start + start_byte,
104                    script,
105                    out,
106                    seen,
107                );
108                for item in covered.iter_mut().take(end_idx).skip(start_idx) {
109                    *item = true;
110                }
111                idx = end_idx;
112            } else {
113                idx += 1;
114            }
115        }
116
117        if !matched_any {
118            tokenize_char_ngrams(segment, base_start, script, normalizer, out, seen);
119            return true;
120        }
121
122        let mut start = 0;
123        while start < char_len {
124            if covered[start] {
125                start += 1;
126                continue;
127            }
128            let mut end = start + 1;
129            while end < char_len && !covered[end] {
130                end += 1;
131            }
132            let start_byte = char_offsets[start];
133            let end_byte = char_offsets[end];
134            tokenize_char_ngrams(
135                &segment[start_byte..end_byte],
136                base_start + start_byte,
137                script,
138                normalizer,
139                out,
140                seen,
141            );
142            start = end;
143        }
144
145        true
146    }
147
148    fn dictionary_for_script(&self, script: SegmentScript) -> Option<&ScriptDictionary> {
149        match script {
150            SegmentScript::Hiragana | SegmentScript::Katakana => self.config.japanese.as_ref(),
151            SegmentScript::Hangul => self.config.hangul.as_ref(),
152            _ => None,
153        }
154    }
155}
156
157#[derive(Debug, Clone)]
158pub struct DictionaryMetadata {
159    pub language: DictionaryLanguage,
160    pub version: String,
161    pub entry_count: usize,
162    pub generated_at: SystemTime,
163}
164
165#[derive(Debug, Clone)]
166pub struct DictionaryExport {
167    pub metadata: DictionaryMetadata,
168    pub entries: Vec<String>,
169}
170
171pub fn export_dictionary(
172    language: DictionaryLanguage,
173    dictionary: &ScriptDictionary,
174) -> Option<DictionaryExport> {
175    if dictionary.is_empty() {
176        return None;
177    }
178    let mut entries: Vec<String> = dictionary.entries.iter().cloned().collect();
179    entries.sort();
180
181    let metadata = DictionaryMetadata {
182        language,
183        version: dictionary
184            .version
185            .clone()
186            .unwrap_or_else(|| format!("{}-unversioned", language_prefix(language))),
187        entry_count: entries.len(),
188        generated_at: SystemTime::now(),
189    };
190
191    Some(DictionaryExport { metadata, entries })
192}
193
194#[derive(Debug, Clone)]
195pub struct DictionaryTrainingConfig {
196    pub min_freq: usize,
197    pub min_token_len: usize,
198    pub max_token_len: usize,
199    pub max_entries: usize,
200    pub version: Option<String>,
201}
202
203impl Default for DictionaryTrainingConfig {
204    fn default() -> Self {
205        Self {
206            min_freq: 2,
207            min_token_len: 2,
208            max_token_len: 8,
209            max_entries: 8_000,
210            version: None,
211        }
212    }
213}
214
215pub fn train_dictionary_for_language(
216    corpus: &[String],
217    language: DictionaryLanguage,
218    config: DictionaryTrainingConfig,
219) -> ScriptDictionary {
220    let min_token_len = config.min_token_len.max(1);
221    let max_token_len = config.max_token_len.max(min_token_len);
222    let mut counts: HashMap<String, usize> = HashMap::new();
223
224    for text in corpus {
225        for (script, start, end) in script_runs(text) {
226            if !matches_language(script, language) {
227                continue;
228            }
229            let segment = &text[start..end];
230            let mut char_offsets: Vec<usize> = segment.char_indices().map(|(i, _)| i).collect();
231            char_offsets.push(segment.len());
232            let char_len = char_offsets.len().saturating_sub(1);
233            for i in 0..char_len {
234                for len in min_token_len..=max_token_len {
235                    if i + len > char_len {
236                        break;
237                    }
238                    let start_byte = char_offsets[i];
239                    let end_byte = char_offsets[i + len];
240                    let candidate = &segment[start_byte..end_byte];
241                    if candidate.chars().any(|c| c.is_whitespace()) {
242                        continue;
243                    }
244                    *counts.entry(candidate.to_string()).or_insert(0) += 1;
245                }
246            }
247        }
248    }
249
250    let mut entries: Vec<(String, usize)> = counts
251        .into_iter()
252        .filter(|(_, freq)| *freq >= config.min_freq)
253        .collect();
254    entries.sort_by(|a, b| {
255        b.1.cmp(&a.1)
256            .then_with(|| b.0.len().cmp(&a.0.len()))
257            .then_with(|| a.0.cmp(&b.0))
258    });
259    if config.max_entries > 0 && entries.len() > config.max_entries {
260        entries.truncate(config.max_entries);
261    }
262
263    let entries_set: HashSet<String> = entries.into_iter().map(|(entry, _)| entry).collect();
264    ScriptDictionary {
265        version: Some(version_or_default(language, &config.version)),
266        entries: entries_set,
267    }
268}
269
270pub fn train_dictionary_config(
271    corpus: &[String],
272    config: DictionaryTrainingConfig,
273) -> DictionaryConfig {
274    let japanese =
275        train_dictionary_for_language(corpus, DictionaryLanguage::Japanese, config.clone());
276    let hangul = train_dictionary_for_language(corpus, DictionaryLanguage::Hangul, config);
277
278    DictionaryConfig {
279        japanese: (!japanese.is_empty()).then_some(japanese),
280        hangul: (!hangul.is_empty()).then_some(hangul),
281    }
282}
283
284fn version_or_default(language: DictionaryLanguage, provided: &Option<String>) -> String {
285    if let Some(version) = provided {
286        return version.clone();
287    }
288    let ts = SystemTime::now()
289        .duration_since(UNIX_EPOCH)
290        .unwrap_or_default()
291        .as_secs();
292    format!("{}-{ts}", language_prefix(language))
293}
294
295fn language_prefix(language: DictionaryLanguage) -> &'static str {
296    match language {
297        DictionaryLanguage::Japanese => "ja",
298        DictionaryLanguage::Hangul => "ko",
299    }
300}
301
302fn matches_language(script: SegmentScript, language: DictionaryLanguage) -> bool {
303    matches!(
304        (language, script),
305        (DictionaryLanguage::Japanese, SegmentScript::Hiragana)
306            | (DictionaryLanguage::Japanese, SegmentScript::Katakana)
307            | (DictionaryLanguage::Hangul, SegmentScript::Hangul)
308    )
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::tokenizer::{SegmentScript, normalize_query};
315    use std::collections::HashSet;
316    use tempfile::tempdir;
317
318    #[test]
319    fn segments_dictionary_tokens_and_fallbacks() {
320        let mut entries = HashSet::new();
321        entries.insert("こん".to_string());
322        let config = DictionaryConfig {
323            japanese: Some(ScriptDictionary {
324                version: Some("v1".to_string()),
325                entries,
326            }),
327            hangul: None,
328        };
329        let segmenter = DictionarySegmenter::new(config);
330        let normalizer = normalize_query();
331        let mut out = Vec::new();
332        let mut seen = HashSet::new();
333
334        let used = segmenter.segment(
335            "こんにちは",
336            0,
337            SegmentScript::Hiragana,
338            normalizer.as_ref(),
339            &mut out,
340            &mut seen,
341        );
342
343        assert!(used, "expected dictionary to be applied when provided");
344        assert!(
345            out.iter().any(|t| t.term == "こん"),
346            "expected dictionary token present, got {:?}",
347            out
348        );
349        assert!(
350            out.iter().any(|t| t.start == 12),
351            "expected fallback tokens for unmatched spans, got {:?}",
352            out
353        );
354    }
355
356    #[test]
357    fn trains_and_exports_dictionaries() {
358        let corpus = vec![
359            "こんにちは世界".to_string(),
360            "こんにちは友達".to_string(),
361            "안녕하세요 세계".to_string(),
362        ];
363        let config = DictionaryTrainingConfig {
364            min_freq: 1,
365            min_token_len: 2,
366            max_token_len: 3,
367            max_entries: 4,
368            version: Some("v1".to_string()),
369        };
370
371        let dictionaries = train_dictionary_config(&corpus, config);
372        let segmenter = DictionarySegmenter::new(dictionaries.clone());
373        let exports = segmenter.export();
374
375        assert!(
376            dictionaries.japanese.is_some(),
377            "expected japanese dictionary"
378        );
379        assert!(dictionaries.hangul.is_some(), "expected hangul dictionary");
380        assert_eq!(exports.len(), 2, "expected exports per language");
381        let ja_export = exports
382            .iter()
383            .find(|e| matches!(e.metadata.language, DictionaryLanguage::Japanese))
384            .expect("japanese export present");
385        assert_eq!(ja_export.metadata.entry_count, ja_export.entries.len());
386        assert!(
387            ja_export
388                .metadata
389                .generated_at
390                .elapsed()
391                .unwrap_or_default()
392                .as_secs()
393                < 5,
394            "expected recent generated_at, got {:?}",
395            ja_export.metadata.generated_at
396        );
397        assert!(
398            ja_export.metadata.version.starts_with("v1"),
399            "expected provided version, got {}",
400            ja_export.metadata.version
401        );
402    }
403
404    #[test]
405    fn saves_and_loads_dictionary_config() {
406        let dir = tempdir().unwrap();
407        let path = dir.path().join("dict.json");
408
409        let mut entries = HashSet::new();
410        entries.insert("こん".to_string());
411        let config = DictionaryConfig {
412            japanese: Some(ScriptDictionary {
413                version: Some("v1".to_string()),
414                entries,
415            }),
416            hangul: None,
417        };
418
419        save_dictionary(&path, &config).unwrap();
420        let loaded = load_dictionary(&path).unwrap();
421        assert_eq!(
422            loaded.japanese.unwrap().entries.len(),
423            1,
424            "expected saved japanese entries"
425        );
426    }
427
428    fn save_dictionary(path: &std::path::Path, config: &DictionaryConfig) -> std::io::Result<()> {
429        let data = serde_json::to_vec(config).map_err(to_io_err)?;
430        std::fs::write(path, data)
431    }
432
433    fn load_dictionary(path: &std::path::Path) -> std::io::Result<DictionaryConfig> {
434        let data = std::fs::read(path)?;
435        serde_json::from_slice(&data).map_err(to_io_err)
436    }
437
438    fn to_io_err(err: impl std::fmt::Display) -> std::io::Error {
439        std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
440    }
441}