Skip to main content

kiri_engine/dictionary/
trie.rs

1//! Double-array trie for efficient common-prefix search.
2//! Ported from Sudachi's DARTSCLONE implementation.
3//!
4//! The trie is stored as a flat i32 slice. Each element is a single unit
5//! encoding offset, hasLeaf, label, and value in different bit fields.
6
7/// Result of a trie lookup.
8#[derive(Debug, Clone, Copy)]
9pub struct TrieMatch {
10    /// The value stored at the trie node.
11    pub value: i32,
12    /// The byte length of the matched key.
13    pub length: usize,
14}
15
16// ---- DARTSCLONE unit bit-field accessors ----
17
18#[inline]
19fn unit_has_leaf(unit: i32) -> bool {
20    ((unit as u32) >> 8) & 1 != 0
21}
22
23#[inline]
24fn unit_value(unit: i32) -> i32 {
25    unit & 0x7FFF_FFFFi32
26}
27
28#[inline]
29fn unit_label(unit: i32) -> i32 {
30    unit & (i32::MIN | 0xff)
31}
32
33#[inline]
34fn unit_offset(unit: i32) -> u32 {
35    let u = unit as u32;
36    (u >> 10) << ((u & (1 << 9)) >> 6)
37}
38
39/// Read a trie from raw bytes. Returns (array, size, bytes_read).
40pub fn read_trie(data: &[u8], offset: usize) -> (Vec<i32>, usize, usize) {
41    let size = i32::from_le_bytes([
42        data[offset],
43        data[offset + 1],
44        data[offset + 2],
45        data[offset + 3],
46    ]) as usize;
47
48    let mut array = Vec::with_capacity(size);
49    let mut pos = offset + 4;
50    for _ in 0..size {
51        let val = i32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
52        array.push(val);
53        pos += 4;
54    }
55
56    let bytes_read = 4 + size * 4;
57    (array, size, bytes_read)
58}
59
60/// Double-array trie backed by a pre-built i32 array.
61pub struct DoubleArrayTrie {
62    array: Vec<i32>,
63    size: usize,
64}
65
66impl DoubleArrayTrie {
67    /// Create from a pre-built array.
68    pub fn new(array: Vec<i32>, size: usize) -> Self {
69        Self { array, size }
70    }
71
72    /// Read a trie from raw bytes at the given offset.
73    pub fn from_bytes(data: &[u8], offset: usize) -> (Self, usize) {
74        let (array, size, bytes_read) = read_trie(data, offset);
75        (Self { array, size }, bytes_read)
76    }
77
78    /// Common prefix search: find all keys that are prefixes of the input.
79    /// This is the core operation for lattice-based tokenization.
80    #[allow(clippy::needless_range_loop)]
81    pub fn common_prefix_search(&self, key: &[u8], offset: usize, limit: usize) -> Vec<TrieMatch> {
82        let mut results = Vec::new();
83        let arr = &self.array;
84        let end = std::cmp::min(offset + limit, key.len());
85
86        let mut node_pos = 0u32;
87
88        // XOR with root offset
89        let root = match arr.get(node_pos as usize) {
90            Some(&v) => v,
91            None => return results,
92        };
93        node_pos ^= unit_offset(root);
94
95        for i in offset..end {
96            let b = key[i] as u32;
97
98            // Follow edge labeled `b`
99            node_pos ^= b;
100
101            let unit = match arr.get(node_pos as usize) {
102                Some(&v) => v,
103                None => return results,
104            };
105
106            // Check if this node exists (label must match the byte we followed)
107            if unit_label(unit) != b as i32 {
108                return results;
109            }
110
111            // Advance nodePos to next level
112            node_pos ^= unit_offset(unit);
113
114            // If this unit has a leaf, read the value from the NEW nodePos
115            if unit_has_leaf(unit) {
116                if let Some(&leaf) = arr.get(node_pos as usize) {
117                    results.push(TrieMatch {
118                        value: unit_value(leaf),
119                        length: i - offset + 1,
120                    });
121                } else {
122                    return results;
123                }
124            }
125        }
126
127        results
128    }
129
130    /// Exact match search: look up a single key.
131    #[allow(clippy::needless_range_loop)]
132    pub fn exact_match_search(&self, key: &[u8], offset: usize, length: usize) -> i32 {
133        let arr = &self.array;
134        let end = offset + length;
135
136        let mut node_pos = 0u32;
137        let mut last_unit = 0i32;
138
139        let root = match arr.get(node_pos as usize) {
140            Some(&v) => v,
141            None => return -1,
142        };
143        node_pos ^= unit_offset(root);
144
145        for i in offset..end {
146            let b = key[i] as u32;
147            node_pos ^= b;
148
149            let unit = match arr.get(node_pos as usize) {
150                Some(&v) => v,
151                None => return -1,
152            };
153            if unit_label(unit) != b as i32 {
154                return -1;
155            }
156
157            node_pos ^= unit_offset(unit);
158            last_unit = unit;
159        }
160
161        if !unit_has_leaf(last_unit) {
162            return -1;
163        }
164
165        match arr.get(node_pos as usize) {
166            Some(&v) => unit_value(v),
167            None => -1,
168        }
169    }
170
171    /// Total number of units in the trie.
172    pub fn total_size(&self) -> usize {
173        self.size
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_bit_field_accessors() {
183        // Test with a known unit value that has a leaf
184        // Bit 8 set = hasLeaf
185        let unit: i32 = 0x100; // bit 8 set
186        assert!(unit_has_leaf(unit));
187        assert!(!unit_has_leaf(0));
188
189        // unitLabel: extracts bits 31 and 0..7
190        let unit_with_label = 0x42i32; // label = 0x42
191        assert_eq!(unit_label(unit_with_label) & 0xff, 0x42);
192
193        // unitValue: lower 31 bits
194        let unit_val = 0x7FFF_FFFFi32;
195        assert_eq!(unit_value(unit_val), 0x7FFF_FFFF);
196    }
197
198    #[test]
199    fn test_unit_offset_basic() {
200        // unitOffset: (u >> 10) << ((u & (1 << 9)) >> 6)
201        // When bit 9 is 0: shift = 0, so offset = u >> 10
202        let unit = 0x400i32; // bit 10 set
203        assert_eq!(unit_offset(unit), 1);
204
205        let unit2 = 0x800i32; // bit 11 set
206        assert_eq!(unit_offset(unit2), 2);
207    }
208}