1use std::any::TypeId;
2use std::sync::Arc;
3
4use parking_lot::{Mutex, RwLock};
5
6use crate::plugins::{InstalledPlugin, KeyRegistration, Plugin, PluginRegistrar, PluginRegistry};
7use awaken_contract::StateError;
8
9use super::{MutationBatch, Snapshot, StateCommand, StateKey, StateMap};
10
11#[derive(Clone)]
12pub struct CommitEvent {
13 pub previous_revision: u64,
14 pub new_revision: u64,
15 pub op_count: usize,
16 pub snapshot: Snapshot,
17}
18
19pub trait CommitHook: Send + Sync + 'static {
20 fn on_commit(&self, event: &CommitEvent);
21}
22
23pub struct StateStore {
24 pub(crate) inner: Arc<RwLock<Snapshot>>,
25 pub(crate) registry: Arc<Mutex<PluginRegistry>>,
26 pub(crate) hooks: Arc<RwLock<Vec<Arc<dyn CommitHook>>>>,
27}
28
29impl Clone for StateStore {
30 fn clone(&self) -> Self {
31 Self {
32 inner: Arc::clone(&self.inner),
33 registry: Arc::clone(&self.registry),
34 hooks: Arc::clone(&self.hooks),
35 }
36 }
37}
38
39impl StateStore {
40 pub fn new() -> Self {
41 Self {
42 inner: Arc::new(RwLock::new(Snapshot {
43 revision: 0,
44 ext: Arc::new(StateMap::default()),
45 })),
46 registry: Arc::new(Mutex::new(PluginRegistry::default())),
47 hooks: Arc::new(RwLock::new(Vec::new())),
48 }
49 }
50
51 pub fn snapshot(&self) -> Snapshot {
52 self.inner.read().clone()
53 }
54
55 pub fn revision(&self) -> u64 {
56 self.inner.read().revision
57 }
58
59 pub fn read<K>(&self) -> Option<K::Value>
60 where
61 K: StateKey,
62 {
63 let guard = self.inner.read();
64 guard.get::<K>().cloned()
65 }
66
67 pub fn add_hook<H>(&self, hook: H)
68 where
69 H: CommitHook,
70 {
71 self.hooks.write().push(Arc::new(hook));
72 }
73
74 pub fn begin_mutation(&self) -> MutationBatch {
75 MutationBatch::new()
76 }
77
78 pub fn merge_parallel(
80 &self,
81 left: MutationBatch,
82 right: MutationBatch,
83 ) -> Result<MutationBatch, StateError> {
84 let registry = self.registry.lock();
85 left.merge_parallel(right, |key| registry.merge_strategy(key))
86 }
87
88 pub fn merge_all_commands(
90 &self,
91 commands: Vec<StateCommand>,
92 ) -> Result<StateCommand, StateError> {
93 let registry = self.registry.lock();
94 commands
95 .into_iter()
96 .try_fold(StateCommand::new(), |acc, cmd| {
97 acc.merge_parallel(cmd, |key| registry.merge_strategy(key))
98 })
99 }
100
101 pub fn commit(&self, patch: MutationBatch) -> Result<u64, StateError> {
102 if patch.is_empty() {
103 return Ok(self.revision());
104 }
105
106 let op_count = patch.op_len();
107 let hooks = self.hooks.read().clone();
108
109 let registry = self.registry.lock();
110 let mut state = self.inner.write();
111
112 if let Some(expected) = patch.base_revision
113 && state.revision != expected
114 {
115 return Err(StateError::RevisionConflict {
116 expected,
117 actual: state.revision,
118 });
119 }
120
121 for key in &patch.touched_keys {
122 registry.ensure_key(key)?;
123 }
124
125 let previous_revision = state.revision;
126 for op in patch.ops {
127 op.apply(&mut state);
128 }
129 state.revision += 1;
130 let new_revision = state.revision;
131 let snapshot = state.clone();
132 drop(state);
133 drop(registry);
134
135 let event = CommitEvent {
136 previous_revision,
137 new_revision,
138 op_count,
139 snapshot,
140 };
141 for hook in hooks {
142 hook.on_commit(&event);
143 }
144
145 Ok(new_revision)
146 }
147
148 pub fn install_plugin<P>(&self, plugin: P) -> Result<(), StateError>
149 where
150 P: Plugin,
151 {
152 let mut registrar = PluginRegistrar::new();
153 plugin.register(&mut registrar)?;
154 let plugin_type_id = TypeId::of::<P>();
155 self.install_plugin_with_keys(plugin_type_id, Arc::new(plugin), registrar.keys)
156 }
157
158 pub(crate) fn install_plugin_with_keys(
159 &self,
160 plugin_type_id: TypeId,
161 plugin: Arc<dyn Plugin>,
162 registrations: Vec<KeyRegistration>,
163 ) -> Result<(), StateError> {
164 let descriptor = plugin.descriptor();
165
166 {
167 let mut registry = self.registry.lock();
168 if registry.plugins.contains_key(&plugin_type_id) {
169 return Err(StateError::PluginAlreadyInstalled {
170 name: descriptor.name.to_string(),
171 });
172 }
173
174 for reg in ®istrations {
175 if registry.keys_by_name.contains_key(®.key) {
176 return Err(StateError::KeyAlreadyRegistered {
177 key: reg.key.clone(),
178 });
179 }
180 }
181
182 for reg in ®istrations {
183 registry.keys_by_name.insert(reg.key.clone(), reg.clone());
184 registry.keys_by_type.insert(reg.type_id, reg.clone());
185 }
186
187 registry.plugins.insert(
188 plugin_type_id,
189 InstalledPlugin {
190 owned_key_type_ids: registrations.iter().map(|r| r.type_id).collect(),
191 },
192 );
193 }
194
195 Ok(())
196 }
197
198 pub(crate) fn register_keys(
204 &self,
205 registrations: &[KeyRegistration],
206 ) -> Result<(), StateError> {
207 let mut registry = self.registry.lock();
208 for reg in registrations {
209 if registry.keys_by_name.contains_key(®.key) {
210 continue;
212 }
213 registry.keys_by_name.insert(reg.key.clone(), reg.clone());
214 registry.keys_by_type.insert(reg.type_id, reg.clone());
215 }
216 Ok(())
217 }
218
219 pub fn uninstall_plugin<P>(&self) -> Result<(), StateError>
220 where
221 P: Plugin,
222 {
223 let plugin_type_id = TypeId::of::<P>();
224 let registrations =
225 {
226 let registry = self.registry.lock();
227 let installed = registry.plugins.get(&plugin_type_id).ok_or(
228 StateError::PluginNotInstalled {
229 type_name: std::any::type_name::<P>(),
230 },
231 )?;
232 installed
233 .owned_key_type_ids
234 .iter()
235 .filter_map(|type_id| registry.keys_by_type.get(type_id).cloned())
236 .collect::<Vec<_>>()
237 };
238
239 let mut patch = MutationBatch::new().with_base_revision(self.revision());
240 for reg in ®istrations {
241 if !reg.options.retain_on_uninstall {
242 patch.clear_extension_with(reg.key.clone(), reg.clear);
243 }
244 }
245 self.commit(patch).map(|_| ())?;
246 self.unregister_plugin_type_id(plugin_type_id)
247 }
248
249 fn unregister_plugin_type_id(&self, plugin_type_id: TypeId) -> Result<(), StateError> {
250 {
251 let mut registry = self.registry.lock();
252 let installed =
253 registry
254 .plugins
255 .remove(&plugin_type_id)
256 .ok_or(StateError::PluginNotInstalled {
257 type_name: "unknown",
258 })?;
259
260 for type_id in &installed.owned_key_type_ids {
261 if let Some(reg) = registry.keys_by_type.remove(type_id) {
262 registry.keys_by_name.remove(®.key);
263 }
264 }
265 }
266
267 Ok(())
268 }
269}
270
271impl Default for StateStore {
272 fn default() -> Self {
273 Self::new()
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use crate::plugins::{Plugin, PluginDescriptor, PluginRegistrar};
281 use crate::state::StateKey;
282 use std::sync::atomic::AtomicU64;
283
284 struct TestCounter;
285
286 impl StateKey for TestCounter {
287 const KEY: &'static str = "test.store_counter";
288 type Value = i64;
289 type Update = i64;
290
291 fn apply(value: &mut Self::Value, update: Self::Update) {
292 *value += update;
293 }
294 }
295
296 struct TestStorePlugin;
297
298 impl Plugin for TestStorePlugin {
299 fn descriptor(&self) -> PluginDescriptor {
300 PluginDescriptor {
301 name: "test-store-plugin",
302 }
303 }
304
305 fn register(&self, registrar: &mut PluginRegistrar) -> Result<(), StateError> {
306 registrar.register_key::<TestCounter>(crate::state::StateKeyOptions::default())
307 }
308 }
309
310 #[test]
311 fn store_new_starts_at_revision_zero() {
312 let store = StateStore::new();
313 assert_eq!(store.revision(), 0);
314 }
315
316 #[test]
317 fn store_commit_increments_revision() {
318 let store = StateStore::new();
319 store.install_plugin(TestStorePlugin).unwrap();
320
321 let mut batch = store.begin_mutation();
322 batch.update::<TestCounter>(1);
323 let rev = store.commit(batch).unwrap();
324 assert_eq!(rev, 1);
325
326 let mut batch = store.begin_mutation();
327 batch.update::<TestCounter>(2);
328 let rev = store.commit(batch).unwrap();
329 assert_eq!(rev, 2);
330 }
331
332 #[test]
333 fn store_empty_commit_returns_current_revision() {
334 let store = StateStore::new();
335 let batch = store.begin_mutation();
336 let rev = store.commit(batch).unwrap();
337 assert_eq!(rev, 0);
338 }
339
340 #[test]
341 fn store_read_returns_none_before_write() {
342 let store = StateStore::new();
343 store.install_plugin(TestStorePlugin).unwrap();
344 let val = store.read::<TestCounter>();
345 assert!(val.is_none());
346 }
347
348 #[test]
349 fn store_read_after_write() {
350 let store = StateStore::new();
351 store.install_plugin(TestStorePlugin).unwrap();
352
353 let mut batch = store.begin_mutation();
354 batch.update::<TestCounter>(42);
355 store.commit(batch).unwrap();
356
357 let val = store.read::<TestCounter>().unwrap();
358 assert_eq!(val, 42);
359 }
360
361 #[test]
362 fn store_multiple_updates_accumulate() {
363 let store = StateStore::new();
364 store.install_plugin(TestStorePlugin).unwrap();
365
366 let mut batch = store.begin_mutation();
367 batch.update::<TestCounter>(10);
368 store.commit(batch).unwrap();
369
370 let mut batch = store.begin_mutation();
371 batch.update::<TestCounter>(20);
372 store.commit(batch).unwrap();
373
374 let val = store.read::<TestCounter>().unwrap();
375 assert_eq!(val, 30);
376 }
377
378 #[test]
379 fn store_snapshot_is_independent_copy() {
380 let store = StateStore::new();
381 store.install_plugin(TestStorePlugin).unwrap();
382
383 let mut batch = store.begin_mutation();
384 batch.update::<TestCounter>(10);
385 store.commit(batch).unwrap();
386
387 let snap = store.snapshot();
388 assert_eq!(snap.revision, 1);
389
390 let mut batch = store.begin_mutation();
391 batch.update::<TestCounter>(20);
392 store.commit(batch).unwrap();
393
394 assert_eq!(snap.revision, 1);
395 assert_eq!(store.revision(), 2);
396 }
397
398 #[test]
399 fn store_clone_shares_state() {
400 let store = StateStore::new();
401 store.install_plugin(TestStorePlugin).unwrap();
402
403 let mut batch = store.begin_mutation();
404 batch.update::<TestCounter>(100);
405 store.commit(batch).unwrap();
406
407 let store2 = store.clone();
408 assert_eq!(store2.read::<TestCounter>().unwrap(), 100);
409 assert_eq!(store2.revision(), 1);
410
411 let mut batch = store2.begin_mutation();
412 batch.update::<TestCounter>(50);
413 store2.commit(batch).unwrap();
414 assert_eq!(store.read::<TestCounter>().unwrap(), 150);
415 }
416
417 #[test]
418 fn store_install_plugin_duplicate_rejected() {
419 let store = StateStore::new();
420 store.install_plugin(TestStorePlugin).unwrap();
421 let err = store.install_plugin(TestStorePlugin);
422 assert!(err.is_err());
423 }
424
425 #[test]
426 fn store_commit_hook_fires() {
427 use std::sync::atomic::Ordering;
428
429 struct TestHook {
430 revision: Arc<AtomicU64>,
431 }
432
433 impl CommitHook for TestHook {
434 fn on_commit(&self, event: &CommitEvent) {
435 self.revision.store(event.new_revision, Ordering::SeqCst);
436 }
437 }
438
439 let store = StateStore::new();
440 store.install_plugin(TestStorePlugin).unwrap();
441
442 let rev = Arc::new(AtomicU64::new(0));
443 store.add_hook(TestHook {
444 revision: rev.clone(),
445 });
446
447 let mut batch = store.begin_mutation();
448 batch.update::<TestCounter>(1);
449 store.commit(batch).unwrap();
450
451 assert_eq!(rev.load(std::sync::atomic::Ordering::SeqCst), 1);
452 }
453
454 #[test]
455 fn store_base_revision_conflict() {
456 let store = StateStore::new();
457 store.install_plugin(TestStorePlugin).unwrap();
458
459 let mut batch = store.begin_mutation();
460 batch.update::<TestCounter>(1);
461 store.commit(batch).unwrap();
462
463 let mut batch = MutationBatch::new().with_base_revision(0);
464 batch.update::<TestCounter>(2);
465 let err = store.commit(batch);
466 assert!(err.is_err());
467 }
468
469 #[test]
470 fn store_uninstall_plugin() {
471 let store = StateStore::new();
472 store.install_plugin(TestStorePlugin).unwrap();
473 store.uninstall_plugin::<TestStorePlugin>().unwrap();
474 let err = store.uninstall_plugin::<TestStorePlugin>();
475 assert!(err.is_err());
476 }
477
478 #[test]
479 fn commit_with_wrong_base_revision_rejected() {
480 let store = StateStore::new();
481 store.install_plugin(TestStorePlugin).unwrap();
482
483 let mut batch = store.begin_mutation();
484 batch.update::<TestCounter>(1);
485 let rev = store.commit(batch).unwrap();
486 assert_eq!(rev, 1);
487
488 let mut stale_batch = MutationBatch::new().with_base_revision(0);
490 stale_batch.update::<TestCounter>(2);
491 let err = store.commit(stale_batch).unwrap_err();
492 assert!(
493 matches!(
494 err,
495 StateError::RevisionConflict {
496 expected: 0,
497 actual: 1
498 }
499 ),
500 "expected RevisionConflict, got: {err:?}"
501 );
502 }
503
504 #[test]
505 fn concurrent_snapshots_independent() {
506 let store = StateStore::new();
507 store.install_plugin(TestStorePlugin).unwrap();
508
509 let snap_before = store.snapshot();
511 assert!(snap_before.get::<TestCounter>().is_none());
512
513 let mut batch = store.begin_mutation();
515 batch.update::<TestCounter>(42);
516 store.commit(batch).unwrap();
517
518 let snap_after = store.snapshot();
520
521 assert!(snap_before.get::<TestCounter>().is_none());
523 assert_eq!(snap_before.revision, 0);
524
525 assert_eq!(*snap_after.get::<TestCounter>().unwrap(), 42);
527 assert_eq!(snap_after.revision, 1);
528 }
529
530 #[test]
531 fn empty_commit_returns_current_revision() {
532 let store = StateStore::new();
533 store.install_plugin(TestStorePlugin).unwrap();
534
535 let mut batch = store.begin_mutation();
537 batch.update::<TestCounter>(1);
538 store.commit(batch).unwrap();
539 assert_eq!(store.revision(), 1);
540
541 let empty_batch = store.begin_mutation();
543 let rev = store.commit(empty_batch).unwrap();
544 assert_eq!(rev, 1);
545 assert_eq!(store.revision(), 1);
546 }
547
548 #[test]
549 fn commit_hook_receives_correct_metadata() {
550 use std::sync::atomic::{AtomicUsize, Ordering};
551
552 struct VerifyHook {
553 prev_rev: Arc<AtomicU64>,
554 new_rev: Arc<AtomicU64>,
555 op_count: Arc<AtomicUsize>,
556 }
557
558 impl CommitHook for VerifyHook {
559 fn on_commit(&self, event: &CommitEvent) {
560 self.prev_rev
561 .store(event.previous_revision, Ordering::SeqCst);
562 self.new_rev.store(event.new_revision, Ordering::SeqCst);
563 self.op_count.store(event.op_count, Ordering::SeqCst);
564 }
565 }
566
567 let store = StateStore::new();
568 store.install_plugin(TestStorePlugin).unwrap();
569
570 let prev_rev = Arc::new(AtomicU64::new(999));
571 let new_rev = Arc::new(AtomicU64::new(999));
572 let op_count = Arc::new(AtomicUsize::new(999));
573 store.add_hook(VerifyHook {
574 prev_rev: prev_rev.clone(),
575 new_rev: new_rev.clone(),
576 op_count: op_count.clone(),
577 });
578
579 let mut batch = store.begin_mutation();
580 batch.update::<TestCounter>(1);
581 batch.update::<TestCounter>(2);
582 batch.update::<TestCounter>(3);
583 store.commit(batch).unwrap();
584
585 assert_eq!(prev_rev.load(Ordering::SeqCst), 0);
586 assert_eq!(new_rev.load(Ordering::SeqCst), 1);
587 assert_eq!(op_count.load(Ordering::SeqCst), 3);
588 }
589
590 #[test]
591 fn store_multiple_updates_in_single_batch() {
592 let store = StateStore::new();
593 store.install_plugin(TestStorePlugin).unwrap();
594
595 let mut batch = store.begin_mutation();
596 batch.update::<TestCounter>(10);
597 batch.update::<TestCounter>(20);
598 batch.update::<TestCounter>(30);
599 store.commit(batch).unwrap();
600
601 let val = store.read::<TestCounter>().unwrap();
602 assert_eq!(val, 60);
603 assert_eq!(store.revision(), 1);
604 }
605
606 #[test]
607 fn store_commit_event_has_correct_metadata() {
608 use std::sync::atomic::{AtomicUsize, Ordering};
609
610 struct MetadataHook {
611 op_count: Arc<AtomicUsize>,
612 prev_rev: Arc<AtomicU64>,
613 }
614
615 impl CommitHook for MetadataHook {
616 fn on_commit(&self, event: &CommitEvent) {
617 self.op_count.store(event.op_count, Ordering::SeqCst);
618 self.prev_rev
619 .store(event.previous_revision, Ordering::SeqCst);
620 }
621 }
622
623 let store = StateStore::new();
624 store.install_plugin(TestStorePlugin).unwrap();
625
626 let op_count = Arc::new(AtomicUsize::new(0));
627 let prev_rev = Arc::new(AtomicU64::new(999));
628 store.add_hook(MetadataHook {
629 op_count: op_count.clone(),
630 prev_rev: prev_rev.clone(),
631 });
632
633 let mut batch = store.begin_mutation();
634 batch.update::<TestCounter>(1);
635 batch.update::<TestCounter>(2);
636 store.commit(batch).unwrap();
637
638 assert_eq!(op_count.load(std::sync::atomic::Ordering::SeqCst), 2);
639 assert_eq!(prev_rev.load(std::sync::atomic::Ordering::SeqCst), 0);
640 }
641}