lindera_dictionary/
viterbi.rs

1use std::io;
2
3use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
4use serde::{Deserialize, Serialize};
5
6use crate::dictionary::character_definition::{CategoryId, CharacterDefinition};
7use crate::dictionary::connection_cost_matrix::ConnectionCostMatrix;
8use crate::dictionary::prefix_dictionary::PrefixDictionary;
9use crate::dictionary::unknown_dictionary::UnknownDictionary;
10use crate::mode::Mode;
11
12const EOS_NODE: EdgeId = EdgeId(1u32);
13
14/// Type of lexicon containing the word
15#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize, Default)]
16pub enum LexType {
17    /// System dictionary (base dictionary)
18    #[default]
19    System,
20    /// User dictionary (additional vocabulary)
21    User,
22    /// Unknown words (OOV handling)
23    Unknown,
24}
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
27pub struct WordId {
28    pub id: u32,
29    pub is_system: bool,
30    pub lex_type: LexType,
31}
32
33impl WordId {
34    /// Creates a new WordId with specified lexicon type
35    pub fn new(lex_type: LexType, id: u32) -> Self {
36        WordId {
37            id,
38            is_system: matches!(lex_type, LexType::System),
39            lex_type,
40        }
41    }
42
43    pub fn is_unknown(&self) -> bool {
44        self.id == u32::MAX || matches!(self.lex_type, LexType::Unknown)
45    }
46
47    pub fn is_system(&self) -> bool {
48        self.is_system
49    }
50
51    pub fn lex_type(&self) -> LexType {
52        self.lex_type
53    }
54}
55
56impl Default for WordId {
57    fn default() -> Self {
58        WordId {
59            id: u32::MAX,
60            is_system: true,
61            lex_type: LexType::System,
62        }
63    }
64}
65
66#[derive(Default, Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
67pub struct WordEntry {
68    pub word_id: WordId,
69    pub word_cost: i16,
70    pub left_id: u16,
71    pub right_id: u16,
72}
73
74impl WordEntry {
75    pub const SERIALIZED_LEN: usize = 10;
76
77    pub fn left_id(&self) -> u32 {
78        self.left_id as u32
79    }
80
81    pub fn right_id(&self) -> u32 {
82        self.right_id as u32
83    }
84
85    pub fn serialize<W: io::Write>(&self, wtr: &mut W) -> io::Result<()> {
86        wtr.write_u32::<LittleEndian>(self.word_id.id)?;
87        wtr.write_i16::<LittleEndian>(self.word_cost)?;
88        wtr.write_u16::<LittleEndian>(self.left_id)?;
89        wtr.write_u16::<LittleEndian>(self.right_id)?;
90        Ok(())
91    }
92
93    pub fn deserialize(data: &[u8], is_system_entry: bool) -> WordEntry {
94        let word_id = WordId::new(
95            if is_system_entry {
96                LexType::System
97            } else {
98                LexType::User
99            },
100            LittleEndian::read_u32(&data[0..4]),
101        );
102        let word_cost = LittleEndian::read_i16(&data[4..6]);
103        let left_id = LittleEndian::read_u16(&data[6..8]);
104        let right_id = LittleEndian::read_u16(&data[8..10]);
105        WordEntry {
106            word_id,
107            word_cost,
108            left_id,
109            right_id,
110        }
111    }
112}
113
114#[derive(Clone, Copy, Debug, Default)]
115pub enum EdgeType {
116    #[default]
117    KNOWN,
118    UNKNOWN,
119    USER,
120    INSERTED,
121}
122
123#[derive(Eq, PartialEq, Clone, Copy, Debug)]
124pub struct EdgeId(pub u32);
125
126#[derive(Default, Clone, Debug)]
127pub struct Edge {
128    pub edge_type: EdgeType,
129    pub word_entry: WordEntry,
130
131    pub path_cost: i32,
132    pub left_edge: Option<EdgeId>,
133
134    pub start_index: u32,
135    pub stop_index: u32,
136
137    pub kanji_only: bool,
138}
139
140impl Edge {
141    // TODO fix em
142    pub fn num_chars(&self) -> usize {
143        (self.stop_index - self.start_index) as usize / 3
144    }
145}
146
147#[derive(Clone, Default)]
148pub struct Lattice {
149    capacity: usize,
150    edges: Vec<Edge>,
151    starts_at: Vec<Vec<EdgeId>>,
152    ends_at: Vec<Vec<EdgeId>>,
153    // Buffer reuse optimization: pre-allocated vectors for reuse
154    edge_buffer: Vec<Edge>,
155    edge_id_buffer: Vec<EdgeId>,
156}
157
158#[inline]
159fn is_kanji(c: char) -> bool {
160    let c = c as u32;
161    // Direct comparison is faster than range.contains()
162    (19968..=40879).contains(&c)
163}
164
165#[inline]
166fn is_kanji_only(s: &str) -> bool {
167    // Early exit for empty strings
168    !s.is_empty() && s.chars().all(is_kanji)
169}
170
171impl Lattice {
172    /// Helper method to create an edge efficiently
173    #[inline]
174    fn create_edge(
175        edge_type: EdgeType,
176        word_entry: WordEntry,
177        start: usize,
178        stop: usize,
179        kanji_only: bool,
180    ) -> Edge {
181        Edge {
182            edge_type,
183            word_entry,
184            left_edge: None,
185            start_index: start as u32,
186            stop_index: stop as u32,
187            path_cost: i32::MAX,
188            kanji_only,
189        }
190    }
191
192    pub fn clear(&mut self) {
193        for edge_vec in &mut self.starts_at {
194            edge_vec.clear();
195        }
196        for edge_vec in &mut self.ends_at {
197            edge_vec.clear();
198        }
199        self.edges.clear();
200        // Clear buffers but preserve capacity for reuse
201        self.edge_buffer.clear();
202        self.edge_id_buffer.clear();
203    }
204
205    /// Get a reusable edge buffer with preserved capacity
206    pub fn get_edge_buffer(&mut self) -> &mut Vec<Edge> {
207        self.edge_buffer.clear();
208        &mut self.edge_buffer
209    }
210
211    /// Get a reusable edge ID buffer with preserved capacity
212    pub fn get_edge_id_buffer(&mut self) -> &mut Vec<EdgeId> {
213        self.edge_id_buffer.clear();
214        &mut self.edge_id_buffer
215    }
216
217    fn set_capacity(&mut self, text_len: usize) {
218        self.clear();
219        if self.capacity < text_len {
220            self.capacity = text_len;
221            self.edges.clear();
222            self.starts_at.resize(text_len + 1, Vec::new());
223            self.ends_at.resize(text_len + 1, Vec::new());
224        }
225    }
226
227    #[inline(never)]
228    pub fn set_text(
229        &mut self,
230        dict: &PrefixDictionary,
231        user_dict: &Option<&PrefixDictionary>,
232        char_definitions: &CharacterDefinition,
233        unknown_dictionary: &UnknownDictionary,
234        text: &str,
235        search_mode: &Mode,
236    ) {
237        let len = text.len();
238        self.set_capacity(len);
239
240        let start_edge_id = self.add_edge(Edge::default());
241        let end_edge_id = self.add_edge(Edge::default());
242
243        assert_eq!(EOS_NODE, end_edge_id);
244        self.ends_at[0].push(start_edge_id);
245        self.starts_at[len].push(end_edge_id);
246
247        // index of the last character of unknown word
248        let mut unknown_word_end: Option<usize> = None;
249
250        for start in 0..len {
251            // No arc is ending here.
252            // No need to check if a valid word starts here.
253            if self.ends_at[start].is_empty() {
254                continue;
255            }
256
257            let suffix = &text[start..];
258
259            let mut found: bool = false;
260
261            // lookup user dictionary
262            if user_dict.is_some() {
263                let dict = user_dict.as_ref().unwrap();
264                for (prefix_len, word_entry) in dict.prefix(suffix) {
265                    let edge = Self::create_edge(
266                        EdgeType::KNOWN,
267                        word_entry,
268                        start,
269                        start + prefix_len,
270                        is_kanji_only(&suffix[..prefix_len]),
271                    );
272                    self.add_edge_in_lattice(edge);
273                    found = true;
274                }
275            }
276
277            // we check all word starting at start, using the double array, like we would use
278            // a prefix trie, and populate the lattice with as many edges
279            for (prefix_len, word_entry) in dict.prefix(suffix) {
280                let edge = Self::create_edge(
281                    EdgeType::KNOWN,
282                    word_entry,
283                    start,
284                    start + prefix_len,
285                    is_kanji_only(&suffix[..prefix_len]),
286                );
287                self.add_edge_in_lattice(edge);
288                found = true;
289            }
290
291            // In the case of normal mode, it doesn't process unknown word greedily.
292            if (search_mode.is_search()
293                || unknown_word_end.map(|index| index <= start).unwrap_or(true))
294                && let Some(first_char) = suffix.chars().next()
295            {
296                let categories = char_definitions.lookup_categories(first_char);
297                for (category_ord, &category) in categories.iter().enumerate() {
298                    unknown_word_end = self.process_unknown_word(
299                        char_definitions,
300                        unknown_dictionary,
301                        category,
302                        category_ord,
303                        unknown_word_end,
304                        start,
305                        suffix,
306                        found,
307                    );
308                }
309            }
310        }
311    }
312
313    #[allow(clippy::too_many_arguments)]
314    fn process_unknown_word(
315        &mut self,
316        char_definitions: &CharacterDefinition,
317        unknown_dictionary: &UnknownDictionary,
318        category: CategoryId,
319        category_ord: usize,
320        unknown_word_index: Option<usize>,
321        start: usize,
322        suffix: &str,
323        found: bool,
324    ) -> Option<usize> {
325        let mut unknown_word_num_chars: usize = 0;
326        let category_data = char_definitions.lookup_definition(category);
327        if category_data.invoke || !found {
328            unknown_word_num_chars = 1;
329            if category_data.group {
330                for c in suffix.chars().skip(1) {
331                    let categories = char_definitions.lookup_categories(c);
332                    if categories.len() > category_ord && categories[category_ord] == category {
333                        unknown_word_num_chars += 1;
334                    } else {
335                        break;
336                    }
337                }
338            }
339        }
340        if unknown_word_num_chars > 0 {
341            // Optimized: Direct byte boundary calculation
342            let byte_end = suffix
343                .char_indices()
344                .nth(unknown_word_num_chars)
345                .map_or(suffix.len(), |(pos, _)| pos);
346            let unknown_word = &suffix[..byte_end];
347            for &word_id in unknown_dictionary.lookup_word_ids(category) {
348                let word_entry = unknown_dictionary.word_entry(word_id);
349                let edge = Self::create_edge(
350                    EdgeType::UNKNOWN,
351                    word_entry,
352                    start,
353                    start + unknown_word.len(),
354                    is_kanji_only(unknown_word),
355                );
356                self.add_edge_in_lattice(edge);
357            }
358            return Some(start + unknown_word.len());
359        }
360        unknown_word_index
361    }
362
363    fn add_edge_in_lattice(&mut self, edge: Edge) {
364        let start_index = edge.start_index as usize;
365        let stop_index = edge.stop_index as usize;
366        let edge_id = self.add_edge(edge);
367        self.starts_at[start_index].push(edge_id);
368        self.ends_at[stop_index].push(edge_id);
369    }
370
371    fn add_edge(&mut self, edge: Edge) -> EdgeId {
372        let edge_id = EdgeId(self.edges.len() as u32);
373        self.edges.push(edge);
374        edge_id
375    }
376
377    pub fn edge(&self, edge_id: EdgeId) -> &Edge {
378        &self.edges[edge_id.0 as usize]
379    }
380
381    #[inline(never)]
382    pub fn calculate_path_costs(&mut self, cost_matrix: &ConnectionCostMatrix, mode: &Mode) {
383        let text_len = self.starts_at.len();
384        for i in 0..text_len {
385            let left_edge_ids = &self.ends_at[i];
386            let right_edge_ids = &self.starts_at[i];
387
388            for &right_edge_id in right_edge_ids {
389                // Cache right edge data to avoid repeated access
390                let right_edge = &self.edges[right_edge_id.0 as usize];
391                let right_word_entry = right_edge.word_entry;
392                let right_left_id = right_word_entry.left_id();
393
394                // Manual loop for better performance (avoids iterator overhead)
395                let mut best_cost = i32::MAX;
396                let mut best_left = None;
397
398                for &left_edge_id in left_edge_ids {
399                    let left_edge = &self.edges[left_edge_id.0 as usize];
400                    let left_right_id = left_edge.word_entry.right_id();
401
402                    // Calculate path cost directly
403                    let mut path_cost =
404                        left_edge.path_cost + cost_matrix.cost(left_right_id, right_left_id);
405                    path_cost += mode.penalty_cost(left_edge);
406
407                    // Track minimum cost with branch-free comparison when possible
408                    if path_cost < best_cost {
409                        best_cost = path_cost;
410                        best_left = Some(left_edge_id);
411                    }
412                }
413
414                // Update edge with best path if found
415                if let Some(best_left_id) = best_left {
416                    let edge = &mut self.edges[right_edge_id.0 as usize];
417                    edge.left_edge = Some(best_left_id);
418                    edge.path_cost = right_word_entry.word_cost as i32 + best_cost;
419                }
420            }
421        }
422    }
423
424    pub fn tokens_offset(&self) -> Vec<(usize, WordId)> {
425        let mut offsets = Vec::new();
426        let mut edge_id = EOS_NODE;
427        let _edge = self.edge(EOS_NODE);
428        loop {
429            let edge = self.edge(edge_id);
430            if let Some(left_edge_id) = edge.left_edge {
431                offsets.push((edge.start_index as usize, edge.word_entry.word_id));
432                edge_id = left_edge_id;
433            } else {
434                break;
435            }
436        }
437        offsets.reverse();
438        offsets.pop();
439        offsets
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use crate::viterbi::{LexType, WordEntry, WordId};
446
447    #[test]
448    fn test_word_entry() {
449        let mut buffer = Vec::new();
450        let word_entry = WordEntry {
451            word_id: WordId {
452                id: 1u32,
453                is_system: true,
454                lex_type: LexType::System,
455            },
456            word_cost: -17i16,
457            left_id: 1411u16,
458            right_id: 1412u16,
459        };
460        word_entry.serialize(&mut buffer).unwrap();
461        assert_eq!(WordEntry::SERIALIZED_LEN, buffer.len());
462        let word_entry2 = WordEntry::deserialize(&buffer[..], true);
463        assert_eq!(word_entry, word_entry2);
464    }
465}