Skip to main content

awaken_contract/state/
mutation.rs

1use std::collections::HashSet;
2use std::fmt;
3use std::marker::PhantomData;
4
5use crate::StateError;
6
7use super::{MergeStrategy, Snapshot, StateKey, StateMap};
8
9/// A type-erased state mutation operation.
10pub trait MutationOp: Send {
11    /// Apply this mutation to the given snapshot.
12    fn apply(self: Box<Self>, state: &mut Snapshot);
13}
14
15/// Target for a typed state mutation.
16pub trait MutationTarget {
17    type Update: Send + 'static;
18
19    fn apply(state: &mut Snapshot, update: Self::Update);
20}
21
22impl<K> MutationTarget for K
23where
24    K: StateKey,
25{
26    type Update = K::Update;
27
28    fn apply(state: &mut Snapshot, update: Self::Update) {
29        let value = std::sync::Arc::make_mut(&mut state.ext).get_or_insert_default::<K>();
30        K::apply(value, update);
31    }
32}
33
34struct KeyPatch<S: MutationTarget> {
35    update: Option<S::Update>,
36    _marker: PhantomData<S>,
37}
38
39impl<S> KeyPatch<S>
40where
41    S: MutationTarget,
42{
43    fn new(update: S::Update) -> Self {
44        Self {
45            update: Some(update),
46            _marker: PhantomData,
47        }
48    }
49}
50
51impl<S> MutationOp for KeyPatch<S>
52where
53    S: MutationTarget + Send,
54{
55    fn apply(mut self: Box<Self>, state: &mut Snapshot) {
56        let update = self.update.take().expect("key patch already applied");
57        S::apply(state, update);
58    }
59}
60
61struct ClearKeyMutation {
62    clear: fn(&mut StateMap),
63}
64
65impl ClearKeyMutation {
66    fn new(clear: fn(&mut StateMap)) -> Self {
67        Self { clear }
68    }
69}
70
71impl MutationOp for ClearKeyMutation {
72    fn apply(self: Box<Self>, state: &mut Snapshot) {
73        (self.clear)(std::sync::Arc::make_mut(&mut state.ext));
74    }
75}
76
77/// A batch of state mutation operations.
78///
79/// Collects typed key updates and applies them atomically to a [`Snapshot`].
80pub struct MutationBatch {
81    pub base_revision: Option<u64>,
82    pub ops: Vec<Box<dyn MutationOp>>,
83    pub touched_keys: Vec<String>,
84}
85
86impl fmt::Debug for MutationBatch {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        f.debug_struct("MutationBatch")
89            .field("base_revision", &self.base_revision)
90            .field("ops_len", &self.ops.len())
91            .field("touched_keys", &self.touched_keys)
92            .finish()
93    }
94}
95
96impl MutationBatch {
97    pub fn new() -> Self {
98        Self {
99            base_revision: None,
100            ops: Vec::new(),
101            touched_keys: Vec::new(),
102        }
103    }
104
105    pub fn with_base_revision(mut self, revision: u64) -> Self {
106        self.base_revision = Some(revision);
107        self
108    }
109
110    pub fn base_revision(&self) -> Option<u64> {
111        self.base_revision
112    }
113
114    pub fn is_empty(&self) -> bool {
115        self.ops.is_empty()
116    }
117
118    pub fn update<K>(&mut self, update: K::Update)
119    where
120        K: StateKey,
121    {
122        self.ops.push(Box::new(KeyPatch::<K>::new(update)));
123        self.touched_keys.push(K::KEY.to_string());
124    }
125
126    pub fn clear_extension_with(&mut self, key: impl Into<String>, clear: fn(&mut StateMap)) {
127        self.ops.push(Box::new(ClearKeyMutation::new(clear)));
128        self.touched_keys.push(key.into());
129    }
130
131    pub fn extend(&mut self, mut other: Self) -> Result<(), StateError> {
132        self.base_revision = match (self.base_revision, other.base_revision) {
133            (Some(left), Some(right)) if left != right => {
134                return Err(StateError::MutationBaseRevisionMismatch { left, right });
135            }
136            (Some(left), _) => Some(left),
137            (None, Some(right)) => Some(right),
138            (None, None) => None,
139        };
140
141        self.ops.append(&mut other.ops);
142        self.touched_keys.append(&mut other.touched_keys);
143        Ok(())
144    }
145
146    pub fn op_len(&self) -> usize {
147        self.ops.len()
148    }
149
150    /// Merge two batches produced by parallel execution.
151    ///
152    /// - Disjoint keys: always merged.
153    /// - Overlapping keys with `Commutative` strategy: merged (order irrelevant).
154    /// - Overlapping keys with `Exclusive` strategy: returns `ParallelMergeConflict`.
155    pub fn merge_parallel<F>(mut self, mut other: Self, strategy: F) -> Result<Self, StateError>
156    where
157        F: Fn(&str) -> MergeStrategy,
158    {
159        // Reconcile base revisions
160        self.base_revision = match (self.base_revision, other.base_revision) {
161            (Some(left), Some(right)) if left != right => {
162                return Err(StateError::MutationBaseRevisionMismatch { left, right });
163            }
164            (Some(left), _) => Some(left),
165            (None, Some(right)) => Some(right),
166            (None, None) => None,
167        };
168
169        // Check overlapping keys
170        let self_keys: HashSet<&str> = self.touched_keys.iter().map(|s| s.as_str()).collect();
171        for key in &other.touched_keys {
172            if self_keys.contains(key.as_str()) && strategy(key) == MergeStrategy::Exclusive {
173                return Err(StateError::ParallelMergeConflict { key: key.clone() });
174            }
175        }
176
177        // Merge ops and keys
178        self.ops.append(&mut other.ops);
179        self.touched_keys.append(&mut other.touched_keys);
180        Ok(self)
181    }
182}
183
184impl Default for MutationBatch {
185    fn default() -> Self {
186        Self::new()
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    struct Counter;
195
196    impl StateKey for Counter {
197        const KEY: &'static str = "counter";
198        type Value = usize;
199        type Update = usize;
200
201        fn apply(value: &mut Self::Value, update: Self::Update) {
202            *value += update;
203        }
204    }
205
206    #[test]
207    fn mutation_batch_merges_matching_base_revisions() {
208        let mut left = MutationBatch::new().with_base_revision(3);
209        left.update::<Counter>(1);
210
211        let mut right = MutationBatch::new().with_base_revision(3);
212        right.update::<Counter>(2);
213
214        left.extend(right)
215            .expect("matching base revisions should merge");
216        assert_eq!(left.base_revision(), Some(3));
217        assert_eq!(left.op_len(), 2);
218    }
219
220    #[test]
221    fn mutation_batch_rejects_mismatched_base_revisions() {
222        let mut left = MutationBatch::new().with_base_revision(1);
223        let right = MutationBatch::new().with_base_revision(2);
224
225        let err = left.extend(right).expect_err("mismatch should fail");
226        assert!(matches!(
227            err,
228            StateError::MutationBaseRevisionMismatch { left: 1, right: 2 }
229        ));
230    }
231
232    #[test]
233    fn mutation_ops_apply_into_snapshot() {
234        let mut batch = MutationBatch::new();
235        batch.update::<Counter>(4);
236
237        let mut snapshot = Snapshot {
238            revision: 0,
239            ext: std::sync::Arc::new(StateMap::default()),
240        };
241
242        for op in batch.ops.drain(..) {
243            op.apply(&mut snapshot);
244        }
245
246        assert_eq!(snapshot.get::<Counter>().copied(), Some(4));
247    }
248
249    #[test]
250    fn mutation_batch_is_empty_when_new() {
251        let batch = MutationBatch::new();
252        assert!(batch.is_empty());
253        assert_eq!(batch.op_len(), 0);
254        assert!(batch.base_revision().is_none());
255    }
256
257    #[test]
258    fn mutation_batch_not_empty_after_update() {
259        let mut batch = MutationBatch::new();
260        batch.update::<Counter>(1);
261        assert!(!batch.is_empty());
262        assert_eq!(batch.op_len(), 1);
263    }
264
265    #[test]
266    fn mutation_batch_parallel_merge_commutative_overlap() {
267        let mut left = MutationBatch::new();
268        left.update::<Counter>(10);
269        let mut right = MutationBatch::new();
270        right.update::<Counter>(20);
271
272        let merged = left
273            .merge_parallel(right, |_| MergeStrategy::Commutative)
274            .expect("commutative overlap should merge");
275        assert_eq!(merged.op_len(), 2);
276    }
277
278    #[test]
279    fn mutation_batch_parallel_merge_exclusive_conflict() {
280        let mut left = MutationBatch::new();
281        left.update::<Counter>(10);
282        let mut right = MutationBatch::new();
283        right.update::<Counter>(20);
284
285        let result = left.merge_parallel(right, |_| MergeStrategy::Exclusive);
286        assert!(result.is_err());
287    }
288}