1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use serde::ser::{Serialize, SerializeMap, Serializer};
12use serde_json::Value;
13use tokio::sync::{Notify, RwLock};
14
15use crate::AdkError;
16
17const MIN_TIMEOUT: Duration = Duration::from_millis(1);
19const MAX_TIMEOUT: Duration = Duration::from_secs(300);
21const MAX_KEY_LEN: usize = 256;
23
24#[derive(Debug, thiserror::Error)]
26pub enum SharedStateError {
27 #[error("shared state key must not be empty")]
29 EmptyKey,
30
31 #[error("shared state key exceeds 256 bytes: {len} bytes")]
33 KeyTooLong {
34 len: usize,
36 },
37
38 #[error("wait_for_key timed out after {timeout:?} for key \"{key}\"")]
40 Timeout {
41 key: String,
43 timeout: Duration,
45 },
46
47 #[error("invalid timeout {timeout:?}: must be between 1ms and 300s")]
49 InvalidTimeout {
50 timeout: Duration,
52 },
53}
54
55impl From<SharedStateError> for AdkError {
56 fn from(err: SharedStateError) -> Self {
57 AdkError::agent(err.to_string())
58 }
59}
60
61#[derive(Debug)]
82pub struct SharedState {
83 data: RwLock<HashMap<String, Value>>,
84 notifiers: RwLock<HashMap<String, Arc<Notify>>>,
85}
86
87impl SharedState {
88 #[must_use]
90 pub fn new() -> Self {
91 Self { data: RwLock::new(HashMap::new()), notifiers: RwLock::new(HashMap::new()) }
92 }
93
94 pub async fn set_shared(
101 &self,
102 key: impl Into<String>,
103 value: Value,
104 ) -> Result<(), SharedStateError> {
105 let key = key.into();
106 validate_key(&key)?;
107
108 self.data.write().await.insert(key.clone(), value);
109
110 let notifiers = self.notifiers.read().await;
112 if let Some(notify) = notifiers.get(&key) {
113 notify.notify_waiters();
114 }
115
116 Ok(())
117 }
118
119 pub async fn get_shared(&self, key: &str) -> Option<Value> {
121 self.data.read().await.get(key).cloned()
122 }
123
124 pub async fn wait_for_key(
133 &self,
134 key: &str,
135 timeout: Duration,
136 ) -> Result<Value, SharedStateError> {
137 if timeout < MIN_TIMEOUT || timeout > MAX_TIMEOUT {
139 return Err(SharedStateError::InvalidTimeout { timeout });
140 }
141
142 if let Some(value) = self.data.read().await.get(key).cloned() {
144 return Ok(value);
145 }
146
147 let notify = {
149 let mut notifiers = self.notifiers.write().await;
150 notifiers.entry(key.to_string()).or_insert_with(|| Arc::new(Notify::new())).clone()
151 };
152
153 let deadline = tokio::time::Instant::now() + timeout;
155 loop {
156 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
157 if remaining.is_zero() {
158 return Err(SharedStateError::Timeout { key: key.to_string(), timeout });
159 }
160
161 match tokio::time::timeout(remaining, notify.notified()).await {
162 Ok(()) => {
163 if let Some(value) = self.data.read().await.get(key).cloned() {
165 return Ok(value);
166 }
167 }
169 Err(_) => {
170 return Err(SharedStateError::Timeout { key: key.to_string(), timeout });
171 }
172 }
173 }
174 }
175
176 pub async fn snapshot(&self) -> HashMap<String, Value> {
178 self.data.read().await.clone()
179 }
180}
181
182impl Default for SharedState {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188impl Serialize for SharedState {
189 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
190 match self.data.try_read() {
193 Ok(data) => {
194 let mut map = serializer.serialize_map(Some(data.len()))?;
195 for (k, v) in data.iter() {
196 map.serialize_entry(k, v)?;
197 }
198 map.end()
199 }
200 Err(_) => serializer.serialize_map(Some(0))?.end(),
201 }
202 }
203}
204
205fn validate_key(key: &str) -> Result<(), SharedStateError> {
207 if key.is_empty() {
208 return Err(SharedStateError::EmptyKey);
209 }
210 if key.len() > MAX_KEY_LEN {
211 return Err(SharedStateError::KeyTooLong { len: key.len() });
212 }
213 Ok(())
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[tokio::test]
221 async fn new_shared_state_is_empty() {
222 let state = SharedState::new();
223 assert!(state.snapshot().await.is_empty());
224 }
225
226 #[tokio::test]
227 async fn set_and_get() {
228 let state = SharedState::new();
229 state.set_shared("key", serde_json::json!("value")).await.unwrap();
230 assert_eq!(state.get_shared("key").await, Some(serde_json::json!("value")));
231 }
232
233 #[tokio::test]
234 async fn get_missing_returns_none() {
235 let state = SharedState::new();
236 assert_eq!(state.get_shared("missing").await, None);
237 }
238
239 #[tokio::test]
240 async fn overwrite_replaces_value() {
241 let state = SharedState::new();
242 state.set_shared("key", serde_json::json!(1)).await.unwrap();
243 state.set_shared("key", serde_json::json!(2)).await.unwrap();
244 assert_eq!(state.get_shared("key").await, Some(serde_json::json!(2)));
245 }
246
247 #[tokio::test]
248 async fn empty_key_rejected() {
249 let state = SharedState::new();
250 let err = state.set_shared("", serde_json::json!(1)).await.unwrap_err();
251 assert!(matches!(err, SharedStateError::EmptyKey));
252 }
253
254 #[tokio::test]
255 async fn long_key_rejected() {
256 let state = SharedState::new();
257 let long_key = "x".repeat(257);
258 let err = state.set_shared(long_key, serde_json::json!(1)).await.unwrap_err();
259 assert!(matches!(err, SharedStateError::KeyTooLong { .. }));
260 }
261
262 #[tokio::test]
263 async fn key_at_256_bytes_accepted() {
264 let state = SharedState::new();
265 let key = "x".repeat(256);
266 state.set_shared(key.clone(), serde_json::json!(1)).await.unwrap();
267 assert_eq!(state.get_shared(&key).await, Some(serde_json::json!(1)));
268 }
269
270 #[tokio::test]
271 async fn wait_for_existing_key_returns_immediately() {
272 let state = SharedState::new();
273 state.set_shared("key", serde_json::json!("val")).await.unwrap();
274 let val = state.wait_for_key("key", Duration::from_secs(1)).await.unwrap();
275 assert_eq!(val, serde_json::json!("val"));
276 }
277
278 #[tokio::test]
279 async fn wait_for_key_timeout() {
280 let state = SharedState::new();
281 let err = state.wait_for_key("missing", Duration::from_millis(10)).await.unwrap_err();
282 assert!(matches!(err, SharedStateError::Timeout { .. }));
283 }
284
285 #[tokio::test]
286 async fn wait_for_key_invalid_timeout_too_small() {
287 let state = SharedState::new();
288 let err = state.wait_for_key("key", Duration::from_nanos(1)).await.unwrap_err();
289 assert!(matches!(err, SharedStateError::InvalidTimeout { .. }));
290 }
291
292 #[tokio::test]
293 async fn wait_for_key_invalid_timeout_too_large() {
294 let state = SharedState::new();
295 let err = state.wait_for_key("key", Duration::from_secs(301)).await.unwrap_err();
296 assert!(matches!(err, SharedStateError::InvalidTimeout { .. }));
297 }
298
299 #[tokio::test]
300 async fn error_converts_to_adk_error() {
301 let err = SharedStateError::EmptyKey;
302 let adk_err: AdkError = err.into();
303 assert!(adk_err.to_string().contains("empty"));
304 }
305}