use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use a2a_protocol_types::error::{A2aError, A2aResult};
use a2a_protocol_types::events::StreamResponse;
use a2a_protocol_types::push::TaskPushNotificationConfig;
use bytes::Bytes;
use http_body_util::Full;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
pub trait PushSender: Send + Sync + 'static {
fn send<'a>(
&'a self,
url: &'a str,
event: &'a StreamResponse,
config: &'a TaskPushNotificationConfig,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>>;
fn allows_private_urls(&self) -> bool {
false
}
}
const DEFAULT_PUSH_REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
#[derive(Debug, Clone)]
pub struct PushRetryPolicy {
pub max_attempts: usize,
pub backoff: Vec<std::time::Duration>,
}
impl Default for PushRetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
backoff: vec![
std::time::Duration::from_secs(1),
std::time::Duration::from_secs(2),
],
}
}
}
impl PushRetryPolicy {
#[must_use]
pub const fn with_max_attempts(mut self, max: usize) -> Self {
self.max_attempts = max;
self
}
#[must_use]
pub fn with_backoff(mut self, backoff: Vec<std::time::Duration>) -> Self {
self.backoff = backoff;
self
}
}
#[derive(Debug)]
pub struct HttpPushSender {
client: Client<hyper_util::client::legacy::connect::HttpConnector, Full<Bytes>>,
request_timeout: std::time::Duration,
retry_policy: PushRetryPolicy,
allow_private_urls: bool,
}
impl Default for HttpPushSender {
fn default() -> Self {
Self::new()
}
}
impl HttpPushSender {
#[must_use]
pub fn new() -> Self {
Self::with_timeout(DEFAULT_PUSH_REQUEST_TIMEOUT)
}
#[must_use]
pub fn with_timeout(request_timeout: std::time::Duration) -> Self {
let client = Client::builder(TokioExecutor::new()).build_http();
Self {
client,
request_timeout,
retry_policy: PushRetryPolicy::default(),
allow_private_urls: false,
}
}
#[must_use]
pub fn with_retry_policy(mut self, policy: PushRetryPolicy) -> Self {
self.retry_policy = policy;
self
}
#[must_use]
pub const fn allow_private_urls(mut self) -> Self {
self.allow_private_urls = true;
self
}
}
#[allow(clippy::missing_const_for_fn)] fn is_private_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
v4.is_loopback() || v4.is_private() || v4.is_link_local() || v4.is_unspecified() || v4.octets()[0] == 100 && (v4.octets()[1] & 0xC0) == 64 }
IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || (v6.segments()[0] & 0xfe00) == 0xfc00
|| (v6.segments()[0] & 0xffc0) == 0xfe80
}
}
}
#[allow(clippy::case_sensitive_file_extension_comparisons)] pub(crate) fn validate_webhook_url(url: &str) -> A2aResult<()> {
let uri: hyper::Uri = url
.parse()
.map_err(|e| A2aError::invalid_params(format!("invalid webhook URL: {e}")))?;
match uri.scheme_str() {
Some("http" | "https") => {}
Some(other) => {
return Err(A2aError::invalid_params(format!(
"webhook URL has unsupported scheme: {other} (expected http or https)"
)));
}
None => {
return Err(A2aError::invalid_params(
"webhook URL missing scheme (expected http:// or https://)",
));
}
}
let host = uri
.host()
.ok_or_else(|| A2aError::invalid_params("webhook URL missing host"))?;
let host_bare = host.trim_start_matches('[').trim_end_matches(']');
if let Ok(ip) = host_bare.parse::<IpAddr>() {
if is_private_ip(ip) {
return Err(A2aError::invalid_params(format!(
"webhook URL targets private/loopback address: {host}"
)));
}
}
let host_lower = host.to_ascii_lowercase();
if host_lower == "localhost"
|| host_lower.ends_with(".local")
|| host_lower.ends_with(".internal")
{
return Err(A2aError::invalid_params(format!(
"webhook URL targets local/internal hostname: {host}"
)));
}
Ok(())
}
pub(crate) async fn validate_webhook_url_with_dns(url: &str) -> A2aResult<()> {
validate_webhook_url(url)?;
let uri: hyper::Uri = url
.parse()
.map_err(|e| A2aError::invalid_params(format!("invalid webhook URL: {e}")))?;
let host = uri
.host()
.ok_or_else(|| A2aError::invalid_params("webhook URL missing host"))?;
let host_bare = host.trim_start_matches('[').trim_end_matches(']');
if host_bare.parse::<IpAddr>().is_ok() {
return Ok(());
}
let port = uri.port_u16().unwrap_or_else(|| {
if uri.scheme_str() == Some("https") {
443
} else {
80
}
});
let addr = format!("{host_bare}:{port}");
let resolved = tokio::net::lookup_host(&addr).await.map_err(|e| {
A2aError::invalid_params(format!(
"webhook URL hostname could not be resolved: {host_bare}: {e}"
))
})?;
let mut found_any = false;
for socket_addr in resolved {
found_any = true;
let ip = socket_addr.ip();
if is_private_ip(ip) {
return Err(A2aError::invalid_params(format!(
"webhook URL hostname {host_bare} resolves to private/loopback address: {ip}"
)));
}
}
if !found_any {
return Err(A2aError::invalid_params(format!(
"webhook URL hostname {host_bare} did not resolve to any addresses"
)));
}
Ok(())
}
fn validate_header_value(value: &str, name: &str) -> A2aResult<()> {
if value.contains('\r') || value.contains('\n') {
return Err(A2aError::invalid_params(format!(
"{name} contains invalid characters (CR/LF)"
)));
}
Ok(())
}
#[allow(clippy::manual_async_fn, clippy::too_many_lines)]
impl PushSender for HttpPushSender {
fn allows_private_urls(&self) -> bool {
self.allow_private_urls
}
fn send<'a>(
&'a self,
url: &'a str,
event: &'a StreamResponse,
config: &'a TaskPushNotificationConfig,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
trace_info!(url, "delivering push notification");
if !self.allow_private_urls {
validate_webhook_url_with_dns(url).await?;
}
if let Some(ref auth) = config.authentication {
validate_header_value(&auth.credentials, "authentication credentials")?;
validate_header_value(&auth.scheme, "authentication scheme")?;
}
if let Some(ref token) = config.token {
validate_header_value(token, "notification token")?;
}
let body_bytes: Bytes = serde_json::to_vec(event)
.map(Bytes::from)
.map_err(|e| A2aError::internal(format!("push serialization: {e}")))?;
let mut last_err = String::new();
for attempt in 0..self.retry_policy.max_attempts {
let mut builder = hyper::Request::builder()
.method(hyper::Method::POST)
.uri(url)
.header("content-type", "application/json");
if let Some(ref auth) = config.authentication {
match auth.scheme.as_str() {
"bearer" => {
builder = builder
.header("authorization", format!("Bearer {}", auth.credentials));
}
"basic" => {
builder = builder
.header("authorization", format!("Basic {}", auth.credentials));
}
_ => {
trace_warn!(
scheme = auth.scheme.as_str(),
"unknown authentication scheme; no auth header set"
);
}
}
}
if let Some(ref token) = config.token {
builder = builder.header("a2a-notification-token", token.as_str());
}
let req = builder
.body(Full::new(body_bytes.clone()))
.map_err(|e| A2aError::internal(format!("push request build: {e}")))?;
let request_result =
tokio::time::timeout(self.request_timeout, self.client.request(req)).await;
match request_result {
Ok(Ok(resp)) if resp.status().is_success() => {
trace_debug!(url, "push notification delivered");
return Ok(());
}
Ok(Ok(resp)) => {
last_err = format!("push notification got HTTP {}", resp.status());
trace_warn!(url, attempt, status = %resp.status(), "push delivery failed");
}
Ok(Err(e)) => {
last_err = format!("push notification failed: {e}");
trace_warn!(url, attempt, error = %e, "push delivery error");
}
Err(_) => {
last_err = format!(
"push notification timed out after {}s",
self.request_timeout.as_secs()
);
trace_warn!(url, attempt, "push delivery timed out");
}
}
if attempt < self.retry_policy.max_attempts - 1 {
let delay = self
.retry_policy
.backoff
.get(attempt)
.or_else(|| self.retry_policy.backoff.last());
if let Some(delay) = delay {
tokio::time::sleep(*delay).await;
}
}
}
Err(A2aError::internal(last_err))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn push_retry_policy_with_max_attempts() {
let policy = PushRetryPolicy::default().with_max_attempts(5);
assert_eq!(policy.max_attempts, 5);
assert_eq!(policy.backoff.len(), 2);
}
#[test]
fn push_retry_policy_with_backoff() {
let backoff = vec![
std::time::Duration::from_millis(100),
std::time::Duration::from_millis(500),
std::time::Duration::from_secs(1),
];
let policy = PushRetryPolicy::default().with_backoff(backoff.clone());
assert_eq!(policy.backoff, backoff);
assert_eq!(policy.max_attempts, 3);
}
#[test]
fn http_push_sender_with_retry_policy() {
let policy = PushRetryPolicy::default().with_max_attempts(10);
let sender = HttpPushSender::new().with_retry_policy(policy);
assert_eq!(sender.retry_policy.max_attempts, 10);
}
#[test]
fn rejects_url_without_host() {
assert!(validate_webhook_url("http:///path").is_err());
}
#[test]
fn http_push_sender_allow_private_urls() {
let sender = HttpPushSender::new().allow_private_urls();
assert!(sender.allow_private_urls);
}
#[test]
fn http_push_sender_default() {
let sender = HttpPushSender::default();
assert_eq!(sender.request_timeout, DEFAULT_PUSH_REQUEST_TIMEOUT);
assert!(!sender.allow_private_urls);
}
#[test]
fn push_retry_policy_default() {
let policy = PushRetryPolicy::default();
assert_eq!(policy.max_attempts, 3);
assert_eq!(policy.backoff.len(), 2);
assert_eq!(policy.backoff[0], std::time::Duration::from_secs(1));
assert_eq!(policy.backoff[1], std::time::Duration::from_secs(2));
}
#[test]
fn rejects_loopback_ipv4() {
assert!(validate_webhook_url("http://127.0.0.1:8080/webhook").is_err());
}
#[test]
fn rejects_private_10_range() {
assert!(validate_webhook_url("http://10.0.0.1/webhook").is_err());
}
#[test]
fn rejects_private_172_range() {
assert!(validate_webhook_url("http://172.16.0.1/webhook").is_err());
}
#[test]
fn rejects_private_192_168_range() {
assert!(validate_webhook_url("http://192.168.1.1/webhook").is_err());
}
#[test]
fn rejects_link_local() {
assert!(validate_webhook_url("http://169.254.169.254/latest").is_err());
}
#[test]
fn rejects_localhost() {
assert!(validate_webhook_url("http://localhost:8080/webhook").is_err());
}
#[test]
fn rejects_dot_local() {
assert!(validate_webhook_url("http://myservice.local/webhook").is_err());
}
#[test]
fn rejects_dot_internal() {
assert!(validate_webhook_url("http://metadata.internal/webhook").is_err());
}
#[test]
fn rejects_ipv6_loopback() {
assert!(validate_webhook_url("http://[::1]:8080/webhook").is_err());
}
#[test]
fn accepts_public_url() {
assert!(validate_webhook_url("https://example.com/webhook").is_ok());
}
#[test]
fn accepts_public_ip() {
assert!(validate_webhook_url("https://203.0.113.1/webhook").is_ok());
}
#[test]
fn rejects_header_with_crlf() {
assert!(validate_header_value("token\r\nX-Injected: value", "test").is_err());
}
#[test]
fn rejects_header_with_cr() {
assert!(validate_header_value("token\rvalue", "test").is_err());
}
#[test]
fn rejects_header_with_lf() {
assert!(validate_header_value("token\nvalue", "test").is_err());
}
#[test]
fn accepts_clean_header_value() {
assert!(validate_header_value("Bearer abc123+/=", "test").is_ok());
}
#[test]
fn rejects_url_without_scheme() {
assert!(validate_webhook_url("example.com/webhook").is_err());
}
#[test]
fn rejects_ftp_scheme() {
assert!(validate_webhook_url("ftp://example.com/webhook").is_err());
}
#[test]
fn rejects_file_scheme() {
assert!(validate_webhook_url("file:///etc/passwd").is_err());
}
#[test]
fn accepts_http_scheme() {
assert!(validate_webhook_url("http://example.com/webhook").is_ok());
}
#[test]
fn rejects_cgnat_range() {
assert!(validate_webhook_url("http://100.64.0.1/webhook").is_err());
}
#[test]
fn rejects_unspecified_ipv4() {
assert!(validate_webhook_url("http://0.0.0.0/webhook").is_err());
}
#[test]
fn rejects_ipv6_unique_local() {
assert!(validate_webhook_url("http://[fc00::1]:8080/webhook").is_err());
}
#[test]
fn rejects_ipv6_link_local() {
assert!(validate_webhook_url("http://[fe80::1]:8080/webhook").is_err());
}
#[tokio::test]
async fn dns_rejects_loopback_ip_literal() {
let result = validate_webhook_url_with_dns("http://127.0.0.1:8080/webhook").await;
assert!(result.is_err(), "loopback IP should be rejected");
}
#[tokio::test]
async fn dns_rejects_private_ip_literal() {
let result = validate_webhook_url_with_dns("http://10.0.0.1/webhook").await;
assert!(result.is_err(), "private IP should be rejected");
}
#[tokio::test]
async fn dns_rejects_localhost_hostname() {
let result = validate_webhook_url_with_dns("http://localhost:8080/webhook").await;
assert!(result.is_err(), "localhost should be rejected");
}
#[tokio::test]
async fn dns_rejects_invalid_scheme() {
let result = validate_webhook_url_with_dns("ftp://example.com/webhook").await;
assert!(result.is_err(), "ftp scheme should be rejected");
}
#[tokio::test]
async fn dns_rejects_missing_host() {
let result = validate_webhook_url_with_dns("http:///path").await;
assert!(result.is_err(), "missing host should be rejected");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn dns_rejects_unresolvable_hostname() {
let (tx, rx) = tokio::sync::oneshot::channel();
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let result = rt.block_on(validate_webhook_url_with_dns(
"https://this-hostname-definitely-does-not-exist-a2a-test.invalid/webhook",
));
let _ = tx.send(result);
});
match tokio::time::timeout(std::time::Duration::from_secs(5), rx).await {
Ok(Ok(result)) => {
assert!(result.is_err(), "unresolvable hostname should be rejected");
}
Ok(Err(_)) => panic!("sender dropped without sending"),
Err(_elapsed) => {
}
}
}
#[tokio::test]
async fn dns_accepts_ip_literal_public() {
let result = validate_webhook_url_with_dns("https://203.0.113.1/webhook").await;
assert!(result.is_ok(), "public IP literal should be accepted");
}
}