a2a_rs_server/
webhook_delivery.rs1use a2a_rs_core::{PushNotificationConfig, StreamResponse};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::{broadcast, Semaphore};
9use tracing::{debug, error, warn};
10
11use crate::webhook_store::WebhookStore;
12
13const DEFAULT_MAX_CONCURRENT: usize = 100;
14
15#[derive(Debug, Clone)]
16pub struct RetryConfig {
17 pub max_retries: u32,
18 pub initial_delay: Duration,
19 pub max_delay: Duration,
20 pub backoff_multiplier: f64,
21}
22
23impl Default for RetryConfig {
24 fn default() -> Self {
25 Self {
26 max_retries: 3,
27 initial_delay: Duration::from_millis(500),
28 max_delay: Duration::from_secs(30),
29 backoff_multiplier: 2.0,
30 }
31 }
32}
33
34pub struct WebhookDelivery {
35 client: reqwest::Client,
36 webhook_store: WebhookStore,
37 retry_config: RetryConfig,
38 concurrency_limit: Arc<Semaphore>,
39}
40
41impl WebhookDelivery {
42 pub fn new(webhook_store: WebhookStore) -> Self {
43 Self::with_config(
44 webhook_store,
45 RetryConfig::default(),
46 DEFAULT_MAX_CONCURRENT,
47 )
48 }
49
50 pub fn with_config(
51 webhook_store: WebhookStore,
52 retry_config: RetryConfig,
53 max_concurrent: usize,
54 ) -> Self {
55 let client = reqwest::Client::builder()
56 .timeout(Duration::from_secs(30))
57 .build()
58 .expect("Failed to create HTTP client");
59
60 Self {
61 client,
62 webhook_store,
63 retry_config,
64 concurrency_limit: Arc::new(Semaphore::new(max_concurrent)),
65 }
66 }
67
68 pub fn start(self: Arc<Self>, mut event_rx: broadcast::Receiver<StreamResponse>) {
69 tokio::spawn(async move {
70 loop {
71 match event_rx.recv().await {
72 Ok(event) => {
73 self.clone().handle_event(event).await;
74 }
75 Err(broadcast::error::RecvError::Lagged(n)) => {
76 warn!("Webhook delivery lagged, missed {} events", n);
77 }
78 Err(broadcast::error::RecvError::Closed) => {
79 debug!("Event channel closed, stopping webhook delivery");
80 break;
81 }
82 }
83 }
84 });
85 }
86
87 async fn handle_event(self: Arc<Self>, event: StreamResponse) {
88 let task_id = match &event {
89 StreamResponse::Task(t) => &t.id,
90 StreamResponse::StatusUpdate(e) => &e.task_id,
91 StreamResponse::ArtifactUpdate(e) => &e.task_id,
92 StreamResponse::Message(_) => return,
93 };
94
95 let configs = self.webhook_store.get_configs_for_task(task_id).await;
96 if configs.is_empty() {
97 return;
98 }
99
100 for config in configs {
101 let self_clone = self.clone();
102 let event_clone = event.clone();
103 let config_clone = config.clone();
104 let semaphore = self.concurrency_limit.clone();
105
106 tokio::spawn(async move {
107 let _permit = match semaphore.acquire_owned().await {
108 Ok(permit) => permit,
109 Err(_) => {
110 error!("Semaphore closed, cannot deliver webhook");
111 return;
112 }
113 };
114
115 if let Err(e) = self_clone
116 .deliver_with_retry(&config_clone, &event_clone)
117 .await
118 {
119 error!("Failed to deliver webhook to {}: {}", config_clone.url, e);
120 }
121 });
122 }
123 }
124
125 async fn deliver_with_retry(
126 &self,
127 config: &PushNotificationConfig,
128 event: &StreamResponse,
129 ) -> Result<(), WebhookError> {
130 let payload =
131 serde_json::to_string(event).map_err(|e| WebhookError::Serialization(e.to_string()))?;
132
133 let mut delay = self.retry_config.initial_delay;
134 let mut last_error = None;
135
136 for attempt in 0..=self.retry_config.max_retries {
137 if attempt > 0 {
138 debug!("Retry attempt {} for webhook {}", attempt, config.url);
139 tokio::time::sleep(delay).await;
140 delay = std::cmp::min(
141 Duration::from_secs_f64(
142 delay.as_secs_f64() * self.retry_config.backoff_multiplier,
143 ),
144 self.retry_config.max_delay,
145 );
146 }
147
148 match self.send_request(config, &payload).await {
149 Ok(()) => {
150 debug!("Successfully delivered webhook to {}", config.url);
151 return Ok(());
152 }
153 Err(e) => {
154 warn!(
155 "Webhook delivery to {} failed (attempt {}): {}",
156 config.url,
157 attempt + 1,
158 e
159 );
160 last_error = Some(e);
161 }
162 }
163 }
164
165 Err(last_error.unwrap_or(WebhookError::Unknown))
166 }
167
168 async fn send_request(
169 &self,
170 config: &PushNotificationConfig,
171 payload: &str,
172 ) -> Result<(), WebhookError> {
173 let mut request = self
174 .client
175 .post(&config.url)
176 .header("Content-Type", "application/json")
177 .body(payload.to_string());
178
179 if let Some(auth) = &config.authentication {
181 if let Some(credentials) = &auth.credentials {
182 match auth.scheme.as_str() {
183 "bearer" => {
184 request =
185 request.header("Authorization", format!("Bearer {}", credentials));
186 }
187 _ => {
188 request = request
189 .header("Authorization", format!("{} {}", auth.scheme, credentials));
190 }
191 }
192 }
193 } else if let Some(token) = &config.token {
194 request = request.header("Authorization", format!("Bearer {}", token));
195 }
196
197 let response = request
198 .send()
199 .await
200 .map_err(|e| WebhookError::Network(e.to_string()))?;
201
202 let status = response.status();
203 if status.is_success() {
204 Ok(())
205 } else if status.is_server_error() {
206 Err(WebhookError::ServerError(status.as_u16()))
207 } else {
208 Err(WebhookError::ClientError(status.as_u16()))
209 }
210 }
211}
212
213#[derive(Debug, thiserror::Error)]
214pub enum WebhookError {
215 #[error("Serialization error: {0}")]
216 Serialization(String),
217 #[error("Network error: {0}")]
218 Network(String),
219 #[error("Server error: {0}")]
220 ServerError(u16),
221 #[error("Client error: {0}")]
222 ClientError(u16),
223 #[error("Unknown error")]
224 Unknown,
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_retry_config_default() {
233 let config = RetryConfig::default();
234 assert_eq!(config.max_retries, 3);
235 assert_eq!(config.initial_delay, Duration::from_millis(500));
236 }
237}