a2a_rs_server/
webhook_delivery.rs1use a2a_rs_core::{PushNotificationConfig, StreamResponse};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::{broadcast, Semaphore};
9use tracing::{debug, error, warn};
10
11use crate::webhook_store::WebhookStore;
12
13const DEFAULT_MAX_CONCURRENT: usize = 100;
14
15#[derive(Debug, Clone)]
16pub struct RetryConfig {
17 pub max_retries: u32,
18 pub initial_delay: Duration,
19 pub max_delay: Duration,
20 pub backoff_multiplier: f64,
21}
22
23impl Default for RetryConfig {
24 fn default() -> Self {
25 Self {
26 max_retries: 3,
27 initial_delay: Duration::from_millis(500),
28 max_delay: Duration::from_secs(30),
29 backoff_multiplier: 2.0,
30 }
31 }
32}
33
34pub struct WebhookDelivery {
35 client: reqwest::Client,
36 webhook_store: WebhookStore,
37 retry_config: RetryConfig,
38 concurrency_limit: Arc<Semaphore>,
39}
40
41impl WebhookDelivery {
42 pub fn new(webhook_store: WebhookStore) -> Self {
43 Self::with_config(webhook_store, RetryConfig::default(), DEFAULT_MAX_CONCURRENT)
44 }
45
46 pub fn with_config(
47 webhook_store: WebhookStore,
48 retry_config: RetryConfig,
49 max_concurrent: usize,
50 ) -> Self {
51 let client = reqwest::Client::builder()
52 .timeout(Duration::from_secs(30))
53 .build()
54 .expect("Failed to create HTTP client");
55
56 Self {
57 client,
58 webhook_store,
59 retry_config,
60 concurrency_limit: Arc::new(Semaphore::new(max_concurrent)),
61 }
62 }
63
64 pub fn start(self: Arc<Self>, mut event_rx: broadcast::Receiver<StreamResponse>) {
65 tokio::spawn(async move {
66 loop {
67 match event_rx.recv().await {
68 Ok(event) => {
69 self.clone().handle_event(event).await;
70 }
71 Err(broadcast::error::RecvError::Lagged(n)) => {
72 warn!("Webhook delivery lagged, missed {} events", n);
73 }
74 Err(broadcast::error::RecvError::Closed) => {
75 debug!("Event channel closed, stopping webhook delivery");
76 break;
77 }
78 }
79 }
80 });
81 }
82
83 async fn handle_event(self: Arc<Self>, event: StreamResponse) {
84 let task_id = match &event {
85 StreamResponse::Task(t) => &t.id,
86 StreamResponse::StatusUpdate(e) => &e.task_id,
87 StreamResponse::ArtifactUpdate(e) => &e.task_id,
88 StreamResponse::Message(_) => return,
89 };
90
91 let configs = self.webhook_store.get_configs_for_task(task_id).await;
92 if configs.is_empty() {
93 return;
94 }
95
96 for config in configs {
97 let self_clone = self.clone();
98 let event_clone = event.clone();
99 let config_clone = config.clone();
100 let semaphore = self.concurrency_limit.clone();
101
102 tokio::spawn(async move {
103 let _permit = match semaphore.acquire_owned().await {
104 Ok(permit) => permit,
105 Err(_) => {
106 error!("Semaphore closed, cannot deliver webhook");
107 return;
108 }
109 };
110
111 if let Err(e) = self_clone
112 .deliver_with_retry(&config_clone, &event_clone)
113 .await
114 {
115 error!("Failed to deliver webhook to {}: {}", config_clone.url, e);
116 }
117 });
118 }
119 }
120
121 async fn deliver_with_retry(
122 &self,
123 config: &PushNotificationConfig,
124 event: &StreamResponse,
125 ) -> Result<(), WebhookError> {
126 let payload =
127 serde_json::to_string(event).map_err(|e| WebhookError::Serialization(e.to_string()))?;
128
129 let mut delay = self.retry_config.initial_delay;
130 let mut last_error = None;
131
132 for attempt in 0..=self.retry_config.max_retries {
133 if attempt > 0 {
134 debug!("Retry attempt {} for webhook {}", attempt, config.url);
135 tokio::time::sleep(delay).await;
136 delay = std::cmp::min(
137 Duration::from_secs_f64(delay.as_secs_f64() * self.retry_config.backoff_multiplier),
138 self.retry_config.max_delay,
139 );
140 }
141
142 match self.send_request(config, &payload).await {
143 Ok(()) => {
144 debug!("Successfully delivered webhook to {}", config.url);
145 return Ok(());
146 }
147 Err(e) => {
148 warn!(
149 "Webhook delivery to {} failed (attempt {}): {}",
150 config.url,
151 attempt + 1,
152 e
153 );
154 last_error = Some(e);
155 }
156 }
157 }
158
159 Err(last_error.unwrap_or(WebhookError::Unknown))
160 }
161
162 async fn send_request(
163 &self,
164 config: &PushNotificationConfig,
165 payload: &str,
166 ) -> Result<(), WebhookError> {
167 let mut request = self
168 .client
169 .post(&config.url)
170 .header("Content-Type", "application/json")
171 .body(payload.to_string());
172
173 if let Some(auth) = &config.authentication {
175 if let Some(credentials) = &auth.credentials {
176 match auth.scheme.as_str() {
177 "bearer" => {
178 request = request.header("Authorization", format!("Bearer {}", credentials));
179 }
180 _ => {
181 request = request.header("Authorization", format!("{} {}", auth.scheme, credentials));
182 }
183 }
184 }
185 } else if let Some(token) = &config.token {
186 request = request.header("Authorization", format!("Bearer {}", token));
187 }
188
189 let response = request
190 .send()
191 .await
192 .map_err(|e| WebhookError::Network(e.to_string()))?;
193
194 let status = response.status();
195 if status.is_success() {
196 Ok(())
197 } else if status.is_server_error() {
198 Err(WebhookError::ServerError(status.as_u16()))
199 } else {
200 Err(WebhookError::ClientError(status.as_u16()))
201 }
202 }
203}
204
205#[derive(Debug, thiserror::Error)]
206pub enum WebhookError {
207 #[error("Serialization error: {0}")]
208 Serialization(String),
209 #[error("Network error: {0}")]
210 Network(String),
211 #[error("Server error: {0}")]
212 ServerError(u16),
213 #[error("Client error: {0}")]
214 ClientError(u16),
215 #[error("Unknown error")]
216 Unknown,
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_retry_config_default() {
225 let config = RetryConfig::default();
226 assert_eq!(config.max_retries, 3);
227 assert_eq!(config.initial_delay, Duration::from_millis(500));
228 }
229}