1#![cfg_attr(not(feature = "std"), no_std)]
2
3use keccak_const::Keccak256;
4
5use array_concat::concat_arrays;
6
7pub use bobcat_maths::U;
8
9use bobcat_maths::wrapping_sub;
10
11#[cfg(all(target_family = "wasm", target_os = "unknown"))]
12#[link(wasm_import_module = "vm_hooks")]
13unsafe extern "C" {
14 fn storage_load_bytes32(key: *const u8, out: *mut u8);
15 fn storage_cache_bytes32(key: *const u8, from: *const u8);
16 fn transient_load_bytes32(key: *const u8, dest: *mut u8);
17 fn transient_store_bytes32(key: *const u8, value: *const u8);
18 fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8);
19 pub fn storage_flush_cache(clear: bool);
20}
21
22#[cfg(all(
23 not(all(target_family = "wasm", target_os = "unknown")),
24 not(feature = "mutex"),
25 feature = "std"
26))]
27pub mod storage_host {
28 use super::*;
29
30 use std::{cell::RefCell, collections::HashMap, ptr::copy_nonoverlapping};
31
32 type WordHashMap = HashMap<U, U>;
33
34 thread_local! {
35 pub static STORAGE: RefCell<WordHashMap> = RefCell::default();
36 pub static TRANSIENT: RefCell<WordHashMap> = RefCell::default();
37 }
38
39 pub fn storage_clear() {
40 STORAGE.with(|s| s.borrow_mut().clear())
41 }
42
43 pub fn transient_clear() {
44 TRANSIENT.with(|s| s.borrow_mut().clear())
45 }
46
47 unsafe fn read_word(key: *const u8) -> U {
48 let mut r = [0u8; 32];
49 unsafe {
50 copy_nonoverlapping(key, r.as_mut_ptr(), 32);
51 }
52 U(r)
53 }
54
55 unsafe fn write_word(key: *mut u8, val: U) {
56 unsafe {
57 copy_nonoverlapping(val.as_ptr(), key, 32);
58 }
59 }
60
61 pub(crate) unsafe fn storage_load_bytes32(key: *const u8, out: *mut u8) {
62 let k = unsafe { read_word(key) };
63 let value = STORAGE.with(|s| match s.borrow().get(&k) {
64 Some(v) => *v,
65 None => U::ZERO,
66 });
67 unsafe { write_word(out, value) };
68 }
69
70 pub(crate) unsafe fn storage_cache_bytes32(key: *const u8, value: *const u8) {
71 let k = unsafe { read_word(key) };
72 let v = unsafe { read_word(value) };
73 STORAGE.with(|s| s.borrow_mut().insert(k, v));
74 }
75
76 pub(crate) unsafe fn transient_load_bytes32(key: *const u8, out: *mut u8) {
77 let k = unsafe { read_word(key) };
78 let value = TRANSIENT.with(|s| match s.borrow().get(&k) {
79 Some(v) => *v,
80 None => U::ZERO,
81 });
82 unsafe { write_word(out, value) };
83 }
84
85 pub(crate) unsafe fn transient_store_bytes32(key: *const u8, value: *const u8) {
86 let k = unsafe { read_word(key) };
87 let v = unsafe { read_word(value) };
88 TRANSIENT.with(|s| s.borrow_mut().insert(k, v));
89 }
90
91 pub unsafe fn storage_flush_cache(clear: bool) {
92 if clear {
93 storage_clear()
94 }
95 }
96}
97
98#[cfg(all(
99 not(all(target_family = "wasm", target_os = "unknown")),
100 feature = "mutex",
101 feature = "std"
102))]
103pub mod storage_host {
104 use super::*;
105
106 use std::{
107 collections::HashMap,
108 ptr::copy_nonoverlapping,
109 sync::{LazyLock, Mutex},
110 };
111
112 type WordHashMap = HashMap<U, U>;
113
114 pub static STORAGE: LazyLock<Mutex<WordHashMap>> = LazyLock::new(|| Mutex::default());
115 pub static TRANSIENT: LazyLock<Mutex<WordHashMap>> = LazyLock::new(|| Mutex::default());
116
117 pub fn storage_clear() {
118 STORAGE.lock().unwrap().clear()
119 }
120
121 pub fn transient_clear() {
122 TRANSIENT.lock().unwrap().clear()
123 }
124
125 unsafe fn read_word(key: *const u8) -> U {
126 let mut r = [0u8; 32];
127 unsafe {
128 copy_nonoverlapping(key, r.as_mut_ptr(), 32);
129 }
130 U(r)
131 }
132
133 unsafe fn write_word(key: *mut u8, val: U) {
134 unsafe {
135 copy_nonoverlapping(val.as_ptr(), key, 32);
136 }
137 }
138
139 pub(crate) unsafe fn storage_load_bytes32(key: *const u8, out: *mut u8) {
140 let k = unsafe { read_word(key) };
141 let value = match STORAGE.lock().unwrap().get(&k) {
142 Some(v) => *v,
143 None => U::ZERO,
144 };
145 unsafe { write_word(out, value) };
146 }
147
148 pub(crate) unsafe fn storage_cache_bytes32(key: *const u8, value: *const u8) {
149 let k = unsafe { read_word(key) };
150 let v = unsafe { read_word(value) };
151 STORAGE.lock().unwrap().insert(k, v);
152 }
153
154 pub(crate) unsafe fn transient_load_bytes32(key: *const u8, out: *mut u8) {
155 let k = unsafe { read_word(key) };
156 let value = match TRANSIENT.lock().unwrap().get(&k) {
157 Some(v) => *v,
158 None => U::ZERO,
159 };
160 unsafe { write_word(out, value) };
161 }
162
163 pub(crate) unsafe fn transient_store_bytes32(key: *const u8, value: *const u8) {
164 let k = unsafe { read_word(key) };
165 let v = unsafe { read_word(value) };
166 TRANSIENT.lock().unwrap().insert(k, v);
167 }
168
169 pub unsafe fn storage_flush_cache(clear: bool) {
170 if clear {
171 storage_clear()
172 }
173 }
174}
175
176#[cfg(all(
177 not(all(target_family = "wasm", target_os = "unknown")),
178 not(feature = "std")
179))]
180mod storage_host {
181 pub(crate) unsafe fn storage_load_bytes32(_: *const u8, _: *mut u8) {}
182
183 pub(crate) unsafe fn storage_cache_bytes32(_: *const u8, _: *const u8) {}
184
185 pub(crate) unsafe fn transient_load_bytes32(_: *const u8, _: *mut u8) {}
186
187 pub(crate) unsafe fn transient_store_bytes32(_: *const u8, _: *const u8) {}
188
189 pub unsafe fn storage_flush_cache(_: bool) {}
190}
191
192#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
193use storage_host::*;
194
195#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
196pub use storage_host::storage_flush_cache;
197
198macro_rules! storage_ops {
199 ($($prefix:ident),* $(,)?) => {
200 $(
201 paste::paste! {
202 pub fn [<$prefix _load>](x: &U) -> U {
203 let mut b = [0u8; 32];
204 unsafe { [<$prefix _load_bytes32>](x.as_ptr(), b.as_mut_ptr()) }
205 U(b)
206 }
207
208 pub fn [<$prefix _load_bool>](x: &U) -> bool {
209 [<$prefix _load>](x).into()
210 }
211
212 pub fn [<$prefix _exchange>](k: &U, exp: &U, new: &U) -> bool {
214 let t = [<$prefix _load>](k);
215 if &t != exp {
216 return false;
217 }
218 [<$prefix _store>](k, new);
219 true
220 }
221
222 pub fn [<$prefix _exchange_res>](k: &U, exp: &U, new: &U) -> Result<(), U> {
223 let t = [<$prefix _load>](k);
224 if &t != exp {
225 return Err(t);
226 }
227 [<$prefix _store>](k, new);
228 Ok(())
229 }
230
231 pub fn [<$prefix _exchange_bool>](k: &U, new: bool) -> bool {
234 [<$prefix _exchange>](k, &U::from(!new), &U::from(new))
235 }
236
237 pub fn [<$prefix _exchange_bool_res>](k: &U, new: bool) -> Result<(), bool> {
238 let x = [<$prefix _exchange_bool>](k, new);
239 if x == !new {
240 Ok(())
241 } else {
242 Err(x)
243 }
244 }
245 }
246 )*
247 };
248}
249
250pub fn storage_store(x: &U, y: &U) {
251 unsafe { storage_cache_bytes32(x.as_ptr(), y.as_ptr()) }
252}
253
254pub fn storage_store_bool(x: &U, y: bool) {
255 storage_store(x, &U::from(y))
256}
257
258pub fn transient_store(x: &U, y: &U) {
259 unsafe { transient_store_bytes32(x.as_ptr(), y.as_ptr()) }
260}
261
262pub fn transient_store_bool(x: &U, y: bool) {
263 transient_store(x, &U::from(y))
264}
265
266pub fn flush_cache() {
267 unsafe { storage_flush_cache(false) }
268}
269
270pub fn flush_guard<R, F: FnOnce() -> R>(f: F) -> R {
271 let r = f();
272 flush_cache();
273 r
274}
275
276storage_ops!(storage, transient);
277
278macro_rules! storage_mutate_ops {
279 ($prefix:ident, $($op:expr),* $(,)?) => {
280 $(
281 paste::paste! {
282 pub fn [<$prefix _wrapping_ $op>](x: &U, new: &U) {
283 [<$prefix _store>](x, &bobcat_maths::[<wrapping_ $op>](&[<$prefix _load>](x), new))
284 }
285
286 pub fn [<$prefix _saturating_ $op>](x: &U, new: &U) {
287 [<$prefix _store>](x, &bobcat_maths::[<saturating_ $op>](&[<$prefix _load>](x), new))
288 }
289
290 pub fn [<$prefix _checked_ $op>](x: &U, new: &U) -> Option<()> {
291 let y = [<$prefix _load>](x);
292 let v = bobcat_maths::[<checked_ $op>](&y, new);
293 [<$prefix _store>](x, &v);
294 Some(())
295 }
296
297 pub fn [<$prefix _checked_ $op _res>](x: &U, new: &U) -> U {
298 let y = [<$prefix _load>](x);
299 let v = bobcat_maths::[<checked_ $op>](&y, new);
300 [<$prefix _store>](x, &v);
301 v
302 }
303 }
304 )*
305 };
306}
307
308storage_mutate_ops!(storage, add, sub, mul, div);
309storage_mutate_ops!(transient, add, sub, mul, div);
310
311#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
312pub fn slot_map_slot(k: &U, p: &U) -> U {
313 const_slot_map(k, p)
314}
315
316pub fn reentrancy_guard_entry(x: &U) {
317 assert!(x.len() <= 32, "too large");
318 assert!(transient_exchange_bool(x, true), "reentrancy alarm")
319}
320
321pub fn reentrancy_guard_exit(x: &U) {
322 assert!(x.len() <= 32, "too large");
323 transient_store(x, &U::ZERO);
324}
325
326pub fn reentrancy_guard<R>(k: &U, f: impl FnOnce() -> R) -> R {
327 reentrancy_guard_entry(k);
328 let v = f();
329 reentrancy_guard_exit(k);
330 v
331}
332
333pub fn reentrancy_guard_sel<R>(k: &[u8; 4], f: impl FnOnce() -> R) -> R {
334 reentrancy_guard::<R>(&U::from(k), f)
335}
336
337pub const fn const_slot_off_curve(b: &[u8]) -> U {
340 wrapping_sub(&const_keccak256(b), &U::ONE)
341}
342
343pub fn slot_off_curve(b: &[u8]) -> U {
344 bobcat_maths::checked_sub_opt(&keccak256(b), &U::ONE).unwrap()
347}
348
349#[cfg(all(target_family = "wasm", target_os = "unknown"))]
350pub fn keccak256(b: &[u8]) -> U {
351 let mut out = [0u8; 32];
352 unsafe {
353 native_keccak256(b.as_ptr(), b.len(), out.as_mut_ptr());
354 }
355 U(out)
356}
357
358pub const fn const_keccak256(b: &[u8]) -> U {
359 U(Keccak256::new().update(b).finalize())
360}
361
362pub const fn const_keccak256_two(x: &[u8], y: &[u8]) -> U {
363 U(Keccak256::new().update(x).update(y).finalize())
364}
365
366pub const fn const_keccak256_two_off_curve(x: &[u8], y: &[u8]) -> U {
367 wrapping_sub(&const_keccak256_two(x, y), &U::ONE)
368}
369
370#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
371pub fn keccak256(b: &[u8]) -> U {
372 const_keccak256(b)
373}
374
375pub fn reentrancy_guard_const_keccak<R>(k: &[u8], f: impl FnOnce() -> R) -> R {
376 reentrancy_guard(&const_keccak256(k), f)
377}
378
379pub fn reentrancy_guard_keccak<R>(k: &[u8], f: impl FnOnce() -> R) -> R {
380 reentrancy_guard(&keccak256(k), f)
381}
382
383pub const fn const_slot_map(k: &U, p: &U) -> U {
386 let a: [u8; 32 * 2] = concat_arrays!(k.0, p.0);
387 const_keccak256(&a)
388}
389
390#[cfg(all(target_family = "wasm", target_os = "unknown"))]
391pub fn slot_map(k: &U, p: &U) -> U {
392 let b: [u8; 32 * 2] = concat_arrays!(k.0, p.0);
393 keccak256(&b)
394}
395
396#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
397pub fn slot_map(k: &U, p: &U) -> U {
398 const_slot_map(k, p)
399}
400
401#[test]
402fn test_slot_edd25519_count() {
403 assert_eq!(
404 U::from(
405 const_hex::const_decode_to_array::<32>(
406 b"709318ac04e7c3155ef66c30be7220b3243d7e2378fa4153b5f14ebd3ea771ab"
407 )
408 .unwrap()
409 ),
410 const_slot_off_curve(b"superposition.passport.ed25519_count")
411 );
412}
413
414#[cfg(all(feature = "std", test))]
415mod test {
416 use super::*;
417
418 use proptest::prelude::*;
419
420 proptest! {
421 #[test]
422 fn test_reentrancy_guard(x in any::<[u8; 8]>()) {
423 reentrancy_guard(&U::from(x), || {
424 assert!(transient_load(&U::from(x)).is_true());
425 });
426 assert!(transient_load(&U::from(x)).is_zero());
427 }
428
429 #[test]
430 fn test_reentrancy_guard_bad(x in any::<[u8; 8]>()) {
431 let x = U::from(x);
432 transient_store(&x, &U::from(false));
433 assert!(transient_exchange_bool(&x, true));
434 assert!(!transient_exchange_bool(&x, true));
435 assert!(transient_load(&x).is_some());
436 }
437
438 #[test]
439 fn test_reentrancy_guard_sel(x in any::<[u8; 4]>()) {
440 reentrancy_guard_sel(&x, || {
441 assert!(transient_load(&U::from(x)).is_true());
442 });
443 assert!(transient_load(&U::from(x)).is_zero());
444 }
445 }
446}