use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use crate::{
AgentSpec, ArtifactId, ExecutionContext, ExecutionHandle, ExecutionMetrics, ExecutionResult,
ResourceLimits, Run, RunError, RunId, RunStatus, RuntimeAdapter, RuntimeKind, StatusResult,
};
const MAX_RESPONSE_BYTES: usize = 10 * 1024 * 1024; const TOTAL_TIMEOUT_SECS: u64 = 30;
const AGENT_CARD_TTL_SECS: u64 = 300; const BACKOFF_DELAYS: &[u64] = &[1, 2, 4, 8, 10];
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentCard {
pub name: String,
pub url: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub capabilities: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct A2ATaskStatus {
pub state: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct A2AArtifact {
#[serde(default)]
pub parts: Vec<A2APart>,
#[serde(default)]
pub name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct A2APart {
#[serde(rename = "type")]
pub part_type: String,
#[serde(default)]
pub text: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct A2ATaskResponse {
pub id: String,
pub status: A2ATaskStatus,
#[serde(default)]
pub artifacts: Vec<A2AArtifact>,
}
pub fn validate_a2a_url(url: &str) -> Result<(), RunError> {
let parsed = reqwest::Url::parse(url).map_err(|e| RunError::InvalidConfig {
message: format!("Invalid A2A URL '{}': {}", url, e),
})?;
if parsed.scheme() != "https" {
return Err(RunError::InvalidConfig {
message: format!(
"A2A URL must use HTTPS (got '{}'): {}",
parsed.scheme(),
url
),
});
}
let host = parsed.host_str().ok_or_else(|| RunError::InvalidConfig {
message: format!("A2A URL has no host: {}", url),
})?;
if host.eq_ignore_ascii_case("localhost") {
return Err(RunError::InvalidConfig {
message: format!("A2A URL targets private/loopback address: {}", url),
});
}
if let Ok(ip) = host.parse::<IpAddr>() {
if is_private_ip(ip) {
return Err(RunError::InvalidConfig {
message: format!("A2A URL targets private IP address: {}", url),
});
}
return Ok(());
}
let port = parsed.port().unwrap_or(443);
let resolve_target = format!("{}:{}", host, port);
match std::net::ToSocketAddrs::to_socket_addrs(&resolve_target) {
Ok(addrs) => {
for addr in addrs {
if is_private_ip(addr.ip()) {
return Err(RunError::InvalidConfig {
message: format!(
"A2A hostname '{}' resolves to private IP {}: {}",
host,
addr.ip(),
url
),
});
}
}
}
Err(_) => {
}
}
Ok(())
}
pub fn is_private_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
let o = v4.octets();
o[0] == 127 || o[0] == 10 || (o[0] == 172 && (16..=31).contains(&o[1])) || (o[0] == 192 && o[1] == 168) || (o[0] == 169 && o[1] == 254) || o[0] == 0 || (o[0] & 0xf0) == 224 || o == [255, 255, 255, 255] }
IpAddr::V6(v6) => {
v6.is_loopback()
|| (v6.segments()[0] & 0xfe00) == 0xfc00 || (v6.segments()[0] & 0xffc0) == 0xfe80 }
}
}
pub fn resolve_auth(agent_name: &str) -> Option<String> {
let exact = agent_name.to_uppercase();
if let Ok(val) = std::env::var(format!("BZZZ_A2A_TOKEN_{}", exact)) {
return Some(val);
}
let normalized: String = exact
.chars()
.map(|c| if c.is_alphanumeric() { c } else { '_' })
.collect();
if normalized != exact {
return std::env::var(format!("BZZZ_A2A_TOKEN_{}", normalized)).ok();
}
None
}
pub fn artifacts_to_value(artifacts: &[A2AArtifact]) -> Option<serde_json::Value> {
let texts: Vec<&str> = artifacts
.iter()
.flat_map(|a| a.parts.iter())
.filter(|p| p.part_type == "text")
.filter_map(|p| p.text.as_deref())
.collect();
match texts.len() {
0 => None,
1 => {
let text = texts[0];
serde_json::from_str::<serde_json::Value>(text)
.ok()
.or_else(|| Some(serde_json::json!({"text": text})))
}
_ => Some(serde_json::Value::Array(
texts
.iter()
.map(|t| {
serde_json::from_str::<serde_json::Value>(t)
.unwrap_or_else(|_| serde_json::json!({"text": t}))
})
.collect(),
)),
}
}
struct CachedCard {
card: AgentCard,
fetched_at: SystemTime,
}
impl CachedCard {
fn is_fresh(&self) -> bool {
self.fetched_at
.elapsed()
.map(|d| d.as_secs() < AGENT_CARD_TTL_SECS)
.unwrap_or(false)
}
}
struct A2AExecution {
run_id: RunId,
status: RunStatus,
started_at: Instant,
output: Option<serde_json::Value>,
error: Option<RunError>,
agent_url: String,
auth_token: Option<String>,
}
pub struct A2ARuntime {
client: reqwest::Client,
card_cache: Arc<RwLock<HashMap<String, CachedCard>>>,
executions: Arc<RwLock<HashMap<String, A2AExecution>>>,
}
impl A2ARuntime {
pub fn new() -> Self {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(TOTAL_TIMEOUT_SECS))
.build()
.expect("Failed to build A2A reqwest client");
A2ARuntime {
client,
card_cache: Arc::new(RwLock::new(HashMap::new())),
executions: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn discover_agent_card(url: &str) -> Result<AgentCard, RunError> {
validate_a2a_url(url)?;
let card_url = format!("{}/.well-known/agent.json", url.trim_end_matches('/'));
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("Failed to build discovery client");
let response = client
.get(&card_url)
.send()
.await
.map_err(|e| RunError::StartupFailed {
message: format!("Failed to fetch A2A agent card from {}: {}", card_url, e),
})?;
if !response.status().is_success() {
return Err(RunError::StartupFailed {
message: format!(
"A2A agent card returned HTTP {} for {}",
response.status(),
card_url
),
});
}
let bytes = response
.bytes()
.await
.map_err(|e| RunError::StartupFailed {
message: format!("Failed to read agent card response: {}", e),
})?;
if bytes.len() > MAX_RESPONSE_BYTES {
return Err(RunError::StartupFailed {
message: format!(
"Agent card response exceeds {} byte limit",
MAX_RESPONSE_BYTES
),
});
}
let card: AgentCard =
serde_json::from_slice(&bytes).map_err(|e| RunError::StartupFailed {
message: format!("Failed to parse agent card JSON: {}", e),
})?;
Ok(card)
}
async fn get_agent_card(&self, base_url: &str) -> Result<AgentCard, RunError> {
{
let cache = self.card_cache.read().await;
if let Some(entry) = cache.get(base_url) {
if entry.is_fresh() {
return Ok(entry.card.clone());
}
}
}
let card_url = format!("{}/.well-known/agent.json", base_url.trim_end_matches('/'));
let response =
self.client
.get(&card_url)
.send()
.await
.map_err(|e| RunError::StartupFailed {
message: format!("Failed to fetch A2A agent card from {}: {}", card_url, e),
})?;
if !response.status().is_success() {
return Err(RunError::StartupFailed {
message: format!(
"A2A agent card returned HTTP {} for {}",
response.status(),
card_url
),
});
}
let bytes = response
.bytes()
.await
.map_err(|e| RunError::StartupFailed {
message: format!("Failed to read agent card response: {}", e),
})?;
if bytes.len() > MAX_RESPONSE_BYTES {
return Err(RunError::StartupFailed {
message: format!(
"Agent card response exceeds {} byte limit",
MAX_RESPONSE_BYTES
),
});
}
let card: AgentCard =
serde_json::from_slice(&bytes).map_err(|e| RunError::StartupFailed {
message: format!("Failed to parse agent card JSON: {}", e),
})?;
{
let mut cache = self.card_cache.write().await;
cache.insert(
base_url.to_string(),
CachedCard {
card: card.clone(),
fetched_at: SystemTime::now(),
},
);
}
Ok(card)
}
async fn send_task_and_wait(
&self,
agent_url: &str,
run: &Run,
auth: Option<&str>,
) -> Result<serde_json::Value, RunError> {
let task_id = run.id.as_str().to_string();
let input_text = run
.input
.as_ref()
.map(|v| v.to_string())
.unwrap_or_default();
let payload = serde_json::json!({
"id": task_id,
"message": {
"role": "user",
"parts": [{"type": "text", "text": input_text}]
}
});
let task_url = format!("{}/tasks/send", agent_url.trim_end_matches('/'));
let deadline = Instant::now() + Duration::from_secs(TOTAL_TIMEOUT_SECS);
let mut step = 0usize;
loop {
if Instant::now() >= deadline {
return Err(RunError::Timeout {
after: Duration::from_secs(TOTAL_TIMEOUT_SECS),
});
}
let mut req = self.client.post(&task_url).json(&payload);
if let Some(token) = auth {
req = req.bearer_auth(token);
}
let response = req.send().await.map_err(|e| RunError::ExecutionFailed {
message: format!("A2A task request failed: {}", e),
})?;
let http_status = response.status();
let bytes = response
.bytes()
.await
.map_err(|e| RunError::ExecutionFailed {
message: format!("Failed to read A2A response: {}", e),
})?;
if bytes.len() > MAX_RESPONSE_BYTES {
return Err(RunError::ExecutionFailed {
message: format!(
"A2A response exceeds {} byte limit ({} bytes received)",
MAX_RESPONSE_BYTES,
bytes.len()
),
});
}
if !http_status.is_success() {
return Err(RunError::ExecutionFailed {
message: format!(
"A2A task returned HTTP {}: {}",
http_status,
String::from_utf8_lossy(&bytes)
),
});
}
let task: A2ATaskResponse =
serde_json::from_slice(&bytes).map_err(|e| RunError::ExecutionFailed {
message: format!("Failed to parse A2A task response: {}", e),
})?;
match task.status.state.as_str() {
"completed" => {
return Ok(
artifacts_to_value(&task.artifacts).unwrap_or(serde_json::Value::Null)
);
}
"failed" => {
return Err(RunError::ExecutionFailed {
message: format!("A2A task '{}' failed", task.id),
});
}
"canceled" => {
return Err(RunError::Cancelled {
reason: format!("A2A task '{}' was canceled", task.id),
});
}
_ => {
let delay = BACKOFF_DELAYS[step.min(BACKOFF_DELAYS.len() - 1)];
tokio::time::sleep(Duration::from_secs(delay)).await;
if step < BACKOFF_DELAYS.len() - 1 {
step += 1;
}
}
}
}
}
}
impl Default for A2ARuntime {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl RuntimeAdapter for A2ARuntime {
fn kind(&self) -> RuntimeKind {
RuntimeKind::Http
}
async fn create(&self, spec: &AgentSpec) -> Result<ExecutionContext, RunError> {
Ok(
ExecutionContext::new(format!("a2a-{}", spec.id.as_str()), RuntimeKind::Http)
.with_limits(ResourceLimits::default()),
)
}
async fn execute(
&self,
ctx: &ExecutionContext,
run: &Run,
) -> Result<ExecutionHandle, RunError> {
let url = match &run.target {
crate::RunTarget::A2AAgent { url } => url.clone(),
_ => {
return Err(RunError::InvalidConfig {
message: "A2ARuntime requires a RunTarget::A2AAgent target".into(),
});
}
};
validate_a2a_url(&url)?;
let started = Instant::now();
let card = self.get_agent_card(&url).await?;
let auth = resolve_auth(&card.name);
let result = self.send_task_and_wait(&url, run, auth.as_deref()).await;
let (status, output, error) = match result {
Ok(v) => (RunStatus::Completed, Some(v), None),
Err(e) => (RunStatus::Failed, None, Some(e)),
};
let elapsed = started.elapsed();
{
let mut execs = self.executions.write().await;
execs.insert(
run.id.as_str().to_string(),
A2AExecution {
run_id: run.id.clone(),
status,
started_at: Instant::now() - elapsed,
output: output.clone(),
error: error.clone(),
agent_url: url.clone(),
auth_token: auth.clone(),
},
);
}
if let Some(e) = error {
return Err(e);
}
Ok(ExecutionHandle::new(
run.id.clone(),
RuntimeKind::Http,
format!("a2a:{}", ctx.id),
))
}
async fn execute_background(
&self,
ctx: &ExecutionContext,
run: &Run,
) -> Result<ExecutionHandle, RunError> {
self.execute(ctx, run).await
}
async fn status(&self, handle: &ExecutionHandle) -> Result<StatusResult, RunError> {
let execs = self.executions.read().await;
let exec = execs
.get(handle.run_id.as_str())
.ok_or_else(|| RunError::NotFound {
resource_type: "a2a-execution".into(),
id: handle.run_id.as_str().to_string(),
})?;
Ok(StatusResult {
run_id: exec.run_id.clone(),
status: exec.status,
current_step: None,
progress: if exec.status == RunStatus::Completed {
100
} else {
0
},
elapsed_ms: exec.started_at.elapsed().as_millis() as u64,
artifacts: Vec::new(),
})
}
async fn destroy(&self, _ctx: &ExecutionContext) -> Result<(), RunError> {
Ok(())
}
async fn cancel(&self, handle: &ExecutionHandle) -> Result<(), RunError> {
let (agent_url, auth_token, task_id) = {
let execs = self.executions.read().await;
match execs.get(handle.run_id.as_str()) {
Some(exec) => (
exec.agent_url.clone(),
exec.auth_token.clone(),
exec.run_id.as_str().to_string(),
),
None => {
return Ok(());
}
}
};
let cancel_url = format!("{}/tasks/cancel", agent_url.trim_end_matches('/'));
let payload = serde_json::json!({
"jsonrpc": "2.0",
"method": "tasks/cancel",
"params": {"id": task_id},
"id": format!("cancel-{}", task_id)
});
let cancel_client = reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.build()
.expect("Failed to build cancel client");
let cancel_result = async {
let mut req = cancel_client.post(&cancel_url).json(&payload);
if let Some(token) = &auth_token {
req = req.bearer_auth(token);
}
req.send().await
};
if let Err(e) = cancel_result.await {
eprintln!(
"Warning: A2A remote cancel failed for task {} at {}: {}",
task_id, agent_url, e
);
}
{
let mut execs = self.executions.write().await;
if let Some(exec) = execs.get_mut(handle.run_id.as_str()) {
exec.status = RunStatus::Cancelled;
}
}
Ok(())
}
async fn wait(&self, handle: &ExecutionHandle) -> Result<ExecutionResult, RunError> {
let execs = self.executions.read().await;
let exec = execs
.get(handle.run_id.as_str())
.ok_or_else(|| RunError::NotFound {
resource_type: "a2a-execution".into(),
id: handle.run_id.as_str().to_string(),
})?;
Ok(ExecutionResult {
run_id: exec.run_id.clone(),
status: exec.status,
artifacts: Vec::<ArtifactId>::new(),
error: exec.error.clone(),
metrics: ExecutionMetrics {
wall_time_ms: exec.started_at.elapsed().as_millis() as u64,
..Default::default()
},
output: exec.output.clone(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn test_validate_url_https_ok() {
assert!(validate_a2a_url("https://agent.example.com").is_ok());
assert!(validate_a2a_url("https://api.acme.io/v1").is_ok());
}
#[test]
fn test_validate_url_http_rejected() {
let err = validate_a2a_url("http://agent.example.com").unwrap_err();
assert!(err.to_string().contains("HTTPS"));
}
#[test]
fn test_validate_url_localhost_rejected() {
assert!(validate_a2a_url("https://localhost/agent").is_err());
}
#[test]
fn test_validate_url_loopback_rejected() {
assert!(validate_a2a_url("https://127.0.0.1").is_err());
}
#[test]
fn test_validate_url_rfc1918_10_rejected() {
assert!(validate_a2a_url("https://10.0.0.1").is_err());
}
#[test]
fn test_validate_url_rfc1918_192_168_rejected() {
assert!(validate_a2a_url("https://192.168.1.1").is_err());
}
#[test]
fn test_validate_url_rfc1918_172_rejected() {
assert!(validate_a2a_url("https://172.16.0.1").is_err());
assert!(validate_a2a_url("https://172.15.0.1").is_ok()); assert!(validate_a2a_url("https://172.32.0.1").is_ok()); }
#[test]
fn test_validate_url_link_local_rejected() {
assert!(validate_a2a_url("https://169.254.1.1").is_err());
}
#[test]
fn test_validate_url_invalid() {
assert!(validate_a2a_url("not-a-url").is_err());
}
#[test]
fn test_is_private_ip_loopback() {
assert!(is_private_ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))));
}
#[test]
fn test_is_private_ip_rfc1918() {
assert!(is_private_ip(IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3))));
assert!(is_private_ip(IpAddr::V4(Ipv4Addr::new(172, 20, 0, 1))));
assert!(is_private_ip(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))));
assert!(!is_private_ip(IpAddr::V4(Ipv4Addr::new(172, 15, 0, 1))));
assert!(!is_private_ip(IpAddr::V4(Ipv4Addr::new(172, 32, 0, 1))));
}
#[test]
fn test_is_private_ip_link_local() {
assert!(is_private_ip(IpAddr::V4(Ipv4Addr::new(169, 254, 0, 1))));
}
#[test]
fn test_is_private_ip_multicast() {
assert!(is_private_ip(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 1))));
assert!(is_private_ip(IpAddr::V4(Ipv4Addr::new(239, 255, 255, 255))));
}
#[test]
fn test_is_private_ip_broadcast() {
assert!(is_private_ip(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255))));
}
#[test]
fn test_is_private_ip_public() {
assert!(!is_private_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
assert!(!is_private_ip(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))));
}
#[test]
fn test_validate_url_multicast_rejected() {
assert!(validate_a2a_url("https://224.0.0.1").is_err());
}
#[test]
fn test_validate_url_broadcast_rejected() {
assert!(validate_a2a_url("https://255.255.255.255").is_err());
}
#[test]
fn test_is_private_ip_ipv6_loopback() {
assert!(is_private_ip(IpAddr::V6(Ipv6Addr::LOCALHOST)));
}
#[test]
fn test_is_private_ip_ipv6_link_local() {
let fe80: Ipv6Addr = "fe80::1".parse().unwrap();
assert!(is_private_ip(IpAddr::V6(fe80)));
}
#[test]
fn test_is_private_ip_ipv6_unique_local() {
let fc00: Ipv6Addr = "fc00::1".parse().unwrap();
assert!(is_private_ip(IpAddr::V6(fc00)));
}
#[test]
fn test_artifacts_empty() {
assert_eq!(artifacts_to_value(&[]), None);
}
#[test]
fn test_artifacts_single_json_text() {
let a = A2AArtifact {
parts: vec![A2APart {
part_type: "text".into(),
text: Some(r#"{"result":42}"#.into()),
}],
name: None,
};
let v = artifacts_to_value(&[a]).unwrap();
assert_eq!(v["result"], 42);
}
#[test]
fn test_artifacts_single_plain_text() {
let a = A2AArtifact {
parts: vec![A2APart {
part_type: "text".into(),
text: Some("hello".into()),
}],
name: None,
};
let v = artifacts_to_value(&[a]).unwrap();
assert_eq!(v["text"], "hello");
}
#[test]
fn test_artifacts_multiple() {
let make = |s: &str| A2AArtifact {
parts: vec![A2APart {
part_type: "text".into(),
text: Some(s.into()),
}],
name: None,
};
let v = artifacts_to_value(&[make(r#"{"a":1}"#), make(r#"{"b":2}"#)]).unwrap();
assert!(v.is_array());
assert_eq!(v.as_array().unwrap().len(), 2);
}
#[test]
fn test_artifacts_non_text_ignored() {
let a = A2AArtifact {
parts: vec![A2APart {
part_type: "data".into(),
text: None,
}],
name: None,
};
assert_eq!(artifacts_to_value(&[a]), None);
}
#[test]
fn test_resolve_auth_unset() {
assert!(resolve_auth("no-such-agent-xyz-99999").is_none());
}
#[test]
fn test_resolve_auth_set() {
std::env::set_var("BZZZ_A2A_TOKEN_MY_TEST_AGENT_001", "tok123");
assert_eq!(resolve_auth("my-test-agent-001"), Some("tok123".into()));
std::env::remove_var("BZZZ_A2A_TOKEN_MY_TEST_AGENT_001");
}
#[test]
fn test_a2a_runtime_kind() {
assert_eq!(A2ARuntime::new().kind(), RuntimeKind::Http);
}
#[tokio::test]
async fn test_a2a_runtime_create() {
let rt = A2ARuntime::new();
let spec = AgentSpec::new("test", RuntimeKind::Http);
let ctx = rt.create(&spec).await.unwrap();
assert!(ctx.id.starts_with("a2a-"));
assert_eq!(ctx.runtime_kind, RuntimeKind::Http);
}
#[tokio::test]
async fn test_cancel_updates_local_status_when_remote_unreachable() {
let rt = A2ARuntime::new();
let run_id = RunId::from_string("test-cancel-unreachable".to_string());
let handle = ExecutionHandle::new(run_id.clone(), RuntimeKind::Http, "a2a:test".to_string());
{
let mut execs = rt.executions.write().await;
execs.insert(
"test-cancel-unreachable".to_string(),
A2AExecution {
run_id: run_id.clone(),
status: RunStatus::Running,
started_at: Instant::now(),
output: None,
error: None,
agent_url: "https://nonexistent.invalid".to_string(),
auth_token: None,
},
);
}
let result = rt.cancel(&handle).await;
assert!(result.is_ok(), "cancel should never fail");
{
let execs = rt.executions.read().await;
let exec = execs.get("test-cancel-unreachable").unwrap();
assert_eq!(exec.status, RunStatus::Cancelled);
}
}
#[tokio::test]
async fn test_cancel_returns_ok_for_missing_execution() {
let rt = A2ARuntime::new();
let run_id = RunId::from_string("test-cancel-missing".to_string());
let handle = ExecutionHandle::new(run_id, RuntimeKind::Http, "a2a:test".to_string());
let result = rt.cancel(&handle).await;
assert!(result.is_ok(), "cancel for missing execution should return Ok");
}
}