anomaly_grid/context_tree/
mod.rs

1//! Context Tree module for variable-order Markov model implementation
2//!
3//! This module implements context storage and probability estimation for building
4//! variable-order Markov models with information-theoretic measures.
5//!
6//! MEMORY OPTIMIZATIONS:
7//! - String interning: Uses StateId instead of String to reduce duplication
8//! - On-demand computation: Probabilities calculated when needed, not stored
9//! - Cached totals: Avoids recomputing transition counts repeatedly
10
11use crate::config::AnomalyGridConfig;
12use crate::context_trie::ContextTrie;
13use crate::error::{AnomalyGridError, AnomalyGridResult};
14use crate::string_interner::{StateId, StringInterner};
15use crate::transition_counts::TransitionCounts;
16use std::collections::hash_map::DefaultHasher;
17use std::collections::HashMap;
18use std::hash::{Hash, Hasher};
19use std::sync::Arc;
20
21/// A node in the context tree that stores transition statistics
22///
23/// Uses optimized storage for small collections and StateId for memory efficiency
24/// Implements lazy computation with caching for entropy and KL divergence
25#[derive(Debug, Clone)]
26pub struct ContextNode {
27    /// Optimized transition counts using SmallVec for small collections
28    counts: TransitionCounts,
29    /// Cached total count to avoid recomputation
30    total_count: usize,
31    /// String interner for converting between strings and StateIds
32    interner: Arc<StringInterner>,
33    /// Cached entropy value (computed lazily)
34    cached_entropy: Option<f64>,
35    /// Cached KL divergence value (computed lazily)
36    cached_kl_divergence: Option<f64>,
37    /// Configuration hash for cache invalidation
38    cached_config_hash: Option<u64>,
39}
40
41impl ContextNode {
42    /// Create a new empty context node with string interner
43    pub fn new(interner: Arc<StringInterner>) -> Self {
44        Self {
45            counts: TransitionCounts::new(),
46            total_count: 0,
47            interner,
48            cached_entropy: None,
49            cached_kl_divergence: None,
50            cached_config_hash: None,
51        }
52    }
53
54    /// Add a transition to this context using string interning
55    pub fn add_transition(&mut self, next_state: &str) {
56        let state_id = self.interner.get_or_intern(next_state);
57        self.counts.increment(state_id);
58        self.total_count += 1;
59        self.invalidate_cache();
60    }
61
62    /// Add a transition using StateId directly (internal use)
63    pub fn add_transition_by_id(&mut self, state_id: StateId) {
64        self.counts.increment(state_id);
65        self.total_count += 1;
66        self.invalidate_cache();
67    }
68
69    /// Invalidate cached computations when data changes
70    fn invalidate_cache(&mut self) {
71        self.cached_entropy = None;
72        self.cached_kl_divergence = None;
73        self.cached_config_hash = None;
74    }
75
76    /// Compute a hash of the configuration for cache validation
77    fn compute_config_hash(config: &AnomalyGridConfig) -> u64 {
78        let mut hasher = DefaultHasher::new();
79        // Hash the relevant configuration parameters that affect entropy/KL divergence
80        config.smoothing_alpha.to_bits().hash(&mut hasher);
81        hasher.finish()
82    }
83
84    /// Check if the cached values are valid for the given configuration
85    fn is_cache_valid(&self, config: &AnomalyGridConfig) -> bool {
86        if let Some(cached_hash) = self.cached_config_hash {
87            cached_hash == Self::compute_config_hash(config)
88        } else {
89            false
90        }
91    }
92
93    /// Get the total number of transitions from this context
94    pub fn total_count(&self) -> usize {
95        self.total_count
96    }
97
98    /// Get the count for a specific next state
99    pub fn get_count(&self, next_state: &str) -> usize {
100        let state_id = self.interner.get_or_intern(next_state);
101        self.counts.get(state_id)
102    }
103
104    /// Get the count for a StateId directly (internal use)
105    pub fn get_count_by_id(&self, state_id: StateId) -> usize {
106        self.counts.get(state_id)
107    }
108
109    /// Get the number of unique next states
110    pub fn vocab_size(&self) -> usize {
111        self.counts.len()
112    }
113
114    /// Get all state IDs with their counts (internal use)
115    pub fn get_state_counts(&self) -> impl Iterator<Item = (StateId, usize)> + '_ {
116        self.counts.iter()
117    }
118
119    /// Get the sum of all transition counts (for compatibility)
120    pub fn total_transitions(&self) -> usize {
121        self.total_count
122    }
123
124    /// Get all counts as strings for compatibility with performance module
125    pub fn get_string_counts(&self) -> HashMap<String, usize> {
126        self.counts
127            .iter()
128            .filter_map(|(state_id, count)| self.interner.get_string(state_id).map(|s| (s, count)))
129            .collect()
130    }
131
132    /// Get counts for debugging (returns string representation)
133    pub fn counts(&self) -> HashMap<String, usize> {
134        self.get_string_counts()
135    }
136
137    /// Get the probability for a specific next state using Laplace smoothing
138    ///
139    /// Computes probability on-demand: P(state) = (count + α) / (total + α * |V|)
140    pub fn get_probability(&self, next_state: &str, config: &AnomalyGridConfig) -> f64 {
141        let state_id = self.interner.get_or_intern(next_state);
142        self.get_probability_by_id(state_id, config)
143    }
144
145    /// Get probability for a StateId directly (internal use)
146    pub fn get_probability_by_id(&self, state_id: StateId, config: &AnomalyGridConfig) -> f64 {
147        if self.total_count == 0 {
148            return 1.0 / (self.vocab_size() as f64).max(1.0);
149        }
150
151        let count = self.get_count_by_id(state_id) as f64;
152        let vocab_size = self.vocab_size() as f64;
153
154        (count + config.smoothing_alpha)
155            / (self.total_count as f64 + config.smoothing_alpha * vocab_size)
156    }
157
158    /// Get probability with proper normalization using global vocabulary size
159    pub fn get_probability_normalized(
160        &self,
161        next_state: &str,
162        config: &AnomalyGridConfig,
163        global_vocab_size: usize,
164    ) -> f64 {
165        if self.total_count == 0 {
166            return 1.0 / (global_vocab_size as f64).max(1.0);
167        }
168
169        let state_id = self.interner.get_or_intern(next_state);
170        let count = self.get_count_by_id(state_id) as f64;
171        let global_vocab_size_f64 = global_vocab_size as f64;
172
173        // Use global vocabulary size for proper normalization
174        // This ensures that probabilities for all states in the global vocabulary sum to 1
175        (count + config.smoothing_alpha)
176            / (self.total_count as f64 + config.smoothing_alpha * global_vocab_size_f64)
177    }
178
179    /// Get probability with proper normalization using global vocabulary by StateId
180    pub fn get_probability_normalized_by_id(
181        &self,
182        state_id: StateId,
183        config: &AnomalyGridConfig,
184        global_vocab_size: usize,
185    ) -> f64 {
186        if self.total_count == 0 {
187            return 1.0 / (global_vocab_size as f64).max(1.0);
188        }
189
190        let count = self.get_count_by_id(state_id) as f64;
191        let global_vocab_size_f64 = global_vocab_size as f64;
192
193        (count + config.smoothing_alpha)
194            / (self.total_count as f64 + config.smoothing_alpha * global_vocab_size_f64)
195    }
196
197    /// Calculate Shannon entropy with lazy computation and caching: H(X) = -∑ P(x) log₂ P(x)
198    pub fn calculate_entropy(&mut self, config: &AnomalyGridConfig) -> f64 {
199        // Check if we have a valid cached value
200        if self.is_cache_valid(config) {
201            if let Some(cached_entropy) = self.cached_entropy {
202                return cached_entropy;
203            }
204        }
205
206        // Compute entropy
207        let entropy = if self.total_count == 0 {
208            0.0
209        } else {
210            self.counts
211                .keys()
212                .map(|state_id| {
213                    let p = self.get_probability_by_id(state_id, config);
214                    if p > 0.0 {
215                        -p * p.log2()
216                    } else {
217                        0.0
218                    }
219                })
220                .sum()
221        };
222
223        // Cache the result
224        self.cached_entropy = Some(entropy);
225        self.cached_config_hash = Some(Self::compute_config_hash(config));
226
227        entropy
228    }
229
230    /// Calculate Shannon entropy without caching (for immutable access)
231    pub fn compute_entropy(&self, config: &AnomalyGridConfig) -> f64 {
232        if self.total_count == 0 {
233            return 0.0;
234        }
235
236        self.counts
237            .keys()
238            .map(|state_id| {
239                let p = self.get_probability_by_id(state_id, config);
240                if p > 0.0 {
241                    -p * p.log2()
242                } else {
243                    0.0
244                }
245            })
246            .sum()
247    }
248
249    /// Calculate KL divergence from uniform distribution with lazy computation and caching
250    pub fn calculate_kl_divergence(&mut self, config: &AnomalyGridConfig) -> f64 {
251        // Check if we have a valid cached value
252        if self.is_cache_valid(config) {
253            if let Some(cached_kl_div) = self.cached_kl_divergence {
254                return cached_kl_div;
255            }
256        }
257
258        // Compute KL divergence
259        let kl_divergence = if self.total_count == 0 {
260            0.0
261        } else {
262            let uniform_prob = 1.0 / self.vocab_size() as f64;
263
264            self.counts
265                .keys()
266                .map(|state_id| {
267                    let p = self.get_probability_by_id(state_id, config);
268                    if p > 0.0 {
269                        p * (p / uniform_prob).log2()
270                    } else {
271                        0.0
272                    }
273                })
274                .sum()
275        };
276
277        // Cache the result
278        self.cached_kl_divergence = Some(kl_divergence);
279        self.cached_config_hash = Some(Self::compute_config_hash(config));
280
281        kl_divergence
282    }
283
284    /// Calculate KL divergence from uniform distribution without caching (for immutable access)
285    pub fn compute_kl_divergence(&self, config: &AnomalyGridConfig) -> f64 {
286        if self.total_count == 0 {
287            return 0.0;
288        }
289
290        let uniform_prob = 1.0 / self.vocab_size() as f64;
291
292        self.counts
293            .keys()
294            .map(|state_id| {
295                let p = self.get_probability_by_id(state_id, config);
296                if p > 0.0 {
297                    p * (p / uniform_prob).log2()
298                } else {
299                    0.0
300                }
301            })
302            .sum()
303    }
304
305    /// Get all probabilities as a HashMap (for compatibility with existing code)
306    ///
307    /// Note: This creates temporary storage and should be used sparingly
308    pub fn get_all_probabilities(&self, config: &AnomalyGridConfig) -> HashMap<String, f64> {
309        self.counts
310            .keys()
311            .filter_map(|state_id| {
312                self.interner.get_string(state_id).map(|state_string| {
313                    let prob = self.get_probability_by_id(state_id, config);
314                    (state_string, prob)
315                })
316            })
317            .collect()
318    }
319
320    /// Reset the context node for reuse in memory pool
321    pub fn reset(&mut self, interner: Arc<StringInterner>) {
322        self.counts = TransitionCounts::new();
323        self.total_count = 0;
324        self.interner = interner;
325        self.cached_entropy = None;
326        self.cached_kl_divergence = None;
327        self.cached_config_hash = None;
328    }
329
330    /// Clear the context node data for memory pool return
331    pub fn clear(&mut self) {
332        self.counts = TransitionCounts::new();
333        self.total_count = 0;
334        self.cached_entropy = None;
335        self.cached_kl_divergence = None;
336        self.cached_config_hash = None;
337        // Keep the interner for potential reuse
338    }
339
340    /// Get cache hit statistics for monitoring
341    pub fn cache_stats(&self) -> (bool, bool) {
342        (
343            self.cached_entropy.is_some(),
344            self.cached_kl_divergence.is_some(),
345        )
346    }
347}
348
349impl Default for ContextNode {
350    fn default() -> Self {
351        Self {
352            counts: TransitionCounts::new(),
353            total_count: 0,
354            interner: Arc::new(StringInterner::new()),
355            cached_entropy: None,
356            cached_kl_divergence: None,
357            cached_config_hash: None,
358        }
359    }
360}
361
362/// Context tree for storing variable-order Markov chain contexts
363///
364/// Uses trie-based storage for memory efficiency through prefix sharing
365#[derive(Debug, Clone)]
366pub struct ContextTree {
367    /// Trie-based storage for memory-efficient prefix sharing
368    trie: ContextTrie,
369    /// Maximum context order (length)
370    pub max_order: usize,
371    /// String interner for converting between strings and StateIds
372    interner: Arc<StringInterner>,
373    /// Last-used configuration for probability calculations
374    pub(crate) last_config: AnomalyGridConfig,
375}
376
377impl ContextTree {
378    /// Create a new context tree with specified maximum order
379    pub fn new(max_order: usize) -> AnomalyGridResult<Self> {
380        if max_order == 0 {
381            return Err(AnomalyGridError::invalid_max_order(max_order));
382        }
383
384        let interner = Arc::new(StringInterner::new());
385        let trie = ContextTrie::new(max_order, Arc::clone(&interner));
386        let last_config = AnomalyGridConfig::default();
387
388        Ok(Self {
389            trie,
390            max_order,
391            interner,
392            last_config,
393        })
394    }
395
396    /// Create a new context tree with existing string interner
397    pub fn with_interner(
398        max_order: usize,
399        interner: Arc<StringInterner>,
400    ) -> AnomalyGridResult<Self> {
401        if max_order == 0 {
402            return Err(AnomalyGridError::invalid_max_order(max_order));
403        }
404
405        let trie = ContextTrie::new(max_order, Arc::clone(&interner));
406        let last_config = AnomalyGridConfig::default();
407
408        Ok(Self {
409            trie,
410            max_order,
411            interner,
412            last_config,
413        })
414    }
415
416    /// Build the context tree from a training sequence
417    ///
418    /// # Complexity
419    /// - Time: O(n × max_order × |alphabet|) where n = sequence length
420    /// - Space: O(|alphabet|^max_order) in worst case
421    ///
422    /// # Performance Guarantees
423    /// - Memory usage is bounded by config.memory_limit if set
424    /// - Processing time scales linearly with sequence length
425    /// - Uses string interning to reduce memory duplication
426    pub fn build_from_sequence(
427        &mut self,
428        sequence: &[String],
429        config: &AnomalyGridConfig,
430    ) -> AnomalyGridResult<()> {
431        // Validate sequence length
432        if sequence.len() < config.min_sequence_length {
433            return Err(AnomalyGridError::sequence_too_short(
434                config.min_sequence_length,
435                sequence.len(),
436                "context tree building",
437            ));
438        }
439
440        // Extract contexts of all orders from 1 to max_order
441        for window_size in 1..=self.max_order {
442            for window in sequence.windows(window_size + 1) {
443                // Check memory limit before adding new context
444                if let Some(limit) = config.memory_limit {
445                    if self.trie.context_count() >= limit {
446                        return Err(AnomalyGridError::memory_limit_exceeded(
447                            self.trie.context_count(),
448                            limit,
449                        ));
450                    }
451                }
452
453                // Convert context to StateIds for trie storage
454                let context_state_ids: Vec<StateId> = window[..window_size]
455                    .iter()
456                    .map(|s| self.interner.get_or_intern(s))
457                    .collect();
458                let next_state = &window[window_size];
459
460                // Get or create context node in trie
461                let node = self.trie.get_or_create_context_data(&context_state_ids);
462                node.add_transition(next_state);
463            }
464        }
465
466        // Store last-used config for future probability queries
467        self.last_config = config.clone();
468
469        Ok(())
470    }
471
472    /// Get the transition probability for a given context and next state
473    ///
474    /// Uses the last configuration seen during training (falls back to default if none).
475    pub fn get_transition_probability(&self, context: &[String], next_state: &str) -> Option<f64> {
476        // Convert context to StateIds
477        let context_state_ids: Vec<StateId> = context
478            .iter()
479            .map(|s| self.interner.get_or_intern(s))
480            .collect();
481
482        self.trie
483            .get_context_data(&context_state_ids)
484            .map(|node| node.get_probability(next_state, &self.last_config))
485    }
486
487    /// Get the transition probability with custom config
488    pub fn get_transition_probability_with_config(
489        &self,
490        context: &[String],
491        next_state: &str,
492        config: &AnomalyGridConfig,
493    ) -> Option<f64> {
494        // Convert context to StateIds
495        let context_state_ids: Vec<StateId> = context
496            .iter()
497            .map(|s| self.interner.get_or_intern(s))
498            .collect();
499
500        self.trie
501            .get_context_data(&context_state_ids)
502            .map(|node| node.get_probability(next_state, config))
503    }
504
505    /// Get the transition probability with proper normalization using global vocabulary
506    pub fn get_transition_probability_normalized(
507        &self,
508        context: &[String],
509        next_state: &str,
510        config: &AnomalyGridConfig,
511        global_state_mapping: &std::collections::HashMap<String, usize>,
512    ) -> Option<f64> {
513        // Convert context to StateIds
514        let context_state_ids: Vec<StateId> = context
515            .iter()
516            .map(|s| self.interner.get_or_intern(s))
517            .collect();
518
519        self.trie.get_context_data(&context_state_ids).map(|node| {
520            node.get_probability_normalized(next_state, config, global_state_mapping.len())
521        })
522    }
523
524    /// Get the transition probability with proper normalization using StateIds
525    pub fn get_transition_probability_normalized_ids(
526        &self,
527        context_ids: &[StateId],
528        next_state_id: StateId,
529        config: &AnomalyGridConfig,
530        global_vocab_size: usize,
531    ) -> Option<f64> {
532        self.trie.get_context_data(context_ids).map(|node| {
533            node.get_probability_normalized_by_id(next_state_id, config, global_vocab_size)
534        })
535    }
536
537    /// Get a context node for the given context
538    pub fn get_context_node(&self, context: &[String]) -> Option<&ContextNode> {
539        // Convert context to StateIds
540        let context_state_ids: Vec<StateId> = context
541            .iter()
542            .map(|s| self.interner.get_or_intern(s))
543            .collect();
544
545        self.trie.get_context_data(&context_state_ids)
546    }
547
548    /// Get the total count for a given context (for adaptive context selection)
549    pub fn get_context_count(&self, context: &[String]) -> Option<usize> {
550        self.get_context_node(context)
551            .map(|node| node.total_count())
552    }
553
554    /// Get the total count for a given context by StateIds
555    pub fn get_context_count_by_ids(&self, context_ids: &[StateId]) -> Option<usize> {
556        self.trie
557            .get_context_data(context_ids)
558            .map(|node| node.total_count())
559    }
560
561    /// Get all contexts of a specific order
562    pub fn get_contexts_of_order(&self, order: usize) -> Vec<Vec<String>> {
563        self.trie
564            .iter_contexts()
565            .filter_map(|(state_ids, _)| {
566                if state_ids.len() == order {
567                    // Convert StateIds back to strings
568                    let strings: Option<Vec<String>> = state_ids
569                        .iter()
570                        .map(|&state_id| self.interner.get_string(state_id))
571                        .collect();
572                    strings
573                } else {
574                    None
575                }
576            })
577            .collect()
578    }
579
580    /// Get the number of contexts stored
581    pub fn context_count(&self) -> usize {
582        self.trie.context_count()
583    }
584
585    /// Get access to the string interner
586    pub fn interner(&self) -> &Arc<StringInterner> {
587        &self.interner
588    }
589
590    /// Get all contexts as a HashMap for compatibility with existing code
591    ///
592    /// Note: This creates a temporary HashMap and should be used sparingly
593    /// for compatibility with existing tests and code that expects the old interface
594    pub fn contexts(&self) -> HashMap<Vec<String>, ContextNode> {
595        let mut contexts = HashMap::new();
596
597        for (state_ids, node) in self.trie.iter_contexts() {
598            // Convert StateIds back to strings
599            if let Some(strings) = state_ids
600                .iter()
601                .map(|&state_id| self.interner.get_string(state_id))
602                .collect::<Option<Vec<String>>>()
603            {
604                contexts.insert(strings, node.clone());
605            }
606        }
607
608        contexts
609    }
610
611    /// Get the trie for internal operations
612    pub(crate) fn trie(&self) -> &ContextTrie {
613        &self.trie
614    }
615
616    /// Rebuild the trie using a filter predicate; returns number of removed contexts
617    pub(crate) fn rebuild_filtered<F>(&mut self, mut keep: F) -> usize
618    where
619        F: FnMut(&[StateId], &ContextNode) -> bool,
620    {
621        let original_count = self.trie.context_count();
622        let mut new_trie = ContextTrie::new(self.max_order, Arc::clone(&self.interner));
623
624        for (state_ids, node) in self.trie.iter_contexts() {
625            if keep(&state_ids, node) {
626                let new_node = new_trie.get_or_create_context_data(&state_ids);
627                for (state_id, count) in node.get_state_counts() {
628                    for _ in 0..count {
629                        new_node.add_transition_by_id(state_id);
630                    }
631                }
632            }
633        }
634
635        // Avoid pruning everything; if nothing would remain, keep original trie
636        if new_trie.context_count() == 0 {
637            0
638        } else {
639            let removed = original_count.saturating_sub(new_trie.context_count());
640            self.trie = new_trie;
641            removed
642        }
643    }
644}