Skip to main content

awaken_runtime/state/
store.rs

1use std::any::TypeId;
2use std::sync::Arc;
3
4use parking_lot::{Mutex, RwLock};
5
6use crate::plugins::{InstalledPlugin, KeyRegistration, Plugin, PluginRegistrar, PluginRegistry};
7use awaken_contract::StateError;
8
9use super::{MutationBatch, Snapshot, StateCommand, StateKey, StateMap};
10
11#[derive(Clone)]
12pub struct CommitEvent {
13    pub previous_revision: u64,
14    pub new_revision: u64,
15    pub op_count: usize,
16    pub snapshot: Snapshot,
17}
18
19pub trait CommitHook: Send + Sync + 'static {
20    fn on_commit(&self, event: &CommitEvent);
21}
22
23pub struct StateStore {
24    pub(crate) inner: Arc<RwLock<Snapshot>>,
25    pub(crate) registry: Arc<Mutex<PluginRegistry>>,
26    pub(crate) hooks: Arc<RwLock<Vec<Arc<dyn CommitHook>>>>,
27}
28
29impl Clone for StateStore {
30    fn clone(&self) -> Self {
31        Self {
32            inner: Arc::clone(&self.inner),
33            registry: Arc::clone(&self.registry),
34            hooks: Arc::clone(&self.hooks),
35        }
36    }
37}
38
39impl StateStore {
40    pub fn new() -> Self {
41        Self {
42            inner: Arc::new(RwLock::new(Snapshot {
43                revision: 0,
44                ext: Arc::new(StateMap::default()),
45            })),
46            registry: Arc::new(Mutex::new(PluginRegistry::default())),
47            hooks: Arc::new(RwLock::new(Vec::new())),
48        }
49    }
50
51    pub fn snapshot(&self) -> Snapshot {
52        self.inner.read().clone()
53    }
54
55    pub fn revision(&self) -> u64 {
56        self.inner.read().revision
57    }
58
59    pub fn read<K>(&self) -> Option<K::Value>
60    where
61        K: StateKey,
62    {
63        let guard = self.inner.read();
64        guard.get::<K>().cloned()
65    }
66
67    pub fn add_hook<H>(&self, hook: H)
68    where
69        H: CommitHook,
70    {
71        self.hooks.write().push(Arc::new(hook));
72    }
73
74    pub fn begin_mutation(&self) -> MutationBatch {
75        MutationBatch::new()
76    }
77
78    /// Merge two batches from parallel execution using registered merge strategies.
79    pub fn merge_parallel(
80        &self,
81        left: MutationBatch,
82        right: MutationBatch,
83    ) -> Result<MutationBatch, StateError> {
84        let registry = self.registry.lock();
85        left.merge_parallel(right, |key| registry.merge_strategy(key))
86    }
87
88    /// Merge multiple commands from parallel execution into one.
89    pub fn merge_all_commands(
90        &self,
91        commands: Vec<StateCommand>,
92    ) -> Result<StateCommand, StateError> {
93        let registry = self.registry.lock();
94        commands
95            .into_iter()
96            .try_fold(StateCommand::new(), |acc, cmd| {
97                acc.merge_parallel(cmd, |key| registry.merge_strategy(key))
98            })
99    }
100
101    pub fn commit(&self, patch: MutationBatch) -> Result<u64, StateError> {
102        if patch.is_empty() {
103            return Ok(self.revision());
104        }
105
106        let op_count = patch.op_len();
107        let hooks = self.hooks.read().clone();
108
109        let registry = self.registry.lock();
110        let mut state = self.inner.write();
111
112        if let Some(expected) = patch.base_revision
113            && state.revision != expected
114        {
115            return Err(StateError::RevisionConflict {
116                expected,
117                actual: state.revision,
118            });
119        }
120
121        for key in &patch.touched_keys {
122            registry.ensure_key(key)?;
123        }
124
125        let previous_revision = state.revision;
126        for op in patch.ops {
127            op.apply(&mut state);
128        }
129        state.revision += 1;
130        let new_revision = state.revision;
131        let snapshot = state.clone();
132        drop(state);
133        drop(registry);
134
135        let event = CommitEvent {
136            previous_revision,
137            new_revision,
138            op_count,
139            snapshot,
140        };
141        for hook in hooks {
142            hook.on_commit(&event);
143        }
144
145        Ok(new_revision)
146    }
147
148    pub fn install_plugin<P>(&self, plugin: P) -> Result<(), StateError>
149    where
150        P: Plugin,
151    {
152        let mut registrar = PluginRegistrar::new();
153        plugin.register(&mut registrar)?;
154        let plugin_type_id = TypeId::of::<P>();
155        self.install_plugin_with_keys(plugin_type_id, Arc::new(plugin), registrar.keys)
156    }
157
158    pub(crate) fn install_plugin_with_keys(
159        &self,
160        plugin_type_id: TypeId,
161        plugin: Arc<dyn Plugin>,
162        registrations: Vec<KeyRegistration>,
163    ) -> Result<(), StateError> {
164        let descriptor = plugin.descriptor();
165
166        {
167            let mut registry = self.registry.lock();
168            if registry.plugins.contains_key(&plugin_type_id) {
169                return Err(StateError::PluginAlreadyInstalled {
170                    name: descriptor.name.to_string(),
171                });
172            }
173
174            for reg in &registrations {
175                if registry.keys_by_name.contains_key(&reg.key) {
176                    return Err(StateError::KeyAlreadyRegistered {
177                        key: reg.key.clone(),
178                    });
179                }
180            }
181
182            for reg in &registrations {
183                registry.keys_by_name.insert(reg.key.clone(), reg.clone());
184                registry.keys_by_type.insert(reg.type_id, reg.clone());
185            }
186
187            registry.plugins.insert(
188                plugin_type_id,
189                InstalledPlugin {
190                    owned_key_type_ids: registrations.iter().map(|r| r.type_id).collect(),
191                },
192            );
193        }
194
195        Ok(())
196    }
197
198    /// Register standalone state keys (not owned by any plugin).
199    ///
200    /// Keys that are already registered are silently skipped.
201    /// This is used to install plugin-declared state keys collected by
202    /// `ExecutionEnv::from_plugins()`.
203    pub(crate) fn register_keys(
204        &self,
205        registrations: &[KeyRegistration],
206    ) -> Result<(), StateError> {
207        let mut registry = self.registry.lock();
208        for reg in registrations {
209            if registry.keys_by_name.contains_key(&reg.key) {
210                // Already registered (e.g., by LoopStatePlugin or another source) — skip.
211                continue;
212            }
213            registry.keys_by_name.insert(reg.key.clone(), reg.clone());
214            registry.keys_by_type.insert(reg.type_id, reg.clone());
215        }
216        Ok(())
217    }
218
219    pub fn uninstall_plugin<P>(&self) -> Result<(), StateError>
220    where
221        P: Plugin,
222    {
223        let plugin_type_id = TypeId::of::<P>();
224        let registrations =
225            {
226                let registry = self.registry.lock();
227                let installed = registry.plugins.get(&plugin_type_id).ok_or(
228                    StateError::PluginNotInstalled {
229                        type_name: std::any::type_name::<P>(),
230                    },
231                )?;
232                installed
233                    .owned_key_type_ids
234                    .iter()
235                    .filter_map(|type_id| registry.keys_by_type.get(type_id).cloned())
236                    .collect::<Vec<_>>()
237            };
238
239        let mut patch = MutationBatch::new().with_base_revision(self.revision());
240        for reg in &registrations {
241            if !reg.options.retain_on_uninstall {
242                patch.clear_extension_with(reg.key.clone(), reg.clear);
243            }
244        }
245        self.commit(patch).map(|_| ())?;
246        self.unregister_plugin_type_id(plugin_type_id)
247    }
248
249    fn unregister_plugin_type_id(&self, plugin_type_id: TypeId) -> Result<(), StateError> {
250        {
251            let mut registry = self.registry.lock();
252            let installed =
253                registry
254                    .plugins
255                    .remove(&plugin_type_id)
256                    .ok_or(StateError::PluginNotInstalled {
257                        type_name: "unknown",
258                    })?;
259
260            for type_id in &installed.owned_key_type_ids {
261                if let Some(reg) = registry.keys_by_type.remove(type_id) {
262                    registry.keys_by_name.remove(&reg.key);
263                }
264            }
265        }
266
267        Ok(())
268    }
269}
270
271impl Default for StateStore {
272    fn default() -> Self {
273        Self::new()
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::plugins::{Plugin, PluginDescriptor, PluginRegistrar};
281    use crate::state::StateKey;
282    use std::sync::atomic::AtomicU64;
283
284    struct TestCounter;
285
286    impl StateKey for TestCounter {
287        const KEY: &'static str = "test.store_counter";
288        type Value = i64;
289        type Update = i64;
290
291        fn apply(value: &mut Self::Value, update: Self::Update) {
292            *value += update;
293        }
294    }
295
296    struct TestStorePlugin;
297
298    impl Plugin for TestStorePlugin {
299        fn descriptor(&self) -> PluginDescriptor {
300            PluginDescriptor {
301                name: "test-store-plugin",
302            }
303        }
304
305        fn register(&self, registrar: &mut PluginRegistrar) -> Result<(), StateError> {
306            registrar.register_key::<TestCounter>(crate::state::StateKeyOptions::default())
307        }
308    }
309
310    #[test]
311    fn store_new_starts_at_revision_zero() {
312        let store = StateStore::new();
313        assert_eq!(store.revision(), 0);
314    }
315
316    #[test]
317    fn store_commit_increments_revision() {
318        let store = StateStore::new();
319        store.install_plugin(TestStorePlugin).unwrap();
320
321        let mut batch = store.begin_mutation();
322        batch.update::<TestCounter>(1);
323        let rev = store.commit(batch).unwrap();
324        assert_eq!(rev, 1);
325
326        let mut batch = store.begin_mutation();
327        batch.update::<TestCounter>(2);
328        let rev = store.commit(batch).unwrap();
329        assert_eq!(rev, 2);
330    }
331
332    #[test]
333    fn store_empty_commit_returns_current_revision() {
334        let store = StateStore::new();
335        let batch = store.begin_mutation();
336        let rev = store.commit(batch).unwrap();
337        assert_eq!(rev, 0);
338    }
339
340    #[test]
341    fn store_read_returns_none_before_write() {
342        let store = StateStore::new();
343        store.install_plugin(TestStorePlugin).unwrap();
344        let val = store.read::<TestCounter>();
345        assert!(val.is_none());
346    }
347
348    #[test]
349    fn store_read_after_write() {
350        let store = StateStore::new();
351        store.install_plugin(TestStorePlugin).unwrap();
352
353        let mut batch = store.begin_mutation();
354        batch.update::<TestCounter>(42);
355        store.commit(batch).unwrap();
356
357        let val = store.read::<TestCounter>().unwrap();
358        assert_eq!(val, 42);
359    }
360
361    #[test]
362    fn store_multiple_updates_accumulate() {
363        let store = StateStore::new();
364        store.install_plugin(TestStorePlugin).unwrap();
365
366        let mut batch = store.begin_mutation();
367        batch.update::<TestCounter>(10);
368        store.commit(batch).unwrap();
369
370        let mut batch = store.begin_mutation();
371        batch.update::<TestCounter>(20);
372        store.commit(batch).unwrap();
373
374        let val = store.read::<TestCounter>().unwrap();
375        assert_eq!(val, 30);
376    }
377
378    #[test]
379    fn store_snapshot_is_independent_copy() {
380        let store = StateStore::new();
381        store.install_plugin(TestStorePlugin).unwrap();
382
383        let mut batch = store.begin_mutation();
384        batch.update::<TestCounter>(10);
385        store.commit(batch).unwrap();
386
387        let snap = store.snapshot();
388        assert_eq!(snap.revision, 1);
389
390        let mut batch = store.begin_mutation();
391        batch.update::<TestCounter>(20);
392        store.commit(batch).unwrap();
393
394        assert_eq!(snap.revision, 1);
395        assert_eq!(store.revision(), 2);
396    }
397
398    #[test]
399    fn store_clone_shares_state() {
400        let store = StateStore::new();
401        store.install_plugin(TestStorePlugin).unwrap();
402
403        let mut batch = store.begin_mutation();
404        batch.update::<TestCounter>(100);
405        store.commit(batch).unwrap();
406
407        let store2 = store.clone();
408        assert_eq!(store2.read::<TestCounter>().unwrap(), 100);
409        assert_eq!(store2.revision(), 1);
410
411        let mut batch = store2.begin_mutation();
412        batch.update::<TestCounter>(50);
413        store2.commit(batch).unwrap();
414        assert_eq!(store.read::<TestCounter>().unwrap(), 150);
415    }
416
417    #[test]
418    fn store_install_plugin_duplicate_rejected() {
419        let store = StateStore::new();
420        store.install_plugin(TestStorePlugin).unwrap();
421        let err = store.install_plugin(TestStorePlugin);
422        assert!(err.is_err());
423    }
424
425    #[test]
426    fn store_commit_hook_fires() {
427        use std::sync::atomic::Ordering;
428
429        struct TestHook {
430            revision: Arc<AtomicU64>,
431        }
432
433        impl CommitHook for TestHook {
434            fn on_commit(&self, event: &CommitEvent) {
435                self.revision.store(event.new_revision, Ordering::SeqCst);
436            }
437        }
438
439        let store = StateStore::new();
440        store.install_plugin(TestStorePlugin).unwrap();
441
442        let rev = Arc::new(AtomicU64::new(0));
443        store.add_hook(TestHook {
444            revision: rev.clone(),
445        });
446
447        let mut batch = store.begin_mutation();
448        batch.update::<TestCounter>(1);
449        store.commit(batch).unwrap();
450
451        assert_eq!(rev.load(std::sync::atomic::Ordering::SeqCst), 1);
452    }
453
454    #[test]
455    fn store_base_revision_conflict() {
456        let store = StateStore::new();
457        store.install_plugin(TestStorePlugin).unwrap();
458
459        let mut batch = store.begin_mutation();
460        batch.update::<TestCounter>(1);
461        store.commit(batch).unwrap();
462
463        let mut batch = MutationBatch::new().with_base_revision(0);
464        batch.update::<TestCounter>(2);
465        let err = store.commit(batch);
466        assert!(err.is_err());
467    }
468
469    #[test]
470    fn store_uninstall_plugin() {
471        let store = StateStore::new();
472        store.install_plugin(TestStorePlugin).unwrap();
473        store.uninstall_plugin::<TestStorePlugin>().unwrap();
474        let err = store.uninstall_plugin::<TestStorePlugin>();
475        assert!(err.is_err());
476    }
477
478    #[test]
479    fn commit_with_wrong_base_revision_rejected() {
480        let store = StateStore::new();
481        store.install_plugin(TestStorePlugin).unwrap();
482
483        let mut batch = store.begin_mutation();
484        batch.update::<TestCounter>(1);
485        let rev = store.commit(batch).unwrap();
486        assert_eq!(rev, 1);
487
488        // Build batch with stale base_revision=0 while store is at revision 1
489        let mut stale_batch = MutationBatch::new().with_base_revision(0);
490        stale_batch.update::<TestCounter>(2);
491        let err = store.commit(stale_batch).unwrap_err();
492        assert!(
493            matches!(
494                err,
495                StateError::RevisionConflict {
496                    expected: 0,
497                    actual: 1
498                }
499            ),
500            "expected RevisionConflict, got: {err:?}"
501        );
502    }
503
504    #[test]
505    fn concurrent_snapshots_independent() {
506        let store = StateStore::new();
507        store.install_plugin(TestStorePlugin).unwrap();
508
509        // Take snapshot before any change
510        let snap_before = store.snapshot();
511        assert!(snap_before.get::<TestCounter>().is_none());
512
513        // Commit a change
514        let mut batch = store.begin_mutation();
515        batch.update::<TestCounter>(42);
516        store.commit(batch).unwrap();
517
518        // Take snapshot after change
519        let snap_after = store.snapshot();
520
521        // First snapshot must NOT see the change
522        assert!(snap_before.get::<TestCounter>().is_none());
523        assert_eq!(snap_before.revision, 0);
524
525        // Second snapshot must see the change
526        assert_eq!(*snap_after.get::<TestCounter>().unwrap(), 42);
527        assert_eq!(snap_after.revision, 1);
528    }
529
530    #[test]
531    fn empty_commit_returns_current_revision() {
532        let store = StateStore::new();
533        store.install_plugin(TestStorePlugin).unwrap();
534
535        // Advance to revision 1
536        let mut batch = store.begin_mutation();
537        batch.update::<TestCounter>(1);
538        store.commit(batch).unwrap();
539        assert_eq!(store.revision(), 1);
540
541        // Empty commit should return current revision without incrementing
542        let empty_batch = store.begin_mutation();
543        let rev = store.commit(empty_batch).unwrap();
544        assert_eq!(rev, 1);
545        assert_eq!(store.revision(), 1);
546    }
547
548    #[test]
549    fn commit_hook_receives_correct_metadata() {
550        use std::sync::atomic::{AtomicUsize, Ordering};
551
552        struct VerifyHook {
553            prev_rev: Arc<AtomicU64>,
554            new_rev: Arc<AtomicU64>,
555            op_count: Arc<AtomicUsize>,
556        }
557
558        impl CommitHook for VerifyHook {
559            fn on_commit(&self, event: &CommitEvent) {
560                self.prev_rev
561                    .store(event.previous_revision, Ordering::SeqCst);
562                self.new_rev.store(event.new_revision, Ordering::SeqCst);
563                self.op_count.store(event.op_count, Ordering::SeqCst);
564            }
565        }
566
567        let store = StateStore::new();
568        store.install_plugin(TestStorePlugin).unwrap();
569
570        let prev_rev = Arc::new(AtomicU64::new(999));
571        let new_rev = Arc::new(AtomicU64::new(999));
572        let op_count = Arc::new(AtomicUsize::new(999));
573        store.add_hook(VerifyHook {
574            prev_rev: prev_rev.clone(),
575            new_rev: new_rev.clone(),
576            op_count: op_count.clone(),
577        });
578
579        let mut batch = store.begin_mutation();
580        batch.update::<TestCounter>(1);
581        batch.update::<TestCounter>(2);
582        batch.update::<TestCounter>(3);
583        store.commit(batch).unwrap();
584
585        assert_eq!(prev_rev.load(Ordering::SeqCst), 0);
586        assert_eq!(new_rev.load(Ordering::SeqCst), 1);
587        assert_eq!(op_count.load(Ordering::SeqCst), 3);
588    }
589
590    #[test]
591    fn store_multiple_updates_in_single_batch() {
592        let store = StateStore::new();
593        store.install_plugin(TestStorePlugin).unwrap();
594
595        let mut batch = store.begin_mutation();
596        batch.update::<TestCounter>(10);
597        batch.update::<TestCounter>(20);
598        batch.update::<TestCounter>(30);
599        store.commit(batch).unwrap();
600
601        let val = store.read::<TestCounter>().unwrap();
602        assert_eq!(val, 60);
603        assert_eq!(store.revision(), 1);
604    }
605
606    #[test]
607    fn store_commit_event_has_correct_metadata() {
608        use std::sync::atomic::{AtomicUsize, Ordering};
609
610        struct MetadataHook {
611            op_count: Arc<AtomicUsize>,
612            prev_rev: Arc<AtomicU64>,
613        }
614
615        impl CommitHook for MetadataHook {
616            fn on_commit(&self, event: &CommitEvent) {
617                self.op_count.store(event.op_count, Ordering::SeqCst);
618                self.prev_rev
619                    .store(event.previous_revision, Ordering::SeqCst);
620            }
621        }
622
623        let store = StateStore::new();
624        store.install_plugin(TestStorePlugin).unwrap();
625
626        let op_count = Arc::new(AtomicUsize::new(0));
627        let prev_rev = Arc::new(AtomicU64::new(999));
628        store.add_hook(MetadataHook {
629            op_count: op_count.clone(),
630            prev_rev: prev_rev.clone(),
631        });
632
633        let mut batch = store.begin_mutation();
634        batch.update::<TestCounter>(1);
635        batch.update::<TestCounter>(2);
636        store.commit(batch).unwrap();
637
638        assert_eq!(op_count.load(std::sync::atomic::Ordering::SeqCst), 2);
639        assert_eq!(prev_rev.load(std::sync::atomic::Ordering::SeqCst), 0);
640    }
641}