1use once_cell::sync::Lazy;
7use pink::CacheOp;
8use sp_core::{crypto::AccountId32, ByteArray};
9use std::{
10 borrow::Cow,
11 collections::BTreeMap,
12 sync::atomic::{AtomicBool, Ordering},
13 time::Instant,
14};
15
16pub use pink::chain_extension::StorageQuotaExceeded;
17
18static TEST_MODE: AtomicBool = AtomicBool::new(false);
19
20pub(crate) fn enable_test_mode() {
21 TEST_MODE.store(true, Ordering::Relaxed);
22}
23
24fn with_global_cache<T>(f: impl FnOnce(&mut LocalCache) -> T) -> T {
25 if TEST_MODE.load(Ordering::Relaxed) {
26 use std::cell::RefCell;
28 thread_local! {
29 pub static GLOBAL_CACHE: RefCell<LocalCache> = RefCell::new(LocalCache::new());
30 }
31 GLOBAL_CACHE.with(move |cache| f(&mut cache.borrow_mut()))
32 } else {
33 use std::sync::Mutex;
34 pub static GLOBAL_CACHE: Mutex<LocalCache> = Mutex::new(LocalCache::new());
35 f(&mut GLOBAL_CACHE.lock().unwrap())
36 }
37}
38
39struct Storage {
40 size: usize,
42 max_size: usize,
43 kvs: BTreeMap<Vec<u8>, StorageValue>,
44}
45
46impl Storage {
47 fn new(max_size: usize) -> Self {
48 Self {
49 size: 0,
50 max_size,
51 kvs: Default::default(),
52 }
53 }
54
55 fn fit_size(&mut self) {
59 if self.size <= self.max_size {
60 return;
61 }
62 let map = std::mem::take(&mut self.kvs);
63
64 let mut kvs: Vec<_> = map
65 .into_iter()
66 .map(|(k, v)| (v.expire_at, (k, v)))
67 .collect();
68 kvs.sort_by_key(|(expire, _)| *expire);
69 self.kvs = kvs
70 .into_iter()
71 .filter_map(|(_, (k, v))| {
72 if self.size <= self.max_size {
73 return Some((k, v));
74 }
75 self.size -= k.len() + v.value.len();
76 None
77 })
78 .collect();
79 }
80
81 fn clear_expired(&mut self, now: u64) {
82 self.kvs.retain(|k, v| {
83 if v.expire_at > now {
84 true
85 } else {
86 self.size -= v.value.len() + k.len();
87 false
88 }
89 });
90 }
91
92 fn remove(&mut self, key: &[u8]) -> Option<Vec<u8>> {
93 let v = self.kvs.remove(key).map(|v| v.value);
94 if let Some(v) = &v {
95 self.size -= v.len() + key.len();
96 }
97 v
98 }
99
100 fn set(
101 &mut self,
102 key: Cow<[u8]>,
103 value: Cow<[u8]>,
104 lifetime: u64,
105 ) -> Result<(), StorageQuotaExceeded> {
106 _ = self.remove(key.as_ref());
107 let data_len = key.len() + value.len();
108 let mut store_size = self.size + data_len;
109 if store_size > self.max_size {
110 self.clear_expired(now());
111 store_size = self.size + data_len;
112 if store_size > self.max_size {
113 return Err(StorageQuotaExceeded);
114 }
115 }
116 self.size = store_size;
117 self.kvs.insert(
118 key.into_owned(),
119 StorageValue {
120 expire_at: now().saturating_add(lifetime),
121 value: value.into_owned(),
122 },
123 );
124 Ok(())
125 }
126
127 #[cfg(test)]
128 fn get(&self, key: &[u8]) -> Option<&StorageValue> {
129 self.kvs.get(key)
130 }
131}
132
133struct StorageValue {
134 expire_at: u64,
136 value: Vec<u8>,
137}
138
139pub struct LocalCache {
140 gc_interval: u64,
142 sets_since_last_gc: u64,
144 default_value_lifetime: u64,
146 storages: BTreeMap<Vec<u8>, Storage>,
147}
148
149impl LocalCache {
150 const fn new() -> Self {
151 Self {
152 gc_interval: 1000,
153 sets_since_last_gc: 0,
154 default_value_lifetime: 3600 * 24 * 7, storages: BTreeMap::new(),
156 }
157 }
158}
159
160impl LocalCache {
161 fn maybe_clear_expired(&mut self) {
162 self.sets_since_last_gc += 1;
163 if self.sets_since_last_gc == self.gc_interval {
164 self.clear_expired();
165 }
166 }
167
168 fn clear_expired(&mut self) {
169 self.sets_since_last_gc = 0;
170 let now = now();
171 self.storages.values_mut().for_each(|storage| {
172 storage.clear_expired(now);
173 });
174 }
175
176 pub fn get(&self, id: &[u8], key: &[u8]) -> Option<Vec<u8>> {
177 let entry = self.storages.get(id)?.kvs.get(key)?;
178 if entry.expire_at <= now() {
179 None
180 } else {
181 Some(entry.value.to_owned())
182 }
183 }
184
185 #[cfg(test)]
186 fn get_include_expired(&self, id: &[u8], key: &[u8]) -> Option<Vec<u8>> {
187 Some(self.storages.get(id)?.kvs.get(key)?.value.to_owned())
188 }
189
190 pub fn set(
191 &mut self,
192 id: Cow<[u8]>,
193 key: Cow<[u8]>,
194 value: Cow<[u8]>,
195 ) -> Result<(), StorageQuotaExceeded> {
196 self.maybe_clear_expired();
197 self.storages
198 .get_mut(id.as_ref())
199 .ok_or(StorageQuotaExceeded)?
200 .set(key, value, self.default_value_lifetime)
201 }
202
203 pub fn set_expire(&mut self, id: Cow<[u8]>, key: Cow<[u8]>, expire: u64) {
204 self.maybe_clear_expired();
205 if expire == 0 {
206 let _ = self.remove(id.as_ref(), key.as_ref());
207 } else if let Some(v) = self
208 .storages
209 .get_mut(id.as_ref())
210 .and_then(|storage| storage.kvs.get_mut(key.as_ref()))
211 {
212 v.expire_at = now().saturating_add(expire)
213 }
214 }
215
216 pub fn remove(&mut self, id: &[u8], key: &[u8]) -> Option<Vec<u8>> {
217 self.maybe_clear_expired();
218 let store = self.storages.get_mut(id)?;
219 store.remove(key)
220 }
221
222 pub fn apply_quotas<'a>(&mut self, quotas: impl IntoIterator<Item = (&'a [u8], usize)>) {
223 for (contract, max_size) in quotas.into_iter() {
224 log::trace!(
225 "Applying cache quotas for {} max_size={max_size}",
226 hex_fmt::HexFmt(contract)
227 );
228 if max_size == 0 {
229 self.storages.remove(contract);
230 continue;
231 }
232 match self.storages.get_mut(contract) {
233 Some(store) => {
234 store.max_size = max_size;
235 store.fit_size();
236 }
237 None => {
238 self.storages
239 .insert(contract.to_vec(), Storage::new(max_size));
240 }
241 }
242 }
243 }
244}
245
246fn now() -> u64 {
247 static REF_TIME: Lazy<Instant> = Lazy::new(Instant::now);
248 REF_TIME.elapsed().as_secs()
249}
250
251pub fn apply_cache_op(contract: &AccountId32, op: CacheOp) {
252 match op {
253 CacheOp::Set { key, value } => {
254 let _ = set(contract.as_slice(), &key, &value);
255 }
256 CacheOp::SetExpiration { key, expiration } => {
257 set_expiration(contract.as_slice(), &key, expiration);
258 }
259 CacheOp::Remove { key } => {
260 let _ = remove(contract.as_slice(), &key);
261 }
262 }
263}
264
265pub fn set(contract: &[u8], key: &[u8], value: &[u8]) -> Result<(), StorageQuotaExceeded> {
266 with_global_cache(|cache| cache.set(contract.into(), key.into(), value.into()))
267}
268
269pub fn get(contract: &[u8], key: &[u8]) -> Option<Vec<u8>> {
270 with_global_cache(|cache| cache.get(contract, key))
271}
272
273pub fn set_expiration(contract: &[u8], key: &[u8], expiration: u64) {
274 with_global_cache(|cache| cache.set_expire(contract.into(), key.into(), expiration))
275}
276
277pub fn remove(contract: &[u8], key: &[u8]) -> Option<Vec<u8>> {
278 with_global_cache(|cache| cache.remove(contract, key))
279}
280
281pub fn apply_quotas<'a>(quotas: impl IntoIterator<Item = (&'a [u8], usize)>) {
282 with_global_cache(|cache| cache.apply_quotas(quotas))
283}
284
285#[cfg(test)]
286mod test {
287 use super::*;
288 fn test_cache() -> LocalCache {
289 LocalCache {
290 gc_interval: 2,
291 sets_since_last_gc: 0,
292 default_value_lifetime: 2,
293 storages: Default::default(),
294 }
295 }
296
297 fn cow(s: &impl AsRef<[u8]>) -> Cow<[u8]> {
298 Cow::Borrowed(s.as_ref())
299 }
300
301 fn gc(cache: &mut LocalCache) {
302 for _ in 0..cache.gc_interval + 1 {
303 let _ = cache.set(cow(b"_"), cow(b"_"), cow(b"_"));
304 }
305 }
306
307 fn sleep(secs: u64) {
308 std::thread::sleep(std::time::Duration::from_secs(secs));
309 }
310
311 fn get_size(cache: &LocalCache, id: &[u8]) -> usize {
312 cache.storages.get(id).unwrap().size
313 }
314
315 #[test]
316 fn default_expire_should_work() {
317 let mut cache = test_cache();
318 cache.apply_quotas([(&b"id"[..], 1000)]);
319 let _ = cache.set(cow(b"id"), cow(b"foo"), cow(b"value"));
320 assert_eq!(cache.get(b"id", b"foo"), Some(b"value".to_vec()));
321
322 sleep(cache.default_value_lifetime);
323 assert_eq!(cache.get(b"id", b"foo"), None);
324 assert!(cache.get_include_expired(b"id", b"foo").is_some());
325 gc(&mut cache);
326 assert_eq!(cache.get_include_expired(b"id", b"foo"), None);
327 assert_eq!(get_size(&cache, b"id"), 0);
328 }
329
330 #[test]
331 fn set_expire_should_work() {
332 let mut cache = test_cache();
333 cache.apply_quotas([(&b"id"[..], 1000)]);
334
335 let _ = cache.set(cow(b"id"), cow(b"foo"), cow(b"value"));
336 assert_eq!(cache.get(b"id", b"foo"), Some(b"value".to_vec()));
337 cache.set_expire(cow(b"id"), cow(b"foo"), cache.default_value_lifetime + 2);
338
339 sleep(cache.default_value_lifetime);
340 gc(&mut cache);
341
342 assert_eq!(cache.get(b"id", b"foo"), Some(b"value".to_vec()));
343
344 sleep(2);
345 gc(&mut cache);
346
347 assert_eq!(cache.get_include_expired(b"id", b"foo"), None);
348 }
349
350 #[test]
351 fn size_limit_should_work() {
352 let mut cache = test_cache();
353 cache.apply_quotas([(&b"id"[..], 10)]);
354
355 assert!(cache.set(cow(b"id"), cow(b"foo"), cow(b"value")).is_ok());
356 assert!(cache.set(cow(b"id"), cow(b"bar"), cow(b"value")).is_err());
357 }
358
359 #[test]
360 fn size_calc() {
361 let mut cache = test_cache();
362 cache.apply_quotas([(&b"id"[..], 100)]);
363
364 assert!(cache.set(cow(b"id"), cow(b"foo"), cow(b"bar")).is_ok());
365 assert_eq!(get_size(&cache, b"id"), 6);
366 assert!(cache.set(cow(b"id"), cow(b"foo"), cow(b"foobar")).is_ok());
367 assert_eq!(get_size(&cache, b"id"), 9);
368 assert!(cache.set(cow(b"id"), cow(b"foo"), cow(b"foo")).is_ok());
369 assert_eq!(get_size(&cache, b"id"), 6);
370 assert!(cache.remove(b"id", b"foo").is_some());
371 assert_eq!(get_size(&cache, b"id"), 0);
372 }
373
374 #[test]
375 fn fit_size_works() {
376 let mut store = Storage::new(20);
377 assert!(store.set(cow(b"k0"), cow(b"v0"), 1000).is_ok());
378 assert_eq!(store.size, 4);
379 assert!(store.set(cow(b"k1"), cow(b"v0"), 50).is_ok());
380 assert_eq!(store.size, 8);
381 assert!(store.set(cow(b"k2"), cow(b"v0"), 200).is_ok());
382 assert_eq!(store.size, 12);
383 assert!(store.set(cow(b"k3"), cow(b"v0"), 100).is_ok());
384 assert_eq!(store.size, 16);
385 assert!(store.set(cow(b"k4"), cow(b"v"), 100).is_ok());
386 assert_eq!(store.size, 19);
387 assert!(store.set(cow(b"k4"), cow(b"vvvvv"), 100).is_err());
388 assert_eq!(store.size, 16);
389
390 assert!(store.get(b"k0").is_some());
391 assert!(store.get(b"k1").is_some());
392 assert!(store.get(b"k2").is_some());
393 assert!(store.get(b"k3").is_some());
394
395 store.max_size = 10;
396 store.fit_size();
397
398 assert!(store.get(b"k0").is_some());
399 assert!(store.get(b"k2").is_some());
400
401 assert!(store.get(b"k1").is_none());
402 assert!(store.get(b"k3").is_none());
403 assert_eq!(store.size, 8);
404 }
405
406 #[test]
407 fn cache_op_works() {
408 use pink::CacheOp;
409
410 enable_test_mode();
411
412 let key = b"hello";
413 let value = b"world";
414 let account = AccountId32::from([2u8; 32]);
415
416 apply_quotas([(account.as_slice(), 1024 * 1024 * 20), (&[1u8; 32], 0)]);
417
418 apply_cache_op(
419 &account,
420 CacheOp::Set {
421 key: key.to_vec(),
422 value: value.to_vec(),
423 },
424 );
425 let result = get(account.as_ref(), key);
426 assert_eq!(result.unwrap(), value);
427
428 apply_cache_op(&account, CacheOp::Remove { key: key.to_vec() });
429
430 let result = get(account.as_slice(), key);
431 assert!(result.is_none());
432
433 let result = set(account.as_slice(), key, value);
434 assert!(result.is_ok());
435 apply_cache_op(
436 &account,
437 CacheOp::SetExpiration {
438 key: key.to_vec(),
439 expiration: 0,
440 },
441 );
442 let result = get(account.as_slice(), key);
443 assert!(result.is_none());
444 }
445}