1use crate::builder::Builder;
3use crate::errors::Result;
4use crate::mapper::CodeMapper;
5use crate::Node;
6
7use crate::END_CODE;
8
9use alloc::vec::Vec;
10
11use core::mem;
12
13pub struct Trie {
15 pub(crate) mapper: CodeMapper,
16 pub(crate) nodes: Vec<Node>,
17}
18
19impl Trie {
20 pub fn from_keys<I, K>(keys: I) -> Result<Self>
50 where
51 I: IntoIterator<Item = K>,
52 K: AsRef<str>,
53 {
54 Builder::new().build_from_keys(keys)?.release_trie()
55 }
56
57 pub fn from_records<I, K>(records: I) -> Result<Self>
84 where
85 I: IntoIterator<Item = (K, u32)>,
86 K: AsRef<str>,
87 {
88 Builder::new().build_from_records(records)?.release_trie()
89 }
90
91 pub fn serialize_to_vec(&self) -> Vec<u8> {
103 let mut dest = Vec::with_capacity(self.io_bytes());
104 self.mapper.serialize_into_vec(&mut dest);
105 dest.extend_from_slice(&u32::try_from(self.nodes.len()).unwrap().to_le_bytes());
106 for node in &self.nodes {
107 dest.extend_from_slice(&node.serialize());
108 }
109 dest
110 }
111
112 pub fn deserialize_from_slice(source: &[u8]) -> (Self, &[u8]) {
136 let (mapper, mut source) = CodeMapper::deserialize_from_slice(source);
137 let nodes = {
138 let len = u32::from_le_bytes(source[..4].try_into().unwrap()) as usize;
139 source = &source[4..];
140 let mut nodes = Vec::with_capacity(len);
141 for _ in 0..len {
142 nodes.push(Node::deserialize(
143 source[..Node::io_bytes()].try_into().unwrap(),
144 ));
145 source = &source[Node::io_bytes()..];
146 }
147 nodes
148 };
149 (Self { mapper, nodes }, source)
150 }
151
152 #[inline(always)]
170 pub fn exact_match<I>(&self, key: I) -> Option<u32>
171 where
172 I: IntoIterator<Item = char>,
173 {
174 let mut node_idx = 0;
175 for c in key {
176 node_idx = self
177 .mapper
178 .get(c)
179 .and_then(|mc| self.get_child_idx(node_idx, mc))?;
180 }
181 if self.is_leaf(node_idx) {
182 Some(self.get_value(node_idx))
183 } else if self.has_leaf(node_idx) {
184 Some(self.get_value(self.get_leaf_idx(node_idx)))
185 } else {
186 None
187 }
188 }
189
190 pub const fn common_prefix_search<I>(&self, haystack: I) -> CommonPrefixSearchIter<I> {
221 CommonPrefixSearchIter {
222 haystack,
223 haystack_pos: 0,
224 trie: self,
225 node_idx: 0,
226 }
227 }
228
229 #[inline(always)]
230 fn get_child_idx(&self, node_idx: u32, mc: u32) -> Option<u32> {
231 if self.is_leaf(node_idx) {
232 return None;
233 }
234 Some(self.get_base(node_idx) ^ mc)
235 .filter(|&child_idx| self.get_check(child_idx) == node_idx)
236 }
237
238 #[inline(always)]
239 fn node_ref(&self, node_idx: u32) -> &Node {
240 &self.nodes[usize::try_from(node_idx).unwrap()]
241 }
242
243 #[inline(always)]
244 fn get_base(&self, node_idx: u32) -> u32 {
245 self.node_ref(node_idx).get_base()
246 }
247
248 #[inline(always)]
249 fn get_check(&self, node_idx: u32) -> u32 {
250 self.node_ref(node_idx).get_check()
251 }
252
253 #[inline(always)]
254 fn is_leaf(&self, node_idx: u32) -> bool {
255 self.node_ref(node_idx).is_leaf()
256 }
257
258 #[inline(always)]
259 fn has_leaf(&self, node_idx: u32) -> bool {
260 self.node_ref(node_idx).has_leaf()
261 }
262
263 #[inline(always)]
264 fn get_leaf_idx(&self, node_idx: u32) -> u32 {
265 let leaf_idx = self.get_base(node_idx) ^ END_CODE;
266 debug_assert_eq!(self.get_check(leaf_idx), node_idx);
267 leaf_idx
268 }
269
270 #[inline(always)]
271 fn get_value(&self, node_idx: u32) -> u32 {
272 debug_assert!(self.is_leaf(node_idx));
273 self.node_ref(node_idx).get_base()
274 }
275
276 pub fn heap_bytes(&self) -> usize {
278 self.mapper.heap_bytes() + self.nodes.len() * mem::size_of::<Node>()
279 }
280
281 pub fn io_bytes(&self) -> usize {
283 self.mapper.io_bytes() + self.nodes.len() * Node::io_bytes() + mem::size_of::<u32>()
284 }
285
286 pub fn num_elems(&self) -> usize {
288 self.nodes.len()
289 }
290
291 pub fn num_vacants(&self) -> usize {
297 self.nodes.iter().filter(|nd| nd.is_vacant()).count()
298 }
299}
300
301pub struct CommonPrefixSearchIter<'t, I> {
303 haystack: I,
304 haystack_pos: usize,
305 trie: &'t Trie,
306 node_idx: u32,
307}
308
309impl<I> Iterator for CommonPrefixSearchIter<'_, I>
310where
311 I: Iterator<Item = char>,
312{
313 type Item = (u32, usize);
314
315 #[inline(always)]
316 fn next(&mut self) -> Option<Self::Item> {
317 for c in self.haystack.by_ref() {
318 let mc = self.trie.mapper.get(c)?;
319 self.node_idx = self.trie.get_child_idx(self.node_idx, mc)?;
320 self.haystack_pos += 1;
321 if self.trie.is_leaf(self.node_idx) {
322 return Some((self.trie.get_value(self.node_idx), self.haystack_pos));
323 } else if self.trie.has_leaf(self.node_idx) {
324 let leaf_idx = self.trie.get_leaf_idx(self.node_idx);
325 return Some((self.trie.get_value(leaf_idx), self.haystack_pos));
326 }
327 }
328 None
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_exact_match() {
338 let keys = vec!["世界", "世界中", "世論調査", "統計調査"];
339 let trie = Trie::from_keys(&keys).unwrap();
340 for (i, key) in keys.iter().enumerate() {
341 assert_eq!(
342 trie.exact_match(key.chars()),
343 Some(u32::try_from(i).unwrap())
344 );
345 }
346 assert_eq!(trie.exact_match("世".chars()), None);
347 assert_eq!(trie.exact_match("世論".chars()), None);
348 assert_eq!(trie.exact_match("世界中で".chars()), None);
349 assert_eq!(trie.exact_match("統計".chars()), None);
350 assert_eq!(trie.exact_match("統計調".chars()), None);
351 assert_eq!(trie.exact_match("日本".chars()), None);
352 }
353
354 #[test]
355 fn test_common_prefix_search() {
356 let keys = vec!["世界", "世界中", "世論調査", "統計調査"];
357 let trie = Trie::from_keys(&keys).unwrap();
358
359 let haystack: Vec<_> = "世界中の統計世論調査".chars().collect();
360 let mut matches = vec![];
361
362 for i in 0..haystack.len() {
363 for (v, j) in trie.common_prefix_search(haystack[i..].iter().copied()) {
364 matches.push((v, i..i + j));
365 }
366 }
367 assert_eq!(matches, vec![(0, 0..2), (1, 0..3), (2, 6..10)]);
368 }
369
370 #[test]
371 fn test_serialize() {
372 let keys = vec!["世界", "世界中", "世論調査", "統計調査"];
373 let trie = Trie::from_keys(&keys).unwrap();
374
375 let bytes = trie.serialize_to_vec();
376 assert_eq!(trie.io_bytes(), bytes.len());
377
378 let (other, remain) = Trie::deserialize_from_slice(&bytes);
379 assert!(remain.is_empty());
380
381 assert_eq!(trie.mapper, other.mapper);
382 assert_eq!(trie.nodes, other.nodes);
383 }
384
385 #[test]
386 fn test_empty_set() {
387 assert!(Trie::from_keys(&[""][0..0]).is_err());
388 }
389
390 #[test]
391 fn test_empty_char() {
392 assert!(Trie::from_keys([""]).is_err());
393 }
394
395 #[test]
396 fn test_empty_key() {
397 assert!(Trie::from_keys(["", "AAA"]).is_err());
398 }
399
400 #[test]
401 fn test_unsorted_keys() {
402 assert!(Trie::from_keys(["BB", "AA"]).is_ok());
403 assert!(Trie::from_keys(["AAA", "AA"]).is_ok());
404 }
405
406 #[test]
407 fn test_duplicate_keys() {
408 assert!(Trie::from_keys(["AA", "AA"]).is_err());
409 }
410
411 #[test]
412 fn test_common_prefix_search_null() {
413 let keys = vec!["世界\0", "世界中", "世間"];
414 let trie = Trie::from_keys(&keys).unwrap();
415
416 let haystack: Vec<_> = "世界\0中の人\0世間".chars().collect();
417 let mut matches = vec![];
418
419 for i in 0..haystack.len() {
420 for (v, j) in trie.common_prefix_search(haystack[i..].iter().copied()) {
421 matches.push((v, i..i + j));
422 }
423 }
424 assert_eq!(matches, vec![(0, 0..3), (2, 7..9)]);
425 }
426}