1use std::io;
2
3use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::dictionary::character_definition::{CategoryId, CharacterDefinition};
8use crate::dictionary::connection_cost_matrix::ConnectionCostMatrix;
9use crate::dictionary::prefix_dictionary::PrefixDictionary;
10use crate::dictionary::unknown_dictionary::UnknownDictionary;
11use crate::mode::Mode;
12
13#[derive(
15 Clone,
16 Copy,
17 Debug,
18 Eq,
19 PartialEq,
20 Serialize,
21 Deserialize,
22 Default,
23 Archive,
24 RkyvSerialize,
25 RkyvDeserialize,
26)]
27
28pub enum LexType {
29 #[default]
31 System,
32 User,
34 Unknown,
36}
37
38#[derive(
39 Clone,
40 Copy,
41 Debug,
42 Eq,
43 PartialEq,
44 Serialize,
45 Deserialize,
46 Archive,
47 RkyvDeserialize,
48 RkyvSerialize,
49)]
50
51pub struct WordId {
52 pub id: u32,
53 pub is_system: bool,
54 pub lex_type: LexType,
55}
56
57impl WordId {
58 pub fn new(lex_type: LexType, id: u32) -> Self {
60 WordId {
61 id,
62 is_system: matches!(lex_type, LexType::System),
63 lex_type,
64 }
65 }
66
67 pub fn is_unknown(&self) -> bool {
68 self.id == u32::MAX || matches!(self.lex_type, LexType::Unknown)
69 }
70
71 pub fn is_system(&self) -> bool {
72 self.is_system
73 }
74
75 pub fn lex_type(&self) -> LexType {
76 self.lex_type
77 }
78}
79
80impl Default for WordId {
81 fn default() -> Self {
82 WordId {
83 id: u32::MAX,
84 is_system: true,
85 lex_type: LexType::System,
86 }
87 }
88}
89
90#[derive(
91 Default,
92 Clone,
93 Copy,
94 Debug,
95 Eq,
96 PartialEq,
97 Serialize,
98 Deserialize,
99 Archive,
100 RkyvSerialize,
101 RkyvDeserialize,
102)]
103
104pub struct WordEntry {
105 pub word_id: WordId,
106 pub word_cost: i16,
107 pub left_id: u16,
108 pub right_id: u16,
109}
110
111impl WordEntry {
112 pub const SERIALIZED_LEN: usize = 10;
113
114 pub fn left_id(&self) -> u32 {
115 self.left_id as u32
116 }
117
118 pub fn right_id(&self) -> u32 {
119 self.right_id as u32
120 }
121
122 pub fn serialize<W: io::Write>(&self, wtr: &mut W) -> io::Result<()> {
123 wtr.write_u32::<LittleEndian>(self.word_id.id)?;
124 wtr.write_i16::<LittleEndian>(self.word_cost)?;
125 wtr.write_u16::<LittleEndian>(self.left_id)?;
126 wtr.write_u16::<LittleEndian>(self.right_id)?;
127 Ok(())
128 }
129
130 pub fn deserialize(data: &[u8], is_system_entry: bool) -> WordEntry {
131 let word_id = WordId::new(
132 if is_system_entry {
133 LexType::System
134 } else {
135 LexType::User
136 },
137 LittleEndian::read_u32(&data[0..4]),
138 );
139 let word_cost = LittleEndian::read_i16(&data[4..6]);
140 let left_id = LittleEndian::read_u16(&data[6..8]);
141 let right_id = LittleEndian::read_u16(&data[8..10]);
142 WordEntry {
143 word_id,
144 word_cost,
145 left_id,
146 right_id,
147 }
148 }
149}
150
151#[derive(Clone, Copy, Debug, Default)]
152pub enum EdgeType {
153 #[default]
154 KNOWN,
155 UNKNOWN,
156 USER,
157 INSERTED,
158}
159
160#[derive(Default, Clone, Debug)]
161pub struct Edge {
162 pub edge_type: EdgeType,
163 pub word_entry: WordEntry,
164
165 pub path_cost: i32,
166 pub left_index: u16, pub start_index: u32,
169 pub stop_index: u32,
170
171 pub kanji_only: bool,
172}
173
174impl Edge {
175 pub fn num_chars(&self) -> usize {
176 (self.stop_index - self.start_index) as usize / 3
177 }
178}
179
180#[derive(Clone, Default)]
181pub struct Lattice {
182 capacity: usize,
183 ends_at: Vec<Vec<Edge>>, char_info_buffer: Vec<CharData>,
185 categories_buffer: Vec<CategoryId>,
186 char_category_cache: Vec<Vec<CategoryId>>,
187}
188
189#[derive(Clone, Copy, Debug, Default)]
190struct CharData {
191 byte_offset: u32,
192 is_kanji: bool,
193 categories_start: u32,
194 categories_len: u16,
195 kanji_run_byte_len: u32,
196}
197
198#[inline]
199pub fn is_kanji(c: char) -> bool {
200 let c = c as u32;
201 (0x4E00..=0x9FAF).contains(&c) || (0x3400..=0x4DBF).contains(&c)
203}
204
205impl Lattice {
206 #[inline]
208 fn create_edge(
209 edge_type: EdgeType,
210 word_entry: WordEntry,
211 start: usize,
212 stop: usize,
213 kanji_only: bool,
214 ) -> Edge {
215 Edge {
216 edge_type,
217 word_entry,
218 left_index: u16::MAX,
219 start_index: start as u32,
220 stop_index: stop as u32,
221 path_cost: i32::MAX,
222 kanji_only,
223 }
224 }
225
226 pub fn clear(&mut self) {
227 for edge_vec in &mut self.ends_at {
228 edge_vec.clear();
229 }
230 self.char_info_buffer.clear();
231 self.categories_buffer.clear();
232 }
233
234 #[inline]
235 fn is_kanji_all(&self, char_idx: usize, byte_len: usize) -> bool {
236 self.char_info_buffer[char_idx].kanji_run_byte_len >= byte_len as u32
237 }
238
239 #[inline]
240 fn get_cached_category(&self, char_idx: usize, category_ord: usize) -> CategoryId {
241 let char_data = &self.char_info_buffer[char_idx];
242 self.categories_buffer[char_data.categories_start as usize + category_ord]
243 }
244
245 fn set_capacity(&mut self, text_len: usize) {
246 self.clear();
247 if self.capacity <= text_len {
248 self.capacity = text_len;
249 self.ends_at.resize(text_len + 1, Vec::new());
250 }
251 for vec in &mut self.ends_at {
252 vec.clear();
253 }
254 }
255
256 #[inline(never)]
257 pub fn set_text(
261 &mut self,
262 dict: &PrefixDictionary,
263 user_dict: &Option<&PrefixDictionary>,
264 char_definitions: &CharacterDefinition,
265 unknown_dictionary: &UnknownDictionary,
266 cost_matrix: &ConnectionCostMatrix,
267 text: &str,
268 search_mode: &Mode,
269 ) {
270 let len = text.len();
271 self.set_capacity(len);
272
273 self.char_info_buffer.clear();
275 self.categories_buffer.clear();
276
277 if self.char_category_cache.is_empty() {
278 self.char_category_cache.resize(256, Vec::new());
279 }
280
281 for (byte_offset, c) in text.char_indices() {
282 let categories_start = self.categories_buffer.len() as u32;
283
284 if (c as u32) < 256 {
285 let cached = &mut self.char_category_cache[c as usize];
286 if cached.is_empty() {
287 let cats = char_definitions.lookup_categories(c);
288 for &category in cats {
289 cached.push(category);
290 }
291 }
292 for &category in cached.iter() {
293 self.categories_buffer.push(category);
294 }
295 } else {
296 let categories = char_definitions.lookup_categories(c);
297 for &category in categories {
298 self.categories_buffer.push(category);
299 }
300 }
301
302 let categories_len = (self.categories_buffer.len() as u32 - categories_start) as u16;
303
304 self.char_info_buffer.push(CharData {
305 byte_offset: byte_offset as u32,
306 is_kanji: is_kanji(c),
307 categories_start,
308 categories_len,
309 kanji_run_byte_len: 0,
310 });
311 }
312 self.char_info_buffer.push(CharData {
314 byte_offset: len as u32,
315 is_kanji: false,
316 categories_start: 0,
317 categories_len: 0,
318 kanji_run_byte_len: 0,
319 });
320
321 for i in (0..self.char_info_buffer.len() - 1).rev() {
323 if self.char_info_buffer[i].is_kanji {
324 let next_byte_offset = self.char_info_buffer[i + 1].byte_offset;
325 let char_byte_len = next_byte_offset - self.char_info_buffer[i].byte_offset;
326 self.char_info_buffer[i].kanji_run_byte_len =
327 char_byte_len + self.char_info_buffer[i + 1].kanji_run_byte_len;
328 } else {
329 self.char_info_buffer[i].kanji_run_byte_len = 0;
330 }
331 }
332
333 let mut start_edge = Edge::default();
334 start_edge.path_cost = 0;
335 start_edge.left_index = u16::MAX;
336 self.ends_at[0].push(start_edge);
337
338 let mut unknown_word_end: Option<usize> = None;
340
341 for char_idx in 0..self.char_info_buffer.len() - 1 {
342 let start = self.char_info_buffer[char_idx].byte_offset as usize;
343
344 if self.ends_at[start].is_empty() {
347 continue;
348 }
349
350 let suffix = &text[start..];
351
352 let mut found: bool = false;
353
354 if user_dict.is_some() {
356 let dict = user_dict.as_ref().unwrap();
357 for (prefix_len, word_entry) in dict.prefix(suffix) {
358 let kanji_only = self.is_kanji_all(char_idx, prefix_len);
359 let edge = Self::create_edge(
360 EdgeType::KNOWN,
361 word_entry,
362 start,
363 start + prefix_len,
364 kanji_only,
365 );
366 self.add_edge_in_lattice(edge, cost_matrix, search_mode);
367 found = true;
368 }
369 }
370
371 for (prefix_len, word_entry) in dict.prefix(suffix) {
374 let kanji_only = self.is_kanji_all(char_idx, prefix_len);
375 let edge = Self::create_edge(
376 EdgeType::KNOWN,
377 word_entry,
378 start,
379 start + prefix_len,
380 kanji_only,
381 );
382 self.add_edge_in_lattice(edge, cost_matrix, search_mode);
383 found = true;
384 }
385
386 if (search_mode.is_search()
388 || unknown_word_end.map(|index| index <= start).unwrap_or(true))
389 && char_idx < self.char_info_buffer.len() - 1
390 {
391 let num_categories = self.char_info_buffer[char_idx].categories_len as usize;
392 for category_ord in 0..num_categories {
393 let category = self.get_cached_category(char_idx, category_ord);
394 unknown_word_end = self.process_unknown_word(
395 char_definitions,
396 unknown_dictionary,
397 cost_matrix,
398 search_mode,
399 category,
400 category_ord,
401 unknown_word_end,
402 start,
403 char_idx,
404 found,
405 );
406 }
407 }
408 }
409
410 if !self.ends_at[len].is_empty() {
412 let mut eos_edge = Edge::default();
413 eos_edge.start_index = len as u32;
414 eos_edge.stop_index = len as u32;
415 let left_edges = &self.ends_at[len];
417 let mut best_cost = i32::MAX;
418 let mut best_left = None;
419 let right_left_id = 0; for (i, left_edge) in left_edges.iter().enumerate() {
422 let left_right_id = left_edge.word_entry.right_id();
423 let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
424 let path_cost = left_edge.path_cost.saturating_add(conn_cost);
425 if path_cost < best_cost {
426 best_cost = path_cost;
427 best_left = Some(i as u16);
428 }
429 }
430 if let Some(left_idx) = best_left {
431 eos_edge.left_index = left_idx;
432 eos_edge.path_cost = best_cost;
433 self.ends_at[len].push(eos_edge);
434 }
435 }
436 }
437
438 #[allow(clippy::too_many_arguments)]
439 fn process_unknown_word(
440 &mut self,
441 char_definitions: &CharacterDefinition,
442 unknown_dictionary: &UnknownDictionary,
443 cost_matrix: &ConnectionCostMatrix,
444 search_mode: &Mode,
445 category: CategoryId,
446 category_ord: usize,
447 unknown_word_index: Option<usize>,
448 start: usize,
449 char_idx: usize,
450 found: bool,
451 ) -> Option<usize> {
452 let mut unknown_word_num_chars: usize = 0;
453 let category_data = char_definitions.lookup_definition(category);
454 if category_data.invoke || !found {
455 unknown_word_num_chars = 1;
456 if category_data.group {
457 for i in 1.. {
458 let next_idx = char_idx + i;
459 if next_idx >= self.char_info_buffer.len() - 1 {
460 break;
461 }
462 let num_categories = self.char_info_buffer[next_idx].categories_len as usize;
463 let mut found_cat = false;
464 if category_ord < num_categories {
465 let cat = self.get_cached_category(next_idx, category_ord);
466 if cat == category {
467 unknown_word_num_chars += 1;
468 found_cat = true;
469 }
470 }
471 if !found_cat {
472 break;
473 }
474 }
475 }
476 }
477 if unknown_word_num_chars > 0 {
478 let byte_end_offset =
479 self.char_info_buffer[char_idx + unknown_word_num_chars].byte_offset;
480 let byte_len = byte_end_offset as usize - start;
481
482 let kanji_only = self.is_kanji_all(char_idx, byte_len);
484
485 for &word_id in unknown_dictionary.lookup_word_ids(category) {
486 let word_entry = unknown_dictionary.word_entry(word_id);
487 let edge = Self::create_edge(
488 EdgeType::UNKNOWN,
489 word_entry,
490 start,
491 start + byte_len,
492 kanji_only,
493 );
494 self.add_edge_in_lattice(edge, cost_matrix, search_mode);
495 }
496 return Some(start + byte_len);
497 }
498 unknown_word_index
499 }
500
501 fn add_edge_in_lattice(
503 &mut self,
504 mut edge: Edge,
505 cost_matrix: &ConnectionCostMatrix,
506 mode: &Mode,
507 ) {
508 let start_index = edge.start_index as usize;
509 let stop_index = edge.stop_index as usize;
510
511 let left_edges = &self.ends_at[start_index];
512 if left_edges.is_empty() {
513 return;
514 }
515
516 let mut best_cost = i32::MAX;
517 let mut best_left = None;
518 let right_left_id = edge.word_entry.left_id();
519
520 for (i, left_edge) in left_edges.iter().enumerate() {
521 let left_right_id = left_edge.word_entry.right_id();
522 let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
523 let penalty = mode.penalty_cost(left_edge);
524 let total_cost = left_edge
525 .path_cost
526 .saturating_add(conn_cost)
527 .saturating_add(penalty);
528
529 if total_cost < best_cost {
530 best_cost = total_cost;
531 best_left = Some(i as u16);
532 }
533 }
534
535 if let Some(best_left_idx) = best_left {
536 edge.path_cost = best_cost.saturating_add(edge.word_entry.word_cost as i32);
537 edge.left_index = best_left_idx;
538 self.ends_at[stop_index].push(edge);
539 }
540 }
541
542 pub fn tokens_offset(&self) -> Vec<(usize, WordId)> {
543 let mut offsets = Vec::new();
544
545 if self.ends_at.is_empty() {
546 return offsets;
547 }
548
549 let mut last_idx = self.ends_at.len() - 1;
550 while last_idx > 0 && self.ends_at[last_idx].is_empty() {
551 last_idx -= 1;
552 }
553
554 if self.ends_at[last_idx].is_empty() {
555 return offsets;
556 }
557
558 let idx = self.ends_at[last_idx].len() - 1;
559 let mut edge = &self.ends_at[last_idx][idx];
560
561 if edge.left_index == u16::MAX {
562 return offsets;
563 }
564
565 loop {
566 if edge.left_index == u16::MAX {
567 break;
568 }
569
570 offsets.push((edge.start_index as usize, edge.word_entry.word_id));
571
572 let left_idx = edge.left_index as usize;
573 let start_idx = edge.start_index as usize;
574
575 edge = &self.ends_at[start_idx][left_idx];
576 }
577
578 offsets.reverse();
579 offsets.pop(); offsets
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use crate::viterbi::{LexType, WordEntry, WordId};
588
589 #[test]
590 fn test_word_entry() {
591 let mut buffer = Vec::new();
592 let word_entry = WordEntry {
593 word_id: WordId {
594 id: 1u32,
595 is_system: true,
596 lex_type: LexType::System,
597 },
598 word_cost: -17i16,
599 left_id: 1411u16,
600 right_id: 1412u16,
601 };
602 word_entry.serialize(&mut buffer).unwrap();
603 assert_eq!(WordEntry::SERIALIZED_LEN, buffer.len());
604 let word_entry2 = WordEntry::deserialize(&buffer[..], true);
605 assert_eq!(word_entry, word_entry2);
606 }
607}