a2a_protocol_server/push/
sender.rs1use std::future::Future;
18use std::net::IpAddr;
19use std::pin::Pin;
20
21use a2a_protocol_types::error::{A2aError, A2aResult};
22use a2a_protocol_types::events::StreamResponse;
23use a2a_protocol_types::push::TaskPushNotificationConfig;
24use bytes::Bytes;
25use http_body_util::Full;
26use hyper_util::client::legacy::Client;
27use hyper_util::rt::TokioExecutor;
28
29pub trait PushSender: Send + Sync + 'static {
33 fn send<'a>(
39 &'a self,
40 url: &'a str,
41 event: &'a StreamResponse,
42 config: &'a TaskPushNotificationConfig,
43 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>>;
44
45 fn allows_private_urls(&self) -> bool {
51 false
52 }
53}
54
55const DEFAULT_PUSH_REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
57
58#[derive(Debug, Clone)]
75pub struct PushRetryPolicy {
76 pub max_attempts: usize,
78 pub backoff: Vec<std::time::Duration>,
83}
84
85impl Default for PushRetryPolicy {
86 fn default() -> Self {
87 Self {
88 max_attempts: 3,
89 backoff: vec![
90 std::time::Duration::from_secs(1),
91 std::time::Duration::from_secs(2),
92 ],
93 }
94 }
95}
96
97impl PushRetryPolicy {
98 #[must_use]
100 pub const fn with_max_attempts(mut self, max: usize) -> Self {
101 self.max_attempts = max;
102 self
103 }
104
105 #[must_use]
107 pub fn with_backoff(mut self, backoff: Vec<std::time::Duration>) -> Self {
108 self.backoff = backoff;
109 self
110 }
111}
112
113#[derive(Debug)]
124pub struct HttpPushSender {
125 client: Client<hyper_util::client::legacy::connect::HttpConnector, Full<Bytes>>,
126 request_timeout: std::time::Duration,
127 retry_policy: PushRetryPolicy,
128 allow_private_urls: bool,
130}
131
132impl Default for HttpPushSender {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138impl HttpPushSender {
139 #[must_use]
142 pub fn new() -> Self {
143 Self::with_timeout(DEFAULT_PUSH_REQUEST_TIMEOUT)
144 }
145
146 #[must_use]
148 pub fn with_timeout(request_timeout: std::time::Duration) -> Self {
149 let client = Client::builder(TokioExecutor::new()).build_http();
150 Self {
151 client,
152 request_timeout,
153 retry_policy: PushRetryPolicy::default(),
154 allow_private_urls: false,
155 }
156 }
157
158 #[must_use]
160 pub fn with_retry_policy(mut self, policy: PushRetryPolicy) -> Self {
161 self.retry_policy = policy;
162 self
163 }
164
165 #[must_use]
170 pub const fn allow_private_urls(mut self) -> Self {
171 self.allow_private_urls = true;
172 self
173 }
174}
175
176#[allow(clippy::missing_const_for_fn)] fn is_private_ip(ip: IpAddr) -> bool {
179 match ip {
180 IpAddr::V4(v4) => {
181 v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() || v4.octets()[0] == 100 && (v4.octets()[1] & 0xC0) == 64 }
187 IpAddr::V6(v6) => {
188 v6.is_loopback() || v6.is_unspecified() || (v6.segments()[0] & 0xfe00) == 0xfc00
192 || (v6.segments()[0] & 0xffc0) == 0xfe80
194 }
195 }
196}
197
198#[allow(clippy::case_sensitive_file_extension_comparisons)] pub(crate) fn validate_webhook_url(url: &str) -> A2aResult<()> {
204 let uri: hyper::Uri = url
206 .parse()
207 .map_err(|e| A2aError::invalid_params(format!("invalid webhook URL: {e}")))?;
208
209 match uri.scheme_str() {
211 Some("http" | "https") => {}
212 Some(other) => {
213 return Err(A2aError::invalid_params(format!(
214 "webhook URL has unsupported scheme: {other} (expected http or https)"
215 )));
216 }
217 None => {
218 return Err(A2aError::invalid_params(
219 "webhook URL missing scheme (expected http:// or https://)",
220 ));
221 }
222 }
223
224 let host = uri
225 .host()
226 .ok_or_else(|| A2aError::invalid_params("webhook URL missing host"))?;
227
228 let host_bare = host.trim_start_matches('[').trim_end_matches(']');
230
231 if let Ok(ip) = host_bare.parse::<IpAddr>() {
233 if is_private_ip(ip) {
234 return Err(A2aError::invalid_params(format!(
235 "webhook URL targets private/loopback address: {host}"
236 )));
237 }
238 }
239
240 let host_lower = host.to_ascii_lowercase();
242 if host_lower == "localhost"
243 || host_lower.ends_with(".local")
244 || host_lower.ends_with(".internal")
245 {
246 return Err(A2aError::invalid_params(format!(
247 "webhook URL targets local/internal hostname: {host}"
248 )));
249 }
250
251 Ok(())
252}
253
254pub(crate) async fn validate_webhook_url_with_dns(url: &str) -> A2aResult<()> {
261 validate_webhook_url(url)?;
263
264 let uri: hyper::Uri = url
266 .parse()
267 .map_err(|e| A2aError::invalid_params(format!("invalid webhook URL: {e}")))?;
268
269 let host = uri
270 .host()
271 .ok_or_else(|| A2aError::invalid_params("webhook URL missing host"))?;
272
273 let host_bare = host.trim_start_matches('[').trim_end_matches(']');
275
276 if host_bare.parse::<IpAddr>().is_ok() {
278 return Ok(());
279 }
280
281 let port = uri.port_u16().unwrap_or_else(|| {
283 if uri.scheme_str() == Some("https") {
284 443
285 } else {
286 80
287 }
288 });
289
290 let addr = format!("{host_bare}:{port}");
291 let resolved = tokio::net::lookup_host(&addr).await.map_err(|e| {
292 A2aError::invalid_params(format!(
293 "webhook URL hostname could not be resolved: {host_bare}: {e}"
294 ))
295 })?;
296
297 let mut found_any = false;
298 for socket_addr in resolved {
299 found_any = true;
300 let ip = socket_addr.ip();
301 if is_private_ip(ip) {
302 return Err(A2aError::invalid_params(format!(
303 "webhook URL hostname {host_bare} resolves to private/loopback address: {ip}"
304 )));
305 }
306 }
307
308 if !found_any {
309 return Err(A2aError::invalid_params(format!(
310 "webhook URL hostname {host_bare} did not resolve to any addresses"
311 )));
312 }
313
314 Ok(())
315}
316
317fn validate_header_value(value: &str, name: &str) -> A2aResult<()> {
319 if value.contains('\r') || value.contains('\n') {
320 return Err(A2aError::invalid_params(format!(
321 "{name} contains invalid characters (CR/LF)"
322 )));
323 }
324 Ok(())
325}
326
327#[allow(clippy::manual_async_fn, clippy::too_many_lines)]
328impl PushSender for HttpPushSender {
329 fn allows_private_urls(&self) -> bool {
330 self.allow_private_urls
331 }
332
333 fn send<'a>(
334 &'a self,
335 url: &'a str,
336 event: &'a StreamResponse,
337 config: &'a TaskPushNotificationConfig,
338 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
339 Box::pin(async move {
340 trace_info!(url, "delivering push notification");
341
342 if !self.allow_private_urls {
344 validate_webhook_url_with_dns(url).await?;
345 }
346
347 if let Some(ref auth) = config.authentication {
349 validate_header_value(&auth.credentials, "authentication credentials")?;
350 validate_header_value(&auth.scheme, "authentication scheme")?;
351 }
352 if let Some(ref token) = config.token {
353 validate_header_value(token, "notification token")?;
354 }
355
356 let body_bytes: Bytes = serde_json::to_vec(event)
357 .map(Bytes::from)
358 .map_err(|e| A2aError::internal(format!("push serialization: {e}")))?;
359
360 let mut last_err = String::new();
361
362 for attempt in 0..self.retry_policy.max_attempts {
363 let mut builder = hyper::Request::builder()
364 .method(hyper::Method::POST)
365 .uri(url)
366 .header("content-type", "application/json");
367
368 if let Some(ref auth) = config.authentication {
370 match auth.scheme.as_str() {
371 "bearer" => {
372 builder = builder
373 .header("authorization", format!("Bearer {}", auth.credentials));
374 }
375 "basic" => {
376 builder = builder
377 .header("authorization", format!("Basic {}", auth.credentials));
378 }
379 _ => {
380 trace_warn!(
381 scheme = auth.scheme.as_str(),
382 "unknown authentication scheme; no auth header set"
383 );
384 }
385 }
386 }
387
388 if let Some(ref token) = config.token {
390 builder = builder.header("a2a-notification-token", token.as_str());
391 }
392
393 let req = builder
394 .body(Full::new(body_bytes.clone()))
395 .map_err(|e| A2aError::internal(format!("push request build: {e}")))?;
396
397 let request_result =
398 tokio::time::timeout(self.request_timeout, self.client.request(req)).await;
399
400 match request_result {
401 Ok(Ok(resp)) if resp.status().is_success() => {
402 trace_debug!(url, "push notification delivered");
403 return Ok(());
404 }
405 Ok(Ok(resp)) => {
406 last_err = format!("push notification got HTTP {}", resp.status());
407 trace_warn!(url, attempt, status = %resp.status(), "push delivery failed");
408 }
409 Ok(Err(e)) => {
410 last_err = format!("push notification failed: {e}");
411 trace_warn!(url, attempt, error = %e, "push delivery error");
412 }
413 Err(_) => {
414 last_err = format!(
415 "push notification timed out after {}s",
416 self.request_timeout.as_secs()
417 );
418 trace_warn!(url, attempt, "push delivery timed out");
419 }
420 }
421
422 if attempt < self.retry_policy.max_attempts - 1 {
424 let delay = self
425 .retry_policy
426 .backoff
427 .get(attempt)
428 .or_else(|| self.retry_policy.backoff.last());
429 if let Some(delay) = delay {
430 tokio::time::sleep(*delay).await;
431 }
432 }
433 }
434
435 Err(A2aError::internal(last_err))
436 })
437 }
438}
439
440#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
448 fn push_retry_policy_with_max_attempts() {
449 let policy = PushRetryPolicy::default().with_max_attempts(5);
450 assert_eq!(policy.max_attempts, 5);
451 assert_eq!(policy.backoff.len(), 2);
453 }
454
455 #[test]
457 fn push_retry_policy_with_backoff() {
458 let backoff = vec![
459 std::time::Duration::from_millis(100),
460 std::time::Duration::from_millis(500),
461 std::time::Duration::from_secs(1),
462 ];
463 let policy = PushRetryPolicy::default().with_backoff(backoff.clone());
464 assert_eq!(policy.backoff, backoff);
465 assert_eq!(policy.max_attempts, 3);
467 }
468
469 #[test]
471 fn http_push_sender_with_retry_policy() {
472 let policy = PushRetryPolicy::default().with_max_attempts(10);
473 let sender = HttpPushSender::new().with_retry_policy(policy);
474 assert_eq!(sender.retry_policy.max_attempts, 10);
475 }
476
477 #[test]
479 fn rejects_url_without_host() {
480 assert!(validate_webhook_url("http:///path").is_err());
481 }
482
483 #[test]
485 fn http_push_sender_allow_private_urls() {
486 let sender = HttpPushSender::new().allow_private_urls();
487 assert!(sender.allow_private_urls);
488 }
489
490 #[test]
492 fn http_push_sender_default() {
493 let sender = HttpPushSender::default();
494 assert_eq!(sender.request_timeout, DEFAULT_PUSH_REQUEST_TIMEOUT);
495 assert!(!sender.allow_private_urls);
496 }
497
498 #[test]
500 fn push_retry_policy_default() {
501 let policy = PushRetryPolicy::default();
502 assert_eq!(policy.max_attempts, 3);
503 assert_eq!(policy.backoff.len(), 2);
504 assert_eq!(policy.backoff[0], std::time::Duration::from_secs(1));
505 assert_eq!(policy.backoff[1], std::time::Duration::from_secs(2));
506 }
507
508 #[test]
509 fn rejects_loopback_ipv4() {
510 assert!(validate_webhook_url("http://127.0.0.1:8080/webhook").is_err());
511 }
512
513 #[test]
514 fn rejects_private_10_range() {
515 assert!(validate_webhook_url("http://10.0.0.1/webhook").is_err());
516 }
517
518 #[test]
519 fn rejects_private_172_range() {
520 assert!(validate_webhook_url("http://172.16.0.1/webhook").is_err());
521 }
522
523 #[test]
524 fn rejects_private_192_168_range() {
525 assert!(validate_webhook_url("http://192.168.1.1/webhook").is_err());
526 }
527
528 #[test]
529 fn rejects_link_local() {
530 assert!(validate_webhook_url("http://169.254.169.254/latest").is_err());
531 }
532
533 #[test]
534 fn rejects_localhost() {
535 assert!(validate_webhook_url("http://localhost:8080/webhook").is_err());
536 }
537
538 #[test]
539 fn rejects_dot_local() {
540 assert!(validate_webhook_url("http://myservice.local/webhook").is_err());
541 }
542
543 #[test]
544 fn rejects_dot_internal() {
545 assert!(validate_webhook_url("http://metadata.internal/webhook").is_err());
546 }
547
548 #[test]
549 fn rejects_ipv6_loopback() {
550 assert!(validate_webhook_url("http://[::1]:8080/webhook").is_err());
551 }
552
553 #[test]
554 fn accepts_public_url() {
555 assert!(validate_webhook_url("https://example.com/webhook").is_ok());
556 }
557
558 #[test]
559 fn accepts_public_ip() {
560 assert!(validate_webhook_url("https://203.0.113.1/webhook").is_ok());
561 }
562
563 #[test]
564 fn rejects_header_with_crlf() {
565 assert!(validate_header_value("token\r\nX-Injected: value", "test").is_err());
566 }
567
568 #[test]
569 fn rejects_header_with_cr() {
570 assert!(validate_header_value("token\rvalue", "test").is_err());
571 }
572
573 #[test]
574 fn rejects_header_with_lf() {
575 assert!(validate_header_value("token\nvalue", "test").is_err());
576 }
577
578 #[test]
579 fn accepts_clean_header_value() {
580 assert!(validate_header_value("Bearer abc123+/=", "test").is_ok());
581 }
582
583 #[test]
584 fn rejects_url_without_scheme() {
585 assert!(validate_webhook_url("example.com/webhook").is_err());
586 }
587
588 #[test]
589 fn rejects_ftp_scheme() {
590 assert!(validate_webhook_url("ftp://example.com/webhook").is_err());
591 }
592
593 #[test]
594 fn rejects_file_scheme() {
595 assert!(validate_webhook_url("file:///etc/passwd").is_err());
596 }
597
598 #[test]
599 fn accepts_http_scheme() {
600 assert!(validate_webhook_url("http://example.com/webhook").is_ok());
601 }
602
603 #[test]
604 fn rejects_cgnat_range() {
605 assert!(validate_webhook_url("http://100.64.0.1/webhook").is_err());
606 }
607
608 #[test]
609 fn rejects_unspecified_ipv4() {
610 assert!(validate_webhook_url("http://0.0.0.0/webhook").is_err());
611 }
612
613 #[test]
614 fn rejects_ipv6_unique_local() {
615 assert!(validate_webhook_url("http://[fc00::1]:8080/webhook").is_err());
616 }
617
618 #[test]
619 fn rejects_ipv6_link_local() {
620 assert!(validate_webhook_url("http://[fe80::1]:8080/webhook").is_err());
621 }
622
623 #[tokio::test]
626 async fn dns_rejects_loopback_ip_literal() {
627 let result = validate_webhook_url_with_dns("http://127.0.0.1:8080/webhook").await;
629 assert!(result.is_err(), "loopback IP should be rejected");
630 }
631
632 #[tokio::test]
633 async fn dns_rejects_private_ip_literal() {
634 let result = validate_webhook_url_with_dns("http://10.0.0.1/webhook").await;
635 assert!(result.is_err(), "private IP should be rejected");
636 }
637
638 #[tokio::test]
639 async fn dns_rejects_localhost_hostname() {
640 let result = validate_webhook_url_with_dns("http://localhost:8080/webhook").await;
642 assert!(result.is_err(), "localhost should be rejected");
643 }
644
645 #[tokio::test]
646 async fn dns_rejects_invalid_scheme() {
647 let result = validate_webhook_url_with_dns("ftp://example.com/webhook").await;
648 assert!(result.is_err(), "ftp scheme should be rejected");
649 }
650
651 #[tokio::test]
652 async fn dns_rejects_missing_host() {
653 let result = validate_webhook_url_with_dns("http:///path").await;
654 assert!(result.is_err(), "missing host should be rejected");
655 }
656
657 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
658 async fn dns_rejects_unresolvable_hostname() {
659 let (tx, rx) = tokio::sync::oneshot::channel();
662 std::thread::spawn(move || {
663 let rt = tokio::runtime::Builder::new_current_thread()
664 .enable_all()
665 .build()
666 .unwrap();
667 let result = rt.block_on(validate_webhook_url_with_dns(
668 "https://this-hostname-definitely-does-not-exist-a2a-test.invalid/webhook",
669 ));
670 let _ = tx.send(result);
671 });
672 match tokio::time::timeout(std::time::Duration::from_secs(5), rx).await {
673 Ok(Ok(result)) => {
674 assert!(result.is_err(), "unresolvable hostname should be rejected");
675 }
676 Ok(Err(_)) => panic!("sender dropped without sending"),
677 Err(_elapsed) => {
678 }
680 }
681 }
682
683 #[tokio::test]
684 async fn dns_accepts_ip_literal_public() {
685 let result = validate_webhook_url_with_dns("https://203.0.113.1/webhook").await;
687 assert!(result.is_ok(), "public IP literal should be accepted");
688 }
689}