Skip to main content

crawdad/
trie.rs

1//! A standard trie form that often provides the fastest queries.
2use crate::builder::Builder;
3use crate::errors::Result;
4use crate::mapper::CodeMapper;
5use crate::Node;
6
7use crate::END_CODE;
8
9use alloc::vec::Vec;
10
11use core::mem;
12
13/// A standard trie form that often provides the fastest queries.
14pub struct Trie {
15    pub(crate) mapper: CodeMapper,
16    pub(crate) nodes: Vec<Node>,
17}
18
19impl Trie {
20    /// Creates a new [`Trie`] from input keys.
21    ///
22    /// Values in `[0..n-1]` will be associated with keys in the lexicographical order,
23    /// where `n` is the number of keys.
24    ///
25    /// # Arguments
26    ///
27    /// - `keys`: Sorted list of string keys.
28    ///
29    /// # Errors
30    ///
31    /// [`CrawdadError`](crate::errors::CrawdadError) will be returned when
32    ///
33    /// - `keys` is empty,
34    /// - `keys` contains empty strings,
35    /// - `keys` contains duplicate keys,
36    /// - the scale of `keys` exceeds the expected one, or
37    /// - the scale of the resulting trie exceeds the expected one.
38    ///
39    /// # Examples
40    ///
41    /// ```
42    /// use crawdad::Trie;
43    ///
44    /// let keys = vec!["世界", "世界中", "国民"];
45    /// let trie = Trie::from_keys(keys).unwrap();
46    ///
47    /// assert_eq!(trie.num_elems(), 8);
48    /// ```
49    pub fn from_keys<I, K>(keys: I) -> Result<Self>
50    where
51        I: IntoIterator<Item = K>,
52        K: AsRef<str>,
53    {
54        Builder::new().build_from_keys(keys)?.release_trie()
55    }
56
57    /// Creates a new [`Trie`] from input records.
58    ///
59    /// # Arguments
60    ///
61    /// - `records`: Sorted list of key-value pairs.
62    ///
63    /// # Errors
64    ///
65    /// [`CrawdadError`](crate::errors::CrawdadError) will be returned when
66    ///
67    /// - `records` is empty,
68    /// - `records` contains empty strings,
69    /// - `records` contains duplicate keys,
70    /// - the scale of `keys` exceeds the expected one, or
71    /// - the scale of the resulting trie exceeds the expected one.
72    ///
73    /// # Examples
74    ///
75    /// ```
76    /// use crawdad::Trie;
77    ///
78    /// let records = vec![("世界", 2), ("世界中", 3), ("国民", 2)];
79    /// let trie = Trie::from_records(records).unwrap();
80    ///
81    /// assert_eq!(trie.num_elems(), 8);
82    /// ```
83    pub fn from_records<I, K>(records: I) -> Result<Self>
84    where
85        I: IntoIterator<Item = (K, u32)>,
86        K: AsRef<str>,
87    {
88        Builder::new().build_from_records(records)?.release_trie()
89    }
90
91    /// Serializes the data structure into a [`Vec`].
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// use crawdad::Trie;
97    ///
98    /// let keys = vec!["世界", "世界中", "国民"];
99    /// let trie = Trie::from_keys(&keys).unwrap();
100    /// let bytes = trie.serialize_to_vec();
101    /// ```
102    pub fn serialize_to_vec(&self) -> Vec<u8> {
103        let mut dest = Vec::with_capacity(self.io_bytes());
104        self.mapper.serialize_into_vec(&mut dest);
105        dest.extend_from_slice(&u32::try_from(self.nodes.len()).unwrap().to_le_bytes());
106        for node in &self.nodes {
107            dest.extend_from_slice(&node.serialize());
108        }
109        dest
110    }
111
112    /// Deserializes the data structure from a given byte slice.
113    ///
114    /// # Arguments
115    ///
116    /// * `source` - A source byte slice.
117    ///
118    /// # Returns
119    ///
120    /// A tuple of the data structure and the slice not used for the deserialization.
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// use crawdad::Trie;
126    ///
127    /// let keys = vec!["世界", "世界中", "国民"];
128    /// let trie = Trie::from_keys(&keys).unwrap();
129    ///
130    /// let bytes = trie.serialize_to_vec();
131    /// let (other, _) = Trie::deserialize_from_slice(&bytes);
132    ///
133    /// assert_eq!(trie.io_bytes(), other.io_bytes());
134    /// ```
135    pub fn deserialize_from_slice(source: &[u8]) -> (Self, &[u8]) {
136        let (mapper, mut source) = CodeMapper::deserialize_from_slice(source);
137        let nodes = {
138            let len = u32::from_le_bytes(source[..4].try_into().unwrap()) as usize;
139            source = &source[4..];
140            let mut nodes = Vec::with_capacity(len);
141            for _ in 0..len {
142                nodes.push(Node::deserialize(
143                    source[..Node::io_bytes()].try_into().unwrap(),
144                ));
145                source = &source[Node::io_bytes()..];
146            }
147            nodes
148        };
149        (Self { mapper, nodes }, source)
150    }
151
152    /// Returns a value associated with an input key if exists.
153    ///
154    /// # Arguments
155    ///
156    /// - `key`: Search key.
157    ///
158    /// # Examples
159    ///
160    /// ```
161    /// use crawdad::Trie;
162    ///
163    /// let keys = vec!["世界", "世界中", "国民"];
164    /// let trie = Trie::from_keys(&keys).unwrap();
165    ///
166    /// assert_eq!(trie.exact_match("世界中".chars()), Some(1));
167    /// assert_eq!(trie.exact_match("日本中".chars()), None);
168    /// ```
169    #[inline(always)]
170    pub fn exact_match<I>(&self, key: I) -> Option<u32>
171    where
172        I: IntoIterator<Item = char>,
173    {
174        let mut node_idx = 0;
175        for c in key {
176            node_idx = self
177                .mapper
178                .get(c)
179                .and_then(|mc| self.get_child_idx(node_idx, mc))?;
180        }
181        if self.is_leaf(node_idx) {
182            Some(self.get_value(node_idx))
183        } else if self.has_leaf(node_idx) {
184            Some(self.get_value(self.get_leaf_idx(node_idx)))
185        } else {
186            None
187        }
188    }
189
190    /// Returns an iterator for common prefix search.
191    ///
192    /// The iterator reports all occurrences of keys starting from an input haystack, where
193    /// an occurrence consists of its associated value and ending positoin in characters.
194    ///
195    /// # Examples
196    ///
197    /// You can find all occurrences of keys in a haystack by performing common prefix searches
198    /// at all starting positions.
199    ///
200    /// ```
201    /// use crawdad::Trie;
202    ///
203    /// let keys = vec!["世界", "世界中", "国民"];
204    /// let trie = Trie::from_keys(&keys).unwrap();
205    ///
206    /// let haystack: Vec<char> = "国民が世界中にて".chars().collect();
207    /// let mut matches = vec![];
208    ///
209    /// for i in 0..haystack.len() {
210    ///     for (v, j) in trie.common_prefix_search(haystack[i..].iter().copied()) {
211    ///         matches.push((v, i..i + j));
212    ///     }
213    /// }
214    ///
215    /// assert_eq!(
216    ///     matches,
217    ///     vec![(2, 0..2), (0, 3..5), (1, 3..6)]
218    /// );
219    /// ```
220    pub const fn common_prefix_search<I>(&self, haystack: I) -> CommonPrefixSearchIter<I> {
221        CommonPrefixSearchIter {
222            haystack,
223            haystack_pos: 0,
224            trie: self,
225            node_idx: 0,
226        }
227    }
228
229    #[inline(always)]
230    fn get_child_idx(&self, node_idx: u32, mc: u32) -> Option<u32> {
231        if self.is_leaf(node_idx) {
232            return None;
233        }
234        Some(self.get_base(node_idx) ^ mc)
235            .filter(|&child_idx| self.get_check(child_idx) == node_idx)
236    }
237
238    #[inline(always)]
239    fn node_ref(&self, node_idx: u32) -> &Node {
240        &self.nodes[usize::try_from(node_idx).unwrap()]
241    }
242
243    #[inline(always)]
244    fn get_base(&self, node_idx: u32) -> u32 {
245        self.node_ref(node_idx).get_base()
246    }
247
248    #[inline(always)]
249    fn get_check(&self, node_idx: u32) -> u32 {
250        self.node_ref(node_idx).get_check()
251    }
252
253    #[inline(always)]
254    fn is_leaf(&self, node_idx: u32) -> bool {
255        self.node_ref(node_idx).is_leaf()
256    }
257
258    #[inline(always)]
259    fn has_leaf(&self, node_idx: u32) -> bool {
260        self.node_ref(node_idx).has_leaf()
261    }
262
263    #[inline(always)]
264    fn get_leaf_idx(&self, node_idx: u32) -> u32 {
265        let leaf_idx = self.get_base(node_idx) ^ END_CODE;
266        debug_assert_eq!(self.get_check(leaf_idx), node_idx);
267        leaf_idx
268    }
269
270    #[inline(always)]
271    fn get_value(&self, node_idx: u32) -> u32 {
272        debug_assert!(self.is_leaf(node_idx));
273        self.node_ref(node_idx).get_base()
274    }
275
276    /// Returns the total amount of heap used by this automaton in bytes.
277    pub fn heap_bytes(&self) -> usize {
278        self.mapper.heap_bytes() + self.nodes.len() * mem::size_of::<Node>()
279    }
280
281    /// Returns the total amount of bytes to serialize the data structure.
282    pub fn io_bytes(&self) -> usize {
283        self.mapper.io_bytes() + self.nodes.len() * Node::io_bytes() + mem::size_of::<u32>()
284    }
285
286    /// Returns the number of reserved elements.
287    pub fn num_elems(&self) -> usize {
288        self.nodes.len()
289    }
290
291    /// Returns the number of vacant elements.
292    ///
293    /// # Note
294    ///
295    /// It takes `O(num_elems)` time.
296    pub fn num_vacants(&self) -> usize {
297        self.nodes.iter().filter(|nd| nd.is_vacant()).count()
298    }
299}
300
301/// Iterator for common prefix search.
302pub struct CommonPrefixSearchIter<'t, I> {
303    haystack: I,
304    haystack_pos: usize,
305    trie: &'t Trie,
306    node_idx: u32,
307}
308
309impl<I> Iterator for CommonPrefixSearchIter<'_, I>
310where
311    I: Iterator<Item = char>,
312{
313    type Item = (u32, usize);
314
315    #[inline(always)]
316    fn next(&mut self) -> Option<Self::Item> {
317        for c in self.haystack.by_ref() {
318            let mc = self.trie.mapper.get(c)?;
319            self.node_idx = self.trie.get_child_idx(self.node_idx, mc)?;
320            self.haystack_pos += 1;
321            if self.trie.is_leaf(self.node_idx) {
322                return Some((self.trie.get_value(self.node_idx), self.haystack_pos));
323            } else if self.trie.has_leaf(self.node_idx) {
324                let leaf_idx = self.trie.get_leaf_idx(self.node_idx);
325                return Some((self.trie.get_value(leaf_idx), self.haystack_pos));
326            }
327        }
328        None
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_exact_match() {
338        let keys = vec!["世界", "世界中", "世論調査", "統計調査"];
339        let trie = Trie::from_keys(&keys).unwrap();
340        for (i, key) in keys.iter().enumerate() {
341            assert_eq!(
342                trie.exact_match(key.chars()),
343                Some(u32::try_from(i).unwrap())
344            );
345        }
346        assert_eq!(trie.exact_match("世".chars()), None);
347        assert_eq!(trie.exact_match("世論".chars()), None);
348        assert_eq!(trie.exact_match("世界中で".chars()), None);
349        assert_eq!(trie.exact_match("統計".chars()), None);
350        assert_eq!(trie.exact_match("統計調".chars()), None);
351        assert_eq!(trie.exact_match("日本".chars()), None);
352    }
353
354    #[test]
355    fn test_common_prefix_search() {
356        let keys = vec!["世界", "世界中", "世論調査", "統計調査"];
357        let trie = Trie::from_keys(&keys).unwrap();
358
359        let haystack: Vec<_> = "世界中の統計世論調査".chars().collect();
360        let mut matches = vec![];
361
362        for i in 0..haystack.len() {
363            for (v, j) in trie.common_prefix_search(haystack[i..].iter().copied()) {
364                matches.push((v, i..i + j));
365            }
366        }
367        assert_eq!(matches, vec![(0, 0..2), (1, 0..3), (2, 6..10)]);
368    }
369
370    #[test]
371    fn test_serialize() {
372        let keys = vec!["世界", "世界中", "世論調査", "統計調査"];
373        let trie = Trie::from_keys(&keys).unwrap();
374
375        let bytes = trie.serialize_to_vec();
376        assert_eq!(trie.io_bytes(), bytes.len());
377
378        let (other, remain) = Trie::deserialize_from_slice(&bytes);
379        assert!(remain.is_empty());
380
381        assert_eq!(trie.mapper, other.mapper);
382        assert_eq!(trie.nodes, other.nodes);
383    }
384
385    #[test]
386    fn test_empty_set() {
387        assert!(Trie::from_keys(&[""][0..0]).is_err());
388    }
389
390    #[test]
391    fn test_empty_char() {
392        assert!(Trie::from_keys([""]).is_err());
393    }
394
395    #[test]
396    fn test_empty_key() {
397        assert!(Trie::from_keys(["", "AAA"]).is_err());
398    }
399
400    #[test]
401    fn test_unsorted_keys() {
402        assert!(Trie::from_keys(["BB", "AA"]).is_ok());
403        assert!(Trie::from_keys(["AAA", "AA"]).is_ok());
404    }
405
406    #[test]
407    fn test_duplicate_keys() {
408        assert!(Trie::from_keys(["AA", "AA"]).is_err());
409    }
410
411    #[test]
412    fn test_common_prefix_search_null() {
413        let keys = vec!["世界\0", "世界中", "世間"];
414        let trie = Trie::from_keys(&keys).unwrap();
415
416        let haystack: Vec<_> = "世界\0中の人\0世間".chars().collect();
417        let mut matches = vec![];
418
419        for i in 0..haystack.len() {
420            for (v, j) in trie.common_prefix_search(haystack[i..].iter().copied()) {
421                matches.push((v, i..i + j));
422            }
423        }
424        assert_eq!(matches, vec![(0, 0..3), (2, 7..9)]);
425    }
426}