a2a_protocol_server/push/
config_store.rs1use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11
12use a2a_protocol_types::error::A2aResult;
13use a2a_protocol_types::push::TaskPushNotificationConfig;
14use tokio::sync::RwLock;
15
16pub trait PushConfigStore: Send + Sync + 'static {
20 fn set<'a>(
26 &'a self,
27 config: TaskPushNotificationConfig,
28 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskPushNotificationConfig>> + Send + 'a>>;
29
30 fn get<'a>(
36 &'a self,
37 task_id: &'a str,
38 id: &'a str,
39 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<TaskPushNotificationConfig>>> + Send + 'a>>;
40
41 fn list<'a>(
47 &'a self,
48 task_id: &'a str,
49 ) -> Pin<Box<dyn Future<Output = A2aResult<Vec<TaskPushNotificationConfig>>> + Send + 'a>>;
50
51 fn delete<'a>(
57 &'a self,
58 task_id: &'a str,
59 id: &'a str,
60 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>>;
61}
62
63const DEFAULT_MAX_PUSH_CONFIGS_PER_TASK: usize = 100;
65
66const DEFAULT_MAX_TOTAL_PUSH_CONFIGS: usize = 100_000;
69
70#[derive(Debug)]
75pub struct InMemoryPushConfigStore {
76 configs: RwLock<HashMap<(String, String), TaskPushNotificationConfig>>,
77 task_counts: RwLock<HashMap<String, usize>>,
79 max_configs_per_task: usize,
81 max_total_configs: usize,
83}
84
85impl Default for InMemoryPushConfigStore {
86 fn default() -> Self {
87 Self {
88 configs: RwLock::new(HashMap::new()),
89 task_counts: RwLock::new(HashMap::new()),
90 max_configs_per_task: DEFAULT_MAX_PUSH_CONFIGS_PER_TASK,
91 max_total_configs: DEFAULT_MAX_TOTAL_PUSH_CONFIGS,
92 }
93 }
94}
95
96impl InMemoryPushConfigStore {
97 #[must_use]
99 pub fn new() -> Self {
100 Self::default()
101 }
102
103 #[must_use]
105 pub fn with_max_configs_per_task(max: usize) -> Self {
106 Self {
107 configs: RwLock::new(HashMap::new()),
108 task_counts: RwLock::new(HashMap::new()),
109 max_configs_per_task: max,
110 max_total_configs: DEFAULT_MAX_TOTAL_PUSH_CONFIGS,
111 }
112 }
113
114 #[must_use]
119 pub const fn with_max_total_configs(mut self, max: usize) -> Self {
120 self.max_total_configs = max;
121 self
122 }
123}
124
125#[allow(clippy::manual_async_fn)]
126impl PushConfigStore for InMemoryPushConfigStore {
127 fn set<'a>(
128 &'a self,
129 mut config: TaskPushNotificationConfig,
130 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskPushNotificationConfig>> + Send + 'a>> {
131 Box::pin(async move {
132 let id = config
134 .id
135 .clone()
136 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
137 config.id = Some(id.clone());
138
139 let key = (config.task_id.clone(), id);
140 let mut store = self.configs.write().await;
141 let mut counts = self.task_counts.write().await;
142
143 let is_new = !store.contains_key(&key);
145 if is_new {
146 let total = store.len();
148 if total >= self.max_total_configs {
149 drop(counts);
150 drop(store);
151 return Err(a2a_protocol_types::error::A2aError::invalid_params(
152 format!(
153 "global push config limit exceeded: {total} configs (max {})",
154 self.max_total_configs,
155 ),
156 ));
157 }
158 let task_id = &config.task_id;
161 let count = counts.get(task_id).copied().unwrap_or(0);
162 let max = self.max_configs_per_task;
163 if count >= max {
164 drop(counts);
165 drop(store);
166 return Err(a2a_protocol_types::error::A2aError::invalid_params(format!(
167 "push config limit exceeded: task {task_id} already has {count} configs (max {max})"
168 )));
169 }
170 }
171
172 store.insert(key, config.clone());
173 if is_new {
174 *counts.entry(config.task_id.clone()).or_insert(0) += 1;
175 }
176 drop(counts);
177 drop(store);
178 Ok(config)
179 })
180 }
181
182 fn get<'a>(
183 &'a self,
184 task_id: &'a str,
185 id: &'a str,
186 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<TaskPushNotificationConfig>>> + Send + 'a>>
187 {
188 Box::pin(async move {
189 let store = self.configs.read().await;
190 let key = (task_id.to_owned(), id.to_owned());
191 let result = store.get(&key).cloned();
192 drop(store);
193 Ok(result)
194 })
195 }
196
197 fn list<'a>(
198 &'a self,
199 task_id: &'a str,
200 ) -> Pin<Box<dyn Future<Output = A2aResult<Vec<TaskPushNotificationConfig>>> + Send + 'a>> {
201 Box::pin(async move {
202 let store = self.configs.read().await;
203 let mut configs: Vec<_> = store
204 .iter()
205 .filter(|((tid, _), _)| tid == task_id)
206 .map(|(_, v)| v.clone())
207 .collect();
208 drop(store);
209 configs.sort_by(|a, b| a.task_id.cmp(&b.task_id).then_with(|| a.id.cmp(&b.id)));
211 Ok(configs)
212 })
213 }
214
215 fn delete<'a>(
216 &'a self,
217 task_id: &'a str,
218 id: &'a str,
219 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
220 Box::pin(async move {
221 let mut store = self.configs.write().await;
222 let mut counts = self.task_counts.write().await;
223 let key = (task_id.to_owned(), id.to_owned());
224 if store.remove(&key).is_some() {
225 if let Some(count) = counts.get_mut(task_id) {
227 *count = count.saturating_sub(1);
228 if *count == 0 {
229 counts.remove(task_id);
230 }
231 }
232 }
233 drop(counts);
234 drop(store);
235 Ok(())
236 })
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use a2a_protocol_types::push::TaskPushNotificationConfig;
244
245 fn make_config(task_id: &str, id: Option<&str>, url: &str) -> TaskPushNotificationConfig {
246 TaskPushNotificationConfig {
247 tenant: None,
248 id: id.map(String::from),
249 task_id: task_id.to_string(),
250 url: url.to_string(),
251 token: None,
252 authentication: None,
253 }
254 }
255
256 #[tokio::test]
257 async fn set_assigns_id_when_none() {
258 let store = InMemoryPushConfigStore::new();
259 let config = make_config("task-1", None, "https://example.com/hook");
260 let result = store.set(config).await.expect("set should succeed");
261 assert!(
262 result.id.is_some(),
263 "set should assign an id when none is provided"
264 );
265 }
266
267 #[tokio::test]
268 async fn set_preserves_explicit_id() {
269 let store = InMemoryPushConfigStore::new();
270 let config = make_config("task-1", Some("my-id"), "https://example.com/hook");
271 let result = store.set(config).await.expect("set should succeed");
272 assert_eq!(
273 result.id.as_deref(),
274 Some("my-id"),
275 "set should preserve the explicitly provided id"
276 );
277 }
278
279 #[tokio::test]
280 async fn get_returns_none_for_missing_config() {
281 let store = InMemoryPushConfigStore::new();
282 let result = store
283 .get("no-task", "no-id")
284 .await
285 .expect("get should succeed");
286 assert!(
287 result.is_none(),
288 "get should return None for a non-existent config"
289 );
290 }
291
292 #[tokio::test]
293 async fn set_then_get_round_trip() {
294 let store = InMemoryPushConfigStore::new();
295 let config = make_config("task-1", Some("cfg-1"), "https://example.com/hook");
296 store.set(config).await.expect("set should succeed");
297
298 let retrieved = store
299 .get("task-1", "cfg-1")
300 .await
301 .expect("get should succeed")
302 .expect("config should exist after set");
303 assert_eq!(retrieved.task_id, "task-1");
304 assert_eq!(retrieved.url, "https://example.com/hook");
305 }
306
307 #[tokio::test]
308 async fn overwrite_existing_config() {
309 let store = InMemoryPushConfigStore::new();
310 let config1 = make_config("task-1", Some("cfg-1"), "https://example.com/v1");
311 store.set(config1).await.expect("first set should succeed");
312
313 let config2 = make_config("task-1", Some("cfg-1"), "https://example.com/v2");
314 store
315 .set(config2)
316 .await
317 .expect("overwrite set should succeed");
318
319 let retrieved = store
320 .get("task-1", "cfg-1")
321 .await
322 .expect("get should succeed")
323 .expect("config should exist");
324 assert_eq!(
325 retrieved.url, "https://example.com/v2",
326 "overwrite should update the URL"
327 );
328 }
329
330 #[tokio::test]
331 async fn list_returns_empty_for_unknown_task() {
332 let store = InMemoryPushConfigStore::new();
333 let configs = store
334 .list("no-such-task")
335 .await
336 .expect("list should succeed");
337 assert!(
338 configs.is_empty(),
339 "list should return empty vec for unknown task"
340 );
341 }
342
343 #[tokio::test]
344 async fn list_returns_only_configs_for_given_task() {
345 let store = InMemoryPushConfigStore::new();
346 store
347 .set(make_config("task-a", Some("c1"), "https://a.com/1"))
348 .await
349 .unwrap();
350 store
351 .set(make_config("task-a", Some("c2"), "https://a.com/2"))
352 .await
353 .unwrap();
354 store
355 .set(make_config("task-b", Some("c3"), "https://b.com/1"))
356 .await
357 .unwrap();
358
359 let a_configs = store.list("task-a").await.expect("list should succeed");
360 assert_eq!(a_configs.len(), 2, "task-a should have exactly 2 configs");
361
362 let b_configs = store.list("task-b").await.expect("list should succeed");
363 assert_eq!(b_configs.len(), 1, "task-b should have exactly 1 config");
364 }
365
366 #[tokio::test]
367 async fn delete_removes_config() {
368 let store = InMemoryPushConfigStore::new();
369 store
370 .set(make_config("task-1", Some("cfg-1"), "https://example.com"))
371 .await
372 .unwrap();
373
374 store
375 .delete("task-1", "cfg-1")
376 .await
377 .expect("delete should succeed");
378
379 let result = store.get("task-1", "cfg-1").await.unwrap();
380 assert!(result.is_none(), "config should be gone after delete");
381 }
382
383 #[tokio::test]
384 async fn delete_nonexistent_is_ok() {
385 let store = InMemoryPushConfigStore::new();
386 let result = store.delete("no-task", "no-id").await;
387 assert!(
388 result.is_ok(),
389 "deleting a non-existent config should not error"
390 );
391 }
392
393 #[tokio::test]
394 async fn max_configs_per_task_limit_enforced() {
395 let store = InMemoryPushConfigStore::with_max_configs_per_task(2);
396 store
397 .set(make_config("task-1", Some("c1"), "https://a.com"))
398 .await
399 .unwrap();
400 store
401 .set(make_config("task-1", Some("c2"), "https://b.com"))
402 .await
403 .unwrap();
404
405 let err = store
406 .set(make_config("task-1", Some("c3"), "https://c.com"))
407 .await
408 .expect_err("third config should exceed per-task limit");
409 let msg = format!("{err}");
410 assert!(
411 msg.contains("limit exceeded"),
412 "error message should mention limit exceeded, got: {msg}"
413 );
414 }
415
416 #[tokio::test]
417 async fn per_task_limit_does_not_block_other_tasks() {
418 let store = InMemoryPushConfigStore::with_max_configs_per_task(1);
419 store
420 .set(make_config("task-1", Some("c1"), "https://a.com"))
421 .await
422 .unwrap();
423
424 let result = store
426 .set(make_config("task-2", Some("c1"), "https://b.com"))
427 .await;
428 assert!(
429 result.is_ok(),
430 "per-task limit should not block a different task"
431 );
432 }
433
434 #[tokio::test]
435 async fn overwrite_does_not_count_toward_per_task_limit() {
436 let store = InMemoryPushConfigStore::with_max_configs_per_task(1);
437 store
438 .set(make_config("task-1", Some("c1"), "https://a.com"))
439 .await
440 .unwrap();
441
442 let result = store
444 .set(make_config("task-1", Some("c1"), "https://b.com"))
445 .await;
446 assert!(
447 result.is_ok(),
448 "overwriting an existing config should not count toward the limit"
449 );
450 }
451
452 #[tokio::test]
453 async fn max_total_configs_limit_enforced() {
454 let store =
455 InMemoryPushConfigStore::with_max_configs_per_task(100).with_max_total_configs(2);
456 store
457 .set(make_config("t1", Some("c1"), "https://a.com"))
458 .await
459 .unwrap();
460 store
461 .set(make_config("t2", Some("c2"), "https://b.com"))
462 .await
463 .unwrap();
464
465 let err = store
466 .set(make_config("t3", Some("c3"), "https://c.com"))
467 .await
468 .expect_err("third config should exceed global limit");
469 let msg = format!("{err}");
470 assert!(
471 msg.contains("global push config limit exceeded"),
472 "error should mention global limit, got: {msg}"
473 );
474 }
475
476 #[tokio::test]
477 async fn overwrite_does_not_count_toward_global_limit() {
478 let store =
479 InMemoryPushConfigStore::with_max_configs_per_task(100).with_max_total_configs(1);
480 store
481 .set(make_config("t1", Some("c1"), "https://a.com"))
482 .await
483 .unwrap();
484
485 let result = store
487 .set(make_config("t1", Some("c1"), "https://b.com"))
488 .await;
489 assert!(
490 result.is_ok(),
491 "overwriting should not count toward global limit"
492 );
493 }
494}