Skip to main content

aster/agents/communication/
shared_state.rs

1//! Shared State Manager
2//!
3//! Provides shared state storage with distributed locking,
4//! watch callbacks, and atomic operations.
5//!
6//! # Features
7//! - Key-value state storage with JSON values
8//! - Watch callbacks for state changes
9//! - Distributed locking with timeouts
10//! - Atomic compare-and-swap operations
11//! - Atomic increment operations
12//! - Automatic cleanup of expired locks
13
14use chrono::{DateTime, Duration, Utc};
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::sync::Arc;
19use thiserror::Error;
20use tokio::sync::oneshot;
21
22/// Result type alias for shared state operations
23pub type SharedStateResult<T> = Result<T, SharedStateError>;
24
25/// Error types for shared state operations
26#[derive(Debug, Error)]
27pub enum SharedStateError {
28    /// Key not found
29    #[error("Key not found: {0}")]
30    KeyNotFound(String),
31
32    /// Lock timeout
33    #[error("Lock timeout for key: {0}")]
34    LockTimeout(String),
35
36    /// Lock not held
37    #[error("Lock not held: {0}")]
38    LockNotHeld(String),
39
40    /// Invalid lock
41    #[error("Invalid lock: {0}")]
42    InvalidLock(String),
43
44    /// Compare and swap failed
45    #[error("Compare and swap failed: expected value does not match")]
46    CompareAndSwapFailed,
47
48    /// Serialization error
49    #[error("Serialization error: {0}")]
50    SerializationError(String),
51
52    /// Channel error
53    #[error("Channel error: {0}")]
54    ChannelError(String),
55}
56
57/// Lock structure representing a distributed lock
58#[derive(Debug, Clone, Serialize, Deserialize)]
59#[serde(rename_all = "camelCase")]
60pub struct Lock {
61    /// Unique lock identifier
62    pub id: String,
63    /// Lock holder identifier
64    pub holder: String,
65    /// Key being locked
66    pub key: String,
67    /// When the lock was acquired
68    pub acquired_at: DateTime<Utc>,
69    /// When the lock expires (if any)
70    pub expires_at: Option<DateTime<Utc>>,
71}
72
73impl Lock {
74    /// Create a new lock
75    pub fn new(
76        key: impl Into<String>,
77        holder: impl Into<String>,
78        timeout: Option<Duration>,
79    ) -> Self {
80        let now = Utc::now();
81        Self {
82            id: uuid::Uuid::new_v4().to_string(),
83            holder: holder.into(),
84            key: key.into(),
85            acquired_at: now,
86            expires_at: timeout.map(|t| now + t),
87        }
88    }
89
90    /// Check if the lock has expired
91    pub fn is_expired(&self) -> bool {
92        self.expires_at.map(|exp| Utc::now() > exp).unwrap_or(false)
93    }
94}
95
96/// State change event
97#[derive(Debug, Clone)]
98pub enum StateEvent {
99    /// Value changed
100    Changed {
101        key: String,
102        value: Value,
103        old_value: Option<Value>,
104    },
105    /// Value deleted
106    Deleted {
107        key: String,
108        old_value: Option<Value>,
109    },
110    /// All state cleared
111    Cleared,
112    /// Lock acquired
113    LockAcquired(Lock),
114    /// Lock released
115    LockReleased(Lock),
116}
117
118/// Watch callback type
119pub type WatchCallback = Arc<dyn Fn(Option<Value>) + Send + Sync>;
120
121/// Watch handle for unsubscribing
122#[derive(Debug, Clone)]
123pub struct WatchHandle {
124    /// Key being watched
125    pub key: String,
126    /// Unique handle ID
127    pub id: String,
128}
129
130impl WatchHandle {
131    /// Create a new watch handle
132    pub fn new(key: impl Into<String>) -> Self {
133        Self {
134            key: key.into(),
135            id: uuid::Uuid::new_v4().to_string(),
136        }
137    }
138}
139
140/// Watcher entry
141struct WatcherEntry {
142    id: String,
143    callback: WatchCallback,
144}
145
146/// Pending lock waiter
147struct LockWaiter {
148    holder: String,
149    timeout: Option<Duration>,
150    sender: oneshot::Sender<Lock>,
151}
152
153/// Statistics about the shared state
154#[derive(Debug, Clone, Serialize, Deserialize)]
155#[serde(rename_all = "camelCase")]
156pub struct SharedStateStats {
157    /// Number of keys in state
158    pub state_size: usize,
159    /// Number of keys being watched
160    pub watchers_count: usize,
161    /// Total number of watcher callbacks
162    pub total_watchers: usize,
163    /// Number of active locks
164    pub locks_count: usize,
165    /// Number of waiters in lock queues
166    pub wait_queue_size: usize,
167}
168
169/// Shared State Manager for inter-agent state sharing
170#[derive(Default)]
171pub struct SharedStateManager {
172    /// Key-value state storage
173    state: HashMap<String, Value>,
174    /// Watchers per key
175    watchers: HashMap<String, Vec<WatcherEntry>>,
176    /// Active locks
177    locks: HashMap<String, Lock>,
178    /// Lock wait queues
179    lock_wait_queue: HashMap<String, Vec<LockWaiter>>,
180    /// Event listeners
181    event_listeners: Vec<Arc<dyn Fn(StateEvent) + Send + Sync>>,
182}
183
184impl SharedStateManager {
185    /// Create a new shared state manager
186    pub fn new() -> Self {
187        Self {
188            state: HashMap::new(),
189            watchers: HashMap::new(),
190            locks: HashMap::new(),
191            lock_wait_queue: HashMap::new(),
192            event_listeners: Vec::new(),
193        }
194    }
195
196    /// Get a value by key
197    pub fn get(&self, key: &str) -> Option<Value> {
198        self.state.get(key).cloned()
199    }
200
201    /// Get a typed value by key
202    pub fn get_typed<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
203        self.state
204            .get(key)
205            .and_then(|v| serde_json::from_value(v.clone()).ok())
206    }
207
208    /// Set a value by key
209    pub fn set(&mut self, key: impl Into<String>, value: Value) {
210        let key = key.into();
211        let old_value = self.state.get(&key).cloned();
212        self.state.insert(key.clone(), value.clone());
213
214        // Notify watchers
215        self.notify_watchers(&key, Some(value.clone()));
216
217        // Emit event
218        self.emit_event(StateEvent::Changed {
219            key,
220            value,
221            old_value,
222        });
223    }
224
225    /// Set a typed value by key
226    pub fn set_typed<T: Serialize>(
227        &mut self,
228        key: impl Into<String>,
229        value: &T,
230    ) -> SharedStateResult<()> {
231        let json_value = serde_json::to_value(value)
232            .map_err(|e| SharedStateError::SerializationError(e.to_string()))?;
233        self.set(key, json_value);
234        Ok(())
235    }
236
237    /// Delete a value by key
238    pub fn delete(&mut self, key: &str) -> bool {
239        if let Some(old_value) = self.state.remove(key) {
240            // Notify watchers with None
241            self.notify_watchers(key, None);
242
243            // Emit event
244            self.emit_event(StateEvent::Deleted {
245                key: key.to_string(),
246                old_value: Some(old_value),
247            });
248
249            true
250        } else {
251            false
252        }
253    }
254
255    /// Check if a key exists
256    pub fn has(&self, key: &str) -> bool {
257        self.state.contains_key(key)
258    }
259
260    /// Get all keys
261    pub fn keys(&self) -> Vec<String> {
262        self.state.keys().cloned().collect()
263    }
264
265    /// Clear all state
266    pub fn clear(&mut self) {
267        self.state.clear();
268        self.emit_event(StateEvent::Cleared);
269    }
270
271    /// Watch for changes to a key
272    pub fn watch<F>(&mut self, key: impl Into<String>, callback: F) -> WatchHandle
273    where
274        F: Fn(Option<Value>) + Send + Sync + 'static,
275    {
276        let key = key.into();
277        let handle = WatchHandle::new(&key);
278
279        let entry = WatcherEntry {
280            id: handle.id.clone(),
281            callback: Arc::new(callback),
282        };
283
284        self.watchers.entry(key).or_default().push(entry);
285
286        handle
287    }
288
289    /// Unwatch a key
290    pub fn unwatch(&mut self, handle: &WatchHandle) -> bool {
291        if let Some(watchers) = self.watchers.get_mut(&handle.key) {
292            let before = watchers.len();
293            watchers.retain(|w| w.id != handle.id);
294            let removed = before != watchers.len();
295
296            // Clean up empty watcher lists
297            if watchers.is_empty() {
298                self.watchers.remove(&handle.key);
299            }
300
301            removed
302        } else {
303            false
304        }
305    }
306
307    /// Notify watchers of a value change
308    fn notify_watchers(&self, key: &str, value: Option<Value>) {
309        if let Some(watchers) = self.watchers.get(key) {
310            for watcher in watchers {
311                // Call the callback, catching any panics
312                let callback = watcher.callback.clone();
313                let value = value.clone();
314                std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
315                    callback(value);
316                }))
317                .ok();
318            }
319        }
320    }
321
322    /// Add an event listener
323    pub fn on_event<F>(&mut self, listener: F)
324    where
325        F: Fn(StateEvent) + Send + Sync + 'static,
326    {
327        self.event_listeners.push(Arc::new(listener));
328    }
329
330    /// Emit an event to all listeners
331    fn emit_event(&self, event: StateEvent) {
332        for listener in &self.event_listeners {
333            let listener = listener.clone();
334            let event = event.clone();
335            std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
336                listener(event);
337            }))
338            .ok();
339        }
340    }
341
342    /// Acquire a lock on a key
343    ///
344    /// If the key is already locked, this will wait until the lock is released
345    /// or the timeout expires.
346    pub fn lock(
347        &mut self,
348        key: impl Into<String>,
349        holder: impl Into<String>,
350        timeout: Option<Duration>,
351    ) -> SharedStateResult<Lock> {
352        let key = key.into();
353        let holder = holder.into();
354
355        // Check if already locked
356        if let Some(existing) = self.locks.get(&key) {
357            // Check if the existing lock has expired
358            if existing.is_expired() {
359                // Remove expired lock and proceed
360                let expired_lock = self.locks.remove(&key).unwrap();
361                self.emit_event(StateEvent::LockReleased(expired_lock));
362            } else {
363                // Lock is held, return error (in async version, would wait)
364                return Err(SharedStateError::LockTimeout(key));
365            }
366        }
367
368        // Create new lock
369        let lock = Lock::new(&key, &holder, timeout);
370        self.locks.insert(key, lock.clone());
371        self.emit_event(StateEvent::LockAcquired(lock.clone()));
372
373        Ok(lock)
374    }
375
376    /// Try to acquire a lock without waiting
377    pub fn try_lock(
378        &mut self,
379        key: impl Into<String>,
380        holder: impl Into<String>,
381        timeout: Option<Duration>,
382    ) -> Option<Lock> {
383        let key = key.into();
384        let holder = holder.into();
385
386        // Check if already locked
387        if let Some(existing) = self.locks.get(&key) {
388            if !existing.is_expired() {
389                return None;
390            }
391            // Remove expired lock
392            let expired_lock = self.locks.remove(&key).unwrap();
393            self.emit_event(StateEvent::LockReleased(expired_lock));
394        }
395
396        // Create new lock
397        let lock = Lock::new(&key, &holder, timeout);
398        self.locks.insert(key, lock.clone());
399        self.emit_event(StateEvent::LockAcquired(lock.clone()));
400
401        Some(lock)
402    }
403
404    /// Prepare an async lock request
405    ///
406    /// Returns a receiver that will receive the lock when it becomes available.
407    /// The caller should await on the receiver with a timeout.
408    pub fn prepare_lock(
409        &mut self,
410        key: impl Into<String>,
411        holder: impl Into<String>,
412        timeout: Option<Duration>,
413    ) -> Result<(String, oneshot::Receiver<Lock>), Lock> {
414        let key = key.into();
415        let holder = holder.into();
416
417        // Check if already locked
418        if let Some(existing) = self.locks.get(&key) {
419            if !existing.is_expired() {
420                // Create a waiter
421                let (tx, rx) = oneshot::channel();
422                let waiter = LockWaiter {
423                    holder,
424                    timeout,
425                    sender: tx,
426                };
427
428                self.lock_wait_queue
429                    .entry(key.clone())
430                    .or_default()
431                    .push(waiter);
432
433                return Ok((key, rx));
434            }
435            // Remove expired lock
436            let expired_lock = self.locks.remove(&key).unwrap();
437            self.emit_event(StateEvent::LockReleased(expired_lock));
438        }
439
440        // Create new lock immediately
441        let lock = Lock::new(&key, &holder, timeout);
442        self.locks.insert(key, lock.clone());
443        self.emit_event(StateEvent::LockAcquired(lock.clone()));
444
445        Err(lock)
446    }
447
448    /// Release a lock
449    pub fn unlock(&mut self, lock: &Lock) -> SharedStateResult<()> {
450        // Verify the lock is valid
451        let current = self.locks.get(&lock.key);
452        match current {
453            None => return Err(SharedStateError::LockNotHeld(lock.key.clone())),
454            Some(current) if current.id != lock.id => {
455                return Err(SharedStateError::InvalidLock(format!(
456                    "Lock ID mismatch: expected {}, got {}",
457                    current.id, lock.id
458                )));
459            }
460            _ => {}
461        }
462
463        // Remove the lock
464        let released_lock = self.locks.remove(&lock.key).unwrap();
465        self.emit_event(StateEvent::LockReleased(released_lock));
466
467        // Process wait queue - take ownership to avoid borrow issues
468        let waiter = self
469            .lock_wait_queue
470            .get_mut(&lock.key)
471            .and_then(|waiters| waiters.pop());
472
473        if let Some(waiter) = waiter {
474            // Grant lock to next waiter
475            let new_lock = Lock::new(&lock.key, &waiter.holder, waiter.timeout);
476            self.locks.insert(lock.key.clone(), new_lock.clone());
477            self.emit_event(StateEvent::LockAcquired(new_lock.clone()));
478
479            // Send lock to waiter (ignore if receiver dropped)
480            let _ = waiter.sender.send(new_lock);
481        }
482
483        // Clean up empty wait queue
484        if self
485            .lock_wait_queue
486            .get(&lock.key)
487            .map(|w| w.is_empty())
488            .unwrap_or(false)
489        {
490            self.lock_wait_queue.remove(&lock.key);
491        }
492
493        Ok(())
494    }
495
496    /// Check if a key is locked
497    pub fn is_locked(&self, key: &str) -> bool {
498        self.locks
499            .get(key)
500            .map(|l| !l.is_expired())
501            .unwrap_or(false)
502    }
503
504    /// Get all active locks
505    pub fn get_all_locks(&self) -> Vec<Lock> {
506        self.locks
507            .values()
508            .filter(|l| !l.is_expired())
509            .cloned()
510            .collect()
511    }
512
513    /// Get lock for a specific key
514    pub fn get_lock(&self, key: &str) -> Option<&Lock> {
515        self.locks.get(key).filter(|l| !l.is_expired())
516    }
517
518    /// Atomic compare-and-swap operation
519    ///
520    /// Sets the value only if the current value equals the expected value.
521    /// Returns true if the swap was successful.
522    pub fn compare_and_swap(&mut self, key: &str, expected: &Value, new_value: Value) -> bool {
523        let current = self.state.get(key);
524
525        if current == Some(expected) {
526            self.set(key.to_string(), new_value);
527            true
528        } else {
529            false
530        }
531    }
532
533    /// Atomic compare-and-swap with typed values
534    pub fn compare_and_swap_typed<T: Serialize + PartialEq + for<'de> Deserialize<'de>>(
535        &mut self,
536        key: &str,
537        expected: &T,
538        new_value: &T,
539    ) -> SharedStateResult<bool> {
540        let current: Option<T> = self.get_typed(key);
541
542        if current.as_ref() == Some(expected) {
543            self.set_typed(key, new_value)?;
544            Ok(true)
545        } else {
546            Ok(false)
547        }
548    }
549
550    /// Atomic increment operation
551    ///
552    /// Increments the value by delta. If the key doesn't exist, initializes to delta.
553    /// Returns the new value.
554    pub fn increment(&mut self, key: &str, delta: i64) -> i64 {
555        let current = self.state.get(key).and_then(|v| v.as_i64()).unwrap_or(0);
556
557        let new_value = current + delta;
558        self.set(key.to_string(), Value::Number(new_value.into()));
559        new_value
560    }
561
562    /// Atomic decrement operation
563    pub fn decrement(&mut self, key: &str, delta: i64) -> i64 {
564        self.increment(key, -delta)
565    }
566
567    /// Cleanup expired locks
568    ///
569    /// Returns the number of locks cleaned up.
570    pub fn cleanup_expired_locks(&mut self) -> usize {
571        let expired_keys: Vec<String> = self
572            .locks
573            .iter()
574            .filter(|(_, lock)| lock.is_expired())
575            .map(|(key, _)| key.clone())
576            .collect();
577
578        let count = expired_keys.len();
579
580        for key in expired_keys {
581            if let Some(lock) = self.locks.remove(&key) {
582                self.emit_event(StateEvent::LockReleased(lock));
583
584                // Process wait queue for this key - take ownership to avoid borrow issues
585                let waiter = self
586                    .lock_wait_queue
587                    .get_mut(&key)
588                    .and_then(|waiters| waiters.pop());
589
590                if let Some(waiter) = waiter {
591                    let new_lock = Lock::new(&key, &waiter.holder, waiter.timeout);
592                    self.locks.insert(key.clone(), new_lock.clone());
593                    self.emit_event(StateEvent::LockAcquired(new_lock.clone()));
594                    let _ = waiter.sender.send(new_lock);
595                }
596
597                // Clean up empty wait queue
598                if self
599                    .lock_wait_queue
600                    .get(&key)
601                    .map(|w| w.is_empty())
602                    .unwrap_or(false)
603                {
604                    self.lock_wait_queue.remove(&key);
605                }
606            }
607        }
608
609        count
610    }
611
612    /// Get statistics about the shared state
613    pub fn get_stats(&self) -> SharedStateStats {
614        let total_watchers: usize = self.watchers.values().map(|w| w.len()).sum();
615        let wait_queue_size: usize = self.lock_wait_queue.values().map(|w| w.len()).sum();
616
617        SharedStateStats {
618            state_size: self.state.len(),
619            watchers_count: self.watchers.len(),
620            total_watchers,
621            locks_count: self.locks.len(),
622            wait_queue_size,
623        }
624    }
625
626    /// Insert a lock directly (for testing purposes)
627    #[cfg(test)]
628    pub fn insert_lock_for_test(&mut self, lock: Lock) {
629        self.locks.insert(lock.key.clone(), lock);
630    }
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636    use serde_json::json;
637    use std::sync::atomic::{AtomicUsize, Ordering};
638
639    #[test]
640    fn test_get_set() {
641        let mut manager = SharedStateManager::new();
642
643        // Set a value
644        manager.set("key1", json!({"name": "test"}));
645
646        // Get the value
647        let value = manager.get("key1");
648        assert!(value.is_some());
649        assert_eq!(value.unwrap(), json!({"name": "test"}));
650
651        // Get non-existent key
652        assert!(manager.get("non_existent").is_none());
653    }
654
655    #[test]
656    fn test_get_set_typed() {
657        let mut manager = SharedStateManager::new();
658
659        #[derive(Debug, Serialize, Deserialize, PartialEq)]
660        struct Config {
661            max_retries: u32,
662            timeout: u64,
663        }
664
665        let config = Config {
666            max_retries: 3,
667            timeout: 5000,
668        };
669
670        manager.set_typed("config", &config).unwrap();
671
672        let retrieved: Option<Config> = manager.get_typed("config");
673        assert_eq!(retrieved, Some(config));
674    }
675
676    #[test]
677    fn test_delete() {
678        let mut manager = SharedStateManager::new();
679
680        manager.set("key1", json!("value1"));
681        assert!(manager.has("key1"));
682
683        let deleted = manager.delete("key1");
684        assert!(deleted);
685        assert!(!manager.has("key1"));
686
687        // Delete non-existent key
688        let deleted = manager.delete("non_existent");
689        assert!(!deleted);
690    }
691
692    #[test]
693    fn test_keys() {
694        let mut manager = SharedStateManager::new();
695
696        manager.set("key1", json!("value1"));
697        manager.set("key2", json!("value2"));
698        manager.set("key3", json!("value3"));
699
700        let keys = manager.keys();
701        assert_eq!(keys.len(), 3);
702        assert!(keys.contains(&"key1".to_string()));
703        assert!(keys.contains(&"key2".to_string()));
704        assert!(keys.contains(&"key3".to_string()));
705    }
706
707    #[test]
708    fn test_clear() {
709        let mut manager = SharedStateManager::new();
710
711        manager.set("key1", json!("value1"));
712        manager.set("key2", json!("value2"));
713
714        manager.clear();
715
716        assert!(manager.keys().is_empty());
717    }
718
719    #[test]
720    fn test_watch() {
721        let mut manager = SharedStateManager::new();
722        let counter = Arc::new(AtomicUsize::new(0));
723        let counter_clone = counter.clone();
724
725        let handle = manager.watch("key1", move |_value| {
726            counter_clone.fetch_add(1, Ordering::SeqCst);
727        });
728
729        // Set value should trigger watcher
730        manager.set("key1", json!("value1"));
731        assert_eq!(counter.load(Ordering::SeqCst), 1);
732
733        // Set again
734        manager.set("key1", json!("value2"));
735        assert_eq!(counter.load(Ordering::SeqCst), 2);
736
737        // Delete should trigger watcher
738        manager.delete("key1");
739        assert_eq!(counter.load(Ordering::SeqCst), 3);
740
741        // Unwatch
742        manager.unwatch(&handle);
743
744        // Set should not trigger watcher anymore
745        manager.set("key1", json!("value3"));
746        assert_eq!(counter.load(Ordering::SeqCst), 3);
747    }
748
749    #[test]
750    fn test_multiple_watchers() {
751        let mut manager = SharedStateManager::new();
752        let counter1 = Arc::new(AtomicUsize::new(0));
753        let counter2 = Arc::new(AtomicUsize::new(0));
754
755        let c1 = counter1.clone();
756        let c2 = counter2.clone();
757
758        let _handle1 = manager.watch("key1", move |_| {
759            c1.fetch_add(1, Ordering::SeqCst);
760        });
761
762        let _handle2 = manager.watch("key1", move |_| {
763            c2.fetch_add(1, Ordering::SeqCst);
764        });
765
766        manager.set("key1", json!("value"));
767
768        assert_eq!(counter1.load(Ordering::SeqCst), 1);
769        assert_eq!(counter2.load(Ordering::SeqCst), 1);
770    }
771
772    #[test]
773    fn test_lock_unlock() {
774        let mut manager = SharedStateManager::new();
775
776        // Acquire lock
777        let lock = manager.lock("resource1", "agent1", None).unwrap();
778        assert_eq!(lock.key, "resource1");
779        assert_eq!(lock.holder, "agent1");
780        assert!(manager.is_locked("resource1"));
781
782        // Try to acquire same lock should fail
783        let result = manager.lock("resource1", "agent2", None);
784        assert!(result.is_err());
785
786        // Release lock
787        manager.unlock(&lock).unwrap();
788        assert!(!manager.is_locked("resource1"));
789
790        // Now agent2 can acquire
791        let lock2 = manager.lock("resource1", "agent2", None).unwrap();
792        assert_eq!(lock2.holder, "agent2");
793    }
794
795    #[test]
796    fn test_try_lock() {
797        let mut manager = SharedStateManager::new();
798
799        // Try lock should succeed
800        let lock = manager.try_lock("resource1", "agent1", None);
801        assert!(lock.is_some());
802
803        // Try lock again should fail
804        let lock2 = manager.try_lock("resource1", "agent2", None);
805        assert!(lock2.is_none());
806    }
807
808    #[test]
809    fn test_lock_expiration() {
810        let mut manager = SharedStateManager::new();
811
812        // Create a lock that's already expired
813        let lock = Lock {
814            id: uuid::Uuid::new_v4().to_string(),
815            holder: "agent1".to_string(),
816            key: "resource1".to_string(),
817            acquired_at: Utc::now() - Duration::seconds(10),
818            expires_at: Some(Utc::now() - Duration::seconds(5)),
819        };
820        manager.locks.insert("resource1".to_string(), lock);
821
822        // Lock should be considered expired
823        assert!(!manager.is_locked("resource1"));
824
825        // Cleanup should remove it
826        let cleaned = manager.cleanup_expired_locks();
827        assert_eq!(cleaned, 1);
828        assert!(manager.locks.is_empty());
829    }
830
831    #[test]
832    fn test_compare_and_swap() {
833        let mut manager = SharedStateManager::new();
834
835        manager.set("counter", json!(10));
836
837        // CAS with correct expected value
838        let success = manager.compare_and_swap("counter", &json!(10), json!(20));
839        assert!(success);
840        assert_eq!(manager.get("counter"), Some(json!(20)));
841
842        // CAS with incorrect expected value
843        let success = manager.compare_and_swap("counter", &json!(10), json!(30));
844        assert!(!success);
845        assert_eq!(manager.get("counter"), Some(json!(20)));
846    }
847
848    #[test]
849    fn test_increment() {
850        let mut manager = SharedStateManager::new();
851
852        // Increment non-existent key
853        let value = manager.increment("counter", 5);
854        assert_eq!(value, 5);
855
856        // Increment existing key
857        let value = manager.increment("counter", 3);
858        assert_eq!(value, 8);
859
860        // Decrement
861        let value = manager.decrement("counter", 2);
862        assert_eq!(value, 6);
863    }
864
865    #[test]
866    fn test_get_all_locks() {
867        let mut manager = SharedStateManager::new();
868
869        manager.lock("resource1", "agent1", None).unwrap();
870        manager.lock("resource2", "agent2", None).unwrap();
871
872        let locks = manager.get_all_locks();
873        assert_eq!(locks.len(), 2);
874    }
875
876    #[test]
877    fn test_get_stats() {
878        let mut manager = SharedStateManager::new();
879
880        manager.set("key1", json!("value1"));
881        manager.set("key2", json!("value2"));
882        manager.watch("key1", |_| {});
883        manager.watch("key1", |_| {});
884        manager.watch("key2", |_| {});
885        manager.lock("resource1", "agent1", None).unwrap();
886
887        let stats = manager.get_stats();
888        assert_eq!(stats.state_size, 2);
889        assert_eq!(stats.watchers_count, 2);
890        assert_eq!(stats.total_watchers, 3);
891        assert_eq!(stats.locks_count, 1);
892    }
893
894    #[test]
895    fn test_event_listener() {
896        let mut manager = SharedStateManager::new();
897        let events = Arc::new(std::sync::Mutex::new(Vec::new()));
898        let events_clone = events.clone();
899
900        manager.on_event(move |event| {
901            events_clone.lock().unwrap().push(format!("{:?}", event));
902        });
903
904        manager.set("key1", json!("value1"));
905        manager.delete("key1");
906
907        let events = events.lock().unwrap();
908        assert_eq!(events.len(), 2);
909        assert!(events[0].contains("Changed"));
910        assert!(events[1].contains("Deleted"));
911    }
912
913    #[test]
914    fn test_unlock_invalid_lock() {
915        let mut manager = SharedStateManager::new();
916
917        let lock = manager.lock("resource1", "agent1", None).unwrap();
918
919        // Create a fake lock with different ID
920        let fake_lock = Lock {
921            id: "fake-id".to_string(),
922            holder: "agent1".to_string(),
923            key: "resource1".to_string(),
924            acquired_at: Utc::now(),
925            expires_at: None,
926        };
927
928        let result = manager.unlock(&fake_lock);
929        assert!(matches!(result, Err(SharedStateError::InvalidLock(_))));
930
931        // Original lock should still be valid
932        assert!(manager.is_locked("resource1"));
933
934        // Unlock with correct lock should work
935        manager.unlock(&lock).unwrap();
936    }
937
938    #[test]
939    fn test_unlock_not_held() {
940        let mut manager = SharedStateManager::new();
941
942        let fake_lock = Lock {
943            id: "fake-id".to_string(),
944            holder: "agent1".to_string(),
945            key: "resource1".to_string(),
946            acquired_at: Utc::now(),
947            expires_at: None,
948        };
949
950        let result = manager.unlock(&fake_lock);
951        assert!(matches!(result, Err(SharedStateError::LockNotHeld(_))));
952    }
953}