lindera_dictionary/
viterbi.rs

1use std::io;
2
3use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::dictionary::character_definition::{CategoryId, CharacterDefinition};
8use crate::dictionary::connection_cost_matrix::ConnectionCostMatrix;
9use crate::dictionary::prefix_dictionary::PrefixDictionary;
10use crate::dictionary::unknown_dictionary::UnknownDictionary;
11use crate::mode::Mode;
12
13/// Type of lexicon containing the word
14#[derive(
15    Clone,
16    Copy,
17    Debug,
18    Eq,
19    PartialEq,
20    Serialize,
21    Deserialize,
22    Default,
23    Archive,
24    RkyvSerialize,
25    RkyvDeserialize,
26)]
27
28pub enum LexType {
29    /// System dictionary (base dictionary)
30    #[default]
31    System,
32    /// User dictionary (additional vocabulary)
33    User,
34    /// Unknown words (OOV handling)
35    Unknown,
36}
37
38#[derive(
39    Clone,
40    Copy,
41    Debug,
42    Eq,
43    PartialEq,
44    Serialize,
45    Deserialize,
46    Archive,
47    RkyvDeserialize,
48    RkyvSerialize,
49)]
50
51pub struct WordId {
52    pub id: u32,
53    pub is_system: bool,
54    pub lex_type: LexType,
55}
56
57impl WordId {
58    /// Creates a new WordId with specified lexicon type
59    pub fn new(lex_type: LexType, id: u32) -> Self {
60        WordId {
61            id,
62            is_system: matches!(lex_type, LexType::System),
63            lex_type,
64        }
65    }
66
67    pub fn is_unknown(&self) -> bool {
68        self.id == u32::MAX || matches!(self.lex_type, LexType::Unknown)
69    }
70
71    pub fn is_system(&self) -> bool {
72        self.is_system
73    }
74
75    pub fn lex_type(&self) -> LexType {
76        self.lex_type
77    }
78}
79
80impl Default for WordId {
81    fn default() -> Self {
82        WordId {
83            id: u32::MAX,
84            is_system: true,
85            lex_type: LexType::System,
86        }
87    }
88}
89
90#[derive(
91    Default,
92    Clone,
93    Copy,
94    Debug,
95    Eq,
96    PartialEq,
97    Serialize,
98    Deserialize,
99    Archive,
100    RkyvSerialize,
101    RkyvDeserialize,
102)]
103
104pub struct WordEntry {
105    pub word_id: WordId,
106    pub word_cost: i16,
107    pub left_id: u16,
108    pub right_id: u16,
109}
110
111impl WordEntry {
112    pub const SERIALIZED_LEN: usize = 10;
113
114    pub fn left_id(&self) -> u32 {
115        self.left_id as u32
116    }
117
118    pub fn right_id(&self) -> u32 {
119        self.right_id as u32
120    }
121
122    pub fn serialize<W: io::Write>(&self, wtr: &mut W) -> io::Result<()> {
123        wtr.write_u32::<LittleEndian>(self.word_id.id)?;
124        wtr.write_i16::<LittleEndian>(self.word_cost)?;
125        wtr.write_u16::<LittleEndian>(self.left_id)?;
126        wtr.write_u16::<LittleEndian>(self.right_id)?;
127        Ok(())
128    }
129
130    pub fn deserialize(data: &[u8], is_system_entry: bool) -> WordEntry {
131        let word_id = WordId::new(
132            if is_system_entry {
133                LexType::System
134            } else {
135                LexType::User
136            },
137            LittleEndian::read_u32(&data[0..4]),
138        );
139        let word_cost = LittleEndian::read_i16(&data[4..6]);
140        let left_id = LittleEndian::read_u16(&data[6..8]);
141        let right_id = LittleEndian::read_u16(&data[8..10]);
142        WordEntry {
143            word_id,
144            word_cost,
145            left_id,
146            right_id,
147        }
148    }
149}
150
151#[derive(Clone, Copy, Debug, Default)]
152pub enum EdgeType {
153    #[default]
154    KNOWN,
155    UNKNOWN,
156    USER,
157    INSERTED,
158}
159
160#[derive(Default, Clone, Debug)]
161pub struct Edge {
162    pub edge_type: EdgeType,
163    pub word_entry: WordEntry,
164
165    pub path_cost: i32,
166    pub left_index: u16, // Index in the previous position's vector
167
168    pub start_index: u32,
169    pub stop_index: u32,
170
171    pub kanji_only: bool,
172}
173
174impl Edge {
175    pub fn num_chars(&self) -> usize {
176        (self.stop_index - self.start_index) as usize / 3
177    }
178}
179
180#[derive(Clone, Default)]
181pub struct Lattice {
182    capacity: usize,
183    ends_at: Vec<Vec<Edge>>, // Now stores edges directly
184    char_info_buffer: Vec<CharData>,
185    categories_buffer: Vec<CategoryId>,
186    char_category_cache: Vec<Vec<CategoryId>>,
187}
188
189#[derive(Clone, Copy, Debug, Default)]
190struct CharData {
191    byte_offset: u32,
192    is_kanji: bool,
193    categories_start: u32,
194    categories_len: u16,
195    kanji_run_byte_len: u32,
196}
197
198#[inline]
199pub fn is_kanji(c: char) -> bool {
200    let c = c as u32;
201    // CJK Unified Ideographs (4E00-9FAF) and Extension A (3400-4DBF)
202    (0x4E00..=0x9FAF).contains(&c) || (0x3400..=0x4DBF).contains(&c)
203}
204
205impl Lattice {
206    /// Helper method to create an edge efficiently
207    #[inline]
208    fn create_edge(
209        edge_type: EdgeType,
210        word_entry: WordEntry,
211        start: usize,
212        stop: usize,
213        kanji_only: bool,
214    ) -> Edge {
215        Edge {
216            edge_type,
217            word_entry,
218            left_index: u16::MAX,
219            start_index: start as u32,
220            stop_index: stop as u32,
221            path_cost: i32::MAX,
222            kanji_only,
223        }
224    }
225
226    pub fn clear(&mut self) {
227        for edge_vec in &mut self.ends_at {
228            edge_vec.clear();
229        }
230        self.char_info_buffer.clear();
231        self.categories_buffer.clear();
232    }
233
234    #[inline]
235    fn is_kanji_all(&self, char_idx: usize, byte_len: usize) -> bool {
236        self.char_info_buffer[char_idx].kanji_run_byte_len >= byte_len as u32
237    }
238
239    #[inline]
240    fn get_cached_category(&self, char_idx: usize, category_ord: usize) -> CategoryId {
241        let char_data = &self.char_info_buffer[char_idx];
242        self.categories_buffer[char_data.categories_start as usize + category_ord]
243    }
244
245    fn set_capacity(&mut self, text_len: usize) {
246        self.clear();
247        if self.capacity <= text_len {
248            self.capacity = text_len;
249            self.ends_at.resize(text_len + 1, Vec::new());
250        }
251        for vec in &mut self.ends_at {
252            vec.clear();
253        }
254    }
255
256    #[inline(never)]
257    // Forward Viterbi implementation:
258    // Constructs the lattice and calculates the path costs simultaneously.
259    // This improves performance by avoiding a separate lattice traversal pass.
260    pub fn set_text(
261        &mut self,
262        dict: &PrefixDictionary,
263        user_dict: &Option<&PrefixDictionary>,
264        char_definitions: &CharacterDefinition,
265        unknown_dictionary: &UnknownDictionary,
266        cost_matrix: &ConnectionCostMatrix,
267        text: &str,
268        search_mode: &Mode,
269    ) {
270        let len = text.len();
271        self.set_capacity(len);
272
273        // Pre-calculate character information for the text
274        self.char_info_buffer.clear();
275        self.categories_buffer.clear();
276
277        if self.char_category_cache.is_empty() {
278            self.char_category_cache.resize(256, Vec::new());
279        }
280
281        for (byte_offset, c) in text.char_indices() {
282            let categories_start = self.categories_buffer.len() as u32;
283
284            if (c as u32) < 256 {
285                let cached = &mut self.char_category_cache[c as usize];
286                if cached.is_empty() {
287                    let cats = char_definitions.lookup_categories(c);
288                    for &category in cats {
289                        cached.push(category);
290                    }
291                }
292                for &category in cached.iter() {
293                    self.categories_buffer.push(category);
294                }
295            } else {
296                let categories = char_definitions.lookup_categories(c);
297                for &category in categories {
298                    self.categories_buffer.push(category);
299                }
300            }
301
302            let categories_len = (self.categories_buffer.len() as u32 - categories_start) as u16;
303
304            self.char_info_buffer.push(CharData {
305                byte_offset: byte_offset as u32,
306                is_kanji: is_kanji(c),
307                categories_start,
308                categories_len,
309                kanji_run_byte_len: 0,
310            });
311        }
312        // Sentinel for end of text
313        self.char_info_buffer.push(CharData {
314            byte_offset: len as u32,
315            is_kanji: false,
316            categories_start: 0,
317            categories_len: 0,
318            kanji_run_byte_len: 0,
319        });
320
321        // Pre-calculate Kanji run lengths (backwards)
322        for i in (0..self.char_info_buffer.len() - 1).rev() {
323            if self.char_info_buffer[i].is_kanji {
324                let next_byte_offset = self.char_info_buffer[i + 1].byte_offset;
325                let char_byte_len = next_byte_offset - self.char_info_buffer[i].byte_offset;
326                self.char_info_buffer[i].kanji_run_byte_len =
327                    char_byte_len + self.char_info_buffer[i + 1].kanji_run_byte_len;
328            } else {
329                self.char_info_buffer[i].kanji_run_byte_len = 0;
330            }
331        }
332
333        let mut start_edge = Edge::default();
334        start_edge.path_cost = 0;
335        start_edge.left_index = u16::MAX;
336        self.ends_at[0].push(start_edge);
337
338        // Index of the last character of unknown word
339        let mut unknown_word_end: Option<usize> = None;
340
341        for char_idx in 0..self.char_info_buffer.len() - 1 {
342            let start = self.char_info_buffer[char_idx].byte_offset as usize;
343
344            // No arc is ending here.
345            // No need to check if a valid word starts here.
346            if self.ends_at[start].is_empty() {
347                continue;
348            }
349
350            let suffix = &text[start..];
351
352            let mut found: bool = false;
353
354            // Lookup user dictionary
355            if user_dict.is_some() {
356                let dict = user_dict.as_ref().unwrap();
357                for (prefix_len, word_entry) in dict.prefix(suffix) {
358                    let kanji_only = self.is_kanji_all(char_idx, prefix_len);
359                    let edge = Self::create_edge(
360                        EdgeType::KNOWN,
361                        word_entry,
362                        start,
363                        start + prefix_len,
364                        kanji_only,
365                    );
366                    self.add_edge_in_lattice(edge, cost_matrix, search_mode);
367                    found = true;
368                }
369            }
370
371            // Check all word starting at start, using the double array, like we would use
372            // a prefix trie, and populate the lattice with as many edges
373            for (prefix_len, word_entry) in dict.prefix(suffix) {
374                let kanji_only = self.is_kanji_all(char_idx, prefix_len);
375                let edge = Self::create_edge(
376                    EdgeType::KNOWN,
377                    word_entry,
378                    start,
379                    start + prefix_len,
380                    kanji_only,
381                );
382                self.add_edge_in_lattice(edge, cost_matrix, search_mode);
383                found = true;
384            }
385
386            // In the case of normal mode, it doesn't process unknown word greedily.
387            if (search_mode.is_search()
388                || unknown_word_end.map(|index| index <= start).unwrap_or(true))
389                && char_idx < self.char_info_buffer.len() - 1
390            {
391                let num_categories = self.char_info_buffer[char_idx].categories_len as usize;
392                for category_ord in 0..num_categories {
393                    let category = self.get_cached_category(char_idx, category_ord);
394                    unknown_word_end = self.process_unknown_word(
395                        char_definitions,
396                        unknown_dictionary,
397                        cost_matrix,
398                        search_mode,
399                        category,
400                        category_ord,
401                        unknown_word_end,
402                        start,
403                        char_idx,
404                        found,
405                    );
406                }
407            }
408        }
409
410        // Connect EOS
411        if !self.ends_at[len].is_empty() {
412            let mut eos_edge = Edge::default();
413            eos_edge.start_index = len as u32;
414            eos_edge.stop_index = len as u32;
415            // Calculate cost for EOS
416            let left_edges = &self.ends_at[len];
417            let mut best_cost = i32::MAX;
418            let mut best_left = None;
419            let right_left_id = 0; // EOS default left_id
420
421            for (i, left_edge) in left_edges.iter().enumerate() {
422                let left_right_id = left_edge.word_entry.right_id();
423                let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
424                let path_cost = left_edge.path_cost.saturating_add(conn_cost);
425                if path_cost < best_cost {
426                    best_cost = path_cost;
427                    best_left = Some(i as u16);
428                }
429            }
430            if let Some(left_idx) = best_left {
431                eos_edge.left_index = left_idx;
432                eos_edge.path_cost = best_cost;
433                self.ends_at[len].push(eos_edge);
434            }
435        }
436    }
437
438    #[allow(clippy::too_many_arguments)]
439    fn process_unknown_word(
440        &mut self,
441        char_definitions: &CharacterDefinition,
442        unknown_dictionary: &UnknownDictionary,
443        cost_matrix: &ConnectionCostMatrix,
444        search_mode: &Mode,
445        category: CategoryId,
446        category_ord: usize,
447        unknown_word_index: Option<usize>,
448        start: usize,
449        char_idx: usize,
450        found: bool,
451    ) -> Option<usize> {
452        let mut unknown_word_num_chars: usize = 0;
453        let category_data = char_definitions.lookup_definition(category);
454        if category_data.invoke || !found {
455            unknown_word_num_chars = 1;
456            if category_data.group {
457                for i in 1.. {
458                    let next_idx = char_idx + i;
459                    if next_idx >= self.char_info_buffer.len() - 1 {
460                        break;
461                    }
462                    let num_categories = self.char_info_buffer[next_idx].categories_len as usize;
463                    let mut found_cat = false;
464                    if category_ord < num_categories {
465                        let cat = self.get_cached_category(next_idx, category_ord);
466                        if cat == category {
467                            unknown_word_num_chars += 1;
468                            found_cat = true;
469                        }
470                    }
471                    if !found_cat {
472                        break;
473                    }
474                }
475            }
476        }
477        if unknown_word_num_chars > 0 {
478            let byte_end_offset =
479                self.char_info_buffer[char_idx + unknown_word_num_chars].byte_offset;
480            let byte_len = byte_end_offset as usize - start;
481
482            // Check Kanji status using pre-calculated buffer
483            let kanji_only = self.is_kanji_all(char_idx, byte_len);
484
485            for &word_id in unknown_dictionary.lookup_word_ids(category) {
486                let word_entry = unknown_dictionary.word_entry(word_id);
487                let edge = Self::create_edge(
488                    EdgeType::UNKNOWN,
489                    word_entry,
490                    start,
491                    start + byte_len,
492                    kanji_only,
493                );
494                self.add_edge_in_lattice(edge, cost_matrix, search_mode);
495            }
496            return Some(start + byte_len);
497        }
498        unknown_word_index
499    }
500
501    // Adds an edge to the lattice and calculates the minimum cost to reach it.
502    fn add_edge_in_lattice(
503        &mut self,
504        mut edge: Edge,
505        cost_matrix: &ConnectionCostMatrix,
506        mode: &Mode,
507    ) {
508        let start_index = edge.start_index as usize;
509        let stop_index = edge.stop_index as usize;
510
511        let left_edges = &self.ends_at[start_index];
512        if left_edges.is_empty() {
513            return;
514        }
515
516        let mut best_cost = i32::MAX;
517        let mut best_left = None;
518        let right_left_id = edge.word_entry.left_id();
519
520        for (i, left_edge) in left_edges.iter().enumerate() {
521            let left_right_id = left_edge.word_entry.right_id();
522            let conn_cost = cost_matrix.cost(left_right_id, right_left_id);
523            let penalty = mode.penalty_cost(left_edge);
524            let total_cost = left_edge
525                .path_cost
526                .saturating_add(conn_cost)
527                .saturating_add(penalty);
528
529            if total_cost < best_cost {
530                best_cost = total_cost;
531                best_left = Some(i as u16);
532            }
533        }
534
535        if let Some(best_left_idx) = best_left {
536            edge.path_cost = best_cost.saturating_add(edge.word_entry.word_cost as i32);
537            edge.left_index = best_left_idx;
538            self.ends_at[stop_index].push(edge);
539        }
540    }
541
542    pub fn tokens_offset(&self) -> Vec<(usize, WordId)> {
543        let mut offsets = Vec::new();
544
545        if self.ends_at.is_empty() {
546            return offsets;
547        }
548
549        let mut last_idx = self.ends_at.len() - 1;
550        while last_idx > 0 && self.ends_at[last_idx].is_empty() {
551            last_idx -= 1;
552        }
553
554        if self.ends_at[last_idx].is_empty() {
555            return offsets;
556        }
557
558        let idx = self.ends_at[last_idx].len() - 1;
559        let mut edge = &self.ends_at[last_idx][idx];
560
561        if edge.left_index == u16::MAX {
562            return offsets;
563        }
564
565        loop {
566            if edge.left_index == u16::MAX {
567                break;
568            }
569
570            offsets.push((edge.start_index as usize, edge.word_entry.word_id));
571
572            let left_idx = edge.left_index as usize;
573            let start_idx = edge.start_index as usize;
574
575            edge = &self.ends_at[start_idx][left_idx];
576        }
577
578        offsets.reverse();
579        offsets.pop(); // Remove EOS
580
581        offsets
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use crate::viterbi::{LexType, WordEntry, WordId};
588
589    #[test]
590    fn test_word_entry() {
591        let mut buffer = Vec::new();
592        let word_entry = WordEntry {
593            word_id: WordId {
594                id: 1u32,
595                is_system: true,
596                lex_type: LexType::System,
597            },
598            word_cost: -17i16,
599            left_id: 1411u16,
600            right_id: 1412u16,
601        };
602        word_entry.serialize(&mut buffer).unwrap();
603        assert_eq!(WordEntry::SERIALIZED_LEN, buffer.len());
604        let word_entry2 = WordEntry::deserialize(&buffer[..], true);
605        assert_eq!(word_entry, word_entry2);
606    }
607}