1use crate::error::AnyResult;
2use cosmwasm_std::Storage;
3use cosmwasm_std::{Order, Record};
4use std::cmp::Ordering;
5use std::collections::BTreeMap;
6use std::iter;
7use std::iter::Peekable;
8use std::ops::{Bound, RangeBounds};
9
10type BTreeMapPairRef<'a, T = Vec<u8>> = (&'a Vec<u8>, &'a T);
13
14pub fn transactional<F, T>(base: &mut dyn Storage, action: F) -> AnyResult<T>
16where
17 F: FnOnce(&mut dyn Storage, &dyn Storage) -> AnyResult<T>,
18{
19 let mut cache = StorageTransaction::new(base);
20 let res = action(&mut cache, base)?;
21 cache.prepare().commit(base);
22 Ok(res)
23}
24
25pub struct StorageTransaction<'a> {
27 storage: &'a dyn Storage,
29 local_state: BTreeMap<Vec<u8>, Delta>,
31 rep_log: RepLog,
33}
34
35impl<'a> StorageTransaction<'a> {
36 pub fn new(storage: &'a dyn Storage) -> Self {
38 StorageTransaction {
39 storage,
40 local_state: BTreeMap::new(),
41 rep_log: RepLog::new(),
42 }
43 }
44
45 pub fn prepare(self) -> RepLog {
47 self.rep_log
48 }
49}
50
51impl<'a> Storage for StorageTransaction<'a> {
52 fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
53 match self.local_state.get(key) {
54 Some(val) => match val {
55 Delta::Set { value } => Some(value.clone()),
56 Delta::Delete {} => None,
57 },
58 None => self.storage.get(key),
59 }
60 }
61
62 fn range<'b>(
66 &'b self,
67 start: Option<&[u8]>,
68 end: Option<&[u8]>,
69 order: Order,
70 ) -> Box<dyn Iterator<Item = Record> + 'b> {
71 let bounds = range_bounds(start, end);
72
73 let local: Box<dyn Iterator<Item = BTreeMapPairRef<Delta>>> =
76 match (bounds.start_bound(), bounds.end_bound()) {
77 (Bound::Included(start), Bound::Excluded(end)) if start > end => {
78 Box::new(iter::empty())
79 }
80 _ => {
81 let local_raw = self.local_state.range(bounds);
82 match order {
83 Order::Ascending => Box::new(local_raw),
84 Order::Descending => Box::new(local_raw.rev()),
85 }
86 }
87 };
88
89 let base = self.storage.range(start, end, order);
90 let merged = MergeOverlay::new(local, base, order);
91 Box::new(merged)
92 }
93
94 fn set(&mut self, key: &[u8], value: &[u8]) {
95 let op = Op::Set {
96 key: key.to_vec(),
97 value: value.to_vec(),
98 };
99 self.local_state.insert(key.to_vec(), op.to_delta());
100 self.rep_log.append(op);
101 }
102
103 fn remove(&mut self, key: &[u8]) {
104 let op = Op::Delete { key: key.to_vec() };
105 self.local_state.insert(key.to_vec(), op.to_delta());
106 self.rep_log.append(op);
107 }
108}
109
110pub struct RepLog {
111 ops_log: Vec<Op>,
113}
114
115impl RepLog {
116 fn new() -> Self {
117 RepLog { ops_log: vec![] }
118 }
119
120 fn append(&mut self, op: Op) {
122 self.ops_log.push(op);
123 }
124
125 pub fn commit(self, storage: &mut dyn Storage) {
127 for op in self.ops_log {
128 op.apply(storage);
129 }
130 }
131}
132
133enum Op {
136 Set {
138 key: Vec<u8>,
139 value: Vec<u8>,
140 },
141 Delete {
142 key: Vec<u8>,
143 },
144}
145
146impl Op {
147 pub fn apply(&self, storage: &mut dyn Storage) {
149 match self {
150 Op::Set { key, value } => storage.set(key, value),
151 Op::Delete { key } => storage.remove(key),
152 }
153 }
154
155 pub fn to_delta(&self) -> Delta {
157 match self {
158 Op::Set { value, .. } => Delta::Set {
159 value: value.clone(),
160 },
161 Op::Delete { .. } => Delta::Delete {},
162 }
163 }
164}
165
166enum Delta {
170 Set { value: Vec<u8> },
171 Delete {},
172}
173
174struct MergeOverlay<'a, L, R>
175where
176 L: Iterator<Item = BTreeMapPairRef<'a, Delta>>,
177 R: Iterator<Item = Record>,
178{
179 left: Peekable<L>,
180 right: Peekable<R>,
181 order: Order,
182}
183
184impl<'a, L, R> MergeOverlay<'a, L, R>
185where
186 L: Iterator<Item = BTreeMapPairRef<'a, Delta>>,
187 R: Iterator<Item = Record>,
188{
189 fn new(left: L, right: R, order: Order) -> Self {
190 MergeOverlay {
191 left: left.peekable(),
192 right: right.peekable(),
193 order,
194 }
195 }
196
197 fn pick_match(&mut self, lkey: Vec<u8>, rkey: Vec<u8>) -> Option<Record> {
198 let order = match self.order {
200 Order::Ascending => lkey.cmp(&rkey),
201 Order::Descending => rkey.cmp(&lkey),
202 };
203
204 match order {
206 Ordering::Less => self.take_left(),
207 Ordering::Equal => {
208 let _ = self.right.next();
210 self.take_left()
211 }
212 Ordering::Greater => self.right.next(),
213 }
214 }
215
216 fn take_left(&mut self) -> Option<Record> {
218 let (lkey, lval) = self.left.next().unwrap();
219 match lval {
220 Delta::Set { value } => Some((lkey.clone(), value.clone())),
221 Delta::Delete {} => self.next(),
222 }
223 }
224}
225
226impl<'a, L, R> Iterator for MergeOverlay<'a, L, R>
227where
228 L: Iterator<Item = BTreeMapPairRef<'a, Delta>>,
229 R: Iterator<Item = Record>,
230{
231 type Item = Record;
232
233 fn next(&mut self) -> Option<Self::Item> {
234 let (left, right) = (self.left.peek(), self.right.peek());
235 match (left, right) {
236 (Some(litem), Some(ritem)) => {
237 let (lkey, _) = litem;
238 let (rkey, _) = ritem;
239
240 let (l, r) = (lkey.to_vec(), rkey.to_vec());
243 self.pick_match(l, r)
244 }
245 (Some(_), None) => self.take_left(),
246 (None, Some(_)) => self.right.next(),
247 (None, None) => None,
248 }
249 }
250}
251
252fn range_bounds(start: Option<&[u8]>, end: Option<&[u8]>) -> impl RangeBounds<Vec<u8>> {
253 (
254 start.map_or(Bound::Unbounded, |x| Bound::Included(x.to_vec())),
255 end.map_or(Bound::Unbounded, |x| Bound::Excluded(x.to_vec())),
256 )
257}
258
259#[cfg(test)]
260mod test {
261 use super::*;
262 use std::cell::RefCell;
263 use std::ops::{Deref, DerefMut};
264
265 use cosmwasm_std::MemoryStorage;
266
267 #[test]
268 fn wrap_storage() {
269 let mut store = MemoryStorage::new();
270 let mut wrap = StorageTransaction::new(&store);
271 wrap.set(b"foo", b"bar");
272
273 assert_eq!(None, store.get(b"foo"));
274 wrap.prepare().commit(&mut store);
275 assert_eq!(Some(b"bar".to_vec()), store.get(b"foo"));
276 }
277
278 #[test]
279 fn wrap_ref_cell() {
280 let store = RefCell::new(MemoryStorage::new());
281 let ops = {
282 let refer = store.borrow();
283 let mut wrap = StorageTransaction::new(refer.deref());
284 wrap.set(b"foo", b"bar");
285 assert_eq!(None, store.borrow().get(b"foo"));
286 wrap.prepare()
287 };
288 ops.commit(store.borrow_mut().deref_mut());
289 assert_eq!(Some(b"bar".to_vec()), store.borrow().get(b"foo"));
290 }
291
292 #[test]
293 fn wrap_box_storage() {
294 let mut store: Box<MemoryStorage> = Box::new(MemoryStorage::new());
295 let mut wrap = StorageTransaction::new(store.as_ref());
296 wrap.set(b"foo", b"bar");
297
298 assert_eq!(None, store.get(b"foo"));
299 wrap.prepare().commit(store.as_mut());
300 assert_eq!(Some(b"bar".to_vec()), store.get(b"foo"));
301 }
302
303 #[test]
304 fn wrap_box_dyn_storage() {
305 let mut store: Box<dyn Storage> = Box::new(MemoryStorage::new());
306 let mut wrap = StorageTransaction::new(store.as_ref());
307 wrap.set(b"foo", b"bar");
308
309 assert_eq!(None, store.get(b"foo"));
310 wrap.prepare().commit(store.as_mut());
311 assert_eq!(Some(b"bar".to_vec()), store.get(b"foo"));
312 }
313
314 #[test]
315 fn wrap_ref_cell_dyn_storage() {
316 let inner: Box<dyn Storage> = Box::new(MemoryStorage::new());
317 let store = RefCell::new(inner);
318 let ops = {
323 let refer = store.borrow();
324 let mut wrap = StorageTransaction::new(refer.as_ref());
325 wrap.set(b"foo", b"bar");
326
327 assert_eq!(None, store.borrow().get(b"foo"));
328 wrap.prepare()
329 };
330 ops.commit(store.borrow_mut().as_mut());
331 assert_eq!(Some(b"bar".to_vec()), store.borrow().get(b"foo"));
332 }
333
334 fn iterator_test_suite<S: Storage>(store: &mut S) {
338 assert_eq!(store.get(b"foo"), Some(b"bar".to_vec()));
340 assert_eq!(store.range(None, None, Order::Ascending).count(), 1);
341
342 store.set(b"ant", b"hill");
344 store.set(b"ze", b"bra");
345
346 store.set(b"bye", b"bye");
348 store.remove(b"bye");
349
350 {
352 let iter = store.range(None, None, Order::Ascending);
353 let elements: Vec<Record> = iter.collect();
354 assert_eq!(
355 elements,
356 vec![
357 (b"ant".to_vec(), b"hill".to_vec()),
358 (b"foo".to_vec(), b"bar".to_vec()),
359 (b"ze".to_vec(), b"bra".to_vec()),
360 ]
361 );
362 }
363
364 {
366 let iter = store.range(None, None, Order::Descending);
367 let elements: Vec<Record> = iter.collect();
368 assert_eq!(
369 elements,
370 vec![
371 (b"ze".to_vec(), b"bra".to_vec()),
372 (b"foo".to_vec(), b"bar".to_vec()),
373 (b"ant".to_vec(), b"hill".to_vec()),
374 ]
375 );
376 }
377
378 {
380 let iter = store.range(Some(b"f"), Some(b"n"), Order::Ascending);
381 let elements: Vec<Record> = iter.collect();
382 assert_eq!(elements, vec![(b"foo".to_vec(), b"bar".to_vec())]);
383 }
384
385 {
387 let iter = store.range(Some(b"air"), Some(b"loop"), Order::Descending);
388 let elements: Vec<Record> = iter.collect();
389 assert_eq!(
390 elements,
391 vec![
392 (b"foo".to_vec(), b"bar".to_vec()),
393 (b"ant".to_vec(), b"hill".to_vec()),
394 ]
395 );
396 }
397
398 {
400 let iter = store.range(Some(b"foo"), Some(b"foo"), Order::Ascending);
401 let elements: Vec<Record> = iter.collect();
402 assert_eq!(elements, vec![]);
403 }
404
405 {
407 let iter = store.range(Some(b"foo"), Some(b"foo"), Order::Descending);
408 let elements: Vec<Record> = iter.collect();
409 assert_eq!(elements, vec![]);
410 }
411
412 {
414 let iter = store.range(Some(b"z"), Some(b"a"), Order::Ascending);
415 let elements: Vec<Record> = iter.collect();
416 assert_eq!(elements, vec![]);
417 }
418
419 {
421 let iter = store.range(Some(b"z"), Some(b"a"), Order::Descending);
422 let elements: Vec<Record> = iter.collect();
423 assert_eq!(elements, vec![]);
424 }
425
426 {
428 let iter = store.range(Some(b"f"), None, Order::Ascending);
429 let elements: Vec<Record> = iter.collect();
430 assert_eq!(
431 elements,
432 vec![
433 (b"foo".to_vec(), b"bar".to_vec()),
434 (b"ze".to_vec(), b"bra".to_vec()),
435 ]
436 );
437 }
438
439 {
441 let iter = store.range(Some(b"f"), None, Order::Descending);
442 let elements: Vec<Record> = iter.collect();
443 assert_eq!(
444 elements,
445 vec![
446 (b"ze".to_vec(), b"bra".to_vec()),
447 (b"foo".to_vec(), b"bar".to_vec()),
448 ]
449 );
450 }
451
452 {
454 let iter = store.range(None, Some(b"f"), Order::Ascending);
455 let elements: Vec<Record> = iter.collect();
456 assert_eq!(elements, vec![(b"ant".to_vec(), b"hill".to_vec()),]);
457 }
458
459 {
461 let iter = store.range(None, Some(b"no"), Order::Descending);
462 let elements: Vec<Record> = iter.collect();
463 assert_eq!(
464 elements,
465 vec![
466 (b"foo".to_vec(), b"bar".to_vec()),
467 (b"ant".to_vec(), b"hill".to_vec()),
468 ]
469 );
470 }
471 }
472
473 #[test]
474 fn delete_local() {
475 let mut base = Box::new(MemoryStorage::new());
476 let mut check = StorageTransaction::new(base.as_ref());
477 check.set(b"foo", b"bar");
478 check.set(b"food", b"bank");
479 check.remove(b"foo");
480
481 assert_eq!(check.get(b"foo"), None);
482 assert_eq!(check.get(b"food"), Some(b"bank".to_vec()));
483
484 check.prepare().commit(base.as_mut());
486 assert_eq!(base.get(b"foo"), None);
487 assert_eq!(base.get(b"food"), Some(b"bank".to_vec()));
488 }
489
490 #[test]
491 fn delete_from_base() {
492 let mut base = Box::new(MemoryStorage::new());
493 base.set(b"foo", b"bar");
494 let mut check = StorageTransaction::new(base.as_ref());
495 check.set(b"food", b"bank");
496 check.remove(b"foo");
497
498 assert_eq!(check.get(b"foo"), None);
499 assert_eq!(check.get(b"food"), Some(b"bank".to_vec()));
500
501 check.prepare().commit(base.as_mut());
503 assert_eq!(base.get(b"foo"), None);
504 assert_eq!(base.get(b"food"), Some(b"bank".to_vec()));
505 }
506
507 #[test]
508 fn storage_transaction_iterator_empty_base() {
509 let base = MemoryStorage::new();
510 let mut check = StorageTransaction::new(&base);
511 check.set(b"foo", b"bar");
512 iterator_test_suite(&mut check);
513 }
514
515 #[test]
516 fn storage_transaction_iterator_with_base_data() {
517 let mut base = MemoryStorage::new();
518 base.set(b"foo", b"bar");
519 let mut check = StorageTransaction::new(&base);
520 iterator_test_suite(&mut check);
521 }
522
523 #[test]
524 fn storage_transaction_iterator_removed_items_from_base() {
525 let mut base = Box::new(MemoryStorage::new());
526 base.set(b"foo", b"bar");
527 base.set(b"food", b"bank");
528 let mut check = StorageTransaction::new(base.as_ref());
529 check.remove(b"food");
530 iterator_test_suite(&mut check);
531 }
532
533 #[test]
534 fn commit_writes_through() {
535 let mut base = Box::new(MemoryStorage::new());
536 base.set(b"foo", b"bar");
537
538 let mut check = StorageTransaction::new(base.as_ref());
539 assert_eq!(check.get(b"foo"), Some(b"bar".to_vec()));
540 check.set(b"subtx", b"works");
541 check.prepare().commit(base.as_mut());
542
543 assert_eq!(base.get(b"subtx"), Some(b"works".to_vec()));
544 }
545
546 #[test]
547 fn storage_remains_readable() {
548 let mut base = MemoryStorage::new();
549 base.set(b"foo", b"bar");
550
551 let mut stx1 = StorageTransaction::new(&base);
552
553 assert_eq!(stx1.get(b"foo"), Some(b"bar".to_vec()));
554
555 stx1.set(b"subtx", b"works");
556 assert_eq!(stx1.get(b"subtx"), Some(b"works".to_vec()));
557
558 assert_eq!(base.get(b"subtx"), None);
560
561 stx1.prepare().commit(&mut base);
562 assert_eq!(base.get(b"subtx"), Some(b"works".to_vec()));
563 }
564
565 #[test]
566 fn ignore_same_as_rollback() {
567 let mut base = MemoryStorage::new();
568 base.set(b"foo", b"bar");
569
570 let mut check = StorageTransaction::new(&base);
571 assert_eq!(check.get(b"foo"), Some(b"bar".to_vec()));
572 check.set(b"subtx", b"works");
573
574 assert_eq!(base.get(b"subtx"), None);
575 }
576}