anomaly_grid/context_tree/
mod.rs1use crate::config::AnomalyGridConfig;
12use crate::context_trie::ContextTrie;
13use crate::error::{AnomalyGridError, AnomalyGridResult};
14use crate::string_interner::{StateId, StringInterner};
15use crate::transition_counts::TransitionCounts;
16use std::collections::HashMap;
17use std::collections::hash_map::DefaultHasher;
18use std::hash::{Hash, Hasher};
19use std::sync::Arc;
20
21#[derive(Debug, Clone)]
26pub struct ContextNode {
27 counts: TransitionCounts,
29 total_count: usize,
31 interner: Arc<StringInterner>,
33 cached_entropy: Option<f64>,
35 cached_kl_divergence: Option<f64>,
37 cached_config_hash: Option<u64>,
39}
40
41impl ContextNode {
42 pub fn new(interner: Arc<StringInterner>) -> Self {
44 Self {
45 counts: TransitionCounts::new(),
46 total_count: 0,
47 interner,
48 cached_entropy: None,
49 cached_kl_divergence: None,
50 cached_config_hash: None,
51 }
52 }
53
54 pub fn add_transition(&mut self, next_state: &str) {
56 let state_id = self.interner.get_or_intern(next_state);
57 self.counts.increment(state_id);
58 self.total_count += 1;
59 self.invalidate_cache();
60 }
61
62 pub fn add_transition_by_id(&mut self, state_id: StateId) {
64 self.counts.increment(state_id);
65 self.total_count += 1;
66 self.invalidate_cache();
67 }
68
69 fn invalidate_cache(&mut self) {
71 self.cached_entropy = None;
72 self.cached_kl_divergence = None;
73 self.cached_config_hash = None;
74 }
75
76 fn compute_config_hash(config: &AnomalyGridConfig) -> u64 {
78 let mut hasher = DefaultHasher::new();
79 config.smoothing_alpha.to_bits().hash(&mut hasher);
81 hasher.finish()
82 }
83
84 fn is_cache_valid(&self, config: &AnomalyGridConfig) -> bool {
86 if let Some(cached_hash) = self.cached_config_hash {
87 cached_hash == Self::compute_config_hash(config)
88 } else {
89 false
90 }
91 }
92
93 pub fn total_count(&self) -> usize {
95 self.total_count
96 }
97
98 pub fn get_count(&self, next_state: &str) -> usize {
100 let state_id = self.interner.get_or_intern(next_state);
101 self.counts.get(state_id)
102 }
103
104 pub fn get_count_by_id(&self, state_id: StateId) -> usize {
106 self.counts.get(state_id)
107 }
108
109 pub fn vocab_size(&self) -> usize {
111 self.counts.len()
112 }
113
114 pub fn get_state_counts(&self) -> impl Iterator<Item = (StateId, usize)> + '_ {
116 self.counts.iter()
117 }
118
119 pub fn total_transitions(&self) -> usize {
121 self.total_count
122 }
123
124 pub fn get_string_counts(&self) -> HashMap<String, usize> {
126 self.counts
127 .iter()
128 .filter_map(|(state_id, count)| self.interner.get_string(state_id).map(|s| (s, count)))
129 .collect()
130 }
131
132 pub fn counts(&self) -> HashMap<String, usize> {
134 self.get_string_counts()
135 }
136
137 pub fn get_probability(&self, next_state: &str, config: &AnomalyGridConfig) -> f64 {
141 let state_id = self.interner.get_or_intern(next_state);
142 self.get_probability_by_id(state_id, config)
143 }
144
145 pub fn get_probability_by_id(&self, state_id: StateId, config: &AnomalyGridConfig) -> f64 {
147 if self.total_count == 0 {
148 return 1.0 / (self.vocab_size() as f64).max(1.0);
149 }
150
151 let count = self.get_count_by_id(state_id) as f64;
152 let vocab_size = self.vocab_size() as f64;
153
154 (count + config.smoothing_alpha)
155 / (self.total_count as f64 + config.smoothing_alpha * vocab_size)
156 }
157
158 pub fn calculate_entropy(&mut self, config: &AnomalyGridConfig) -> f64 {
160 if self.is_cache_valid(config) {
162 if let Some(cached_entropy) = self.cached_entropy {
163 return cached_entropy;
164 }
165 }
166
167 let entropy = if self.total_count == 0 {
169 0.0
170 } else {
171 self.counts
172 .keys()
173 .map(|state_id| {
174 let p = self.get_probability_by_id(state_id, config);
175 if p > 0.0 {
176 -p * p.log2()
177 } else {
178 0.0
179 }
180 })
181 .sum()
182 };
183
184 self.cached_entropy = Some(entropy);
186 self.cached_config_hash = Some(Self::compute_config_hash(config));
187
188 entropy
189 }
190
191 pub fn compute_entropy(&self, config: &AnomalyGridConfig) -> f64 {
193 if self.total_count == 0 {
194 return 0.0;
195 }
196
197 self.counts
198 .keys()
199 .map(|state_id| {
200 let p = self.get_probability_by_id(state_id, config);
201 if p > 0.0 {
202 -p * p.log2()
203 } else {
204 0.0
205 }
206 })
207 .sum()
208 }
209
210 pub fn calculate_kl_divergence(&mut self, config: &AnomalyGridConfig) -> f64 {
212 if self.is_cache_valid(config) {
214 if let Some(cached_kl_div) = self.cached_kl_divergence {
215 return cached_kl_div;
216 }
217 }
218
219 let kl_divergence = if self.total_count == 0 {
221 0.0
222 } else {
223 let uniform_prob = 1.0 / self.vocab_size() as f64;
224
225 self.counts
226 .keys()
227 .map(|state_id| {
228 let p = self.get_probability_by_id(state_id, config);
229 if p > 0.0 {
230 p * (p / uniform_prob).log2()
231 } else {
232 0.0
233 }
234 })
235 .sum()
236 };
237
238 self.cached_kl_divergence = Some(kl_divergence);
240 self.cached_config_hash = Some(Self::compute_config_hash(config));
241
242 kl_divergence
243 }
244
245 pub fn compute_kl_divergence(&self, config: &AnomalyGridConfig) -> f64 {
247 if self.total_count == 0 {
248 return 0.0;
249 }
250
251 let uniform_prob = 1.0 / self.vocab_size() as f64;
252
253 self.counts
254 .keys()
255 .map(|state_id| {
256 let p = self.get_probability_by_id(state_id, config);
257 if p > 0.0 {
258 p * (p / uniform_prob).log2()
259 } else {
260 0.0
261 }
262 })
263 .sum()
264 }
265
266 pub fn get_all_probabilities(&self, config: &AnomalyGridConfig) -> HashMap<String, f64> {
270 self.counts
271 .keys()
272 .filter_map(|state_id| {
273 self.interner.get_string(state_id).map(|state_string| {
274 let prob = self.get_probability_by_id(state_id, config);
275 (state_string, prob)
276 })
277 })
278 .collect()
279 }
280
281 pub fn reset(&mut self, interner: Arc<StringInterner>) {
283 self.counts = TransitionCounts::new();
284 self.total_count = 0;
285 self.interner = interner;
286 self.cached_entropy = None;
287 self.cached_kl_divergence = None;
288 self.cached_config_hash = None;
289 }
290
291 pub fn clear(&mut self) {
293 self.counts = TransitionCounts::new();
294 self.total_count = 0;
295 self.cached_entropy = None;
296 self.cached_kl_divergence = None;
297 self.cached_config_hash = None;
298 }
300
301 pub fn cache_stats(&self) -> (bool, bool) {
303 (self.cached_entropy.is_some(), self.cached_kl_divergence.is_some())
304 }
305}
306
307impl Default for ContextNode {
308 fn default() -> Self {
309 Self {
310 counts: TransitionCounts::new(),
311 total_count: 0,
312 interner: Arc::new(StringInterner::new()),
313 cached_entropy: None,
314 cached_kl_divergence: None,
315 cached_config_hash: None,
316 }
317 }
318}
319
320#[derive(Debug, Clone)]
324pub struct ContextTree {
325 trie: ContextTrie,
327 pub max_order: usize,
329 interner: Arc<StringInterner>,
331}
332
333impl ContextTree {
334 pub fn new(max_order: usize) -> AnomalyGridResult<Self> {
336 if max_order == 0 {
337 return Err(AnomalyGridError::invalid_max_order(max_order));
338 }
339
340 let interner = Arc::new(StringInterner::new());
341 let trie = ContextTrie::new(max_order, Arc::clone(&interner));
342
343 Ok(Self {
344 trie,
345 max_order,
346 interner,
347 })
348 }
349
350 pub fn with_interner(
352 max_order: usize,
353 interner: Arc<StringInterner>,
354 ) -> AnomalyGridResult<Self> {
355 if max_order == 0 {
356 return Err(AnomalyGridError::invalid_max_order(max_order));
357 }
358
359 let trie = ContextTrie::new(max_order, Arc::clone(&interner));
360
361 Ok(Self {
362 trie,
363 max_order,
364 interner,
365 })
366 }
367
368 pub fn build_from_sequence(
379 &mut self,
380 sequence: &[String],
381 config: &AnomalyGridConfig,
382 ) -> AnomalyGridResult<()> {
383 if sequence.len() < config.min_sequence_length {
385 return Err(AnomalyGridError::sequence_too_short(
386 config.min_sequence_length,
387 sequence.len(),
388 "context tree building",
389 ));
390 }
391
392 for window_size in 1..=self.max_order {
394 for window in sequence.windows(window_size + 1) {
395 if let Some(limit) = config.memory_limit {
397 if self.trie.context_count() >= limit {
398 return Err(AnomalyGridError::memory_limit_exceeded(
399 self.trie.context_count(),
400 limit,
401 ));
402 }
403 }
404
405 let context_state_ids: Vec<StateId> = window[..window_size]
407 .iter()
408 .map(|s| self.interner.get_or_intern(s))
409 .collect();
410 let next_state = &window[window_size];
411
412 let node = self.trie.get_or_create_context_data(&context_state_ids);
414 node.add_transition(next_state);
415 }
416 }
417
418 Ok(())
419 }
420
421 pub fn get_transition_probability(&self, context: &[String], next_state: &str) -> Option<f64> {
423 let context_state_ids: Vec<StateId> = context
425 .iter()
426 .map(|s| self.interner.get_or_intern(s))
427 .collect();
428
429 self.trie
430 .get_context_data(&context_state_ids)
431 .map(|node| node.get_probability(next_state, &AnomalyGridConfig::default()))
432 }
433
434 pub fn get_transition_probability_with_config(
436 &self,
437 context: &[String],
438 next_state: &str,
439 config: &AnomalyGridConfig,
440 ) -> Option<f64> {
441 let context_state_ids: Vec<StateId> = context
443 .iter()
444 .map(|s| self.interner.get_or_intern(s))
445 .collect();
446
447 self.trie
448 .get_context_data(&context_state_ids)
449 .map(|node| node.get_probability(next_state, config))
450 }
451
452 pub fn get_context_node(&self, context: &[String]) -> Option<&ContextNode> {
454 let context_state_ids: Vec<StateId> = context
456 .iter()
457 .map(|s| self.interner.get_or_intern(s))
458 .collect();
459
460 self.trie.get_context_data(&context_state_ids)
461 }
462
463 pub fn get_contexts_of_order(&self, order: usize) -> Vec<Vec<String>> {
465 self.trie
466 .iter_contexts()
467 .filter_map(|(state_ids, _)| {
468 if state_ids.len() == order {
469 let strings: Option<Vec<String>> = state_ids
471 .iter()
472 .map(|&state_id| self.interner.get_string(state_id))
473 .collect();
474 strings
475 } else {
476 None
477 }
478 })
479 .collect()
480 }
481
482 pub fn context_count(&self) -> usize {
484 self.trie.context_count()
485 }
486
487 pub fn interner(&self) -> &Arc<StringInterner> {
489 &self.interner
490 }
491
492 pub fn contexts(&self) -> HashMap<Vec<String>, ContextNode> {
497 let mut contexts = HashMap::new();
498
499 for (state_ids, node) in self.trie.iter_contexts() {
500 if let Some(strings) = state_ids
502 .iter()
503 .map(|&state_id| self.interner.get_string(state_id))
504 .collect::<Option<Vec<String>>>()
505 {
506 contexts.insert(strings, node.clone());
507 }
508 }
509
510 contexts
511 }
512
513 pub(crate) fn trie(&self) -> &ContextTrie {
515 &self.trie
516 }
517}