Skip to main content

a2a_rs_server/
webhook_store.rs

1//! Webhook configuration storage
2//!
3//! Provides thread-safe storage for push notification webhook configurations.
4
5use a2a_rs_core::PushNotificationConfig;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use url::Url;
10
11/// Error when setting a webhook configuration
12#[derive(Debug, thiserror::Error)]
13pub enum WebhookValidationError {
14    #[error("Invalid URL format: {0}")]
15    InvalidUrl(String),
16    #[error("URL scheme must be http or https, got: {0}")]
17    InvalidScheme(String),
18}
19
20/// Stored webhook configuration with ID
21#[derive(Debug, Clone)]
22pub struct StoredWebhookConfig {
23    pub config_id: String,
24    pub config: PushNotificationConfig,
25}
26
27/// Thread-safe in-memory webhook configuration store
28#[derive(Clone)]
29pub struct WebhookStore {
30    configs: Arc<RwLock<HashMap<String, HashMap<String, PushNotificationConfig>>>>,
31}
32
33impl Default for WebhookStore {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl WebhookStore {
40    pub fn new() -> Self {
41        Self {
42            configs: Arc::new(RwLock::new(HashMap::new())),
43        }
44    }
45
46    pub async fn set(
47        &self,
48        task_id: &str,
49        config_id: &str,
50        config: PushNotificationConfig,
51    ) -> Result<(), WebhookValidationError> {
52        let parsed = Url::parse(&config.url)
53            .map_err(|e| WebhookValidationError::InvalidUrl(e.to_string()))?;
54
55        match parsed.scheme() {
56            "http" | "https" => {}
57            scheme => return Err(WebhookValidationError::InvalidScheme(scheme.to_string())),
58        }
59
60        self.configs
61            .write()
62            .await
63            .entry(task_id.to_string())
64            .or_default()
65            .insert(config_id.to_string(), config);
66
67        Ok(())
68    }
69
70    pub async fn get(&self, task_id: &str, config_id: &str) -> Option<PushNotificationConfig> {
71        self.configs
72            .read()
73            .await
74            .get(task_id)?
75            .get(config_id)
76            .cloned()
77    }
78
79    pub async fn list(&self, task_id: &str) -> Vec<StoredWebhookConfig> {
80        let guard = self.configs.read().await;
81        guard
82            .get(task_id)
83            .map(|configs| {
84                configs
85                    .iter()
86                    .map(|(id, config)| StoredWebhookConfig {
87                        config_id: id.clone(),
88                        config: config.clone(),
89                    })
90                    .collect()
91            })
92            .unwrap_or_default()
93    }
94
95    pub async fn delete(&self, task_id: &str, config_id: &str) -> bool {
96        let mut guard = self.configs.write().await;
97        if let Some(configs) = guard.get_mut(task_id) {
98            let removed = configs.remove(config_id).is_some();
99            if configs.is_empty() {
100                guard.remove(task_id);
101            }
102            return removed;
103        }
104        false
105    }
106
107    pub async fn get_configs_for_task(&self, task_id: &str) -> Vec<PushNotificationConfig> {
108        let guard = self.configs.read().await;
109        guard
110            .get(task_id)
111            .map(|configs| configs.values().cloned().collect())
112            .unwrap_or_default()
113    }
114
115    pub async fn remove_task(&self, task_id: &str) {
116        self.configs.write().await.remove(task_id);
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    fn make_config(url: &str) -> PushNotificationConfig {
125        PushNotificationConfig {
126            id: None,
127            url: url.to_string(),
128            token: None,
129            authentication: None,
130        }
131    }
132
133    #[tokio::test]
134    async fn test_set_and_get() {
135        let store = WebhookStore::new();
136        let config = make_config("https://example.com/webhook");
137
138        store
139            .set("task-1", "config-1", config.clone())
140            .await
141            .unwrap();
142
143        let retrieved = store.get("task-1", "config-1").await;
144        assert!(retrieved.is_some());
145        assert_eq!(retrieved.unwrap().url, "https://example.com/webhook");
146    }
147
148    #[tokio::test]
149    async fn test_list() {
150        let store = WebhookStore::new();
151        store
152            .set("task-1", "config-1", make_config("https://a.com"))
153            .await
154            .unwrap();
155        store
156            .set("task-1", "config-2", make_config("https://b.com"))
157            .await
158            .unwrap();
159
160        let configs = store.list("task-1").await;
161        assert_eq!(configs.len(), 2);
162    }
163
164    #[tokio::test]
165    async fn test_delete() {
166        let store = WebhookStore::new();
167        store
168            .set("task-1", "config-1", make_config("https://a.com"))
169            .await
170            .unwrap();
171
172        assert!(store.delete("task-1", "config-1").await);
173        assert!(store.get("task-1", "config-1").await.is_none());
174        assert!(!store.delete("task-1", "config-1").await);
175    }
176
177    #[tokio::test]
178    async fn test_replace_existing() {
179        let store = WebhookStore::new();
180        store
181            .set("task-1", "config-1", make_config("https://old.com"))
182            .await
183            .unwrap();
184        store
185            .set("task-1", "config-1", make_config("https://new.com"))
186            .await
187            .unwrap();
188
189        let configs = store.list("task-1").await;
190        assert_eq!(configs.len(), 1);
191        assert_eq!(configs[0].config.url, "https://new.com");
192    }
193
194    #[tokio::test]
195    async fn test_invalid_url() {
196        let store = WebhookStore::new();
197        let config = make_config("not-a-valid-url");
198        let result = store.set("task-1", "config-1", config).await;
199        assert!(result.is_err());
200    }
201
202    #[tokio::test]
203    async fn test_invalid_scheme() {
204        let store = WebhookStore::new();
205        let config = make_config("ftp://example.com/webhook");
206        let result = store.set("task-1", "config-1", config).await;
207        assert!(matches!(
208            result,
209            Err(WebhookValidationError::InvalidScheme(_))
210        ));
211    }
212
213    #[tokio::test]
214    async fn test_concurrent_sets() {
215        let store = Arc::new(WebhookStore::new());
216
217        let handles: Vec<_> = (0..100)
218            .map(|i| {
219                let store = store.clone();
220                tokio::spawn(async move {
221                    let task_id = format!("task-{}", i % 10);
222                    let config_id = format!("config-{}", i);
223                    store
224                        .set(
225                            &task_id,
226                            &config_id,
227                            make_config(&format!("https://example{}.com", i)),
228                        )
229                        .await
230                })
231            })
232            .collect();
233
234        for h in handles {
235            let result = h.await.unwrap();
236            assert!(result.is_ok());
237        }
238
239        for i in 0..10 {
240            let configs = store.list(&format!("task-{}", i)).await;
241            assert_eq!(configs.len(), 10);
242        }
243    }
244
245    #[tokio::test]
246    async fn test_concurrent_reads_and_writes() {
247        let store = Arc::new(WebhookStore::new());
248
249        for i in 0..10 {
250            store
251                .set(
252                    "task-1",
253                    &format!("config-{}", i),
254                    make_config(&format!("https://pre{}.com", i)),
255                )
256                .await
257                .unwrap();
258        }
259
260        let mut handles = Vec::new();
261
262        for i in 10..60 {
263            let store = store.clone();
264            handles.push(tokio::spawn(async move {
265                store
266                    .set(
267                        "task-1",
268                        &format!("config-{}", i),
269                        make_config(&format!("https://new{}.com", i)),
270                    )
271                    .await
272                    .unwrap();
273            }));
274        }
275
276        for _ in 0..50 {
277            let store = store.clone();
278            handles.push(tokio::spawn(async move {
279                let _ = store.list("task-1").await;
280            }));
281        }
282
283        for i in 0..5 {
284            let store = store.clone();
285            handles.push(tokio::spawn(async move {
286                store.delete("task-1", &format!("config-{}", i)).await;
287            }));
288        }
289
290        for h in handles {
291            h.await.unwrap();
292        }
293
294        let configs = store.list("task-1").await;
295        assert_eq!(configs.len(), 55);
296    }
297}