Skip to main content

polyglot_sql/
trie.rs

1//! Trie data structure for efficient prefix matching
2//!
3//! This module provides a trie implementation used for:
4//! - Efficient keyword matching in the tokenizer
5//! - Time format conversion with overlapping patterns
6//! - Schema table name resolution
7//!
8//! Based on the Python implementation in `sqlglot/trie.py`.
9
10use std::collections::HashMap;
11
12/// Result of searching for a key in a trie
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum TrieResult {
15    /// Key not found in trie
16    Failed,
17    /// Key is a prefix of an existing key
18    Prefix,
19    /// Key exists in trie
20    Exists,
21}
22
23/// A trie (prefix tree) data structure
24///
25/// Generic over the value type `V`. If no value is needed, use `()`.
26///
27/// # Example
28///
29/// ```
30/// use polyglot_sql::trie::{Trie, TrieResult};
31///
32/// let mut trie = Trie::new();
33/// trie.insert("cat", 1);
34/// trie.insert("car", 2);
35///
36/// assert_eq!(trie.in_trie("cat"), (TrieResult::Exists, Some(&1)));
37/// assert_eq!(trie.in_trie("ca").0, TrieResult::Prefix);
38/// assert_eq!(trie.in_trie("dog").0, TrieResult::Failed);
39/// ```
40#[derive(Debug, Clone)]
41pub struct Trie<V> {
42    children: HashMap<char, Trie<V>>,
43    value: Option<V>,
44}
45
46impl<V> Default for Trie<V> {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl<V> Trie<V> {
53    /// Create a new empty trie
54    pub fn new() -> Self {
55        Self {
56            children: HashMap::new(),
57            value: None,
58        }
59    }
60
61    /// Insert a key-value pair into the trie
62    ///
63    /// # Arguments
64    /// * `key` - The key to insert (a string slice)
65    /// * `value` - The value to associate with the key
66    pub fn insert(&mut self, key: &str, value: V) {
67        let mut current = self;
68        for ch in key.chars() {
69            current = current.children.entry(ch).or_insert_with(Trie::new);
70        }
71        current.value = Some(value);
72    }
73
74    /// Get the value associated with a key
75    ///
76    /// Returns `None` if the key doesn't exist or only exists as a prefix.
77    pub fn get(&self, key: &str) -> Option<&V> {
78        let mut current = self;
79        for ch in key.chars() {
80            match current.children.get(&ch) {
81                Some(child) => current = child,
82                None => return None,
83            }
84        }
85        current.value.as_ref()
86    }
87
88    /// Check if a key exists in the trie
89    ///
90    /// Returns a tuple of (TrieResult, Option<&V>) where:
91    /// - `TrieResult::Failed` - key not found
92    /// - `TrieResult::Prefix` - key is a prefix of an existing key
93    /// - `TrieResult::Exists` - key exists in trie
94    ///
95    /// When the result is `Exists`, the Option will contain the value.
96    pub fn in_trie(&self, key: &str) -> (TrieResult, Option<&V>) {
97        if key.is_empty() {
98            return (TrieResult::Failed, None);
99        }
100
101        let mut current = self;
102        for ch in key.chars() {
103            match current.children.get(&ch) {
104                Some(child) => current = child,
105                None => return (TrieResult::Failed, None),
106            }
107        }
108
109        if current.value.is_some() {
110            (TrieResult::Exists, current.value.as_ref())
111        } else {
112            (TrieResult::Prefix, None)
113        }
114    }
115
116    /// Check if a key exists in the trie, following one character at a time
117    ///
118    /// This is useful for streaming/incremental matching. Returns:
119    /// - `TrieResult::Failed` - character not found from current position
120    /// - `TrieResult::Prefix` - character found, but not at end of a word
121    /// - `TrieResult::Exists` - character found and at end of a word
122    ///
123    /// Also returns the subtrie at this position (if any).
124    pub fn in_trie_char(&self, ch: char) -> (TrieResult, Option<&Trie<V>>) {
125        match self.children.get(&ch) {
126            Some(child) => {
127                if child.value.is_some() {
128                    (TrieResult::Exists, Some(child))
129                } else {
130                    (TrieResult::Prefix, Some(child))
131                }
132            }
133            None => (TrieResult::Failed, None),
134        }
135    }
136
137    /// Get the subtrie for a given character
138    pub fn get_child(&self, ch: char) -> Option<&Trie<V>> {
139        self.children.get(&ch)
140    }
141
142    /// Check if this node has a value (is a complete word)
143    pub fn has_value(&self) -> bool {
144        self.value.is_some()
145    }
146
147    /// Get the value at this node
148    pub fn value(&self) -> Option<&V> {
149        self.value.as_ref()
150    }
151
152    /// Check if the trie is empty
153    pub fn is_empty(&self) -> bool {
154        self.children.is_empty() && self.value.is_none()
155    }
156
157    /// Get all keys in the trie
158    pub fn keys(&self) -> Vec<String> {
159        let mut result = Vec::new();
160        self.collect_keys(String::new(), &mut result);
161        result
162    }
163
164    fn collect_keys(&self, prefix: String, result: &mut Vec<String>) {
165        if self.value.is_some() {
166            result.push(prefix.clone());
167        }
168        for (ch, child) in &self.children {
169            let mut new_prefix = prefix.clone();
170            new_prefix.push(*ch);
171            child.collect_keys(new_prefix, result);
172        }
173    }
174}
175
176/// Create a new trie from an iterator of (key, value) pairs
177///
178/// # Example
179///
180/// ```
181/// use polyglot_sql::trie::new_trie;
182///
183/// let trie = new_trie([
184///     ("foo".to_string(), 1),
185///     ("bar".to_string(), 2),
186/// ]);
187/// assert_eq!(trie.get("foo"), Some(&1));
188/// ```
189pub fn new_trie<V, I>(keywords: I) -> Trie<V>
190where
191    I: IntoIterator<Item = (String, V)>,
192{
193    let mut trie = Trie::new();
194    for (key, value) in keywords {
195        trie.insert(&key, value);
196    }
197    trie
198}
199
200/// Create a new trie from an iterator of keys (values are unit type)
201///
202/// Useful when you only need to check for key presence.
203///
204/// # Example
205///
206/// ```
207/// use polyglot_sql::trie::{new_trie_from_keys, TrieResult};
208///
209/// let trie = new_trie_from_keys(["SELECT", "FROM", "WHERE"]);
210/// assert_eq!(trie.in_trie("SELECT").0, TrieResult::Exists);
211/// ```
212pub fn new_trie_from_keys<I, S>(keywords: I) -> Trie<()>
213where
214    I: IntoIterator<Item = S>,
215    S: AsRef<str>,
216{
217    let mut trie = Trie::new();
218    for key in keywords {
219        trie.insert(key.as_ref(), ());
220    }
221    trie
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_new_trie() {
230        let trie = new_trie([
231            ("bla".to_string(), ()),
232            ("foo".to_string(), ()),
233            ("blab".to_string(), ()),
234        ]);
235
236        assert_eq!(trie.in_trie("bla").0, TrieResult::Exists);
237        assert_eq!(trie.in_trie("blab").0, TrieResult::Exists);
238        assert_eq!(trie.in_trie("foo").0, TrieResult::Exists);
239    }
240
241    #[test]
242    fn test_in_trie_failed() {
243        let trie = new_trie_from_keys(["cat"]);
244        assert_eq!(trie.in_trie("bob").0, TrieResult::Failed);
245    }
246
247    #[test]
248    fn test_in_trie_prefix() {
249        let trie = new_trie_from_keys(["cat"]);
250        assert_eq!(trie.in_trie("ca").0, TrieResult::Prefix);
251    }
252
253    #[test]
254    fn test_in_trie_exists() {
255        let trie = new_trie_from_keys(["cat"]);
256        assert_eq!(trie.in_trie("cat").0, TrieResult::Exists);
257    }
258
259    #[test]
260    fn test_empty_key() {
261        let trie = new_trie_from_keys(["cat"]);
262        assert_eq!(trie.in_trie("").0, TrieResult::Failed);
263    }
264
265    #[test]
266    fn test_get_value() {
267        let trie = new_trie([
268            ("foo".to_string(), 42),
269            ("bar".to_string(), 100),
270        ]);
271
272        assert_eq!(trie.get("foo"), Some(&42));
273        assert_eq!(trie.get("bar"), Some(&100));
274        assert_eq!(trie.get("baz"), None);
275        assert_eq!(trie.get("fo"), None); // Prefix only
276    }
277
278    #[test]
279    fn test_in_trie_char() {
280        let trie = new_trie_from_keys(["cat", "car"]);
281
282        // Start from root
283        let (result, subtrie) = trie.in_trie_char('c');
284        assert_eq!(result, TrieResult::Prefix);
285        assert!(subtrie.is_some());
286
287        // Continue with 'a'
288        let subtrie = subtrie.unwrap();
289        let (result, subtrie) = subtrie.in_trie_char('a');
290        assert_eq!(result, TrieResult::Prefix);
291        assert!(subtrie.is_some());
292
293        // Continue with 't' (reaches 'cat')
294        let subtrie = subtrie.unwrap();
295        let (result, _) = subtrie.in_trie_char('t');
296        assert_eq!(result, TrieResult::Exists);
297
298        // Try 'd' which doesn't exist
299        let (result, subtrie) = trie.in_trie_char('d');
300        assert_eq!(result, TrieResult::Failed);
301        assert!(subtrie.is_none());
302    }
303
304    #[test]
305    fn test_keys() {
306        let trie = new_trie_from_keys(["cat", "car", "card"]);
307        let mut keys = trie.keys();
308        keys.sort();
309        assert_eq!(keys, vec!["car", "card", "cat"]);
310    }
311
312    #[test]
313    fn test_unicode() {
314        let trie = new_trie_from_keys(["cafe", "caf\u{00e9}"]); // "caf\u{00e9}" = "cafe" with accent
315        assert_eq!(trie.in_trie("cafe").0, TrieResult::Exists);
316        assert_eq!(trie.in_trie("caf\u{00e9}").0, TrieResult::Exists);
317    }
318
319    #[test]
320    fn test_overlapping_prefixes() {
321        // Test case from sqlglot: "bla" and "blab"
322        let trie = new_trie_from_keys(["bla", "blab"]);
323
324        // "bla" should exist
325        assert_eq!(trie.in_trie("bla").0, TrieResult::Exists);
326
327        // "blab" should exist
328        assert_eq!(trie.in_trie("blab").0, TrieResult::Exists);
329
330        // "bl" should be prefix
331        assert_eq!(trie.in_trie("bl").0, TrieResult::Prefix);
332    }
333}