anomaly_grid/
context_trie.rs

1//! Trie-based context storage for memory-efficient prefix sharing
2//!
3//! This module implements a trie (prefix tree) data structure for storing
4//! variable-order Markov chain contexts with significant memory savings
5//! through prefix sharing.
6
7use crate::context_tree::ContextNode;
8use crate::string_interner::{StateId, StringInterner};
9use smallvec::{SmallVec, smallvec};
10use std::sync::Arc;
11
12/// Node identifier in the trie
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub struct NodeId(u32);
15
16impl NodeId {
17    /// Create a new NodeId
18    pub fn new(id: u32) -> Self {
19        Self(id)
20    }
21
22    /// Get the raw ID value
23    pub fn id(self) -> u32 {
24        self.0
25    }
26}
27
28/// A node in the context trie
29#[derive(Debug, Clone, Default)]
30pub struct TrieNode {
31    /// Children nodes: (StateId, NodeId) pairs
32    /// Uses SmallVec for memory efficiency since most nodes have few children
33    children: SmallVec<[(StateId, NodeId); 4]>,
34    
35    /// Context data if this node represents a complete context
36    context_data: Option<ContextNode>,
37    
38    /// Parent node for navigation (None for root)
39    parent: Option<NodeId>,
40    
41    /// The state that led to this node from parent
42    state_from_parent: Option<StateId>,
43}
44
45impl TrieNode {
46    /// Create a new empty trie node
47    pub fn new(parent: Option<NodeId>, state_from_parent: Option<StateId>) -> Self {
48        Self {
49            children: smallvec![],
50            context_data: None,
51            parent,
52            state_from_parent,
53        }
54    }
55
56    /// Add a child node
57    pub fn add_child(&mut self, state: StateId, child_id: NodeId) {
58        // Check if child already exists
59        for (existing_state, existing_id) in &mut self.children {
60            if *existing_state == state {
61                *existing_id = child_id;
62                return;
63            }
64        }
65        
66        // Add new child
67        self.children.push((state, child_id));
68    }
69
70    /// Get child node ID for a given state
71    pub fn get_child(&self, state: StateId) -> Option<NodeId> {
72        self.children
73            .iter()
74            .find(|(s, _)| *s == state)
75            .map(|(_, id)| *id)
76    }
77
78    /// Get all children
79    pub fn children(&self) -> &[(StateId, NodeId)] {
80        &self.children
81    }
82
83    /// Set context data for this node
84    pub fn set_context_data(&mut self, data: ContextNode) {
85        self.context_data = Some(data);
86    }
87
88    /// Get context data if present
89    pub fn context_data(&self) -> Option<&ContextNode> {
90        self.context_data.as_ref()
91    }
92
93    /// Get mutable context data if present
94    pub fn context_data_mut(&mut self) -> Option<&mut ContextNode> {
95        self.context_data.as_mut()
96    }
97
98    /// Get parent node ID
99    pub fn parent(&self) -> Option<NodeId> {
100        self.parent
101    }
102
103    /// Get the state that led to this node from parent
104    pub fn state_from_parent(&self) -> Option<StateId> {
105        self.state_from_parent
106    }
107
108    /// Check if this node has context data
109    pub fn has_context_data(&self) -> bool {
110        self.context_data.is_some()
111    }
112
113    /// Get memory usage estimate for this node
114    pub fn memory_usage(&self) -> usize {
115        let mut size = std::mem::size_of::<Self>();
116        
117        // Children storage
118        size += self.children.capacity() * std::mem::size_of::<(StateId, NodeId)>();
119        
120        // Context data if present
121        if let Some(ref data) = self.context_data {
122            size += std::mem::size_of::<ContextNode>();
123            // Add estimated size of transition counts
124            size += data.vocab_size() * std::mem::size_of::<(StateId, usize)>();
125        }
126        
127        size
128    }
129
130    /// Reset the trie node for reuse in memory pool
131    pub fn reset(&mut self, parent: Option<NodeId>, state_from_parent: Option<StateId>) {
132        self.children.clear();
133        self.context_data = None;
134        self.parent = parent;
135        self.state_from_parent = state_from_parent;
136    }
137
138    /// Clear the trie node data for memory pool return
139    pub fn clear(&mut self) {
140        self.children.clear();
141        self.context_data = None;
142        self.parent = None;
143        self.state_from_parent = None;
144    }
145}
146
147/// Trie-based context storage with prefix sharing
148#[derive(Debug, Clone)]
149pub struct ContextTrie {
150    /// All nodes in the trie stored in a vector for cache efficiency
151    nodes: Vec<TrieNode>,
152    
153    /// Root node ID
154    root: NodeId,
155    
156    /// Free node IDs for reuse
157    free_nodes: Vec<NodeId>,
158    
159    /// Maximum context order
160    max_order: usize,
161    
162    /// String interner for state management
163    interner: Arc<StringInterner>,
164}
165
166impl ContextTrie {
167    /// Create a new context trie
168    pub fn new(max_order: usize, interner: Arc<StringInterner>) -> Self {
169        let mut nodes = Vec::new();
170        let root = NodeId::new(0);
171        
172        // Create root node
173        nodes.push(TrieNode::new(None, None));
174        
175        Self {
176            nodes,
177            root,
178            free_nodes: Vec::new(),
179            max_order,
180            interner,
181        }
182    }
183
184    /// Allocate a new node ID
185    fn allocate_node_id(&mut self) -> NodeId {
186        if let Some(id) = self.free_nodes.pop() {
187            id
188        } else {
189            let id = NodeId::new(self.nodes.len() as u32);
190            self.nodes.push(TrieNode::new(None, None));
191            id
192        }
193    }
194
195    /// Get a node by ID
196    fn get_node(&self, id: NodeId) -> Option<&TrieNode> {
197        self.nodes.get(id.id() as usize)
198    }
199
200    /// Get a mutable node by ID
201    fn get_node_mut(&mut self, id: NodeId) -> Option<&mut TrieNode> {
202        self.nodes.get_mut(id.id() as usize)
203    }
204
205    /// Insert a context path and return the node ID for the final context
206    pub fn insert_context_path(&mut self, context: &[StateId]) -> NodeId {
207        let mut current_id = self.root;
208        
209        for &state in context {
210            let next_id = {
211                let current_node = self.get_node(current_id).expect("Invalid node ID");
212                current_node.get_child(state)
213            };
214            
215            current_id = if let Some(existing_id) = next_id {
216                existing_id
217            } else {
218                // Create new child node
219                let new_id = self.allocate_node_id();
220                
221                // Set up the new node
222                if let Some(new_node) = self.get_node_mut(new_id) {
223                    new_node.parent = Some(current_id);
224                    new_node.state_from_parent = Some(state);
225                }
226                
227                // Add child to current node
228                if let Some(current_node) = self.get_node_mut(current_id) {
229                    current_node.add_child(state, new_id);
230                }
231                
232                new_id
233            };
234        }
235        
236        current_id
237    }
238
239    /// Get the node ID for a context path
240    pub fn get_context_node_id(&self, context: &[StateId]) -> Option<NodeId> {
241        let mut current_id = self.root;
242        
243        for &state in context {
244            let current_node = self.get_node(current_id)?;
245            current_id = current_node.get_child(state)?;
246        }
247        
248        Some(current_id)
249    }
250
251    /// Get context data for a given context
252    pub fn get_context_data(&self, context: &[StateId]) -> Option<&ContextNode> {
253        let node_id = self.get_context_node_id(context)?;
254        let node = self.get_node(node_id)?;
255        node.context_data()
256    }
257
258    /// Get mutable context data for a given context
259    pub fn get_context_data_mut(&mut self, context: &[StateId]) -> Option<&mut ContextNode> {
260        let node_id = self.get_context_node_id(context)?;
261        let node = self.get_node_mut(node_id)?;
262        node.context_data_mut()
263    }
264
265    /// Set context data for a given context
266    pub fn set_context_data(&mut self, context: &[StateId], data: ContextNode) {
267        let node_id = self.insert_context_path(context);
268        if let Some(node) = self.get_node_mut(node_id) {
269            node.set_context_data(data);
270        }
271    }
272
273    /// Get or create context data for a given context
274    pub fn get_or_create_context_data(&mut self, context: &[StateId]) -> &mut ContextNode {
275        let node_id = self.insert_context_path(context);
276        
277        // Check if context data exists
278        let needs_creation = {
279            let node = self.get_node(node_id).expect("Invalid node ID");
280            !node.has_context_data()
281        };
282        
283        if needs_creation {
284            let new_data = ContextNode::new(Arc::clone(&self.interner));
285            if let Some(node) = self.get_node_mut(node_id) {
286                node.set_context_data(new_data);
287            }
288        }
289        
290        self.get_node_mut(node_id)
291            .expect("Invalid node ID")
292            .context_data_mut()
293            .expect("Context data should exist")
294    }
295
296    /// Iterate over all contexts with data
297    pub fn iter_contexts(&self) -> impl Iterator<Item = (Vec<StateId>, &ContextNode)> {
298        ContextTrieIterator::new(self)
299    }
300
301    /// Get the number of contexts with data
302    pub fn context_count(&self) -> usize {
303        self.nodes
304            .iter()
305            .filter(|node| node.has_context_data())
306            .count()
307    }
308
309    /// Get the total number of nodes in the trie
310    pub fn node_count(&self) -> usize {
311        self.nodes.len()
312    }
313
314    /// Get memory usage estimate
315    pub fn memory_usage(&self) -> usize {
316        let mut total = std::mem::size_of::<Self>();
317        
318        // Node storage
319        total += self.nodes.capacity() * std::mem::size_of::<TrieNode>();
320        
321        // Individual node memory usage
322        for node in &self.nodes {
323            total += node.memory_usage();
324        }
325        
326        // Free nodes vector
327        total += self.free_nodes.capacity() * std::mem::size_of::<NodeId>();
328        
329        total
330    }
331
332    /// Get access to the string interner
333    pub fn interner(&self) -> &Arc<StringInterner> {
334        &self.interner
335    }
336
337    /// Get maximum order
338    pub fn max_order(&self) -> usize {
339        self.max_order
340    }
341}
342
343/// Iterator over contexts in the trie
344pub struct ContextTrieIterator<'a> {
345    trie: &'a ContextTrie,
346    stack: Vec<(NodeId, Vec<StateId>)>,
347}
348
349impl<'a> ContextTrieIterator<'a> {
350    fn new(trie: &'a ContextTrie) -> Self {
351        let stack = vec![(trie.root, Vec::new())];
352        
353        Self { trie, stack }
354    }
355}
356
357impl<'a> Iterator for ContextTrieIterator<'a> {
358    type Item = (Vec<StateId>, &'a ContextNode);
359
360    fn next(&mut self) -> Option<Self::Item> {
361        while let Some((node_id, path)) = self.stack.pop() {
362            if let Some(node) = self.trie.get_node(node_id) {
363                // Add children to stack for further exploration
364                for &(state, child_id) in node.children() {
365                    let mut child_path = path.clone();
366                    child_path.push(state);
367                    self.stack.push((child_id, child_path));
368                }
369                
370                // If this node has context data, return it
371                if let Some(context_data) = node.context_data() {
372                    return Some((path, context_data));
373                }
374            }
375        }
376        
377        None
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn test_trie_basic_operations() {
387        let interner = Arc::new(StringInterner::new());
388        let mut trie = ContextTrie::new(3, Arc::clone(&interner));
389        
390        // Create some state IDs
391        let state_a = StateId::new(1);
392        let state_b = StateId::new(2);
393        let _state_c = StateId::new(3);
394        
395        // Insert a context path
396        let context = vec![state_a, state_b];
397        let node_id = trie.insert_context_path(&context);
398        
399        // Verify we can retrieve it
400        let retrieved_id = trie.get_context_node_id(&context);
401        assert_eq!(Some(node_id), retrieved_id);
402        
403        // Insert context data
404        let context_data = ContextNode::new(Arc::clone(&interner));
405        trie.set_context_data(&context, context_data);
406        
407        // Verify we can retrieve the data
408        let retrieved_data = trie.get_context_data(&context);
409        assert!(retrieved_data.is_some());
410    }
411
412    #[test]
413    fn test_trie_prefix_sharing() {
414        let interner = Arc::new(StringInterner::new());
415        let mut trie = ContextTrie::new(3, Arc::clone(&interner));
416        
417        let state_a = StateId::new(1);
418        let state_b = StateId::new(2);
419        let state_c = StateId::new(3);
420        
421        // Insert contexts that share prefixes
422        let context1 = vec![state_a, state_b];
423        let context2 = vec![state_a, state_b, state_c];
424        let context3 = vec![state_a, state_c];
425        
426        trie.insert_context_path(&context1);
427        trie.insert_context_path(&context2);
428        trie.insert_context_path(&context3);
429        
430        // Should have shared prefix nodes
431        let node_count = trie.node_count();
432        // Root + A + A->B + A->B->C + A->C = 5 nodes for 3 contexts
433        // This demonstrates prefix sharing (A and A->B are shared)
434        assert!(node_count <= 6); // Allow some flexibility
435        
436        // All contexts should be retrievable
437        assert!(trie.get_context_node_id(&context1).is_some());
438        assert!(trie.get_context_node_id(&context2).is_some());
439        assert!(trie.get_context_node_id(&context3).is_some());
440    }
441
442    #[test]
443    fn test_trie_iteration() {
444        let interner = Arc::new(StringInterner::new());
445        let mut trie = ContextTrie::new(2, Arc::clone(&interner));
446        
447        let state_a = StateId::new(1);
448        let state_b = StateId::new(2);
449        
450        // Add some contexts with data
451        let context1 = vec![state_a];
452        let context2 = vec![state_a, state_b];
453        
454        let data1 = ContextNode::new(Arc::clone(&interner));
455        let data2 = ContextNode::new(Arc::clone(&interner));
456        
457        trie.set_context_data(&context1, data1);
458        trie.set_context_data(&context2, data2);
459        
460        // Iterate and count
461        let contexts: Vec<_> = trie.iter_contexts().collect();
462        assert_eq!(contexts.len(), 2);
463        
464        // Verify context count
465        assert_eq!(trie.context_count(), 2);
466    }
467
468    #[test]
469    fn test_memory_usage_calculation() {
470        let interner = Arc::new(StringInterner::new());
471        let mut trie = ContextTrie::new(2, Arc::clone(&interner));
472        
473        let initial_usage = trie.memory_usage();
474        assert!(initial_usage > 0);
475        
476        // Add some data
477        let state_a = StateId::new(1);
478        let context = vec![state_a];
479        let data = ContextNode::new(Arc::clone(&interner));
480        trie.set_context_data(&context, data);
481        
482        let final_usage = trie.memory_usage();
483        assert!(final_usage > initial_usage);
484    }
485}