lindera_dictionary/dictionary/
unknown_dictionary.rs1use std::str::FromStr;
2
3use log::warn;
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::LinderaResult;
8use crate::dictionary::character_definition::CategoryId;
9use crate::error::LinderaErrorKind;
10use crate::viterbi::WordEntry;
11
12#[derive(Serialize, Deserialize, Clone, Archive, RkyvSerialize, RkyvDeserialize)]
13
14pub struct UnknownDictionary {
15 pub category_references: Vec<Vec<u32>>,
16 pub costs: Vec<WordEntry>,
17 pub words_idx_data: Vec<u32>,
19 pub words_data: Vec<u8>,
22}
23
24impl UnknownDictionary {
25 pub fn load(unknown_data: &[u8]) -> LinderaResult<UnknownDictionary> {
26 let mut aligned = rkyv::util::AlignedVec::<16>::new();
27 aligned.extend_from_slice(unknown_data);
28 rkyv::from_bytes::<UnknownDictionary, rkyv::rancor::Error>(&aligned).map_err(|err| {
29 LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(err.to_string()))
30 })
31 }
32
33 pub fn word_entry(&self, word_id: u32) -> WordEntry {
34 self.costs[word_id as usize]
35 }
36
37 pub fn lookup_word_ids(&self, category_id: CategoryId) -> &[u32] {
38 &self.category_references[category_id.0][..]
39 }
40
41 pub fn word_details(&self, word_id: u32) -> Option<Vec<&str>> {
43 let idx = word_id as usize;
44 if idx >= self.words_idx_data.len() {
45 return None;
46 }
47 let offset = self.words_idx_data[idx] as usize;
48 if offset + 4 > self.words_data.len() {
49 return None;
50 }
51 let len = u32::from_le_bytes(self.words_data[offset..offset + 4].try_into().ok()?) as usize;
52 if offset + 4 + len > self.words_data.len() {
53 return None;
54 }
55 let text = std::str::from_utf8(&self.words_data[offset + 4..offset + 4 + len]).ok()?;
56 Some(text.split('\0').collect())
57 }
58
59 pub fn gen_unk_words<F>(
61 &self,
62 sentence: &str,
63 start_pos: usize,
64 has_matched: bool,
65 max_grouping_len: Option<usize>,
66 mut callback: F,
67 ) where
68 F: FnMut(UnkWord),
69 {
70 let chars: Vec<char> = sentence.chars().collect();
71 let max_len = max_grouping_len.unwrap_or(10);
72
73 let actual_max_len = if has_matched { 1 } else { max_len.min(3) };
75
76 for length in 1..=actual_max_len {
77 if start_pos + length > chars.len() {
78 break;
79 }
80
81 let end_pos = start_pos + length;
82
83 let first_char = chars[start_pos];
85 let char_type = classify_char_type(first_char);
86
87 let unk_word = UnkWord {
89 word_idx: WordIdx::new(char_type as u32),
90 end_char: end_pos,
91 };
92
93 callback(unk_word);
94 }
95 }
96
97 pub fn compatible_unk_index(
99 &self,
100 sentence: &str,
101 start: usize,
102 _end: usize,
103 feature: &str,
104 ) -> Option<WordIdx> {
105 let chars: Vec<char> = sentence.chars().collect();
106 if start >= chars.len() {
107 return None;
108 }
109
110 let first_char = chars[start];
111 let char_type = classify_char_type(first_char);
112
113 if feature.starts_with(&format!("名詞,{}", get_type_name(char_type))) {
115 Some(WordIdx::new(char_type as u32))
116 } else {
117 None
118 }
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct UnkWord {
125 pub word_idx: WordIdx,
126 pub end_char: usize,
127}
128
129impl UnkWord {
130 pub fn word_idx(&self) -> WordIdx {
131 self.word_idx
132 }
133
134 pub fn end_char(&self) -> usize {
135 self.end_char
136 }
137}
138
139#[derive(Debug, Clone, Copy)]
140pub struct WordIdx {
141 pub word_id: u32,
142}
143
144impl WordIdx {
145 pub fn new(word_id: u32) -> Self {
146 Self { word_id }
147 }
148}
149
150fn classify_char_type(ch: char) -> usize {
152 if ch.is_ascii_digit() {
153 5 } else if ch.is_ascii_alphabetic() {
155 4 } else if is_kanji(ch) {
157 3 } else if is_katakana(ch) {
159 2 } else if is_hiragana(ch) {
161 1 } else {
163 0 }
165}
166
167fn get_type_name(char_type: usize) -> &'static str {
168 match char_type {
169 1 => "一般",
170 2 => "一般",
171 3 => "一般",
172 4 => "固有名詞",
173 5 => "数",
174 _ => "一般",
175 }
176}
177
178fn is_hiragana(ch: char) -> bool {
180 matches!(ch, '\u{3041}'..='\u{3096}')
181}
182
183fn is_katakana(ch: char) -> bool {
184 matches!(ch, '\u{30A1}'..='\u{30F6}' | '\u{30F7}'..='\u{30FA}' | '\u{31F0}'..='\u{31FF}')
185}
186
187fn is_kanji(ch: char) -> bool {
188 matches!(ch, '\u{4E00}'..='\u{9FAF}' | '\u{3400}'..='\u{4DBF}')
189}
190
191#[derive(Debug)]
192pub struct UnknownDictionaryEntry {
193 pub surface: String,
194 pub left_id: u32,
195 pub right_id: u32,
196 pub word_cost: i32,
197}
198
199fn parse_dictionary_entry(
200 fields: &[&str],
201 expected_fields_len: usize,
202) -> LinderaResult<UnknownDictionaryEntry> {
203 if fields.len() != expected_fields_len {
204 return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
205 "Invalid number of fields. Expect {}, got {}",
206 expected_fields_len,
207 fields.len()
208 )));
209 }
210 let surface = fields[0];
211 let left_id = u32::from_str(fields[1])
212 .map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
213 let right_id = u32::from_str(fields[2])
214 .map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
215 let word_cost = i32::from_str(fields[3])
216 .map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
217
218 Ok(UnknownDictionaryEntry {
219 surface: surface.to_string(),
220 left_id,
221 right_id,
222 word_cost,
223 })
224}
225
226fn get_entry_id_matching_surface(
227 entries: &[UnknownDictionaryEntry],
228 target_surface: &str,
229) -> Vec<u32> {
230 entries
231 .iter()
232 .enumerate()
233 .filter_map(|(entry_id, entry)| {
234 if entry.surface == *target_surface {
235 Some(entry_id as u32)
236 } else {
237 None
238 }
239 })
240 .collect()
241}
242
243fn make_category_references(
244 categories: &[String],
245 entries: &[UnknownDictionaryEntry],
246) -> Vec<Vec<u32>> {
247 categories
248 .iter()
249 .map(|category| get_entry_id_matching_surface(entries, category))
250 .collect()
251}
252
253fn make_costs_array(entries: &[UnknownDictionaryEntry]) -> Vec<WordEntry> {
254 entries
255 .iter()
256 .enumerate()
257 .map(|(i, e)| {
258 if e.left_id != e.right_id {
261 warn!("left id and right id are not same: {e:?}");
262 }
263 WordEntry {
264 word_id: crate::viterbi::WordId::new(crate::viterbi::LexType::Unknown, i as u32),
265 left_id: e.left_id as u16,
266 right_id: e.right_id as u16,
267 word_cost: e.word_cost as i16,
268 }
269 })
270 .collect()
271}
272
273pub fn parse_unk(categories: &[String], file_content: &str) -> LinderaResult<UnknownDictionary> {
274 let mut unknown_dict_entries = Vec::new();
275 let mut words_idx_data = Vec::new();
276 let mut words_data: Vec<u8> = Vec::new();
277
278 for line in file_content.lines() {
279 let fields: Vec<&str> = line.split(',').collect::<Vec<&str>>();
280 let entry = parse_dictionary_entry(&fields[..], fields.len())?;
281 unknown_dict_entries.push(entry);
282
283 let offset = words_data.len() as u32;
285 words_idx_data.push(offset);
286
287 let details = if fields.len() > 4 {
288 fields[4..].join("\0")
289 } else {
290 String::new()
291 };
292 let details_bytes = details.as_bytes();
293 let len = details_bytes.len() as u32;
294 words_data.extend_from_slice(&len.to_le_bytes());
295 words_data.extend_from_slice(details_bytes);
296 }
297
298 let category_references = make_category_references(categories, &unknown_dict_entries[..]);
299 let costs = make_costs_array(&unknown_dict_entries[..]);
300 Ok(UnknownDictionary {
301 category_references,
302 costs,
303 words_idx_data,
304 words_data,
305 })
306}
307
308impl ArchivedUnknownDictionary {
309 pub fn word_entry(&self, word_id: u32) -> WordEntry {
310 let archived_entry = &self.costs[word_id as usize];
315 rkyv::deserialize::<WordEntry, rkyv::rancor::Error>(archived_entry).unwrap()
316 }
317
318 pub fn lookup_word_ids(&self, category_id: CategoryId) -> &[rkyv::rend::u32_le] {
319 self.category_references[category_id.0].as_slice()
320 }
321
322 pub fn word_details(&self, word_id: u32) -> Option<Vec<&str>> {
324 let idx = word_id as usize;
325 if idx >= self.words_idx_data.len() {
326 return None;
327 }
328 let offset = u32::from(self.words_idx_data[idx]) as usize;
329 if offset + 4 > self.words_data.len() {
330 return None;
331 }
332 let len_bytes: [u8; 4] = self.words_data[offset..offset + 4].try_into().ok()?;
333 let len = u32::from_le_bytes(len_bytes) as usize;
334 if offset + 4 + len > self.words_data.len() {
335 return None;
336 }
337 let text = std::str::from_utf8(&self.words_data[offset + 4..offset + 4 + len]).ok()?;
338 Some(text.split('\0').collect())
339 }
340}