Skip to main content

a2a_protocol_server/push/
sender.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 Tom F.
3
4//! Push notification sender trait and HTTP implementation.
5//!
6//! [`PushSender`] abstracts the delivery of streaming events to client webhook
7//! endpoints. [`HttpPushSender`] uses hyper to POST events over HTTP(S).
8//!
9//! # Security
10//!
11//! [`HttpPushSender`] validates webhook URLs to reject private/loopback
12//! addresses (SSRF protection) and sanitizes authentication credentials
13//! to prevent HTTP header injection.
14
15use std::future::Future;
16use std::net::IpAddr;
17use std::pin::Pin;
18
19use a2a_protocol_types::error::{A2aError, A2aResult};
20use a2a_protocol_types::events::StreamResponse;
21use a2a_protocol_types::push::TaskPushNotificationConfig;
22use bytes::Bytes;
23use http_body_util::Full;
24use hyper_util::client::legacy::Client;
25use hyper_util::rt::TokioExecutor;
26
27/// Trait for delivering push notifications to client webhooks.
28///
29/// Object-safe; used as `Box<dyn PushSender>`.
30pub trait PushSender: Send + Sync + 'static {
31    /// Sends a streaming event to the client's webhook URL.
32    ///
33    /// # Errors
34    ///
35    /// Returns an [`A2aError`] if delivery fails after all retries.
36    fn send<'a>(
37        &'a self,
38        url: &'a str,
39        event: &'a StreamResponse,
40        config: &'a TaskPushNotificationConfig,
41    ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>>;
42}
43
44/// Maximum number of delivery attempts before giving up.
45const MAX_PUSH_ATTEMPTS: usize = 3;
46
47/// Backoff durations between retry attempts.
48const PUSH_RETRY_BACKOFF: [std::time::Duration; 2] = [
49    std::time::Duration::from_secs(1),
50    std::time::Duration::from_secs(2),
51];
52
53/// Default per-request timeout for push notification delivery.
54const DEFAULT_PUSH_REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
55
56/// HTTP-based [`PushSender`] using hyper.
57///
58/// Retries up to 3 times with exponential backoff on
59/// transient HTTP errors.
60///
61/// # Security
62///
63/// - Rejects webhook URLs targeting private/loopback/link-local addresses
64///   to prevent SSRF attacks.
65/// - Validates authentication credentials to prevent HTTP header injection
66///   (rejects values containing CR/LF characters).
67#[derive(Debug)]
68pub struct HttpPushSender {
69    client: Client<hyper_util::client::legacy::connect::HttpConnector, Full<Bytes>>,
70    request_timeout: std::time::Duration,
71    /// Whether to skip SSRF URL validation (for testing only).
72    allow_private_urls: bool,
73}
74
75impl Default for HttpPushSender {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81impl HttpPushSender {
82    /// Creates a new [`HttpPushSender`] with the default 30-second request timeout.
83    #[must_use]
84    pub fn new() -> Self {
85        Self::with_timeout(DEFAULT_PUSH_REQUEST_TIMEOUT)
86    }
87
88    /// Creates a new [`HttpPushSender`] with a custom per-request timeout.
89    #[must_use]
90    pub fn with_timeout(request_timeout: std::time::Duration) -> Self {
91        let client = Client::builder(TokioExecutor::new()).build_http();
92        Self {
93            client,
94            request_timeout,
95            allow_private_urls: false,
96        }
97    }
98
99    /// Creates an [`HttpPushSender`] that allows private/loopback URLs.
100    ///
101    /// **Warning:** This disables SSRF protection and should only be used
102    /// in testing or trusted environments.
103    #[must_use]
104    pub const fn allow_private_urls(mut self) -> Self {
105        self.allow_private_urls = true;
106        self
107    }
108}
109
110/// Returns `true` if the given IP address is private, loopback, or link-local.
111#[allow(clippy::missing_const_for_fn)] // IpAddr methods aren't const-stable everywhere
112fn is_private_ip(ip: IpAddr) -> bool {
113    match ip {
114        IpAddr::V4(v4) => {
115            v4.is_loopback()          // 127.0.0.0/8
116                || v4.is_private()    // 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
117                || v4.is_link_local() // 169.254.0.0/16
118                || v4.is_unspecified() // 0.0.0.0
119                || v4.octets()[0] == 100 && (v4.octets()[1] & 0xC0) == 64 // 100.64.0.0/10 (CGNAT)
120        }
121        IpAddr::V6(v6) => {
122            v6.is_loopback()          // ::1
123                || v6.is_unspecified() // ::
124                // fc00::/7 (unique local)
125                || (v6.segments()[0] & 0xfe00) == 0xfc00
126                // fe80::/10 (link-local)
127                || (v6.segments()[0] & 0xffc0) == 0xfe80
128        }
129    }
130}
131
132/// Validates a webhook URL to prevent SSRF attacks.
133///
134/// Rejects URLs targeting private/loopback/link-local addresses.
135#[allow(clippy::case_sensitive_file_extension_comparisons)] // host_lower is already lowercased
136fn validate_webhook_url(url: &str) -> A2aResult<()> {
137    // Parse the URL to extract the host.
138    let uri: hyper::Uri = url
139        .parse()
140        .map_err(|e| A2aError::invalid_params(format!("invalid webhook URL: {e}")))?;
141
142    let host = uri
143        .host()
144        .ok_or_else(|| A2aError::invalid_params("webhook URL missing host"))?;
145
146    // Strip brackets from IPv6 addresses (hyper::Uri returns "[::1]" as host).
147    let host_bare = host.trim_start_matches('[').trim_end_matches(']');
148
149    // Try to parse the host as an IP address directly.
150    if let Ok(ip) = host_bare.parse::<IpAddr>() {
151        if is_private_ip(ip) {
152            return Err(A2aError::invalid_params(format!(
153                "webhook URL targets private/loopback address: {host}"
154            )));
155        }
156    }
157
158    // Check for well-known private hostnames.
159    let host_lower = host.to_ascii_lowercase();
160    if host_lower == "localhost"
161        || host_lower.ends_with(".local")
162        || host_lower.ends_with(".internal")
163    {
164        return Err(A2aError::invalid_params(format!(
165            "webhook URL targets local/internal hostname: {host}"
166        )));
167    }
168
169    Ok(())
170}
171
172/// Validates that a header value contains no CR/LF characters.
173fn validate_header_value(value: &str, name: &str) -> A2aResult<()> {
174    if value.contains('\r') || value.contains('\n') {
175        return Err(A2aError::invalid_params(format!(
176            "{name} contains invalid characters (CR/LF)"
177        )));
178    }
179    Ok(())
180}
181
182#[allow(clippy::manual_async_fn, clippy::too_many_lines)]
183impl PushSender for HttpPushSender {
184    fn send<'a>(
185        &'a self,
186        url: &'a str,
187        event: &'a StreamResponse,
188        config: &'a TaskPushNotificationConfig,
189    ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
190        Box::pin(async move {
191            trace_info!(url, "delivering push notification");
192
193            // SSRF protection: reject private/loopback addresses.
194            if !self.allow_private_urls {
195                validate_webhook_url(url)?;
196            }
197
198            // Header injection protection: validate credentials.
199            if let Some(ref auth) = config.authentication {
200                validate_header_value(&auth.credentials, "authentication credentials")?;
201                validate_header_value(&auth.scheme, "authentication scheme")?;
202            }
203            if let Some(ref token) = config.token {
204                validate_header_value(token, "notification token")?;
205            }
206
207            let body_bytes = serde_json::to_vec(event)
208                .map_err(|e| A2aError::internal(format!("push serialization: {e}")))?;
209
210            let mut last_err = String::new();
211
212            for attempt in 0..MAX_PUSH_ATTEMPTS {
213                let mut builder = hyper::Request::builder()
214                    .method(hyper::Method::POST)
215                    .uri(url)
216                    .header("content-type", "application/json");
217
218                // Set authentication headers from config.
219                if let Some(ref auth) = config.authentication {
220                    match auth.scheme.as_str() {
221                        "bearer" => {
222                            builder = builder
223                                .header("authorization", format!("Bearer {}", auth.credentials));
224                        }
225                        "basic" => {
226                            builder = builder
227                                .header("authorization", format!("Basic {}", auth.credentials));
228                        }
229                        _ => {
230                            trace_warn!(
231                                scheme = auth.scheme.as_str(),
232                                "unknown authentication scheme; no auth header set"
233                            );
234                        }
235                    }
236                }
237
238                // Set notification token header if present.
239                if let Some(ref token) = config.token {
240                    builder = builder.header("a2a-notification-token", token.as_str());
241                }
242
243                let req = builder
244                    .body(Full::new(Bytes::from(body_bytes.clone())))
245                    .map_err(|e| A2aError::internal(format!("push request build: {e}")))?;
246
247                let request_result =
248                    tokio::time::timeout(self.request_timeout, self.client.request(req)).await;
249
250                match request_result {
251                    Ok(Ok(resp)) if resp.status().is_success() => {
252                        trace_debug!(url, "push notification delivered");
253                        return Ok(());
254                    }
255                    Ok(Ok(resp)) => {
256                        last_err = format!("push notification got HTTP {}", resp.status());
257                        trace_warn!(url, attempt, status = %resp.status(), "push delivery failed");
258                    }
259                    Ok(Err(e)) => {
260                        last_err = format!("push notification failed: {e}");
261                        trace_warn!(url, attempt, error = %e, "push delivery error");
262                    }
263                    Err(_) => {
264                        last_err = format!(
265                            "push notification timed out after {}s",
266                            self.request_timeout.as_secs()
267                        );
268                        trace_warn!(url, attempt, "push delivery timed out");
269                    }
270                }
271
272                // Retry with backoff (except on last attempt).
273                if attempt < MAX_PUSH_ATTEMPTS - 1 {
274                    if let Some(delay) = PUSH_RETRY_BACKOFF.get(attempt) {
275                        tokio::time::sleep(*delay).await;
276                    }
277                }
278            }
279
280            Err(A2aError::internal(last_err))
281        })
282    }
283}
284
285// ── Tests ─────────────────────────────────────────────────────────────────────
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn rejects_loopback_ipv4() {
293        assert!(validate_webhook_url("http://127.0.0.1:8080/webhook").is_err());
294    }
295
296    #[test]
297    fn rejects_private_10_range() {
298        assert!(validate_webhook_url("http://10.0.0.1/webhook").is_err());
299    }
300
301    #[test]
302    fn rejects_private_172_range() {
303        assert!(validate_webhook_url("http://172.16.0.1/webhook").is_err());
304    }
305
306    #[test]
307    fn rejects_private_192_168_range() {
308        assert!(validate_webhook_url("http://192.168.1.1/webhook").is_err());
309    }
310
311    #[test]
312    fn rejects_link_local() {
313        assert!(validate_webhook_url("http://169.254.169.254/latest").is_err());
314    }
315
316    #[test]
317    fn rejects_localhost() {
318        assert!(validate_webhook_url("http://localhost:8080/webhook").is_err());
319    }
320
321    #[test]
322    fn rejects_dot_local() {
323        assert!(validate_webhook_url("http://myservice.local/webhook").is_err());
324    }
325
326    #[test]
327    fn rejects_dot_internal() {
328        assert!(validate_webhook_url("http://metadata.internal/webhook").is_err());
329    }
330
331    #[test]
332    fn rejects_ipv6_loopback() {
333        assert!(validate_webhook_url("http://[::1]:8080/webhook").is_err());
334    }
335
336    #[test]
337    fn accepts_public_url() {
338        assert!(validate_webhook_url("https://example.com/webhook").is_ok());
339    }
340
341    #[test]
342    fn accepts_public_ip() {
343        assert!(validate_webhook_url("https://203.0.113.1/webhook").is_ok());
344    }
345
346    #[test]
347    fn rejects_header_with_crlf() {
348        assert!(validate_header_value("token\r\nX-Injected: value", "test").is_err());
349    }
350
351    #[test]
352    fn rejects_header_with_cr() {
353        assert!(validate_header_value("token\rvalue", "test").is_err());
354    }
355
356    #[test]
357    fn rejects_header_with_lf() {
358        assert!(validate_header_value("token\nvalue", "test").is_err());
359    }
360
361    #[test]
362    fn accepts_clean_header_value() {
363        assert!(validate_header_value("Bearer abc123+/=", "test").is_ok());
364    }
365}