anomaly_grid/
transition_counts.rs1use crate::string_interner::StateId;
7use smallvec::{smallvec, SmallVec};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
15pub enum TransitionCounts {
16 Small(SmallVec<[(StateId, usize); 4]>),
19
20 Large(HashMap<StateId, usize>),
23}
24
25impl TransitionCounts {
26 pub fn new() -> Self {
28 Self::Small(smallvec![])
29 }
30
31 pub fn get(&self, state_id: StateId) -> usize {
33 match self {
34 Self::Small(vec) => vec
35 .iter()
36 .find(|(id, _)| *id == state_id)
37 .map(|(_, count)| *count)
38 .unwrap_or(0),
39 Self::Large(map) => map.get(&state_id).copied().unwrap_or(0),
40 }
41 }
42
43 pub fn insert(&mut self, state_id: StateId, count: usize) {
45 match self {
46 Self::Small(vec) => {
47 if let Some((_, existing_count)) = vec.iter_mut().find(|(id, _)| *id == state_id) {
49 *existing_count = count;
50 return;
51 }
52
53 if vec.len() >= 4 {
55 let mut map = HashMap::new();
57 for (id, c) in vec.iter() {
58 map.insert(*id, *c);
59 }
60 map.insert(state_id, count);
61 *self = Self::Large(map);
62 } else {
63 vec.push((state_id, count));
65 }
66 }
67 Self::Large(map) => {
68 map.insert(state_id, count);
69 }
70 }
71 }
72
73 pub fn increment(&mut self, state_id: StateId) {
75 let current = self.get(state_id);
76 self.insert(state_id, current + 1);
77 }
78
79 pub fn len(&self) -> usize {
81 match self {
82 Self::Small(vec) => vec.len(),
83 Self::Large(map) => map.len(),
84 }
85 }
86
87 pub fn is_empty(&self) -> bool {
89 self.len() == 0
90 }
91
92 pub fn iter(&self) -> TransitionCountsIter {
94 match self {
95 Self::Small(vec) => TransitionCountsIter::Small(vec.iter()),
96 Self::Large(map) => TransitionCountsIter::Large(map.iter()),
97 }
98 }
99
100 pub fn keys(&self) -> impl Iterator<Item = StateId> + '_ {
102 self.iter().map(|(state_id, _)| state_id)
103 }
104
105 pub fn values(&self) -> impl Iterator<Item = usize> + '_ {
107 self.iter().map(|(_, count)| count)
108 }
109
110 pub fn is_small(&self) -> bool {
112 matches!(self, Self::Small(_))
113 }
114
115 pub fn memory_usage(&self) -> usize {
117 match self {
118 Self::Small(vec) => {
119 std::mem::size_of::<SmallVec<[(StateId, usize); 4]>>()
121 + if vec.spilled() {
122 vec.capacity() * std::mem::size_of::<(StateId, usize)>()
123 } else {
124 0 }
126 }
127 Self::Large(map) => {
128 std::mem::size_of::<HashMap<StateId, usize>>()
130 + map.capacity()
131 * (std::mem::size_of::<StateId>() + std::mem::size_of::<usize>())
132 }
133 }
134 }
135}
136
137impl Default for TransitionCounts {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143pub enum TransitionCountsIter<'a> {
145 Small(std::slice::Iter<'a, (StateId, usize)>),
146 Large(std::collections::hash_map::Iter<'a, StateId, usize>),
147}
148
149impl<'a> Iterator for TransitionCountsIter<'a> {
150 type Item = (StateId, usize);
151
152 fn next(&mut self) -> Option<Self::Item> {
153 match self {
154 Self::Small(iter) => iter.next().map(|(id, count)| (*id, *count)),
155 Self::Large(iter) => iter.next().map(|(id, count)| (*id, *count)),
156 }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::string_interner::StateId;
164
165 #[test]
166 fn test_small_collection_operations() {
167 let mut counts = TransitionCounts::new();
168 assert!(counts.is_empty());
169 assert!(counts.is_small());
170
171 counts.increment(StateId::new(1));
173 counts.increment(StateId::new(2));
174 counts.increment(StateId::new(1)); assert_eq!(counts.len(), 2);
177 assert_eq!(counts.get(StateId::new(1)), 2);
178 assert_eq!(counts.get(StateId::new(2)), 1);
179 assert_eq!(counts.get(StateId::new(3)), 0);
180 assert!(counts.is_small());
181 }
182
183 #[test]
184 fn test_promotion_to_large() {
185 let mut counts = TransitionCounts::new();
186
187 for i in 1..=4 {
189 counts.increment(StateId::new(i));
190 }
191 assert!(counts.is_small());
192 assert_eq!(counts.len(), 4);
193
194 counts.increment(StateId::new(5));
196 assert!(!counts.is_small());
197 assert_eq!(counts.len(), 5);
198
199 for i in 1..=5 {
201 assert_eq!(counts.get(StateId::new(i)), 1);
202 }
203 }
204
205 #[test]
206 fn test_large_collection_operations() {
207 let mut counts = TransitionCounts::new();
208
209 for i in 1..=10 {
211 counts.increment(StateId::new(i));
212 }
213 assert!(!counts.is_small());
214 assert_eq!(counts.len(), 10);
215
216 counts.increment(StateId::new(5)); assert_eq!(counts.get(StateId::new(5)), 2);
219 assert_eq!(counts.get(StateId::new(1)), 1);
220 }
221
222 #[test]
223 fn test_iteration() {
224 let mut counts = TransitionCounts::new();
225 counts.increment(StateId::new(1));
226 counts.increment(StateId::new(2));
227 counts.increment(StateId::new(1));
228
229 let collected: Vec<_> = counts.iter().collect();
230 assert_eq!(collected.len(), 2);
231
232 let state_1_count = collected
234 .iter()
235 .find(|(id, _)| *id == StateId::new(1))
236 .unwrap()
237 .1;
238 let state_2_count = collected
239 .iter()
240 .find(|(id, _)| *id == StateId::new(2))
241 .unwrap()
242 .1;
243
244 assert_eq!(state_1_count, 2);
245 assert_eq!(state_2_count, 1);
246 }
247
248 #[test]
249 fn test_memory_usage() {
250 let small_counts = TransitionCounts::new();
251 let small_usage = small_counts.memory_usage();
252
253 let mut large_counts = TransitionCounts::new();
254 for i in 1..=10 {
255 large_counts.increment(StateId::new(i));
256 }
257 let large_usage = large_counts.memory_usage();
258
259 assert!(small_usage > 0);
261 assert!(large_usage > 0);
262
263 println!("Small usage: {small_usage} bytes");
266 println!("Large usage: {large_usage} bytes");
267 }
268
269 #[test]
270 fn test_keys_and_values() {
271 let mut counts = TransitionCounts::new();
272 counts.increment(StateId::new(1));
273 counts.increment(StateId::new(2));
274 counts.increment(StateId::new(1));
275
276 let keys: Vec<_> = counts.keys().collect();
277 let values: Vec<_> = counts.values().collect();
278
279 assert_eq!(keys.len(), 2);
280 assert_eq!(values.len(), 2);
281 assert!(keys.contains(&StateId::new(1)));
282 assert!(keys.contains(&StateId::new(2)));
283 assert!(values.contains(&1));
284 assert!(values.contains(&2));
285 }
286}