Skip to main content

adk_core/
shared_state.rs

1//! Thread-safe shared state for parallel agent coordination.
2//!
3//! [`SharedState`] is a concurrent key-value store scoped to a single
4//! `ParallelAgent::run()` invocation. Sub-agents use [`SharedState::set_shared`],
5//! [`SharedState::get_shared`], and [`SharedState::wait_for_key`] to exchange data and coordinate.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use serde::ser::{Serialize, SerializeMap, Serializer};
12use serde_json::Value;
13use tokio::sync::{Notify, RwLock};
14
15use crate::AdkError;
16
17/// Minimum allowed timeout for [`SharedState::wait_for_key`].
18const MIN_TIMEOUT: Duration = Duration::from_millis(1);
19/// Maximum allowed timeout for [`SharedState::wait_for_key`].
20const MAX_TIMEOUT: Duration = Duration::from_secs(300);
21/// Maximum key length in bytes.
22const MAX_KEY_LEN: usize = 256;
23
24/// Errors from [`SharedState`] operations.
25#[derive(Debug, thiserror::Error)]
26pub enum SharedStateError {
27    /// Key must not be empty.
28    #[error("shared state key must not be empty")]
29    EmptyKey,
30
31    /// Key exceeds the maximum length.
32    #[error("shared state key exceeds 256 bytes: {len} bytes")]
33    KeyTooLong { len: usize },
34
35    /// `wait_for_key` timed out.
36    #[error("wait_for_key timed out after {timeout:?} for key \"{key}\"")]
37    Timeout { key: String, timeout: Duration },
38
39    /// Timeout value is outside the valid range.
40    #[error("invalid timeout {timeout:?}: must be between 1ms and 300s")]
41    InvalidTimeout { timeout: Duration },
42}
43
44impl From<SharedStateError> for AdkError {
45    fn from(err: SharedStateError) -> Self {
46        AdkError::agent(err.to_string())
47    }
48}
49
50/// Thread-safe key-value store for parallel agent coordination.
51///
52/// Scoped to a single `ParallelAgent::run()` invocation. All sub-agents
53/// share the same `Arc<SharedState>` instance.
54///
55/// # Example
56///
57/// ```rust,ignore
58/// use adk_core::SharedState;
59/// use std::sync::Arc;
60/// use std::time::Duration;
61///
62/// let state = Arc::new(SharedState::new());
63///
64/// // Agent A publishes a workbook handle
65/// state.set_shared("workbook_id", serde_json::json!("wb-123")).await?;
66///
67/// // Agent B waits for the handle
68/// let handle = state.wait_for_key("workbook_id", Duration::from_secs(30)).await?;
69/// ```
70#[derive(Debug)]
71pub struct SharedState {
72    data: RwLock<HashMap<String, Value>>,
73    notifiers: RwLock<HashMap<String, Arc<Notify>>>,
74}
75
76impl SharedState {
77    /// Creates a new empty `SharedState`.
78    #[must_use]
79    pub fn new() -> Self {
80        Self { data: RwLock::new(HashMap::new()), notifiers: RwLock::new(HashMap::new()) }
81    }
82
83    /// Inserts a key-value pair. Notifies all waiters on that key.
84    ///
85    /// # Errors
86    ///
87    /// Returns [`SharedStateError::EmptyKey`] if key is empty.
88    /// Returns [`SharedStateError::KeyTooLong`] if key exceeds 256 bytes.
89    pub async fn set_shared(
90        &self,
91        key: impl Into<String>,
92        value: Value,
93    ) -> Result<(), SharedStateError> {
94        let key = key.into();
95        validate_key(&key)?;
96
97        self.data.write().await.insert(key.clone(), value);
98
99        // Notify all waiters for this key
100        let notifiers = self.notifiers.read().await;
101        if let Some(notify) = notifiers.get(&key) {
102            notify.notify_waiters();
103        }
104
105        Ok(())
106    }
107
108    /// Returns the value for a key, or `None` if not present.
109    pub async fn get_shared(&self, key: &str) -> Option<Value> {
110        self.data.read().await.get(key).cloned()
111    }
112
113    /// Blocks until the key appears, or the timeout expires.
114    ///
115    /// If the key already exists, returns immediately.
116    ///
117    /// # Errors
118    ///
119    /// Returns [`SharedStateError::Timeout`] if the timeout expires.
120    /// Returns [`SharedStateError::InvalidTimeout`] if timeout is outside [1ms, 300s].
121    pub async fn wait_for_key(
122        &self,
123        key: &str,
124        timeout: Duration,
125    ) -> Result<Value, SharedStateError> {
126        // Validate timeout range
127        if timeout < MIN_TIMEOUT || timeout > MAX_TIMEOUT {
128            return Err(SharedStateError::InvalidTimeout { timeout });
129        }
130
131        // Check if key already exists
132        if let Some(value) = self.data.read().await.get(key).cloned() {
133            return Ok(value);
134        }
135
136        // Get or create a Notify for this key
137        let notify = {
138            let mut notifiers = self.notifiers.write().await;
139            notifiers.entry(key.to_string()).or_insert_with(|| Arc::new(Notify::new())).clone()
140        };
141
142        // Wait with timeout, re-checking after each notification
143        let deadline = tokio::time::Instant::now() + timeout;
144        loop {
145            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
146            if remaining.is_zero() {
147                return Err(SharedStateError::Timeout { key: key.to_string(), timeout });
148            }
149
150            match tokio::time::timeout(remaining, notify.notified()).await {
151                Ok(()) => {
152                    // Check if our key was set
153                    if let Some(value) = self.data.read().await.get(key).cloned() {
154                        return Ok(value);
155                    }
156                    // Spurious wake or different key — loop and wait again
157                }
158                Err(_) => {
159                    return Err(SharedStateError::Timeout { key: key.to_string(), timeout });
160                }
161            }
162        }
163    }
164
165    /// Returns a snapshot of all current entries.
166    pub async fn snapshot(&self) -> HashMap<String, Value> {
167        self.data.read().await.clone()
168    }
169}
170
171impl Default for SharedState {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177impl Serialize for SharedState {
178    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
179        // Use try_read to avoid blocking in a sync context.
180        // If the lock is held, serialize as empty map.
181        match self.data.try_read() {
182            Ok(data) => {
183                let mut map = serializer.serialize_map(Some(data.len()))?;
184                for (k, v) in data.iter() {
185                    map.serialize_entry(k, v)?;
186                }
187                map.end()
188            }
189            Err(_) => serializer.serialize_map(Some(0))?.end(),
190        }
191    }
192}
193
194/// Validates a shared state key.
195fn validate_key(key: &str) -> Result<(), SharedStateError> {
196    if key.is_empty() {
197        return Err(SharedStateError::EmptyKey);
198    }
199    if key.len() > MAX_KEY_LEN {
200        return Err(SharedStateError::KeyTooLong { len: key.len() });
201    }
202    Ok(())
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[tokio::test]
210    async fn new_shared_state_is_empty() {
211        let state = SharedState::new();
212        assert!(state.snapshot().await.is_empty());
213    }
214
215    #[tokio::test]
216    async fn set_and_get() {
217        let state = SharedState::new();
218        state.set_shared("key", serde_json::json!("value")).await.unwrap();
219        assert_eq!(state.get_shared("key").await, Some(serde_json::json!("value")));
220    }
221
222    #[tokio::test]
223    async fn get_missing_returns_none() {
224        let state = SharedState::new();
225        assert_eq!(state.get_shared("missing").await, None);
226    }
227
228    #[tokio::test]
229    async fn overwrite_replaces_value() {
230        let state = SharedState::new();
231        state.set_shared("key", serde_json::json!(1)).await.unwrap();
232        state.set_shared("key", serde_json::json!(2)).await.unwrap();
233        assert_eq!(state.get_shared("key").await, Some(serde_json::json!(2)));
234    }
235
236    #[tokio::test]
237    async fn empty_key_rejected() {
238        let state = SharedState::new();
239        let err = state.set_shared("", serde_json::json!(1)).await.unwrap_err();
240        assert!(matches!(err, SharedStateError::EmptyKey));
241    }
242
243    #[tokio::test]
244    async fn long_key_rejected() {
245        let state = SharedState::new();
246        let long_key = "x".repeat(257);
247        let err = state.set_shared(long_key, serde_json::json!(1)).await.unwrap_err();
248        assert!(matches!(err, SharedStateError::KeyTooLong { .. }));
249    }
250
251    #[tokio::test]
252    async fn key_at_256_bytes_accepted() {
253        let state = SharedState::new();
254        let key = "x".repeat(256);
255        state.set_shared(key.clone(), serde_json::json!(1)).await.unwrap();
256        assert_eq!(state.get_shared(&key).await, Some(serde_json::json!(1)));
257    }
258
259    #[tokio::test]
260    async fn wait_for_existing_key_returns_immediately() {
261        let state = SharedState::new();
262        state.set_shared("key", serde_json::json!("val")).await.unwrap();
263        let val = state.wait_for_key("key", Duration::from_secs(1)).await.unwrap();
264        assert_eq!(val, serde_json::json!("val"));
265    }
266
267    #[tokio::test]
268    async fn wait_for_key_timeout() {
269        let state = SharedState::new();
270        let err = state.wait_for_key("missing", Duration::from_millis(10)).await.unwrap_err();
271        assert!(matches!(err, SharedStateError::Timeout { .. }));
272    }
273
274    #[tokio::test]
275    async fn wait_for_key_invalid_timeout_too_small() {
276        let state = SharedState::new();
277        let err = state.wait_for_key("key", Duration::from_nanos(1)).await.unwrap_err();
278        assert!(matches!(err, SharedStateError::InvalidTimeout { .. }));
279    }
280
281    #[tokio::test]
282    async fn wait_for_key_invalid_timeout_too_large() {
283        let state = SharedState::new();
284        let err = state.wait_for_key("key", Duration::from_secs(301)).await.unwrap_err();
285        assert!(matches!(err, SharedStateError::InvalidTimeout { .. }));
286    }
287
288    #[tokio::test]
289    async fn error_converts_to_adk_error() {
290        let err = SharedStateError::EmptyKey;
291        let adk_err: AdkError = err.into();
292        assert!(adk_err.to_string().contains("empty"));
293    }
294}