anomaly_grid/
transition_counts.rs

1//! Optimized transition count storage for small collections
2//!
3//! This module provides memory-efficient storage for transition counts,
4//! optimizing for the common case where contexts have few transitions.
5
6use crate::string_interner::StateId;
7use smallvec::{smallvec, SmallVec};
8use std::collections::HashMap;
9
10/// Memory-efficient storage for transition counts
11///
12/// Uses SmallVec for small collections (≤4 transitions) and HashMap for larger ones.
13/// Based on analysis showing 100% of typical contexts have ≤4 transitions.
14#[derive(Debug, Clone)]
15pub enum TransitionCounts {
16    /// Inline storage for small collections (≤4 transitions)
17    /// Uses stack allocation to avoid heap overhead
18    Small(SmallVec<[(StateId, usize); 4]>),
19
20    /// HashMap storage for large collections (>4 transitions)
21    /// Falls back to HashMap when small storage is exceeded
22    Large(HashMap<StateId, usize>),
23}
24
25impl TransitionCounts {
26    /// Create a new empty transition counts collection
27    pub fn new() -> Self {
28        Self::Small(smallvec![])
29    }
30
31    /// Get the count for a specific state
32    pub fn get(&self, state_id: StateId) -> usize {
33        match self {
34            Self::Small(vec) => vec
35                .iter()
36                .find(|(id, _)| *id == state_id)
37                .map(|(_, count)| *count)
38                .unwrap_or(0),
39            Self::Large(map) => map.get(&state_id).copied().unwrap_or(0),
40        }
41    }
42
43    /// Insert or update a count for a state
44    pub fn insert(&mut self, state_id: StateId, count: usize) {
45        match self {
46            Self::Small(vec) => {
47                // Try to find existing entry
48                if let Some((_, existing_count)) = vec.iter_mut().find(|(id, _)| *id == state_id) {
49                    *existing_count = count;
50                    return;
51                }
52
53                // Check if we need to promote to Large
54                if vec.len() >= 4 {
55                    // Promote to HashMap
56                    let mut map = HashMap::new();
57                    for (id, c) in vec.iter() {
58                        map.insert(*id, *c);
59                    }
60                    map.insert(state_id, count);
61                    *self = Self::Large(map);
62                } else {
63                    // Add to small vector
64                    vec.push((state_id, count));
65                }
66            }
67            Self::Large(map) => {
68                map.insert(state_id, count);
69            }
70        }
71    }
72
73    /// Increment the count for a state, inserting if not present
74    pub fn increment(&mut self, state_id: StateId) {
75        let current = self.get(state_id);
76        self.insert(state_id, current + 1);
77    }
78
79    /// Get the number of unique states
80    pub fn len(&self) -> usize {
81        match self {
82            Self::Small(vec) => vec.len(),
83            Self::Large(map) => map.len(),
84        }
85    }
86
87    /// Check if the collection is empty
88    pub fn is_empty(&self) -> bool {
89        self.len() == 0
90    }
91
92    /// Iterate over all (state_id, count) pairs
93    pub fn iter(&self) -> TransitionCountsIter {
94        match self {
95            Self::Small(vec) => TransitionCountsIter::Small(vec.iter()),
96            Self::Large(map) => TransitionCountsIter::Large(map.iter()),
97        }
98    }
99
100    /// Get all state IDs
101    pub fn keys(&self) -> impl Iterator<Item = StateId> + '_ {
102        self.iter().map(|(state_id, _)| state_id)
103    }
104
105    /// Get all counts
106    pub fn values(&self) -> impl Iterator<Item = usize> + '_ {
107        self.iter().map(|(_, count)| count)
108    }
109
110    /// Check if the collection is using small storage
111    pub fn is_small(&self) -> bool {
112        matches!(self, Self::Small(_))
113    }
114
115    /// Get memory usage estimate in bytes
116    pub fn memory_usage(&self) -> usize {
117        match self {
118            Self::Small(vec) => {
119                // SmallVec overhead + inline storage
120                std::mem::size_of::<SmallVec<[(StateId, usize); 4]>>()
121                    + if vec.spilled() {
122                        vec.capacity() * std::mem::size_of::<(StateId, usize)>()
123                    } else {
124                        0 // Inline storage already counted
125                    }
126            }
127            Self::Large(map) => {
128                // HashMap overhead + entries
129                std::mem::size_of::<HashMap<StateId, usize>>()
130                    + map.capacity()
131                        * (std::mem::size_of::<StateId>() + std::mem::size_of::<usize>())
132            }
133        }
134    }
135}
136
137impl Default for TransitionCounts {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143/// Iterator over transition counts
144pub enum TransitionCountsIter<'a> {
145    Small(std::slice::Iter<'a, (StateId, usize)>),
146    Large(std::collections::hash_map::Iter<'a, StateId, usize>),
147}
148
149impl<'a> Iterator for TransitionCountsIter<'a> {
150    type Item = (StateId, usize);
151
152    fn next(&mut self) -> Option<Self::Item> {
153        match self {
154            Self::Small(iter) => iter.next().map(|(id, count)| (*id, *count)),
155            Self::Large(iter) => iter.next().map(|(id, count)| (*id, *count)),
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use crate::string_interner::StateId;
164
165    #[test]
166    fn test_small_collection_operations() {
167        let mut counts = TransitionCounts::new();
168        assert!(counts.is_empty());
169        assert!(counts.is_small());
170
171        // Add some states
172        counts.increment(StateId::new(1));
173        counts.increment(StateId::new(2));
174        counts.increment(StateId::new(1)); // Increment existing
175
176        assert_eq!(counts.len(), 2);
177        assert_eq!(counts.get(StateId::new(1)), 2);
178        assert_eq!(counts.get(StateId::new(2)), 1);
179        assert_eq!(counts.get(StateId::new(3)), 0);
180        assert!(counts.is_small());
181    }
182
183    #[test]
184    fn test_promotion_to_large() {
185        let mut counts = TransitionCounts::new();
186
187        // Add 4 states (still small)
188        for i in 1..=4 {
189            counts.increment(StateId::new(i));
190        }
191        assert!(counts.is_small());
192        assert_eq!(counts.len(), 4);
193
194        // Add 5th state (should promote to large)
195        counts.increment(StateId::new(5));
196        assert!(!counts.is_small());
197        assert_eq!(counts.len(), 5);
198
199        // Verify all data is preserved
200        for i in 1..=5 {
201            assert_eq!(counts.get(StateId::new(i)), 1);
202        }
203    }
204
205    #[test]
206    fn test_large_collection_operations() {
207        let mut counts = TransitionCounts::new();
208
209        // Force promotion to large
210        for i in 1..=10 {
211            counts.increment(StateId::new(i));
212        }
213        assert!(!counts.is_small());
214        assert_eq!(counts.len(), 10);
215
216        // Test operations on large collection
217        counts.increment(StateId::new(5)); // Should be 2 now
218        assert_eq!(counts.get(StateId::new(5)), 2);
219        assert_eq!(counts.get(StateId::new(1)), 1);
220    }
221
222    #[test]
223    fn test_iteration() {
224        let mut counts = TransitionCounts::new();
225        counts.increment(StateId::new(1));
226        counts.increment(StateId::new(2));
227        counts.increment(StateId::new(1));
228
229        let collected: Vec<_> = counts.iter().collect();
230        assert_eq!(collected.len(), 2);
231
232        // Check that we have the right states (order may vary)
233        let state_1_count = collected
234            .iter()
235            .find(|(id, _)| *id == StateId::new(1))
236            .unwrap()
237            .1;
238        let state_2_count = collected
239            .iter()
240            .find(|(id, _)| *id == StateId::new(2))
241            .unwrap()
242            .1;
243
244        assert_eq!(state_1_count, 2);
245        assert_eq!(state_2_count, 1);
246    }
247
248    #[test]
249    fn test_memory_usage() {
250        let small_counts = TransitionCounts::new();
251        let small_usage = small_counts.memory_usage();
252
253        let mut large_counts = TransitionCounts::new();
254        for i in 1..=10 {
255            large_counts.increment(StateId::new(i));
256        }
257        let large_usage = large_counts.memory_usage();
258
259        // Small should use less memory for small collections
260        assert!(small_usage > 0);
261        assert!(large_usage > 0);
262
263        // For this test, we just verify the calculation works
264        // The actual comparison depends on the specific sizes
265        println!("Small usage: {small_usage} bytes");
266        println!("Large usage: {large_usage} bytes");
267    }
268
269    #[test]
270    fn test_keys_and_values() {
271        let mut counts = TransitionCounts::new();
272        counts.increment(StateId::new(1));
273        counts.increment(StateId::new(2));
274        counts.increment(StateId::new(1));
275
276        let keys: Vec<_> = counts.keys().collect();
277        let values: Vec<_> = counts.values().collect();
278
279        assert_eq!(keys.len(), 2);
280        assert_eq!(values.len(), 2);
281        assert!(keys.contains(&StateId::new(1)));
282        assert!(keys.contains(&StateId::new(2)));
283        assert!(values.contains(&1));
284        assert!(values.contains(&2));
285    }
286}