awaken_contract/state/
mutation.rs1use std::collections::HashSet;
2use std::fmt;
3use std::marker::PhantomData;
4
5use crate::StateError;
6
7use super::{MergeStrategy, Snapshot, StateKey, StateMap};
8
9pub trait MutationOp: Send {
11 fn apply(self: Box<Self>, state: &mut Snapshot);
13}
14
15pub trait MutationTarget {
17 type Update: Send + 'static;
18
19 fn apply(state: &mut Snapshot, update: Self::Update);
20}
21
22impl<K> MutationTarget for K
23where
24 K: StateKey,
25{
26 type Update = K::Update;
27
28 fn apply(state: &mut Snapshot, update: Self::Update) {
29 let value = std::sync::Arc::make_mut(&mut state.ext).get_or_insert_default::<K>();
30 K::apply(value, update);
31 }
32}
33
34struct KeyPatch<S: MutationTarget> {
35 update: Option<S::Update>,
36 _marker: PhantomData<S>,
37}
38
39impl<S> KeyPatch<S>
40where
41 S: MutationTarget,
42{
43 fn new(update: S::Update) -> Self {
44 Self {
45 update: Some(update),
46 _marker: PhantomData,
47 }
48 }
49}
50
51impl<S> MutationOp for KeyPatch<S>
52where
53 S: MutationTarget + Send,
54{
55 fn apply(mut self: Box<Self>, state: &mut Snapshot) {
56 let update = self.update.take().expect("key patch already applied");
57 S::apply(state, update);
58 }
59}
60
61struct ClearKeyMutation {
62 clear: fn(&mut StateMap),
63}
64
65impl ClearKeyMutation {
66 fn new(clear: fn(&mut StateMap)) -> Self {
67 Self { clear }
68 }
69}
70
71impl MutationOp for ClearKeyMutation {
72 fn apply(self: Box<Self>, state: &mut Snapshot) {
73 (self.clear)(std::sync::Arc::make_mut(&mut state.ext));
74 }
75}
76
77pub struct MutationBatch {
81 pub base_revision: Option<u64>,
82 pub ops: Vec<Box<dyn MutationOp>>,
83 pub touched_keys: Vec<String>,
84}
85
86impl fmt::Debug for MutationBatch {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("MutationBatch")
89 .field("base_revision", &self.base_revision)
90 .field("ops_len", &self.ops.len())
91 .field("touched_keys", &self.touched_keys)
92 .finish()
93 }
94}
95
96impl MutationBatch {
97 pub fn new() -> Self {
98 Self {
99 base_revision: None,
100 ops: Vec::new(),
101 touched_keys: Vec::new(),
102 }
103 }
104
105 pub fn with_base_revision(mut self, revision: u64) -> Self {
106 self.base_revision = Some(revision);
107 self
108 }
109
110 pub fn base_revision(&self) -> Option<u64> {
111 self.base_revision
112 }
113
114 pub fn is_empty(&self) -> bool {
115 self.ops.is_empty()
116 }
117
118 pub fn update<K>(&mut self, update: K::Update)
119 where
120 K: StateKey,
121 {
122 self.ops.push(Box::new(KeyPatch::<K>::new(update)));
123 self.touched_keys.push(K::KEY.to_string());
124 }
125
126 pub fn clear_extension_with(&mut self, key: impl Into<String>, clear: fn(&mut StateMap)) {
127 self.ops.push(Box::new(ClearKeyMutation::new(clear)));
128 self.touched_keys.push(key.into());
129 }
130
131 pub fn extend(&mut self, mut other: Self) -> Result<(), StateError> {
132 self.base_revision = match (self.base_revision, other.base_revision) {
133 (Some(left), Some(right)) if left != right => {
134 return Err(StateError::MutationBaseRevisionMismatch { left, right });
135 }
136 (Some(left), _) => Some(left),
137 (None, Some(right)) => Some(right),
138 (None, None) => None,
139 };
140
141 self.ops.append(&mut other.ops);
142 self.touched_keys.append(&mut other.touched_keys);
143 Ok(())
144 }
145
146 pub fn op_len(&self) -> usize {
147 self.ops.len()
148 }
149
150 pub fn merge_parallel<F>(mut self, mut other: Self, strategy: F) -> Result<Self, StateError>
156 where
157 F: Fn(&str) -> MergeStrategy,
158 {
159 self.base_revision = match (self.base_revision, other.base_revision) {
161 (Some(left), Some(right)) if left != right => {
162 return Err(StateError::MutationBaseRevisionMismatch { left, right });
163 }
164 (Some(left), _) => Some(left),
165 (None, Some(right)) => Some(right),
166 (None, None) => None,
167 };
168
169 let self_keys: HashSet<&str> = self.touched_keys.iter().map(|s| s.as_str()).collect();
171 for key in &other.touched_keys {
172 if self_keys.contains(key.as_str()) && strategy(key) == MergeStrategy::Exclusive {
173 return Err(StateError::ParallelMergeConflict { key: key.clone() });
174 }
175 }
176
177 self.ops.append(&mut other.ops);
179 self.touched_keys.append(&mut other.touched_keys);
180 Ok(self)
181 }
182}
183
184impl Default for MutationBatch {
185 fn default() -> Self {
186 Self::new()
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 struct Counter;
195
196 impl StateKey for Counter {
197 const KEY: &'static str = "counter";
198 type Value = usize;
199 type Update = usize;
200
201 fn apply(value: &mut Self::Value, update: Self::Update) {
202 *value += update;
203 }
204 }
205
206 #[test]
207 fn mutation_batch_merges_matching_base_revisions() {
208 let mut left = MutationBatch::new().with_base_revision(3);
209 left.update::<Counter>(1);
210
211 let mut right = MutationBatch::new().with_base_revision(3);
212 right.update::<Counter>(2);
213
214 left.extend(right)
215 .expect("matching base revisions should merge");
216 assert_eq!(left.base_revision(), Some(3));
217 assert_eq!(left.op_len(), 2);
218 }
219
220 #[test]
221 fn mutation_batch_rejects_mismatched_base_revisions() {
222 let mut left = MutationBatch::new().with_base_revision(1);
223 let right = MutationBatch::new().with_base_revision(2);
224
225 let err = left.extend(right).expect_err("mismatch should fail");
226 assert!(matches!(
227 err,
228 StateError::MutationBaseRevisionMismatch { left: 1, right: 2 }
229 ));
230 }
231
232 #[test]
233 fn mutation_ops_apply_into_snapshot() {
234 let mut batch = MutationBatch::new();
235 batch.update::<Counter>(4);
236
237 let mut snapshot = Snapshot {
238 revision: 0,
239 ext: std::sync::Arc::new(StateMap::default()),
240 };
241
242 for op in batch.ops.drain(..) {
243 op.apply(&mut snapshot);
244 }
245
246 assert_eq!(snapshot.get::<Counter>().copied(), Some(4));
247 }
248
249 #[test]
250 fn mutation_batch_is_empty_when_new() {
251 let batch = MutationBatch::new();
252 assert!(batch.is_empty());
253 assert_eq!(batch.op_len(), 0);
254 assert!(batch.base_revision().is_none());
255 }
256
257 #[test]
258 fn mutation_batch_not_empty_after_update() {
259 let mut batch = MutationBatch::new();
260 batch.update::<Counter>(1);
261 assert!(!batch.is_empty());
262 assert_eq!(batch.op_len(), 1);
263 }
264
265 #[test]
266 fn mutation_batch_parallel_merge_commutative_overlap() {
267 let mut left = MutationBatch::new();
268 left.update::<Counter>(10);
269 let mut right = MutationBatch::new();
270 right.update::<Counter>(20);
271
272 let merged = left
273 .merge_parallel(right, |_| MergeStrategy::Commutative)
274 .expect("commutative overlap should merge");
275 assert_eq!(merged.op_len(), 2);
276 }
277
278 #[test]
279 fn mutation_batch_parallel_merge_exclusive_conflict() {
280 let mut left = MutationBatch::new();
281 left.update::<Counter>(10);
282 let mut right = MutationBatch::new();
283 right.update::<Counter>(20);
284
285 let result = left.merge_parallel(right, |_| MergeStrategy::Exclusive);
286 assert!(result.is_err());
287 }
288}