use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
use a2a_protocol_types::push::{AuthenticationInfo, TaskPushNotificationConfig};
use a2a_protocol_types::task::{ContextId, TaskId, TaskState, TaskStatus};
use a2a_protocol_server::push::{HttpPushSender, PushSender};
fn status_event() -> StreamResponse {
StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: TaskId::new("task-1"),
context_id: ContextId::new("ctx"),
status: TaskStatus::new(TaskState::Working),
metadata: None,
})
}
fn base_config(url: &str) -> TaskPushNotificationConfig {
TaskPushNotificationConfig::new("task-1", url)
}
async fn mock_server(
status: u16,
request_counter: Arc<AtomicUsize>,
) -> (SocketAddr, tokio::task::JoinHandle<()>) {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
for _ in 0..5 {
let Ok((stream, _)) = listener.accept().await else {
break;
};
let counter = Arc::clone(&request_counter);
tokio::spawn(async move {
counter.fetch_add(1, Ordering::SeqCst);
stream.readable().await.unwrap();
let mut buf = vec![0u8; 4096];
let _ = stream.try_read(&mut buf);
let response = format!(
"HTTP/1.1 {status} OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"
);
stream.writable().await.unwrap();
let _ = stream.try_write(response.as_bytes());
});
}
});
(addr, handle)
}
async fn mock_server_with_headers(
captured: Arc<std::sync::Mutex<Vec<String>>>,
) -> (SocketAddr, tokio::task::JoinHandle<()>) {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
for _ in 0..3 {
let Ok((stream, _)) = listener.accept().await else {
break;
};
let captured = Arc::clone(&captured);
tokio::spawn(async move {
stream.readable().await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.try_read(&mut buf).unwrap_or(0);
let request = String::from_utf8_lossy(&buf[..n]).to_string();
captured.lock().unwrap().push(request);
let response = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n";
stream.writable().await.unwrap();
let _ = stream.try_write(response.as_bytes());
});
}
});
(addr, handle)
}
#[tokio::test]
async fn successful_delivery_on_first_attempt() {
let counter = Arc::new(AtomicUsize::new(0));
let (addr, handle) = mock_server(200, Arc::clone(&counter)).await;
let sender = HttpPushSender::new().allow_private_urls();
let url = format!("http://{addr}/webhook");
let config = base_config(&url);
sender.send(&url, &status_event(), &config).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"should succeed on first attempt"
);
handle.abort();
}
#[tokio::test]
async fn retries_on_server_error_and_eventually_fails() {
let counter = Arc::new(AtomicUsize::new(0));
let (addr, handle) = mock_server(500, Arc::clone(&counter)).await;
let sender = HttpPushSender::new().allow_private_urls();
let url = format!("http://{addr}/webhook");
let config = base_config(&url);
let result = sender.send(&url, &status_event(), &config).await;
assert!(result.is_err(), "should fail after all retries");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("HTTP 500"),
"error should mention HTTP status: {err_msg}"
);
assert_eq!(
counter.load(Ordering::SeqCst),
3,
"should retry exactly 3 times"
);
handle.abort();
}
#[tokio::test]
async fn retries_on_client_error_status() {
let counter = Arc::new(AtomicUsize::new(0));
let (addr, handle) = mock_server(403, Arc::clone(&counter)).await;
let sender = HttpPushSender::new().allow_private_urls();
let url = format!("http://{addr}/webhook");
let config = base_config(&url);
let result = sender.send(&url, &status_event(), &config).await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 3);
handle.abort();
}
#[tokio::test]
async fn connection_refused_returns_error() {
let sender = HttpPushSender::new().allow_private_urls();
let url = "http://127.0.0.1:1/webhook";
let config = base_config(url);
let result = sender.send(url, &status_event(), &config).await;
assert!(result.is_err(), "should fail on connection refused");
}
#[tokio::test]
async fn bearer_auth_header_is_sent() {
let captured = Arc::new(std::sync::Mutex::new(Vec::new()));
let (addr, handle) = mock_server_with_headers(Arc::clone(&captured)).await;
let sender = HttpPushSender::new().allow_private_urls();
let url = format!("http://{addr}/webhook");
let mut config = base_config(&url);
config.authentication = Some(AuthenticationInfo {
scheme: "bearer".into(),
credentials: "my-secret-token".into(),
});
sender.send(&url, &status_event(), &config).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let reqs = captured.lock().unwrap();
assert!(
!reqs.is_empty(),
"should have captured at least one request"
);
let req = &reqs[0];
assert!(
req.contains("authorization: Bearer my-secret-token")
|| req.contains("Authorization: Bearer my-secret-token"),
"should contain Bearer auth header, got: {req}"
);
handle.abort();
}
#[tokio::test]
async fn basic_auth_header_is_sent() {
let captured = Arc::new(std::sync::Mutex::new(Vec::new()));
let (addr, handle) = mock_server_with_headers(Arc::clone(&captured)).await;
let sender = HttpPushSender::new().allow_private_urls();
let url = format!("http://{addr}/webhook");
let mut config = base_config(&url);
config.authentication = Some(AuthenticationInfo {
scheme: "basic".into(),
credentials: "dXNlcjpwYXNz".into(),
});
sender.send(&url, &status_event(), &config).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let reqs = captured.lock().unwrap();
assert!(!reqs.is_empty());
let req = &reqs[0];
assert!(
req.contains("authorization: Basic dXNlcjpwYXNz")
|| req.contains("Authorization: Basic dXNlcjpwYXNz"),
"should contain Basic auth header, got: {req}"
);
handle.abort();
}
#[tokio::test]
async fn notification_token_header_is_sent() {
let captured = Arc::new(std::sync::Mutex::new(Vec::new()));
let (addr, handle) = mock_server_with_headers(Arc::clone(&captured)).await;
let sender = HttpPushSender::new().allow_private_urls();
let url = format!("http://{addr}/webhook");
let mut config = base_config(&url);
config.token = Some("my-notification-token".into());
sender.send(&url, &status_event(), &config).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let reqs = captured.lock().unwrap();
assert!(!reqs.is_empty());
let req = &reqs[0];
assert!(
req.contains("a2a-notification-token: my-notification-token"),
"should contain notification token header, got: {req}"
);
handle.abort();
}
#[tokio::test]
async fn both_auth_and_token_headers_are_sent() {
let captured = Arc::new(std::sync::Mutex::new(Vec::new()));
let (addr, handle) = mock_server_with_headers(Arc::clone(&captured)).await;
let sender = HttpPushSender::new().allow_private_urls();
let url = format!("http://{addr}/webhook");
let mut config = base_config(&url);
config.authentication = Some(AuthenticationInfo {
scheme: "bearer".into(),
credentials: "token-123".into(),
});
config.token = Some("notif-456".into());
sender.send(&url, &status_event(), &config).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let reqs = captured.lock().unwrap();
assert!(!reqs.is_empty());
let req = &reqs[0];
assert!(
req.contains("Bearer token-123") || req.contains("bearer token-123"),
"should contain Bearer auth"
);
assert!(
req.contains("a2a-notification-token: notif-456"),
"should contain notification token"
);
handle.abort();
}
#[tokio::test]
async fn request_has_json_content_type() {
let captured = Arc::new(std::sync::Mutex::new(Vec::new()));
let (addr, handle) = mock_server_with_headers(Arc::clone(&captured)).await;
let sender = HttpPushSender::new().allow_private_urls();
let url = format!("http://{addr}/webhook");
let config = base_config(&url);
sender.send(&url, &status_event(), &config).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let reqs = captured.lock().unwrap();
assert!(!reqs.is_empty());
let req = &reqs[0];
assert!(
req.contains("content-type: application/json")
|| req.contains("Content-Type: application/json"),
"should have JSON content type, got: {req}"
);
handle.abort();
}
#[tokio::test]
async fn request_uses_post_method() {
let captured = Arc::new(std::sync::Mutex::new(Vec::new()));
let (addr, handle) = mock_server_with_headers(Arc::clone(&captured)).await;
let sender = HttpPushSender::new().allow_private_urls();
let url = format!("http://{addr}/webhook");
let config = base_config(&url);
sender.send(&url, &status_event(), &config).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let reqs = captured.lock().unwrap();
assert!(!reqs.is_empty());
assert!(
reqs[0].starts_with("POST "),
"should use POST method, got: {}",
&reqs[0][..50.min(reqs[0].len())]
);
handle.abort();
}
#[test]
fn http_push_sender_default_creates_instance() {
let sender = HttpPushSender::default();
let dbg = format!("{sender:?}");
assert!(dbg.contains("HttpPushSender"));
}
#[test]
fn http_push_sender_debug_impl() {
let sender = HttpPushSender::new().allow_private_urls();
let dbg = format!("{sender:?}");
assert!(dbg.contains("HttpPushSender"));
}