a2a_rs_server/
webhook_store.rs1use a2a_rs_core::PushNotificationConfig;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use url::Url;
10
11#[derive(Debug, thiserror::Error)]
13pub enum WebhookValidationError {
14 #[error("Invalid URL format: {0}")]
15 InvalidUrl(String),
16 #[error("URL scheme must be http or https, got: {0}")]
17 InvalidScheme(String),
18}
19
20#[derive(Debug, Clone)]
22pub struct StoredWebhookConfig {
23 pub config_id: String,
24 pub config: PushNotificationConfig,
25}
26
27#[derive(Clone)]
29pub struct WebhookStore {
30 configs: Arc<RwLock<HashMap<String, HashMap<String, PushNotificationConfig>>>>,
31}
32
33impl Default for WebhookStore {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl WebhookStore {
40 pub fn new() -> Self {
41 Self {
42 configs: Arc::new(RwLock::new(HashMap::new())),
43 }
44 }
45
46 pub async fn set(
47 &self,
48 task_id: &str,
49 config_id: &str,
50 config: PushNotificationConfig,
51 ) -> Result<(), WebhookValidationError> {
52 let parsed = Url::parse(&config.url)
53 .map_err(|e| WebhookValidationError::InvalidUrl(e.to_string()))?;
54
55 match parsed.scheme() {
56 "http" | "https" => {}
57 scheme => return Err(WebhookValidationError::InvalidScheme(scheme.to_string())),
58 }
59
60 self.configs
61 .write()
62 .await
63 .entry(task_id.to_string())
64 .or_default()
65 .insert(config_id.to_string(), config);
66
67 Ok(())
68 }
69
70 pub async fn get(&self, task_id: &str, config_id: &str) -> Option<PushNotificationConfig> {
71 self.configs
72 .read()
73 .await
74 .get(task_id)?
75 .get(config_id)
76 .cloned()
77 }
78
79 pub async fn list(&self, task_id: &str) -> Vec<StoredWebhookConfig> {
80 let guard = self.configs.read().await;
81 guard
82 .get(task_id)
83 .map(|configs| {
84 configs
85 .iter()
86 .map(|(id, config)| StoredWebhookConfig {
87 config_id: id.clone(),
88 config: config.clone(),
89 })
90 .collect()
91 })
92 .unwrap_or_default()
93 }
94
95 pub async fn delete(&self, task_id: &str, config_id: &str) -> bool {
96 let mut guard = self.configs.write().await;
97 if let Some(configs) = guard.get_mut(task_id) {
98 let removed = configs.remove(config_id).is_some();
99 if configs.is_empty() {
100 guard.remove(task_id);
101 }
102 return removed;
103 }
104 false
105 }
106
107 pub async fn get_configs_for_task(&self, task_id: &str) -> Vec<PushNotificationConfig> {
108 let guard = self.configs.read().await;
109 guard
110 .get(task_id)
111 .map(|configs| configs.values().cloned().collect())
112 .unwrap_or_default()
113 }
114
115 pub async fn remove_task(&self, task_id: &str) {
116 self.configs.write().await.remove(task_id);
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 fn make_config(url: &str) -> PushNotificationConfig {
125 PushNotificationConfig {
126 id: None,
127 url: url.to_string(),
128 token: None,
129 authentication: None,
130 }
131 }
132
133 #[tokio::test]
134 async fn test_set_and_get() {
135 let store = WebhookStore::new();
136 let config = make_config("https://example.com/webhook");
137
138 store.set("task-1", "config-1", config.clone()).await.unwrap();
139
140 let retrieved = store.get("task-1", "config-1").await;
141 assert!(retrieved.is_some());
142 assert_eq!(retrieved.unwrap().url, "https://example.com/webhook");
143 }
144
145 #[tokio::test]
146 async fn test_list() {
147 let store = WebhookStore::new();
148 store.set("task-1", "config-1", make_config("https://a.com")).await.unwrap();
149 store.set("task-1", "config-2", make_config("https://b.com")).await.unwrap();
150
151 let configs = store.list("task-1").await;
152 assert_eq!(configs.len(), 2);
153 }
154
155 #[tokio::test]
156 async fn test_delete() {
157 let store = WebhookStore::new();
158 store.set("task-1", "config-1", make_config("https://a.com")).await.unwrap();
159
160 assert!(store.delete("task-1", "config-1").await);
161 assert!(store.get("task-1", "config-1").await.is_none());
162 assert!(!store.delete("task-1", "config-1").await);
163 }
164
165 #[tokio::test]
166 async fn test_replace_existing() {
167 let store = WebhookStore::new();
168 store.set("task-1", "config-1", make_config("https://old.com")).await.unwrap();
169 store.set("task-1", "config-1", make_config("https://new.com")).await.unwrap();
170
171 let configs = store.list("task-1").await;
172 assert_eq!(configs.len(), 1);
173 assert_eq!(configs[0].config.url, "https://new.com");
174 }
175
176 #[tokio::test]
177 async fn test_invalid_url() {
178 let store = WebhookStore::new();
179 let config = make_config("not-a-valid-url");
180 let result = store.set("task-1", "config-1", config).await;
181 assert!(result.is_err());
182 }
183
184 #[tokio::test]
185 async fn test_invalid_scheme() {
186 let store = WebhookStore::new();
187 let config = make_config("ftp://example.com/webhook");
188 let result = store.set("task-1", "config-1", config).await;
189 assert!(matches!(result, Err(WebhookValidationError::InvalidScheme(_))));
190 }
191
192 #[tokio::test]
193 async fn test_concurrent_sets() {
194 let store = Arc::new(WebhookStore::new());
195
196 let handles: Vec<_> = (0..100)
197 .map(|i| {
198 let store = store.clone();
199 tokio::spawn(async move {
200 let task_id = format!("task-{}", i % 10);
201 let config_id = format!("config-{}", i);
202 store
203 .set(&task_id, &config_id, make_config(&format!("https://example{}.com", i)))
204 .await
205 })
206 })
207 .collect();
208
209 for h in handles {
210 let result = h.await.unwrap();
211 assert!(result.is_ok());
212 }
213
214 for i in 0..10 {
215 let configs = store.list(&format!("task-{}", i)).await;
216 assert_eq!(configs.len(), 10);
217 }
218 }
219
220 #[tokio::test]
221 async fn test_concurrent_reads_and_writes() {
222 let store = Arc::new(WebhookStore::new());
223
224 for i in 0..10 {
225 store
226 .set("task-1", &format!("config-{}", i), make_config(&format!("https://pre{}.com", i)))
227 .await
228 .unwrap();
229 }
230
231 let mut handles = Vec::new();
232
233 for i in 10..60 {
234 let store = store.clone();
235 handles.push(tokio::spawn(async move {
236 store
237 .set("task-1", &format!("config-{}", i), make_config(&format!("https://new{}.com", i)))
238 .await
239 .unwrap();
240 }));
241 }
242
243 for _ in 0..50 {
244 let store = store.clone();
245 handles.push(tokio::spawn(async move {
246 let _ = store.list("task-1").await;
247 }));
248 }
249
250 for i in 0..5 {
251 let store = store.clone();
252 handles.push(tokio::spawn(async move {
253 store.delete("task-1", &format!("config-{}", i)).await;
254 }));
255 }
256
257 for h in handles {
258 h.await.unwrap();
259 }
260
261 let configs = store.list("task-1").await;
262 assert_eq!(configs.len(), 55);
263 }
264}