Skip to main content

a2a_rs_server/
webhook_store.rs

1//! Webhook configuration storage
2//!
3//! Provides thread-safe storage for push notification webhook configurations.
4
5use a2a_rs_core::PushNotificationConfig;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use url::Url;
10
11/// Error when setting a webhook configuration
12#[derive(Debug, thiserror::Error)]
13pub enum WebhookValidationError {
14    #[error("Invalid URL format: {0}")]
15    InvalidUrl(String),
16    #[error("URL scheme must be http or https, got: {0}")]
17    InvalidScheme(String),
18}
19
20/// Stored webhook configuration with ID
21#[derive(Debug, Clone)]
22pub struct StoredWebhookConfig {
23    pub config_id: String,
24    pub config: PushNotificationConfig,
25}
26
27/// Thread-safe in-memory webhook configuration store
28#[derive(Clone)]
29pub struct WebhookStore {
30    configs: Arc<RwLock<HashMap<String, HashMap<String, PushNotificationConfig>>>>,
31}
32
33impl Default for WebhookStore {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl WebhookStore {
40    pub fn new() -> Self {
41        Self {
42            configs: Arc::new(RwLock::new(HashMap::new())),
43        }
44    }
45
46    pub async fn set(
47        &self,
48        task_id: &str,
49        config_id: &str,
50        config: PushNotificationConfig,
51    ) -> Result<(), WebhookValidationError> {
52        let parsed = Url::parse(&config.url)
53            .map_err(|e| WebhookValidationError::InvalidUrl(e.to_string()))?;
54
55        match parsed.scheme() {
56            "http" | "https" => {}
57            scheme => return Err(WebhookValidationError::InvalidScheme(scheme.to_string())),
58        }
59
60        self.configs
61            .write()
62            .await
63            .entry(task_id.to_string())
64            .or_default()
65            .insert(config_id.to_string(), config);
66
67        Ok(())
68    }
69
70    pub async fn get(&self, task_id: &str, config_id: &str) -> Option<PushNotificationConfig> {
71        self.configs
72            .read()
73            .await
74            .get(task_id)?
75            .get(config_id)
76            .cloned()
77    }
78
79    pub async fn list(&self, task_id: &str) -> Vec<StoredWebhookConfig> {
80        let guard = self.configs.read().await;
81        guard
82            .get(task_id)
83            .map(|configs| {
84                configs
85                    .iter()
86                    .map(|(id, config)| StoredWebhookConfig {
87                        config_id: id.clone(),
88                        config: config.clone(),
89                    })
90                    .collect()
91            })
92            .unwrap_or_default()
93    }
94
95    pub async fn delete(&self, task_id: &str, config_id: &str) -> bool {
96        let mut guard = self.configs.write().await;
97        if let Some(configs) = guard.get_mut(task_id) {
98            let removed = configs.remove(config_id).is_some();
99            if configs.is_empty() {
100                guard.remove(task_id);
101            }
102            return removed;
103        }
104        false
105    }
106
107    pub async fn get_configs_for_task(&self, task_id: &str) -> Vec<PushNotificationConfig> {
108        let guard = self.configs.read().await;
109        guard
110            .get(task_id)
111            .map(|configs| configs.values().cloned().collect())
112            .unwrap_or_default()
113    }
114
115    pub async fn remove_task(&self, task_id: &str) {
116        self.configs.write().await.remove(task_id);
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    fn make_config(url: &str) -> PushNotificationConfig {
125        PushNotificationConfig {
126            id: None,
127            url: url.to_string(),
128            token: None,
129            authentication: None,
130        }
131    }
132
133    #[tokio::test]
134    async fn test_set_and_get() {
135        let store = WebhookStore::new();
136        let config = make_config("https://example.com/webhook");
137
138        store.set("task-1", "config-1", config.clone()).await.unwrap();
139
140        let retrieved = store.get("task-1", "config-1").await;
141        assert!(retrieved.is_some());
142        assert_eq!(retrieved.unwrap().url, "https://example.com/webhook");
143    }
144
145    #[tokio::test]
146    async fn test_list() {
147        let store = WebhookStore::new();
148        store.set("task-1", "config-1", make_config("https://a.com")).await.unwrap();
149        store.set("task-1", "config-2", make_config("https://b.com")).await.unwrap();
150
151        let configs = store.list("task-1").await;
152        assert_eq!(configs.len(), 2);
153    }
154
155    #[tokio::test]
156    async fn test_delete() {
157        let store = WebhookStore::new();
158        store.set("task-1", "config-1", make_config("https://a.com")).await.unwrap();
159
160        assert!(store.delete("task-1", "config-1").await);
161        assert!(store.get("task-1", "config-1").await.is_none());
162        assert!(!store.delete("task-1", "config-1").await);
163    }
164
165    #[tokio::test]
166    async fn test_replace_existing() {
167        let store = WebhookStore::new();
168        store.set("task-1", "config-1", make_config("https://old.com")).await.unwrap();
169        store.set("task-1", "config-1", make_config("https://new.com")).await.unwrap();
170
171        let configs = store.list("task-1").await;
172        assert_eq!(configs.len(), 1);
173        assert_eq!(configs[0].config.url, "https://new.com");
174    }
175
176    #[tokio::test]
177    async fn test_invalid_url() {
178        let store = WebhookStore::new();
179        let config = make_config("not-a-valid-url");
180        let result = store.set("task-1", "config-1", config).await;
181        assert!(result.is_err());
182    }
183
184    #[tokio::test]
185    async fn test_invalid_scheme() {
186        let store = WebhookStore::new();
187        let config = make_config("ftp://example.com/webhook");
188        let result = store.set("task-1", "config-1", config).await;
189        assert!(matches!(result, Err(WebhookValidationError::InvalidScheme(_))));
190    }
191
192    #[tokio::test]
193    async fn test_concurrent_sets() {
194        let store = Arc::new(WebhookStore::new());
195
196        let handles: Vec<_> = (0..100)
197            .map(|i| {
198                let store = store.clone();
199                tokio::spawn(async move {
200                    let task_id = format!("task-{}", i % 10);
201                    let config_id = format!("config-{}", i);
202                    store
203                        .set(&task_id, &config_id, make_config(&format!("https://example{}.com", i)))
204                        .await
205                })
206            })
207            .collect();
208
209        for h in handles {
210            let result = h.await.unwrap();
211            assert!(result.is_ok());
212        }
213
214        for i in 0..10 {
215            let configs = store.list(&format!("task-{}", i)).await;
216            assert_eq!(configs.len(), 10);
217        }
218    }
219
220    #[tokio::test]
221    async fn test_concurrent_reads_and_writes() {
222        let store = Arc::new(WebhookStore::new());
223
224        for i in 0..10 {
225            store
226                .set("task-1", &format!("config-{}", i), make_config(&format!("https://pre{}.com", i)))
227                .await
228                .unwrap();
229        }
230
231        let mut handles = Vec::new();
232
233        for i in 10..60 {
234            let store = store.clone();
235            handles.push(tokio::spawn(async move {
236                store
237                    .set("task-1", &format!("config-{}", i), make_config(&format!("https://new{}.com", i)))
238                    .await
239                    .unwrap();
240            }));
241        }
242
243        for _ in 0..50 {
244            let store = store.clone();
245            handles.push(tokio::spawn(async move {
246                let _ = store.list("task-1").await;
247            }));
248        }
249
250        for i in 0..5 {
251            let store = store.clone();
252            handles.push(tokio::spawn(async move {
253                store.delete("task-1", &format!("config-{}", i)).await;
254            }));
255        }
256
257        for h in handles {
258            h.await.unwrap();
259        }
260
261        let configs = store.list("task-1").await;
262        assert_eq!(configs.len(), 55);
263    }
264}