anomaly_grid/context_tree/
mod.rs1use crate::config::AnomalyGridConfig;
12use crate::context_trie::ContextTrie;
13use crate::error::{AnomalyGridError, AnomalyGridResult};
14use crate::string_interner::{StateId, StringInterner};
15use crate::transition_counts::TransitionCounts;
16use std::collections::hash_map::DefaultHasher;
17use std::collections::HashMap;
18use std::hash::{Hash, Hasher};
19use std::sync::Arc;
20
21#[derive(Debug, Clone)]
26pub struct ContextNode {
27 counts: TransitionCounts,
29 total_count: usize,
31 interner: Arc<StringInterner>,
33 cached_entropy: Option<f64>,
35 cached_kl_divergence: Option<f64>,
37 cached_config_hash: Option<u64>,
39}
40
41impl ContextNode {
42 pub fn new(interner: Arc<StringInterner>) -> Self {
44 Self {
45 counts: TransitionCounts::new(),
46 total_count: 0,
47 interner,
48 cached_entropy: None,
49 cached_kl_divergence: None,
50 cached_config_hash: None,
51 }
52 }
53
54 pub fn add_transition(&mut self, next_state: &str) {
56 let state_id = self.interner.get_or_intern(next_state);
57 self.counts.increment(state_id);
58 self.total_count += 1;
59 self.invalidate_cache();
60 }
61
62 pub fn add_transition_by_id(&mut self, state_id: StateId) {
64 self.counts.increment(state_id);
65 self.total_count += 1;
66 self.invalidate_cache();
67 }
68
69 fn invalidate_cache(&mut self) {
71 self.cached_entropy = None;
72 self.cached_kl_divergence = None;
73 self.cached_config_hash = None;
74 }
75
76 fn compute_config_hash(config: &AnomalyGridConfig) -> u64 {
78 let mut hasher = DefaultHasher::new();
79 config.smoothing_alpha.to_bits().hash(&mut hasher);
81 hasher.finish()
82 }
83
84 fn is_cache_valid(&self, config: &AnomalyGridConfig) -> bool {
86 if let Some(cached_hash) = self.cached_config_hash {
87 cached_hash == Self::compute_config_hash(config)
88 } else {
89 false
90 }
91 }
92
93 pub fn total_count(&self) -> usize {
95 self.total_count
96 }
97
98 pub fn get_count(&self, next_state: &str) -> usize {
100 let state_id = self.interner.get_or_intern(next_state);
101 self.counts.get(state_id)
102 }
103
104 pub fn get_count_by_id(&self, state_id: StateId) -> usize {
106 self.counts.get(state_id)
107 }
108
109 pub fn vocab_size(&self) -> usize {
111 self.counts.len()
112 }
113
114 pub fn get_state_counts(&self) -> impl Iterator<Item = (StateId, usize)> + '_ {
116 self.counts.iter()
117 }
118
119 pub fn total_transitions(&self) -> usize {
121 self.total_count
122 }
123
124 pub fn get_string_counts(&self) -> HashMap<String, usize> {
126 self.counts
127 .iter()
128 .filter_map(|(state_id, count)| self.interner.get_string(state_id).map(|s| (s, count)))
129 .collect()
130 }
131
132 pub fn counts(&self) -> HashMap<String, usize> {
134 self.get_string_counts()
135 }
136
137 pub fn get_probability(&self, next_state: &str, config: &AnomalyGridConfig) -> f64 {
141 let state_id = self.interner.get_or_intern(next_state);
142 self.get_probability_by_id(state_id, config)
143 }
144
145 pub fn get_probability_by_id(&self, state_id: StateId, config: &AnomalyGridConfig) -> f64 {
147 if self.total_count == 0 {
148 return 1.0 / (self.vocab_size() as f64).max(1.0);
149 }
150
151 let count = self.get_count_by_id(state_id) as f64;
152 let vocab_size = self.vocab_size() as f64;
153
154 (count + config.smoothing_alpha)
155 / (self.total_count as f64 + config.smoothing_alpha * vocab_size)
156 }
157
158 pub fn get_probability_normalized(
160 &self,
161 next_state: &str,
162 config: &AnomalyGridConfig,
163 global_vocab_size: usize,
164 ) -> f64 {
165 if self.total_count == 0 {
166 return 1.0 / (global_vocab_size as f64).max(1.0);
167 }
168
169 let state_id = self.interner.get_or_intern(next_state);
170 let count = self.get_count_by_id(state_id) as f64;
171 let global_vocab_size_f64 = global_vocab_size as f64;
172
173 (count + config.smoothing_alpha)
176 / (self.total_count as f64 + config.smoothing_alpha * global_vocab_size_f64)
177 }
178
179 pub fn get_probability_normalized_by_id(
181 &self,
182 state_id: StateId,
183 config: &AnomalyGridConfig,
184 global_vocab_size: usize,
185 ) -> f64 {
186 if self.total_count == 0 {
187 return 1.0 / (global_vocab_size as f64).max(1.0);
188 }
189
190 let count = self.get_count_by_id(state_id) as f64;
191 let global_vocab_size_f64 = global_vocab_size as f64;
192
193 (count + config.smoothing_alpha)
194 / (self.total_count as f64 + config.smoothing_alpha * global_vocab_size_f64)
195 }
196
197 pub fn calculate_entropy(&mut self, config: &AnomalyGridConfig) -> f64 {
199 if self.is_cache_valid(config) {
201 if let Some(cached_entropy) = self.cached_entropy {
202 return cached_entropy;
203 }
204 }
205
206 let entropy = if self.total_count == 0 {
208 0.0
209 } else {
210 self.counts
211 .keys()
212 .map(|state_id| {
213 let p = self.get_probability_by_id(state_id, config);
214 if p > 0.0 {
215 -p * p.log2()
216 } else {
217 0.0
218 }
219 })
220 .sum()
221 };
222
223 self.cached_entropy = Some(entropy);
225 self.cached_config_hash = Some(Self::compute_config_hash(config));
226
227 entropy
228 }
229
230 pub fn compute_entropy(&self, config: &AnomalyGridConfig) -> f64 {
232 if self.total_count == 0 {
233 return 0.0;
234 }
235
236 self.counts
237 .keys()
238 .map(|state_id| {
239 let p = self.get_probability_by_id(state_id, config);
240 if p > 0.0 {
241 -p * p.log2()
242 } else {
243 0.0
244 }
245 })
246 .sum()
247 }
248
249 pub fn calculate_kl_divergence(&mut self, config: &AnomalyGridConfig) -> f64 {
251 if self.is_cache_valid(config) {
253 if let Some(cached_kl_div) = self.cached_kl_divergence {
254 return cached_kl_div;
255 }
256 }
257
258 let kl_divergence = if self.total_count == 0 {
260 0.0
261 } else {
262 let uniform_prob = 1.0 / self.vocab_size() as f64;
263
264 self.counts
265 .keys()
266 .map(|state_id| {
267 let p = self.get_probability_by_id(state_id, config);
268 if p > 0.0 {
269 p * (p / uniform_prob).log2()
270 } else {
271 0.0
272 }
273 })
274 .sum()
275 };
276
277 self.cached_kl_divergence = Some(kl_divergence);
279 self.cached_config_hash = Some(Self::compute_config_hash(config));
280
281 kl_divergence
282 }
283
284 pub fn compute_kl_divergence(&self, config: &AnomalyGridConfig) -> f64 {
286 if self.total_count == 0 {
287 return 0.0;
288 }
289
290 let uniform_prob = 1.0 / self.vocab_size() as f64;
291
292 self.counts
293 .keys()
294 .map(|state_id| {
295 let p = self.get_probability_by_id(state_id, config);
296 if p > 0.0 {
297 p * (p / uniform_prob).log2()
298 } else {
299 0.0
300 }
301 })
302 .sum()
303 }
304
305 pub fn get_all_probabilities(&self, config: &AnomalyGridConfig) -> HashMap<String, f64> {
309 self.counts
310 .keys()
311 .filter_map(|state_id| {
312 self.interner.get_string(state_id).map(|state_string| {
313 let prob = self.get_probability_by_id(state_id, config);
314 (state_string, prob)
315 })
316 })
317 .collect()
318 }
319
320 pub fn reset(&mut self, interner: Arc<StringInterner>) {
322 self.counts = TransitionCounts::new();
323 self.total_count = 0;
324 self.interner = interner;
325 self.cached_entropy = None;
326 self.cached_kl_divergence = None;
327 self.cached_config_hash = None;
328 }
329
330 pub fn clear(&mut self) {
332 self.counts = TransitionCounts::new();
333 self.total_count = 0;
334 self.cached_entropy = None;
335 self.cached_kl_divergence = None;
336 self.cached_config_hash = None;
337 }
339
340 pub fn cache_stats(&self) -> (bool, bool) {
342 (
343 self.cached_entropy.is_some(),
344 self.cached_kl_divergence.is_some(),
345 )
346 }
347}
348
349impl Default for ContextNode {
350 fn default() -> Self {
351 Self {
352 counts: TransitionCounts::new(),
353 total_count: 0,
354 interner: Arc::new(StringInterner::new()),
355 cached_entropy: None,
356 cached_kl_divergence: None,
357 cached_config_hash: None,
358 }
359 }
360}
361
362#[derive(Debug, Clone)]
366pub struct ContextTree {
367 trie: ContextTrie,
369 pub max_order: usize,
371 interner: Arc<StringInterner>,
373 pub(crate) last_config: AnomalyGridConfig,
375}
376
377impl ContextTree {
378 pub fn new(max_order: usize) -> AnomalyGridResult<Self> {
380 if max_order == 0 {
381 return Err(AnomalyGridError::invalid_max_order(max_order));
382 }
383
384 let interner = Arc::new(StringInterner::new());
385 let trie = ContextTrie::new(max_order, Arc::clone(&interner));
386 let last_config = AnomalyGridConfig::default();
387
388 Ok(Self {
389 trie,
390 max_order,
391 interner,
392 last_config,
393 })
394 }
395
396 pub fn with_interner(
398 max_order: usize,
399 interner: Arc<StringInterner>,
400 ) -> AnomalyGridResult<Self> {
401 if max_order == 0 {
402 return Err(AnomalyGridError::invalid_max_order(max_order));
403 }
404
405 let trie = ContextTrie::new(max_order, Arc::clone(&interner));
406 let last_config = AnomalyGridConfig::default();
407
408 Ok(Self {
409 trie,
410 max_order,
411 interner,
412 last_config,
413 })
414 }
415
416 pub fn build_from_sequence(
427 &mut self,
428 sequence: &[String],
429 config: &AnomalyGridConfig,
430 ) -> AnomalyGridResult<()> {
431 if sequence.len() < config.min_sequence_length {
433 return Err(AnomalyGridError::sequence_too_short(
434 config.min_sequence_length,
435 sequence.len(),
436 "context tree building",
437 ));
438 }
439
440 for window_size in 1..=self.max_order {
442 for window in sequence.windows(window_size + 1) {
443 if let Some(limit) = config.memory_limit {
445 if self.trie.context_count() >= limit {
446 return Err(AnomalyGridError::memory_limit_exceeded(
447 self.trie.context_count(),
448 limit,
449 ));
450 }
451 }
452
453 let context_state_ids: Vec<StateId> = window[..window_size]
455 .iter()
456 .map(|s| self.interner.get_or_intern(s))
457 .collect();
458 let next_state = &window[window_size];
459
460 let node = self.trie.get_or_create_context_data(&context_state_ids);
462 node.add_transition(next_state);
463 }
464 }
465
466 self.last_config = config.clone();
468
469 Ok(())
470 }
471
472 pub fn get_transition_probability(&self, context: &[String], next_state: &str) -> Option<f64> {
476 let context_state_ids: Vec<StateId> = context
478 .iter()
479 .map(|s| self.interner.get_or_intern(s))
480 .collect();
481
482 self.trie
483 .get_context_data(&context_state_ids)
484 .map(|node| node.get_probability(next_state, &self.last_config))
485 }
486
487 pub fn get_transition_probability_with_config(
489 &self,
490 context: &[String],
491 next_state: &str,
492 config: &AnomalyGridConfig,
493 ) -> Option<f64> {
494 let context_state_ids: Vec<StateId> = context
496 .iter()
497 .map(|s| self.interner.get_or_intern(s))
498 .collect();
499
500 self.trie
501 .get_context_data(&context_state_ids)
502 .map(|node| node.get_probability(next_state, config))
503 }
504
505 pub fn get_transition_probability_normalized(
507 &self,
508 context: &[String],
509 next_state: &str,
510 config: &AnomalyGridConfig,
511 global_state_mapping: &std::collections::HashMap<String, usize>,
512 ) -> Option<f64> {
513 let context_state_ids: Vec<StateId> = context
515 .iter()
516 .map(|s| self.interner.get_or_intern(s))
517 .collect();
518
519 self.trie.get_context_data(&context_state_ids).map(|node| {
520 node.get_probability_normalized(next_state, config, global_state_mapping.len())
521 })
522 }
523
524 pub fn get_transition_probability_normalized_ids(
526 &self,
527 context_ids: &[StateId],
528 next_state_id: StateId,
529 config: &AnomalyGridConfig,
530 global_vocab_size: usize,
531 ) -> Option<f64> {
532 self.trie.get_context_data(context_ids).map(|node| {
533 node.get_probability_normalized_by_id(next_state_id, config, global_vocab_size)
534 })
535 }
536
537 pub fn get_context_node(&self, context: &[String]) -> Option<&ContextNode> {
539 let context_state_ids: Vec<StateId> = context
541 .iter()
542 .map(|s| self.interner.get_or_intern(s))
543 .collect();
544
545 self.trie.get_context_data(&context_state_ids)
546 }
547
548 pub fn get_context_count(&self, context: &[String]) -> Option<usize> {
550 self.get_context_node(context)
551 .map(|node| node.total_count())
552 }
553
554 pub fn get_context_count_by_ids(&self, context_ids: &[StateId]) -> Option<usize> {
556 self.trie
557 .get_context_data(context_ids)
558 .map(|node| node.total_count())
559 }
560
561 pub fn get_contexts_of_order(&self, order: usize) -> Vec<Vec<String>> {
563 self.trie
564 .iter_contexts()
565 .filter_map(|(state_ids, _)| {
566 if state_ids.len() == order {
567 let strings: Option<Vec<String>> = state_ids
569 .iter()
570 .map(|&state_id| self.interner.get_string(state_id))
571 .collect();
572 strings
573 } else {
574 None
575 }
576 })
577 .collect()
578 }
579
580 pub fn context_count(&self) -> usize {
582 self.trie.context_count()
583 }
584
585 pub fn interner(&self) -> &Arc<StringInterner> {
587 &self.interner
588 }
589
590 pub fn contexts(&self) -> HashMap<Vec<String>, ContextNode> {
595 let mut contexts = HashMap::new();
596
597 for (state_ids, node) in self.trie.iter_contexts() {
598 if let Some(strings) = state_ids
600 .iter()
601 .map(|&state_id| self.interner.get_string(state_id))
602 .collect::<Option<Vec<String>>>()
603 {
604 contexts.insert(strings, node.clone());
605 }
606 }
607
608 contexts
609 }
610
611 pub(crate) fn trie(&self) -> &ContextTrie {
613 &self.trie
614 }
615
616 pub(crate) fn rebuild_filtered<F>(&mut self, mut keep: F) -> usize
618 where
619 F: FnMut(&[StateId], &ContextNode) -> bool,
620 {
621 let original_count = self.trie.context_count();
622 let mut new_trie = ContextTrie::new(self.max_order, Arc::clone(&self.interner));
623
624 for (state_ids, node) in self.trie.iter_contexts() {
625 if keep(&state_ids, node) {
626 let new_node = new_trie.get_or_create_context_data(&state_ids);
627 for (state_id, count) in node.get_state_counts() {
628 for _ in 0..count {
629 new_node.add_transition_by_id(state_id);
630 }
631 }
632 }
633 }
634
635 if new_trie.context_count() == 0 {
637 0
638 } else {
639 let removed = original_count.saturating_sub(new_trie.context_count());
640 self.trie = new_trie;
641 removed
642 }
643 }
644}