Skip to main content

awaken_contract/state/
slot.rs

1use serde::{Deserialize, Serialize, de::DeserializeOwned};
2use std::hash::{Hash, Hasher};
3use std::marker::PhantomData;
4use typedmap::clone::SyncCloneBounds;
5use typedmap::{TypedMap, TypedMapKey};
6
7use crate::error::StateError;
8use crate::model::{JsonValue, decode_json, encode_json};
9
10struct ExtensionMarker;
11
12struct TypedKey<K>(PhantomData<fn() -> K>);
13
14impl<K> TypedKey<K> {
15    const fn new() -> Self {
16        Self(PhantomData)
17    }
18}
19
20impl<K> Clone for TypedKey<K> {
21    fn clone(&self) -> Self {
22        *self
23    }
24}
25
26impl<K> Copy for TypedKey<K> {}
27
28impl<K> PartialEq for TypedKey<K> {
29    fn eq(&self, _other: &Self) -> bool {
30        true
31    }
32}
33
34impl<K> Eq for TypedKey<K> {}
35
36impl<K: 'static> Hash for TypedKey<K> {
37    fn hash<H: Hasher>(&self, state: &mut H) {
38        std::any::TypeId::of::<K>().hash(state);
39    }
40}
41
42impl<K> TypedMapKey<ExtensionMarker> for TypedKey<K>
43where
44    K: StateKey,
45{
46    type Value = K::Value;
47}
48
49pub struct StateMap {
50    values: TypedMap<ExtensionMarker, SyncCloneBounds, SyncCloneBounds>,
51}
52
53impl Default for StateMap {
54    fn default() -> Self {
55        Self {
56            values: TypedMap::new_with_bounds(),
57        }
58    }
59}
60
61impl Clone for StateMap {
62    fn clone(&self) -> Self {
63        let mut values = TypedMap::new_with_bounds();
64        for entry in self.values.iter() {
65            values.insert_key_value(entry.to_owned());
66        }
67        Self { values }
68    }
69}
70
71impl StateMap {
72    pub fn contains<K: StateKey>(&self) -> bool {
73        self.values.contains_key(&TypedKey::<K>::new())
74    }
75
76    pub fn get<K: StateKey>(&self) -> Option<&K::Value> {
77        self.values.get(&TypedKey::<K>::new())
78    }
79
80    pub fn get_mut<K: StateKey>(&mut self) -> Option<&mut K::Value> {
81        self.values.get_mut(&TypedKey::<K>::new())
82    }
83
84    pub fn insert<K: StateKey>(&mut self, value: K::Value) {
85        self.values.insert(TypedKey::<K>::new(), value);
86    }
87
88    pub fn remove<K: StateKey>(&mut self) -> Option<K::Value> {
89        self.values.remove(&TypedKey::<K>::new())
90    }
91
92    pub fn get_or_insert_default<K: StateKey>(&mut self) -> &mut K::Value {
93        if !self.contains::<K>() {
94            self.insert::<K>(K::Value::default());
95        }
96
97        self.get_mut::<K>()
98            .expect("value should exist after insertion")
99    }
100}
101
102/// Lifetime scope for a state key.
103///
104/// Controls when the key's value is cleared relative to run boundaries.
105#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
106pub enum KeyScope {
107    /// Cleared at run start (default, current behavior).
108    #[default]
109    Run,
110    /// Persists across runs on the same thread.
111    Thread,
112}
113
114/// Parallel merge strategy for a state key.
115///
116/// Determines how concurrent updates to the same key are handled
117/// when merging `MutationBatch`es from parallel tool execution.
118#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
119pub enum MergeStrategy {
120    /// Concurrent updates to this key are mutually exclusive.
121    /// Parallel batches that both touch this key cannot be merged.
122    #[default]
123    Exclusive,
124    /// Updates to this key are commutative — they can be applied
125    /// in any order and produce the same result. Parallel batches
126    /// that both touch this key will have their ops concatenated.
127    Commutative,
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
131pub struct StateKeyOptions {
132    pub persistent: bool,
133    pub retain_on_uninstall: bool,
134    pub scope: KeyScope,
135}
136
137impl Default for StateKeyOptions {
138    fn default() -> Self {
139        Self {
140            persistent: true,
141            retain_on_uninstall: false,
142            scope: KeyScope::Run,
143        }
144    }
145}
146
147pub trait StateKey: 'static + Send + Sync {
148    const KEY: &'static str;
149
150    /// Parallel merge strategy. Default: `Exclusive` (conflict on concurrent writes).
151    const MERGE: MergeStrategy = MergeStrategy::Exclusive;
152
153    /// Lifetime scope. Default: `Run` (cleared at run start).
154    const SCOPE: KeyScope = KeyScope::Run;
155
156    type Value: Clone + Default + Serialize + DeserializeOwned + Send + Sync + 'static;
157    type Update: Send + 'static;
158
159    fn apply(value: &mut Self::Value, update: Self::Update);
160
161    fn encode(value: &Self::Value) -> Result<JsonValue, StateError> {
162        encode_json(Self::KEY, value)
163    }
164
165    fn decode(value: JsonValue) -> Result<Self::Value, StateError> {
166        decode_json(Self::KEY, value)
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    struct Counter;
175
176    impl StateKey for Counter {
177        const KEY: &'static str = "counter";
178        type Value = usize;
179        type Update = usize;
180
181        fn apply(value: &mut Self::Value, update: Self::Update) {
182            *value += update;
183        }
184    }
185
186    #[test]
187    fn state_map_can_store_and_update_typed_values() {
188        let mut slots = StateMap::default();
189        Counter::apply(slots.get_or_insert_default::<Counter>(), 2);
190        Counter::apply(slots.get_or_insert_default::<Counter>(), 3);
191
192        assert_eq!(slots.get::<Counter>().copied(), Some(5));
193        assert_eq!(slots.remove::<Counter>(), Some(5));
194        assert!(!slots.contains::<Counter>());
195    }
196}