a2a_protocol_server/push/
sender.rs1use std::future::Future;
18use std::net::{IpAddr, SocketAddr};
19use std::pin::Pin;
20
21use a2a_protocol_types::error::{A2aError, A2aResult};
22use a2a_protocol_types::events::StreamResponse;
23use a2a_protocol_types::push::TaskPushNotificationConfig;
24use bytes::Bytes;
25use http_body_util::Full;
26use hyper_util::client::legacy::Client;
27use hyper_util::rt::TokioExecutor;
28
29pub trait PushSender: Send + Sync + 'static {
33 fn send<'a>(
39 &'a self,
40 url: &'a str,
41 event: &'a StreamResponse,
42 config: &'a TaskPushNotificationConfig,
43 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>>;
44
45 fn allows_private_urls(&self) -> bool {
51 false
52 }
53}
54
55const DEFAULT_PUSH_REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
57
58#[derive(Debug, Clone)]
75pub struct PushRetryPolicy {
76 pub max_attempts: usize,
78 pub backoff: Vec<std::time::Duration>,
83}
84
85impl Default for PushRetryPolicy {
86 fn default() -> Self {
87 Self {
88 max_attempts: 3,
89 backoff: vec![
90 std::time::Duration::from_secs(1),
91 std::time::Duration::from_secs(2),
92 ],
93 }
94 }
95}
96
97impl PushRetryPolicy {
98 #[must_use]
100 pub const fn with_max_attempts(mut self, max: usize) -> Self {
101 self.max_attempts = max;
102 self
103 }
104
105 #[must_use]
107 pub fn with_backoff(mut self, backoff: Vec<std::time::Duration>) -> Self {
108 self.backoff = backoff;
109 self
110 }
111}
112
113#[derive(Debug)]
124pub struct HttpPushSender {
125 client: Client<hyper_util::client::legacy::connect::HttpConnector, Full<Bytes>>,
126 request_timeout: std::time::Duration,
127 retry_policy: PushRetryPolicy,
128 allow_private_urls: bool,
130}
131
132impl Default for HttpPushSender {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138impl HttpPushSender {
139 #[must_use]
142 pub fn new() -> Self {
143 Self::with_timeout(DEFAULT_PUSH_REQUEST_TIMEOUT)
144 }
145
146 #[must_use]
148 pub fn with_timeout(request_timeout: std::time::Duration) -> Self {
149 let client = Client::builder(TokioExecutor::new()).build_http();
150 Self {
151 client,
152 request_timeout,
153 retry_policy: PushRetryPolicy::default(),
154 allow_private_urls: false,
155 }
156 }
157
158 #[must_use]
160 pub fn with_retry_policy(mut self, policy: PushRetryPolicy) -> Self {
161 self.retry_policy = policy;
162 self
163 }
164
165 #[must_use]
170 pub const fn allow_private_urls(mut self) -> Self {
171 self.allow_private_urls = true;
172 self
173 }
174}
175
176#[allow(clippy::missing_const_for_fn)] fn is_private_ip(ip: IpAddr) -> bool {
179 match ip {
180 IpAddr::V4(v4) => {
181 v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() || v4.octets()[0] == 100 && (v4.octets()[1] & 0xC0) == 64 }
187 IpAddr::V6(v6) => {
188 v6.is_loopback() || v6.is_unspecified() || (v6.segments()[0] & 0xfe00) == 0xfc00
192 || (v6.segments()[0] & 0xffc0) == 0xfe80
194 }
195 }
196}
197
198#[allow(clippy::case_sensitive_file_extension_comparisons)] pub(crate) fn validate_webhook_url(url: &str) -> A2aResult<()> {
204 let uri: hyper::Uri = url
206 .parse()
207 .map_err(|e| A2aError::invalid_params(format!("invalid webhook URL: {e}")))?;
208
209 match uri.scheme_str() {
211 Some("http" | "https") => {}
212 Some(other) => {
213 return Err(A2aError::invalid_params(format!(
214 "webhook URL has unsupported scheme: {other} (expected http or https)"
215 )));
216 }
217 None => {
218 return Err(A2aError::invalid_params(
219 "webhook URL missing scheme (expected http:// or https://)",
220 ));
221 }
222 }
223
224 let host = uri
225 .host()
226 .ok_or_else(|| A2aError::invalid_params("webhook URL missing host"))?;
227
228 let host_bare = host.trim_start_matches('[').trim_end_matches(']');
230
231 if let Ok(ip) = host_bare.parse::<IpAddr>() {
233 if is_private_ip(ip) {
234 return Err(A2aError::invalid_params(format!(
235 "webhook URL targets private/loopback address: {host}"
236 )));
237 }
238 }
239
240 let host_lower = host.to_ascii_lowercase();
242 if host_lower == "localhost"
243 || host_lower.ends_with(".local")
244 || host_lower.ends_with(".internal")
245 {
246 return Err(A2aError::invalid_params(format!(
247 "webhook URL targets local/internal hostname: {host}"
248 )));
249 }
250
251 Ok(())
252}
253
254pub(crate) async fn validate_webhook_url_with_dns(url: &str) -> A2aResult<Option<SocketAddr>> {
271 validate_webhook_url(url)?;
273
274 let uri: hyper::Uri = url
276 .parse()
277 .map_err(|e| A2aError::invalid_params(format!("invalid webhook URL: {e}")))?;
278
279 let host = uri
280 .host()
281 .ok_or_else(|| A2aError::invalid_params("webhook URL missing host"))?;
282
283 let host_bare = host.trim_start_matches('[').trim_end_matches(']');
285
286 if host_bare.parse::<IpAddr>().is_ok() {
289 return Ok(None);
290 }
291
292 let port = uri.port_u16().unwrap_or_else(|| {
294 if uri.scheme_str() == Some("https") {
295 443
296 } else {
297 80
298 }
299 });
300
301 let addr = format!("{host_bare}:{port}");
302 let resolved = tokio::net::lookup_host(&addr).await.map_err(|e| {
303 A2aError::invalid_params(format!(
304 "webhook URL hostname could not be resolved: {host_bare}: {e}"
305 ))
306 })?;
307
308 let mut pinned: Option<SocketAddr> = None;
309 for socket_addr in resolved {
310 let ip = socket_addr.ip();
311 if is_private_ip(ip) {
312 return Err(A2aError::invalid_params(format!(
313 "webhook URL hostname {host_bare} resolves to private/loopback address: {ip}"
314 )));
315 }
316 if pinned.is_none() {
317 pinned = Some(socket_addr);
318 }
319 }
320
321 pinned
322 .ok_or_else(|| {
323 A2aError::invalid_params(format!(
324 "webhook URL hostname {host_bare} did not resolve to any addresses"
325 ))
326 })
327 .map(Some)
328}
329
330fn rewrite_uri_with_pinned_addr(url: &str, pinned: SocketAddr) -> A2aResult<hyper::Uri> {
340 let uri: hyper::Uri = url
341 .parse()
342 .map_err(|e| A2aError::invalid_params(format!("invalid webhook URL: {e}")))?;
343
344 let scheme = uri
345 .scheme_str()
346 .ok_or_else(|| A2aError::invalid_params("webhook URL missing scheme"))?;
347
348 let host_str = match pinned.ip() {
350 IpAddr::V4(v4) => v4.to_string(),
351 IpAddr::V6(v6) => format!("[{v6}]"),
352 };
353
354 let path_and_query = uri
355 .path_and_query()
356 .map_or_else(|| "/".to_string(), std::string::ToString::to_string);
357
358 let rewritten = format!(
359 "{scheme}://{host_str}:{port}{path_and_query}",
360 port = pinned.port()
361 );
362
363 rewritten
364 .parse()
365 .map_err(|e| A2aError::invalid_params(format!("could not rewrite webhook URL: {e}")))
366}
367
368fn host_header_from_url(url: &str) -> A2aResult<String> {
374 let uri: hyper::Uri = url
375 .parse()
376 .map_err(|e| A2aError::invalid_params(format!("invalid webhook URL: {e}")))?;
377 let host = uri
378 .host()
379 .ok_or_else(|| A2aError::invalid_params("webhook URL missing host"))?;
380 Ok(uri
381 .port_u16()
382 .map_or_else(|| host.to_string(), |port| format!("{host}:{port}")))
383}
384
385fn validate_header_value(value: &str, name: &str) -> A2aResult<()> {
387 if value.contains('\r') || value.contains('\n') {
388 return Err(A2aError::invalid_params(format!(
389 "{name} contains invalid characters (CR/LF)"
390 )));
391 }
392 Ok(())
393}
394
395#[allow(clippy::manual_async_fn, clippy::too_many_lines)]
396impl PushSender for HttpPushSender {
397 fn allows_private_urls(&self) -> bool {
398 self.allow_private_urls
399 }
400
401 fn send<'a>(
402 &'a self,
403 url: &'a str,
404 event: &'a StreamResponse,
405 config: &'a TaskPushNotificationConfig,
406 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
407 Box::pin(async move {
408 trace_info!(url, "delivering push notification");
409
410 let pinned_addr = if self.allow_private_urls {
419 None
420 } else {
421 validate_webhook_url_with_dns(url).await?
422 };
423
424 let (pinned_uri, pinned_host_header) = if let Some(addr) = pinned_addr {
425 (
426 Some(rewrite_uri_with_pinned_addr(url, addr)?),
427 Some(host_header_from_url(url)?),
428 )
429 } else {
430 (None, None)
431 };
432
433 if let Some(ref auth) = config.authentication {
435 validate_header_value(&auth.credentials, "authentication credentials")?;
436 validate_header_value(&auth.scheme, "authentication scheme")?;
437 }
438 if let Some(ref token) = config.token {
439 validate_header_value(token, "notification token")?;
440 }
441
442 let body_bytes: Bytes = serde_json::to_vec(event)
443 .map(Bytes::from)
444 .map_err(|e| A2aError::internal(format!("push serialization: {e}")))?;
445
446 let mut last_err = String::new();
447
448 for attempt in 0..self.retry_policy.max_attempts {
449 let mut builder = hyper::Request::builder()
450 .method(hyper::Method::POST)
451 .header("content-type", "application/json");
452
453 if let Some(uri) = pinned_uri.as_ref() {
454 builder = builder.uri(uri.clone());
455 if let Some(host) = pinned_host_header.as_deref() {
456 builder = builder.header("host", host);
457 }
458 } else {
459 builder = builder.uri(url);
460 }
461
462 if let Some(ref auth) = config.authentication {
464 match auth.scheme.as_str() {
465 "bearer" => {
466 builder = builder
467 .header("authorization", format!("Bearer {}", auth.credentials));
468 }
469 "basic" => {
470 builder = builder
471 .header("authorization", format!("Basic {}", auth.credentials));
472 }
473 _ => {
474 trace_warn!(
475 scheme = auth.scheme.as_str(),
476 "unknown authentication scheme; no auth header set"
477 );
478 }
479 }
480 }
481
482 if let Some(ref token) = config.token {
484 builder = builder.header("a2a-notification-token", token.as_str());
485 }
486
487 let req = builder
488 .body(Full::new(body_bytes.clone()))
489 .map_err(|e| A2aError::internal(format!("push request build: {e}")))?;
490
491 let request_result =
492 tokio::time::timeout(self.request_timeout, self.client.request(req)).await;
493
494 match request_result {
495 Ok(Ok(resp)) if resp.status().is_success() => {
496 trace_debug!(url, "push notification delivered");
497 return Ok(());
498 }
499 Ok(Ok(resp)) => {
500 last_err = format!("push notification got HTTP {}", resp.status());
501 trace_warn!(url, attempt, status = %resp.status(), "push delivery failed");
502 }
503 Ok(Err(e)) => {
504 last_err = format!("push notification failed: {e}");
505 trace_warn!(url, attempt, error = %e, "push delivery error");
506 }
507 Err(_) => {
508 last_err = format!(
509 "push notification timed out after {}s",
510 self.request_timeout.as_secs()
511 );
512 trace_warn!(url, attempt, "push delivery timed out");
513 }
514 }
515
516 if attempt < self.retry_policy.max_attempts - 1 {
518 let delay = self
519 .retry_policy
520 .backoff
521 .get(attempt)
522 .or_else(|| self.retry_policy.backoff.last());
523 if let Some(delay) = delay {
524 tokio::time::sleep(*delay).await;
525 }
526 }
527 }
528
529 Err(A2aError::internal(last_err))
530 })
531 }
532}
533
534#[cfg(test)]
537mod tests {
538 use super::*;
539
540 #[test]
542 fn push_retry_policy_with_max_attempts() {
543 let policy = PushRetryPolicy::default().with_max_attempts(5);
544 assert_eq!(policy.max_attempts, 5);
545 assert_eq!(policy.backoff.len(), 2);
547 }
548
549 #[test]
551 fn push_retry_policy_with_backoff() {
552 let backoff = vec![
553 std::time::Duration::from_millis(100),
554 std::time::Duration::from_millis(500),
555 std::time::Duration::from_secs(1),
556 ];
557 let policy = PushRetryPolicy::default().with_backoff(backoff.clone());
558 assert_eq!(policy.backoff, backoff);
559 assert_eq!(policy.max_attempts, 3);
561 }
562
563 #[test]
565 fn http_push_sender_with_retry_policy() {
566 let policy = PushRetryPolicy::default().with_max_attempts(10);
567 let sender = HttpPushSender::new().with_retry_policy(policy);
568 assert_eq!(sender.retry_policy.max_attempts, 10);
569 }
570
571 #[test]
573 fn rejects_url_without_host() {
574 assert!(validate_webhook_url("http:///path").is_err());
575 }
576
577 #[test]
579 fn http_push_sender_allow_private_urls() {
580 let sender = HttpPushSender::new().allow_private_urls();
581 assert!(sender.allow_private_urls);
582 }
583
584 #[test]
586 fn http_push_sender_default() {
587 let sender = HttpPushSender::default();
588 assert_eq!(sender.request_timeout, DEFAULT_PUSH_REQUEST_TIMEOUT);
589 assert!(!sender.allow_private_urls);
590 }
591
592 #[test]
594 fn push_retry_policy_default() {
595 let policy = PushRetryPolicy::default();
596 assert_eq!(policy.max_attempts, 3);
597 assert_eq!(policy.backoff.len(), 2);
598 assert_eq!(policy.backoff[0], std::time::Duration::from_secs(1));
599 assert_eq!(policy.backoff[1], std::time::Duration::from_secs(2));
600 }
601
602 #[test]
603 fn rejects_loopback_ipv4() {
604 assert!(validate_webhook_url("http://127.0.0.1:8080/webhook").is_err());
605 }
606
607 #[test]
608 fn rejects_private_10_range() {
609 assert!(validate_webhook_url("http://10.0.0.1/webhook").is_err());
610 }
611
612 #[test]
613 fn rejects_private_172_range() {
614 assert!(validate_webhook_url("http://172.16.0.1/webhook").is_err());
615 }
616
617 #[test]
618 fn rejects_private_192_168_range() {
619 assert!(validate_webhook_url("http://192.168.1.1/webhook").is_err());
620 }
621
622 #[test]
623 fn rejects_link_local() {
624 assert!(validate_webhook_url("http://169.254.169.254/latest").is_err());
625 }
626
627 #[test]
628 fn rejects_localhost() {
629 assert!(validate_webhook_url("http://localhost:8080/webhook").is_err());
630 }
631
632 #[test]
633 fn rejects_dot_local() {
634 assert!(validate_webhook_url("http://myservice.local/webhook").is_err());
635 }
636
637 #[test]
638 fn rejects_dot_internal() {
639 assert!(validate_webhook_url("http://metadata.internal/webhook").is_err());
640 }
641
642 #[test]
643 fn rejects_ipv6_loopback() {
644 assert!(validate_webhook_url("http://[::1]:8080/webhook").is_err());
645 }
646
647 #[test]
648 fn accepts_public_url() {
649 assert!(validate_webhook_url("https://example.com/webhook").is_ok());
650 }
651
652 #[test]
653 fn accepts_public_ip() {
654 assert!(validate_webhook_url("https://203.0.113.1/webhook").is_ok());
655 }
656
657 #[test]
658 fn rejects_header_with_crlf() {
659 assert!(validate_header_value("token\r\nX-Injected: value", "test").is_err());
660 }
661
662 #[test]
663 fn rejects_header_with_cr() {
664 assert!(validate_header_value("token\rvalue", "test").is_err());
665 }
666
667 #[test]
668 fn rejects_header_with_lf() {
669 assert!(validate_header_value("token\nvalue", "test").is_err());
670 }
671
672 #[test]
673 fn accepts_clean_header_value() {
674 assert!(validate_header_value("Bearer abc123+/=", "test").is_ok());
675 }
676
677 #[test]
678 fn rejects_url_without_scheme() {
679 assert!(validate_webhook_url("example.com/webhook").is_err());
680 }
681
682 #[test]
683 fn rejects_ftp_scheme() {
684 assert!(validate_webhook_url("ftp://example.com/webhook").is_err());
685 }
686
687 #[test]
688 fn rejects_file_scheme() {
689 assert!(validate_webhook_url("file:///etc/passwd").is_err());
690 }
691
692 #[test]
693 fn accepts_http_scheme() {
694 assert!(validate_webhook_url("http://example.com/webhook").is_ok());
695 }
696
697 #[test]
698 fn rejects_cgnat_range() {
699 assert!(validate_webhook_url("http://100.64.0.1/webhook").is_err());
700 }
701
702 #[test]
703 fn rejects_unspecified_ipv4() {
704 assert!(validate_webhook_url("http://0.0.0.0/webhook").is_err());
705 }
706
707 #[test]
708 fn rejects_ipv6_unique_local() {
709 assert!(validate_webhook_url("http://[fc00::1]:8080/webhook").is_err());
710 }
711
712 #[test]
713 fn rejects_ipv6_link_local() {
714 assert!(validate_webhook_url("http://[fe80::1]:8080/webhook").is_err());
715 }
716
717 #[tokio::test]
720 async fn dns_rejects_loopback_ip_literal() {
721 let result = validate_webhook_url_with_dns("http://127.0.0.1:8080/webhook").await;
723 assert!(result.is_err(), "loopback IP should be rejected");
724 }
725
726 #[tokio::test]
727 async fn dns_rejects_private_ip_literal() {
728 let result = validate_webhook_url_with_dns("http://10.0.0.1/webhook").await;
729 assert!(result.is_err(), "private IP should be rejected");
730 }
731
732 #[tokio::test]
733 async fn dns_rejects_localhost_hostname() {
734 let result = validate_webhook_url_with_dns("http://localhost:8080/webhook").await;
736 assert!(result.is_err(), "localhost should be rejected");
737 }
738
739 #[tokio::test]
740 async fn dns_rejects_invalid_scheme() {
741 let result = validate_webhook_url_with_dns("ftp://example.com/webhook").await;
742 assert!(result.is_err(), "ftp scheme should be rejected");
743 }
744
745 #[tokio::test]
746 async fn dns_rejects_missing_host() {
747 let result = validate_webhook_url_with_dns("http:///path").await;
748 assert!(result.is_err(), "missing host should be rejected");
749 }
750
751 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
752 async fn dns_rejects_unresolvable_hostname() {
753 let (tx, rx) = tokio::sync::oneshot::channel();
756 std::thread::spawn(move || {
757 let rt = tokio::runtime::Builder::new_current_thread()
758 .enable_all()
759 .build()
760 .unwrap();
761 let result = rt.block_on(validate_webhook_url_with_dns(
762 "https://this-hostname-definitely-does-not-exist-a2a-test.invalid/webhook",
763 ));
764 let _ = tx.send(result);
765 });
766 match tokio::time::timeout(std::time::Duration::from_secs(5), rx).await {
767 Ok(Ok(result)) => {
768 assert!(result.is_err(), "unresolvable hostname should be rejected");
769 }
770 Ok(Err(_)) => panic!("sender dropped without sending"),
771 Err(_elapsed) => {
772 }
774 }
775 }
776
777 #[tokio::test]
778 async fn dns_accepts_ip_literal_public() {
779 let result = validate_webhook_url_with_dns("https://203.0.113.1/webhook").await;
782 assert!(
783 matches!(result, Ok(None)),
784 "public IP literal should be accepted with no pinning (got {result:?})",
785 );
786 }
787
788 #[test]
791 fn rewrite_uri_preserves_scheme_path_and_query() {
792 let pinned: SocketAddr = "203.0.113.1:8080".parse().unwrap();
793 let rewritten =
794 rewrite_uri_with_pinned_addr("http://example.com:8080/webhook?x=1", pinned).unwrap();
795 assert_eq!(rewritten.to_string(), "http://203.0.113.1:8080/webhook?x=1",);
796 }
797
798 #[test]
799 fn rewrite_uri_uses_ipv6_brackets() {
800 let pinned: SocketAddr = "[2001:db8::1]:443".parse().unwrap();
801 let rewritten =
802 rewrite_uri_with_pinned_addr("https://example.com/webhook", pinned).unwrap();
803 assert!(
805 rewritten.to_string().contains("[2001:db8::1]:443"),
806 "IPv6 literal should be bracketed: {rewritten}",
807 );
808 }
809
810 #[test]
811 fn rewrite_uri_default_path_when_missing() {
812 let pinned: SocketAddr = "203.0.113.1:80".parse().unwrap();
813 let rewritten = rewrite_uri_with_pinned_addr("http://example.com", pinned).unwrap();
814 assert_eq!(rewritten.to_string(), "http://203.0.113.1:80/");
815 }
816
817 #[test]
818 fn host_header_includes_port_when_present() {
819 let host = host_header_from_url("http://example.com:8080/webhook").unwrap();
820 assert_eq!(host, "example.com:8080");
821 }
822
823 #[test]
824 fn host_header_omits_default_port() {
825 let host = host_header_from_url("https://example.com/webhook").unwrap();
826 assert_eq!(host, "example.com");
827 }
828
829 #[test]
830 fn host_header_from_url_rejects_missing_host() {
831 let result = host_header_from_url("http:///path");
832 assert!(result.is_err());
833 }
834}