Skip to main content

adk_core/
shared_state.rs

1//! Thread-safe shared state for parallel agent coordination.
2//!
3//! [`SharedState`] is a concurrent key-value store scoped to a single
4//! `ParallelAgent::run()` invocation. Sub-agents use [`SharedState::set_shared`],
5//! [`SharedState::get_shared`], and [`SharedState::wait_for_key`] to exchange data and coordinate.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use serde::ser::{Serialize, SerializeMap, Serializer};
12use serde_json::Value;
13use tokio::sync::{Notify, RwLock};
14
15use crate::AdkError;
16
17/// Minimum allowed timeout for [`SharedState::wait_for_key`].
18const MIN_TIMEOUT: Duration = Duration::from_millis(1);
19/// Maximum allowed timeout for [`SharedState::wait_for_key`].
20const MAX_TIMEOUT: Duration = Duration::from_secs(300);
21/// Maximum key length in bytes.
22const MAX_KEY_LEN: usize = 256;
23
24/// Errors from [`SharedState`] operations.
25#[derive(Debug, thiserror::Error)]
26pub enum SharedStateError {
27    /// Key must not be empty.
28    #[error("shared state key must not be empty")]
29    EmptyKey,
30
31    /// Key exceeds the maximum length.
32    #[error("shared state key exceeds 256 bytes: {len} bytes")]
33    KeyTooLong {
34        /// The actual length of the key in bytes.
35        len: usize,
36    },
37
38    /// `wait_for_key` timed out.
39    #[error("wait_for_key timed out after {timeout:?} for key \"{key}\"")]
40    Timeout {
41        /// The key that was being waited on.
42        key: String,
43        /// The timeout duration that expired.
44        timeout: Duration,
45    },
46
47    /// Timeout value is outside the valid range.
48    #[error("invalid timeout {timeout:?}: must be between 1ms and 300s")]
49    InvalidTimeout {
50        /// The invalid timeout value that was provided.
51        timeout: Duration,
52    },
53}
54
55impl From<SharedStateError> for AdkError {
56    fn from(err: SharedStateError) -> Self {
57        AdkError::agent(err.to_string())
58    }
59}
60
61/// Thread-safe key-value store for parallel agent coordination.
62///
63/// Scoped to a single `ParallelAgent::run()` invocation. All sub-agents
64/// share the same `Arc<SharedState>` instance.
65///
66/// # Example
67///
68/// ```rust,ignore
69/// use adk_core::SharedState;
70/// use std::sync::Arc;
71/// use std::time::Duration;
72///
73/// let state = Arc::new(SharedState::new());
74///
75/// // Agent A publishes a workbook handle
76/// state.set_shared("workbook_id", serde_json::json!("wb-123")).await?;
77///
78/// // Agent B waits for the handle
79/// let handle = state.wait_for_key("workbook_id", Duration::from_secs(30)).await?;
80/// ```
81#[derive(Debug)]
82pub struct SharedState {
83    data: RwLock<HashMap<String, Value>>,
84    notifiers: RwLock<HashMap<String, Arc<Notify>>>,
85}
86
87impl SharedState {
88    /// Creates a new empty `SharedState`.
89    #[must_use]
90    pub fn new() -> Self {
91        Self { data: RwLock::new(HashMap::new()), notifiers: RwLock::new(HashMap::new()) }
92    }
93
94    /// Inserts a key-value pair. Notifies all waiters on that key.
95    ///
96    /// # Errors
97    ///
98    /// Returns [`SharedStateError::EmptyKey`] if key is empty.
99    /// Returns [`SharedStateError::KeyTooLong`] if key exceeds 256 bytes.
100    pub async fn set_shared(
101        &self,
102        key: impl Into<String>,
103        value: Value,
104    ) -> Result<(), SharedStateError> {
105        let key = key.into();
106        validate_key(&key)?;
107
108        self.data.write().await.insert(key.clone(), value);
109
110        // Notify all waiters for this key
111        let notifiers = self.notifiers.read().await;
112        if let Some(notify) = notifiers.get(&key) {
113            notify.notify_waiters();
114        }
115
116        Ok(())
117    }
118
119    /// Returns the value for a key, or `None` if not present.
120    pub async fn get_shared(&self, key: &str) -> Option<Value> {
121        self.data.read().await.get(key).cloned()
122    }
123
124    /// Blocks until the key appears, or the timeout expires.
125    ///
126    /// If the key already exists, returns immediately.
127    ///
128    /// # Errors
129    ///
130    /// Returns [`SharedStateError::Timeout`] if the timeout expires.
131    /// Returns [`SharedStateError::InvalidTimeout`] if timeout is outside [1ms, 300s].
132    pub async fn wait_for_key(
133        &self,
134        key: &str,
135        timeout: Duration,
136    ) -> Result<Value, SharedStateError> {
137        // Validate timeout range
138        if timeout < MIN_TIMEOUT || timeout > MAX_TIMEOUT {
139            return Err(SharedStateError::InvalidTimeout { timeout });
140        }
141
142        // Check if key already exists
143        if let Some(value) = self.data.read().await.get(key).cloned() {
144            return Ok(value);
145        }
146
147        // Get or create a Notify for this key
148        let notify = {
149            let mut notifiers = self.notifiers.write().await;
150            notifiers.entry(key.to_string()).or_insert_with(|| Arc::new(Notify::new())).clone()
151        };
152
153        // Wait with timeout, re-checking after each notification
154        let deadline = tokio::time::Instant::now() + timeout;
155        loop {
156            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
157            if remaining.is_zero() {
158                return Err(SharedStateError::Timeout { key: key.to_string(), timeout });
159            }
160
161            match tokio::time::timeout(remaining, notify.notified()).await {
162                Ok(()) => {
163                    // Check if our key was set
164                    if let Some(value) = self.data.read().await.get(key).cloned() {
165                        return Ok(value);
166                    }
167                    // Spurious wake or different key — loop and wait again
168                }
169                Err(_) => {
170                    return Err(SharedStateError::Timeout { key: key.to_string(), timeout });
171                }
172            }
173        }
174    }
175
176    /// Returns a snapshot of all current entries.
177    pub async fn snapshot(&self) -> HashMap<String, Value> {
178        self.data.read().await.clone()
179    }
180}
181
182impl Default for SharedState {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188impl Serialize for SharedState {
189    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
190        // Use try_read to avoid blocking in a sync context.
191        // If the lock is held, serialize as empty map.
192        match self.data.try_read() {
193            Ok(data) => {
194                let mut map = serializer.serialize_map(Some(data.len()))?;
195                for (k, v) in data.iter() {
196                    map.serialize_entry(k, v)?;
197                }
198                map.end()
199            }
200            Err(_) => serializer.serialize_map(Some(0))?.end(),
201        }
202    }
203}
204
205/// Validates a shared state key.
206fn validate_key(key: &str) -> Result<(), SharedStateError> {
207    if key.is_empty() {
208        return Err(SharedStateError::EmptyKey);
209    }
210    if key.len() > MAX_KEY_LEN {
211        return Err(SharedStateError::KeyTooLong { len: key.len() });
212    }
213    Ok(())
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[tokio::test]
221    async fn new_shared_state_is_empty() {
222        let state = SharedState::new();
223        assert!(state.snapshot().await.is_empty());
224    }
225
226    #[tokio::test]
227    async fn set_and_get() {
228        let state = SharedState::new();
229        state.set_shared("key", serde_json::json!("value")).await.unwrap();
230        assert_eq!(state.get_shared("key").await, Some(serde_json::json!("value")));
231    }
232
233    #[tokio::test]
234    async fn get_missing_returns_none() {
235        let state = SharedState::new();
236        assert_eq!(state.get_shared("missing").await, None);
237    }
238
239    #[tokio::test]
240    async fn overwrite_replaces_value() {
241        let state = SharedState::new();
242        state.set_shared("key", serde_json::json!(1)).await.unwrap();
243        state.set_shared("key", serde_json::json!(2)).await.unwrap();
244        assert_eq!(state.get_shared("key").await, Some(serde_json::json!(2)));
245    }
246
247    #[tokio::test]
248    async fn empty_key_rejected() {
249        let state = SharedState::new();
250        let err = state.set_shared("", serde_json::json!(1)).await.unwrap_err();
251        assert!(matches!(err, SharedStateError::EmptyKey));
252    }
253
254    #[tokio::test]
255    async fn long_key_rejected() {
256        let state = SharedState::new();
257        let long_key = "x".repeat(257);
258        let err = state.set_shared(long_key, serde_json::json!(1)).await.unwrap_err();
259        assert!(matches!(err, SharedStateError::KeyTooLong { .. }));
260    }
261
262    #[tokio::test]
263    async fn key_at_256_bytes_accepted() {
264        let state = SharedState::new();
265        let key = "x".repeat(256);
266        state.set_shared(key.clone(), serde_json::json!(1)).await.unwrap();
267        assert_eq!(state.get_shared(&key).await, Some(serde_json::json!(1)));
268    }
269
270    #[tokio::test]
271    async fn wait_for_existing_key_returns_immediately() {
272        let state = SharedState::new();
273        state.set_shared("key", serde_json::json!("val")).await.unwrap();
274        let val = state.wait_for_key("key", Duration::from_secs(1)).await.unwrap();
275        assert_eq!(val, serde_json::json!("val"));
276    }
277
278    #[tokio::test]
279    async fn wait_for_key_timeout() {
280        let state = SharedState::new();
281        let err = state.wait_for_key("missing", Duration::from_millis(10)).await.unwrap_err();
282        assert!(matches!(err, SharedStateError::Timeout { .. }));
283    }
284
285    #[tokio::test]
286    async fn wait_for_key_invalid_timeout_too_small() {
287        let state = SharedState::new();
288        let err = state.wait_for_key("key", Duration::from_nanos(1)).await.unwrap_err();
289        assert!(matches!(err, SharedStateError::InvalidTimeout { .. }));
290    }
291
292    #[tokio::test]
293    async fn wait_for_key_invalid_timeout_too_large() {
294        let state = SharedState::new();
295        let err = state.wait_for_key("key", Duration::from_secs(301)).await.unwrap_err();
296        assert!(matches!(err, SharedStateError::InvalidTimeout { .. }));
297    }
298
299    #[tokio::test]
300    async fn error_converts_to_adk_error() {
301        let err = SharedStateError::EmptyKey;
302        let adk_err: AdkError = err.into();
303        assert!(adk_err.to_string().contains("empty"));
304    }
305}