lindera_dictionary/
viterbi.rs

1use std::io;
2
3use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::dictionary::character_definition::{CategoryId, CharacterDefinition};
8use crate::dictionary::connection_cost_matrix::ConnectionCostMatrix;
9use crate::dictionary::prefix_dictionary::PrefixDictionary;
10use crate::dictionary::unknown_dictionary::UnknownDictionary;
11use crate::mode::Mode;
12
13const EOS_NODE: EdgeId = EdgeId(1u32);
14
15/// Type of lexicon containing the word
16#[derive(
17    Clone,
18    Copy,
19    Debug,
20    Eq,
21    PartialEq,
22    Serialize,
23    Deserialize,
24    Default,
25    Archive,
26    RkyvSerialize,
27    RkyvDeserialize,
28)]
29
30pub enum LexType {
31    /// System dictionary (base dictionary)
32    #[default]
33    System,
34    /// User dictionary (additional vocabulary)
35    User,
36    /// Unknown words (OOV handling)
37    Unknown,
38}
39
40#[derive(
41    Clone,
42    Copy,
43    Debug,
44    Eq,
45    PartialEq,
46    Serialize,
47    Deserialize,
48    Archive,
49    RkyvDeserialize,
50    RkyvSerialize,
51)]
52
53pub struct WordId {
54    pub id: u32,
55    pub is_system: bool,
56    pub lex_type: LexType,
57}
58
59impl WordId {
60    /// Creates a new WordId with specified lexicon type
61    pub fn new(lex_type: LexType, id: u32) -> Self {
62        WordId {
63            id,
64            is_system: matches!(lex_type, LexType::System),
65            lex_type,
66        }
67    }
68
69    pub fn is_unknown(&self) -> bool {
70        self.id == u32::MAX || matches!(self.lex_type, LexType::Unknown)
71    }
72
73    pub fn is_system(&self) -> bool {
74        self.is_system
75    }
76
77    pub fn lex_type(&self) -> LexType {
78        self.lex_type
79    }
80}
81
82impl Default for WordId {
83    fn default() -> Self {
84        WordId {
85            id: u32::MAX,
86            is_system: true,
87            lex_type: LexType::System,
88        }
89    }
90}
91
92#[derive(
93    Default,
94    Clone,
95    Copy,
96    Debug,
97    Eq,
98    PartialEq,
99    Serialize,
100    Deserialize,
101    Archive,
102    RkyvSerialize,
103    RkyvDeserialize,
104)]
105
106pub struct WordEntry {
107    pub word_id: WordId,
108    pub word_cost: i16,
109    pub left_id: u16,
110    pub right_id: u16,
111}
112
113impl WordEntry {
114    pub const SERIALIZED_LEN: usize = 10;
115
116    pub fn left_id(&self) -> u32 {
117        self.left_id as u32
118    }
119
120    pub fn right_id(&self) -> u32 {
121        self.right_id as u32
122    }
123
124    pub fn serialize<W: io::Write>(&self, wtr: &mut W) -> io::Result<()> {
125        wtr.write_u32::<LittleEndian>(self.word_id.id)?;
126        wtr.write_i16::<LittleEndian>(self.word_cost)?;
127        wtr.write_u16::<LittleEndian>(self.left_id)?;
128        wtr.write_u16::<LittleEndian>(self.right_id)?;
129        Ok(())
130    }
131
132    pub fn deserialize(data: &[u8], is_system_entry: bool) -> WordEntry {
133        let word_id = WordId::new(
134            if is_system_entry {
135                LexType::System
136            } else {
137                LexType::User
138            },
139            LittleEndian::read_u32(&data[0..4]),
140        );
141        let word_cost = LittleEndian::read_i16(&data[4..6]);
142        let left_id = LittleEndian::read_u16(&data[6..8]);
143        let right_id = LittleEndian::read_u16(&data[8..10]);
144        WordEntry {
145            word_id,
146            word_cost,
147            left_id,
148            right_id,
149        }
150    }
151}
152
153#[derive(Clone, Copy, Debug, Default)]
154pub enum EdgeType {
155    #[default]
156    KNOWN,
157    UNKNOWN,
158    USER,
159    INSERTED,
160}
161
162#[derive(Eq, PartialEq, Clone, Copy, Debug)]
163pub struct EdgeId(pub u32);
164
165#[derive(Default, Clone, Debug)]
166pub struct Edge {
167    pub edge_type: EdgeType,
168    pub word_entry: WordEntry,
169
170    pub path_cost: i32,
171    pub left_edge: Option<EdgeId>,
172
173    pub start_index: u32,
174    pub stop_index: u32,
175
176    pub kanji_only: bool,
177}
178
179impl Edge {
180    // TODO fix em
181    pub fn num_chars(&self) -> usize {
182        (self.stop_index - self.start_index) as usize / 3
183    }
184}
185
186#[derive(Clone, Default)]
187pub struct Lattice {
188    capacity: usize,
189    edges: Vec<Edge>,
190    starts_at: Vec<Vec<EdgeId>>,
191    ends_at: Vec<Vec<EdgeId>>,
192    // Buffer reuse optimization: pre-allocated vectors for reuse
193    edge_buffer: Vec<Edge>,
194    edge_id_buffer: Vec<EdgeId>,
195    left_cache_buffer: Vec<(u32, i32, i32, EdgeId)>,
196    char_info_buffer: Vec<CharData>,
197    categories_buffer: Vec<CategoryId>,
198    // Fast path cache for character properties (first 256 characters)
199    char_category_cache: Vec<Vec<CategoryId>>,
200}
201
202#[derive(Clone, Copy, Debug, Default)]
203struct CharData {
204    byte_offset: u32,
205    is_kanji: bool,
206    categories_start: u32,
207    categories_len: u16,
208    kanji_run_byte_len: u32,
209}
210
211#[inline]
212pub fn is_kanji(c: char) -> bool {
213    let c = c as u32;
214    // CJK Unified Ideographs (4E00-9FAF) and Extension A (3400-4DBF)
215    (0x4E00..=0x9FAF).contains(&c) || (0x3400..=0x4DBF).contains(&c)
216}
217
218impl Lattice {
219    /// Helper method to create an edge efficiently
220    #[inline]
221    fn create_edge(
222        edge_type: EdgeType,
223        word_entry: WordEntry,
224        start: usize,
225        stop: usize,
226        kanji_only: bool,
227    ) -> Edge {
228        Edge {
229            edge_type,
230            word_entry,
231            left_edge: None,
232            start_index: start as u32,
233            stop_index: stop as u32,
234            path_cost: i32::MAX,
235            kanji_only,
236        }
237    }
238
239    pub fn clear(&mut self) {
240        for edge_vec in &mut self.starts_at {
241            edge_vec.clear();
242        }
243        for edge_vec in &mut self.ends_at {
244            edge_vec.clear();
245        }
246        self.edges.clear();
247        // Clear buffers but preserve capacity for reuse
248        self.edge_buffer.clear();
249        self.edge_id_buffer.clear();
250        self.left_cache_buffer.clear();
251        self.char_info_buffer.clear();
252        self.categories_buffer.clear();
253    }
254
255    /// Get a reusable edge buffer with preserved capacity
256    pub fn get_edge_buffer(&mut self) -> &mut Vec<Edge> {
257        self.edge_buffer.clear();
258        &mut self.edge_buffer
259    }
260
261    /// Get a reusable edge ID buffer with preserved capacity
262    pub fn get_edge_id_buffer(&mut self) -> &mut Vec<EdgeId> {
263        self.edge_id_buffer.clear();
264        &mut self.edge_id_buffer
265    }
266
267    #[inline]
268    fn is_kanji_all(&self, char_idx: usize, byte_len: usize) -> bool {
269        self.char_info_buffer[char_idx].kanji_run_byte_len >= byte_len as u32
270    }
271
272    #[inline]
273    fn get_cached_category(&self, char_idx: usize, category_ord: usize) -> CategoryId {
274        let char_data = &self.char_info_buffer[char_idx];
275        self.categories_buffer[char_data.categories_start as usize + category_ord]
276    }
277
278    fn set_capacity(&mut self, text_len: usize) {
279        self.clear();
280        if self.capacity < text_len {
281            self.capacity = text_len;
282            self.edges.clear();
283            self.starts_at.resize(text_len + 1, Vec::new());
284            self.ends_at.resize(text_len + 1, Vec::new());
285        }
286    }
287
288    #[inline(never)]
289    pub fn set_text(
290        &mut self,
291        dict: &PrefixDictionary,
292        user_dict: &Option<&PrefixDictionary>,
293        char_definitions: &CharacterDefinition,
294        unknown_dictionary: &UnknownDictionary,
295        text: &str,
296        search_mode: &Mode,
297    ) {
298        let len = text.len();
299        self.set_capacity(len);
300
301        // Pre-calculate character information for the text
302        self.char_info_buffer.clear();
303        self.categories_buffer.clear();
304
305        if self.char_category_cache.is_empty() {
306            self.char_category_cache.resize(256, Vec::new());
307        }
308
309        for (byte_offset, c) in text.char_indices() {
310            let categories_start = self.categories_buffer.len() as u32;
311
312            if (c as u32) < 256 {
313                let cached = &mut self.char_category_cache[c as usize];
314                if cached.is_empty() {
315                    let cats = char_definitions.lookup_categories(c);
316                    for &category in cats {
317                        cached.push(category);
318                    }
319                }
320                for &category in cached.iter() {
321                    self.categories_buffer.push(category);
322                }
323            } else {
324                let categories = char_definitions.lookup_categories(c);
325                for &category in categories {
326                    self.categories_buffer.push(category);
327                }
328            }
329
330            let categories_len = (self.categories_buffer.len() as u32 - categories_start) as u16;
331
332            self.char_info_buffer.push(CharData {
333                byte_offset: byte_offset as u32,
334                is_kanji: is_kanji(c),
335                categories_start,
336                categories_len,
337                kanji_run_byte_len: 0,
338            });
339        }
340        // Sentinel for end of text
341        self.char_info_buffer.push(CharData {
342            byte_offset: len as u32,
343            is_kanji: false,
344            categories_start: 0,
345            categories_len: 0,
346            kanji_run_byte_len: 0,
347        });
348
349        // Pre-calculate Kanji run lengths (backwards)
350        for i in (0..self.char_info_buffer.len() - 1).rev() {
351            if self.char_info_buffer[i].is_kanji {
352                let next_byte_offset = self.char_info_buffer[i + 1].byte_offset;
353                let char_byte_len = next_byte_offset - self.char_info_buffer[i].byte_offset;
354                self.char_info_buffer[i].kanji_run_byte_len =
355                    char_byte_len + self.char_info_buffer[i + 1].kanji_run_byte_len;
356            } else {
357                self.char_info_buffer[i].kanji_run_byte_len = 0;
358            }
359        }
360
361        let start_edge_id = self.add_edge(Edge::default());
362        let end_edge_id = self.add_edge(Edge::default());
363
364        assert_eq!(EOS_NODE, end_edge_id);
365        self.ends_at[0].push(start_edge_id);
366        self.starts_at[len].push(end_edge_id);
367
368        // index of the last character of unknown word
369        let mut unknown_word_end: Option<usize> = None;
370
371        for char_idx in 0..self.char_info_buffer.len() - 1 {
372            let start = self.char_info_buffer[char_idx].byte_offset as usize;
373
374            // No arc is ending here.
375            // No need to check if a valid word starts here.
376            if self.ends_at[start].is_empty() {
377                continue;
378            }
379
380            let suffix = &text[start..];
381
382            let mut found: bool = false;
383
384            // lookup user dictionary
385            if user_dict.is_some() {
386                let dict = user_dict.as_ref().unwrap();
387                for (prefix_len, word_entry) in dict.prefix(suffix) {
388                    let kanji_only = self.is_kanji_all(char_idx, prefix_len);
389                    let edge = Self::create_edge(
390                        EdgeType::KNOWN,
391                        word_entry,
392                        start,
393                        start + prefix_len,
394                        kanji_only,
395                    );
396                    self.add_edge_in_lattice(edge);
397                    found = true;
398                }
399            }
400
401            // we check all word starting at start, using the double array, like we would use
402            // a prefix trie, and populate the lattice with as many edges
403            for (prefix_len, word_entry) in dict.prefix(suffix) {
404                let kanji_only = self.is_kanji_all(char_idx, prefix_len);
405                let edge = Self::create_edge(
406                    EdgeType::KNOWN,
407                    word_entry,
408                    start,
409                    start + prefix_len,
410                    kanji_only,
411                );
412                self.add_edge_in_lattice(edge);
413                found = true;
414            }
415
416            // In the case of normal mode, it doesn't process unknown word greedily.
417            if (search_mode.is_search()
418                || unknown_word_end.map(|index| index <= start).unwrap_or(true))
419                && char_idx < self.char_info_buffer.len() - 1
420            {
421                let num_categories = self.char_info_buffer[char_idx].categories_len as usize;
422                for category_ord in 0..num_categories {
423                    let category = self.get_cached_category(char_idx, category_ord);
424                    unknown_word_end = self.process_unknown_word(
425                        char_definitions,
426                        unknown_dictionary,
427                        category,
428                        category_ord,
429                        unknown_word_end,
430                        start,
431                        char_idx,
432                        found,
433                    );
434                }
435            }
436        }
437    }
438
439    #[allow(clippy::too_many_arguments)]
440    fn process_unknown_word(
441        &mut self,
442        char_definitions: &CharacterDefinition,
443        unknown_dictionary: &UnknownDictionary,
444        category: CategoryId,
445        category_ord: usize,
446        unknown_word_index: Option<usize>,
447        start: usize,
448        char_idx: usize,
449        found: bool,
450    ) -> Option<usize> {
451        let mut unknown_word_num_chars: usize = 0;
452        let category_data = char_definitions.lookup_definition(category);
453        if category_data.invoke || !found {
454            unknown_word_num_chars = 1;
455            if category_data.group {
456                for i in 1.. {
457                    let next_idx = char_idx + i;
458                    if next_idx >= self.char_info_buffer.len() - 1 {
459                        break;
460                    }
461                    let num_categories = self.char_info_buffer[next_idx].categories_len as usize;
462                    let mut found_cat = false;
463                    if category_ord < num_categories {
464                        let cat = self.get_cached_category(next_idx, category_ord);
465                        if cat == category {
466                            unknown_word_num_chars += 1;
467                            found_cat = true;
468                        }
469                    }
470                    if !found_cat {
471                        break;
472                    }
473                }
474            }
475        }
476        if unknown_word_num_chars > 0 {
477            let byte_end_offset =
478                self.char_info_buffer[char_idx + unknown_word_num_chars].byte_offset;
479            let byte_len = byte_end_offset as usize - start;
480
481            // Check Kanji status using pre-calculated buffer
482            let kanji_only = self.is_kanji_all(char_idx, byte_len);
483
484            for &word_id in unknown_dictionary.lookup_word_ids(category) {
485                let word_entry = unknown_dictionary.word_entry(word_id);
486                let edge = Self::create_edge(
487                    EdgeType::UNKNOWN,
488                    word_entry,
489                    start,
490                    start + byte_len,
491                    kanji_only,
492                );
493                self.add_edge_in_lattice(edge);
494            }
495            return Some(start + byte_len);
496        }
497        unknown_word_index
498    }
499
500    fn add_edge_in_lattice(&mut self, edge: Edge) {
501        let start_index = edge.start_index as usize;
502        let stop_index = edge.stop_index as usize;
503        let edge_id = self.add_edge(edge);
504        self.starts_at[start_index].push(edge_id);
505        self.ends_at[stop_index].push(edge_id);
506    }
507
508    fn add_edge(&mut self, edge: Edge) -> EdgeId {
509        let edge_id = EdgeId(self.edges.len() as u32);
510        self.edges.push(edge);
511        edge_id
512    }
513
514    pub fn edge(&self, edge_id: EdgeId) -> &Edge {
515        &self.edges[edge_id.0 as usize]
516    }
517
518    #[inline(never)]
519    pub fn calculate_path_costs(&mut self, cost_matrix: &ConnectionCostMatrix, mode: &Mode) {
520        let text_len = self.starts_at.len();
521        for i in 0..text_len {
522            let left_edge_ids = &self.ends_at[i];
523            let right_edge_ids = &self.starts_at[i];
524
525            if right_edge_ids.is_empty() || left_edge_ids.is_empty() {
526                continue;
527            }
528
529            // Cache left edge data to avoid repeated access and penalty calculation
530            // We use mem::take to temporarily own the buffer and avoid borrow checker conflicts with self.edges
531            let mut left_cache = std::mem::take(&mut self.left_cache_buffer);
532            for &left_edge_id in left_edge_ids {
533                let left_edge = &self.edges[left_edge_id.0 as usize];
534                left_cache.push((
535                    left_edge.word_entry.right_id(),
536                    left_edge.path_cost,
537                    mode.penalty_cost(left_edge),
538                    left_edge_id,
539                ));
540            }
541
542            for &right_edge_id in right_edge_ids {
543                // Cache right edge data to avoid repeated access
544                let right_edge = &self.edges[right_edge_id.0 as usize];
545                let right_word_entry = right_edge.word_entry;
546                let right_left_id = right_word_entry.left_id();
547
548                // Manual loop for better performance (avoids iterator overhead)
549                let mut best_cost = i32::MAX;
550                let mut best_left = None;
551
552                for &(left_right_id, left_path_cost, left_penalty, left_edge_id) in
553                    left_cache.iter()
554                {
555                    // Calculate path cost directly
556                    let mut path_cost =
557                        left_path_cost + cost_matrix.cost(left_right_id, right_left_id);
558                    path_cost += left_penalty;
559
560                    // Track minimum cost with branch-free comparison when possible
561                    if path_cost < best_cost {
562                        best_cost = path_cost;
563                        best_left = Some(left_edge_id);
564                    }
565                }
566
567                // Update edge with best path if found
568                if let Some(best_left_id) = best_left {
569                    let edge = &mut self.edges[right_edge_id.0 as usize];
570                    edge.left_edge = Some(best_left_id);
571                    edge.path_cost = right_word_entry.word_cost as i32 + best_cost;
572                }
573            }
574            left_cache.clear();
575            self.left_cache_buffer = left_cache;
576        }
577    }
578
579    pub fn tokens_offset(&self) -> Vec<(usize, WordId)> {
580        let mut offsets = Vec::new();
581        let mut edge_id = EOS_NODE;
582        let _edge = self.edge(EOS_NODE);
583        loop {
584            let edge = self.edge(edge_id);
585            if let Some(left_edge_id) = edge.left_edge {
586                offsets.push((edge.start_index as usize, edge.word_entry.word_id));
587                edge_id = left_edge_id;
588            } else {
589                break;
590            }
591        }
592        offsets.reverse();
593        offsets.pop();
594        offsets
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use crate::viterbi::{LexType, WordEntry, WordId};
601
602    #[test]
603    fn test_word_entry() {
604        let mut buffer = Vec::new();
605        let word_entry = WordEntry {
606            word_id: WordId {
607                id: 1u32,
608                is_system: true,
609                lex_type: LexType::System,
610            },
611            word_cost: -17i16,
612            left_id: 1411u16,
613            right_id: 1412u16,
614        };
615        word_entry.serialize(&mut buffer).unwrap();
616        assert_eq!(WordEntry::SERIALIZED_LEN, buffer.len());
617        let word_entry2 = WordEntry::deserialize(&buffer[..], true);
618        assert_eq!(word_entry, word_entry2);
619    }
620}