anomaly_grid/context_tree/
mod.rs

1//! Context Tree module for variable-order Markov model implementation
2//!
3//! This module implements context storage and probability estimation for building
4//! variable-order Markov models with information-theoretic measures.
5//!
6//! MEMORY OPTIMIZATIONS:
7//! - String interning: Uses StateId instead of String to reduce duplication
8//! - On-demand computation: Probabilities calculated when needed, not stored
9//! - Cached totals: Avoids recomputing transition counts repeatedly
10
11use crate::config::AnomalyGridConfig;
12use crate::context_trie::ContextTrie;
13use crate::error::{AnomalyGridError, AnomalyGridResult};
14use crate::string_interner::{StateId, StringInterner};
15use crate::transition_counts::TransitionCounts;
16use std::collections::HashMap;
17use std::collections::hash_map::DefaultHasher;
18use std::hash::{Hash, Hasher};
19use std::sync::Arc;
20
21/// A node in the context tree that stores transition statistics
22///
23/// Uses optimized storage for small collections and StateId for memory efficiency
24/// Implements lazy computation with caching for entropy and KL divergence
25#[derive(Debug, Clone)]
26pub struct ContextNode {
27    /// Optimized transition counts using SmallVec for small collections
28    counts: TransitionCounts,
29    /// Cached total count to avoid recomputation
30    total_count: usize,
31    /// String interner for converting between strings and StateIds
32    interner: Arc<StringInterner>,
33    /// Cached entropy value (computed lazily)
34    cached_entropy: Option<f64>,
35    /// Cached KL divergence value (computed lazily)
36    cached_kl_divergence: Option<f64>,
37    /// Configuration hash for cache invalidation
38    cached_config_hash: Option<u64>,
39}
40
41impl ContextNode {
42    /// Create a new empty context node with string interner
43    pub fn new(interner: Arc<StringInterner>) -> Self {
44        Self {
45            counts: TransitionCounts::new(),
46            total_count: 0,
47            interner,
48            cached_entropy: None,
49            cached_kl_divergence: None,
50            cached_config_hash: None,
51        }
52    }
53
54    /// Add a transition to this context using string interning
55    pub fn add_transition(&mut self, next_state: &str) {
56        let state_id = self.interner.get_or_intern(next_state);
57        self.counts.increment(state_id);
58        self.total_count += 1;
59        self.invalidate_cache();
60    }
61
62    /// Add a transition using StateId directly (internal use)
63    pub fn add_transition_by_id(&mut self, state_id: StateId) {
64        self.counts.increment(state_id);
65        self.total_count += 1;
66        self.invalidate_cache();
67    }
68
69    /// Invalidate cached computations when data changes
70    fn invalidate_cache(&mut self) {
71        self.cached_entropy = None;
72        self.cached_kl_divergence = None;
73        self.cached_config_hash = None;
74    }
75
76    /// Compute a hash of the configuration for cache validation
77    fn compute_config_hash(config: &AnomalyGridConfig) -> u64 {
78        let mut hasher = DefaultHasher::new();
79        // Hash the relevant configuration parameters that affect entropy/KL divergence
80        config.smoothing_alpha.to_bits().hash(&mut hasher);
81        hasher.finish()
82    }
83
84    /// Check if the cached values are valid for the given configuration
85    fn is_cache_valid(&self, config: &AnomalyGridConfig) -> bool {
86        if let Some(cached_hash) = self.cached_config_hash {
87            cached_hash == Self::compute_config_hash(config)
88        } else {
89            false
90        }
91    }
92
93    /// Get the total number of transitions from this context
94    pub fn total_count(&self) -> usize {
95        self.total_count
96    }
97
98    /// Get the count for a specific next state
99    pub fn get_count(&self, next_state: &str) -> usize {
100        let state_id = self.interner.get_or_intern(next_state);
101        self.counts.get(state_id)
102    }
103
104    /// Get the count for a StateId directly (internal use)
105    pub fn get_count_by_id(&self, state_id: StateId) -> usize {
106        self.counts.get(state_id)
107    }
108
109    /// Get the number of unique next states
110    pub fn vocab_size(&self) -> usize {
111        self.counts.len()
112    }
113
114    /// Get all state IDs with their counts (internal use)
115    pub fn get_state_counts(&self) -> impl Iterator<Item = (StateId, usize)> + '_ {
116        self.counts.iter()
117    }
118
119    /// Get the sum of all transition counts (for compatibility)
120    pub fn total_transitions(&self) -> usize {
121        self.total_count
122    }
123
124    /// Get all counts as strings for compatibility with performance module
125    pub fn get_string_counts(&self) -> HashMap<String, usize> {
126        self.counts
127            .iter()
128            .filter_map(|(state_id, count)| self.interner.get_string(state_id).map(|s| (s, count)))
129            .collect()
130    }
131
132    /// Get counts for debugging (returns string representation)
133    pub fn counts(&self) -> HashMap<String, usize> {
134        self.get_string_counts()
135    }
136
137    /// Get the probability for a specific next state using Laplace smoothing
138    ///
139    /// Computes probability on-demand: P(state) = (count + α) / (total + α * |V|)
140    pub fn get_probability(&self, next_state: &str, config: &AnomalyGridConfig) -> f64 {
141        let state_id = self.interner.get_or_intern(next_state);
142        self.get_probability_by_id(state_id, config)
143    }
144
145    /// Get probability for a StateId directly (internal use)
146    pub fn get_probability_by_id(&self, state_id: StateId, config: &AnomalyGridConfig) -> f64 {
147        if self.total_count == 0 {
148            return 1.0 / (self.vocab_size() as f64).max(1.0);
149        }
150
151        let count = self.get_count_by_id(state_id) as f64;
152        let vocab_size = self.vocab_size() as f64;
153
154        (count + config.smoothing_alpha)
155            / (self.total_count as f64 + config.smoothing_alpha * vocab_size)
156    }
157
158    /// Calculate Shannon entropy with lazy computation and caching: H(X) = -∑ P(x) log₂ P(x)
159    pub fn calculate_entropy(&mut self, config: &AnomalyGridConfig) -> f64 {
160        // Check if we have a valid cached value
161        if self.is_cache_valid(config) {
162            if let Some(cached_entropy) = self.cached_entropy {
163                return cached_entropy;
164            }
165        }
166
167        // Compute entropy
168        let entropy = if self.total_count == 0 {
169            0.0
170        } else {
171            self.counts
172                .keys()
173                .map(|state_id| {
174                    let p = self.get_probability_by_id(state_id, config);
175                    if p > 0.0 {
176                        -p * p.log2()
177                    } else {
178                        0.0
179                    }
180                })
181                .sum()
182        };
183
184        // Cache the result
185        self.cached_entropy = Some(entropy);
186        self.cached_config_hash = Some(Self::compute_config_hash(config));
187        
188        entropy
189    }
190
191    /// Calculate Shannon entropy without caching (for immutable access)
192    pub fn compute_entropy(&self, config: &AnomalyGridConfig) -> f64 {
193        if self.total_count == 0 {
194            return 0.0;
195        }
196
197        self.counts
198            .keys()
199            .map(|state_id| {
200                let p = self.get_probability_by_id(state_id, config);
201                if p > 0.0 {
202                    -p * p.log2()
203                } else {
204                    0.0
205                }
206            })
207            .sum()
208    }
209
210    /// Calculate KL divergence from uniform distribution with lazy computation and caching
211    pub fn calculate_kl_divergence(&mut self, config: &AnomalyGridConfig) -> f64 {
212        // Check if we have a valid cached value
213        if self.is_cache_valid(config) {
214            if let Some(cached_kl_div) = self.cached_kl_divergence {
215                return cached_kl_div;
216            }
217        }
218
219        // Compute KL divergence
220        let kl_divergence = if self.total_count == 0 {
221            0.0
222        } else {
223            let uniform_prob = 1.0 / self.vocab_size() as f64;
224            
225            self.counts
226                .keys()
227                .map(|state_id| {
228                    let p = self.get_probability_by_id(state_id, config);
229                    if p > 0.0 {
230                        p * (p / uniform_prob).log2()
231                    } else {
232                        0.0
233                    }
234                })
235                .sum()
236        };
237
238        // Cache the result
239        self.cached_kl_divergence = Some(kl_divergence);
240        self.cached_config_hash = Some(Self::compute_config_hash(config));
241        
242        kl_divergence
243    }
244
245    /// Calculate KL divergence from uniform distribution without caching (for immutable access)
246    pub fn compute_kl_divergence(&self, config: &AnomalyGridConfig) -> f64 {
247        if self.total_count == 0 {
248            return 0.0;
249        }
250
251        let uniform_prob = 1.0 / self.vocab_size() as f64;
252
253        self.counts
254            .keys()
255            .map(|state_id| {
256                let p = self.get_probability_by_id(state_id, config);
257                if p > 0.0 {
258                    p * (p / uniform_prob).log2()
259                } else {
260                    0.0
261                }
262            })
263            .sum()
264    }
265
266    /// Get all probabilities as a HashMap (for compatibility with existing code)
267    ///
268    /// Note: This creates temporary storage and should be used sparingly
269    pub fn get_all_probabilities(&self, config: &AnomalyGridConfig) -> HashMap<String, f64> {
270        self.counts
271            .keys()
272            .filter_map(|state_id| {
273                self.interner.get_string(state_id).map(|state_string| {
274                    let prob = self.get_probability_by_id(state_id, config);
275                    (state_string, prob)
276                })
277            })
278            .collect()
279    }
280
281    /// Reset the context node for reuse in memory pool
282    pub fn reset(&mut self, interner: Arc<StringInterner>) {
283        self.counts = TransitionCounts::new();
284        self.total_count = 0;
285        self.interner = interner;
286        self.cached_entropy = None;
287        self.cached_kl_divergence = None;
288        self.cached_config_hash = None;
289    }
290
291    /// Clear the context node data for memory pool return
292    pub fn clear(&mut self) {
293        self.counts = TransitionCounts::new();
294        self.total_count = 0;
295        self.cached_entropy = None;
296        self.cached_kl_divergence = None;
297        self.cached_config_hash = None;
298        // Keep the interner for potential reuse
299    }
300
301    /// Get cache hit statistics for monitoring
302    pub fn cache_stats(&self) -> (bool, bool) {
303        (self.cached_entropy.is_some(), self.cached_kl_divergence.is_some())
304    }
305}
306
307impl Default for ContextNode {
308    fn default() -> Self {
309        Self {
310            counts: TransitionCounts::new(),
311            total_count: 0,
312            interner: Arc::new(StringInterner::new()),
313            cached_entropy: None,
314            cached_kl_divergence: None,
315            cached_config_hash: None,
316        }
317    }
318}
319
320/// Context tree for storing variable-order Markov chain contexts
321///
322/// Uses trie-based storage for memory efficiency through prefix sharing
323#[derive(Debug, Clone)]
324pub struct ContextTree {
325    /// Trie-based storage for memory-efficient prefix sharing
326    trie: ContextTrie,
327    /// Maximum context order (length)
328    pub max_order: usize,
329    /// String interner for converting between strings and StateIds
330    interner: Arc<StringInterner>,
331}
332
333impl ContextTree {
334    /// Create a new context tree with specified maximum order
335    pub fn new(max_order: usize) -> AnomalyGridResult<Self> {
336        if max_order == 0 {
337            return Err(AnomalyGridError::invalid_max_order(max_order));
338        }
339
340        let interner = Arc::new(StringInterner::new());
341        let trie = ContextTrie::new(max_order, Arc::clone(&interner));
342
343        Ok(Self {
344            trie,
345            max_order,
346            interner,
347        })
348    }
349
350    /// Create a new context tree with existing string interner
351    pub fn with_interner(
352        max_order: usize,
353        interner: Arc<StringInterner>,
354    ) -> AnomalyGridResult<Self> {
355        if max_order == 0 {
356            return Err(AnomalyGridError::invalid_max_order(max_order));
357        }
358
359        let trie = ContextTrie::new(max_order, Arc::clone(&interner));
360
361        Ok(Self {
362            trie,
363            max_order,
364            interner,
365        })
366    }
367
368    /// Build the context tree from a training sequence
369    ///
370    /// # Complexity
371    /// - Time: O(n × max_order × |alphabet|) where n = sequence length
372    /// - Space: O(|alphabet|^max_order) in worst case
373    ///
374    /// # Performance Guarantees
375    /// - Memory usage is bounded by config.memory_limit if set
376    /// - Processing time scales linearly with sequence length
377    /// - Uses string interning to reduce memory duplication
378    pub fn build_from_sequence(
379        &mut self,
380        sequence: &[String],
381        config: &AnomalyGridConfig,
382    ) -> AnomalyGridResult<()> {
383        // Validate sequence length
384        if sequence.len() < config.min_sequence_length {
385            return Err(AnomalyGridError::sequence_too_short(
386                config.min_sequence_length,
387                sequence.len(),
388                "context tree building",
389            ));
390        }
391
392        // Extract contexts of all orders from 1 to max_order
393        for window_size in 1..=self.max_order {
394            for window in sequence.windows(window_size + 1) {
395                // Check memory limit before adding new context
396                if let Some(limit) = config.memory_limit {
397                    if self.trie.context_count() >= limit {
398                        return Err(AnomalyGridError::memory_limit_exceeded(
399                            self.trie.context_count(),
400                            limit,
401                        ));
402                    }
403                }
404
405                // Convert context to StateIds for trie storage
406                let context_state_ids: Vec<StateId> = window[..window_size]
407                    .iter()
408                    .map(|s| self.interner.get_or_intern(s))
409                    .collect();
410                let next_state = &window[window_size];
411
412                // Get or create context node in trie
413                let node = self.trie.get_or_create_context_data(&context_state_ids);
414                node.add_transition(next_state);
415            }
416        }
417
418        Ok(())
419    }
420
421    /// Get the transition probability for a given context and next state
422    pub fn get_transition_probability(&self, context: &[String], next_state: &str) -> Option<f64> {
423        // Convert context to StateIds
424        let context_state_ids: Vec<StateId> = context
425            .iter()
426            .map(|s| self.interner.get_or_intern(s))
427            .collect();
428        
429        self.trie
430            .get_context_data(&context_state_ids)
431            .map(|node| node.get_probability(next_state, &AnomalyGridConfig::default()))
432    }
433
434    /// Get the transition probability with custom config
435    pub fn get_transition_probability_with_config(
436        &self,
437        context: &[String],
438        next_state: &str,
439        config: &AnomalyGridConfig,
440    ) -> Option<f64> {
441        // Convert context to StateIds
442        let context_state_ids: Vec<StateId> = context
443            .iter()
444            .map(|s| self.interner.get_or_intern(s))
445            .collect();
446        
447        self.trie
448            .get_context_data(&context_state_ids)
449            .map(|node| node.get_probability(next_state, config))
450    }
451
452    /// Get a context node for the given context
453    pub fn get_context_node(&self, context: &[String]) -> Option<&ContextNode> {
454        // Convert context to StateIds
455        let context_state_ids: Vec<StateId> = context
456            .iter()
457            .map(|s| self.interner.get_or_intern(s))
458            .collect();
459        
460        self.trie.get_context_data(&context_state_ids)
461    }
462
463    /// Get all contexts of a specific order
464    pub fn get_contexts_of_order(&self, order: usize) -> Vec<Vec<String>> {
465        self.trie
466            .iter_contexts()
467            .filter_map(|(state_ids, _)| {
468                if state_ids.len() == order {
469                    // Convert StateIds back to strings
470                    let strings: Option<Vec<String>> = state_ids
471                        .iter()
472                        .map(|&state_id| self.interner.get_string(state_id))
473                        .collect();
474                    strings
475                } else {
476                    None
477                }
478            })
479            .collect()
480    }
481
482    /// Get the number of contexts stored
483    pub fn context_count(&self) -> usize {
484        self.trie.context_count()
485    }
486
487    /// Get access to the string interner
488    pub fn interner(&self) -> &Arc<StringInterner> {
489        &self.interner
490    }
491
492    /// Get all contexts as a HashMap for compatibility with existing code
493    /// 
494    /// Note: This creates a temporary HashMap and should be used sparingly
495    /// for compatibility with existing tests and code that expects the old interface
496    pub fn contexts(&self) -> HashMap<Vec<String>, ContextNode> {
497        let mut contexts = HashMap::new();
498        
499        for (state_ids, node) in self.trie.iter_contexts() {
500            // Convert StateIds back to strings
501            if let Some(strings) = state_ids
502                .iter()
503                .map(|&state_id| self.interner.get_string(state_id))
504                .collect::<Option<Vec<String>>>()
505            {
506                contexts.insert(strings, node.clone());
507            }
508        }
509        
510        contexts
511    }
512
513    /// Get the trie for internal operations
514    pub(crate) fn trie(&self) -> &ContextTrie {
515        &self.trie
516    }
517}