Skip to main content

kiri_native/
lib.rs

1use std::fs;
2use std::sync::Arc;
3
4use memmap2::Mmap;
5use napi::bindgen_prelude::*;
6use napi_derive::napi;
7use serde::Serialize;
8
9use kiri_engine::dictionary::grammar::read_grammar;
10use kiri_engine::dictionary::header::read_dictionary_header;
11use kiri_engine::dictionary::lexicon::Lexicon;
12use kiri_engine::dictionary::lexicon_set::LexiconSet;
13use kiri_engine::inhibit_connection_in_data;
14use kiri_engine::oov::mecab::{CategoryInfo, MecabOov, MecabOovConfig};
15use kiri_engine::oov::simple::OovProviderConfig;
16use kiri_engine::shared::SharedDictionary;
17use kiri_engine::tokenizer::{build_lattice_and_solve, DictionaryCtx, LatticeInput, OovCtx};
18use kiri_engine::types::{
19    has_synonym_group_ids, is_user_dictionary, ConnectionCosts, Grammar, DICTIONARY_HEADER_SIZE,
20};
21use kiri_engine::DictData;
22
23/// MeCab OOV options passed from JS.
24#[napi(object)]
25pub struct NativeMecabCategory {
26    pub category_type: u32,
27    pub is_invoke: bool,
28    pub is_group: bool,
29    pub length: u32,
30}
31
32#[napi(object)]
33pub struct NativeMecabOov {
34    pub category_type: u32,
35    pub left_id: i32,
36    pub right_id: i32,
37    pub cost: i32,
38    pub pos_id: i32,
39}
40
41#[napi(object)]
42pub struct NativeTokenizerOptions {
43    pub oov_left_id: Option<i32>,
44    pub oov_right_id: Option<i32>,
45    pub oov_cost: Option<i32>,
46    pub oov_pos_id: Option<i32>,
47    pub mecab_categories: Option<Vec<NativeMecabCategory>>,
48    pub mecab_oovs: Option<Vec<NativeMecabOov>>,
49    pub inhibited_connections: Option<Vec<Vec<i32>>>,
50}
51
52/// PathNode returned to JS.
53#[napi(object)]
54#[derive(Serialize)]
55pub struct NativePathNode {
56    pub begin: u32,
57    pub end: u32,
58    pub word_id: i32,
59    pub left_id: i32,
60    pub right_id: i32,
61    pub cost: i32,
62    pub total_cost: i32,
63    pub is_oov: bool,
64    pub oov_pos_id: Option<i32>,
65}
66
67/// WordInfo returned to JS.
68#[napi(object)]
69pub struct NativeWordInfo {
70    pub surface: String,
71    pub headword_length: u32,
72    pub pos_id: i32,
73    pub normalized_form: String,
74    pub dictionary_form_word_id: i32,
75    pub dictionary_form: String,
76    pub reading_form: String,
77    pub a_unit_split: Vec<i32>,
78    pub b_unit_split: Vec<i32>,
79    pub word_structure: Vec<i32>,
80    pub synonym_gids: Vec<i32>,
81}
82
83/// Pre-parsed shared dictionary that can be reused across multiple tokenizers.
84/// The ~150 MB DoubleArrayTrie is shared via Arc, so creating additional
85/// tokenizers from this handle is nearly free.
86#[napi]
87pub struct NativeSharedDictionary {
88    inner: SharedDictionary,
89}
90
91#[napi]
92impl NativeSharedDictionary {
93    /// Load a dictionary from disk and pre-parse the trie for sharing.
94    #[napi(constructor)]
95    pub fn new(dict_path: String, inhibited_connections: Option<Vec<Vec<i32>>>) -> Result<Self> {
96        let file = fs::File::open(&dict_path)
97            .map_err(|e| Error::from_reason(format!("Failed to open dictionary: {e}")))?;
98
99        let mmap = unsafe { Mmap::map(&file) }
100            .map_err(|e| Error::from_reason(format!("Failed to mmap dictionary: {e}")))?;
101
102        let needs_inhibition = inhibited_connections
103            .as_ref()
104            .is_some_and(|pairs| !pairs.is_empty());
105
106        let inner = if needs_inhibition {
107            let mut data: Vec<u8> = mmap[..].to_vec();
108            let pairs = inhibited_connections.unwrap();
109
110            let (grammar, _) = read_grammar(&data, DICTIONARY_HEADER_SIZE)
111                .map_err(|e| Error::from_reason(format!("Failed to read grammar: {e}")))?;
112
113            for pair in &pairs {
114                if pair.len() == 2 {
115                    inhibit_connection_in_data(
116                        &mut data,
117                        &grammar.connection,
118                        checked_i16(pair[0], "inhibited_connection.left_id")?,
119                        checked_i16(pair[1], "inhibited_connection.right_id")?,
120                    );
121                }
122            }
123
124            SharedDictionary::new(data)
125                .map_err(|e| Error::from_reason(format!("Failed to build shared dict: {e}")))?
126        } else {
127            // Zero-copy: keep the mmap, OS manages the pages
128            SharedDictionary::from_mmap(mmap)
129                .map_err(|e| Error::from_reason(format!("Failed to build shared dict: {e}")))?
130        };
131
132        Ok(Self { inner })
133    }
134}
135
136#[napi]
137pub struct NativeTokenizer {
138    lexicon_set: LexiconSet,
139    grammar: Grammar,
140    /// Arc'd system dict bytes (shared with LexiconSet). Can be mmap or owned.
141    system_data: Arc<DictData>,
142    oov_config: OovProviderConfig,
143    mecab_config: Option<MecabOovConfig>,
144}
145
146#[napi]
147impl NativeTokenizer {
148    /// Load a system dictionary and create the native tokenizer.
149    #[napi(constructor)]
150    pub fn new(dict_path: String, options: Option<NativeTokenizerOptions>) -> Result<Self> {
151        let file = fs::File::open(&dict_path)
152            .map_err(|e| Error::from_reason(format!("Failed to open dictionary: {e}")))?;
153
154        // SAFETY: We keep the Mmap alive for the lifetime of NativeTokenizer.
155        let mmap = unsafe { Mmap::map(&file) }
156            .map_err(|e| Error::from_reason(format!("Failed to mmap dictionary: {e}")))?;
157
158        let mut data: Vec<u8> = mmap[..].to_vec();
159
160        let header = read_dictionary_header(&data, 0)
161            .map_err(|e| Error::from_reason(format!("Failed to read header: {e}")))?;
162
163        let has_synonyms = has_synonym_group_ids(header.version);
164
165        let (grammar, grammar_bytes) = read_grammar(&data, DICTIONARY_HEADER_SIZE)
166            .map_err(|e| Error::from_reason(format!("Failed to read grammar: {e}")))?;
167
168        let lexicon_offset = DICTIONARY_HEADER_SIZE + grammar_bytes;
169        let (lexicon, _) = Lexicon::from_bytes(&data, lexicon_offset, has_synonyms);
170
171        // Apply connection cost inhibitions before wrapping data in Arc
172        if let Some(ref opts) = options {
173            if let Some(ref pairs) = opts.inhibited_connections {
174                for pair in pairs {
175                    if pair.len() == 2 {
176                        inhibit_connection_in_data(
177                            &mut data,
178                            &grammar.connection,
179                            checked_i16(pair[0], "inhibited_connection.left_id")?,
180                            checked_i16(pair[1], "inhibited_connection.right_id")?,
181                        );
182                    }
183                }
184            }
185        }
186
187        let data_arc = Arc::new(DictData::Owned(data));
188        let lexicon_set = LexiconSet::new(lexicon, data_arc.clone(), grammar.pos_list.len());
189
190        // Build OOV config
191        let oov_config = if let Some(ref opts) = options {
192            OovProviderConfig {
193                left_id: checked_i16(opts.oov_left_id.unwrap_or(0), "oov_left_id")?,
194                right_id: checked_i16(opts.oov_right_id.unwrap_or(0), "oov_right_id")?,
195                cost: checked_i16(opts.oov_cost.unwrap_or(10000), "oov_cost")?,
196                pos_id: checked_i16(opts.oov_pos_id.unwrap_or(0), "oov_pos_id")?,
197            }
198        } else {
199            OovProviderConfig::default()
200        };
201
202        // Build MeCab OOV config
203        let mecab_config = if let Some(ref opts) = options {
204            build_mecab_config(opts)?
205        } else {
206            None
207        };
208
209        Ok(Self {
210            lexicon_set,
211            grammar,
212            system_data: data_arc,
213            oov_config,
214            mecab_config,
215        })
216    }
217
218    /// Create a tokenizer from a shared dictionary handle.
219    /// The ~150 MB trie is shared via Arc (16-byte refcount bump, not a copy).
220    #[napi(factory)]
221    pub fn from_shared(
222        shared: &NativeSharedDictionary,
223        options: Option<NativeTokenizerOptions>,
224    ) -> Result<Self> {
225        let data_arc = shared.inner.data().clone();
226
227        let grammar = shared
228            .inner
229            .create_grammar()
230            .map_err(|e| Error::from_reason(format!("Failed to read grammar: {e}")))?;
231
232        let (lexicon, _) = shared.inner.create_lexicon();
233
234        let lexicon_set = LexiconSet::new(lexicon, data_arc.clone(), grammar.pos_list.len());
235
236        let oov_config = if let Some(ref opts) = options {
237            OovProviderConfig {
238                left_id: checked_i16(opts.oov_left_id.unwrap_or(0), "oov_left_id")?,
239                right_id: checked_i16(opts.oov_right_id.unwrap_or(0), "oov_right_id")?,
240                cost: checked_i16(opts.oov_cost.unwrap_or(10000), "oov_cost")?,
241                pos_id: checked_i16(opts.oov_pos_id.unwrap_or(0), "oov_pos_id")?,
242            }
243        } else {
244            OovProviderConfig::default()
245        };
246
247        let mecab_config = if let Some(ref opts) = options {
248            build_mecab_config(opts)?
249        } else {
250            None
251        };
252
253        Ok(Self {
254            lexicon_set,
255            grammar,
256            system_data: data_arc,
257            oov_config,
258            mecab_config,
259        })
260    }
261
262    /// Add a user dictionary.
263    #[napi]
264    pub fn add_user_dictionary(&mut self, dict_path: String) -> Result<()> {
265        let file = fs::File::open(&dict_path)
266            .map_err(|e| Error::from_reason(format!("Failed to open user dictionary: {e}")))?;
267
268        let mmap = unsafe { Mmap::map(&file) }
269            .map_err(|e| Error::from_reason(format!("Failed to mmap user dictionary: {e}")))?;
270
271        let data: Vec<u8> = mmap[..].to_vec();
272
273        let header = read_dictionary_header(&data, 0)
274            .map_err(|e| Error::from_reason(format!("Failed to read user dict header: {e}")))?;
275
276        if !is_user_dictionary(header.version) {
277            return Err(Error::from_reason("Not a user dictionary"));
278        }
279
280        let has_synonyms = has_synonym_group_ids(header.version);
281
282        let (user_grammar, grammar_bytes) = read_grammar(&data, DICTIONARY_HEADER_SIZE)
283            .map_err(|e| Error::from_reason(format!("Failed to read user grammar: {e}")))?;
284
285        let lexicon_offset = DICTIONARY_HEADER_SIZE + grammar_bytes;
286        let (lexicon, _) = Lexicon::from_bytes(&data, lexicon_offset, has_synonyms);
287
288        let pos_offset = self.grammar.pos_list.len();
289        for pos in &user_grammar.pos_list {
290            self.grammar.pos_list.push(pos.clone());
291        }
292
293        self.lexicon_set
294            .add(lexicon, Arc::new(DictData::Owned(data)), pos_offset)
295            .map_err(Error::from_reason)?;
296
297        Ok(())
298    }
299
300    /// The hot path — builds lattice and returns best path nodes.
301    #[napi]
302    #[allow(clippy::too_many_arguments)]
303    pub fn tokenize(
304        &self,
305        bytes: Uint8Array,
306        can_bow: Uint8Array,
307        char_categories: Uint32Array,
308        word_candidate_lengths: Uint32Array,
309        continuous_lengths: Uint32Array,
310        code_point_byte_lengths_flat: Uint8Array,
311        code_point_offsets: Uint32Array,
312    ) -> Result<Vec<NativePathNode>> {
313        let dict = DictionaryCtx {
314            lexicon: &self.lexicon_set,
315            connection: ConnectionCosts::new(&self.grammar.connection, self.system_data.as_ref()),
316        };
317        let input = LatticeInput {
318            bytes: &bytes,
319            can_bow: &can_bow,
320            char_categories: &char_categories,
321            word_candidate_lengths: &word_candidate_lengths,
322            continuous_lengths: &continuous_lengths,
323            code_point_byte_lengths_flat: &code_point_byte_lengths_flat,
324            code_point_offsets: &code_point_offsets,
325        };
326        let oov = OovCtx {
327            simple: &self.oov_config,
328            mecab: self.mecab_config.as_ref(),
329        };
330        let result = build_lattice_and_solve(&dict, &input, &oov).map_err(Error::from_reason)?;
331
332        Ok(result
333            .into_iter()
334            .map(|n| NativePathNode {
335                begin: n.begin as u32,
336                end: n.end as u32,
337                word_id: n.word_id,
338                left_id: n.left_id as i32,
339                right_id: n.right_id as i32,
340                cost: n.cost as i32,
341                total_cost: n.total_cost,
342                is_oov: n.is_oov,
343                oov_pos_id: n.oov_pos_id.map(|id| id as i32),
344            })
345            .collect())
346    }
347
348    /// Get word info by packed word ID. Called during TS post-processing.
349    #[napi]
350    pub fn get_word_info(&self, word_id: i32) -> Result<NativeWordInfo> {
351        let wi = self.lexicon_set.get_word_info(word_id);
352        Ok(NativeWordInfo {
353            surface: wi.surface,
354            headword_length: wi.headword_length as u32,
355            pos_id: wi.pos_id as i32,
356            normalized_form: wi.normalized_form,
357            dictionary_form_word_id: wi.dictionary_form_word_id,
358            dictionary_form: wi.dictionary_form,
359            reading_form: wi.reading_form,
360            a_unit_split: wi.a_unit_split,
361            b_unit_split: wi.b_unit_split,
362            word_structure: wi.word_structure,
363            synonym_gids: wi.synonym_gids,
364        })
365    }
366
367    /// Get POS tags by ID.
368    #[napi]
369    pub fn get_pos(&self, pos_id: u32) -> Result<Vec<String>> {
370        let pos = self
371            .grammar
372            .pos_list
373            .get(pos_id as usize)
374            .ok_or_else(|| Error::from_reason(format!("POS ID {pos_id} out of range")))?;
375        Ok(pos.tags.to_vec())
376    }
377
378    /// Total number of POS entries.
379    #[napi(getter)]
380    pub fn pos_count(&self) -> u32 {
381        self.grammar.pos_list.len() as u32
382    }
383}
384
385/// Convert i32 to i16, returning an error if the value doesn't fit.
386fn checked_i16(value: i32, field: &str) -> Result<i16> {
387    i16::try_from(value)
388        .map_err(|_| Error::from_reason(format!("{field} value {value} exceeds i16 range")))
389}
390
391fn build_mecab_config(opts: &NativeTokenizerOptions) -> Result<Option<MecabOovConfig>> {
392    let categories_raw = match opts.mecab_categories.as_ref() {
393        Some(c) => c,
394        None => return Ok(None),
395    };
396    let oovs_raw = match opts.mecab_oovs.as_ref() {
397        Some(o) => o,
398        None => return Ok(None),
399    };
400
401    let categories: Vec<CategoryInfo> = categories_raw
402        .iter()
403        .map(|c| CategoryInfo {
404            category_type: c.category_type,
405            is_invoke: c.is_invoke,
406            is_group: c.is_group,
407            length: c.length,
408        })
409        .collect();
410
411    // Group OOV entries by category type
412    let mut oov_map: Vec<(u32, Vec<MecabOov>)> = Vec::new();
413    for raw in oovs_raw {
414        let oov = MecabOov {
415            left_id: checked_i16(raw.left_id, "mecab_oov.left_id")?,
416            right_id: checked_i16(raw.right_id, "mecab_oov.right_id")?,
417            cost: checked_i16(raw.cost, "mecab_oov.cost")?,
418            pos_id: checked_i16(raw.pos_id, "mecab_oov.pos_id")?,
419        };
420        if let Some(entry) = oov_map.iter_mut().find(|(t, _)| *t == raw.category_type) {
421            entry.1.push(oov);
422        } else {
423            oov_map.push((raw.category_type, vec![oov]));
424        }
425    }
426
427    Ok(Some(MecabOovConfig {
428        categories,
429        oov_list: oov_map,
430    }))
431}