1use std::io;
2
3use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::dictionary::character_definition::{CategoryId, CharacterDefinition};
8use crate::dictionary::connection_cost_matrix::ConnectionCostMatrix;
9use crate::dictionary::prefix_dictionary::PrefixDictionary;
10use crate::dictionary::unknown_dictionary::UnknownDictionary;
11use crate::mode::Mode;
12
13const EOS_NODE: EdgeId = EdgeId(1u32);
14
15#[derive(
17 Clone,
18 Copy,
19 Debug,
20 Eq,
21 PartialEq,
22 Serialize,
23 Deserialize,
24 Default,
25 Archive,
26 RkyvSerialize,
27 RkyvDeserialize,
28)]
29
30pub enum LexType {
31 #[default]
33 System,
34 User,
36 Unknown,
38}
39
40#[derive(
41 Clone,
42 Copy,
43 Debug,
44 Eq,
45 PartialEq,
46 Serialize,
47 Deserialize,
48 Archive,
49 RkyvDeserialize,
50 RkyvSerialize,
51)]
52
53pub struct WordId {
54 pub id: u32,
55 pub is_system: bool,
56 pub lex_type: LexType,
57}
58
59impl WordId {
60 pub fn new(lex_type: LexType, id: u32) -> Self {
62 WordId {
63 id,
64 is_system: matches!(lex_type, LexType::System),
65 lex_type,
66 }
67 }
68
69 pub fn is_unknown(&self) -> bool {
70 self.id == u32::MAX || matches!(self.lex_type, LexType::Unknown)
71 }
72
73 pub fn is_system(&self) -> bool {
74 self.is_system
75 }
76
77 pub fn lex_type(&self) -> LexType {
78 self.lex_type
79 }
80}
81
82impl Default for WordId {
83 fn default() -> Self {
84 WordId {
85 id: u32::MAX,
86 is_system: true,
87 lex_type: LexType::System,
88 }
89 }
90}
91
92#[derive(
93 Default,
94 Clone,
95 Copy,
96 Debug,
97 Eq,
98 PartialEq,
99 Serialize,
100 Deserialize,
101 Archive,
102 RkyvSerialize,
103 RkyvDeserialize,
104)]
105
106pub struct WordEntry {
107 pub word_id: WordId,
108 pub word_cost: i16,
109 pub left_id: u16,
110 pub right_id: u16,
111}
112
113impl WordEntry {
114 pub const SERIALIZED_LEN: usize = 10;
115
116 pub fn left_id(&self) -> u32 {
117 self.left_id as u32
118 }
119
120 pub fn right_id(&self) -> u32 {
121 self.right_id as u32
122 }
123
124 pub fn serialize<W: io::Write>(&self, wtr: &mut W) -> io::Result<()> {
125 wtr.write_u32::<LittleEndian>(self.word_id.id)?;
126 wtr.write_i16::<LittleEndian>(self.word_cost)?;
127 wtr.write_u16::<LittleEndian>(self.left_id)?;
128 wtr.write_u16::<LittleEndian>(self.right_id)?;
129 Ok(())
130 }
131
132 pub fn deserialize(data: &[u8], is_system_entry: bool) -> WordEntry {
133 let word_id = WordId::new(
134 if is_system_entry {
135 LexType::System
136 } else {
137 LexType::User
138 },
139 LittleEndian::read_u32(&data[0..4]),
140 );
141 let word_cost = LittleEndian::read_i16(&data[4..6]);
142 let left_id = LittleEndian::read_u16(&data[6..8]);
143 let right_id = LittleEndian::read_u16(&data[8..10]);
144 WordEntry {
145 word_id,
146 word_cost,
147 left_id,
148 right_id,
149 }
150 }
151}
152
153#[derive(Clone, Copy, Debug, Default)]
154pub enum EdgeType {
155 #[default]
156 KNOWN,
157 UNKNOWN,
158 USER,
159 INSERTED,
160}
161
162#[derive(Eq, PartialEq, Clone, Copy, Debug)]
163pub struct EdgeId(pub u32);
164
165#[derive(Default, Clone, Debug)]
166pub struct Edge {
167 pub edge_type: EdgeType,
168 pub word_entry: WordEntry,
169
170 pub path_cost: i32,
171 pub left_edge: Option<EdgeId>,
172
173 pub start_index: u32,
174 pub stop_index: u32,
175
176 pub kanji_only: bool,
177}
178
179impl Edge {
180 pub fn num_chars(&self) -> usize {
182 (self.stop_index - self.start_index) as usize / 3
183 }
184}
185
186#[derive(Clone, Default)]
187pub struct Lattice {
188 capacity: usize,
189 edges: Vec<Edge>,
190 starts_at: Vec<Vec<EdgeId>>,
191 ends_at: Vec<Vec<EdgeId>>,
192 edge_buffer: Vec<Edge>,
194 edge_id_buffer: Vec<EdgeId>,
195 left_cache_buffer: Vec<(u32, i32, i32, EdgeId)>,
196 char_info_buffer: Vec<CharData>,
197 categories_buffer: Vec<CategoryId>,
198 char_category_cache: Vec<Vec<CategoryId>>,
200}
201
202#[derive(Clone, Copy, Debug, Default)]
203struct CharData {
204 byte_offset: u32,
205 is_kanji: bool,
206 categories_start: u32,
207 categories_len: u16,
208 kanji_run_byte_len: u32,
209}
210
211#[inline]
212pub fn is_kanji(c: char) -> bool {
213 let c = c as u32;
214 (0x4E00..=0x9FAF).contains(&c) || (0x3400..=0x4DBF).contains(&c)
216}
217
218impl Lattice {
219 #[inline]
221 fn create_edge(
222 edge_type: EdgeType,
223 word_entry: WordEntry,
224 start: usize,
225 stop: usize,
226 kanji_only: bool,
227 ) -> Edge {
228 Edge {
229 edge_type,
230 word_entry,
231 left_edge: None,
232 start_index: start as u32,
233 stop_index: stop as u32,
234 path_cost: i32::MAX,
235 kanji_only,
236 }
237 }
238
239 pub fn clear(&mut self) {
240 for edge_vec in &mut self.starts_at {
241 edge_vec.clear();
242 }
243 for edge_vec in &mut self.ends_at {
244 edge_vec.clear();
245 }
246 self.edges.clear();
247 self.edge_buffer.clear();
249 self.edge_id_buffer.clear();
250 self.left_cache_buffer.clear();
251 self.char_info_buffer.clear();
252 self.categories_buffer.clear();
253 }
254
255 pub fn get_edge_buffer(&mut self) -> &mut Vec<Edge> {
257 self.edge_buffer.clear();
258 &mut self.edge_buffer
259 }
260
261 pub fn get_edge_id_buffer(&mut self) -> &mut Vec<EdgeId> {
263 self.edge_id_buffer.clear();
264 &mut self.edge_id_buffer
265 }
266
267 #[inline]
268 fn is_kanji_all(&self, char_idx: usize, byte_len: usize) -> bool {
269 self.char_info_buffer[char_idx].kanji_run_byte_len >= byte_len as u32
270 }
271
272 #[inline]
273 fn get_cached_category(&self, char_idx: usize, category_ord: usize) -> CategoryId {
274 let char_data = &self.char_info_buffer[char_idx];
275 self.categories_buffer[char_data.categories_start as usize + category_ord]
276 }
277
278 fn set_capacity(&mut self, text_len: usize) {
279 self.clear();
280 if self.capacity < text_len {
281 self.capacity = text_len;
282 self.edges.clear();
283 self.starts_at.resize(text_len + 1, Vec::new());
284 self.ends_at.resize(text_len + 1, Vec::new());
285 }
286 }
287
288 #[inline(never)]
289 pub fn set_text(
290 &mut self,
291 dict: &PrefixDictionary,
292 user_dict: &Option<&PrefixDictionary>,
293 char_definitions: &CharacterDefinition,
294 unknown_dictionary: &UnknownDictionary,
295 text: &str,
296 search_mode: &Mode,
297 ) {
298 let len = text.len();
299 self.set_capacity(len);
300
301 self.char_info_buffer.clear();
303 self.categories_buffer.clear();
304
305 if self.char_category_cache.is_empty() {
306 self.char_category_cache.resize(256, Vec::new());
307 }
308
309 for (byte_offset, c) in text.char_indices() {
310 let categories_start = self.categories_buffer.len() as u32;
311
312 if (c as u32) < 256 {
313 let cached = &mut self.char_category_cache[c as usize];
314 if cached.is_empty() {
315 let cats = char_definitions.lookup_categories(c);
316 for &category in cats {
317 cached.push(category);
318 }
319 }
320 for &category in cached.iter() {
321 self.categories_buffer.push(category);
322 }
323 } else {
324 let categories = char_definitions.lookup_categories(c);
325 for &category in categories {
326 self.categories_buffer.push(category);
327 }
328 }
329
330 let categories_len = (self.categories_buffer.len() as u32 - categories_start) as u16;
331
332 self.char_info_buffer.push(CharData {
333 byte_offset: byte_offset as u32,
334 is_kanji: is_kanji(c),
335 categories_start,
336 categories_len,
337 kanji_run_byte_len: 0,
338 });
339 }
340 self.char_info_buffer.push(CharData {
342 byte_offset: len as u32,
343 is_kanji: false,
344 categories_start: 0,
345 categories_len: 0,
346 kanji_run_byte_len: 0,
347 });
348
349 for i in (0..self.char_info_buffer.len() - 1).rev() {
351 if self.char_info_buffer[i].is_kanji {
352 let next_byte_offset = self.char_info_buffer[i + 1].byte_offset;
353 let char_byte_len = next_byte_offset - self.char_info_buffer[i].byte_offset;
354 self.char_info_buffer[i].kanji_run_byte_len =
355 char_byte_len + self.char_info_buffer[i + 1].kanji_run_byte_len;
356 } else {
357 self.char_info_buffer[i].kanji_run_byte_len = 0;
358 }
359 }
360
361 let start_edge_id = self.add_edge(Edge::default());
362 let end_edge_id = self.add_edge(Edge::default());
363
364 assert_eq!(EOS_NODE, end_edge_id);
365 self.ends_at[0].push(start_edge_id);
366 self.starts_at[len].push(end_edge_id);
367
368 let mut unknown_word_end: Option<usize> = None;
370
371 for char_idx in 0..self.char_info_buffer.len() - 1 {
372 let start = self.char_info_buffer[char_idx].byte_offset as usize;
373
374 if self.ends_at[start].is_empty() {
377 continue;
378 }
379
380 let suffix = &text[start..];
381
382 let mut found: bool = false;
383
384 if user_dict.is_some() {
386 let dict = user_dict.as_ref().unwrap();
387 for (prefix_len, word_entry) in dict.prefix(suffix) {
388 let kanji_only = self.is_kanji_all(char_idx, prefix_len);
389 let edge = Self::create_edge(
390 EdgeType::KNOWN,
391 word_entry,
392 start,
393 start + prefix_len,
394 kanji_only,
395 );
396 self.add_edge_in_lattice(edge);
397 found = true;
398 }
399 }
400
401 for (prefix_len, word_entry) in dict.prefix(suffix) {
404 let kanji_only = self.is_kanji_all(char_idx, prefix_len);
405 let edge = Self::create_edge(
406 EdgeType::KNOWN,
407 word_entry,
408 start,
409 start + prefix_len,
410 kanji_only,
411 );
412 self.add_edge_in_lattice(edge);
413 found = true;
414 }
415
416 if (search_mode.is_search()
418 || unknown_word_end.map(|index| index <= start).unwrap_or(true))
419 && char_idx < self.char_info_buffer.len() - 1
420 {
421 let num_categories = self.char_info_buffer[char_idx].categories_len as usize;
422 for category_ord in 0..num_categories {
423 let category = self.get_cached_category(char_idx, category_ord);
424 unknown_word_end = self.process_unknown_word(
425 char_definitions,
426 unknown_dictionary,
427 category,
428 category_ord,
429 unknown_word_end,
430 start,
431 char_idx,
432 found,
433 );
434 }
435 }
436 }
437 }
438
439 #[allow(clippy::too_many_arguments)]
440 fn process_unknown_word(
441 &mut self,
442 char_definitions: &CharacterDefinition,
443 unknown_dictionary: &UnknownDictionary,
444 category: CategoryId,
445 category_ord: usize,
446 unknown_word_index: Option<usize>,
447 start: usize,
448 char_idx: usize,
449 found: bool,
450 ) -> Option<usize> {
451 let mut unknown_word_num_chars: usize = 0;
452 let category_data = char_definitions.lookup_definition(category);
453 if category_data.invoke || !found {
454 unknown_word_num_chars = 1;
455 if category_data.group {
456 for i in 1.. {
457 let next_idx = char_idx + i;
458 if next_idx >= self.char_info_buffer.len() - 1 {
459 break;
460 }
461 let num_categories = self.char_info_buffer[next_idx].categories_len as usize;
462 let mut found_cat = false;
463 if category_ord < num_categories {
464 let cat = self.get_cached_category(next_idx, category_ord);
465 if cat == category {
466 unknown_word_num_chars += 1;
467 found_cat = true;
468 }
469 }
470 if !found_cat {
471 break;
472 }
473 }
474 }
475 }
476 if unknown_word_num_chars > 0 {
477 let byte_end_offset =
478 self.char_info_buffer[char_idx + unknown_word_num_chars].byte_offset;
479 let byte_len = byte_end_offset as usize - start;
480
481 let kanji_only = self.is_kanji_all(char_idx, byte_len);
483
484 for &word_id in unknown_dictionary.lookup_word_ids(category) {
485 let word_entry = unknown_dictionary.word_entry(word_id);
486 let edge = Self::create_edge(
487 EdgeType::UNKNOWN,
488 word_entry,
489 start,
490 start + byte_len,
491 kanji_only,
492 );
493 self.add_edge_in_lattice(edge);
494 }
495 return Some(start + byte_len);
496 }
497 unknown_word_index
498 }
499
500 fn add_edge_in_lattice(&mut self, edge: Edge) {
501 let start_index = edge.start_index as usize;
502 let stop_index = edge.stop_index as usize;
503 let edge_id = self.add_edge(edge);
504 self.starts_at[start_index].push(edge_id);
505 self.ends_at[stop_index].push(edge_id);
506 }
507
508 fn add_edge(&mut self, edge: Edge) -> EdgeId {
509 let edge_id = EdgeId(self.edges.len() as u32);
510 self.edges.push(edge);
511 edge_id
512 }
513
514 pub fn edge(&self, edge_id: EdgeId) -> &Edge {
515 &self.edges[edge_id.0 as usize]
516 }
517
518 #[inline(never)]
519 pub fn calculate_path_costs(&mut self, cost_matrix: &ConnectionCostMatrix, mode: &Mode) {
520 let text_len = self.starts_at.len();
521 for i in 0..text_len {
522 let left_edge_ids = &self.ends_at[i];
523 let right_edge_ids = &self.starts_at[i];
524
525 if right_edge_ids.is_empty() || left_edge_ids.is_empty() {
526 continue;
527 }
528
529 let mut left_cache = std::mem::take(&mut self.left_cache_buffer);
532 for &left_edge_id in left_edge_ids {
533 let left_edge = &self.edges[left_edge_id.0 as usize];
534 left_cache.push((
535 left_edge.word_entry.right_id(),
536 left_edge.path_cost,
537 mode.penalty_cost(left_edge),
538 left_edge_id,
539 ));
540 }
541
542 for &right_edge_id in right_edge_ids {
543 let right_edge = &self.edges[right_edge_id.0 as usize];
545 let right_word_entry = right_edge.word_entry;
546 let right_left_id = right_word_entry.left_id();
547
548 let mut best_cost = i32::MAX;
550 let mut best_left = None;
551
552 for &(left_right_id, left_path_cost, left_penalty, left_edge_id) in
553 left_cache.iter()
554 {
555 let mut path_cost =
557 left_path_cost + cost_matrix.cost(left_right_id, right_left_id);
558 path_cost += left_penalty;
559
560 if path_cost < best_cost {
562 best_cost = path_cost;
563 best_left = Some(left_edge_id);
564 }
565 }
566
567 if let Some(best_left_id) = best_left {
569 let edge = &mut self.edges[right_edge_id.0 as usize];
570 edge.left_edge = Some(best_left_id);
571 edge.path_cost = right_word_entry.word_cost as i32 + best_cost;
572 }
573 }
574 left_cache.clear();
575 self.left_cache_buffer = left_cache;
576 }
577 }
578
579 pub fn tokens_offset(&self) -> Vec<(usize, WordId)> {
580 let mut offsets = Vec::new();
581 let mut edge_id = EOS_NODE;
582 let _edge = self.edge(EOS_NODE);
583 loop {
584 let edge = self.edge(edge_id);
585 if let Some(left_edge_id) = edge.left_edge {
586 offsets.push((edge.start_index as usize, edge.word_entry.word_id));
587 edge_id = left_edge_id;
588 } else {
589 break;
590 }
591 }
592 offsets.reverse();
593 offsets.pop();
594 offsets
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use crate::viterbi::{LexType, WordEntry, WordId};
601
602 #[test]
603 fn test_word_entry() {
604 let mut buffer = Vec::new();
605 let word_entry = WordEntry {
606 word_id: WordId {
607 id: 1u32,
608 is_system: true,
609 lex_type: LexType::System,
610 },
611 word_cost: -17i16,
612 left_id: 1411u16,
613 right_id: 1412u16,
614 };
615 word_entry.serialize(&mut buffer).unwrap();
616 assert_eq!(WordEntry::SERIALIZED_LEN, buffer.len());
617 let word_entry2 = WordEntry::deserialize(&buffer[..], true);
618 assert_eq!(word_entry, word_entry2);
619 }
620}