1use std::io;
2
3use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
4use serde::{Deserialize, Serialize};
5
6use crate::dictionary::character_definition::{CategoryId, CharacterDefinition};
7use crate::dictionary::connection_cost_matrix::ConnectionCostMatrix;
8use crate::dictionary::prefix_dictionary::PrefixDictionary;
9use crate::dictionary::unknown_dictionary::UnknownDictionary;
10use crate::mode::Mode;
11
12const EOS_NODE: EdgeId = EdgeId(1u32);
13
14#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize, Default)]
16pub enum LexType {
17 #[default]
19 System,
20 User,
22 Unknown,
24}
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
27pub struct WordId {
28 pub id: u32,
29 pub is_system: bool,
30 pub lex_type: LexType,
31}
32
33impl WordId {
34 pub fn new(lex_type: LexType, id: u32) -> Self {
36 WordId {
37 id,
38 is_system: matches!(lex_type, LexType::System),
39 lex_type,
40 }
41 }
42
43 pub fn is_unknown(&self) -> bool {
44 self.id == u32::MAX || matches!(self.lex_type, LexType::Unknown)
45 }
46
47 pub fn is_system(&self) -> bool {
48 self.is_system
49 }
50
51 pub fn lex_type(&self) -> LexType {
52 self.lex_type
53 }
54}
55
56impl Default for WordId {
57 fn default() -> Self {
58 WordId {
59 id: u32::MAX,
60 is_system: true,
61 lex_type: LexType::System,
62 }
63 }
64}
65
66#[derive(Default, Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
67pub struct WordEntry {
68 pub word_id: WordId,
69 pub word_cost: i16,
70 pub left_id: u16,
71 pub right_id: u16,
72}
73
74impl WordEntry {
75 pub const SERIALIZED_LEN: usize = 10;
76
77 pub fn left_id(&self) -> u32 {
78 self.left_id as u32
79 }
80
81 pub fn right_id(&self) -> u32 {
82 self.right_id as u32
83 }
84
85 pub fn serialize<W: io::Write>(&self, wtr: &mut W) -> io::Result<()> {
86 wtr.write_u32::<LittleEndian>(self.word_id.id)?;
87 wtr.write_i16::<LittleEndian>(self.word_cost)?;
88 wtr.write_u16::<LittleEndian>(self.left_id)?;
89 wtr.write_u16::<LittleEndian>(self.right_id)?;
90 Ok(())
91 }
92
93 pub fn deserialize(data: &[u8], is_system_entry: bool) -> WordEntry {
94 let word_id = WordId::new(
95 if is_system_entry {
96 LexType::System
97 } else {
98 LexType::User
99 },
100 LittleEndian::read_u32(&data[0..4]),
101 );
102 let word_cost = LittleEndian::read_i16(&data[4..6]);
103 let left_id = LittleEndian::read_u16(&data[6..8]);
104 let right_id = LittleEndian::read_u16(&data[8..10]);
105 WordEntry {
106 word_id,
107 word_cost,
108 left_id,
109 right_id,
110 }
111 }
112}
113
114#[derive(Clone, Copy, Debug, Default)]
115pub enum EdgeType {
116 #[default]
117 KNOWN,
118 UNKNOWN,
119 USER,
120 INSERTED,
121}
122
123#[derive(Eq, PartialEq, Clone, Copy, Debug)]
124pub struct EdgeId(pub u32);
125
126#[derive(Default, Clone, Debug)]
127pub struct Edge {
128 pub edge_type: EdgeType,
129 pub word_entry: WordEntry,
130
131 pub path_cost: i32,
132 pub left_edge: Option<EdgeId>,
133
134 pub start_index: u32,
135 pub stop_index: u32,
136
137 pub kanji_only: bool,
138}
139
140impl Edge {
141 pub fn num_chars(&self) -> usize {
143 (self.stop_index - self.start_index) as usize / 3
144 }
145}
146
147#[derive(Clone, Default)]
148pub struct Lattice {
149 capacity: usize,
150 edges: Vec<Edge>,
151 starts_at: Vec<Vec<EdgeId>>,
152 ends_at: Vec<Vec<EdgeId>>,
153 edge_buffer: Vec<Edge>,
155 edge_id_buffer: Vec<EdgeId>,
156}
157
158#[inline]
159fn is_kanji(c: char) -> bool {
160 let c = c as u32;
161 (19968..=40879).contains(&c)
163}
164
165#[inline]
166fn is_kanji_only(s: &str) -> bool {
167 !s.is_empty() && s.chars().all(is_kanji)
169}
170
171impl Lattice {
172 #[inline]
174 fn create_edge(
175 edge_type: EdgeType,
176 word_entry: WordEntry,
177 start: usize,
178 stop: usize,
179 kanji_only: bool,
180 ) -> Edge {
181 Edge {
182 edge_type,
183 word_entry,
184 left_edge: None,
185 start_index: start as u32,
186 stop_index: stop as u32,
187 path_cost: i32::MAX,
188 kanji_only,
189 }
190 }
191
192 pub fn clear(&mut self) {
193 for edge_vec in &mut self.starts_at {
194 edge_vec.clear();
195 }
196 for edge_vec in &mut self.ends_at {
197 edge_vec.clear();
198 }
199 self.edges.clear();
200 self.edge_buffer.clear();
202 self.edge_id_buffer.clear();
203 }
204
205 pub fn get_edge_buffer(&mut self) -> &mut Vec<Edge> {
207 self.edge_buffer.clear();
208 &mut self.edge_buffer
209 }
210
211 pub fn get_edge_id_buffer(&mut self) -> &mut Vec<EdgeId> {
213 self.edge_id_buffer.clear();
214 &mut self.edge_id_buffer
215 }
216
217 fn set_capacity(&mut self, text_len: usize) {
218 self.clear();
219 if self.capacity < text_len {
220 self.capacity = text_len;
221 self.edges.clear();
222 self.starts_at.resize(text_len + 1, Vec::new());
223 self.ends_at.resize(text_len + 1, Vec::new());
224 }
225 }
226
227 #[inline(never)]
228 pub fn set_text(
229 &mut self,
230 dict: &PrefixDictionary,
231 user_dict: &Option<&PrefixDictionary>,
232 char_definitions: &CharacterDefinition,
233 unknown_dictionary: &UnknownDictionary,
234 text: &str,
235 search_mode: &Mode,
236 ) {
237 let len = text.len();
238 self.set_capacity(len);
239
240 let start_edge_id = self.add_edge(Edge::default());
241 let end_edge_id = self.add_edge(Edge::default());
242
243 assert_eq!(EOS_NODE, end_edge_id);
244 self.ends_at[0].push(start_edge_id);
245 self.starts_at[len].push(end_edge_id);
246
247 let mut unknown_word_end: Option<usize> = None;
249
250 for start in 0..len {
251 if self.ends_at[start].is_empty() {
254 continue;
255 }
256
257 let suffix = &text[start..];
258
259 let mut found: bool = false;
260
261 if user_dict.is_some() {
263 let dict = user_dict.as_ref().unwrap();
264 for (prefix_len, word_entry) in dict.prefix(suffix) {
265 let edge = Self::create_edge(
266 EdgeType::KNOWN,
267 word_entry,
268 start,
269 start + prefix_len,
270 is_kanji_only(&suffix[..prefix_len]),
271 );
272 self.add_edge_in_lattice(edge);
273 found = true;
274 }
275 }
276
277 for (prefix_len, word_entry) in dict.prefix(suffix) {
280 let edge = Self::create_edge(
281 EdgeType::KNOWN,
282 word_entry,
283 start,
284 start + prefix_len,
285 is_kanji_only(&suffix[..prefix_len]),
286 );
287 self.add_edge_in_lattice(edge);
288 found = true;
289 }
290
291 if (search_mode.is_search()
293 || unknown_word_end.map(|index| index <= start).unwrap_or(true))
294 && let Some(first_char) = suffix.chars().next()
295 {
296 let categories = char_definitions.lookup_categories(first_char);
297 for (category_ord, &category) in categories.iter().enumerate() {
298 unknown_word_end = self.process_unknown_word(
299 char_definitions,
300 unknown_dictionary,
301 category,
302 category_ord,
303 unknown_word_end,
304 start,
305 suffix,
306 found,
307 );
308 }
309 }
310 }
311 }
312
313 #[allow(clippy::too_many_arguments)]
314 fn process_unknown_word(
315 &mut self,
316 char_definitions: &CharacterDefinition,
317 unknown_dictionary: &UnknownDictionary,
318 category: CategoryId,
319 category_ord: usize,
320 unknown_word_index: Option<usize>,
321 start: usize,
322 suffix: &str,
323 found: bool,
324 ) -> Option<usize> {
325 let mut unknown_word_num_chars: usize = 0;
326 let category_data = char_definitions.lookup_definition(category);
327 if category_data.invoke || !found {
328 unknown_word_num_chars = 1;
329 if category_data.group {
330 for c in suffix.chars().skip(1) {
331 let categories = char_definitions.lookup_categories(c);
332 if categories.len() > category_ord && categories[category_ord] == category {
333 unknown_word_num_chars += 1;
334 } else {
335 break;
336 }
337 }
338 }
339 }
340 if unknown_word_num_chars > 0 {
341 let byte_end = suffix
343 .char_indices()
344 .nth(unknown_word_num_chars)
345 .map_or(suffix.len(), |(pos, _)| pos);
346 let unknown_word = &suffix[..byte_end];
347 for &word_id in unknown_dictionary.lookup_word_ids(category) {
348 let word_entry = unknown_dictionary.word_entry(word_id);
349 let edge = Self::create_edge(
350 EdgeType::UNKNOWN,
351 word_entry,
352 start,
353 start + unknown_word.len(),
354 is_kanji_only(unknown_word),
355 );
356 self.add_edge_in_lattice(edge);
357 }
358 return Some(start + unknown_word.len());
359 }
360 unknown_word_index
361 }
362
363 fn add_edge_in_lattice(&mut self, edge: Edge) {
364 let start_index = edge.start_index as usize;
365 let stop_index = edge.stop_index as usize;
366 let edge_id = self.add_edge(edge);
367 self.starts_at[start_index].push(edge_id);
368 self.ends_at[stop_index].push(edge_id);
369 }
370
371 fn add_edge(&mut self, edge: Edge) -> EdgeId {
372 let edge_id = EdgeId(self.edges.len() as u32);
373 self.edges.push(edge);
374 edge_id
375 }
376
377 pub fn edge(&self, edge_id: EdgeId) -> &Edge {
378 &self.edges[edge_id.0 as usize]
379 }
380
381 #[inline(never)]
382 pub fn calculate_path_costs(&mut self, cost_matrix: &ConnectionCostMatrix, mode: &Mode) {
383 let text_len = self.starts_at.len();
384 for i in 0..text_len {
385 let left_edge_ids = &self.ends_at[i];
386 let right_edge_ids = &self.starts_at[i];
387
388 for &right_edge_id in right_edge_ids {
389 let right_edge = &self.edges[right_edge_id.0 as usize];
391 let right_word_entry = right_edge.word_entry;
392 let right_left_id = right_word_entry.left_id();
393
394 let mut best_cost = i32::MAX;
396 let mut best_left = None;
397
398 for &left_edge_id in left_edge_ids {
399 let left_edge = &self.edges[left_edge_id.0 as usize];
400 let left_right_id = left_edge.word_entry.right_id();
401
402 let mut path_cost =
404 left_edge.path_cost + cost_matrix.cost(left_right_id, right_left_id);
405 path_cost += mode.penalty_cost(left_edge);
406
407 if path_cost < best_cost {
409 best_cost = path_cost;
410 best_left = Some(left_edge_id);
411 }
412 }
413
414 if let Some(best_left_id) = best_left {
416 let edge = &mut self.edges[right_edge_id.0 as usize];
417 edge.left_edge = Some(best_left_id);
418 edge.path_cost = right_word_entry.word_cost as i32 + best_cost;
419 }
420 }
421 }
422 }
423
424 pub fn tokens_offset(&self) -> Vec<(usize, WordId)> {
425 let mut offsets = Vec::new();
426 let mut edge_id = EOS_NODE;
427 let _edge = self.edge(EOS_NODE);
428 loop {
429 let edge = self.edge(edge_id);
430 if let Some(left_edge_id) = edge.left_edge {
431 offsets.push((edge.start_index as usize, edge.word_entry.word_id));
432 edge_id = left_edge_id;
433 } else {
434 break;
435 }
436 }
437 offsets.reverse();
438 offsets.pop();
439 offsets
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use crate::viterbi::{LexType, WordEntry, WordId};
446
447 #[test]
448 fn test_word_entry() {
449 let mut buffer = Vec::new();
450 let word_entry = WordEntry {
451 word_id: WordId {
452 id: 1u32,
453 is_system: true,
454 lex_type: LexType::System,
455 },
456 word_cost: -17i16,
457 left_id: 1411u16,
458 right_id: 1412u16,
459 };
460 word_entry.serialize(&mut buffer).unwrap();
461 assert_eq!(WordEntry::SERIALIZED_LEN, buffer.len());
462 let word_entry2 = WordEntry::deserialize(&buffer[..], true);
463 assert_eq!(word_entry, word_entry2);
464 }
465}