a2a_rs_server/
webhook_store.rs1use a2a_rs_core::PushNotificationConfig;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use url::Url;
10
11#[derive(Debug, thiserror::Error)]
13pub enum WebhookValidationError {
14 #[error("Invalid URL format: {0}")]
15 InvalidUrl(String),
16 #[error("URL scheme must be http or https, got: {0}")]
17 InvalidScheme(String),
18}
19
20#[derive(Debug, Clone)]
22pub struct StoredWebhookConfig {
23 pub config_id: String,
24 pub config: PushNotificationConfig,
25}
26
27#[derive(Clone)]
29pub struct WebhookStore {
30 configs: Arc<RwLock<HashMap<String, HashMap<String, PushNotificationConfig>>>>,
31}
32
33impl Default for WebhookStore {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl WebhookStore {
40 pub fn new() -> Self {
41 Self {
42 configs: Arc::new(RwLock::new(HashMap::new())),
43 }
44 }
45
46 pub async fn set(
47 &self,
48 task_id: &str,
49 config_id: &str,
50 config: PushNotificationConfig,
51 ) -> Result<(), WebhookValidationError> {
52 let parsed = Url::parse(&config.url)
53 .map_err(|e| WebhookValidationError::InvalidUrl(e.to_string()))?;
54
55 match parsed.scheme() {
56 "http" | "https" => {}
57 scheme => return Err(WebhookValidationError::InvalidScheme(scheme.to_string())),
58 }
59
60 self.configs
61 .write()
62 .await
63 .entry(task_id.to_string())
64 .or_default()
65 .insert(config_id.to_string(), config);
66
67 Ok(())
68 }
69
70 pub async fn get(&self, task_id: &str, config_id: &str) -> Option<PushNotificationConfig> {
71 self.configs
72 .read()
73 .await
74 .get(task_id)?
75 .get(config_id)
76 .cloned()
77 }
78
79 pub async fn list(&self, task_id: &str) -> Vec<StoredWebhookConfig> {
80 let guard = self.configs.read().await;
81 guard
82 .get(task_id)
83 .map(|configs| {
84 configs
85 .iter()
86 .map(|(id, config)| StoredWebhookConfig {
87 config_id: id.clone(),
88 config: config.clone(),
89 })
90 .collect()
91 })
92 .unwrap_or_default()
93 }
94
95 pub async fn delete(&self, task_id: &str, config_id: &str) -> bool {
96 let mut guard = self.configs.write().await;
97 if let Some(configs) = guard.get_mut(task_id) {
98 let removed = configs.remove(config_id).is_some();
99 if configs.is_empty() {
100 guard.remove(task_id);
101 }
102 return removed;
103 }
104 false
105 }
106
107 pub async fn get_configs_for_task(&self, task_id: &str) -> Vec<PushNotificationConfig> {
108 let guard = self.configs.read().await;
109 guard
110 .get(task_id)
111 .map(|configs| configs.values().cloned().collect())
112 .unwrap_or_default()
113 }
114
115 pub async fn remove_task(&self, task_id: &str) {
116 self.configs.write().await.remove(task_id);
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 fn make_config(url: &str) -> PushNotificationConfig {
125 PushNotificationConfig {
126 id: None,
127 url: url.to_string(),
128 token: None,
129 authentication: None,
130 }
131 }
132
133 #[tokio::test]
134 async fn test_set_and_get() {
135 let store = WebhookStore::new();
136 let config = make_config("https://example.com/webhook");
137
138 store
139 .set("task-1", "config-1", config.clone())
140 .await
141 .unwrap();
142
143 let retrieved = store.get("task-1", "config-1").await;
144 assert!(retrieved.is_some());
145 assert_eq!(retrieved.unwrap().url, "https://example.com/webhook");
146 }
147
148 #[tokio::test]
149 async fn test_list() {
150 let store = WebhookStore::new();
151 store
152 .set("task-1", "config-1", make_config("https://a.com"))
153 .await
154 .unwrap();
155 store
156 .set("task-1", "config-2", make_config("https://b.com"))
157 .await
158 .unwrap();
159
160 let configs = store.list("task-1").await;
161 assert_eq!(configs.len(), 2);
162 }
163
164 #[tokio::test]
165 async fn test_delete() {
166 let store = WebhookStore::new();
167 store
168 .set("task-1", "config-1", make_config("https://a.com"))
169 .await
170 .unwrap();
171
172 assert!(store.delete("task-1", "config-1").await);
173 assert!(store.get("task-1", "config-1").await.is_none());
174 assert!(!store.delete("task-1", "config-1").await);
175 }
176
177 #[tokio::test]
178 async fn test_replace_existing() {
179 let store = WebhookStore::new();
180 store
181 .set("task-1", "config-1", make_config("https://old.com"))
182 .await
183 .unwrap();
184 store
185 .set("task-1", "config-1", make_config("https://new.com"))
186 .await
187 .unwrap();
188
189 let configs = store.list("task-1").await;
190 assert_eq!(configs.len(), 1);
191 assert_eq!(configs[0].config.url, "https://new.com");
192 }
193
194 #[tokio::test]
195 async fn test_invalid_url() {
196 let store = WebhookStore::new();
197 let config = make_config("not-a-valid-url");
198 let result = store.set("task-1", "config-1", config).await;
199 assert!(result.is_err());
200 }
201
202 #[tokio::test]
203 async fn test_invalid_scheme() {
204 let store = WebhookStore::new();
205 let config = make_config("ftp://example.com/webhook");
206 let result = store.set("task-1", "config-1", config).await;
207 assert!(matches!(
208 result,
209 Err(WebhookValidationError::InvalidScheme(_))
210 ));
211 }
212
213 #[tokio::test]
214 async fn test_concurrent_sets() {
215 let store = Arc::new(WebhookStore::new());
216
217 let handles: Vec<_> = (0..100)
218 .map(|i| {
219 let store = store.clone();
220 tokio::spawn(async move {
221 let task_id = format!("task-{}", i % 10);
222 let config_id = format!("config-{}", i);
223 store
224 .set(
225 &task_id,
226 &config_id,
227 make_config(&format!("https://example{}.com", i)),
228 )
229 .await
230 })
231 })
232 .collect();
233
234 for h in handles {
235 let result = h.await.unwrap();
236 assert!(result.is_ok());
237 }
238
239 for i in 0..10 {
240 let configs = store.list(&format!("task-{}", i)).await;
241 assert_eq!(configs.len(), 10);
242 }
243 }
244
245 #[tokio::test]
246 async fn test_concurrent_reads_and_writes() {
247 let store = Arc::new(WebhookStore::new());
248
249 for i in 0..10 {
250 store
251 .set(
252 "task-1",
253 &format!("config-{}", i),
254 make_config(&format!("https://pre{}.com", i)),
255 )
256 .await
257 .unwrap();
258 }
259
260 let mut handles = Vec::new();
261
262 for i in 10..60 {
263 let store = store.clone();
264 handles.push(tokio::spawn(async move {
265 store
266 .set(
267 "task-1",
268 &format!("config-{}", i),
269 make_config(&format!("https://new{}.com", i)),
270 )
271 .await
272 .unwrap();
273 }));
274 }
275
276 for _ in 0..50 {
277 let store = store.clone();
278 handles.push(tokio::spawn(async move {
279 let _ = store.list("task-1").await;
280 }));
281 }
282
283 for i in 0..5 {
284 let store = store.clone();
285 handles.push(tokio::spawn(async move {
286 store.delete("task-1", &format!("config-{}", i)).await;
287 }));
288 }
289
290 for h in handles {
291 h.await.unwrap();
292 }
293
294 let configs = store.list("task-1").await;
295 assert_eq!(configs.len(), 55);
296 }
297}