1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use serde::ser::{Serialize, SerializeMap, Serializer};
12use serde_json::Value;
13use tokio::sync::{Notify, RwLock};
14
15use crate::AdkError;
16
17const MIN_TIMEOUT: Duration = Duration::from_millis(1);
19const MAX_TIMEOUT: Duration = Duration::from_secs(300);
21const MAX_KEY_LEN: usize = 256;
23
24#[derive(Debug, thiserror::Error)]
26pub enum SharedStateError {
27 #[error("shared state key must not be empty")]
29 EmptyKey,
30
31 #[error("shared state key exceeds 256 bytes: {len} bytes")]
33 KeyTooLong { len: usize },
34
35 #[error("wait_for_key timed out after {timeout:?} for key \"{key}\"")]
37 Timeout { key: String, timeout: Duration },
38
39 #[error("invalid timeout {timeout:?}: must be between 1ms and 300s")]
41 InvalidTimeout { timeout: Duration },
42}
43
44impl From<SharedStateError> for AdkError {
45 fn from(err: SharedStateError) -> Self {
46 AdkError::agent(err.to_string())
47 }
48}
49
50#[derive(Debug)]
71pub struct SharedState {
72 data: RwLock<HashMap<String, Value>>,
73 notifiers: RwLock<HashMap<String, Arc<Notify>>>,
74}
75
76impl SharedState {
77 #[must_use]
79 pub fn new() -> Self {
80 Self { data: RwLock::new(HashMap::new()), notifiers: RwLock::new(HashMap::new()) }
81 }
82
83 pub async fn set_shared(
90 &self,
91 key: impl Into<String>,
92 value: Value,
93 ) -> Result<(), SharedStateError> {
94 let key = key.into();
95 validate_key(&key)?;
96
97 self.data.write().await.insert(key.clone(), value);
98
99 let notifiers = self.notifiers.read().await;
101 if let Some(notify) = notifiers.get(&key) {
102 notify.notify_waiters();
103 }
104
105 Ok(())
106 }
107
108 pub async fn get_shared(&self, key: &str) -> Option<Value> {
110 self.data.read().await.get(key).cloned()
111 }
112
113 pub async fn wait_for_key(
122 &self,
123 key: &str,
124 timeout: Duration,
125 ) -> Result<Value, SharedStateError> {
126 if timeout < MIN_TIMEOUT || timeout > MAX_TIMEOUT {
128 return Err(SharedStateError::InvalidTimeout { timeout });
129 }
130
131 if let Some(value) = self.data.read().await.get(key).cloned() {
133 return Ok(value);
134 }
135
136 let notify = {
138 let mut notifiers = self.notifiers.write().await;
139 notifiers.entry(key.to_string()).or_insert_with(|| Arc::new(Notify::new())).clone()
140 };
141
142 let deadline = tokio::time::Instant::now() + timeout;
144 loop {
145 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
146 if remaining.is_zero() {
147 return Err(SharedStateError::Timeout { key: key.to_string(), timeout });
148 }
149
150 match tokio::time::timeout(remaining, notify.notified()).await {
151 Ok(()) => {
152 if let Some(value) = self.data.read().await.get(key).cloned() {
154 return Ok(value);
155 }
156 }
158 Err(_) => {
159 return Err(SharedStateError::Timeout { key: key.to_string(), timeout });
160 }
161 }
162 }
163 }
164
165 pub async fn snapshot(&self) -> HashMap<String, Value> {
167 self.data.read().await.clone()
168 }
169}
170
171impl Default for SharedState {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177impl Serialize for SharedState {
178 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
179 match self.data.try_read() {
182 Ok(data) => {
183 let mut map = serializer.serialize_map(Some(data.len()))?;
184 for (k, v) in data.iter() {
185 map.serialize_entry(k, v)?;
186 }
187 map.end()
188 }
189 Err(_) => serializer.serialize_map(Some(0))?.end(),
190 }
191 }
192}
193
194fn validate_key(key: &str) -> Result<(), SharedStateError> {
196 if key.is_empty() {
197 return Err(SharedStateError::EmptyKey);
198 }
199 if key.len() > MAX_KEY_LEN {
200 return Err(SharedStateError::KeyTooLong { len: key.len() });
201 }
202 Ok(())
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[tokio::test]
210 async fn new_shared_state_is_empty() {
211 let state = SharedState::new();
212 assert!(state.snapshot().await.is_empty());
213 }
214
215 #[tokio::test]
216 async fn set_and_get() {
217 let state = SharedState::new();
218 state.set_shared("key", serde_json::json!("value")).await.unwrap();
219 assert_eq!(state.get_shared("key").await, Some(serde_json::json!("value")));
220 }
221
222 #[tokio::test]
223 async fn get_missing_returns_none() {
224 let state = SharedState::new();
225 assert_eq!(state.get_shared("missing").await, None);
226 }
227
228 #[tokio::test]
229 async fn overwrite_replaces_value() {
230 let state = SharedState::new();
231 state.set_shared("key", serde_json::json!(1)).await.unwrap();
232 state.set_shared("key", serde_json::json!(2)).await.unwrap();
233 assert_eq!(state.get_shared("key").await, Some(serde_json::json!(2)));
234 }
235
236 #[tokio::test]
237 async fn empty_key_rejected() {
238 let state = SharedState::new();
239 let err = state.set_shared("", serde_json::json!(1)).await.unwrap_err();
240 assert!(matches!(err, SharedStateError::EmptyKey));
241 }
242
243 #[tokio::test]
244 async fn long_key_rejected() {
245 let state = SharedState::new();
246 let long_key = "x".repeat(257);
247 let err = state.set_shared(long_key, serde_json::json!(1)).await.unwrap_err();
248 assert!(matches!(err, SharedStateError::KeyTooLong { .. }));
249 }
250
251 #[tokio::test]
252 async fn key_at_256_bytes_accepted() {
253 let state = SharedState::new();
254 let key = "x".repeat(256);
255 state.set_shared(key.clone(), serde_json::json!(1)).await.unwrap();
256 assert_eq!(state.get_shared(&key).await, Some(serde_json::json!(1)));
257 }
258
259 #[tokio::test]
260 async fn wait_for_existing_key_returns_immediately() {
261 let state = SharedState::new();
262 state.set_shared("key", serde_json::json!("val")).await.unwrap();
263 let val = state.wait_for_key("key", Duration::from_secs(1)).await.unwrap();
264 assert_eq!(val, serde_json::json!("val"));
265 }
266
267 #[tokio::test]
268 async fn wait_for_key_timeout() {
269 let state = SharedState::new();
270 let err = state.wait_for_key("missing", Duration::from_millis(10)).await.unwrap_err();
271 assert!(matches!(err, SharedStateError::Timeout { .. }));
272 }
273
274 #[tokio::test]
275 async fn wait_for_key_invalid_timeout_too_small() {
276 let state = SharedState::new();
277 let err = state.wait_for_key("key", Duration::from_nanos(1)).await.unwrap_err();
278 assert!(matches!(err, SharedStateError::InvalidTimeout { .. }));
279 }
280
281 #[tokio::test]
282 async fn wait_for_key_invalid_timeout_too_large() {
283 let state = SharedState::new();
284 let err = state.wait_for_key("key", Duration::from_secs(301)).await.unwrap_err();
285 assert!(matches!(err, SharedStateError::InvalidTimeout { .. }));
286 }
287
288 #[tokio::test]
289 async fn error_converts_to_adk_error() {
290 let err = SharedStateError::EmptyKey;
291 let adk_err: AdkError = err.into();
292 assert!(adk_err.to_string().contains("empty"));
293 }
294}