Skip to main content

kiri_engine/
tokenizer.rs

1//! buildLattice main loop: for each byte position, checks canBow, hasPreviousNode,
2//! performs lexicon lookup, MeCab OOV, simple OOV, and inserts into lattice.
3//!
4//! This is where the speedup lives -- the hot path that Rust accelerates.
5
6use crate::dictionary::lexicon_set::LexiconSet;
7use crate::lattice::node::LatticeNode;
8use crate::lattice::Lattice;
9use crate::oov::mecab::{provide_mecab_oov, MecabOovConfig};
10use crate::oov::simple::{provide_oov, OovProviderConfig};
11use crate::types::ConnectionCosts;
12
13/// PathNode returned to JS — minimal data for TS post-processing.
14#[derive(Debug, Clone)]
15pub struct PathNode {
16    pub begin: usize,
17    pub end: usize,
18    pub word_id: i32,
19    pub left_id: i16,
20    pub right_id: i16,
21    pub cost: i16,
22    pub total_cost: i32,
23    pub is_oov: bool,
24    pub oov_pos_id: Option<i16>,
25}
26
27impl From<&LatticeNode> for PathNode {
28    fn from(n: &LatticeNode) -> Self {
29        Self {
30            begin: n.begin,
31            end: n.end,
32            word_id: n.word_id,
33            left_id: n.left_id,
34            right_id: n.right_id,
35            cost: n.cost,
36            total_cost: n.total_cost,
37            is_oov: n.is_oov,
38            oov_pos_id: n.oov_pos_id,
39        }
40    }
41}
42
43/// Pre-packed input text data for lattice construction.
44pub struct LatticeInput<'a> {
45    pub bytes: &'a [u8],
46    pub can_bow: &'a [u8],
47    pub char_categories: &'a [u32],
48    pub word_candidate_lengths: &'a [u32],
49    pub continuous_lengths: &'a [u32],
50    pub code_point_byte_lengths_flat: &'a [u8],
51    pub code_point_offsets: &'a [u32],
52}
53
54/// Dictionary context for lattice construction.
55pub struct DictionaryCtx<'a> {
56    pub lexicon: &'a LexiconSet,
57    pub connection: ConnectionCosts<'a>,
58}
59
60/// OOV provider context.
61pub struct OovCtx<'a> {
62    pub simple: &'a OovProviderConfig,
63    pub mecab: Option<&'a MecabOovConfig>,
64}
65
66/// Build the lattice and return the best path.
67pub fn build_lattice_and_solve(
68    dict: &DictionaryCtx<'_>,
69    input: &LatticeInput<'_>,
70    oov: &OovCtx<'_>,
71) -> Result<Vec<PathNode>, String> {
72    let byte_length = input.bytes.len();
73    let mut lattice = Lattice::new();
74    lattice.resize(byte_length);
75
76    for i in 0..byte_length {
77        // Only start words at character boundaries
78        if input.can_bow[i] == 0 {
79            continue;
80        }
81
82        // Must have previous nodes at this position
83        if !lattice.has_previous_node(i) {
84            continue;
85        }
86
87        let mut has_words = false;
88
89        // Lexicon lookup
90        let matches = dict.lexicon.lookup(input.bytes, i, byte_length - i);
91        for m in &matches {
92            for &word_id in &m.word_ids {
93                let node = LatticeNode {
94                    word_id,
95                    left_id: dict.lexicon.get_left_id(word_id),
96                    right_id: dict.lexicon.get_right_id(word_id),
97                    cost: dict.lexicon.get_cost(word_id),
98                    ..Default::default()
99                };
100                lattice.insert(i, i + m.length, node, &dict.connection);
101                has_words = true;
102            }
103        }
104
105        // MeCab OOV
106        if let Some(mecab_cfg) = oov.mecab {
107            let cat = input.char_categories.get(i).copied().unwrap_or(1);
108            let cont_len = input.continuous_lengths.get(i).copied().unwrap_or(0) as usize;
109
110            // Extract code point byte lengths for this position
111            let cp_start = input.code_point_offsets.get(i).copied().unwrap_or(0) as usize;
112            let cp_end = if i + cont_len < input.code_point_offsets.len() {
113                input.code_point_offsets[i + cont_len] as usize
114            } else {
115                input.code_point_byte_lengths_flat.len()
116            };
117            // Convert flat byte array to usize slice
118            let cp_bytes: Vec<usize> = input
119                .code_point_byte_lengths_flat
120                .get(cp_start..cp_end)
121                .unwrap_or(&[])
122                .iter()
123                .map(|&b| b as usize)
124                .collect();
125
126            let mecab_results = provide_mecab_oov(cat, cont_len, &cp_bytes, has_words, mecab_cfg);
127            for result in mecab_results {
128                lattice.insert(i, i + result.byte_length, result.node, &dict.connection);
129                has_words = true;
130            }
131        }
132
133        // Simple OOV fallback
134        if !has_words {
135            let wc_len = input.word_candidate_lengths.get(i).copied().unwrap_or(0) as usize;
136            if let Some(node) = provide_oov(wc_len, false, oov.simple) {
137                lattice.insert(i, i + wc_len, node, &dict.connection);
138            }
139        }
140    }
141
142    lattice.connect_eos(&dict.connection);
143    let best_path = lattice.get_best_path()?;
144    Ok(best_path.iter().map(PathNode::from).collect())
145}