lindera_dictionary/dictionary/
unknown_dictionary.rs

1use std::str::FromStr;
2
3use log::warn;
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::LinderaResult;
8use crate::dictionary::character_definition::CategoryId;
9use crate::error::LinderaErrorKind;
10use crate::viterbi::WordEntry;
11
12#[derive(Serialize, Deserialize, Clone, Archive, RkyvSerialize, RkyvDeserialize)]
13
14pub struct UnknownDictionary {
15    pub category_references: Vec<Vec<u32>>,
16    pub costs: Vec<WordEntry>,
17}
18
19impl UnknownDictionary {
20    pub fn load(unknown_data: &[u8]) -> LinderaResult<UnknownDictionary> {
21        let mut aligned = rkyv::util::AlignedVec::<16>::new();
22        aligned.extend_from_slice(unknown_data);
23        rkyv::from_bytes::<UnknownDictionary, rkyv::rancor::Error>(&aligned).map_err(|err| {
24            LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(err.to_string()))
25        })
26    }
27
28    pub fn word_entry(&self, word_id: u32) -> WordEntry {
29        self.costs[word_id as usize]
30    }
31
32    pub fn lookup_word_ids(&self, category_id: CategoryId) -> &[u32] {
33        &self.category_references[category_id.0][..]
34    }
35
36    /// Unknown word generation with callback system
37    pub fn gen_unk_words<F>(
38        &self,
39        sentence: &str,
40        start_pos: usize,
41        has_matched: bool,
42        max_grouping_len: Option<usize>,
43        mut callback: F,
44    ) where
45        F: FnMut(UnkWord),
46    {
47        let chars: Vec<char> = sentence.chars().collect();
48        let max_len = max_grouping_len.unwrap_or(10);
49
50        // Limit based on dictionary matches for efficiency
51        let actual_max_len = if has_matched { 1 } else { max_len.min(3) };
52
53        for length in 1..=actual_max_len {
54            if start_pos + length > chars.len() {
55                break;
56            }
57
58            let end_pos = start_pos + length;
59
60            // Classify character type for unknown word
61            let first_char = chars[start_pos];
62            let char_type = classify_char_type(first_char);
63
64            // Create unknown word entry
65            let unk_word = UnkWord {
66                word_idx: WordIdx::new(char_type as u32),
67                end_char: end_pos,
68            };
69
70            callback(unk_word);
71        }
72    }
73
74    /// Check compatibility with unknown word based on feature matching
75    pub fn compatible_unk_index(
76        &self,
77        sentence: &str,
78        start: usize,
79        _end: usize,
80        feature: &str,
81    ) -> Option<WordIdx> {
82        let chars: Vec<char> = sentence.chars().collect();
83        if start >= chars.len() {
84            return None;
85        }
86
87        let first_char = chars[start];
88        let char_type = classify_char_type(first_char);
89
90        // Simple compatibility check based on feature string
91        if feature.starts_with(&format!("名詞,{}", get_type_name(char_type))) {
92            Some(WordIdx::new(char_type as u32))
93        } else {
94            None
95        }
96    }
97}
98
99/// Unknown word structure for callback system
100#[derive(Debug, Clone)]
101pub struct UnkWord {
102    pub word_idx: WordIdx,
103    pub end_char: usize,
104}
105
106impl UnkWord {
107    pub fn word_idx(&self) -> WordIdx {
108        self.word_idx
109    }
110
111    pub fn end_char(&self) -> usize {
112        self.end_char
113    }
114}
115
116#[derive(Debug, Clone, Copy)]
117pub struct WordIdx {
118    pub word_id: u32,
119}
120
121impl WordIdx {
122    pub fn new(word_id: u32) -> Self {
123        Self { word_id }
124    }
125}
126
127/// Classify character type (compatible with existing system)
128fn classify_char_type(ch: char) -> usize {
129    if ch.is_ascii_digit() {
130        5 // NUMERIC
131    } else if ch.is_ascii_alphabetic() {
132        4 // ALPHA
133    } else if is_kanji(ch) {
134        3 // KANJI
135    } else if is_katakana(ch) {
136        2 // KATAKANA
137    } else if is_hiragana(ch) {
138        1 // HIRAGANA
139    } else {
140        0 // DEFAULT
141    }
142}
143
144fn get_type_name(char_type: usize) -> &'static str {
145    match char_type {
146        1 => "一般",
147        2 => "一般",
148        3 => "一般",
149        4 => "固有名詞",
150        5 => "数",
151        _ => "一般",
152    }
153}
154
155/// Character classification helpers
156fn is_hiragana(ch: char) -> bool {
157    matches!(ch, '\u{3041}'..='\u{3096}')
158}
159
160fn is_katakana(ch: char) -> bool {
161    matches!(ch, '\u{30A1}'..='\u{30F6}' | '\u{30F7}'..='\u{30FA}' | '\u{31F0}'..='\u{31FF}')
162}
163
164fn is_kanji(ch: char) -> bool {
165    matches!(ch, '\u{4E00}'..='\u{9FAF}' | '\u{3400}'..='\u{4DBF}')
166}
167
168#[derive(Debug)]
169pub struct UnknownDictionaryEntry {
170    pub surface: String,
171    pub left_id: u32,
172    pub right_id: u32,
173    pub word_cost: i32,
174}
175
176fn parse_dictionary_entry(
177    fields: &[&str],
178    expected_fields_len: usize,
179) -> LinderaResult<UnknownDictionaryEntry> {
180    if fields.len() != expected_fields_len {
181        return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
182            "Invalid number of fields. Expect {}, got {}",
183            expected_fields_len,
184            fields.len()
185        )));
186    }
187    let surface = fields[0];
188    let left_id = u32::from_str(fields[1])
189        .map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
190    let right_id = u32::from_str(fields[2])
191        .map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
192    let word_cost = i32::from_str(fields[3])
193        .map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
194
195    Ok(UnknownDictionaryEntry {
196        surface: surface.to_string(),
197        left_id,
198        right_id,
199        word_cost,
200    })
201}
202
203fn get_entry_id_matching_surface(
204    entries: &[UnknownDictionaryEntry],
205    target_surface: &str,
206) -> Vec<u32> {
207    entries
208        .iter()
209        .enumerate()
210        .filter_map(|(entry_id, entry)| {
211            if entry.surface == *target_surface {
212                Some(entry_id as u32)
213            } else {
214                None
215            }
216        })
217        .collect()
218}
219
220fn make_category_references(
221    categories: &[String],
222    entries: &[UnknownDictionaryEntry],
223) -> Vec<Vec<u32>> {
224    categories
225        .iter()
226        .map(|category| get_entry_id_matching_surface(entries, category))
227        .collect()
228}
229
230fn make_costs_array(entries: &[UnknownDictionaryEntry]) -> Vec<WordEntry> {
231    entries
232        .iter()
233        .map(|e| {
234            // Do not perform strict checks on left context id and right context id in unk.def.
235            // Just output a warning.
236            if e.left_id != e.right_id {
237                warn!("left id and right id are not same: {e:?}");
238            }
239            WordEntry {
240                word_id: crate::viterbi::WordId::new(crate::viterbi::LexType::Unknown, u32::MAX),
241                left_id: e.left_id as u16,
242                right_id: e.right_id as u16,
243                word_cost: e.word_cost as i16,
244            }
245        })
246        .collect()
247}
248
249pub fn parse_unk(categories: &[String], file_content: &str) -> LinderaResult<UnknownDictionary> {
250    let mut unknown_dict_entries = Vec::new();
251    for line in file_content.lines() {
252        let fields: Vec<&str> = line.split(',').collect::<Vec<&str>>();
253        let entry = parse_dictionary_entry(&fields[..], fields.len())?;
254        unknown_dict_entries.push(entry);
255    }
256
257    let category_references = make_category_references(categories, &unknown_dict_entries[..]);
258    let costs = make_costs_array(&unknown_dict_entries[..]);
259    Ok(UnknownDictionary {
260        category_references,
261        costs,
262    })
263}
264
265impl ArchivedUnknownDictionary {
266    pub fn word_entry(&self, word_id: u32) -> WordEntry {
267        // We have to deserialize the single entry or extract fields.
268        // Simple Archive usually preserves layout for primitives.
269        // Using deserialize ensures we get the native struct.
270        // Since WordEntry is small and Copy, this is efficient enough.
271        let archived_entry = &self.costs[word_id as usize];
272        rkyv::deserialize::<WordEntry, rkyv::rancor::Error>(archived_entry).unwrap()
273    }
274
275    pub fn lookup_word_ids(&self, category_id: CategoryId) -> &[rkyv::rend::u32_le] {
276        self.category_references[category_id.0].as_slice()
277    }
278}