lindera_dictionary/dictionary/
unknown_dictionary.rs1use std::str::FromStr;
2
3use log::warn;
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::LinderaResult;
8use crate::dictionary::character_definition::CategoryId;
9use crate::error::LinderaErrorKind;
10use crate::viterbi::WordEntry;
11
12#[derive(Serialize, Deserialize, Clone, Archive, RkyvSerialize, RkyvDeserialize)]
13
14pub struct UnknownDictionary {
15 pub category_references: Vec<Vec<u32>>,
16 pub costs: Vec<WordEntry>,
17}
18
19impl UnknownDictionary {
20 pub fn load(unknown_data: &[u8]) -> LinderaResult<UnknownDictionary> {
21 let mut aligned = rkyv::util::AlignedVec::<16>::new();
22 aligned.extend_from_slice(unknown_data);
23 rkyv::from_bytes::<UnknownDictionary, rkyv::rancor::Error>(&aligned).map_err(|err| {
24 LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(err.to_string()))
25 })
26 }
27
28 pub fn word_entry(&self, word_id: u32) -> WordEntry {
29 self.costs[word_id as usize]
30 }
31
32 pub fn lookup_word_ids(&self, category_id: CategoryId) -> &[u32] {
33 &self.category_references[category_id.0][..]
34 }
35
36 pub fn gen_unk_words<F>(
38 &self,
39 sentence: &str,
40 start_pos: usize,
41 has_matched: bool,
42 max_grouping_len: Option<usize>,
43 mut callback: F,
44 ) where
45 F: FnMut(UnkWord),
46 {
47 let chars: Vec<char> = sentence.chars().collect();
48 let max_len = max_grouping_len.unwrap_or(10);
49
50 let actual_max_len = if has_matched { 1 } else { max_len.min(3) };
52
53 for length in 1..=actual_max_len {
54 if start_pos + length > chars.len() {
55 break;
56 }
57
58 let end_pos = start_pos + length;
59
60 let first_char = chars[start_pos];
62 let char_type = classify_char_type(first_char);
63
64 let unk_word = UnkWord {
66 word_idx: WordIdx::new(char_type as u32),
67 end_char: end_pos,
68 };
69
70 callback(unk_word);
71 }
72 }
73
74 pub fn compatible_unk_index(
76 &self,
77 sentence: &str,
78 start: usize,
79 _end: usize,
80 feature: &str,
81 ) -> Option<WordIdx> {
82 let chars: Vec<char> = sentence.chars().collect();
83 if start >= chars.len() {
84 return None;
85 }
86
87 let first_char = chars[start];
88 let char_type = classify_char_type(first_char);
89
90 if feature.starts_with(&format!("名詞,{}", get_type_name(char_type))) {
92 Some(WordIdx::new(char_type as u32))
93 } else {
94 None
95 }
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct UnkWord {
102 pub word_idx: WordIdx,
103 pub end_char: usize,
104}
105
106impl UnkWord {
107 pub fn word_idx(&self) -> WordIdx {
108 self.word_idx
109 }
110
111 pub fn end_char(&self) -> usize {
112 self.end_char
113 }
114}
115
116#[derive(Debug, Clone, Copy)]
117pub struct WordIdx {
118 pub word_id: u32,
119}
120
121impl WordIdx {
122 pub fn new(word_id: u32) -> Self {
123 Self { word_id }
124 }
125}
126
127fn classify_char_type(ch: char) -> usize {
129 if ch.is_ascii_digit() {
130 5 } else if ch.is_ascii_alphabetic() {
132 4 } else if is_kanji(ch) {
134 3 } else if is_katakana(ch) {
136 2 } else if is_hiragana(ch) {
138 1 } else {
140 0 }
142}
143
144fn get_type_name(char_type: usize) -> &'static str {
145 match char_type {
146 1 => "一般",
147 2 => "一般",
148 3 => "一般",
149 4 => "固有名詞",
150 5 => "数",
151 _ => "一般",
152 }
153}
154
155fn is_hiragana(ch: char) -> bool {
157 matches!(ch, '\u{3041}'..='\u{3096}')
158}
159
160fn is_katakana(ch: char) -> bool {
161 matches!(ch, '\u{30A1}'..='\u{30F6}' | '\u{30F7}'..='\u{30FA}' | '\u{31F0}'..='\u{31FF}')
162}
163
164fn is_kanji(ch: char) -> bool {
165 matches!(ch, '\u{4E00}'..='\u{9FAF}' | '\u{3400}'..='\u{4DBF}')
166}
167
168#[derive(Debug)]
169pub struct UnknownDictionaryEntry {
170 pub surface: String,
171 pub left_id: u32,
172 pub right_id: u32,
173 pub word_cost: i32,
174}
175
176fn parse_dictionary_entry(
177 fields: &[&str],
178 expected_fields_len: usize,
179) -> LinderaResult<UnknownDictionaryEntry> {
180 if fields.len() != expected_fields_len {
181 return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
182 "Invalid number of fields. Expect {}, got {}",
183 expected_fields_len,
184 fields.len()
185 )));
186 }
187 let surface = fields[0];
188 let left_id = u32::from_str(fields[1])
189 .map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
190 let right_id = u32::from_str(fields[2])
191 .map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
192 let word_cost = i32::from_str(fields[3])
193 .map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
194
195 Ok(UnknownDictionaryEntry {
196 surface: surface.to_string(),
197 left_id,
198 right_id,
199 word_cost,
200 })
201}
202
203fn get_entry_id_matching_surface(
204 entries: &[UnknownDictionaryEntry],
205 target_surface: &str,
206) -> Vec<u32> {
207 entries
208 .iter()
209 .enumerate()
210 .filter_map(|(entry_id, entry)| {
211 if entry.surface == *target_surface {
212 Some(entry_id as u32)
213 } else {
214 None
215 }
216 })
217 .collect()
218}
219
220fn make_category_references(
221 categories: &[String],
222 entries: &[UnknownDictionaryEntry],
223) -> Vec<Vec<u32>> {
224 categories
225 .iter()
226 .map(|category| get_entry_id_matching_surface(entries, category))
227 .collect()
228}
229
230fn make_costs_array(entries: &[UnknownDictionaryEntry]) -> Vec<WordEntry> {
231 entries
232 .iter()
233 .map(|e| {
234 if e.left_id != e.right_id {
237 warn!("left id and right id are not same: {e:?}");
238 }
239 WordEntry {
240 word_id: crate::viterbi::WordId::new(crate::viterbi::LexType::Unknown, u32::MAX),
241 left_id: e.left_id as u16,
242 right_id: e.right_id as u16,
243 word_cost: e.word_cost as i16,
244 }
245 })
246 .collect()
247}
248
249pub fn parse_unk(categories: &[String], file_content: &str) -> LinderaResult<UnknownDictionary> {
250 let mut unknown_dict_entries = Vec::new();
251 for line in file_content.lines() {
252 let fields: Vec<&str> = line.split(',').collect::<Vec<&str>>();
253 let entry = parse_dictionary_entry(&fields[..], fields.len())?;
254 unknown_dict_entries.push(entry);
255 }
256
257 let category_references = make_category_references(categories, &unknown_dict_entries[..]);
258 let costs = make_costs_array(&unknown_dict_entries[..]);
259 Ok(UnknownDictionary {
260 category_references,
261 costs,
262 })
263}
264
265impl ArchivedUnknownDictionary {
266 pub fn word_entry(&self, word_id: u32) -> WordEntry {
267 let archived_entry = &self.costs[word_id as usize];
272 rkyv::deserialize::<WordEntry, rkyv::rancor::Error>(archived_entry).unwrap()
273 }
274
275 pub fn lookup_word_ids(&self, category_id: CategoryId) -> &[rkyv::rend::u32_le] {
276 self.category_references[category_id.0].as_slice()
277 }
278}