a2a_protocol_server/push/
sender.rs1use std::future::Future;
16use std::net::IpAddr;
17use std::pin::Pin;
18
19use a2a_protocol_types::error::{A2aError, A2aResult};
20use a2a_protocol_types::events::StreamResponse;
21use a2a_protocol_types::push::TaskPushNotificationConfig;
22use bytes::Bytes;
23use http_body_util::Full;
24use hyper_util::client::legacy::Client;
25use hyper_util::rt::TokioExecutor;
26
27pub trait PushSender: Send + Sync + 'static {
31 fn send<'a>(
37 &'a self,
38 url: &'a str,
39 event: &'a StreamResponse,
40 config: &'a TaskPushNotificationConfig,
41 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>>;
42}
43
44const MAX_PUSH_ATTEMPTS: usize = 3;
46
47const PUSH_RETRY_BACKOFF: [std::time::Duration; 2] = [
49 std::time::Duration::from_secs(1),
50 std::time::Duration::from_secs(2),
51];
52
53const DEFAULT_PUSH_REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
55
56#[derive(Debug)]
68pub struct HttpPushSender {
69 client: Client<hyper_util::client::legacy::connect::HttpConnector, Full<Bytes>>,
70 request_timeout: std::time::Duration,
71 allow_private_urls: bool,
73}
74
75impl Default for HttpPushSender {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81impl HttpPushSender {
82 #[must_use]
84 pub fn new() -> Self {
85 Self::with_timeout(DEFAULT_PUSH_REQUEST_TIMEOUT)
86 }
87
88 #[must_use]
90 pub fn with_timeout(request_timeout: std::time::Duration) -> Self {
91 let client = Client::builder(TokioExecutor::new()).build_http();
92 Self {
93 client,
94 request_timeout,
95 allow_private_urls: false,
96 }
97 }
98
99 #[must_use]
104 pub const fn allow_private_urls(mut self) -> Self {
105 self.allow_private_urls = true;
106 self
107 }
108}
109
110#[allow(clippy::missing_const_for_fn)] fn is_private_ip(ip: IpAddr) -> bool {
113 match ip {
114 IpAddr::V4(v4) => {
115 v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() || v4.octets()[0] == 100 && (v4.octets()[1] & 0xC0) == 64 }
121 IpAddr::V6(v6) => {
122 v6.is_loopback() || v6.is_unspecified() || (v6.segments()[0] & 0xfe00) == 0xfc00
126 || (v6.segments()[0] & 0xffc0) == 0xfe80
128 }
129 }
130}
131
132#[allow(clippy::case_sensitive_file_extension_comparisons)] fn validate_webhook_url(url: &str) -> A2aResult<()> {
137 let uri: hyper::Uri = url
139 .parse()
140 .map_err(|e| A2aError::invalid_params(format!("invalid webhook URL: {e}")))?;
141
142 let host = uri
143 .host()
144 .ok_or_else(|| A2aError::invalid_params("webhook URL missing host"))?;
145
146 let host_bare = host.trim_start_matches('[').trim_end_matches(']');
148
149 if let Ok(ip) = host_bare.parse::<IpAddr>() {
151 if is_private_ip(ip) {
152 return Err(A2aError::invalid_params(format!(
153 "webhook URL targets private/loopback address: {host}"
154 )));
155 }
156 }
157
158 let host_lower = host.to_ascii_lowercase();
160 if host_lower == "localhost"
161 || host_lower.ends_with(".local")
162 || host_lower.ends_with(".internal")
163 {
164 return Err(A2aError::invalid_params(format!(
165 "webhook URL targets local/internal hostname: {host}"
166 )));
167 }
168
169 Ok(())
170}
171
172fn validate_header_value(value: &str, name: &str) -> A2aResult<()> {
174 if value.contains('\r') || value.contains('\n') {
175 return Err(A2aError::invalid_params(format!(
176 "{name} contains invalid characters (CR/LF)"
177 )));
178 }
179 Ok(())
180}
181
182#[allow(clippy::manual_async_fn, clippy::too_many_lines)]
183impl PushSender for HttpPushSender {
184 fn send<'a>(
185 &'a self,
186 url: &'a str,
187 event: &'a StreamResponse,
188 config: &'a TaskPushNotificationConfig,
189 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
190 Box::pin(async move {
191 trace_info!(url, "delivering push notification");
192
193 if !self.allow_private_urls {
195 validate_webhook_url(url)?;
196 }
197
198 if let Some(ref auth) = config.authentication {
200 validate_header_value(&auth.credentials, "authentication credentials")?;
201 validate_header_value(&auth.scheme, "authentication scheme")?;
202 }
203 if let Some(ref token) = config.token {
204 validate_header_value(token, "notification token")?;
205 }
206
207 let body_bytes = serde_json::to_vec(event)
208 .map_err(|e| A2aError::internal(format!("push serialization: {e}")))?;
209
210 let mut last_err = String::new();
211
212 for attempt in 0..MAX_PUSH_ATTEMPTS {
213 let mut builder = hyper::Request::builder()
214 .method(hyper::Method::POST)
215 .uri(url)
216 .header("content-type", "application/json");
217
218 if let Some(ref auth) = config.authentication {
220 match auth.scheme.as_str() {
221 "bearer" => {
222 builder = builder
223 .header("authorization", format!("Bearer {}", auth.credentials));
224 }
225 "basic" => {
226 builder = builder
227 .header("authorization", format!("Basic {}", auth.credentials));
228 }
229 _ => {
230 trace_warn!(
231 scheme = auth.scheme.as_str(),
232 "unknown authentication scheme; no auth header set"
233 );
234 }
235 }
236 }
237
238 if let Some(ref token) = config.token {
240 builder = builder.header("a2a-notification-token", token.as_str());
241 }
242
243 let req = builder
244 .body(Full::new(Bytes::from(body_bytes.clone())))
245 .map_err(|e| A2aError::internal(format!("push request build: {e}")))?;
246
247 let request_result =
248 tokio::time::timeout(self.request_timeout, self.client.request(req)).await;
249
250 match request_result {
251 Ok(Ok(resp)) if resp.status().is_success() => {
252 trace_debug!(url, "push notification delivered");
253 return Ok(());
254 }
255 Ok(Ok(resp)) => {
256 last_err = format!("push notification got HTTP {}", resp.status());
257 trace_warn!(url, attempt, status = %resp.status(), "push delivery failed");
258 }
259 Ok(Err(e)) => {
260 last_err = format!("push notification failed: {e}");
261 trace_warn!(url, attempt, error = %e, "push delivery error");
262 }
263 Err(_) => {
264 last_err = format!(
265 "push notification timed out after {}s",
266 self.request_timeout.as_secs()
267 );
268 trace_warn!(url, attempt, "push delivery timed out");
269 }
270 }
271
272 if attempt < MAX_PUSH_ATTEMPTS - 1 {
274 if let Some(delay) = PUSH_RETRY_BACKOFF.get(attempt) {
275 tokio::time::sleep(*delay).await;
276 }
277 }
278 }
279
280 Err(A2aError::internal(last_err))
281 })
282 }
283}
284
285#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn rejects_loopback_ipv4() {
293 assert!(validate_webhook_url("http://127.0.0.1:8080/webhook").is_err());
294 }
295
296 #[test]
297 fn rejects_private_10_range() {
298 assert!(validate_webhook_url("http://10.0.0.1/webhook").is_err());
299 }
300
301 #[test]
302 fn rejects_private_172_range() {
303 assert!(validate_webhook_url("http://172.16.0.1/webhook").is_err());
304 }
305
306 #[test]
307 fn rejects_private_192_168_range() {
308 assert!(validate_webhook_url("http://192.168.1.1/webhook").is_err());
309 }
310
311 #[test]
312 fn rejects_link_local() {
313 assert!(validate_webhook_url("http://169.254.169.254/latest").is_err());
314 }
315
316 #[test]
317 fn rejects_localhost() {
318 assert!(validate_webhook_url("http://localhost:8080/webhook").is_err());
319 }
320
321 #[test]
322 fn rejects_dot_local() {
323 assert!(validate_webhook_url("http://myservice.local/webhook").is_err());
324 }
325
326 #[test]
327 fn rejects_dot_internal() {
328 assert!(validate_webhook_url("http://metadata.internal/webhook").is_err());
329 }
330
331 #[test]
332 fn rejects_ipv6_loopback() {
333 assert!(validate_webhook_url("http://[::1]:8080/webhook").is_err());
334 }
335
336 #[test]
337 fn accepts_public_url() {
338 assert!(validate_webhook_url("https://example.com/webhook").is_ok());
339 }
340
341 #[test]
342 fn accepts_public_ip() {
343 assert!(validate_webhook_url("https://203.0.113.1/webhook").is_ok());
344 }
345
346 #[test]
347 fn rejects_header_with_crlf() {
348 assert!(validate_header_value("token\r\nX-Injected: value", "test").is_err());
349 }
350
351 #[test]
352 fn rejects_header_with_cr() {
353 assert!(validate_header_value("token\rvalue", "test").is_err());
354 }
355
356 #[test]
357 fn rejects_header_with_lf() {
358 assert!(validate_header_value("token\nvalue", "test").is_err());
359 }
360
361 #[test]
362 fn accepts_clean_header_value() {
363 assert!(validate_header_value("Bearer abc123+/=", "test").is_ok());
364 }
365}