bobcat_storage/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2
3use keccak_const::Keccak256;
4
5use array_concat::concat_arrays;
6
7pub use bobcat_maths::U;
8
9use bobcat_maths::wrapping_sub;
10
11#[cfg(all(target_family = "wasm", target_os = "unknown"))]
12#[link(wasm_import_module = "vm_hooks")]
13unsafe extern "C" {
14    fn storage_load_bytes32(key: *const u8, out: *mut u8);
15    fn storage_cache_bytes32(key: *const u8, from: *const u8);
16    fn transient_load_bytes32(key: *const u8, dest: *mut u8);
17    fn transient_store_bytes32(key: *const u8, value: *const u8);
18    fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8);
19    pub fn storage_flush_cache(clear: bool);
20}
21
22#[cfg(all(
23    not(all(target_family = "wasm", target_os = "unknown")),
24    not(feature = "mutex"),
25    feature = "std"
26))]
27pub mod storage_host {
28    use super::*;
29
30    use std::{cell::RefCell, collections::HashMap, ptr::copy_nonoverlapping};
31
32    type WordHashMap = HashMap<U, U>;
33
34    thread_local! {
35        pub static STORAGE: RefCell<WordHashMap> = RefCell::default();
36        pub static TRANSIENT: RefCell<WordHashMap> = RefCell::default();
37    }
38
39    pub fn storage_clear() {
40        STORAGE.with(|s| s.borrow_mut().clear())
41    }
42
43    pub fn transient_clear() {
44        TRANSIENT.with(|s| s.borrow_mut().clear())
45    }
46
47    unsafe fn read_word(key: *const u8) -> U {
48        let mut r = [0u8; 32];
49        unsafe {
50            copy_nonoverlapping(key, r.as_mut_ptr(), 32);
51        }
52        U(r)
53    }
54
55    unsafe fn write_word(key: *mut u8, val: U) {
56        unsafe {
57            copy_nonoverlapping(val.as_ptr(), key, 32);
58        }
59    }
60
61    pub(crate) unsafe fn storage_load_bytes32(key: *const u8, out: *mut u8) {
62        let k = unsafe { read_word(key) };
63        let value = STORAGE.with(|s| match s.borrow().get(&k) {
64            Some(v) => *v,
65            None => U::ZERO,
66        });
67        unsafe { write_word(out, value) };
68    }
69
70    pub(crate) unsafe fn storage_cache_bytes32(key: *const u8, value: *const u8) {
71        let k = unsafe { read_word(key) };
72        let v = unsafe { read_word(value) };
73        STORAGE.with(|s| s.borrow_mut().insert(k, v));
74    }
75
76    pub(crate) unsafe fn transient_load_bytes32(key: *const u8, out: *mut u8) {
77        let k = unsafe { read_word(key) };
78        let value = TRANSIENT.with(|s| match s.borrow().get(&k) {
79            Some(v) => *v,
80            None => U::ZERO,
81        });
82        unsafe { write_word(out, value) };
83    }
84
85    pub(crate) unsafe fn transient_store_bytes32(key: *const u8, value: *const u8) {
86        let k = unsafe { read_word(key) };
87        let v = unsafe { read_word(value) };
88        TRANSIENT.with(|s| s.borrow_mut().insert(k, v));
89    }
90
91    pub unsafe fn storage_flush_cache(clear: bool) {
92        if clear {
93            storage_clear()
94        }
95    }
96}
97
98#[cfg(all(
99    not(all(target_family = "wasm", target_os = "unknown")),
100    feature = "mutex",
101    feature = "std"
102))]
103pub mod storage_host {
104    use super::*;
105
106    use std::{
107        collections::HashMap,
108        ptr::copy_nonoverlapping,
109        sync::{LazyLock, Mutex},
110    };
111
112    type WordHashMap = HashMap<U, U>;
113
114    pub static STORAGE: LazyLock<Mutex<WordHashMap>> = LazyLock::new(|| Mutex::default());
115    pub static TRANSIENT: LazyLock<Mutex<WordHashMap>> = LazyLock::new(|| Mutex::default());
116
117    pub fn storage_clear() {
118        STORAGE.lock().unwrap().clear()
119    }
120
121    pub fn transient_clear() {
122        TRANSIENT.lock().unwrap().clear()
123    }
124
125    unsafe fn read_word(key: *const u8) -> U {
126        let mut r = [0u8; 32];
127        unsafe {
128            copy_nonoverlapping(key, r.as_mut_ptr(), 32);
129        }
130        U(r)
131    }
132
133    unsafe fn write_word(key: *mut u8, val: U) {
134        unsafe {
135            copy_nonoverlapping(val.as_ptr(), key, 32);
136        }
137    }
138
139    pub(crate) unsafe fn storage_load_bytes32(key: *const u8, out: *mut u8) {
140        let k = unsafe { read_word(key) };
141        let value = match STORAGE.lock().unwrap().get(&k) {
142            Some(v) => *v,
143            None => U::ZERO,
144        };
145        unsafe { write_word(out, value) };
146    }
147
148    pub(crate) unsafe fn storage_cache_bytes32(key: *const u8, value: *const u8) {
149        let k = unsafe { read_word(key) };
150        let v = unsafe { read_word(value) };
151        STORAGE.lock().unwrap().insert(k, v);
152    }
153
154    pub(crate) unsafe fn transient_load_bytes32(key: *const u8, out: *mut u8) {
155        let k = unsafe { read_word(key) };
156        let value = match TRANSIENT.lock().unwrap().get(&k) {
157            Some(v) => *v,
158            None => U::ZERO,
159        };
160        unsafe { write_word(out, value) };
161    }
162
163    pub(crate) unsafe fn transient_store_bytes32(key: *const u8, value: *const u8) {
164        let k = unsafe { read_word(key) };
165        let v = unsafe { read_word(value) };
166        TRANSIENT.lock().unwrap().insert(k, v);
167    }
168
169    pub unsafe fn storage_flush_cache(clear: bool) {
170        if clear {
171            storage_clear()
172        }
173    }
174}
175
176#[cfg(all(
177    not(all(target_family = "wasm", target_os = "unknown")),
178    not(feature = "std")
179))]
180mod storage_host {
181    pub(crate) unsafe fn storage_load_bytes32(_: *const u8, _: *mut u8) {}
182
183    pub(crate) unsafe fn storage_cache_bytes32(_: *const u8, _: *const u8) {}
184
185    pub(crate) unsafe fn transient_load_bytes32(_: *const u8, _: *mut u8) {}
186
187    pub(crate) unsafe fn transient_store_bytes32(_: *const u8, _: *const u8) {}
188
189    pub unsafe fn storage_flush_cache(_: bool) {}
190}
191
192#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
193use storage_host::*;
194
195#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
196pub use storage_host::storage_flush_cache;
197
198macro_rules! storage_ops {
199    ($($prefix:ident),* $(,)?) => {
200        $(
201            paste::paste! {
202                pub fn [<$prefix _load>](x: &U) -> U {
203                    let mut b = [0u8; 32];
204                    unsafe { [<$prefix _load_bytes32>](x.as_ptr(), b.as_mut_ptr()) }
205                    U(b)
206                }
207
208                pub fn [<$prefix _load_bool>](x: &U) -> bool {
209                    [<$prefix _load>](x).into()
210                }
211
212                /// Attempt to "exchange" a value, returning whether the expected value was set.
213                pub fn [<$prefix _exchange>](k: &U, exp: &U, new: &U) -> bool {
214                    let t = [<$prefix _load>](k);
215                    if &t != exp {
216                        return false;
217                    }
218                    [<$prefix _store>](k, new);
219                    true
220                }
221
222                pub fn [<$prefix _exchange_res>](k: &U, exp: &U, new: &U) -> Result<(), U> {
223                    let t = [<$prefix _load>](k);
224                    if &t != exp {
225                        return Err(t);
226                    }
227                    [<$prefix _store>](k, new);
228                    Ok(())
229                }
230
231                /// Set the value given, checking that the value passed has the inverse
232                /// set set currently. So, passing true would check if false is set.
233                pub fn [<$prefix _exchange_bool>](k: &U, new: bool) -> bool {
234                    [<$prefix _exchange>](k, &U::from(!new), &U::from(new))
235                }
236
237                pub fn [<$prefix _exchange_bool_res>](k: &U, new: bool) -> Result<(), bool> {
238                   let x = [<$prefix _exchange_bool>](k, new);
239                   if x == !new {
240                       Ok(())
241                   } else {
242                       Err(x)
243                   }
244                }
245            }
246        )*
247    };
248}
249
250pub fn storage_store(x: &U, y: &U) {
251    unsafe { storage_cache_bytes32(x.as_ptr(), y.as_ptr()) }
252}
253
254pub fn storage_store_bool(x: &U, y: bool) {
255    storage_store(x, &U::from(y))
256}
257
258pub fn transient_store(x: &U, y: &U) {
259    unsafe { transient_store_bytes32(x.as_ptr(), y.as_ptr()) }
260}
261
262pub fn transient_store_bool(x: &U, y: bool) {
263    transient_store(x, &U::from(y))
264}
265
266pub fn flush_cache() {
267    unsafe { storage_flush_cache(false) }
268}
269
270pub fn flush_guard<R, F: FnOnce() -> R>(f: F) -> R {
271    let r = f();
272    flush_cache();
273    r
274}
275
276storage_ops!(storage, transient);
277
278macro_rules! storage_mutate_ops {
279    ($prefix:ident, $($op:expr),* $(,)?) => {
280        $(
281            paste::paste! {
282                pub fn [<$prefix _wrapping_ $op>](x: &U, new: &U) {
283                    [<$prefix _store>](x, &bobcat_maths::[<wrapping_ $op>](&[<$prefix _load>](x), new))
284                }
285
286                pub fn [<$prefix _saturating_ $op>](x: &U, new: &U) {
287                    [<$prefix _store>](x, &bobcat_maths::[<saturating_ $op>](&[<$prefix _load>](x), new))
288                }
289
290                pub fn [<$prefix _checked_ $op>](x: &U, new: &U) -> Option<()> {
291                    let y = [<$prefix _load>](x);
292                    let v = bobcat_maths::[<checked_ $op>](&y, new);
293                    [<$prefix _store>](x, &v);
294                    Some(())
295                }
296
297                pub fn [<$prefix _checked_ $op _res>](x: &U, new: &U) -> U {
298                    let y = [<$prefix _load>](x);
299                    let v = bobcat_maths::[<checked_ $op>](&y, new);
300                    [<$prefix _store>](x, &v);
301                    v
302                }
303            }
304        )*
305    };
306}
307
308storage_mutate_ops!(storage, add, sub, mul, div);
309storage_mutate_ops!(transient, add, sub, mul, div);
310
311#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
312pub fn slot_map_slot(k: &U, p: &U) -> U {
313    const_slot_map(k, p)
314}
315
316pub fn reentrancy_guard_entry(x: &U) {
317    assert!(x.len() <= 32, "too large");
318    assert!(transient_exchange_bool(x, true), "reentrancy alarm")
319}
320
321pub fn reentrancy_guard_exit(x: &U) {
322    assert!(x.len() <= 32, "too large");
323    transient_store(x, &U::ZERO);
324}
325
326pub fn reentrancy_guard<R>(k: &U, f: impl FnOnce() -> R) -> R {
327    reentrancy_guard_entry(k);
328    let v = f();
329    reentrancy_guard_exit(k);
330    v
331}
332
333pub fn reentrancy_guard_sel<R>(k: &[u8; 4], f: impl FnOnce() -> R) -> R {
334    reentrancy_guard::<R>(&U::from(k), f)
335}
336
337/// Compute the slot for a slice, and take it off the curve. Useful for
338/// storage slot accesses (and more).
339pub const fn const_slot_off_curve(b: &[u8]) -> U {
340    wrapping_sub(&const_keccak256(b), &U::ONE)
341}
342
343pub fn slot_off_curve(b: &[u8]) -> U {
344    // This won't result in 0 from the keccak, so we can use checked_sub to
345    // use the code the host gives us for a slightly lower codesize profile.
346    bobcat_maths::checked_sub_opt(&keccak256(b), &U::ONE).unwrap()
347}
348
349#[cfg(all(target_family = "wasm", target_os = "unknown"))]
350pub fn keccak256(b: &[u8]) -> U {
351    let mut out = [0u8; 32];
352    unsafe {
353        native_keccak256(b.as_ptr(), b.len(), out.as_mut_ptr());
354    }
355    U(out)
356}
357
358pub const fn const_keccak256(b: &[u8]) -> U {
359    U(Keccak256::new().update(b).finalize())
360}
361
362pub const fn const_keccak256_two(x: &[u8], y: &[u8]) -> U {
363    U(Keccak256::new().update(x).update(y).finalize())
364}
365
366pub const fn const_keccak256_two_off_curve(x: &[u8], y: &[u8]) -> U {
367    wrapping_sub(&const_keccak256_two(x, y), &U::ONE)
368}
369
370#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
371pub fn keccak256(b: &[u8]) -> U {
372    const_keccak256(b)
373}
374
375pub fn reentrancy_guard_const_keccak<R>(k: &[u8], f: impl FnOnce() -> R) -> R {
376    reentrancy_guard(&const_keccak256(k), f)
377}
378
379pub fn reentrancy_guard_keccak<R>(k: &[u8], f: impl FnOnce() -> R) -> R {
380    reentrancy_guard(&keccak256(k), f)
381}
382
383/// Find the storage map slot using keccak_const. Don't do this during
384/// your runtime code, unless you want to pay the codesize price.
385pub const fn const_slot_map(k: &U, p: &U) -> U {
386    let a: [u8; 32 * 2] = concat_arrays!(k.0, p.0);
387    const_keccak256(&a)
388}
389
390#[cfg(all(target_family = "wasm", target_os = "unknown"))]
391pub fn slot_map(k: &U, p: &U) -> U {
392    let b: [u8; 32 * 2] = concat_arrays!(k.0, p.0);
393    keccak256(&b)
394}
395
396#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
397pub fn slot_map(k: &U, p: &U) -> U {
398    const_slot_map(k, p)
399}
400
401#[test]
402fn test_slot_edd25519_count() {
403    assert_eq!(
404        U::from(
405            const_hex::const_decode_to_array::<32>(
406                b"709318ac04e7c3155ef66c30be7220b3243d7e2378fa4153b5f14ebd3ea771ab"
407            )
408            .unwrap()
409        ),
410        const_slot_off_curve(b"superposition.passport.ed25519_count")
411    );
412}
413
414#[cfg(all(feature = "std", test))]
415mod test {
416    use super::*;
417
418    use proptest::prelude::*;
419
420    proptest! {
421        #[test]
422        fn test_reentrancy_guard(x in any::<[u8; 8]>()) {
423            reentrancy_guard(&U::from(x), || {
424                assert!(transient_load(&U::from(x)).is_true());
425            });
426            assert!(transient_load(&U::from(x)).is_zero());
427        }
428
429        #[test]
430        fn test_reentrancy_guard_bad(x in any::<[u8; 8]>()) {
431             let x = U::from(x);
432             transient_store(&x, &U::from(false));
433             assert!(transient_exchange_bool(&x, true));
434             assert!(!transient_exchange_bool(&x, true));
435            assert!(transient_load(&x).is_some());
436        }
437
438        #[test]
439        fn test_reentrancy_guard_sel(x in any::<[u8; 4]>()) {
440            reentrancy_guard_sel(&x, || {
441                assert!(transient_load(&U::from(x)).is_true());
442            });
443            assert!(transient_load(&U::from(x)).is_zero());
444        }
445    }
446}