awaken_contract/state/
slot.rs1use serde::{Deserialize, Serialize, de::DeserializeOwned};
2use std::hash::{Hash, Hasher};
3use std::marker::PhantomData;
4use typedmap::clone::SyncCloneBounds;
5use typedmap::{TypedMap, TypedMapKey};
6
7use crate::error::StateError;
8use crate::model::{JsonValue, decode_json, encode_json};
9
10struct ExtensionMarker;
11
12struct TypedKey<K>(PhantomData<fn() -> K>);
13
14impl<K> TypedKey<K> {
15 const fn new() -> Self {
16 Self(PhantomData)
17 }
18}
19
20impl<K> Clone for TypedKey<K> {
21 fn clone(&self) -> Self {
22 *self
23 }
24}
25
26impl<K> Copy for TypedKey<K> {}
27
28impl<K> PartialEq for TypedKey<K> {
29 fn eq(&self, _other: &Self) -> bool {
30 true
31 }
32}
33
34impl<K> Eq for TypedKey<K> {}
35
36impl<K: 'static> Hash for TypedKey<K> {
37 fn hash<H: Hasher>(&self, state: &mut H) {
38 std::any::TypeId::of::<K>().hash(state);
39 }
40}
41
42impl<K> TypedMapKey<ExtensionMarker> for TypedKey<K>
43where
44 K: StateKey,
45{
46 type Value = K::Value;
47}
48
49pub struct StateMap {
50 values: TypedMap<ExtensionMarker, SyncCloneBounds, SyncCloneBounds>,
51}
52
53impl Default for StateMap {
54 fn default() -> Self {
55 Self {
56 values: TypedMap::new_with_bounds(),
57 }
58 }
59}
60
61impl Clone for StateMap {
62 fn clone(&self) -> Self {
63 let mut values = TypedMap::new_with_bounds();
64 for entry in self.values.iter() {
65 values.insert_key_value(entry.to_owned());
66 }
67 Self { values }
68 }
69}
70
71impl StateMap {
72 pub fn contains<K: StateKey>(&self) -> bool {
73 self.values.contains_key(&TypedKey::<K>::new())
74 }
75
76 pub fn get<K: StateKey>(&self) -> Option<&K::Value> {
77 self.values.get(&TypedKey::<K>::new())
78 }
79
80 pub fn get_mut<K: StateKey>(&mut self) -> Option<&mut K::Value> {
81 self.values.get_mut(&TypedKey::<K>::new())
82 }
83
84 pub fn insert<K: StateKey>(&mut self, value: K::Value) {
85 self.values.insert(TypedKey::<K>::new(), value);
86 }
87
88 pub fn remove<K: StateKey>(&mut self) -> Option<K::Value> {
89 self.values.remove(&TypedKey::<K>::new())
90 }
91
92 pub fn get_or_insert_default<K: StateKey>(&mut self) -> &mut K::Value {
93 if !self.contains::<K>() {
94 self.insert::<K>(K::Value::default());
95 }
96
97 self.get_mut::<K>()
98 .expect("value should exist after insertion")
99 }
100}
101
102#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
106pub enum KeyScope {
107 #[default]
109 Run,
110 Thread,
112}
113
114#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
119pub enum MergeStrategy {
120 #[default]
123 Exclusive,
124 Commutative,
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
131pub struct StateKeyOptions {
132 pub persistent: bool,
133 pub retain_on_uninstall: bool,
134 pub scope: KeyScope,
135}
136
137impl Default for StateKeyOptions {
138 fn default() -> Self {
139 Self {
140 persistent: true,
141 retain_on_uninstall: false,
142 scope: KeyScope::Run,
143 }
144 }
145}
146
147pub trait StateKey: 'static + Send + Sync {
148 const KEY: &'static str;
149
150 const MERGE: MergeStrategy = MergeStrategy::Exclusive;
152
153 const SCOPE: KeyScope = KeyScope::Run;
155
156 type Value: Clone + Default + Serialize + DeserializeOwned + Send + Sync + 'static;
157 type Update: Send + 'static;
158
159 fn apply(value: &mut Self::Value, update: Self::Update);
160
161 fn encode(value: &Self::Value) -> Result<JsonValue, StateError> {
162 encode_json(Self::KEY, value)
163 }
164
165 fn decode(value: JsonValue) -> Result<Self::Value, StateError> {
166 decode_json(Self::KEY, value)
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 struct Counter;
175
176 impl StateKey for Counter {
177 const KEY: &'static str = "counter";
178 type Value = usize;
179 type Update = usize;
180
181 fn apply(value: &mut Self::Value, update: Self::Update) {
182 *value += update;
183 }
184 }
185
186 #[test]
187 fn state_map_can_store_and_update_typed_values() {
188 let mut slots = StateMap::default();
189 Counter::apply(slots.get_or_insert_default::<Counter>(), 2);
190 Counter::apply(slots.get_or_insert_default::<Counter>(), 3);
191
192 assert_eq!(slots.get::<Counter>().copied(), Some(5));
193 assert_eq!(slots.remove::<Counter>(), Some(5));
194 assert!(!slots.contains::<Counter>());
195 }
196}