lindera_dictionary/
viterbi.rs

1use std::io;
2
3use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
4use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
5use serde::{Deserialize, Serialize};
6
7use crate::dictionary::character_definition::{CategoryId, CharacterDefinition};
8use crate::dictionary::connection_cost_matrix::ConnectionCostMatrix;
9use crate::dictionary::prefix_dictionary::PrefixDictionary;
10use crate::dictionary::unknown_dictionary::UnknownDictionary;
11use crate::mode::Mode;
12
13const EOS_NODE: EdgeId = EdgeId(1u32);
14
15/// Type of lexicon containing the word
16#[derive(
17    Clone,
18    Copy,
19    Debug,
20    Eq,
21    PartialEq,
22    Serialize,
23    Deserialize,
24    Default,
25    Archive,
26    RkyvSerialize,
27    RkyvDeserialize,
28)]
29
30pub enum LexType {
31    /// System dictionary (base dictionary)
32    #[default]
33    System,
34    /// User dictionary (additional vocabulary)
35    User,
36    /// Unknown words (OOV handling)
37    Unknown,
38}
39
40#[derive(
41    Clone,
42    Copy,
43    Debug,
44    Eq,
45    PartialEq,
46    Serialize,
47    Deserialize,
48    Archive,
49    RkyvDeserialize,
50    RkyvSerialize,
51)]
52
53pub struct WordId {
54    pub id: u32,
55    pub is_system: bool,
56    pub lex_type: LexType,
57}
58
59impl WordId {
60    /// Creates a new WordId with specified lexicon type
61    pub fn new(lex_type: LexType, id: u32) -> Self {
62        WordId {
63            id,
64            is_system: matches!(lex_type, LexType::System),
65            lex_type,
66        }
67    }
68
69    pub fn is_unknown(&self) -> bool {
70        self.id == u32::MAX || matches!(self.lex_type, LexType::Unknown)
71    }
72
73    pub fn is_system(&self) -> bool {
74        self.is_system
75    }
76
77    pub fn lex_type(&self) -> LexType {
78        self.lex_type
79    }
80}
81
82impl Default for WordId {
83    fn default() -> Self {
84        WordId {
85            id: u32::MAX,
86            is_system: true,
87            lex_type: LexType::System,
88        }
89    }
90}
91
92#[derive(
93    Default,
94    Clone,
95    Copy,
96    Debug,
97    Eq,
98    PartialEq,
99    Serialize,
100    Deserialize,
101    Archive,
102    RkyvSerialize,
103    RkyvDeserialize,
104)]
105
106pub struct WordEntry {
107    pub word_id: WordId,
108    pub word_cost: i16,
109    pub left_id: u16,
110    pub right_id: u16,
111}
112
113impl WordEntry {
114    pub const SERIALIZED_LEN: usize = 10;
115
116    pub fn left_id(&self) -> u32 {
117        self.left_id as u32
118    }
119
120    pub fn right_id(&self) -> u32 {
121        self.right_id as u32
122    }
123
124    pub fn serialize<W: io::Write>(&self, wtr: &mut W) -> io::Result<()> {
125        wtr.write_u32::<LittleEndian>(self.word_id.id)?;
126        wtr.write_i16::<LittleEndian>(self.word_cost)?;
127        wtr.write_u16::<LittleEndian>(self.left_id)?;
128        wtr.write_u16::<LittleEndian>(self.right_id)?;
129        Ok(())
130    }
131
132    pub fn deserialize(data: &[u8], is_system_entry: bool) -> WordEntry {
133        let word_id = WordId::new(
134            if is_system_entry {
135                LexType::System
136            } else {
137                LexType::User
138            },
139            LittleEndian::read_u32(&data[0..4]),
140        );
141        let word_cost = LittleEndian::read_i16(&data[4..6]);
142        let left_id = LittleEndian::read_u16(&data[6..8]);
143        let right_id = LittleEndian::read_u16(&data[8..10]);
144        WordEntry {
145            word_id,
146            word_cost,
147            left_id,
148            right_id,
149        }
150    }
151}
152
153#[derive(Clone, Copy, Debug, Default)]
154pub enum EdgeType {
155    #[default]
156    KNOWN,
157    UNKNOWN,
158    USER,
159    INSERTED,
160}
161
162#[derive(Eq, PartialEq, Clone, Copy, Debug)]
163pub struct EdgeId(pub u32);
164
165#[derive(Default, Clone, Debug)]
166pub struct Edge {
167    pub edge_type: EdgeType,
168    pub word_entry: WordEntry,
169
170    pub path_cost: i32,
171    pub left_edge: Option<EdgeId>,
172
173    pub start_index: u32,
174    pub stop_index: u32,
175
176    pub kanji_only: bool,
177}
178
179impl Edge {
180    // TODO fix em
181    pub fn num_chars(&self) -> usize {
182        (self.stop_index - self.start_index) as usize / 3
183    }
184}
185
186#[derive(Clone, Default)]
187pub struct Lattice {
188    capacity: usize,
189    edges: Vec<Edge>,
190    starts_at: Vec<Vec<EdgeId>>,
191    ends_at: Vec<Vec<EdgeId>>,
192    // Buffer reuse optimization: pre-allocated vectors for reuse
193    edge_buffer: Vec<Edge>,
194    edge_id_buffer: Vec<EdgeId>,
195}
196
197#[inline]
198fn is_kanji(c: char) -> bool {
199    let c = c as u32;
200    // Direct comparison is faster than range.contains()
201    (19968..=40879).contains(&c)
202}
203
204#[inline]
205fn is_kanji_only(s: &str) -> bool {
206    // Early exit for empty strings
207    !s.is_empty() && s.chars().all(is_kanji)
208}
209
210impl Lattice {
211    /// Helper method to create an edge efficiently
212    #[inline]
213    fn create_edge(
214        edge_type: EdgeType,
215        word_entry: WordEntry,
216        start: usize,
217        stop: usize,
218        kanji_only: bool,
219    ) -> Edge {
220        Edge {
221            edge_type,
222            word_entry,
223            left_edge: None,
224            start_index: start as u32,
225            stop_index: stop as u32,
226            path_cost: i32::MAX,
227            kanji_only,
228        }
229    }
230
231    pub fn clear(&mut self) {
232        for edge_vec in &mut self.starts_at {
233            edge_vec.clear();
234        }
235        for edge_vec in &mut self.ends_at {
236            edge_vec.clear();
237        }
238        self.edges.clear();
239        // Clear buffers but preserve capacity for reuse
240        self.edge_buffer.clear();
241        self.edge_id_buffer.clear();
242    }
243
244    /// Get a reusable edge buffer with preserved capacity
245    pub fn get_edge_buffer(&mut self) -> &mut Vec<Edge> {
246        self.edge_buffer.clear();
247        &mut self.edge_buffer
248    }
249
250    /// Get a reusable edge ID buffer with preserved capacity
251    pub fn get_edge_id_buffer(&mut self) -> &mut Vec<EdgeId> {
252        self.edge_id_buffer.clear();
253        &mut self.edge_id_buffer
254    }
255
256    fn set_capacity(&mut self, text_len: usize) {
257        self.clear();
258        if self.capacity < text_len {
259            self.capacity = text_len;
260            self.edges.clear();
261            self.starts_at.resize(text_len + 1, Vec::new());
262            self.ends_at.resize(text_len + 1, Vec::new());
263        }
264    }
265
266    #[inline(never)]
267    pub fn set_text(
268        &mut self,
269        dict: &PrefixDictionary,
270        user_dict: &Option<&PrefixDictionary>,
271        char_definitions: &CharacterDefinition,
272        unknown_dictionary: &UnknownDictionary,
273        text: &str,
274        search_mode: &Mode,
275    ) {
276        let len = text.len();
277        self.set_capacity(len);
278
279        let start_edge_id = self.add_edge(Edge::default());
280        let end_edge_id = self.add_edge(Edge::default());
281
282        assert_eq!(EOS_NODE, end_edge_id);
283        self.ends_at[0].push(start_edge_id);
284        self.starts_at[len].push(end_edge_id);
285
286        // index of the last character of unknown word
287        let mut unknown_word_end: Option<usize> = None;
288
289        for start in 0..len {
290            // No arc is ending here.
291            // No need to check if a valid word starts here.
292            if self.ends_at[start].is_empty() {
293                continue;
294            }
295
296            let suffix = &text[start..];
297
298            let mut found: bool = false;
299
300            // lookup user dictionary
301            if user_dict.is_some() {
302                let dict = user_dict.as_ref().unwrap();
303                for (prefix_len, word_entry) in dict.prefix(suffix) {
304                    let edge = Self::create_edge(
305                        EdgeType::KNOWN,
306                        word_entry,
307                        start,
308                        start + prefix_len,
309                        is_kanji_only(&suffix[..prefix_len]),
310                    );
311                    self.add_edge_in_lattice(edge);
312                    found = true;
313                }
314            }
315
316            // we check all word starting at start, using the double array, like we would use
317            // a prefix trie, and populate the lattice with as many edges
318            for (prefix_len, word_entry) in dict.prefix(suffix) {
319                let edge = Self::create_edge(
320                    EdgeType::KNOWN,
321                    word_entry,
322                    start,
323                    start + prefix_len,
324                    is_kanji_only(&suffix[..prefix_len]),
325                );
326                self.add_edge_in_lattice(edge);
327                found = true;
328            }
329
330            // In the case of normal mode, it doesn't process unknown word greedily.
331            if (search_mode.is_search()
332                || unknown_word_end.map(|index| index <= start).unwrap_or(true))
333                && let Some(first_char) = suffix.chars().next()
334            {
335                let categories = char_definitions.lookup_categories(first_char);
336                for (category_ord, &category) in categories.iter().enumerate() {
337                    unknown_word_end = self.process_unknown_word(
338                        char_definitions,
339                        unknown_dictionary,
340                        category,
341                        category_ord,
342                        unknown_word_end,
343                        start,
344                        suffix,
345                        found,
346                    );
347                }
348            }
349        }
350    }
351
352    #[allow(clippy::too_many_arguments)]
353    fn process_unknown_word(
354        &mut self,
355        char_definitions: &CharacterDefinition,
356        unknown_dictionary: &UnknownDictionary,
357        category: CategoryId,
358        category_ord: usize,
359        unknown_word_index: Option<usize>,
360        start: usize,
361        suffix: &str,
362        found: bool,
363    ) -> Option<usize> {
364        let mut unknown_word_num_chars: usize = 0;
365        let category_data = char_definitions.lookup_definition(category);
366        if category_data.invoke || !found {
367            unknown_word_num_chars = 1;
368            if category_data.group {
369                for c in suffix.chars().skip(1) {
370                    let categories = char_definitions.lookup_categories(c);
371                    if categories.len() > category_ord && categories[category_ord] == category {
372                        unknown_word_num_chars += 1;
373                    } else {
374                        break;
375                    }
376                }
377            }
378        }
379        if unknown_word_num_chars > 0 {
380            // Optimized: Direct byte boundary calculation
381            let byte_end = suffix
382                .char_indices()
383                .nth(unknown_word_num_chars)
384                .map_or(suffix.len(), |(pos, _)| pos);
385            let unknown_word = &suffix[..byte_end];
386            for &word_id in unknown_dictionary.lookup_word_ids(category) {
387                let word_entry = unknown_dictionary.word_entry(word_id);
388                let edge = Self::create_edge(
389                    EdgeType::UNKNOWN,
390                    word_entry,
391                    start,
392                    start + unknown_word.len(),
393                    is_kanji_only(unknown_word),
394                );
395                self.add_edge_in_lattice(edge);
396            }
397            return Some(start + unknown_word.len());
398        }
399        unknown_word_index
400    }
401
402    fn add_edge_in_lattice(&mut self, edge: Edge) {
403        let start_index = edge.start_index as usize;
404        let stop_index = edge.stop_index as usize;
405        let edge_id = self.add_edge(edge);
406        self.starts_at[start_index].push(edge_id);
407        self.ends_at[stop_index].push(edge_id);
408    }
409
410    fn add_edge(&mut self, edge: Edge) -> EdgeId {
411        let edge_id = EdgeId(self.edges.len() as u32);
412        self.edges.push(edge);
413        edge_id
414    }
415
416    pub fn edge(&self, edge_id: EdgeId) -> &Edge {
417        &self.edges[edge_id.0 as usize]
418    }
419
420    #[inline(never)]
421    pub fn calculate_path_costs(&mut self, cost_matrix: &ConnectionCostMatrix, mode: &Mode) {
422        let text_len = self.starts_at.len();
423        for i in 0..text_len {
424            let left_edge_ids = &self.ends_at[i];
425            let right_edge_ids = &self.starts_at[i];
426
427            for &right_edge_id in right_edge_ids {
428                // Cache right edge data to avoid repeated access
429                let right_edge = &self.edges[right_edge_id.0 as usize];
430                let right_word_entry = right_edge.word_entry;
431                let right_left_id = right_word_entry.left_id();
432
433                // Manual loop for better performance (avoids iterator overhead)
434                let mut best_cost = i32::MAX;
435                let mut best_left = None;
436
437                for &left_edge_id in left_edge_ids {
438                    let left_edge = &self.edges[left_edge_id.0 as usize];
439                    let left_right_id = left_edge.word_entry.right_id();
440
441                    // Calculate path cost directly
442                    let mut path_cost =
443                        left_edge.path_cost + cost_matrix.cost(left_right_id, right_left_id);
444                    path_cost += mode.penalty_cost(left_edge);
445
446                    // Track minimum cost with branch-free comparison when possible
447                    if path_cost < best_cost {
448                        best_cost = path_cost;
449                        best_left = Some(left_edge_id);
450                    }
451                }
452
453                // Update edge with best path if found
454                if let Some(best_left_id) = best_left {
455                    let edge = &mut self.edges[right_edge_id.0 as usize];
456                    edge.left_edge = Some(best_left_id);
457                    edge.path_cost = right_word_entry.word_cost as i32 + best_cost;
458                }
459            }
460        }
461    }
462
463    pub fn tokens_offset(&self) -> Vec<(usize, WordId)> {
464        let mut offsets = Vec::new();
465        let mut edge_id = EOS_NODE;
466        let _edge = self.edge(EOS_NODE);
467        loop {
468            let edge = self.edge(edge_id);
469            if let Some(left_edge_id) = edge.left_edge {
470                offsets.push((edge.start_index as usize, edge.word_entry.word_id));
471                edge_id = left_edge_id;
472            } else {
473                break;
474            }
475        }
476        offsets.reverse();
477        offsets.pop();
478        offsets
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use crate::viterbi::{LexType, WordEntry, WordId};
485
486    #[test]
487    fn test_word_entry() {
488        let mut buffer = Vec::new();
489        let word_entry = WordEntry {
490            word_id: WordId {
491                id: 1u32,
492                is_system: true,
493                lex_type: LexType::System,
494            },
495            word_cost: -17i16,
496            left_id: 1411u16,
497            right_id: 1412u16,
498        };
499        word_entry.serialize(&mut buffer).unwrap();
500        assert_eq!(WordEntry::SERIALIZED_LEN, buffer.len());
501        let word_entry2 = WordEntry::deserialize(&buffer[..], true);
502        assert_eq!(word_entry, word_entry2);
503    }
504}