a2a_protocol_server/push/
config_store.rs1use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11
12use a2a_protocol_types::error::A2aResult;
13use a2a_protocol_types::push::TaskPushNotificationConfig;
14use tokio::sync::RwLock;
15
16pub trait PushConfigStore: Send + Sync + 'static {
20 fn set<'a>(
26 &'a self,
27 config: TaskPushNotificationConfig,
28 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskPushNotificationConfig>> + Send + 'a>>;
29
30 fn get<'a>(
36 &'a self,
37 task_id: &'a str,
38 id: &'a str,
39 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<TaskPushNotificationConfig>>> + Send + 'a>>;
40
41 fn list<'a>(
47 &'a self,
48 task_id: &'a str,
49 ) -> Pin<Box<dyn Future<Output = A2aResult<Vec<TaskPushNotificationConfig>>> + Send + 'a>>;
50
51 fn delete<'a>(
57 &'a self,
58 task_id: &'a str,
59 id: &'a str,
60 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>>;
61}
62
63const DEFAULT_MAX_PUSH_CONFIGS_PER_TASK: usize = 100;
65
66const DEFAULT_MAX_TOTAL_PUSH_CONFIGS: usize = 100_000;
69
70#[derive(Debug)]
75pub struct InMemoryPushConfigStore {
76 configs: RwLock<HashMap<(String, String), TaskPushNotificationConfig>>,
77 task_counts: RwLock<HashMap<String, usize>>,
79 max_configs_per_task: usize,
81 max_total_configs: usize,
83}
84
85impl Default for InMemoryPushConfigStore {
86 fn default() -> Self {
87 Self {
88 configs: RwLock::new(HashMap::new()),
89 task_counts: RwLock::new(HashMap::new()),
90 max_configs_per_task: DEFAULT_MAX_PUSH_CONFIGS_PER_TASK,
91 max_total_configs: DEFAULT_MAX_TOTAL_PUSH_CONFIGS,
92 }
93 }
94}
95
96impl InMemoryPushConfigStore {
97 #[must_use]
99 pub fn new() -> Self {
100 Self::default()
101 }
102
103 #[must_use]
105 pub fn with_max_configs_per_task(max: usize) -> Self {
106 Self {
107 configs: RwLock::new(HashMap::new()),
108 task_counts: RwLock::new(HashMap::new()),
109 max_configs_per_task: max,
110 max_total_configs: DEFAULT_MAX_TOTAL_PUSH_CONFIGS,
111 }
112 }
113
114 #[must_use]
119 pub const fn with_max_total_configs(mut self, max: usize) -> Self {
120 self.max_total_configs = max;
121 self
122 }
123}
124
125#[allow(clippy::manual_async_fn)]
126impl PushConfigStore for InMemoryPushConfigStore {
127 fn set<'a>(
128 &'a self,
129 mut config: TaskPushNotificationConfig,
130 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskPushNotificationConfig>> + Send + 'a>> {
131 Box::pin(async move {
132 let id = config
134 .id
135 .clone()
136 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
137 config.id = Some(id.clone());
138
139 let key = (config.task_id.clone(), id);
140 let mut store = self.configs.write().await;
141 let mut counts = self.task_counts.write().await;
142
143 let is_new = !store.contains_key(&key);
145 if is_new {
146 let total = store.len();
148 if total >= self.max_total_configs {
149 drop(counts);
150 drop(store);
151 return Err(a2a_protocol_types::error::A2aError::invalid_params(
152 format!(
153 "global push config limit exceeded: {total} configs (max {})",
154 self.max_total_configs,
155 ),
156 ));
157 }
158 let task_id = &config.task_id;
161 let count = counts.get(task_id).copied().unwrap_or(0);
162 let max = self.max_configs_per_task;
163 if count >= max {
164 drop(counts);
165 drop(store);
166 return Err(a2a_protocol_types::error::A2aError::invalid_params(format!(
167 "push config limit exceeded: task {task_id} already has {count} configs (max {max})"
168 )));
169 }
170 }
171
172 store.insert(key, config.clone());
173 if is_new {
174 *counts.entry(config.task_id.clone()).or_insert(0) += 1;
175 }
176 drop(counts);
177 drop(store);
178 Ok(config)
179 })
180 }
181
182 fn get<'a>(
183 &'a self,
184 task_id: &'a str,
185 id: &'a str,
186 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<TaskPushNotificationConfig>>> + Send + 'a>>
187 {
188 Box::pin(async move {
189 let store = self.configs.read().await;
190 let key = (task_id.to_owned(), id.to_owned());
191 let result = store.get(&key).cloned();
192 drop(store);
193 Ok(result)
194 })
195 }
196
197 fn list<'a>(
198 &'a self,
199 task_id: &'a str,
200 ) -> Pin<Box<dyn Future<Output = A2aResult<Vec<TaskPushNotificationConfig>>> + Send + 'a>> {
201 Box::pin(async move {
202 let store = self.configs.read().await;
203 let configs: Vec<_> = store
204 .iter()
205 .filter(|((tid, _), _)| tid == task_id)
206 .map(|(_, v)| v.clone())
207 .collect();
208 drop(store);
209 Ok(configs)
210 })
211 }
212
213 fn delete<'a>(
214 &'a self,
215 task_id: &'a str,
216 id: &'a str,
217 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
218 Box::pin(async move {
219 let mut store = self.configs.write().await;
220 let mut counts = self.task_counts.write().await;
221 let key = (task_id.to_owned(), id.to_owned());
222 if store.remove(&key).is_some() {
223 if let Some(count) = counts.get_mut(task_id) {
225 *count = count.saturating_sub(1);
226 if *count == 0 {
227 counts.remove(task_id);
228 }
229 }
230 }
231 drop(counts);
232 drop(store);
233 Ok(())
234 })
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use a2a_protocol_types::push::TaskPushNotificationConfig;
242
243 fn make_config(task_id: &str, id: Option<&str>, url: &str) -> TaskPushNotificationConfig {
244 TaskPushNotificationConfig {
245 tenant: None,
246 id: id.map(String::from),
247 task_id: task_id.to_string(),
248 url: url.to_string(),
249 token: None,
250 authentication: None,
251 }
252 }
253
254 #[tokio::test]
255 async fn set_assigns_id_when_none() {
256 let store = InMemoryPushConfigStore::new();
257 let config = make_config("task-1", None, "https://example.com/hook");
258 let result = store.set(config).await.expect("set should succeed");
259 assert!(
260 result.id.is_some(),
261 "set should assign an id when none is provided"
262 );
263 }
264
265 #[tokio::test]
266 async fn set_preserves_explicit_id() {
267 let store = InMemoryPushConfigStore::new();
268 let config = make_config("task-1", Some("my-id"), "https://example.com/hook");
269 let result = store.set(config).await.expect("set should succeed");
270 assert_eq!(
271 result.id.as_deref(),
272 Some("my-id"),
273 "set should preserve the explicitly provided id"
274 );
275 }
276
277 #[tokio::test]
278 async fn get_returns_none_for_missing_config() {
279 let store = InMemoryPushConfigStore::new();
280 let result = store
281 .get("no-task", "no-id")
282 .await
283 .expect("get should succeed");
284 assert!(
285 result.is_none(),
286 "get should return None for a non-existent config"
287 );
288 }
289
290 #[tokio::test]
291 async fn set_then_get_round_trip() {
292 let store = InMemoryPushConfigStore::new();
293 let config = make_config("task-1", Some("cfg-1"), "https://example.com/hook");
294 store.set(config).await.expect("set should succeed");
295
296 let retrieved = store
297 .get("task-1", "cfg-1")
298 .await
299 .expect("get should succeed")
300 .expect("config should exist after set");
301 assert_eq!(retrieved.task_id, "task-1");
302 assert_eq!(retrieved.url, "https://example.com/hook");
303 }
304
305 #[tokio::test]
306 async fn overwrite_existing_config() {
307 let store = InMemoryPushConfigStore::new();
308 let config1 = make_config("task-1", Some("cfg-1"), "https://example.com/v1");
309 store.set(config1).await.expect("first set should succeed");
310
311 let config2 = make_config("task-1", Some("cfg-1"), "https://example.com/v2");
312 store
313 .set(config2)
314 .await
315 .expect("overwrite set should succeed");
316
317 let retrieved = store
318 .get("task-1", "cfg-1")
319 .await
320 .expect("get should succeed")
321 .expect("config should exist");
322 assert_eq!(
323 retrieved.url, "https://example.com/v2",
324 "overwrite should update the URL"
325 );
326 }
327
328 #[tokio::test]
329 async fn list_returns_empty_for_unknown_task() {
330 let store = InMemoryPushConfigStore::new();
331 let configs = store
332 .list("no-such-task")
333 .await
334 .expect("list should succeed");
335 assert!(
336 configs.is_empty(),
337 "list should return empty vec for unknown task"
338 );
339 }
340
341 #[tokio::test]
342 async fn list_returns_only_configs_for_given_task() {
343 let store = InMemoryPushConfigStore::new();
344 store
345 .set(make_config("task-a", Some("c1"), "https://a.com/1"))
346 .await
347 .unwrap();
348 store
349 .set(make_config("task-a", Some("c2"), "https://a.com/2"))
350 .await
351 .unwrap();
352 store
353 .set(make_config("task-b", Some("c3"), "https://b.com/1"))
354 .await
355 .unwrap();
356
357 let a_configs = store.list("task-a").await.expect("list should succeed");
358 assert_eq!(a_configs.len(), 2, "task-a should have exactly 2 configs");
359
360 let b_configs = store.list("task-b").await.expect("list should succeed");
361 assert_eq!(b_configs.len(), 1, "task-b should have exactly 1 config");
362 }
363
364 #[tokio::test]
365 async fn delete_removes_config() {
366 let store = InMemoryPushConfigStore::new();
367 store
368 .set(make_config("task-1", Some("cfg-1"), "https://example.com"))
369 .await
370 .unwrap();
371
372 store
373 .delete("task-1", "cfg-1")
374 .await
375 .expect("delete should succeed");
376
377 let result = store.get("task-1", "cfg-1").await.unwrap();
378 assert!(result.is_none(), "config should be gone after delete");
379 }
380
381 #[tokio::test]
382 async fn delete_nonexistent_is_ok() {
383 let store = InMemoryPushConfigStore::new();
384 let result = store.delete("no-task", "no-id").await;
385 assert!(
386 result.is_ok(),
387 "deleting a non-existent config should not error"
388 );
389 }
390
391 #[tokio::test]
392 async fn max_configs_per_task_limit_enforced() {
393 let store = InMemoryPushConfigStore::with_max_configs_per_task(2);
394 store
395 .set(make_config("task-1", Some("c1"), "https://a.com"))
396 .await
397 .unwrap();
398 store
399 .set(make_config("task-1", Some("c2"), "https://b.com"))
400 .await
401 .unwrap();
402
403 let err = store
404 .set(make_config("task-1", Some("c3"), "https://c.com"))
405 .await
406 .expect_err("third config should exceed per-task limit");
407 let msg = format!("{err}");
408 assert!(
409 msg.contains("limit exceeded"),
410 "error message should mention limit exceeded, got: {msg}"
411 );
412 }
413
414 #[tokio::test]
415 async fn per_task_limit_does_not_block_other_tasks() {
416 let store = InMemoryPushConfigStore::with_max_configs_per_task(1);
417 store
418 .set(make_config("task-1", Some("c1"), "https://a.com"))
419 .await
420 .unwrap();
421
422 let result = store
424 .set(make_config("task-2", Some("c1"), "https://b.com"))
425 .await;
426 assert!(
427 result.is_ok(),
428 "per-task limit should not block a different task"
429 );
430 }
431
432 #[tokio::test]
433 async fn overwrite_does_not_count_toward_per_task_limit() {
434 let store = InMemoryPushConfigStore::with_max_configs_per_task(1);
435 store
436 .set(make_config("task-1", Some("c1"), "https://a.com"))
437 .await
438 .unwrap();
439
440 let result = store
442 .set(make_config("task-1", Some("c1"), "https://b.com"))
443 .await;
444 assert!(
445 result.is_ok(),
446 "overwriting an existing config should not count toward the limit"
447 );
448 }
449
450 #[tokio::test]
451 async fn max_total_configs_limit_enforced() {
452 let store =
453 InMemoryPushConfigStore::with_max_configs_per_task(100).with_max_total_configs(2);
454 store
455 .set(make_config("t1", Some("c1"), "https://a.com"))
456 .await
457 .unwrap();
458 store
459 .set(make_config("t2", Some("c2"), "https://b.com"))
460 .await
461 .unwrap();
462
463 let err = store
464 .set(make_config("t3", Some("c3"), "https://c.com"))
465 .await
466 .expect_err("third config should exceed global limit");
467 let msg = format!("{err}");
468 assert!(
469 msg.contains("global push config limit exceeded"),
470 "error should mention global limit, got: {msg}"
471 );
472 }
473
474 #[tokio::test]
475 async fn overwrite_does_not_count_toward_global_limit() {
476 let store =
477 InMemoryPushConfigStore::with_max_configs_per_task(100).with_max_total_configs(1);
478 store
479 .set(make_config("t1", Some("c1"), "https://a.com"))
480 .await
481 .unwrap();
482
483 let result = store
485 .set(make_config("t1", Some("c1"), "https://b.com"))
486 .await;
487 assert!(
488 result.is_ok(),
489 "overwriting should not count toward global limit"
490 );
491 }
492}