Skip to main content

a2a_rs_server/
webhook_delivery.rs

1//! Webhook delivery engine
2//!
3//! Handles delivery of push notification events to registered webhooks.
4
5use a2a_rs_core::{PushNotificationConfig, StreamResponse};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::{broadcast, Semaphore};
9use tracing::{debug, error, warn};
10
11use crate::webhook_store::WebhookStore;
12
13const DEFAULT_MAX_CONCURRENT: usize = 100;
14
15#[derive(Debug, Clone)]
16pub struct RetryConfig {
17    pub max_retries: u32,
18    pub initial_delay: Duration,
19    pub max_delay: Duration,
20    pub backoff_multiplier: f64,
21}
22
23impl Default for RetryConfig {
24    fn default() -> Self {
25        Self {
26            max_retries: 3,
27            initial_delay: Duration::from_millis(500),
28            max_delay: Duration::from_secs(30),
29            backoff_multiplier: 2.0,
30        }
31    }
32}
33
34pub struct WebhookDelivery {
35    client: reqwest::Client,
36    webhook_store: WebhookStore,
37    retry_config: RetryConfig,
38    concurrency_limit: Arc<Semaphore>,
39}
40
41impl WebhookDelivery {
42    pub fn new(webhook_store: WebhookStore) -> Self {
43        Self::with_config(webhook_store, RetryConfig::default(), DEFAULT_MAX_CONCURRENT)
44    }
45
46    pub fn with_config(
47        webhook_store: WebhookStore,
48        retry_config: RetryConfig,
49        max_concurrent: usize,
50    ) -> Self {
51        let client = reqwest::Client::builder()
52            .timeout(Duration::from_secs(30))
53            .build()
54            .expect("Failed to create HTTP client");
55
56        Self {
57            client,
58            webhook_store,
59            retry_config,
60            concurrency_limit: Arc::new(Semaphore::new(max_concurrent)),
61        }
62    }
63
64    pub fn start(self: Arc<Self>, mut event_rx: broadcast::Receiver<StreamResponse>) {
65        tokio::spawn(async move {
66            loop {
67                match event_rx.recv().await {
68                    Ok(event) => {
69                        self.clone().handle_event(event).await;
70                    }
71                    Err(broadcast::error::RecvError::Lagged(n)) => {
72                        warn!("Webhook delivery lagged, missed {} events", n);
73                    }
74                    Err(broadcast::error::RecvError::Closed) => {
75                        debug!("Event channel closed, stopping webhook delivery");
76                        break;
77                    }
78                }
79            }
80        });
81    }
82
83    async fn handle_event(self: Arc<Self>, event: StreamResponse) {
84        let task_id = match &event {
85            StreamResponse::Task(t) => &t.id,
86            StreamResponse::StatusUpdate(e) => &e.task_id,
87            StreamResponse::ArtifactUpdate(e) => &e.task_id,
88            StreamResponse::Message(_) => return,
89        };
90
91        let configs = self.webhook_store.get_configs_for_task(task_id).await;
92        if configs.is_empty() {
93            return;
94        }
95
96        for config in configs {
97            let self_clone = self.clone();
98            let event_clone = event.clone();
99            let config_clone = config.clone();
100            let semaphore = self.concurrency_limit.clone();
101
102            tokio::spawn(async move {
103                let _permit = match semaphore.acquire_owned().await {
104                    Ok(permit) => permit,
105                    Err(_) => {
106                        error!("Semaphore closed, cannot deliver webhook");
107                        return;
108                    }
109                };
110
111                if let Err(e) = self_clone
112                    .deliver_with_retry(&config_clone, &event_clone)
113                    .await
114                {
115                    error!("Failed to deliver webhook to {}: {}", config_clone.url, e);
116                }
117            });
118        }
119    }
120
121    async fn deliver_with_retry(
122        &self,
123        config: &PushNotificationConfig,
124        event: &StreamResponse,
125    ) -> Result<(), WebhookError> {
126        let payload =
127            serde_json::to_string(event).map_err(|e| WebhookError::Serialization(e.to_string()))?;
128
129        let mut delay = self.retry_config.initial_delay;
130        let mut last_error = None;
131
132        for attempt in 0..=self.retry_config.max_retries {
133            if attempt > 0 {
134                debug!("Retry attempt {} for webhook {}", attempt, config.url);
135                tokio::time::sleep(delay).await;
136                delay = std::cmp::min(
137                    Duration::from_secs_f64(delay.as_secs_f64() * self.retry_config.backoff_multiplier),
138                    self.retry_config.max_delay,
139                );
140            }
141
142            match self.send_request(config, &payload).await {
143                Ok(()) => {
144                    debug!("Successfully delivered webhook to {}", config.url);
145                    return Ok(());
146                }
147                Err(e) => {
148                    warn!(
149                        "Webhook delivery to {} failed (attempt {}): {}",
150                        config.url,
151                        attempt + 1,
152                        e
153                    );
154                    last_error = Some(e);
155                }
156            }
157        }
158
159        Err(last_error.unwrap_or(WebhookError::Unknown))
160    }
161
162    async fn send_request(
163        &self,
164        config: &PushNotificationConfig,
165        payload: &str,
166    ) -> Result<(), WebhookError> {
167        let mut request = self
168            .client
169            .post(&config.url)
170            .header("Content-Type", "application/json")
171            .body(payload.to_string());
172
173        // Add authentication if configured
174        if let Some(auth) = &config.authentication {
175            if let Some(credentials) = &auth.credentials {
176                match auth.scheme.as_str() {
177                    "bearer" => {
178                        request = request.header("Authorization", format!("Bearer {}", credentials));
179                    }
180                    _ => {
181                        request = request.header("Authorization", format!("{} {}", auth.scheme, credentials));
182                    }
183                }
184            }
185        } else if let Some(token) = &config.token {
186            request = request.header("Authorization", format!("Bearer {}", token));
187        }
188
189        let response = request
190            .send()
191            .await
192            .map_err(|e| WebhookError::Network(e.to_string()))?;
193
194        let status = response.status();
195        if status.is_success() {
196            Ok(())
197        } else if status.is_server_error() {
198            Err(WebhookError::ServerError(status.as_u16()))
199        } else {
200            Err(WebhookError::ClientError(status.as_u16()))
201        }
202    }
203}
204
205#[derive(Debug, thiserror::Error)]
206pub enum WebhookError {
207    #[error("Serialization error: {0}")]
208    Serialization(String),
209    #[error("Network error: {0}")]
210    Network(String),
211    #[error("Server error: {0}")]
212    ServerError(u16),
213    #[error("Client error: {0}")]
214    ClientError(u16),
215    #[error("Unknown error")]
216    Unknown,
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_retry_config_default() {
225        let config = RetryConfig::default();
226        assert_eq!(config.max_retries, 3);
227        assert_eq!(config.initial_delay, Duration::from_millis(500));
228    }
229}