1use chrono::{DateTime, Duration, Utc};
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::sync::Arc;
19use thiserror::Error;
20use tokio::sync::oneshot;
21
22pub type SharedStateResult<T> = Result<T, SharedStateError>;
24
25#[derive(Debug, Error)]
27pub enum SharedStateError {
28 #[error("Key not found: {0}")]
30 KeyNotFound(String),
31
32 #[error("Lock timeout for key: {0}")]
34 LockTimeout(String),
35
36 #[error("Lock not held: {0}")]
38 LockNotHeld(String),
39
40 #[error("Invalid lock: {0}")]
42 InvalidLock(String),
43
44 #[error("Compare and swap failed: expected value does not match")]
46 CompareAndSwapFailed,
47
48 #[error("Serialization error: {0}")]
50 SerializationError(String),
51
52 #[error("Channel error: {0}")]
54 ChannelError(String),
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59#[serde(rename_all = "camelCase")]
60pub struct Lock {
61 pub id: String,
63 pub holder: String,
65 pub key: String,
67 pub acquired_at: DateTime<Utc>,
69 pub expires_at: Option<DateTime<Utc>>,
71}
72
73impl Lock {
74 pub fn new(
76 key: impl Into<String>,
77 holder: impl Into<String>,
78 timeout: Option<Duration>,
79 ) -> Self {
80 let now = Utc::now();
81 Self {
82 id: uuid::Uuid::new_v4().to_string(),
83 holder: holder.into(),
84 key: key.into(),
85 acquired_at: now,
86 expires_at: timeout.map(|t| now + t),
87 }
88 }
89
90 pub fn is_expired(&self) -> bool {
92 self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
93 }
94}
95
96#[derive(Debug, Clone)]
98pub enum StateEvent {
99 Changed {
101 key: String,
102 value: Value,
103 old_value: Option<Value>,
104 },
105 Deleted {
107 key: String,
108 old_value: Option<Value>,
109 },
110 Cleared,
112 LockAcquired(Lock),
114 LockReleased(Lock),
116}
117
118pub type WatchCallback = Arc<dyn Fn(Option<Value>) + Send + Sync>;
120
121#[derive(Debug, Clone)]
123pub struct WatchHandle {
124 pub key: String,
126 pub id: String,
128}
129
130impl WatchHandle {
131 pub fn new(key: impl Into<String>) -> Self {
133 Self {
134 key: key.into(),
135 id: uuid::Uuid::new_v4().to_string(),
136 }
137 }
138}
139
140struct WatcherEntry {
142 id: String,
143 callback: WatchCallback,
144}
145
146struct LockWaiter {
148 holder: String,
149 timeout: Option<Duration>,
150 sender: oneshot::Sender<Lock>,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155#[serde(rename_all = "camelCase")]
156pub struct SharedStateStats {
157 pub state_size: usize,
159 pub watchers_count: usize,
161 pub total_watchers: usize,
163 pub locks_count: usize,
165 pub wait_queue_size: usize,
167}
168
169#[derive(Default)]
171pub struct SharedStateManager {
172 state: HashMap<String, Value>,
174 watchers: HashMap<String, Vec<WatcherEntry>>,
176 locks: HashMap<String, Lock>,
178 lock_wait_queue: HashMap<String, Vec<LockWaiter>>,
180 event_listeners: Vec<Arc<dyn Fn(StateEvent) + Send + Sync>>,
182}
183
184impl SharedStateManager {
185 pub fn new() -> Self {
187 Self {
188 state: HashMap::new(),
189 watchers: HashMap::new(),
190 locks: HashMap::new(),
191 lock_wait_queue: HashMap::new(),
192 event_listeners: Vec::new(),
193 }
194 }
195
196 pub fn get(&self, key: &str) -> Option<Value> {
198 self.state.get(key).cloned()
199 }
200
201 pub fn get_typed<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
203 self.state
204 .get(key)
205 .and_then(|v| serde_json::from_value(v.clone()).ok())
206 }
207
208 pub fn set(&mut self, key: impl Into<String>, value: Value) {
210 let key = key.into();
211 let old_value = self.state.get(&key).cloned();
212 self.state.insert(key.clone(), value.clone());
213
214 self.notify_watchers(&key, Some(value.clone()));
216
217 self.emit_event(StateEvent::Changed {
219 key,
220 value,
221 old_value,
222 });
223 }
224
225 pub fn set_typed<T: Serialize>(
227 &mut self,
228 key: impl Into<String>,
229 value: &T,
230 ) -> SharedStateResult<()> {
231 let json_value = serde_json::to_value(value)
232 .map_err(|e| SharedStateError::SerializationError(e.to_string()))?;
233 self.set(key, json_value);
234 Ok(())
235 }
236
237 pub fn delete(&mut self, key: &str) -> bool {
239 if let Some(old_value) = self.state.remove(key) {
240 self.notify_watchers(key, None);
242
243 self.emit_event(StateEvent::Deleted {
245 key: key.to_string(),
246 old_value: Some(old_value),
247 });
248
249 true
250 } else {
251 false
252 }
253 }
254
255 pub fn has(&self, key: &str) -> bool {
257 self.state.contains_key(key)
258 }
259
260 pub fn keys(&self) -> Vec<String> {
262 self.state.keys().cloned().collect()
263 }
264
265 pub fn clear(&mut self) {
267 self.state.clear();
268 self.emit_event(StateEvent::Cleared);
269 }
270
271 pub fn watch<F>(&mut self, key: impl Into<String>, callback: F) -> WatchHandle
273 where
274 F: Fn(Option<Value>) + Send + Sync + 'static,
275 {
276 let key = key.into();
277 let handle = WatchHandle::new(&key);
278
279 let entry = WatcherEntry {
280 id: handle.id.clone(),
281 callback: Arc::new(callback),
282 };
283
284 self.watchers.entry(key).or_default().push(entry);
285
286 handle
287 }
288
289 pub fn unwatch(&mut self, handle: &WatchHandle) -> bool {
291 if let Some(watchers) = self.watchers.get_mut(&handle.key) {
292 let before = watchers.len();
293 watchers.retain(|w| w.id != handle.id);
294 let removed = before != watchers.len();
295
296 if watchers.is_empty() {
298 self.watchers.remove(&handle.key);
299 }
300
301 removed
302 } else {
303 false
304 }
305 }
306
307 fn notify_watchers(&self, key: &str, value: Option<Value>) {
309 if let Some(watchers) = self.watchers.get(key) {
310 for watcher in watchers {
311 let callback = watcher.callback.clone();
313 let value = value.clone();
314 std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
315 callback(value);
316 }))
317 .ok();
318 }
319 }
320 }
321
322 pub fn on_event<F>(&mut self, listener: F)
324 where
325 F: Fn(StateEvent) + Send + Sync + 'static,
326 {
327 self.event_listeners.push(Arc::new(listener));
328 }
329
330 fn emit_event(&self, event: StateEvent) {
332 for listener in &self.event_listeners {
333 let listener = listener.clone();
334 let event = event.clone();
335 std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
336 listener(event);
337 }))
338 .ok();
339 }
340 }
341
342 pub fn lock(
347 &mut self,
348 key: impl Into<String>,
349 holder: impl Into<String>,
350 timeout: Option<Duration>,
351 ) -> SharedStateResult<Lock> {
352 let key = key.into();
353 let holder = holder.into();
354
355 if let Some(existing) = self.locks.get(&key) {
357 if existing.is_expired() {
359 let expired_lock = self.locks.remove(&key).unwrap();
361 self.emit_event(StateEvent::LockReleased(expired_lock));
362 } else {
363 return Err(SharedStateError::LockTimeout(key));
365 }
366 }
367
368 let lock = Lock::new(&key, &holder, timeout);
370 self.locks.insert(key, lock.clone());
371 self.emit_event(StateEvent::LockAcquired(lock.clone()));
372
373 Ok(lock)
374 }
375
376 pub fn try_lock(
378 &mut self,
379 key: impl Into<String>,
380 holder: impl Into<String>,
381 timeout: Option<Duration>,
382 ) -> Option<Lock> {
383 let key = key.into();
384 let holder = holder.into();
385
386 if let Some(existing) = self.locks.get(&key) {
388 if !existing.is_expired() {
389 return None;
390 }
391 let expired_lock = self.locks.remove(&key).unwrap();
393 self.emit_event(StateEvent::LockReleased(expired_lock));
394 }
395
396 let lock = Lock::new(&key, &holder, timeout);
398 self.locks.insert(key, lock.clone());
399 self.emit_event(StateEvent::LockAcquired(lock.clone()));
400
401 Some(lock)
402 }
403
404 pub fn prepare_lock(
409 &mut self,
410 key: impl Into<String>,
411 holder: impl Into<String>,
412 timeout: Option<Duration>,
413 ) -> Result<(String, oneshot::Receiver<Lock>), Lock> {
414 let key = key.into();
415 let holder = holder.into();
416
417 if let Some(existing) = self.locks.get(&key) {
419 if !existing.is_expired() {
420 let (tx, rx) = oneshot::channel();
422 let waiter = LockWaiter {
423 holder,
424 timeout,
425 sender: tx,
426 };
427
428 self.lock_wait_queue
429 .entry(key.clone())
430 .or_default()
431 .push(waiter);
432
433 return Ok((key, rx));
434 }
435 let expired_lock = self.locks.remove(&key).unwrap();
437 self.emit_event(StateEvent::LockReleased(expired_lock));
438 }
439
440 let lock = Lock::new(&key, &holder, timeout);
442 self.locks.insert(key, lock.clone());
443 self.emit_event(StateEvent::LockAcquired(lock.clone()));
444
445 Err(lock)
446 }
447
448 pub fn unlock(&mut self, lock: &Lock) -> SharedStateResult<()> {
450 let current = self.locks.get(&lock.key);
452 match current {
453 None => return Err(SharedStateError::LockNotHeld(lock.key.clone())),
454 Some(current) if current.id != lock.id => {
455 return Err(SharedStateError::InvalidLock(format!(
456 "Lock ID mismatch: expected {}, got {}",
457 current.id, lock.id
458 )));
459 }
460 _ => {}
461 }
462
463 let released_lock = self.locks.remove(&lock.key).unwrap();
465 self.emit_event(StateEvent::LockReleased(released_lock));
466
467 let waiter = self
469 .lock_wait_queue
470 .get_mut(&lock.key)
471 .and_then(|waiters| waiters.pop());
472
473 if let Some(waiter) = waiter {
474 let new_lock = Lock::new(&lock.key, &waiter.holder, waiter.timeout);
476 self.locks.insert(lock.key.clone(), new_lock.clone());
477 self.emit_event(StateEvent::LockAcquired(new_lock.clone()));
478
479 let _ = waiter.sender.send(new_lock);
481 }
482
483 if self
485 .lock_wait_queue
486 .get(&lock.key)
487 .map(|w| w.is_empty())
488 .unwrap_or(false)
489 {
490 self.lock_wait_queue.remove(&lock.key);
491 }
492
493 Ok(())
494 }
495
496 pub fn is_locked(&self, key: &str) -> bool {
498 self.locks
499 .get(key)
500 .map(|l| !l.is_expired())
501 .unwrap_or(false)
502 }
503
504 pub fn get_all_locks(&self) -> Vec<Lock> {
506 self.locks
507 .values()
508 .filter(|l| !l.is_expired())
509 .cloned()
510 .collect()
511 }
512
513 pub fn get_lock(&self, key: &str) -> Option<&Lock> {
515 self.locks.get(key).filter(|l| !l.is_expired())
516 }
517
518 pub fn compare_and_swap(&mut self, key: &str, expected: &Value, new_value: Value) -> bool {
523 let current = self.state.get(key);
524
525 if current == Some(expected) {
526 self.set(key.to_string(), new_value);
527 true
528 } else {
529 false
530 }
531 }
532
533 pub fn compare_and_swap_typed<T: Serialize + PartialEq + for<'de> Deserialize<'de>>(
535 &mut self,
536 key: &str,
537 expected: &T,
538 new_value: &T,
539 ) -> SharedStateResult<bool> {
540 let current: Option<T> = self.get_typed(key);
541
542 if current.as_ref() == Some(expected) {
543 self.set_typed(key, new_value)?;
544 Ok(true)
545 } else {
546 Ok(false)
547 }
548 }
549
550 pub fn increment(&mut self, key: &str, delta: i64) -> i64 {
555 let current = self.state.get(key).and_then(|v| v.as_i64()).unwrap_or(0);
556
557 let new_value = current + delta;
558 self.set(key.to_string(), Value::Number(new_value.into()));
559 new_value
560 }
561
562 pub fn decrement(&mut self, key: &str, delta: i64) -> i64 {
564 self.increment(key, -delta)
565 }
566
567 pub fn cleanup_expired_locks(&mut self) -> usize {
571 let expired_keys: Vec<String> = self
572 .locks
573 .iter()
574 .filter(|(_, lock)| lock.is_expired())
575 .map(|(key, _)| key.clone())
576 .collect();
577
578 let count = expired_keys.len();
579
580 for key in expired_keys {
581 if let Some(lock) = self.locks.remove(&key) {
582 self.emit_event(StateEvent::LockReleased(lock));
583
584 let waiter = self
586 .lock_wait_queue
587 .get_mut(&key)
588 .and_then(|waiters| waiters.pop());
589
590 if let Some(waiter) = waiter {
591 let new_lock = Lock::new(&key, &waiter.holder, waiter.timeout);
592 self.locks.insert(key.clone(), new_lock.clone());
593 self.emit_event(StateEvent::LockAcquired(new_lock.clone()));
594 let _ = waiter.sender.send(new_lock);
595 }
596
597 if self
599 .lock_wait_queue
600 .get(&key)
601 .map(|w| w.is_empty())
602 .unwrap_or(false)
603 {
604 self.lock_wait_queue.remove(&key);
605 }
606 }
607 }
608
609 count
610 }
611
612 pub fn get_stats(&self) -> SharedStateStats {
614 let total_watchers: usize = self.watchers.values().map(|w| w.len()).sum();
615 let wait_queue_size: usize = self.lock_wait_queue.values().map(|w| w.len()).sum();
616
617 SharedStateStats {
618 state_size: self.state.len(),
619 watchers_count: self.watchers.len(),
620 total_watchers,
621 locks_count: self.locks.len(),
622 wait_queue_size,
623 }
624 }
625
626 #[cfg(test)]
628 pub fn insert_lock_for_test(&mut self, lock: Lock) {
629 self.locks.insert(lock.key.clone(), lock);
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636 use serde_json::json;
637 use std::sync::atomic::{AtomicUsize, Ordering};
638
639 #[test]
640 fn test_get_set() {
641 let mut manager = SharedStateManager::new();
642
643 manager.set("key1", json!({"name": "test"}));
645
646 let value = manager.get("key1");
648 assert!(value.is_some());
649 assert_eq!(value.unwrap(), json!({"name": "test"}));
650
651 assert!(manager.get("non_existent").is_none());
653 }
654
655 #[test]
656 fn test_get_set_typed() {
657 let mut manager = SharedStateManager::new();
658
659 #[derive(Debug, Serialize, Deserialize, PartialEq)]
660 struct Config {
661 max_retries: u32,
662 timeout: u64,
663 }
664
665 let config = Config {
666 max_retries: 3,
667 timeout: 5000,
668 };
669
670 manager.set_typed("config", &config).unwrap();
671
672 let retrieved: Option<Config> = manager.get_typed("config");
673 assert_eq!(retrieved, Some(config));
674 }
675
676 #[test]
677 fn test_delete() {
678 let mut manager = SharedStateManager::new();
679
680 manager.set("key1", json!("value1"));
681 assert!(manager.has("key1"));
682
683 let deleted = manager.delete("key1");
684 assert!(deleted);
685 assert!(!manager.has("key1"));
686
687 let deleted = manager.delete("non_existent");
689 assert!(!deleted);
690 }
691
692 #[test]
693 fn test_keys() {
694 let mut manager = SharedStateManager::new();
695
696 manager.set("key1", json!("value1"));
697 manager.set("key2", json!("value2"));
698 manager.set("key3", json!("value3"));
699
700 let keys = manager.keys();
701 assert_eq!(keys.len(), 3);
702 assert!(keys.contains(&"key1".to_string()));
703 assert!(keys.contains(&"key2".to_string()));
704 assert!(keys.contains(&"key3".to_string()));
705 }
706
707 #[test]
708 fn test_clear() {
709 let mut manager = SharedStateManager::new();
710
711 manager.set("key1", json!("value1"));
712 manager.set("key2", json!("value2"));
713
714 manager.clear();
715
716 assert!(manager.keys().is_empty());
717 }
718
719 #[test]
720 fn test_watch() {
721 let mut manager = SharedStateManager::new();
722 let counter = Arc::new(AtomicUsize::new(0));
723 let counter_clone = counter.clone();
724
725 let handle = manager.watch("key1", move |_value| {
726 counter_clone.fetch_add(1, Ordering::SeqCst);
727 });
728
729 manager.set("key1", json!("value1"));
731 assert_eq!(counter.load(Ordering::SeqCst), 1);
732
733 manager.set("key1", json!("value2"));
735 assert_eq!(counter.load(Ordering::SeqCst), 2);
736
737 manager.delete("key1");
739 assert_eq!(counter.load(Ordering::SeqCst), 3);
740
741 manager.unwatch(&handle);
743
744 manager.set("key1", json!("value3"));
746 assert_eq!(counter.load(Ordering::SeqCst), 3);
747 }
748
749 #[test]
750 fn test_multiple_watchers() {
751 let mut manager = SharedStateManager::new();
752 let counter1 = Arc::new(AtomicUsize::new(0));
753 let counter2 = Arc::new(AtomicUsize::new(0));
754
755 let c1 = counter1.clone();
756 let c2 = counter2.clone();
757
758 let _handle1 = manager.watch("key1", move |_| {
759 c1.fetch_add(1, Ordering::SeqCst);
760 });
761
762 let _handle2 = manager.watch("key1", move |_| {
763 c2.fetch_add(1, Ordering::SeqCst);
764 });
765
766 manager.set("key1", json!("value"));
767
768 assert_eq!(counter1.load(Ordering::SeqCst), 1);
769 assert_eq!(counter2.load(Ordering::SeqCst), 1);
770 }
771
772 #[test]
773 fn test_lock_unlock() {
774 let mut manager = SharedStateManager::new();
775
776 let lock = manager.lock("resource1", "agent1", None).unwrap();
778 assert_eq!(lock.key, "resource1");
779 assert_eq!(lock.holder, "agent1");
780 assert!(manager.is_locked("resource1"));
781
782 let result = manager.lock("resource1", "agent2", None);
784 assert!(result.is_err());
785
786 manager.unlock(&lock).unwrap();
788 assert!(!manager.is_locked("resource1"));
789
790 let lock2 = manager.lock("resource1", "agent2", None).unwrap();
792 assert_eq!(lock2.holder, "agent2");
793 }
794
795 #[test]
796 fn test_try_lock() {
797 let mut manager = SharedStateManager::new();
798
799 let lock = manager.try_lock("resource1", "agent1", None);
801 assert!(lock.is_some());
802
803 let lock2 = manager.try_lock("resource1", "agent2", None);
805 assert!(lock2.is_none());
806 }
807
808 #[test]
809 fn test_lock_expiration() {
810 let mut manager = SharedStateManager::new();
811
812 let lock = Lock {
814 id: uuid::Uuid::new_v4().to_string(),
815 holder: "agent1".to_string(),
816 key: "resource1".to_string(),
817 acquired_at: Utc::now() - Duration::seconds(10),
818 expires_at: Some(Utc::now() - Duration::seconds(5)),
819 };
820 manager.locks.insert("resource1".to_string(), lock);
821
822 assert!(!manager.is_locked("resource1"));
824
825 let cleaned = manager.cleanup_expired_locks();
827 assert_eq!(cleaned, 1);
828 assert!(manager.locks.is_empty());
829 }
830
831 #[test]
832 fn test_compare_and_swap() {
833 let mut manager = SharedStateManager::new();
834
835 manager.set("counter", json!(10));
836
837 let success = manager.compare_and_swap("counter", &json!(10), json!(20));
839 assert!(success);
840 assert_eq!(manager.get("counter"), Some(json!(20)));
841
842 let success = manager.compare_and_swap("counter", &json!(10), json!(30));
844 assert!(!success);
845 assert_eq!(manager.get("counter"), Some(json!(20)));
846 }
847
848 #[test]
849 fn test_increment() {
850 let mut manager = SharedStateManager::new();
851
852 let value = manager.increment("counter", 5);
854 assert_eq!(value, 5);
855
856 let value = manager.increment("counter", 3);
858 assert_eq!(value, 8);
859
860 let value = manager.decrement("counter", 2);
862 assert_eq!(value, 6);
863 }
864
865 #[test]
866 fn test_get_all_locks() {
867 let mut manager = SharedStateManager::new();
868
869 manager.lock("resource1", "agent1", None).unwrap();
870 manager.lock("resource2", "agent2", None).unwrap();
871
872 let locks = manager.get_all_locks();
873 assert_eq!(locks.len(), 2);
874 }
875
876 #[test]
877 fn test_get_stats() {
878 let mut manager = SharedStateManager::new();
879
880 manager.set("key1", json!("value1"));
881 manager.set("key2", json!("value2"));
882 manager.watch("key1", |_| {});
883 manager.watch("key1", |_| {});
884 manager.watch("key2", |_| {});
885 manager.lock("resource1", "agent1", None).unwrap();
886
887 let stats = manager.get_stats();
888 assert_eq!(stats.state_size, 2);
889 assert_eq!(stats.watchers_count, 2);
890 assert_eq!(stats.total_watchers, 3);
891 assert_eq!(stats.locks_count, 1);
892 }
893
894 #[test]
895 fn test_event_listener() {
896 let mut manager = SharedStateManager::new();
897 let events = Arc::new(std::sync::Mutex::new(Vec::new()));
898 let events_clone = events.clone();
899
900 manager.on_event(move |event| {
901 events_clone.lock().unwrap().push(format!("{:?}", event));
902 });
903
904 manager.set("key1", json!("value1"));
905 manager.delete("key1");
906
907 let events = events.lock().unwrap();
908 assert_eq!(events.len(), 2);
909 assert!(events[0].contains("Changed"));
910 assert!(events[1].contains("Deleted"));
911 }
912
913 #[test]
914 fn test_unlock_invalid_lock() {
915 let mut manager = SharedStateManager::new();
916
917 let lock = manager.lock("resource1", "agent1", None).unwrap();
918
919 let fake_lock = Lock {
921 id: "fake-id".to_string(),
922 holder: "agent1".to_string(),
923 key: "resource1".to_string(),
924 acquired_at: Utc::now(),
925 expires_at: None,
926 };
927
928 let result = manager.unlock(&fake_lock);
929 assert!(matches!(result, Err(SharedStateError::InvalidLock(_))));
930
931 assert!(manager.is_locked("resource1"));
933
934 manager.unlock(&lock).unwrap();
936 }
937
938 #[test]
939 fn test_unlock_not_held() {
940 let mut manager = SharedStateManager::new();
941
942 let fake_lock = Lock {
943 id: "fake-id".to_string(),
944 holder: "agent1".to_string(),
945 key: "resource1".to_string(),
946 acquired_at: Utc::now(),
947 expires_at: None,
948 };
949
950 let result = manager.unlock(&fake_lock);
951 assert!(matches!(result, Err(SharedStateError::LockNotHeld(_))));
952 }
953}