Skip to main content

a2a_rs_server/
webhook_delivery.rs

1//! Webhook delivery engine
2//!
3//! Handles delivery of push notification events to registered webhooks.
4
5use a2a_rs_core::{PushNotificationConfig, StreamResponse};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::{broadcast, Semaphore};
9use tracing::{debug, error, warn};
10
11use crate::webhook_store::WebhookStore;
12
13const DEFAULT_MAX_CONCURRENT: usize = 100;
14
15#[derive(Debug, Clone)]
16pub struct RetryConfig {
17    pub max_retries: u32,
18    pub initial_delay: Duration,
19    pub max_delay: Duration,
20    pub backoff_multiplier: f64,
21}
22
23impl Default for RetryConfig {
24    fn default() -> Self {
25        Self {
26            max_retries: 3,
27            initial_delay: Duration::from_millis(500),
28            max_delay: Duration::from_secs(30),
29            backoff_multiplier: 2.0,
30        }
31    }
32}
33
34pub struct WebhookDelivery {
35    client: reqwest::Client,
36    webhook_store: WebhookStore,
37    retry_config: RetryConfig,
38    concurrency_limit: Arc<Semaphore>,
39}
40
41impl WebhookDelivery {
42    pub fn new(webhook_store: WebhookStore) -> Self {
43        Self::with_config(
44            webhook_store,
45            RetryConfig::default(),
46            DEFAULT_MAX_CONCURRENT,
47        )
48    }
49
50    pub fn with_config(
51        webhook_store: WebhookStore,
52        retry_config: RetryConfig,
53        max_concurrent: usize,
54    ) -> Self {
55        let client = reqwest::Client::builder()
56            .timeout(Duration::from_secs(30))
57            .build()
58            .expect("Failed to create HTTP client");
59
60        Self {
61            client,
62            webhook_store,
63            retry_config,
64            concurrency_limit: Arc::new(Semaphore::new(max_concurrent)),
65        }
66    }
67
68    pub fn start(self: Arc<Self>, mut event_rx: broadcast::Receiver<StreamResponse>) {
69        tokio::spawn(async move {
70            loop {
71                match event_rx.recv().await {
72                    Ok(event) => {
73                        self.clone().handle_event(event).await;
74                    }
75                    Err(broadcast::error::RecvError::Lagged(n)) => {
76                        warn!("Webhook delivery lagged, missed {} events", n);
77                    }
78                    Err(broadcast::error::RecvError::Closed) => {
79                        debug!("Event channel closed, stopping webhook delivery");
80                        break;
81                    }
82                }
83            }
84        });
85    }
86
87    async fn handle_event(self: Arc<Self>, event: StreamResponse) {
88        let task_id = match &event {
89            StreamResponse::Task(t) => &t.id,
90            StreamResponse::StatusUpdate(e) => &e.task_id,
91            StreamResponse::ArtifactUpdate(e) => &e.task_id,
92            StreamResponse::Message(_) => return,
93        };
94
95        let configs = self.webhook_store.get_configs_for_task(task_id).await;
96        if configs.is_empty() {
97            return;
98        }
99
100        for config in configs {
101            let self_clone = self.clone();
102            let event_clone = event.clone();
103            let config_clone = config.clone();
104            let semaphore = self.concurrency_limit.clone();
105
106            tokio::spawn(async move {
107                let _permit = match semaphore.acquire_owned().await {
108                    Ok(permit) => permit,
109                    Err(_) => {
110                        error!("Semaphore closed, cannot deliver webhook");
111                        return;
112                    }
113                };
114
115                if let Err(e) = self_clone
116                    .deliver_with_retry(&config_clone, &event_clone)
117                    .await
118                {
119                    error!("Failed to deliver webhook to {}: {}", config_clone.url, e);
120                }
121            });
122        }
123    }
124
125    async fn deliver_with_retry(
126        &self,
127        config: &PushNotificationConfig,
128        event: &StreamResponse,
129    ) -> Result<(), WebhookError> {
130        let payload =
131            serde_json::to_string(event).map_err(|e| WebhookError::Serialization(e.to_string()))?;
132
133        let mut delay = self.retry_config.initial_delay;
134        let mut last_error = None;
135
136        for attempt in 0..=self.retry_config.max_retries {
137            if attempt > 0 {
138                debug!("Retry attempt {} for webhook {}", attempt, config.url);
139                tokio::time::sleep(delay).await;
140                delay = std::cmp::min(
141                    Duration::from_secs_f64(
142                        delay.as_secs_f64() * self.retry_config.backoff_multiplier,
143                    ),
144                    self.retry_config.max_delay,
145                );
146            }
147
148            match self.send_request(config, &payload).await {
149                Ok(()) => {
150                    debug!("Successfully delivered webhook to {}", config.url);
151                    return Ok(());
152                }
153                Err(e) => {
154                    warn!(
155                        "Webhook delivery to {} failed (attempt {}): {}",
156                        config.url,
157                        attempt + 1,
158                        e
159                    );
160                    last_error = Some(e);
161                }
162            }
163        }
164
165        Err(last_error.unwrap_or(WebhookError::Unknown))
166    }
167
168    async fn send_request(
169        &self,
170        config: &PushNotificationConfig,
171        payload: &str,
172    ) -> Result<(), WebhookError> {
173        let mut request = self
174            .client
175            .post(&config.url)
176            .header("Content-Type", "application/json")
177            .body(payload.to_string());
178
179        // Add authentication if configured
180        if let Some(auth) = &config.authentication {
181            if let Some(credentials) = &auth.credentials {
182                match auth.scheme.as_str() {
183                    "bearer" => {
184                        request =
185                            request.header("Authorization", format!("Bearer {}", credentials));
186                    }
187                    _ => {
188                        request = request
189                            .header("Authorization", format!("{} {}", auth.scheme, credentials));
190                    }
191                }
192            }
193        } else if let Some(token) = &config.token {
194            request = request.header("Authorization", format!("Bearer {}", token));
195        }
196
197        let response = request
198            .send()
199            .await
200            .map_err(|e| WebhookError::Network(e.to_string()))?;
201
202        let status = response.status();
203        if status.is_success() {
204            Ok(())
205        } else if status.is_server_error() {
206            Err(WebhookError::ServerError(status.as_u16()))
207        } else {
208            Err(WebhookError::ClientError(status.as_u16()))
209        }
210    }
211}
212
213#[derive(Debug, thiserror::Error)]
214pub enum WebhookError {
215    #[error("Serialization error: {0}")]
216    Serialization(String),
217    #[error("Network error: {0}")]
218    Network(String),
219    #[error("Server error: {0}")]
220    ServerError(u16),
221    #[error("Client error: {0}")]
222    ClientError(u16),
223    #[error("Unknown error")]
224    Unknown,
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_retry_config_default() {
233        let config = RetryConfig::default();
234        assert_eq!(config.max_retries, 3);
235        assert_eq!(config.initial_delay, Duration::from_millis(500));
236    }
237}