1use std::{
2 collections::{HashMap, HashSet},
3 time::{SystemTime, UNIX_EPOCH},
4};
5
6use super::{SegmentScript, TextNormalizer, TokenWithScript, script_runs, tokenize_char_ngrams};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DictionaryLanguage {
11 Japanese,
12 Hangul,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ScriptDictionary {
17 pub version: Option<String>,
18 pub entries: HashSet<String>,
19}
20
21impl ScriptDictionary {
22 pub fn is_empty(&self) -> bool {
23 self.entries.is_empty()
24 }
25}
26
27#[derive(Debug, Clone, Default, Serialize, Deserialize)]
28pub struct DictionaryConfig {
29 pub japanese: Option<ScriptDictionary>,
30 pub hangul: Option<ScriptDictionary>,
31}
32
33#[derive(Debug, Clone, Default)]
34pub struct DictionarySegmenter {
35 pub config: DictionaryConfig,
36}
37
38impl DictionarySegmenter {
39 pub fn new(config: DictionaryConfig) -> Self {
40 Self { config }
41 }
42
43 pub fn export(&self) -> Vec<DictionaryExport> {
44 let mut exports = Vec::new();
45 if let Some(dict) = &self.config.japanese
46 && let Some(export) = export_dictionary(DictionaryLanguage::Japanese, dict)
47 {
48 exports.push(export);
49 }
50 if let Some(dict) = &self.config.hangul
51 && let Some(export) = export_dictionary(DictionaryLanguage::Hangul, dict)
52 {
53 exports.push(export);
54 }
55 exports
56 }
57
58 pub fn segment(
59 &self,
60 segment: &str,
61 base_start: usize,
62 script: SegmentScript,
63 normalizer: &dyn TextNormalizer,
64 out: &mut Vec<TokenWithScript>,
65 seen: &mut HashSet<(String, usize, usize)>,
66 ) -> bool {
67 let Some(dictionary) = self.dictionary_for_script(script) else {
68 return false;
69 };
70 if dictionary.is_empty() {
71 return false;
72 }
73
74 let mut char_offsets: Vec<usize> = segment.char_indices().map(|(i, _)| i).collect();
75 char_offsets.push(segment.len());
76 let char_len = char_offsets.len().saturating_sub(1);
77 if char_len == 0 {
78 return true;
79 }
80
81 let mut covered = vec![false; char_len];
82 let mut matched_any = false;
83 let mut idx = 0;
84
85 while idx < char_len {
86 let mut matched_range: Option<(usize, usize)> = None;
87 for end in (idx + 1..=char_len).rev() {
88 let start_byte = char_offsets[idx];
89 let end_byte = char_offsets[end];
90 let candidate = &segment[start_byte..end_byte];
91 if dictionary.entries.contains(candidate) {
92 matched_range = Some((idx, end));
93 break;
94 }
95 }
96
97 if let Some((start_idx, end_idx)) = matched_range {
98 matched_any = true;
99 let start_byte = char_offsets[start_idx];
100 let end_byte = char_offsets[end_idx];
101 normalizer.normalize(
102 &segment[start_byte..end_byte],
103 base_start + start_byte,
104 script,
105 out,
106 seen,
107 );
108 for item in covered.iter_mut().take(end_idx).skip(start_idx) {
109 *item = true;
110 }
111 idx = end_idx;
112 } else {
113 idx += 1;
114 }
115 }
116
117 if !matched_any {
118 tokenize_char_ngrams(segment, base_start, script, normalizer, out, seen);
119 return true;
120 }
121
122 let mut start = 0;
123 while start < char_len {
124 if covered[start] {
125 start += 1;
126 continue;
127 }
128 let mut end = start + 1;
129 while end < char_len && !covered[end] {
130 end += 1;
131 }
132 let start_byte = char_offsets[start];
133 let end_byte = char_offsets[end];
134 tokenize_char_ngrams(
135 &segment[start_byte..end_byte],
136 base_start + start_byte,
137 script,
138 normalizer,
139 out,
140 seen,
141 );
142 start = end;
143 }
144
145 true
146 }
147
148 fn dictionary_for_script(&self, script: SegmentScript) -> Option<&ScriptDictionary> {
149 match script {
150 SegmentScript::Hiragana | SegmentScript::Katakana => self.config.japanese.as_ref(),
151 SegmentScript::Hangul => self.config.hangul.as_ref(),
152 _ => None,
153 }
154 }
155}
156
157#[derive(Debug, Clone)]
158pub struct DictionaryMetadata {
159 pub language: DictionaryLanguage,
160 pub version: String,
161 pub entry_count: usize,
162 pub generated_at: SystemTime,
163}
164
165#[derive(Debug, Clone)]
166pub struct DictionaryExport {
167 pub metadata: DictionaryMetadata,
168 pub entries: Vec<String>,
169}
170
171pub fn export_dictionary(
172 language: DictionaryLanguage,
173 dictionary: &ScriptDictionary,
174) -> Option<DictionaryExport> {
175 if dictionary.is_empty() {
176 return None;
177 }
178 let mut entries: Vec<String> = dictionary.entries.iter().cloned().collect();
179 entries.sort();
180
181 let metadata = DictionaryMetadata {
182 language,
183 version: dictionary
184 .version
185 .clone()
186 .unwrap_or_else(|| format!("{}-unversioned", language_prefix(language))),
187 entry_count: entries.len(),
188 generated_at: SystemTime::now(),
189 };
190
191 Some(DictionaryExport { metadata, entries })
192}
193
194#[derive(Debug, Clone)]
195pub struct DictionaryTrainingConfig {
196 pub min_freq: usize,
197 pub min_token_len: usize,
198 pub max_token_len: usize,
199 pub max_entries: usize,
200 pub version: Option<String>,
201}
202
203impl Default for DictionaryTrainingConfig {
204 fn default() -> Self {
205 Self {
206 min_freq: 2,
207 min_token_len: 2,
208 max_token_len: 8,
209 max_entries: 8_000,
210 version: None,
211 }
212 }
213}
214
215pub fn train_dictionary_for_language(
216 corpus: &[String],
217 language: DictionaryLanguage,
218 config: DictionaryTrainingConfig,
219) -> ScriptDictionary {
220 let min_token_len = config.min_token_len.max(1);
221 let max_token_len = config.max_token_len.max(min_token_len);
222 let mut counts: HashMap<String, usize> = HashMap::new();
223
224 for text in corpus {
225 for (script, start, end) in script_runs(text) {
226 if !matches_language(script, language) {
227 continue;
228 }
229 let segment = &text[start..end];
230 let mut char_offsets: Vec<usize> = segment.char_indices().map(|(i, _)| i).collect();
231 char_offsets.push(segment.len());
232 let char_len = char_offsets.len().saturating_sub(1);
233 for i in 0..char_len {
234 for len in min_token_len..=max_token_len {
235 if i + len > char_len {
236 break;
237 }
238 let start_byte = char_offsets[i];
239 let end_byte = char_offsets[i + len];
240 let candidate = &segment[start_byte..end_byte];
241 if candidate.chars().any(|c| c.is_whitespace()) {
242 continue;
243 }
244 *counts.entry(candidate.to_string()).or_insert(0) += 1;
245 }
246 }
247 }
248 }
249
250 let mut entries: Vec<(String, usize)> = counts
251 .into_iter()
252 .filter(|(_, freq)| *freq >= config.min_freq)
253 .collect();
254 entries.sort_by(|a, b| {
255 b.1.cmp(&a.1)
256 .then_with(|| b.0.len().cmp(&a.0.len()))
257 .then_with(|| a.0.cmp(&b.0))
258 });
259 if config.max_entries > 0 && entries.len() > config.max_entries {
260 entries.truncate(config.max_entries);
261 }
262
263 let entries_set: HashSet<String> = entries.into_iter().map(|(entry, _)| entry).collect();
264 ScriptDictionary {
265 version: Some(version_or_default(language, &config.version)),
266 entries: entries_set,
267 }
268}
269
270pub fn train_dictionary_config(
271 corpus: &[String],
272 config: DictionaryTrainingConfig,
273) -> DictionaryConfig {
274 let japanese =
275 train_dictionary_for_language(corpus, DictionaryLanguage::Japanese, config.clone());
276 let hangul = train_dictionary_for_language(corpus, DictionaryLanguage::Hangul, config);
277
278 DictionaryConfig {
279 japanese: (!japanese.is_empty()).then_some(japanese),
280 hangul: (!hangul.is_empty()).then_some(hangul),
281 }
282}
283
284fn version_or_default(language: DictionaryLanguage, provided: &Option<String>) -> String {
285 if let Some(version) = provided {
286 return version.clone();
287 }
288 let ts = SystemTime::now()
289 .duration_since(UNIX_EPOCH)
290 .unwrap_or_default()
291 .as_secs();
292 format!("{}-{ts}", language_prefix(language))
293}
294
295fn language_prefix(language: DictionaryLanguage) -> &'static str {
296 match language {
297 DictionaryLanguage::Japanese => "ja",
298 DictionaryLanguage::Hangul => "ko",
299 }
300}
301
302fn matches_language(script: SegmentScript, language: DictionaryLanguage) -> bool {
303 matches!(
304 (language, script),
305 (DictionaryLanguage::Japanese, SegmentScript::Hiragana)
306 | (DictionaryLanguage::Japanese, SegmentScript::Katakana)
307 | (DictionaryLanguage::Hangul, SegmentScript::Hangul)
308 )
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::tokenizer::{SegmentScript, normalize_query};
315 use std::collections::HashSet;
316 use tempfile::tempdir;
317
318 #[test]
319 fn segments_dictionary_tokens_and_fallbacks() {
320 let mut entries = HashSet::new();
321 entries.insert("こん".to_string());
322 let config = DictionaryConfig {
323 japanese: Some(ScriptDictionary {
324 version: Some("v1".to_string()),
325 entries,
326 }),
327 hangul: None,
328 };
329 let segmenter = DictionarySegmenter::new(config);
330 let normalizer = normalize_query();
331 let mut out = Vec::new();
332 let mut seen = HashSet::new();
333
334 let used = segmenter.segment(
335 "こんにちは",
336 0,
337 SegmentScript::Hiragana,
338 normalizer.as_ref(),
339 &mut out,
340 &mut seen,
341 );
342
343 assert!(used, "expected dictionary to be applied when provided");
344 assert!(
345 out.iter().any(|t| t.term == "こん"),
346 "expected dictionary token present, got {:?}",
347 out
348 );
349 assert!(
350 out.iter().any(|t| t.start == 12),
351 "expected fallback tokens for unmatched spans, got {:?}",
352 out
353 );
354 }
355
356 #[test]
357 fn trains_and_exports_dictionaries() {
358 let corpus = vec![
359 "こんにちは世界".to_string(),
360 "こんにちは友達".to_string(),
361 "안녕하세요 세계".to_string(),
362 ];
363 let config = DictionaryTrainingConfig {
364 min_freq: 1,
365 min_token_len: 2,
366 max_token_len: 3,
367 max_entries: 4,
368 version: Some("v1".to_string()),
369 };
370
371 let dictionaries = train_dictionary_config(&corpus, config);
372 let segmenter = DictionarySegmenter::new(dictionaries.clone());
373 let exports = segmenter.export();
374
375 assert!(
376 dictionaries.japanese.is_some(),
377 "expected japanese dictionary"
378 );
379 assert!(dictionaries.hangul.is_some(), "expected hangul dictionary");
380 assert_eq!(exports.len(), 2, "expected exports per language");
381 let ja_export = exports
382 .iter()
383 .find(|e| matches!(e.metadata.language, DictionaryLanguage::Japanese))
384 .expect("japanese export present");
385 assert_eq!(ja_export.metadata.entry_count, ja_export.entries.len());
386 assert!(
387 ja_export
388 .metadata
389 .generated_at
390 .elapsed()
391 .unwrap_or_default()
392 .as_secs()
393 < 5,
394 "expected recent generated_at, got {:?}",
395 ja_export.metadata.generated_at
396 );
397 assert!(
398 ja_export.metadata.version.starts_with("v1"),
399 "expected provided version, got {}",
400 ja_export.metadata.version
401 );
402 }
403
404 #[test]
405 fn saves_and_loads_dictionary_config() {
406 let dir = tempdir().unwrap();
407 let path = dir.path().join("dict.json");
408
409 let mut entries = HashSet::new();
410 entries.insert("こん".to_string());
411 let config = DictionaryConfig {
412 japanese: Some(ScriptDictionary {
413 version: Some("v1".to_string()),
414 entries,
415 }),
416 hangul: None,
417 };
418
419 save_dictionary(&path, &config).unwrap();
420 let loaded = load_dictionary(&path).unwrap();
421 assert_eq!(
422 loaded.japanese.unwrap().entries.len(),
423 1,
424 "expected saved japanese entries"
425 );
426 }
427
428 fn save_dictionary(path: &std::path::Path, config: &DictionaryConfig) -> std::io::Result<()> {
429 let data = serde_json::to_vec(config).map_err(to_io_err)?;
430 std::fs::write(path, data)
431 }
432
433 fn load_dictionary(path: &std::path::Path) -> std::io::Result<DictionaryConfig> {
434 let data = std::fs::read(path)?;
435 serde_json::from_slice(&data).map_err(to_io_err)
436 }
437
438 fn to_io_err(err: impl std::fmt::Display) -> std::io::Error {
439 std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
440 }
441}