neptune_common/
storage.rs

1use cosmwasm_std::{Addr, Deps, DepsMut, Order};
2use cw_storage_plus::{Bounder, KeyDeserialize, Map, PrimaryKey};
3use serde::{de::DeserializeOwned, Serialize};
4use std::fmt::Debug;
5
6use crate::{
7    error::{CommonError, CommonResult},
8    neptune_map::*,
9};
10
11pub const PARAMS_KEY: &str = "params";
12pub const STATE_KEY: &str = "state";
13
14/// Reads a map from storage is ascending order.
15pub fn read_map<'k, K, O, V>(
16    deps: Deps, start_after: Option<K>, limit: Option<u32>, map: Map<'k, K, V>,
17) -> Result<Vec<(O, V)>, CommonError>
18where
19    K: Bounder<'k> + PrimaryKey<'k> + KeyDeserialize<Output = O>,
20    O: 'static,
21    V: Serialize + DeserializeOwned,
22{
23    let start = start_after.map(|key| key.inclusive_bound().unwrap());
24    let vec = match limit {
25        Some(limit) => map
26            .range(deps.storage, start, None, Order::Ascending)
27            .take(limit as usize)
28            .collect::<Result<Vec<_>, _>>()?,
29        None => map.range(deps.storage, start, None, Order::Ascending).collect::<Result<Vec<_>, _>>()?,
30    };
31    Ok(vec)
32}
33
34/// Trait for types which act as a storage cache with cosmwasm storage plus.
35pub trait Cacher<'s, 'k, K, V>
36where
37    for<'a> &'a K: Debug + PartialEq + Eq + PrimaryKey<'a>,
38    K: Clone + Debug + PartialEq + Eq,
39    V: Clone + Serialize + DeserializeOwned,
40{
41    fn must_get_mut(&mut self, deps: Deps<'_>, key: &K) -> CommonResult<&mut V>;
42    fn must_get(&mut self, deps: Deps<'_>, key: &K) -> CommonResult<&V>;
43}
44
45/// The inner part of the cache which keeps track of wether the value has been modified.
46struct CacheInner<V>
47where
48    V: Clone + Serialize + DeserializeOwned,
49{
50    value: V,
51    is_modified: bool,
52}
53
54/// A cache which stores values in memory to avoid repeated disk reads/writes.
55pub struct Cache<'s, 'k, K, V>
56where
57    for<'a> &'a K: Debug + PartialEq + Eq + PrimaryKey<'a>,
58    K: Clone + Debug + PartialEq + Eq,
59    V: Clone + Serialize + DeserializeOwned,
60{
61    map: NeptuneMap<K, CacheInner<V>>,
62    storage: Map<'s, &'k K, V>,
63}
64
65impl<'s, 'k, K, V> Cache<'s, 'k, K, V>
66where
67    for<'a> &'a K: Debug + PartialEq + Eq + PrimaryKey<'a>,
68    K: Clone + Debug + PartialEq + Eq,
69    V: Clone + Serialize + DeserializeOwned,
70{
71    pub fn new(storage: Map<'s, &'k K, V>) -> Self {
72        Self { map: NeptuneMap::new(), storage }
73    }
74
75    pub fn save(&mut self, deps: DepsMut<'_>) -> CommonResult<()> {
76        for (key, inner) in self.map.iter() {
77            if inner.is_modified {
78                self.storage.save(deps.storage, key, &inner.value)?;
79            }
80        }
81        Ok(())
82    }
83}
84
85impl<'s, 'k, K, V> Cacher<'s, 'k, K, V> for Cache<'s, 'k, K, V>
86where
87    for<'a> &'a K: Debug + PartialEq + Eq + PrimaryKey<'a>,
88    K: Clone + Debug + PartialEq + Eq,
89    V: Clone + Serialize + DeserializeOwned,
90{
91    fn must_get_mut(&mut self, deps: Deps<'_>, key: &K) -> CommonResult<&mut V> {
92        match self.map.iter().position(|x| &x.0 == key) {
93            Some(index) => {
94                let inner = &mut self.map.0[index].1;
95                inner.is_modified = true;
96                Ok(&mut inner.value)
97            }
98            None => {
99                let value = self.storage.load(deps.storage, key)?;
100                let inner = CacheInner { value, is_modified: true };
101                self.map.insert(key.clone(), inner);
102                Ok(&mut self.map.last_mut().unwrap().1.value)
103            }
104        }
105    }
106
107    fn must_get(&mut self, deps: Deps<'_>, key: &K) -> CommonResult<&V> {
108        match self.map.iter().position(|x| &x.0 == key) {
109            Some(index) => Ok(&self.map.0[index].1.value),
110            None => {
111                let value = self.storage.load(deps.storage, key)?;
112                let inner = CacheInner { value, is_modified: false };
113                self.map.insert(key.clone(), inner);
114                Ok(&self.map.last().unwrap().1.value)
115            }
116        }
117    }
118}
119
120/// A cache which stores values in memory to avoid repeated disk reads/writes.
121/// Values are accessed through a raw query to another contracts storage.
122pub struct QueryCache<'s, 'k, K, V>
123where
124    for<'a> &'a K: Debug + PartialEq + Eq + PrimaryKey<'a>,
125    K: Clone + Debug + PartialEq + Eq,
126    V: Clone + Serialize + DeserializeOwned,
127{
128    map: NeptuneMap<K, V>,
129    storage: Map<'s, &'k K, V>,
130    addr: Addr,
131}
132
133impl<'s, 'k, K, V> QueryCache<'s, 'k, K, V>
134where
135    for<'a> &'a K: Debug + PartialEq + Eq + PrimaryKey<'a>,
136    K: Clone + Debug + PartialEq + Eq,
137    V: Clone + Serialize + DeserializeOwned,
138{
139    pub fn new(storage: Map<'s, &'k K, V>, addr: Addr) -> Self {
140        Self { map: NeptuneMap::new(), storage, addr }
141    }
142}
143
144impl<'s, 'k, K, V> Cacher<'s, 'k, K, V> for QueryCache<'s, 'k, K, V>
145where
146    for<'a> &'a K: Debug + PartialEq + Eq + PrimaryKey<'a>,
147    K: Clone + Debug + PartialEq + Eq,
148    V: Clone + Serialize + DeserializeOwned,
149{
150    fn must_get_mut(&mut self, deps: Deps<'_>, key: &K) -> CommonResult<&mut V> {
151        match self.map.iter().position(|x| &x.0 == key) {
152            Some(index) => Ok(&mut self.map.0[index].1),
153            None => {
154                let value = self
155                    .storage
156                    .query(&deps.querier, self.addr.clone(), key)?
157                    .ok_or_else(|| CommonError::KeyNotFound(format!("{key:?}")))?;
158                self.map.insert(key.clone(), value);
159                Ok(&mut self.map.last_mut().unwrap().1)
160            }
161        }
162    }
163
164    fn must_get(&mut self, deps: Deps<'_>, key: &K) -> CommonResult<&V> {
165        match self.map.iter().position(|x| &x.0 == key) {
166            Some(index) => Ok(&self.map.0[index].1),
167            None => {
168                let value = self
169                    .storage
170                    .query(&deps.querier, self.addr.clone(), key)?
171                    .ok_or_else(|| CommonError::KeyNotFound(format!("{key:?}")))?;
172                self.map.insert(key.clone(), value);
173                Ok(&self.map.last().unwrap().1)
174            }
175        }
176    }
177}