use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use rand::RngExt;
use url::Url;
use crate::push::claim::{
A2aPushDeliveryStore, AbandonedReason, DeliveryClaim, DeliveryErrorClass, DeliveryOutcome,
GaveUpReason,
};
use crate::push::secret::Secret;
use crate::push::ssrf::{
OutboundUrlValidator, SsrfBlockReason, SsrfDecision, decide as ssrf_decide,
};
use crate::storage::A2aStorageError;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct PushDeliveryConfig {
pub max_attempts: u32,
pub backoff_base: Duration,
pub backoff_cap: Duration,
pub backoff_jitter: f32,
pub request_timeout: Duration,
pub connect_timeout: Duration,
pub read_timeout: Duration,
pub claim_expiry: Duration,
pub max_payload_bytes: usize,
pub allow_insecure_urls: bool,
}
impl Default for PushDeliveryConfig {
fn default() -> Self {
Self {
max_attempts: 8,
backoff_base: Duration::from_secs(2),
backoff_cap: Duration::from_secs(60),
backoff_jitter: 0.25,
request_timeout: Duration::from_secs(30),
connect_timeout: Duration::from_secs(5),
read_timeout: Duration::from_secs(30),
claim_expiry: Duration::from_secs(600),
max_payload_bytes: 1024 * 1024,
allow_insecure_urls: false,
}
}
}
#[derive(Clone)]
pub struct PushTarget {
pub tenant: String,
pub owner: String,
pub task_id: String,
pub event_sequence: u64,
pub config_id: String,
pub url: Url,
pub auth_scheme: String,
pub auth_credentials: Secret,
pub token: Option<Secret>,
}
pub trait PushDnsResolver: Send + Sync {
fn resolve(
&self,
host: &str,
port: u16,
) -> futures::future::BoxFuture<'_, Result<Vec<IpAddr>, String>>;
}
pub struct TokioDnsResolver;
impl PushDnsResolver for TokioDnsResolver {
fn resolve(
&self,
host: &str,
port: u16,
) -> futures::future::BoxFuture<'_, Result<Vec<IpAddr>, String>> {
let host = host.to_string();
Box::pin(async move {
let ips: Vec<IpAddr> = tokio::net::lookup_host((host.as_str(), port))
.await
.map_err(|e| e.to_string())?
.map(|sa| sa.ip())
.collect();
Ok(ips)
})
}
}
#[derive(Clone)]
pub struct PushDeliveryWorker {
pub push_delivery_store: Arc<dyn A2aPushDeliveryStore>,
pub dns_resolver: Arc<dyn PushDnsResolver>,
pub config: PushDeliveryConfig,
pub outbound_validator: Option<OutboundUrlValidator>,
pub instance_id: String,
}
impl PushDeliveryWorker {
pub fn new(
push_delivery_store: Arc<dyn A2aPushDeliveryStore>,
config: PushDeliveryConfig,
outbound_validator: Option<OutboundUrlValidator>,
instance_id: String,
) -> Result<Self, String> {
Ok(Self {
push_delivery_store,
dns_resolver: Arc::new(TokioDnsResolver),
config,
outbound_validator,
instance_id,
})
}
pub fn with_dns_resolver(mut self, resolver: Arc<dyn PushDnsResolver>) -> Self {
self.dns_resolver = resolver;
self
}
pub async fn deliver(&self, target: &PushTarget, payload: &[u8]) -> DeliveryReport {
if payload.len() > self.config.max_payload_bytes {
let claim = match self.claim(target).await {
Ok(c) => c,
Err(_e) => return DeliveryReport::UnclaimedSkip,
};
return self
.persist_terminal(
target,
&claim,
DeliveryOutcome::GaveUp {
reason: GaveUpReason::PayloadTooLarge,
last_error_class: DeliveryErrorClass::PayloadTooLarge,
last_http_status: None,
},
)
.await;
}
let claim = match self.claim(target).await {
Ok(c) => c,
Err(ClaimFailure::AlreadyHeld) => return DeliveryReport::ClaimLostOrFinal,
Err(ClaimFailure::Other(_)) => return DeliveryReport::UnclaimedSkip,
};
let mut current_count = claim.delivery_attempt_count;
loop {
if current_count >= self.config.max_attempts {
return self
.persist_terminal(
target,
&claim,
DeliveryOutcome::GaveUp {
reason: GaveUpReason::MaxAttemptsExhausted,
last_error_class: DeliveryErrorClass::Timeout,
last_http_status: None,
},
)
.await;
}
let ips = self
.resolve(&target.url)
.await
.unwrap_or_else(|_| Vec::new());
let decision = ssrf_decide(
&target.url,
&ips,
self.config.allow_insecure_urls,
self.outbound_validator.as_ref(),
);
let resolved_ip = match decision {
SsrfDecision::Allow { resolved_ip } => resolved_ip,
SsrfDecision::Block(reason) => {
let (gu, ec) = ssrf_block_to_diagnostics(reason);
return self
.persist_terminal(
target,
&claim,
DeliveryOutcome::GaveUp {
reason: gu,
last_error_class: ec,
last_http_status: None,
},
)
.await;
}
};
let new_count = match self
.push_delivery_store
.record_attempt_started(
&target.tenant,
&target.task_id,
target.event_sequence,
&target.config_id,
&claim.claimant,
claim.generation,
)
.await
{
Ok(n) => n,
Err(A2aStorageError::StaleDeliveryClaim { .. }) => {
return DeliveryReport::ClaimLostOrFinal;
}
Err(_) => return DeliveryReport::TransientStoreError,
};
current_count = new_count;
let result = self.post(target, payload, resolved_ip).await;
let outcome = match &result {
Ok(status) if (200..400).contains(&status.as_u16()) => DeliveryOutcome::Succeeded {
http_status: status.as_u16(),
},
Ok(status) if status.as_u16() == 429 => DeliveryOutcome::Retry {
next_attempt_at: SystemTime::now() + self.backoff_for(new_count),
http_status: Some(status.as_u16()),
error_class: DeliveryErrorClass::HttpError429,
},
Ok(status) if status.as_u16() == 408 => DeliveryOutcome::Retry {
next_attempt_at: SystemTime::now() + self.backoff_for(new_count),
http_status: Some(status.as_u16()),
error_class: DeliveryErrorClass::Timeout,
},
Ok(status) if (500..600).contains(&status.as_u16()) => DeliveryOutcome::Retry {
next_attempt_at: SystemTime::now() + self.backoff_for(new_count),
http_status: Some(status.as_u16()),
error_class: DeliveryErrorClass::HttpError5xx {
status: status.as_u16(),
},
},
Ok(status) => {
DeliveryOutcome::GaveUp {
reason: GaveUpReason::NonRetryableHttpStatus,
last_error_class: DeliveryErrorClass::HttpError4xx {
status: status.as_u16(),
},
last_http_status: Some(status.as_u16()),
}
}
Err(err_class) => DeliveryOutcome::Retry {
next_attempt_at: SystemTime::now() + self.backoff_for(new_count),
http_status: None,
error_class: *err_class,
},
};
match outcome {
DeliveryOutcome::Succeeded { .. }
| DeliveryOutcome::GaveUp { .. }
| DeliveryOutcome::Abandoned { .. } => {
return self.persist_terminal(target, &claim, outcome).await;
}
DeliveryOutcome::Retry {
http_status,
error_class,
..
} => {
let _ = self
.push_delivery_store
.record_delivery_outcome(
&target.tenant,
&target.task_id,
target.event_sequence,
&target.config_id,
&claim.claimant,
claim.generation,
DeliveryOutcome::Retry {
next_attempt_at: SystemTime::now() + self.backoff_for(new_count),
http_status,
error_class,
},
)
.await;
if current_count >= self.config.max_attempts {
return self
.persist_terminal(
target,
&claim,
DeliveryOutcome::GaveUp {
reason: GaveUpReason::MaxAttemptsExhausted,
last_error_class: error_class,
last_http_status: http_status,
},
)
.await;
}
tokio::time::sleep(self.backoff_for(new_count)).await;
continue;
}
}
}
}
async fn persist_terminal(
&self,
target: &PushTarget,
claim: &DeliveryClaim,
outcome: DeliveryOutcome,
) -> DeliveryReport {
let success_report = match &outcome {
DeliveryOutcome::Succeeded { http_status } => DeliveryReport::Succeeded(*http_status),
DeliveryOutcome::GaveUp { reason, .. } => DeliveryReport::GaveUp(*reason),
DeliveryOutcome::Abandoned { reason } => DeliveryReport::Abandoned(*reason),
DeliveryOutcome::Retry { .. } => return DeliveryReport::TransientStoreError,
};
const TERMINAL_PERSIST_MAX_ATTEMPTS: u32 = 3;
let backoffs = [
Duration::from_millis(50),
Duration::from_millis(150),
Duration::from_millis(500),
];
let mut last_error: Option<A2aStorageError> = None;
for attempt in 0..TERMINAL_PERSIST_MAX_ATTEMPTS {
if attempt > 0 {
tokio::time::sleep(backoffs[(attempt as usize).min(backoffs.len() - 1)]).await;
}
match self
.push_delivery_store
.record_delivery_outcome(
&target.tenant,
&target.task_id,
target.event_sequence,
&target.config_id,
&claim.claimant,
claim.generation,
outcome.clone(),
)
.await
{
Ok(()) => return success_report,
Err(A2aStorageError::StaleDeliveryClaim { .. }) => {
return DeliveryReport::ClaimLostOrFinal;
}
Err(e) => {
last_error = Some(e);
}
}
}
let _ = last_error; DeliveryReport::TransientStoreError
}
pub async fn abandon_reclaimed(
&self,
target: &PushTarget,
reason: AbandonedReason,
) -> DeliveryReport {
let claim = match self.claim(target).await {
Ok(c) => c,
Err(ClaimFailure::AlreadyHeld) => return DeliveryReport::ClaimLostOrFinal,
Err(ClaimFailure::Other(_)) => return DeliveryReport::UnclaimedSkip,
};
self.persist_terminal(target, &claim, DeliveryOutcome::Abandoned { reason })
.await
}
async fn claim(&self, target: &PushTarget) -> Result<DeliveryClaim, ClaimFailure> {
match self
.push_delivery_store
.claim_delivery(
&target.tenant,
&target.task_id,
target.event_sequence,
&target.config_id,
&self.instance_id,
&target.owner,
self.config.claim_expiry,
)
.await
{
Ok(c) => Ok(c),
Err(A2aStorageError::ClaimAlreadyHeld { .. }) => Err(ClaimFailure::AlreadyHeld),
Err(e) => Err(ClaimFailure::Other(e.to_string())),
}
}
async fn resolve(&self, url: &Url) -> Result<Vec<IpAddr>, String> {
let host = url
.host_str()
.ok_or_else(|| "url has no host".to_string())?;
let port = url.port_or_known_default().unwrap_or(443);
self.dns_resolver.resolve(host, port).await
}
async fn post(
&self,
target: &PushTarget,
payload: &[u8],
resolved_ip: IpAddr,
) -> Result<reqwest::StatusCode, DeliveryErrorClass> {
let host = target
.url
.host_str()
.ok_or(DeliveryErrorClass::NetworkError)?
.to_string();
let pinned = SocketAddr::new(resolved_ip, 0);
let client = reqwest::Client::builder()
.connect_timeout(self.config.connect_timeout)
.timeout(self.config.request_timeout)
.redirect(reqwest::redirect::Policy::none())
.resolve(&host, pinned)
.build()
.map_err(|_| DeliveryErrorClass::NetworkError)?;
let mut req = client
.post(target.url.clone())
.header("Content-Type", "application/json")
.header(
"User-Agent",
format!("turul-a2a/{}", env!("CARGO_PKG_VERSION")),
)
.header("X-Turul-Event-Sequence", target.event_sequence.to_string());
if !target.auth_scheme.is_empty() {
req = req.header(
"Authorization",
format!(
"{} {}",
target.auth_scheme,
target.auth_credentials.expose()
),
);
}
if let Some(tok) = &target.token {
req = req.header("X-Turul-Push-Token", tok.expose());
}
let resp = req.body(payload.to_vec()).send().await;
match resp {
Ok(r) => Ok(r.status()),
Err(e) => {
if e.is_timeout() {
Err(DeliveryErrorClass::Timeout)
} else if e.is_connect() {
Err(DeliveryErrorClass::NetworkError)
} else if e.to_string().to_lowercase().contains("tls") {
Err(DeliveryErrorClass::TlsRejected)
} else {
Err(DeliveryErrorClass::NetworkError)
}
}
}
}
fn backoff_for(&self, attempt: u32) -> Duration {
let base = self.config.backoff_base.as_millis() as u64;
let cap = self.config.backoff_cap.as_millis() as u64;
let raw = base.saturating_mul(1u64 << attempt.min(31).saturating_sub(1));
let target = raw.min(cap);
let jitter = self.config.backoff_jitter.max(0.0);
let delta = (target as f64 * jitter as f64) as i64;
let offset = if delta == 0 {
0i64
} else {
rand::rng().random_range(-delta..=delta)
};
let final_ms = (target as i64).saturating_add(offset).max(0) as u64;
Duration::from_millis(final_ms)
}
}
enum ClaimFailure {
AlreadyHeld,
Other(#[allow(dead_code)] String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DeliveryReport {
Succeeded(u16),
GaveUp(GaveUpReason),
Abandoned(AbandonedReason),
ClaimLostOrFinal,
UnclaimedSkip,
TransientStoreError,
}
fn ssrf_block_to_diagnostics(reason: SsrfBlockReason) -> (GaveUpReason, DeliveryErrorClass) {
match reason {
SsrfBlockReason::PrivateIp
| SsrfBlockReason::InvalidUrl
| SsrfBlockReason::DnsResolutionFailed
| SsrfBlockReason::ValidatorDenied => {
(GaveUpReason::SsrfBlocked, DeliveryErrorClass::SSRFBlocked)
}
SsrfBlockReason::InsecureScheme => {
(GaveUpReason::TlsRejected, DeliveryErrorClass::TlsRejected)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::InMemoryA2aStorage;
fn worker_with_store(
store: Arc<InMemoryA2aStorage>,
cfg: PushDeliveryConfig,
) -> PushDeliveryWorker {
PushDeliveryWorker::new(store, cfg, None, format!("worker-{}", uuid::Uuid::now_v7()))
.expect("build")
}
fn target(url: &str) -> PushTarget {
PushTarget {
tenant: "t".into(),
owner: "anonymous".into(),
task_id: format!("task-{}", uuid::Uuid::now_v7()),
event_sequence: 1,
config_id: "cfg-A".into(),
url: Url::parse(url).unwrap(),
auth_scheme: "Bearer".into(),
auth_credentials: Secret::new("cred".into()),
token: None,
}
}
#[test]
fn backoff_doubles_until_cap() {
let store = Arc::new(InMemoryA2aStorage::new());
let cfg = PushDeliveryConfig {
backoff_base: Duration::from_secs(1),
backoff_cap: Duration::from_secs(8),
backoff_jitter: 0.0,
..Default::default()
};
let w = worker_with_store(store, cfg);
assert_eq!(w.backoff_for(1), Duration::from_secs(1));
assert_eq!(w.backoff_for(2), Duration::from_secs(2));
assert_eq!(w.backoff_for(3), Duration::from_secs(4));
assert_eq!(w.backoff_for(4), Duration::from_secs(8));
assert_eq!(w.backoff_for(5), Duration::from_secs(8));
assert_eq!(w.backoff_for(10), Duration::from_secs(8));
}
#[tokio::test]
async fn payload_too_large_short_circuits_with_gaveup() {
let store = Arc::new(InMemoryA2aStorage::new());
let cfg = PushDeliveryConfig {
max_payload_bytes: 10,
..Default::default()
};
let w = worker_with_store(store.clone(), cfg);
let t = target("https://example.com/");
let payload = vec![0u8; 1024];
let report = w.deliver(&t, &payload).await;
assert_eq!(
report,
DeliveryReport::GaveUp(GaveUpReason::PayloadTooLarge)
);
let failed = store
.list_failed_deliveries(&t.tenant, SystemTime::UNIX_EPOCH, 10)
.await
.unwrap();
assert_eq!(failed.len(), 1);
assert!(matches!(
failed[0].last_error_class,
DeliveryErrorClass::PayloadTooLarge
));
}
#[tokio::test]
async fn non_https_in_production_records_gaveup_ssrf() {
let store = Arc::new(InMemoryA2aStorage::new());
let cfg = PushDeliveryConfig {
allow_insecure_urls: false,
..Default::default()
};
let w = worker_with_store(store.clone(), cfg);
let t = target("http://webhook.example.com/");
let report = w.deliver(&t, b"{}").await;
assert!(matches!(report, DeliveryReport::GaveUp(_)));
let failed = store
.list_failed_deliveries(&t.tenant, SystemTime::UNIX_EPOCH, 10)
.await
.unwrap();
assert_eq!(failed.len(), 1);
}
}