rhaki_cw_multi_test/
transactions.rs

1use crate::error::AnyResult;
2use cosmwasm_std::Storage;
3use cosmwasm_std::{Order, Record};
4use std::cmp::Ordering;
5use std::collections::BTreeMap;
6use std::iter;
7use std::iter::Peekable;
8use std::ops::{Bound, RangeBounds};
9
10/// The BTreeMap specific key-value pair reference type, as returned by `BTreeMap<Vec<u8>, T>::range`.
11/// This is internal as it can change any time if the map implementation is swapped out.
12type BTreeMapPairRef<'a, T = Vec<u8>> = (&'a Vec<u8>, &'a T);
13
14/// transactional
15pub fn transactional<F, T>(base: &mut dyn Storage, action: F) -> AnyResult<T>
16where
17    F: FnOnce(&mut dyn Storage, &dyn Storage) -> AnyResult<T>,
18{
19    let mut cache = StorageTransaction::new(base);
20    let res = action(&mut cache, base)?;
21    cache.prepare().commit(base);
22    Ok(res)
23}
24
25/// StorageTransaction is a wrapper around a Storage that allows cache before commit
26pub struct StorageTransaction<'a> {
27    /// read-only access to backing storage
28    storage: &'a dyn Storage,
29    /// these are local changes not flushed to backing storage
30    local_state: BTreeMap<Vec<u8>, Delta>,
31    /// a log of local changes not yet flushed to backing storage
32    rep_log: RepLog,
33}
34
35impl<'a> StorageTransaction<'a> {
36    /// Create a new StoratgeTransaction
37    pub fn new(storage: &'a dyn Storage) -> Self {
38        StorageTransaction {
39            storage,
40            local_state: BTreeMap::new(),
41            rep_log: RepLog::new(),
42        }
43    }
44
45    /// prepares this transaction to be committed to storage
46    pub fn prepare(self) -> RepLog {
47        self.rep_log
48    }
49}
50
51impl<'a> Storage for StorageTransaction<'a> {
52    fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
53        match self.local_state.get(key) {
54            Some(val) => match val {
55                Delta::Set { value } => Some(value.clone()),
56                Delta::Delete {} => None,
57            },
58            None => self.storage.get(key),
59        }
60    }
61
62    /// Range allows iteration over a set of keys, either forwards or backwards
63    /// uses standard Rust range notation, e.g. `db.range(b"foo"‥b"bar")`,
64    /// works also in reverse order.
65    fn range<'b>(
66        &'b self,
67        start: Option<&[u8]>,
68        end: Option<&[u8]>,
69        order: Order,
70    ) -> Box<dyn Iterator<Item = Record> + 'b> {
71        let bounds = range_bounds(start, end);
72
73        // BTreeMap.range panics if range is start > end.
74        // However, this cases represent just empty range and we treat it as such.
75        let local: Box<dyn Iterator<Item = BTreeMapPairRef<Delta>>> =
76            match (bounds.start_bound(), bounds.end_bound()) {
77                (Bound::Included(start), Bound::Excluded(end)) if start > end => {
78                    Box::new(iter::empty())
79                }
80                _ => {
81                    let local_raw = self.local_state.range(bounds);
82                    match order {
83                        Order::Ascending => Box::new(local_raw),
84                        Order::Descending => Box::new(local_raw.rev()),
85                    }
86                }
87            };
88
89        let base = self.storage.range(start, end, order);
90        let merged = MergeOverlay::new(local, base, order);
91        Box::new(merged)
92    }
93
94    fn set(&mut self, key: &[u8], value: &[u8]) {
95        let op = Op::Set {
96            key: key.to_vec(),
97            value: value.to_vec(),
98        };
99        self.local_state.insert(key.to_vec(), op.to_delta());
100        self.rep_log.append(op);
101    }
102
103    fn remove(&mut self, key: &[u8]) {
104        let op = Op::Delete { key: key.to_vec() };
105        self.local_state.insert(key.to_vec(), op.to_delta());
106        self.rep_log.append(op);
107    }
108}
109
110pub struct RepLog {
111    /// this is a list of changes to be written to backing storage upon commit
112    ops_log: Vec<Op>,
113}
114
115impl RepLog {
116    fn new() -> Self {
117        RepLog { ops_log: vec![] }
118    }
119
120    /// appends an op to the list of changes to be applied upon commit
121    fn append(&mut self, op: Op) {
122        self.ops_log.push(op);
123    }
124
125    /// applies the stored list of `Op`s to the provided `Storage`
126    pub fn commit(self, storage: &mut dyn Storage) {
127        for op in self.ops_log {
128            op.apply(storage);
129        }
130    }
131}
132
133/// Op is the user operation, which can be stored in the RepLog.
134/// Currently: `Set` or `Delete`.
135enum Op {
136    /// represents the `Set` operation for setting a key-value pair in storage
137    Set {
138        key: Vec<u8>,
139        value: Vec<u8>,
140    },
141    Delete {
142        key: Vec<u8>,
143    },
144}
145
146impl Op {
147    /// applies this `Op` to the provided storage
148    pub fn apply(&self, storage: &mut dyn Storage) {
149        match self {
150            Op::Set { key, value } => storage.set(key, value),
151            Op::Delete { key } => storage.remove(key),
152        }
153    }
154
155    /// converts the Op to a delta, which can be stored in a local cache
156    pub fn to_delta(&self) -> Delta {
157        match self {
158            Op::Set { value, .. } => Delta::Set {
159                value: value.clone(),
160            },
161            Op::Delete { .. } => Delta::Delete {},
162        }
163    }
164}
165
166/// Delta is the changes, stored in the local transaction cache.
167/// This is either Set{value} or Delete{}. Note that this is the "value"
168/// part of a BTree, so the Key (from the Op) is stored separately.
169enum Delta {
170    Set { value: Vec<u8> },
171    Delete {},
172}
173
174struct MergeOverlay<'a, L, R>
175where
176    L: Iterator<Item = BTreeMapPairRef<'a, Delta>>,
177    R: Iterator<Item = Record>,
178{
179    left: Peekable<L>,
180    right: Peekable<R>,
181    order: Order,
182}
183
184impl<'a, L, R> MergeOverlay<'a, L, R>
185where
186    L: Iterator<Item = BTreeMapPairRef<'a, Delta>>,
187    R: Iterator<Item = Record>,
188{
189    fn new(left: L, right: R, order: Order) -> Self {
190        MergeOverlay {
191            left: left.peekable(),
192            right: right.peekable(),
193            order,
194        }
195    }
196
197    fn pick_match(&mut self, lkey: Vec<u8>, rkey: Vec<u8>) -> Option<Record> {
198        // compare keys - result is such that Ordering::Less => return left side
199        let order = match self.order {
200            Order::Ascending => lkey.cmp(&rkey),
201            Order::Descending => rkey.cmp(&lkey),
202        };
203
204        // left must be translated and filtered before return, not so with right
205        match order {
206            Ordering::Less => self.take_left(),
207            Ordering::Equal => {
208                //
209                let _ = self.right.next();
210                self.take_left()
211            }
212            Ordering::Greater => self.right.next(),
213        }
214    }
215
216    /// take_left must only be called when we know self.left.next() will return Some
217    fn take_left(&mut self) -> Option<Record> {
218        let (lkey, lval) = self.left.next().unwrap();
219        match lval {
220            Delta::Set { value } => Some((lkey.clone(), value.clone())),
221            Delta::Delete {} => self.next(),
222        }
223    }
224}
225
226impl<'a, L, R> Iterator for MergeOverlay<'a, L, R>
227where
228    L: Iterator<Item = BTreeMapPairRef<'a, Delta>>,
229    R: Iterator<Item = Record>,
230{
231    type Item = Record;
232
233    fn next(&mut self) -> Option<Self::Item> {
234        let (left, right) = (self.left.peek(), self.right.peek());
235        match (left, right) {
236            (Some(litem), Some(ritem)) => {
237                let (lkey, _) = litem;
238                let (rkey, _) = ritem;
239
240                // we just use cloned keys to avoid double mutable references
241                // (we must release the return value from peek, before beginning to call next or other mut methods
242                let (l, r) = (lkey.to_vec(), rkey.to_vec());
243                self.pick_match(l, r)
244            }
245            (Some(_), None) => self.take_left(),
246            (None, Some(_)) => self.right.next(),
247            (None, None) => None,
248        }
249    }
250}
251
252fn range_bounds(start: Option<&[u8]>, end: Option<&[u8]>) -> impl RangeBounds<Vec<u8>> {
253    (
254        start.map_or(Bound::Unbounded, |x| Bound::Included(x.to_vec())),
255        end.map_or(Bound::Unbounded, |x| Bound::Excluded(x.to_vec())),
256    )
257}
258
259#[cfg(test)]
260mod test {
261    use super::*;
262    use std::cell::RefCell;
263    use std::ops::{Deref, DerefMut};
264
265    use cosmwasm_std::MemoryStorage;
266
267    #[test]
268    fn wrap_storage() {
269        let mut store = MemoryStorage::new();
270        let mut wrap = StorageTransaction::new(&store);
271        wrap.set(b"foo", b"bar");
272
273        assert_eq!(None, store.get(b"foo"));
274        wrap.prepare().commit(&mut store);
275        assert_eq!(Some(b"bar".to_vec()), store.get(b"foo"));
276    }
277
278    #[test]
279    fn wrap_ref_cell() {
280        let store = RefCell::new(MemoryStorage::new());
281        let ops = {
282            let refer = store.borrow();
283            let mut wrap = StorageTransaction::new(refer.deref());
284            wrap.set(b"foo", b"bar");
285            assert_eq!(None, store.borrow().get(b"foo"));
286            wrap.prepare()
287        };
288        ops.commit(store.borrow_mut().deref_mut());
289        assert_eq!(Some(b"bar".to_vec()), store.borrow().get(b"foo"));
290    }
291
292    #[test]
293    fn wrap_box_storage() {
294        let mut store: Box<MemoryStorage> = Box::new(MemoryStorage::new());
295        let mut wrap = StorageTransaction::new(store.as_ref());
296        wrap.set(b"foo", b"bar");
297
298        assert_eq!(None, store.get(b"foo"));
299        wrap.prepare().commit(store.as_mut());
300        assert_eq!(Some(b"bar".to_vec()), store.get(b"foo"));
301    }
302
303    #[test]
304    fn wrap_box_dyn_storage() {
305        let mut store: Box<dyn Storage> = Box::new(MemoryStorage::new());
306        let mut wrap = StorageTransaction::new(store.as_ref());
307        wrap.set(b"foo", b"bar");
308
309        assert_eq!(None, store.get(b"foo"));
310        wrap.prepare().commit(store.as_mut());
311        assert_eq!(Some(b"bar".to_vec()), store.get(b"foo"));
312    }
313
314    #[test]
315    fn wrap_ref_cell_dyn_storage() {
316        let inner: Box<dyn Storage> = Box::new(MemoryStorage::new());
317        let store = RefCell::new(inner);
318        // Tricky but working
319        // 1. we cannot inline StorageTransaction::new(store.borrow().as_ref()) as Ref must outlive StorageTransaction
320        // 2. we cannot call ops.commit() until refer is out of scope - borrow_mut() and borrow() on the same object
321        // This can work with some careful scoping, this provides a good reference
322        let ops = {
323            let refer = store.borrow();
324            let mut wrap = StorageTransaction::new(refer.as_ref());
325            wrap.set(b"foo", b"bar");
326
327            assert_eq!(None, store.borrow().get(b"foo"));
328            wrap.prepare()
329        };
330        ops.commit(store.borrow_mut().as_mut());
331        assert_eq!(Some(b"bar".to_vec()), store.borrow().get(b"foo"));
332    }
333
334    // iterator_test_suite takes a storage, adds data and runs iterator tests
335    // the storage must previously have exactly one key: "foo" = "bar"
336    // (this allows us to test StorageTransaction and other wrapped storage better)
337    fn iterator_test_suite<S: Storage>(store: &mut S) {
338        // ensure we had previously set "foo" = "bar"
339        assert_eq!(store.get(b"foo"), Some(b"bar".to_vec()));
340        assert_eq!(store.range(None, None, Order::Ascending).count(), 1);
341
342        // setup - add some data, and delete part of it as well
343        store.set(b"ant", b"hill");
344        store.set(b"ze", b"bra");
345
346        // noise that should be ignored
347        store.set(b"bye", b"bye");
348        store.remove(b"bye");
349
350        // unbounded
351        {
352            let iter = store.range(None, None, Order::Ascending);
353            let elements: Vec<Record> = iter.collect();
354            assert_eq!(
355                elements,
356                vec![
357                    (b"ant".to_vec(), b"hill".to_vec()),
358                    (b"foo".to_vec(), b"bar".to_vec()),
359                    (b"ze".to_vec(), b"bra".to_vec()),
360                ]
361            );
362        }
363
364        // unbounded (descending)
365        {
366            let iter = store.range(None, None, Order::Descending);
367            let elements: Vec<Record> = iter.collect();
368            assert_eq!(
369                elements,
370                vec![
371                    (b"ze".to_vec(), b"bra".to_vec()),
372                    (b"foo".to_vec(), b"bar".to_vec()),
373                    (b"ant".to_vec(), b"hill".to_vec()),
374                ]
375            );
376        }
377
378        // bounded
379        {
380            let iter = store.range(Some(b"f"), Some(b"n"), Order::Ascending);
381            let elements: Vec<Record> = iter.collect();
382            assert_eq!(elements, vec![(b"foo".to_vec(), b"bar".to_vec())]);
383        }
384
385        // bounded (descending)
386        {
387            let iter = store.range(Some(b"air"), Some(b"loop"), Order::Descending);
388            let elements: Vec<Record> = iter.collect();
389            assert_eq!(
390                elements,
391                vec![
392                    (b"foo".to_vec(), b"bar".to_vec()),
393                    (b"ant".to_vec(), b"hill".to_vec()),
394                ]
395            );
396        }
397
398        // bounded empty [a, a)
399        {
400            let iter = store.range(Some(b"foo"), Some(b"foo"), Order::Ascending);
401            let elements: Vec<Record> = iter.collect();
402            assert_eq!(elements, vec![]);
403        }
404
405        // bounded empty [a, a) (descending)
406        {
407            let iter = store.range(Some(b"foo"), Some(b"foo"), Order::Descending);
408            let elements: Vec<Record> = iter.collect();
409            assert_eq!(elements, vec![]);
410        }
411
412        // bounded empty [a, b) with b < a
413        {
414            let iter = store.range(Some(b"z"), Some(b"a"), Order::Ascending);
415            let elements: Vec<Record> = iter.collect();
416            assert_eq!(elements, vec![]);
417        }
418
419        // bounded empty [a, b) with b < a (descending)
420        {
421            let iter = store.range(Some(b"z"), Some(b"a"), Order::Descending);
422            let elements: Vec<Record> = iter.collect();
423            assert_eq!(elements, vec![]);
424        }
425
426        // right unbounded
427        {
428            let iter = store.range(Some(b"f"), None, Order::Ascending);
429            let elements: Vec<Record> = iter.collect();
430            assert_eq!(
431                elements,
432                vec![
433                    (b"foo".to_vec(), b"bar".to_vec()),
434                    (b"ze".to_vec(), b"bra".to_vec()),
435                ]
436            );
437        }
438
439        // right unbounded (descending)
440        {
441            let iter = store.range(Some(b"f"), None, Order::Descending);
442            let elements: Vec<Record> = iter.collect();
443            assert_eq!(
444                elements,
445                vec![
446                    (b"ze".to_vec(), b"bra".to_vec()),
447                    (b"foo".to_vec(), b"bar".to_vec()),
448                ]
449            );
450        }
451
452        // left unbounded
453        {
454            let iter = store.range(None, Some(b"f"), Order::Ascending);
455            let elements: Vec<Record> = iter.collect();
456            assert_eq!(elements, vec![(b"ant".to_vec(), b"hill".to_vec()),]);
457        }
458
459        // left unbounded (descending)
460        {
461            let iter = store.range(None, Some(b"no"), Order::Descending);
462            let elements: Vec<Record> = iter.collect();
463            assert_eq!(
464                elements,
465                vec![
466                    (b"foo".to_vec(), b"bar".to_vec()),
467                    (b"ant".to_vec(), b"hill".to_vec()),
468                ]
469            );
470        }
471    }
472
473    #[test]
474    fn delete_local() {
475        let mut base = Box::new(MemoryStorage::new());
476        let mut check = StorageTransaction::new(base.as_ref());
477        check.set(b"foo", b"bar");
478        check.set(b"food", b"bank");
479        check.remove(b"foo");
480
481        assert_eq!(check.get(b"foo"), None);
482        assert_eq!(check.get(b"food"), Some(b"bank".to_vec()));
483
484        // now commit to base and query there
485        check.prepare().commit(base.as_mut());
486        assert_eq!(base.get(b"foo"), None);
487        assert_eq!(base.get(b"food"), Some(b"bank".to_vec()));
488    }
489
490    #[test]
491    fn delete_from_base() {
492        let mut base = Box::new(MemoryStorage::new());
493        base.set(b"foo", b"bar");
494        let mut check = StorageTransaction::new(base.as_ref());
495        check.set(b"food", b"bank");
496        check.remove(b"foo");
497
498        assert_eq!(check.get(b"foo"), None);
499        assert_eq!(check.get(b"food"), Some(b"bank".to_vec()));
500
501        // now commit to base and query there
502        check.prepare().commit(base.as_mut());
503        assert_eq!(base.get(b"foo"), None);
504        assert_eq!(base.get(b"food"), Some(b"bank".to_vec()));
505    }
506
507    #[test]
508    fn storage_transaction_iterator_empty_base() {
509        let base = MemoryStorage::new();
510        let mut check = StorageTransaction::new(&base);
511        check.set(b"foo", b"bar");
512        iterator_test_suite(&mut check);
513    }
514
515    #[test]
516    fn storage_transaction_iterator_with_base_data() {
517        let mut base = MemoryStorage::new();
518        base.set(b"foo", b"bar");
519        let mut check = StorageTransaction::new(&base);
520        iterator_test_suite(&mut check);
521    }
522
523    #[test]
524    fn storage_transaction_iterator_removed_items_from_base() {
525        let mut base = Box::new(MemoryStorage::new());
526        base.set(b"foo", b"bar");
527        base.set(b"food", b"bank");
528        let mut check = StorageTransaction::new(base.as_ref());
529        check.remove(b"food");
530        iterator_test_suite(&mut check);
531    }
532
533    #[test]
534    fn commit_writes_through() {
535        let mut base = Box::new(MemoryStorage::new());
536        base.set(b"foo", b"bar");
537
538        let mut check = StorageTransaction::new(base.as_ref());
539        assert_eq!(check.get(b"foo"), Some(b"bar".to_vec()));
540        check.set(b"subtx", b"works");
541        check.prepare().commit(base.as_mut());
542
543        assert_eq!(base.get(b"subtx"), Some(b"works".to_vec()));
544    }
545
546    #[test]
547    fn storage_remains_readable() {
548        let mut base = MemoryStorage::new();
549        base.set(b"foo", b"bar");
550
551        let mut stx1 = StorageTransaction::new(&base);
552
553        assert_eq!(stx1.get(b"foo"), Some(b"bar".to_vec()));
554
555        stx1.set(b"subtx", b"works");
556        assert_eq!(stx1.get(b"subtx"), Some(b"works".to_vec()));
557
558        // Can still read from base, txn is not yet committed
559        assert_eq!(base.get(b"subtx"), None);
560
561        stx1.prepare().commit(&mut base);
562        assert_eq!(base.get(b"subtx"), Some(b"works".to_vec()));
563    }
564
565    #[test]
566    fn ignore_same_as_rollback() {
567        let mut base = MemoryStorage::new();
568        base.set(b"foo", b"bar");
569
570        let mut check = StorageTransaction::new(&base);
571        assert_eq!(check.get(b"foo"), Some(b"bar".to_vec()));
572        check.set(b"subtx", b"works");
573
574        assert_eq!(base.get(b"subtx"), None);
575    }
576}