anomaly_grid/
context_trie.rs1use crate::context_tree::ContextNode;
8use crate::string_interner::{StateId, StringInterner};
9use smallvec::{smallvec, SmallVec};
10use std::sync::Arc;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub struct NodeId(u32);
15
16impl NodeId {
17 pub fn new(id: u32) -> Self {
19 Self(id)
20 }
21
22 pub fn id(self) -> u32 {
24 self.0
25 }
26}
27
28#[derive(Debug, Clone, Default)]
30pub struct TrieNode {
31 children: SmallVec<[(StateId, NodeId); 4]>,
34
35 context_data: Option<ContextNode>,
37
38 parent: Option<NodeId>,
40
41 state_from_parent: Option<StateId>,
43}
44
45impl TrieNode {
46 pub fn new(parent: Option<NodeId>, state_from_parent: Option<StateId>) -> Self {
48 Self {
49 children: smallvec![],
50 context_data: None,
51 parent,
52 state_from_parent,
53 }
54 }
55
56 pub fn add_child(&mut self, state: StateId, child_id: NodeId) {
58 for (existing_state, existing_id) in &mut self.children {
60 if *existing_state == state {
61 *existing_id = child_id;
62 return;
63 }
64 }
65
66 self.children.push((state, child_id));
68 }
69
70 pub fn get_child(&self, state: StateId) -> Option<NodeId> {
72 self.children
73 .iter()
74 .find(|(s, _)| *s == state)
75 .map(|(_, id)| *id)
76 }
77
78 pub fn children(&self) -> &[(StateId, NodeId)] {
80 &self.children
81 }
82
83 pub fn set_context_data(&mut self, data: ContextNode) {
85 self.context_data = Some(data);
86 }
87
88 pub fn context_data(&self) -> Option<&ContextNode> {
90 self.context_data.as_ref()
91 }
92
93 pub fn context_data_mut(&mut self) -> Option<&mut ContextNode> {
95 self.context_data.as_mut()
96 }
97
98 pub fn parent(&self) -> Option<NodeId> {
100 self.parent
101 }
102
103 pub fn state_from_parent(&self) -> Option<StateId> {
105 self.state_from_parent
106 }
107
108 pub fn has_context_data(&self) -> bool {
110 self.context_data.is_some()
111 }
112
113 pub fn memory_usage(&self) -> usize {
115 let mut size = std::mem::size_of::<Self>();
116
117 size += self.children.capacity() * std::mem::size_of::<(StateId, NodeId)>();
119
120 if let Some(ref data) = self.context_data {
122 size += std::mem::size_of::<ContextNode>();
123 size += data.vocab_size() * std::mem::size_of::<(StateId, usize)>();
125 }
126
127 size
128 }
129
130 pub fn reset(&mut self, parent: Option<NodeId>, state_from_parent: Option<StateId>) {
132 self.children.clear();
133 self.context_data = None;
134 self.parent = parent;
135 self.state_from_parent = state_from_parent;
136 }
137
138 pub fn clear(&mut self) {
140 self.children.clear();
141 self.context_data = None;
142 self.parent = None;
143 self.state_from_parent = None;
144 }
145}
146
147#[derive(Debug, Clone)]
149pub struct ContextTrie {
150 nodes: Vec<TrieNode>,
152
153 root: NodeId,
155
156 free_nodes: Vec<NodeId>,
158
159 max_order: usize,
161
162 interner: Arc<StringInterner>,
164}
165
166impl ContextTrie {
167 pub fn new(max_order: usize, interner: Arc<StringInterner>) -> Self {
169 let mut nodes = Vec::new();
170 let root = NodeId::new(0);
171
172 nodes.push(TrieNode::new(None, None));
174
175 Self {
176 nodes,
177 root,
178 free_nodes: Vec::new(),
179 max_order,
180 interner,
181 }
182 }
183
184 fn allocate_node_id(&mut self) -> NodeId {
186 if let Some(id) = self.free_nodes.pop() {
187 id
188 } else {
189 let id = NodeId::new(self.nodes.len() as u32);
190 self.nodes.push(TrieNode::new(None, None));
191 id
192 }
193 }
194
195 fn get_node(&self, id: NodeId) -> Option<&TrieNode> {
197 self.nodes.get(id.id() as usize)
198 }
199
200 fn get_node_mut(&mut self, id: NodeId) -> Option<&mut TrieNode> {
202 self.nodes.get_mut(id.id() as usize)
203 }
204
205 pub fn insert_context_path(&mut self, context: &[StateId]) -> NodeId {
207 let mut current_id = self.root;
208
209 for &state in context {
210 let next_id = {
211 let current_node = self.get_node(current_id).expect("Invalid node ID");
212 current_node.get_child(state)
213 };
214
215 current_id = if let Some(existing_id) = next_id {
216 existing_id
217 } else {
218 let new_id = self.allocate_node_id();
220
221 if let Some(new_node) = self.get_node_mut(new_id) {
223 new_node.parent = Some(current_id);
224 new_node.state_from_parent = Some(state);
225 }
226
227 if let Some(current_node) = self.get_node_mut(current_id) {
229 current_node.add_child(state, new_id);
230 }
231
232 new_id
233 };
234 }
235
236 current_id
237 }
238
239 pub fn get_context_node_id(&self, context: &[StateId]) -> Option<NodeId> {
241 let mut current_id = self.root;
242
243 for &state in context {
244 let current_node = self.get_node(current_id)?;
245 current_id = current_node.get_child(state)?;
246 }
247
248 Some(current_id)
249 }
250
251 pub fn get_context_data(&self, context: &[StateId]) -> Option<&ContextNode> {
253 let node_id = self.get_context_node_id(context)?;
254 let node = self.get_node(node_id)?;
255 node.context_data()
256 }
257
258 pub fn get_context_data_mut(&mut self, context: &[StateId]) -> Option<&mut ContextNode> {
260 let node_id = self.get_context_node_id(context)?;
261 let node = self.get_node_mut(node_id)?;
262 node.context_data_mut()
263 }
264
265 pub fn set_context_data(&mut self, context: &[StateId], data: ContextNode) {
267 let node_id = self.insert_context_path(context);
268 if let Some(node) = self.get_node_mut(node_id) {
269 node.set_context_data(data);
270 }
271 }
272
273 pub fn get_or_create_context_data(&mut self, context: &[StateId]) -> &mut ContextNode {
275 let node_id = self.insert_context_path(context);
276
277 let needs_creation = {
279 let node = self.get_node(node_id).expect("Invalid node ID");
280 !node.has_context_data()
281 };
282
283 if needs_creation {
284 let new_data = ContextNode::new(Arc::clone(&self.interner));
285 if let Some(node) = self.get_node_mut(node_id) {
286 node.set_context_data(new_data);
287 }
288 }
289
290 self.get_node_mut(node_id)
291 .expect("Invalid node ID")
292 .context_data_mut()
293 .expect("Context data should exist")
294 }
295
296 pub fn iter_contexts(&self) -> impl Iterator<Item = (Vec<StateId>, &ContextNode)> {
298 ContextTrieIterator::new(self)
299 }
300
301 pub fn context_count(&self) -> usize {
303 self.nodes
304 .iter()
305 .filter(|node| node.has_context_data())
306 .count()
307 }
308
309 pub fn node_count(&self) -> usize {
311 self.nodes.len()
312 }
313
314 pub fn memory_usage(&self) -> usize {
316 let mut total = std::mem::size_of::<Self>();
317
318 total += self.nodes.capacity() * std::mem::size_of::<TrieNode>();
320
321 for node in &self.nodes {
323 total += node.memory_usage();
324 }
325
326 total += self.free_nodes.capacity() * std::mem::size_of::<NodeId>();
328
329 total
330 }
331
332 pub fn interner(&self) -> &Arc<StringInterner> {
334 &self.interner
335 }
336
337 pub fn max_order(&self) -> usize {
339 self.max_order
340 }
341}
342
343pub struct ContextTrieIterator<'a> {
345 trie: &'a ContextTrie,
346 stack: Vec<(NodeId, Vec<StateId>)>,
347}
348
349impl<'a> ContextTrieIterator<'a> {
350 fn new(trie: &'a ContextTrie) -> Self {
351 let stack = vec![(trie.root, Vec::new())];
352
353 Self { trie, stack }
354 }
355}
356
357impl<'a> Iterator for ContextTrieIterator<'a> {
358 type Item = (Vec<StateId>, &'a ContextNode);
359
360 fn next(&mut self) -> Option<Self::Item> {
361 while let Some((node_id, path)) = self.stack.pop() {
362 if let Some(node) = self.trie.get_node(node_id) {
363 for &(state, child_id) in node.children() {
365 let mut child_path = path.clone();
366 child_path.push(state);
367 self.stack.push((child_id, child_path));
368 }
369
370 if let Some(context_data) = node.context_data() {
372 return Some((path, context_data));
373 }
374 }
375 }
376
377 None
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn test_trie_basic_operations() {
387 let interner = Arc::new(StringInterner::new());
388 let mut trie = ContextTrie::new(3, Arc::clone(&interner));
389
390 let state_a = StateId::new(1);
392 let state_b = StateId::new(2);
393 let _state_c = StateId::new(3);
394
395 let context = vec![state_a, state_b];
397 let node_id = trie.insert_context_path(&context);
398
399 let retrieved_id = trie.get_context_node_id(&context);
401 assert_eq!(Some(node_id), retrieved_id);
402
403 let context_data = ContextNode::new(Arc::clone(&interner));
405 trie.set_context_data(&context, context_data);
406
407 let retrieved_data = trie.get_context_data(&context);
409 assert!(retrieved_data.is_some());
410 }
411
412 #[test]
413 fn test_trie_prefix_sharing() {
414 let interner = Arc::new(StringInterner::new());
415 let mut trie = ContextTrie::new(3, Arc::clone(&interner));
416
417 let state_a = StateId::new(1);
418 let state_b = StateId::new(2);
419 let state_c = StateId::new(3);
420
421 let context1 = vec![state_a, state_b];
423 let context2 = vec![state_a, state_b, state_c];
424 let context3 = vec![state_a, state_c];
425
426 trie.insert_context_path(&context1);
427 trie.insert_context_path(&context2);
428 trie.insert_context_path(&context3);
429
430 let node_count = trie.node_count();
432 assert!(node_count <= 6); assert!(trie.get_context_node_id(&context1).is_some());
438 assert!(trie.get_context_node_id(&context2).is_some());
439 assert!(trie.get_context_node_id(&context3).is_some());
440 }
441
442 #[test]
443 fn test_trie_iteration() {
444 let interner = Arc::new(StringInterner::new());
445 let mut trie = ContextTrie::new(2, Arc::clone(&interner));
446
447 let state_a = StateId::new(1);
448 let state_b = StateId::new(2);
449
450 let context1 = vec![state_a];
452 let context2 = vec![state_a, state_b];
453
454 let data1 = ContextNode::new(Arc::clone(&interner));
455 let data2 = ContextNode::new(Arc::clone(&interner));
456
457 trie.set_context_data(&context1, data1);
458 trie.set_context_data(&context2, data2);
459
460 let contexts: Vec<_> = trie.iter_contexts().collect();
462 assert_eq!(contexts.len(), 2);
463
464 assert_eq!(trie.context_count(), 2);
466 }
467
468 #[test]
469 fn test_memory_usage_calculation() {
470 let interner = Arc::new(StringInterner::new());
471 let mut trie = ContextTrie::new(2, Arc::clone(&interner));
472
473 let initial_usage = trie.memory_usage();
474 assert!(initial_usage > 0);
475
476 let state_a = StateId::new(1);
478 let context = vec![state_a];
479 let data = ContextNode::new(Arc::clone(&interner));
480 trie.set_context_data(&context, data);
481
482 let final_usage = trie.memory_usage();
483 assert!(final_usage > initial_usage);
484 }
485}