1use alloc::{
2 boxed::Box,
3 collections::{BTreeMap, BTreeSet},
4};
5use core::cell::RefCell;
6
7pub trait KvMap<K: Ord + Clone, V: Clone>:
12 Extend<(K, V)> + FromIterator<(K, V)> + IntoIterator<Item = (K, V)>
13{
14 fn get(&self, key: &K) -> Option<&V>;
15 fn contains_key(&self, key: &K) -> bool;
16 fn len(&self) -> usize;
17 fn is_empty(&self) -> bool {
18 self.len() == 0
19 }
20 fn insert(&mut self, key: K, value: V) -> Option<V>;
21 fn remove(&mut self, key: &K) -> Option<V>;
22
23 fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_>;
24}
25
26impl<K: Ord + Clone, V: Clone> KvMap<K, V> for BTreeMap<K, V> {
30 fn get(&self, key: &K) -> Option<&V> {
31 self.get(key)
32 }
33
34 fn contains_key(&self, key: &K) -> bool {
35 self.contains_key(key)
36 }
37
38 fn len(&self) -> usize {
39 self.len()
40 }
41
42 fn insert(&mut self, key: K, value: V) -> Option<V> {
43 self.insert(key, value)
44 }
45
46 fn remove(&mut self, key: &K) -> Option<V> {
47 self.remove(key)
48 }
49
50 fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_> {
51 Box::new(self.iter())
52 }
53}
54
55#[derive(Debug, Default, Clone, Eq, PartialEq)]
70pub struct RecordingMap<K, V> {
71 data: BTreeMap<K, V>,
72 updates: BTreeSet<K>,
73 trace: RefCell<BTreeMap<K, V>>,
74}
75
76impl<K: Ord + Clone, V: Clone> RecordingMap<K, V> {
77 pub fn new(init: impl IntoIterator<Item = (K, V)>) -> Self {
82 RecordingMap {
83 data: init.into_iter().collect(),
84 updates: BTreeSet::new(),
85 trace: RefCell::new(BTreeMap::new()),
86 }
87 }
88
89 pub fn inner(&self) -> &BTreeMap<K, V> {
93 &self.data
94 }
95
96 pub fn finalize(self) -> (BTreeMap<K, V>, BTreeMap<K, V>) {
104 (self.data, self.trace.take())
105 }
106
107 #[cfg(test)]
111 pub fn trace_len(&self) -> usize {
112 self.trace.borrow().len()
113 }
114
115 #[cfg(test)]
116 pub fn updates_len(&self) -> usize {
117 self.updates.len()
118 }
119}
120
121impl<K: Ord + Clone, V: Clone> KvMap<K, V> for RecordingMap<K, V> {
122 fn get(&self, key: &K) -> Option<&V> {
129 self.data.get(key).inspect(|&value| {
130 if !self.updates.contains(key) {
131 self.trace.borrow_mut().insert(key.clone(), value.clone());
132 }
133 })
134 }
135
136 fn contains_key(&self, key: &K) -> bool {
140 self.get(key).is_some()
141 }
142
143 fn len(&self) -> usize {
145 self.data.len()
146 }
147
148 fn insert(&mut self, key: K, value: V) -> Option<V> {
156 let new_update = self.updates.insert(key.clone());
157 self.data.insert(key.clone(), value).inspect(|old_value| {
158 if new_update {
159 self.trace.borrow_mut().insert(key, old_value.clone());
160 }
161 })
162 }
163
164 fn remove(&mut self, key: &K) -> Option<V> {
168 self.data.remove(key).inspect(|old_value| {
169 let new_update = self.updates.insert(key.clone());
170 if new_update {
171 self.trace.borrow_mut().insert(key.clone(), old_value.clone());
172 }
173 })
174 }
175
176 fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_> {
181 Box::new(self.data.iter())
182 }
183}
184
185impl<K: Clone + Ord, V: Clone> Extend<(K, V)> for RecordingMap<K, V> {
186 fn extend<T: IntoIterator<Item = (K, V)>>(&mut self, iter: T) {
187 iter.into_iter().for_each(move |(k, v)| {
188 self.insert(k, v);
189 });
190 }
191}
192
193impl<K: Clone + Ord, V: Clone> FromIterator<(K, V)> for RecordingMap<K, V> {
194 fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
195 Self::new(iter)
196 }
197}
198
199impl<K: Clone + Ord, V: Clone> IntoIterator for RecordingMap<K, V> {
200 type Item = (K, V);
201 type IntoIter = alloc::collections::btree_map::IntoIter<K, V>;
202
203 fn into_iter(self) -> Self::IntoIter {
204 self.data.into_iter()
205 }
206}
207
208#[cfg(test)]
212mod tests {
213 use super::*;
214
215 const ITEMS: [(u64, u64); 5] = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)];
216
217 #[test]
218 fn test_get_item() {
219 let map = RecordingMap::new(ITEMS.to_vec());
221
222 let get_items = [0, 1, 2];
224 for key in get_items.iter() {
225 map.get(key);
226 }
227
228 let (_, proof) = map.finalize();
230
231 for (key, value) in ITEMS.iter() {
233 match get_items.contains(key) {
234 true => assert_eq!(proof.get(key), Some(value)),
235 false => assert_eq!(proof.get(key), None),
236 }
237 }
238 }
239
240 #[test]
241 fn test_contains_key() {
242 let map = RecordingMap::new(ITEMS.to_vec());
244
245 let get_items = [0, 1, 2];
247 for key in get_items.iter() {
248 map.contains_key(key);
249 }
250
251 let (_, proof) = map.finalize();
253
254 for (key, _) in ITEMS.iter() {
256 match get_items.contains(key) {
257 true => assert!(proof.contains_key(key)),
258 false => assert!(!proof.contains_key(key)),
259 }
260 }
261 }
262
263 #[test]
264 fn test_len() {
265 let mut map = RecordingMap::new(ITEMS.to_vec());
267 assert_eq!(map.len(), ITEMS.len());
269
270 map.insert(4, 5);
273 assert_eq!(map.len(), ITEMS.len());
274 assert_eq!(map.trace_len(), 1);
275 assert_eq!(map.updates_len(), 1);
276
277 map.insert(5, 5);
280 assert_eq!(map.len(), ITEMS.len() + 1);
281 assert_eq!(map.trace_len(), 1);
282 assert_eq!(map.updates_len(), 2);
283
284 let get_items = [0, 1, 2];
287 for key in get_items.iter() {
288 map.contains_key(key);
289 }
290 assert_eq!(map.trace_len(), 4);
291 assert_eq!(map.updates_len(), 2);
292
293 let get_items = [0, 1, 2];
296 for key in get_items.iter() {
297 map.contains_key(key);
298 }
299 assert_eq!(map.trace_len(), 4);
300 assert_eq!(map.updates_len(), 2);
301
302 let _val = map.get(&5).unwrap();
305 assert_eq!(map.trace_len(), 4);
306 assert_eq!(map.updates_len(), 2);
307
308 map.insert(5, 11);
311 assert_eq!(map.trace_len(), 4);
312 assert_eq!(map.updates_len(), 2);
313
314 let (_, proof) = map.finalize();
317
318 assert_eq!(proof.len(), get_items.len() + 1);
321 }
322
323 #[test]
324 fn test_iter() {
325 let mut map = RecordingMap::new(ITEMS.to_vec());
326 assert!(map.iter().all(|(x, y)| ITEMS.contains(&(*x, *y))));
327
328 let new_value = 5;
331 map.insert(4, new_value);
332 assert_eq!(map.iter().count(), ITEMS.len());
333 assert!(map.iter().all(|(x, y)| if x == &4 {
334 y == &new_value
335 } else {
336 ITEMS.contains(&(*x, *y))
337 }));
338 }
339
340 #[test]
341 fn test_is_empty() {
342 let empty_map: RecordingMap<u64, u64> = RecordingMap::default();
344 assert!(empty_map.is_empty());
345
346 let map = RecordingMap::new(ITEMS.to_vec());
348 assert!(!map.is_empty());
349 }
350
351 #[test]
352 fn test_remove() {
353 let mut map = RecordingMap::new(ITEMS.to_vec());
354
355 let key = 0;
357 let value = map.remove(&key).unwrap();
358 assert_eq!(value, ITEMS[0].1);
359 assert_eq!(map.len(), ITEMS.len() - 1);
360 assert_eq!(map.trace_len(), 1);
361 assert_eq!(map.updates_len(), 1);
362
363 let key = 0;
365 let value = 0;
366 map.insert(key, value);
367 let value = map.remove(&key).unwrap();
368 assert_eq!(value, 0);
369 assert_eq!(map.len(), ITEMS.len() - 1);
370 assert_eq!(map.trace_len(), 1);
371 assert_eq!(map.updates_len(), 1);
372
373 let key = 100;
375 let value = map.remove(&key);
376 assert_eq!(value, None);
377 assert_eq!(map.len(), ITEMS.len() - 1);
378 assert_eq!(map.trace_len(), 1);
379 assert_eq!(map.updates_len(), 1);
380
381 let key = 100;
383 let value = 100;
384 map.insert(key, value);
385 let value = map.remove(&key).unwrap();
386 assert_eq!(value, 100);
387 assert_eq!(map.len(), ITEMS.len() - 1);
388 assert_eq!(map.trace_len(), 1);
389 assert_eq!(map.updates_len(), 2);
390
391 let (_, proof) = map.finalize();
393
394 for (key, value) in ITEMS.iter() {
396 match key {
397 0 => assert_eq!(proof.get(key), Some(value)),
398 _ => assert_eq!(proof.get(key), None),
399 }
400 }
401 }
402}