use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use reqwest::StatusCode;
use tokio::sync::Semaphore;
use tokio::sync::broadcast::error::RecvError;
use tokio_util::sync::CancellationToken;
use crate::a2a::core::bus::Event;
use crate::a2a::core::push_notifications::PushNotificationConfig;
use crate::a2a::core::ssrf;
use crate::a2a::core::task_types::{Task, TaskId};
use crate::a2a::state::A2aState;
const DELIVERY_TIMEOUT_SECS: u64 = 10;
const MAX_RETRIES: u32 = 3;
const BACKOFF_BASE_MS: u64 = 200;
const NOTIFICATION_TOKEN_HEADER: &str = "X-Basemind-Notification-Token";
const MAX_INFLIGHT_DELIVERIES: usize = 32;
pub fn spawn_delivery_worker(
state: A2aState,
cancel: CancellationToken,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
if let Err(error) = build_base_client() {
tracing::error!(%error, "failed to build webhook delivery client; worker exiting");
return;
}
let permits = Arc::new(Semaphore::new(MAX_INFLIGHT_DELIVERIES));
let mut rx = state.bus.subscribe();
loop {
tokio::select! {
() = cancel.cancelled() => {
tracing::debug!("webhook delivery worker cancelled");
return;
}
received = rx.recv() => match received {
Ok(event) => handle_event(&state, &permits, event).await,
Err(RecvError::Lagged(skipped)) => {
tracing::warn!(
skipped,
"webhook delivery worker lagged behind the bus; events dropped",
);
}
Err(RecvError::Closed) => {
tracing::debug!("bus closed; webhook delivery worker exiting");
return;
}
},
}
}
})
}
fn build_base_client() -> Result<reqwest::Client, reqwest::Error> {
let timeout = Duration::from_secs(DELIVERY_TIMEOUT_SECS);
reqwest::Client::builder()
.connect_timeout(timeout)
.timeout(timeout)
.redirect(reqwest::redirect::Policy::none())
.build()
}
async fn handle_event(state: &A2aState, permits: &Arc<Semaphore>, event: Event) {
let Some(task_id) = task_id_for_event(&event) else {
return;
};
let configs: Vec<PushNotificationConfig> = {
let store = state.push_notifications.read().await;
store.list(&task_id).to_vec()
};
if configs.is_empty() {
return;
}
let body: Arc<[u8]> = match serde_json::to_vec(&event) {
Ok(body) => Arc::from(body.into_boxed_slice()),
Err(error) => {
tracing::error!(%error, %task_id, "failed to serialize bus event for webhook delivery");
return;
}
};
spawn_deliveries(permits, configs, body);
}
fn spawn_deliveries(
permits: &Arc<Semaphore>,
configs: Vec<PushNotificationConfig>,
body: Arc<[u8]>,
) {
for config in configs {
let body = Arc::clone(&body);
spawn_bounded(permits, move || async move {
deliver_with_retries(&config, &body).await;
});
}
}
fn spawn_bounded<Work, Fut>(permits: &Arc<Semaphore>, work: Work)
where
Work: FnOnce() -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send,
{
let permits = Arc::clone(permits);
tokio::spawn(async move {
let Ok(_permit) = permits.acquire().await else {
return;
};
work().await;
});
}
fn task_id_for_event(event: &Event) -> Option<TaskId> {
let task: &Task = match event {
Event::TaskCreated(task) => task,
Event::TaskStatusChanged { task, .. } => task,
};
Some(task.id)
}
async fn deliver_with_retries(config: &PushNotificationConfig, body: &[u8]) {
for attempt in 0..=MAX_RETRIES {
match deliver_once(config, body).await {
DeliveryOutcome::Success => return,
DeliveryOutcome::Aborted => return,
DeliveryOutcome::ClientError(status) => {
tracing::warn!(
url = %config.url,
status = status.as_u16(),
"webhook rejected delivery with a 4xx; not retrying",
);
return;
}
DeliveryOutcome::Retryable(reason) => {
if attempt == MAX_RETRIES {
tracing::warn!(
url = %config.url,
attempts = attempt + 1,
reason = %reason,
"webhook delivery exhausted retries",
);
return;
}
let delay_ms = BACKOFF_BASE_MS << attempt;
tracing::debug!(
url = %config.url,
attempt = attempt + 1,
delay_ms,
reason = %reason,
"webhook delivery failed; backing off before retry",
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
}
}
enum DeliveryOutcome {
Success,
Aborted,
ClientError(StatusCode),
Retryable(String),
}
async fn deliver_once(config: &PushNotificationConfig, body: &[u8]) -> DeliveryOutcome {
let target = match ssrf::validate_webhook_url(&config.url) {
Ok(target) => target,
Err(rejected) => {
tracing::warn!(
url = %config.url,
reason = %rejected.reason,
"webhook url failed SSRF validation at delivery time; aborting",
);
return DeliveryOutcome::Aborted;
}
};
let safe_addr = match resolve_safe_addr(&target).await {
Ok(addr) => addr,
Err(reason) => {
tracing::warn!(
url = %config.url,
host = %target.host,
reason = %reason,
"webhook host resolution blocked or failed; aborting delivery",
);
return DeliveryOutcome::Aborted;
}
};
let client = match build_pinned_client(&target.host, safe_addr) {
Ok(client) => client,
Err(error) => {
tracing::error!(url = %config.url, %error, "failed to build pinned webhook client");
return DeliveryOutcome::Aborted;
}
};
let mut request = client
.post(&config.url)
.header(reqwest::header::CONTENT_TYPE, "application/json")
.body(body.to_vec());
if !config.token.is_empty() {
request = request.header(NOTIFICATION_TOKEN_HEADER, &config.token);
}
if let Some(auth) = &config.authentication {
request = request.header(
reqwest::header::AUTHORIZATION,
format!("{} {}", auth.scheme, auth.credentials),
);
}
match request.send().await {
Ok(response) => classify_response(response.status()),
Err(error) => DeliveryOutcome::Retryable(format!("transport error: {error}")),
}
}
fn classify_response(status: StatusCode) -> DeliveryOutcome {
if status.is_success() {
DeliveryOutcome::Success
} else if status.is_client_error() {
DeliveryOutcome::ClientError(status)
} else {
DeliveryOutcome::Retryable(format!("server responded {status}"))
}
}
async fn resolve_safe_addr(target: &ssrf::WebhookTarget) -> Result<SocketAddr, String> {
let addrs: Vec<SocketAddr> = tokio::net::lookup_host((target.host.as_str(), target.port))
.await
.map_err(|error| format!("dns resolution failed: {error}"))?
.collect();
if addrs.is_empty() {
return Err("host resolved to no addresses".to_owned());
}
for addr in &addrs {
if let Some(reason) = ssrf::ip_is_blocked(addr.ip()) {
return Err(format!(
"resolved address {} is blocked: {reason}",
addr.ip()
));
}
}
Ok(addrs[0])
}
fn build_pinned_client(host: &str, addr: SocketAddr) -> Result<reqwest::Client, reqwest::Error> {
let timeout = Duration::from_secs(DELIVERY_TIMEOUT_SECS);
reqwest::Client::builder()
.connect_timeout(timeout)
.timeout(timeout)
.redirect(reqwest::redirect::Policy::none())
.resolve(host, addr)
.build()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use chrono::Utc;
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
use tokio::net::TcpListener;
use crate::a2a::core::push_notifications::{PushNotificationAuth, PushNotificationId};
use crate::a2a::core::task_types::{ContextId, TaskMessage, TaskState, TaskStatus};
async fn capture_one(
status_line: &'static str,
) -> (SocketAddr, tokio::task::JoinHandle<String>) {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind loopback");
let addr = listener.local_addr().expect("local addr");
let handle = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.expect("accept connection");
let mut buf = vec![0_u8; 8192];
let n = stream.read(&mut buf).await.expect("read request");
let request = String::from_utf8_lossy(&buf[..n]).into_owned();
let response =
format!("HTTP/1.1 {status_line}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n");
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.flush().await;
request
});
(addr, handle)
}
fn config_for(
addr: SocketAddr,
token: &str,
auth: Option<PushNotificationAuth>,
) -> PushNotificationConfig {
PushNotificationConfig {
id: PushNotificationId::new(),
task_id: TaskId::new(),
url: format!("http://{addr}/webhook"),
token: token.to_owned(),
authentication: auth,
}
}
fn sample_task() -> Task {
Task {
id: TaskId::new(),
context_id: ContextId::new(),
status: TaskStatus {
state: TaskState::Working,
message: None,
timestamp: Utc::now(),
},
artifacts: Vec::new(),
history: Vec::<TaskMessage>::new(),
metadata: None,
assignee: None,
creator: None,
deadline: None,
}
}
#[tokio::test]
async fn deliver_once_succeeds_on_2xx_and_sends_headers() {
let (addr, handle) = capture_one("200 OK").await;
let auth = PushNotificationAuth {
scheme: "Bearer".to_owned(),
credentials: "sekret".to_owned(),
};
let config = config_for(addr, "corr-token", Some(auth));
let task = sample_task();
let event = Event::TaskCreated(Arc::new(task));
let body = serde_json::to_vec(&event).expect("serialize event");
let client = reqwest::Client::builder()
.resolve(&addr.ip().to_string(), addr)
.build()
.expect("build client");
let mut request = client
.post(&config.url)
.header(reqwest::header::CONTENT_TYPE, "application/json")
.body(body.clone());
request = request.header(NOTIFICATION_TOKEN_HEADER, &config.token);
request = request.header(reqwest::header::AUTHORIZATION, "Bearer sekret");
let response = request.send().await.expect("send request");
assert!(
matches!(
classify_response(response.status()),
DeliveryOutcome::Success
),
"2xx must classify as success",
);
let captured = handle.await.expect("listener task");
assert!(
captured.starts_with("POST /webhook "),
"must POST the path: {captured}"
);
assert!(
captured
.to_lowercase()
.contains("content-type: application/json"),
"must set JSON content-type: {captured}",
);
assert!(
captured.contains("x-basemind-notification-token: corr-token")
|| captured
.to_lowercase()
.contains("x-basemind-notification-token: corr-token"),
"must forward the correlation token header: {captured}",
);
assert!(
captured
.to_lowercase()
.contains("authorization: bearer sekret"),
"must forward the authorization header: {captured}",
);
assert!(
captured.contains("\"type\":\"task_created\""),
"must POST the serialized event body: {captured}",
);
}
#[tokio::test]
async fn pinned_client_does_not_follow_redirects() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind loopback");
let addr = listener.local_addr().expect("local addr");
let server = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.expect("accept");
let mut buf = vec![0_u8; 4096];
let _ = stream.read(&mut buf).await;
let response = "HTTP/1.1 302 Found\r\nLocation: http://169.254.169.254/\r\n\
Content-Length: 0\r\nConnection: close\r\n\r\n";
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.flush().await;
});
let client =
build_pinned_client(&addr.ip().to_string(), addr).expect("build pinned client");
let response = client
.post(format!("http://{addr}/webhook"))
.body(Vec::new())
.send()
.await
.expect("send must not error");
assert_eq!(
response.status(),
StatusCode::FOUND,
"the 302 must be surfaced, not followed to the metadata host",
);
server.await.expect("server task");
}
#[test]
fn classify_4xx_is_client_error() {
assert!(matches!(
classify_response(StatusCode::BAD_REQUEST),
DeliveryOutcome::ClientError(_)
));
}
#[test]
fn classify_5xx_is_retryable() {
assert!(matches!(
classify_response(StatusCode::INTERNAL_SERVER_ERROR),
DeliveryOutcome::Retryable(_)
));
}
#[tokio::test]
async fn deliver_once_aborts_on_loopback_ssrf() {
let (addr, handle) = capture_one("200 OK").await;
let config = config_for(addr, "", None);
let outcome = deliver_once(&config, b"{}").await;
assert!(
matches!(outcome, DeliveryOutcome::Aborted),
"loopback delivery must be aborted by the SSRF guard",
);
handle.abort();
}
#[test]
fn task_events_expose_task_id() {
let task = sample_task();
let id = task.id;
let event = Event::TaskCreated(Arc::new(task));
assert_eq!(task_id_for_event(&event), Some(id));
}
}