use std::{
sync::atomic::{AtomicU64, Ordering},
time::Duration,
};
use anyhow::{Context, Result, anyhow};
use axum::{
Json,
extract::{
Query, State,
ws::{Message as AxumWsMessage, WebSocket, WebSocketUpgrade},
},
http::HeaderMap,
response::{IntoResponse, Response},
};
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::{broadcast, mpsc, oneshot};
use tracing::{debug, info, warn};
use uuid::Uuid;
use crate::{
a2a::{
auth::A2aIdentity,
relay_identity,
types::{AgentCard, JsonRpcRequest, JsonRpcResponse},
},
server::{AppState, constant_time_eq},
};
use rsclaw_config::runtime::{
A2aRelayModeRuntime, A2aRelayNodeRuntime, A2aRelayRuntime, A2aRelayStrategyRuntime,
};
const RELAY_PROTOCOL: &str = "rsclaw.a2a.relay.v1";
const ROUTE_TTL_MS: u64 = 30_000;
const REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
const STREAM_MAX_LIFETIME: Duration = Duration::from_secs(1800);
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
const FAILOVER_BACKOFF_MAX: Duration = Duration::from_secs(60);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HelloCapabilities {
pub streaming_relay: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum RelayFrame {
Hello {
protocol: String,
node_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
node_version: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
agent_card: Option<AgentCard>,
#[serde(default, skip_serializing_if = "Option::is_none")]
capabilities: Option<HelloCapabilities>,
#[serde(default, skip_serializing_if = "Option::is_none")]
nonce_node: Option<String>,
},
Challenge {
relay_id: String,
nonce_relay: String,
},
Auth {
signature: String,
},
RouteLease {
node_id: String,
agents: Vec<String>,
ttl_ms: u64,
epoch: u64,
},
Request {
request_id: String,
target: String,
method: String,
params: Value,
principal: String,
deadline_ms: u64,
},
Response {
request_id: String,
response: JsonRpcResponse,
},
Event {
request_id: String,
seq: u64,
result: Value,
},
Cancel {
request_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
task_id: Option<String>,
},
Ping {
ts: u64,
},
Pong {
ts: u64,
},
Error {
request_id: String,
message: String,
},
}
#[derive(Debug, Clone)]
struct Connection {
tx: mpsc::UnboundedSender<AxumWsMessage>,
epoch: u64,
}
#[derive(Debug, Clone)]
struct StreamPending {
tx: broadcast::Sender<Value>,
agent_ref: String,
node_id: String,
deadline: std::time::Instant,
}
#[derive(Debug, Clone)]
pub struct RouteEntry {
pub agent_ref: String,
pub node_id: String,
pub epoch: u64,
pub expires_at: std::time::Instant,
}
#[derive(Default, Debug)]
pub struct RelayMetrics {
pub request_count: AtomicU64,
pub request_latency_ms_total: AtomicU64,
pub ws_reconnects: AtomicU64,
pub auth_failures: AtomicU64,
pub acl_denials: AtomicU64,
pub route_expirations: AtomicU64,
pub failovers: AtomicU64,
pub inflight_losses: AtomicU64,
}
impl RelayMetrics {
pub fn snapshot(&self, connected_nodes: u64, route_count: u64) -> serde_json::Value {
let req = self.request_count.load(Ordering::Relaxed);
let lat = self.request_latency_ms_total.load(Ordering::Relaxed);
let avg = if req > 0 { lat / req } else { 0 };
serde_json::json!({
"connected_nodes": connected_nodes,
"route_count": route_count,
"request_count": req,
"request_latency_ms_avg": avg,
"ws_reconnects": self.ws_reconnects.load(Ordering::Relaxed),
"auth_failures": self.auth_failures.load(Ordering::Relaxed),
"acl_denials": self.acl_denials.load(Ordering::Relaxed),
"route_expirations": self.route_expirations.load(Ordering::Relaxed),
"failovers": self.failovers.load(Ordering::Relaxed),
"inflight_losses": self.inflight_losses.load(Ordering::Relaxed),
})
}
}
#[derive(Default)]
pub struct RelayHub {
connections: DashMap<String, Connection>,
routes: DashMap<String, RouteEntry>,
pending: DashMap<String, (oneshot::Sender<JsonRpcResponse>, String)>,
stream_pending: DashMap<String, StreamPending>,
spoke_stream_tasks: DashMap<String, String>,
task_routes: DashMap<String, String>,
pub metrics: RelayMetrics,
}
pub fn audit_relay(
decision: &str,
principal: &str,
action: &str,
resource: &str,
relay_id: &str,
node_id: &str,
matched_scope: Option<&str>,
reason: Option<&str>,
) {
tracing::info!(
target: "a2a.audit",
decision,
principal,
action,
resource,
relay_id,
node_id,
matched_scope = matched_scope.unwrap_or(""),
reason = reason.unwrap_or(""),
"a2a relay audit"
);
}
impl RelayHub {
pub fn new() -> Self {
Self::default()
}
pub fn connection_count(&self) -> usize {
self.connections.len()
}
pub fn connected_nodes(&self) -> Vec<String> {
let mut nodes: Vec<String> = self
.connections
.iter()
.map(|entry| entry.key().clone())
.collect();
nodes.sort();
nodes
}
pub fn route_for(&self, agent_ref: &str) -> Option<RouteEntry> {
let entry = self.routes.get(agent_ref)?;
if entry.expires_at <= std::time::Instant::now() {
drop(entry);
self.routes.remove(agent_ref);
self.metrics
.route_expirations
.fetch_add(1, Ordering::Relaxed);
return None;
}
Some(entry.clone())
}
pub fn route_count(&self) -> usize {
self.routes.len()
}
pub fn apply_route_lease(
&self,
node_id: &str,
agents: &[String],
ttl_ms: u64,
epoch: u64,
) -> Result<()> {
let ttl = Duration::from_millis(ttl_ms.max(1));
let expires_at = std::time::Instant::now() + ttl;
for agent_ref in agents {
validate_agent_ref(agent_ref)?;
if !agent_ref.starts_with(&format!("{node_id}/")) {
anyhow::bail!("node '{node_id}' cannot advertise '{agent_ref}'");
}
if let Some(existing) = self.routes.get(agent_ref)
&& existing.epoch > epoch
&& existing.expires_at > std::time::Instant::now()
{
continue;
}
self.routes.insert(
agent_ref.clone(),
RouteEntry {
agent_ref: agent_ref.clone(),
node_id: node_id.to_owned(),
epoch,
expires_at,
},
);
}
Ok(())
}
pub async fn invoke_jsonrpc(
&self,
target: &str,
method: &str,
params: Value,
principal: &str,
) -> Result<JsonRpcResponse> {
let route = self
.route_for(target)
.ok_or_else(|| anyhow!("no live relay route for {target}"))?;
let conn = self
.connections
.get(&route.node_id)
.ok_or_else(|| anyhow!("node '{}' is not connected", route.node_id))?;
let request_id = format!("relay:{}", Uuid::new_v4());
let (tx, rx) = oneshot::channel();
let node_id = route.node_id.clone();
self.pending
.insert(request_id.clone(), (tx, node_id.clone()));
let frame = RelayFrame::Request {
request_id: request_id.clone(),
target: target.to_owned(),
method: method.to_owned(),
params,
principal: principal.to_owned(),
deadline_ms: REQUEST_TIMEOUT.as_millis() as u64,
};
let msg = AxumWsMessage::Text(serde_json::to_string(&frame)?.into());
if let Err(e) = conn.tx.send(msg) {
self.pending.remove(&request_id);
anyhow::bail!("relay send to node '{}' failed: {e}", node_id);
}
drop(conn);
let started = std::time::Instant::now();
self.metrics.request_count.fetch_add(1, Ordering::Relaxed);
let result = match tokio::time::timeout(REQUEST_TIMEOUT, rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(_)) => Err(anyhow!("relay response channel closed")),
Err(_) => {
self.pending.remove(&request_id);
Err(anyhow!("relay request timed out"))
}
};
let elapsed_ms = started.elapsed().as_millis().min(u64::MAX as u128) as u64;
self.metrics
.request_latency_ms_total
.fetch_add(elapsed_ms, Ordering::Relaxed);
result
}
fn register_connection(
&self,
node_id: &str,
tx: mpsc::UnboundedSender<AxumWsMessage>,
epoch: u64,
) {
self.connections
.insert(node_id.to_owned(), Connection { tx, epoch });
}
fn unregister_connection(&self, node_id: &str, epoch: u64) {
if let Some(conn) = self.connections.get(node_id)
&& conn.epoch != epoch
{
return;
}
self.connections.remove(node_id);
let prefix = format!("{node_id}/");
let stale: Vec<String> = self
.routes
.iter()
.filter(|entry| entry.key().starts_with(&prefix))
.map(|entry| entry.key().clone())
.collect();
for key in stale {
self.routes.remove(&key);
}
let lost: Vec<String> = self
.stream_pending
.iter()
.filter(|entry| entry.value().agent_ref.starts_with(&prefix))
.map(|entry| entry.key().clone())
.collect();
for request_id in lost {
let synthetic = serde_json::json!({
"kind": "status-update",
"taskId": "",
"contextId": "",
"status": {
"state": "failed",
"message": {
"role": "agent",
"messageId": format!("relay-loss-{}", Uuid::new_v4()),
"parts": [{
"kind": "text",
"text": format!("relay route lost: node '{node_id}' disconnected"),
}],
}
},
"final": true,
});
self.forward_stream_event(&request_id, synthetic);
self.stream_pending.remove(&request_id);
self.metrics.inflight_losses.fetch_add(1, Ordering::Relaxed);
}
let pending_keys: Vec<String> = self
.pending
.iter()
.filter_map(|e| (e.value().1 == node_id).then(|| e.key().clone()))
.collect();
for k in pending_keys {
if let Some((_, (tx, _))) = self.pending.remove(&k) {
let _ = tx.send(JsonRpcResponse::err(
Value::Null,
-32004,
format!("relay node '{node_id}' disconnected"),
));
}
}
}
fn complete_pending(&self, request_id: &str, response: JsonRpcResponse) {
if let Some((_, (tx, _node))) = self.pending.remove(request_id) {
let _ = tx.send(response);
}
}
pub async fn invoke_streaming(
&self,
target: &str,
method: &str,
params: Value,
principal: &str,
) -> Result<(String, String, broadcast::Receiver<Value>)> {
let route = self
.route_for(target)
.ok_or_else(|| anyhow!("no live relay route for {target}"))?;
let conn = self
.connections
.get(&route.node_id)
.ok_or_else(|| anyhow!("node '{}' is not connected", route.node_id))?;
let request_id = format!("relay:stream:{}", Uuid::new_v4());
let (event_tx, event_rx) = broadcast::channel(128);
self.stream_pending.insert(
request_id.clone(),
StreamPending {
tx: event_tx,
agent_ref: target.to_owned(),
node_id: route.node_id.clone(),
deadline: std::time::Instant::now() + STREAM_MAX_LIFETIME,
},
);
let frame = RelayFrame::Request {
request_id: request_id.clone(),
target: target.to_owned(),
method: method.to_owned(),
params,
principal: principal.to_owned(),
deadline_ms: REQUEST_TIMEOUT.as_millis() as u64,
};
let msg = AxumWsMessage::Text(serde_json::to_string(&frame)?.into());
if let Err(e) = conn.tx.send(msg) {
self.stream_pending.remove(&request_id);
anyhow::bail!("relay send to node '{}' failed: {e}", route.node_id);
}
Ok((request_id, route.node_id, event_rx))
}
fn send_cancel_to(&self, node_id: &str, request_id: &str) {
let Some(conn) = self.connections.get(node_id) else {
return;
};
let frame = RelayFrame::Cancel {
request_id: request_id.to_owned(),
task_id: None,
};
if let Ok(s) = serde_json::to_string(&frame) {
let _ = conn.tx.send(AxumWsMessage::Text(s.into()));
}
}
fn complete_streaming(&self, request_id: &str) -> bool {
self.stream_pending.remove(request_id).is_some()
}
fn forward_stream_event(&self, request_id: &str, value: Value) -> usize {
let Some(entry) = self.stream_pending.get(request_id) else {
return 0;
};
if let Some(task_id) = value.get("taskId").and_then(|v| v.as_str()) {
self.task_routes
.insert(task_id.to_owned(), entry.agent_ref.clone());
}
entry.tx.send(value).unwrap_or(0)
}
pub fn sweep_expired_streams(&self) -> usize {
let now = std::time::Instant::now();
let expired: Vec<(String, String)> = self
.stream_pending
.iter()
.filter(|e| e.value().deadline <= now)
.map(|e| (e.key().clone(), e.value().node_id.clone()))
.collect();
for (request_id, node_id) in &expired {
let synthetic = serde_json::json!({
"kind": "status-update",
"taskId": "",
"contextId": "",
"status": {
"state": "failed",
"message": {
"role": "agent",
"messageId": format!("relay-deadline-{}", Uuid::new_v4()),
"parts": [{
"kind": "text",
"text": format!(
"relay stream exceeded {}s lifetime cap; aborting",
STREAM_MAX_LIFETIME.as_secs()
),
}],
}
},
"final": true,
});
self.forward_stream_event(request_id, synthetic);
self.stream_pending.remove(request_id);
self.send_cancel_to(node_id, request_id);
self.metrics.inflight_losses.fetch_add(1, Ordering::Relaxed);
warn!(
request_id = %request_id,
node_id = %node_id,
"relay stream hit deadline — synthetic failure emitted"
);
}
expired.len()
}
pub fn record_task_route(&self, task_id: &str, agent_ref: &str) {
self.task_routes
.insert(task_id.to_owned(), agent_ref.to_owned());
}
pub fn route_for_task(&self, task_id: &str) -> Option<String> {
self.task_routes.get(task_id).map(|e| e.clone())
}
}
pub struct RelayStreamGuard {
relay_hub: std::sync::Arc<RelayHub>,
node_id: String,
request_id: String,
}
impl RelayStreamGuard {
pub fn new(relay_hub: std::sync::Arc<RelayHub>, node_id: String, request_id: String) -> Self {
Self {
relay_hub,
node_id,
request_id,
}
}
}
impl Drop for RelayStreamGuard {
fn drop(&mut self) {
if self.relay_hub.complete_streaming(&self.request_id) {
self.relay_hub
.send_cancel_to(&self.node_id, &self.request_id);
}
}
}
pub fn validate_agent_ref(agent_ref: &str) -> Result<()> {
let Some((node, agent)) = agent_ref.split_once('/') else {
anyhow::bail!("agent_ref must be '<node>/<agent>'");
};
if node.is_empty() || agent.is_empty() || agent.contains('/') {
anyhow::bail!("invalid agent_ref '{agent_ref}'");
}
Ok(())
}
pub fn local_agent_from_ref(agent_ref: &str, node_id: &str) -> Option<String> {
let (node, agent) = agent_ref.split_once('/')?;
(node == node_id && !agent.is_empty() && !agent.contains('/')).then(|| agent.to_owned())
}
pub fn scope_allows(scopes: &[String], namespace: &str, action: &str, target: &str) -> bool {
let exact = format!("{namespace}:{action}:{target}");
let all = format!("{namespace}:{action}:*");
scopes.iter().any(|scope| {
scope == &exact
|| scope == &all
|| scope
.strip_suffix("/*")
.is_some_and(|prefix| exact.starts_with(&format!("{prefix}/")))
})
}
pub fn can_invoke(identity: Option<&A2aIdentity>, target: &str) -> bool {
match identity {
None => true,
Some(id) if id.id == "gateway-auth" => true,
Some(id) => scope_allows(&id.scopes, "a2a", "invoke", target),
}
}
fn default_node_scopes(node_id: &str, relay_id: &str) -> Vec<String> {
vec![
format!("relay:connect:{relay_id}"),
format!("relay:advertise:{node_id}/*"),
format!("relay:receive:{node_id}/*"),
]
}
fn resolve_node(relay: &A2aRelayRuntime, node_id: &str) -> Option<A2aRelayNodeRuntime> {
if relay.revoked_nodes.iter().any(|n| n == node_id) {
return None;
}
relay
.nodes
.iter()
.find_map(|node| (node.node_id == node_id).then(|| node.clone()))
}
fn verify_node_token(node: &A2aRelayNodeRuntime, token: &str) -> bool {
!node.token.is_empty() && constant_time_eq(&node.token, token)
}
fn relay_connect_token_allows(node: &A2aRelayNodeRuntime, presented: Option<&str>) -> bool {
if node.token.is_empty() {
return node.public_key.is_some();
}
presented.is_some_and(|token| verify_node_token(node, token))
}
#[derive(Debug, Deserialize)]
pub struct RelayWsQuery {
node_id: String,
#[serde(default)]
token: Option<String>,
}
pub async fn relay_ws_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Query(query): Query<RelayWsQuery>,
headers: HeaderMap,
) -> Response {
let relay = &state.config.gateway.a2a_relay;
if relay.mode != A2aRelayModeRuntime::Hub {
return axum::http::StatusCode::NOT_FOUND.into_response();
}
let Some(mut node) = resolve_node(relay, &query.node_id) else {
state
.relay_hub
.metrics
.auth_failures
.fetch_add(1, Ordering::Relaxed);
audit_relay(
"deny",
&format!("node:{}", query.node_id),
"connect",
&format!("relay:{}", relay.relay_id),
&relay.relay_id,
&query.node_id,
None,
Some("unknown or revoked node"),
);
return axum::http::StatusCode::UNAUTHORIZED.into_response();
};
let presented = query.token.as_deref().or_else(|| bearer_token(&headers));
if !relay_connect_token_allows(&node, presented) {
state
.relay_hub
.metrics
.auth_failures
.fetch_add(1, Ordering::Relaxed);
let reason = if node.token.is_empty() {
"no token configured; keypair handshake required"
} else if presented.is_none() {
"no token presented"
} else {
"token mismatch"
};
audit_relay(
"deny",
&format!("node:{}", node.node_id),
"connect",
&format!("relay:{}", relay.relay_id),
&relay.relay_id,
&node.node_id,
None,
Some(reason),
);
return axum::http::StatusCode::UNAUTHORIZED.into_response();
}
if node.scopes.is_empty() {
node.scopes = default_node_scopes(&node.node_id, &relay.relay_id);
}
if !scope_allows(&node.scopes, "relay", "connect", &relay.relay_id) {
state
.relay_hub
.metrics
.acl_denials
.fetch_add(1, Ordering::Relaxed);
audit_relay(
"deny",
&format!("node:{}", node.node_id),
"connect",
&format!("relay:{}", relay.relay_id),
&relay.relay_id,
&node.node_id,
None,
Some("relay:connect scope missing"),
);
return axum::http::StatusCode::FORBIDDEN.into_response();
}
ws.on_upgrade(move |socket| handle_hub_socket(socket, state, node))
}
fn bearer_token(headers: &HeaderMap) -> Option<&str> {
headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
}
async fn hub_keypair_handshake<S, R>(
sink: &mut S,
stream: &mut R,
node: &A2aRelayNodeRuntime,
public_key_b64: &str,
relay_id: &str,
) -> std::result::Result<(), String>
where
S: futures::Sink<AxumWsMessage> + Unpin,
R: futures::Stream<Item = std::result::Result<AxumWsMessage, axum::Error>> + Unpin,
{
let hello = match tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.next()).await {
Ok(Some(Ok(AxumWsMessage::Text(text)))) => text,
Ok(Some(Ok(_))) => return Err("first frame was not Text".to_owned()),
Ok(Some(Err(e))) => return Err(format!("ws error: {e}")),
Ok(None) => return Err("stream closed before Hello".to_owned()),
Err(_) => return Err("Hello timed out".to_owned()),
};
let nonce_node = match serde_json::from_str::<RelayFrame>(&hello) {
Ok(RelayFrame::Hello {
nonce_node: Some(n),
node_id,
..
}) => {
if node_id != node.node_id {
return Err(format!("hello node_id mismatch: claimed {node_id}"));
}
n
}
Ok(RelayFrame::Hello {
nonce_node: None, ..
}) => {
return Err("hello missing nonce_node (keypair mode required)".to_owned());
}
Ok(_) => return Err("first frame was not Hello".to_owned()),
Err(e) => return Err(format!("invalid Hello frame: {e}")),
};
let nonce_relay = relay_identity::fresh_nonce_b64();
let challenge = RelayFrame::Challenge {
relay_id: relay_id.to_owned(),
nonce_relay: nonce_relay.clone(),
};
let payload =
serde_json::to_string(&challenge).map_err(|e| format!("serialize Challenge: {e}"))?;
if sink
.send(AxumWsMessage::Text(payload.into()))
.await
.is_err()
{
return Err("send Challenge failed".to_owned());
}
let auth_text = match tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.next()).await {
Ok(Some(Ok(AxumWsMessage::Text(text)))) => text,
Ok(Some(Ok(_))) => return Err("second frame was not Text".to_owned()),
Ok(Some(Err(e))) => return Err(format!("ws error: {e}")),
Ok(None) => return Err("stream closed before Auth".to_owned()),
Err(_) => return Err("Auth timed out".to_owned()),
};
let signature = match serde_json::from_str::<RelayFrame>(&auth_text) {
Ok(RelayFrame::Auth { signature }) => signature,
Ok(_) => return Err("second frame was not Auth".to_owned()),
Err(e) => return Err(format!("invalid Auth frame: {e}")),
};
relay_identity::verify_handshake(
public_key_b64,
&node.node_id,
relay_id,
&nonce_node,
&nonce_relay,
&signature,
)
.map_err(|e| e.to_string())
}
async fn handle_hub_socket(socket: WebSocket, state: AppState, node: A2aRelayNodeRuntime) {
let epoch = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis().min(u64::MAX as u128) as u64)
.unwrap_or(0);
let (mut sink, mut stream) = socket.split();
if let Some(public_key_b64) = node.public_key.as_deref() {
let relay_id = state.config.gateway.a2a_relay.relay_id.clone();
match hub_keypair_handshake(&mut sink, &mut stream, &node, public_key_b64, &relay_id).await
{
Ok(()) => {
audit_relay(
"allow",
&format!("node:{}", node.node_id),
"connect",
&format!("relay:{}", relay_id),
&relay_id,
&node.node_id,
Some("ed25519_handshake"),
None,
);
}
Err(reason) => {
state
.relay_hub
.metrics
.auth_failures
.fetch_add(1, Ordering::Relaxed);
audit_relay(
"deny",
&format!("node:{}", node.node_id),
"connect",
&format!("relay:{}", relay_id),
&relay_id,
&node.node_id,
None,
Some(&format!("keypair handshake failed: {reason}")),
);
let _ = sink.send(AxumWsMessage::Close(None)).await;
return;
}
}
}
let (tx, mut rx) = mpsc::unbounded_channel::<AxumWsMessage>();
let ping_tx = tx.clone();
state
.relay_hub
.register_connection(&node.node_id, tx, epoch);
info!(node = %node.node_id, "a2a relay node connected");
let writer = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if sink.send(msg).await.is_err() {
break;
}
}
});
let ping = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(15));
loop {
interval.tick().await;
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
if ping_tx
.send(AxumWsMessage::Ping(Vec::new().into()))
.is_err()
{
break;
}
let frame = RelayFrame::Ping { ts };
if let Ok(msg) = serde_json::to_string(&frame) {
if ping_tx.send(AxumWsMessage::Text(msg.into())).is_err() {
break;
}
}
}
});
while let Some(msg) = stream.next().await {
let Ok(msg) = msg else {
break;
};
let AxumWsMessage::Text(text) = msg else {
continue;
};
match serde_json::from_str::<RelayFrame>(&text) {
Ok(frame) => handle_hub_frame(&state, &node, frame).await,
Err(e) => warn!(node = %node.node_id, error = %e, "invalid a2a relay frame"),
}
}
state.relay_hub.unregister_connection(&node.node_id, epoch);
writer.abort();
ping.abort();
info!(node = %node.node_id, "a2a relay node disconnected");
}
async fn handle_hub_frame(state: &AppState, node: &A2aRelayNodeRuntime, frame: RelayFrame) {
match frame {
RelayFrame::Hello {
protocol,
node_id,
capabilities,
..
} => {
if protocol != RELAY_PROTOCOL || node_id != node.node_id {
warn!(node = %node.node_id, protocol, claimed = %node_id, "relay hello mismatch");
}
if let Some(caps) = capabilities {
info!(
node = %node.node_id,
streaming_relay = caps.streaming_relay,
"relay node capabilities"
);
}
}
RelayFrame::RouteLease {
node_id,
agents,
ttl_ms,
epoch,
} => {
let relay_id = state.config.gateway.a2a_relay.relay_id.as_str();
if node_id != node.node_id {
warn!(node = %node.node_id, claimed = %node_id, "relay route lease node mismatch");
state
.relay_hub
.metrics
.acl_denials
.fetch_add(1, Ordering::Relaxed);
audit_relay(
"deny",
&format!("node:{}", node.node_id),
"advertise",
&format!("node:{node_id}"),
relay_id,
&node.node_id,
None,
Some("route lease node mismatch"),
);
return;
}
for agent in &agents {
if !scope_allows(&node.scopes, "relay", "advertise", agent) {
warn!(node = %node.node_id, agent, "relay advertise denied");
state
.relay_hub
.metrics
.acl_denials
.fetch_add(1, Ordering::Relaxed);
audit_relay(
"deny",
&format!("node:{}", node.node_id),
"advertise",
&format!("agent:{agent}"),
relay_id,
&node.node_id,
None,
Some("relay:advertise scope missing"),
);
return;
}
}
if let Err(e) = state
.relay_hub
.apply_route_lease(&node.node_id, &agents, ttl_ms, epoch)
{
warn!(node = %node.node_id, error = %e, "relay route lease rejected");
}
}
RelayFrame::Auth { .. } | RelayFrame::Challenge { .. } => {
debug!(node = %node.node_id, "handshake frame after registration; ignored");
}
RelayFrame::Response {
request_id,
response,
} => {
if !state.relay_hub.complete_streaming(&request_id) {
state.relay_hub.complete_pending(&request_id, response);
}
}
RelayFrame::Event {
request_id, result, ..
} => {
if state.relay_hub.forward_stream_event(&request_id, result) == 0 {
debug!(request_id, "relay event for unknown stream");
}
}
RelayFrame::Pong { .. } => {}
other => debug!(node = %node.node_id, frame = ?other, "hub ignored relay frame"),
}
}
pub fn relay_target_from_params(params: &Value) -> Option<String> {
params
.get("metadata")
.and_then(|m| m.get("agentId"))
.and_then(|v| v.as_str())
.filter(|target| target.contains('/'))
.map(str::to_owned)
}
const FORWARDABLE_METHODS: &[&str] = &[
"SendMessage",
"GetTask",
"CancelTask",
"CreateTaskPushNotificationConfig",
"GetTaskPushNotificationConfig",
"ListTaskPushNotificationConfigs",
"DeleteTaskPushNotificationConfig",
];
pub fn task_id_from_params(params: &Value) -> Option<&str> {
params
.get("id")
.and_then(|v| v.as_str())
.or_else(|| params.get("taskId").and_then(|v| v.as_str()))
}
pub fn relay_target_from_request(hub: &RelayHub, req: &JsonRpcRequest) -> Option<String> {
if !FORWARDABLE_METHODS.contains(&req.method.as_str()) {
return None;
}
if let Some(target) = relay_target_from_params(&req.params) {
return Some(target);
}
task_id_from_params(&req.params).and_then(|tid| hub.route_for_task(tid))
}
pub async fn try_forward_jsonrpc(
state: &AppState,
caller: Option<&A2aIdentity>,
req: &JsonRpcRequest,
) -> Option<JsonRpcResponse> {
let target = relay_target_from_request(&state.relay_hub, req)?;
if state.relay_hub.route_for(&target).is_none() {
return None;
}
let relay_id = state.config.gateway.a2a_relay.relay_id.as_str();
let principal_id = caller.map(|id| id.id.as_str()).unwrap_or("anonymous-dev");
if !can_invoke(caller, &target) {
state
.relay_hub
.metrics
.acl_denials
.fetch_add(1, Ordering::Relaxed);
let target_node = target.split('/').next().unwrap_or("");
audit_relay(
"deny",
principal_id,
"invoke",
&format!("agent:{target}"),
relay_id,
target_node,
None,
Some("a2a:invoke scope missing"),
);
return Some(JsonRpcResponse::err(
req.id.clone(),
-32003,
format!("not authorized to invoke {target}"),
));
}
let target_node = target.split('/').next().unwrap_or("");
audit_relay(
"allow",
principal_id,
"invoke",
&format!("agent:{target}"),
relay_id,
target_node,
None,
Some("cross_node"),
);
let principal = principal_id;
let mut params = req.params.clone();
rewrite_target_agent_for_spoke(&mut params, &target);
match state
.relay_hub
.invoke_jsonrpc(&target, &req.method, params, principal)
.await
{
Ok(mut response) => {
response.id = req.id.clone();
if let Some(task_id) = response
.result
.as_ref()
.and_then(|r| r.get("id"))
.and_then(|v| v.as_str())
{
state.relay_hub.record_task_route(task_id, &target);
}
Some(response)
}
Err(e) => Some(JsonRpcResponse::err(req.id.clone(), -32004, e.to_string())),
}
}
pub(crate) fn rewrite_target_agent_for_spoke(params: &mut Value, target: &str) {
let Some((_, agent)) = target.split_once('/') else {
return;
};
if let Some(metadata) = params.get_mut("metadata").and_then(|m| m.as_object_mut()) {
metadata.insert("agentId".to_owned(), Value::String(agent.to_owned()));
}
}
pub async fn relay_stats_handler(State(state): State<AppState>) -> Json<serde_json::Value> {
let nodes = state.relay_hub.connected_nodes();
let snapshot = state
.relay_hub
.metrics
.snapshot(nodes.len() as u64, state.relay_hub.route_count() as u64);
Json(serde_json::json!({
"relay_id": state.config.gateway.a2a_relay.relay_id,
"mode": match state.config.gateway.a2a_relay.mode {
A2aRelayModeRuntime::Disabled => "disabled",
A2aRelayModeRuntime::Hub => "hub",
A2aRelayModeRuntime::Spoke => "spoke",
},
"connected_node_ids": nodes,
"metrics": snapshot,
}))
}
pub fn start_spoke_if_configured(state: AppState) {
if state.config.gateway.a2a_relay.mode != A2aRelayModeRuntime::Spoke {
return;
}
let relay = state.config.gateway.a2a_relay.clone();
if relay.hub_urls.is_empty() {
warn!("a2a relay spoke mode set but no hub URLs configured");
return;
}
tokio::spawn(async move {
let strategy = relay.strategy.clone();
if strategy == A2aRelayStrategyRuntime::MultiHome {
warn!("a2a relay strategy=multi_home not yet supported, using primary_standby");
}
let urls = relay.hub_urls.clone();
let mut idx: usize = 0;
let mut per_relay_delay = Duration::from_secs(1);
loop {
let hub_url = &urls[idx];
let connect_start = std::time::Instant::now();
match run_spoke_once(state.clone(), &relay, hub_url).await {
Ok(()) => {
idx = 0;
per_relay_delay = Duration::from_secs(1);
info!(hub = %hub_url, "a2a relay spoke session ended cleanly, returning to primary");
}
Err(e) => {
let was_long_lived = connect_start.elapsed() > Duration::from_secs(60);
warn!(
error = %e,
hub = %hub_url,
idx,
"a2a relay spoke disconnected"
);
state
.relay_hub
.metrics
.ws_reconnects
.fetch_add(1, Ordering::Relaxed);
if was_long_lived {
per_relay_delay = Duration::from_secs(1);
} else {
if urls.len() > 1 {
idx = (idx + 1) % urls.len();
state
.relay_hub
.metrics
.failovers
.fetch_add(1, Ordering::Relaxed);
info!(next_hub = %urls[idx], "a2a relay failing over");
}
per_relay_delay = (per_relay_delay * 2).min(FAILOVER_BACKOFF_MAX);
}
tokio::time::sleep(per_relay_delay).await;
}
}
}
});
}
async fn run_spoke_once(state: AppState, relay: &A2aRelayRuntime, hub_url: &str) -> Result<()> {
let node_id = relay
.node_id
.as_deref()
.ok_or_else(|| anyhow!("a2a relay spoke node_id is required"))?;
let token = relay.token.as_deref();
let signing_key = match relay.private_key.as_deref() {
Some(pk) => Some(
relay_identity::signing_key_from_b64(pk).context("parse spoke private_key (base64)")?,
),
None => None,
};
if token.is_none() && signing_key.is_none() {
anyhow::bail!("a2a relay spoke requires either token or private_key");
}
let sep = if hub_url.contains('?') { '&' } else { '?' };
let mut url = format!("{hub_url}{sep}node_id={}", urlencoding::encode(node_id));
if let Some(t) = token {
url.push_str(&format!("&token={}", urlencoding::encode(t)));
}
let (stream, _) = tokio_tungstenite::connect_async(&url)
.await
.with_context(|| format!("connect relay hub {hub_url}"))?;
info!(node = %node_id, hub = %hub_url, keypair = signing_key.is_some(), "a2a relay spoke connected");
let (mut write, mut read) = stream.split();
enum SpokeWriteItem {
Frame(RelayFrame),
WsPing,
}
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<SpokeWriteItem>();
let writer = tokio::spawn(async move {
while let Some(item) = write_rx.recv().await {
let result = match item {
SpokeWriteItem::Frame(frame) => send_spoke_frame(&mut write, &frame).await,
SpokeWriteItem::WsPing => write
.send(tokio_tungstenite::tungstenite::Message::Ping(
Vec::new().into(),
))
.await
.map_err(anyhow::Error::from),
};
if let Err(e) = result {
warn!(error = %e, "spoke write error");
break;
}
}
});
let (spoke_tx, mut frame_rx) = mpsc::unbounded_channel::<RelayFrame>();
let frame_adapter_tx = write_tx.clone();
let frame_adapter = tokio::spawn(async move {
while let Some(frame) = frame_rx.recv().await {
if frame_adapter_tx.send(SpokeWriteItem::Frame(frame)).is_err() {
break;
}
}
});
let nonce_node = signing_key
.as_ref()
.map(|_| relay_identity::fresh_nonce_b64());
spoke_tx
.send(spoke_hello(&state, node_id, nonce_node.clone()))
.map_err(|_| anyhow!("spoke writer closed"))?;
if signing_key.is_none() {
spoke_tx
.send(spoke_route_lease(&state, node_id, 1))
.map_err(|_| anyhow!("spoke writer closed"))?;
}
let ping_tx = write_tx.clone();
let pinger = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(15));
interval.tick().await; loop {
interval.tick().await;
if ping_tx.send(SpokeWriteItem::WsPing).is_err() {
break;
}
}
});
let renew_tx = spoke_tx.clone();
let renew_state = state.clone();
let renew_node_id = node_id.to_owned();
let renewer = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_millis(ROUTE_TTL_MS / 3));
interval.tick().await;
let mut epoch: u64 = 2;
loop {
interval.tick().await;
let frame = spoke_route_lease(&renew_state, &renew_node_id, epoch);
if renew_tx.send(frame).is_err() {
break;
}
epoch = epoch.saturating_add(1);
}
});
while let Some(msg) = read.next().await {
let msg = msg?;
let tokio_tungstenite::tungstenite::Message::Text(text) = msg else {
continue;
};
let frame: RelayFrame = serde_json::from_str(&text)?;
match frame {
RelayFrame::Challenge {
relay_id,
nonce_relay,
} => {
let Some(sk) = signing_key.as_ref() else {
anyhow::bail!("hub sent Challenge but spoke has no private_key");
};
let Some(nn) = nonce_node.as_deref() else {
anyhow::bail!("Challenge received without our nonce_node — protocol drift");
};
let sig = relay_identity::sign_handshake(sk, node_id, &relay_id, nn, &nonce_relay);
spoke_tx
.send(RelayFrame::Auth { signature: sig })
.map_err(|_| anyhow!("spoke writer closed"))?;
spoke_tx
.send(spoke_route_lease(&state, node_id, 1))
.map_err(|_| anyhow!("spoke writer closed"))?;
}
RelayFrame::Request {
request_id,
target,
method,
params,
principal,
..
} => {
let response = handle_spoke_request(
&state,
node_id,
&request_id,
&target,
&method,
params,
principal,
spoke_tx.clone(),
)
.await;
if let Some(response) = response {
let _ = spoke_tx.send(RelayFrame::Response {
request_id,
response,
});
}
}
RelayFrame::Ping { ts } => {
let _ = spoke_tx.send(RelayFrame::Pong { ts });
}
RelayFrame::Cancel { request_id, .. } => {
if let Some((_, task_id)) = state.relay_hub.spoke_stream_tasks.remove(&request_id)
&& let Some((_, token)) = state.task_cancels.remove(&task_id)
{
token.cancel();
}
}
_ => {}
}
}
let request_ids: Vec<String> = state
.relay_hub
.spoke_stream_tasks
.iter()
.map(|e| e.key().clone())
.collect();
for rid in request_ids {
if let Some((_, task_id)) = state.relay_hub.spoke_stream_tasks.remove(&rid)
&& let Some((_, token)) = state.task_cancels.remove(&task_id)
{
token.cancel();
}
}
writer.abort();
renewer.abort();
pinger.abort();
frame_adapter.abort();
Ok(())
}
fn spoke_hello(state: &AppState, node_id: &str, nonce_node: Option<String>) -> RelayFrame {
RelayFrame::Hello {
protocol: RELAY_PROTOCOL.to_owned(),
node_id: node_id.to_owned(),
node_version: Some(env!("CARGO_PKG_VERSION").to_owned()),
agent_card: Some(crate::a2a::server::build_agent_card(state, false)),
capabilities: Some(HelloCapabilities {
streaming_relay: true,
}),
nonce_node,
}
}
fn spoke_route_lease(state: &AppState, node_id: &str, epoch: u64) -> RelayFrame {
let agents = state
.agents
.all()
.into_iter()
.map(|agent| format!("{node_id}/{}", agent.id))
.collect();
RelayFrame::RouteLease {
node_id: node_id.to_owned(),
agents,
ttl_ms: ROUTE_TTL_MS,
epoch,
}
}
async fn send_spoke_frame<W>(write: &mut W, frame: &RelayFrame) -> Result<()>
where
W: futures::Sink<tokio_tungstenite::tungstenite::Message> + Unpin,
W::Error: std::error::Error + Send + Sync + 'static,
{
write
.send(tokio_tungstenite::tungstenite::Message::Text(
serde_json::to_string(frame)?.into(),
))
.await?;
Ok(())
}
async fn handle_spoke_request(
state: &AppState,
node_id: &str,
request_id: &str,
target: &str,
method: &str,
params: Value,
principal: String,
spoke_tx: mpsc::UnboundedSender<RelayFrame>,
) -> Option<JsonRpcResponse> {
let Some(local_agent) = local_agent_from_ref(target, node_id) else {
return Some(JsonRpcResponse::err(
Value::Null,
-32003,
format!("target not hosted here: {target}"),
));
};
let mut params = params;
if let Some(metadata) = params.get_mut("metadata").and_then(|m| m.as_object_mut()) {
metadata.insert("agentId".to_owned(), Value::String(local_agent));
}
if method == "SendStreamingMessage" || method == "SubscribeToTask" {
let caller = Some(A2aIdentity {
id: principal,
scopes: Vec::new(),
});
let (task_id, event_rx) =
crate::a2a::streaming::spawn_streaming_task(state.clone(), caller, params).await;
let request_id_for_relay = request_id.to_owned();
state
.relay_hub
.spoke_stream_tasks
.insert(request_id_for_relay.clone(), task_id.clone());
let relay_hub = state.relay_hub.clone();
tokio::spawn(async move {
use futures::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
let mut seq = 0u64;
let mut stream = BroadcastStream::new(event_rx);
while let Some(result) = stream.next().await {
match result {
Ok(event) => {
let wire = event.to_wire_event();
if spoke_tx
.send(RelayFrame::Event {
request_id: request_id_for_relay.clone(),
seq,
result: wire,
})
.is_err()
{
break;
}
seq += 1;
if event.is_final() {
break;
}
}
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => {
warn!(lagged = n, "spoke relay event lagged");
}
}
}
relay_hub.spoke_stream_tasks.remove(&request_id_for_relay);
let _ = spoke_tx.send(RelayFrame::Response {
request_id: request_id_for_relay,
response: JsonRpcResponse::ok(
Value::String(task_id),
serde_json::json!({"ok": true}),
),
});
});
return None;
}
let req = JsonRpcRequest {
jsonrpc: "2.0".to_owned(),
id: Value::String(format!("spoke:{}", Uuid::new_v4())),
method: method.to_owned(),
params,
};
let caller = Some(A2aIdentity {
id: principal,
scopes: Vec::new(),
});
Some(
crate::a2a::server::a2a_rpc_handler_inner(state.clone(), caller, req)
.await
.0,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn suffix_scopes_match_agent_children_only() {
let scopes = vec!["a2a:invoke:a3/*".to_owned()];
assert!(scope_allows(&scopes, "a2a", "invoke", "a3/main"));
assert!(!scope_allows(&scopes, "a2a", "invoke", "a30/main"));
assert!(!scope_allows(&scopes, "a2a", "cancel", "a3/main"));
}
#[test]
fn route_lease_rejects_cross_node_advertise() {
let hub = RelayHub::new();
let err = hub
.apply_route_lease("a1", &["a3/main".to_owned()], 10_000, 1)
.expect_err("cross-node route should fail");
assert!(err.to_string().contains("cannot advertise"));
}
#[test]
fn route_lease_adds_live_route() {
let hub = RelayHub::new();
hub.apply_route_lease("a1", &["a1/main".to_owned()], 10_000, 1)
.unwrap();
let route = hub.route_for("a1/main").expect("route");
assert_eq!(route.node_id, "a1");
}
#[test]
fn gateway_auth_can_invoke_everything() {
let id = A2aIdentity {
id: "gateway-auth".to_owned(),
scopes: Vec::new(),
};
assert!(can_invoke(Some(&id), "a3/main"));
}
#[test]
fn scoped_identity_can_only_invoke_allowed_target() {
let id = A2aIdentity {
id: "node:a1".to_owned(),
scopes: vec!["a2a:invoke:a3/main".to_owned()],
};
assert!(can_invoke(Some(&id), "a3/main"));
assert!(!can_invoke(Some(&id), "a3/coder"));
}
#[test]
fn keypair_node_with_token_still_requires_matching_token() {
let node = A2aRelayNodeRuntime {
node_id: "a1".to_owned(),
token: "secret".to_owned(),
public_key: Some("pk".to_owned()),
roles: Vec::new(),
scopes: Vec::new(),
};
assert!(relay_connect_token_allows(&node, Some("secret")));
assert!(!relay_connect_token_allows(&node, None));
assert!(!relay_connect_token_allows(&node, Some("wrong")));
}
#[tokio::test]
async fn hub_invocation_sends_request_and_returns_response() {
let hub = std::sync::Arc::new(RelayHub::new());
let (tx, mut rx) = mpsc::unbounded_channel();
hub.register_connection("a3", tx, 1);
hub.apply_route_lease("a3", &["a3/main".to_owned()], 10_000, 1)
.unwrap();
let invoke_hub = std::sync::Arc::clone(&hub);
let invoke = tokio::spawn(async move {
invoke_hub
.invoke_jsonrpc(
"a3/main",
"SendMessage",
serde_json::json!({"metadata": {"agentId": "main"}}),
"node:a1",
)
.await
});
let msg = tokio::time::timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
let AxumWsMessage::Text(text) = msg else {
panic!("expected text relay frame");
};
let frame: RelayFrame = serde_json::from_str(&text).unwrap();
let RelayFrame::Request {
request_id,
target,
principal,
..
} = frame
else {
panic!("expected request frame");
};
assert_eq!(target, "a3/main");
assert_eq!(principal, "node:a1");
hub.complete_pending(
&request_id,
JsonRpcResponse::ok(
Value::String("client-id".into()),
serde_json::json!({"ok": true}),
),
);
let response = invoke.await.unwrap().unwrap();
assert_eq!(response.result.unwrap()["ok"], true);
}
#[tokio::test]
async fn drop_guard_sends_cancel_when_stream_drops_early() {
let hub = std::sync::Arc::new(RelayHub::new());
let (tx, mut rx) = mpsc::unbounded_channel();
hub.register_connection("a3", tx, 1);
hub.apply_route_lease("a3", &["a3/main".to_owned()], 10_000, 1)
.unwrap();
let (request_id, node_id, _event_rx) = hub
.invoke_streaming(
"a3/main",
"SendStreamingMessage",
serde_json::json!({"metadata": {"agentId": "main"}}),
"node:a1",
)
.await
.unwrap();
assert_eq!(node_id, "a3");
let _req = rx.recv().await.unwrap();
let guard = RelayStreamGuard::new(hub.clone(), node_id, request_id.clone());
drop(guard);
let msg = tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("cancel frame should arrive")
.unwrap();
let AxumWsMessage::Text(text) = msg else {
panic!("expected text relay frame");
};
let frame: RelayFrame = serde_json::from_str(&text).unwrap();
match frame {
RelayFrame::Cancel {
request_id: rid, ..
} => assert_eq!(rid, request_id),
other => panic!("expected Cancel, got {other:?}"),
}
assert!(!hub.stream_pending.contains_key(&request_id));
}
#[test]
fn relay_target_falls_back_to_task_id_route() {
let hub = RelayHub::new();
hub.record_task_route("task-abc", "a3/main");
let req = JsonRpcRequest {
jsonrpc: "2.0".to_owned(),
id: Value::Null,
method: "GetTask".to_owned(),
params: serde_json::json!({"id": "task-abc"}),
};
assert_eq!(
relay_target_from_request(&hub, &req).as_deref(),
Some("a3/main")
);
let push_req = JsonRpcRequest {
jsonrpc: "2.0".to_owned(),
id: Value::Null,
method: "GetTaskPushNotificationConfig".to_owned(),
params: serde_json::json!({"taskId": "task-abc", "pushNotificationConfigId": "p1"}),
};
assert_eq!(
relay_target_from_request(&hub, &push_req).as_deref(),
Some("a3/main")
);
let bad_req = JsonRpcRequest {
jsonrpc: "2.0".to_owned(),
id: Value::Null,
method: "ListTasks".to_owned(),
params: serde_json::json!({"id": "task-abc"}),
};
assert!(relay_target_from_request(&hub, &bad_req).is_none());
}
#[test]
fn forward_stream_event_records_task_route() {
let hub = RelayHub::new();
let (event_tx, _event_rx) = broadcast::channel::<Value>(4);
hub.stream_pending.insert(
"req-1".to_owned(),
StreamPending {
tx: event_tx,
agent_ref: "a3/main".to_owned(),
node_id: "a3".to_owned(),
deadline: std::time::Instant::now() + Duration::from_secs(60),
},
);
let wire = serde_json::json!({
"kind": "status-update",
"taskId": "task-xyz",
"contextId": "ctx-1",
"status": {"state": "submitted"},
"final": false,
});
hub.forward_stream_event("req-1", wire);
assert_eq!(hub.route_for_task("task-xyz").as_deref(), Some("a3/main"));
}
#[tokio::test]
async fn unregister_surfaces_inflight_stream_as_failed() {
let hub = std::sync::Arc::new(RelayHub::new());
let (tx, mut _rx) = mpsc::unbounded_channel();
hub.register_connection("home-mac", tx, 1);
hub.apply_route_lease("home-mac", &["home-mac/main".to_owned()], 10_000, 1)
.unwrap();
let (_request_id, _node_id, mut event_rx) = hub
.invoke_streaming(
"home-mac/main",
"SendStreamingMessage",
serde_json::json!({"metadata": {"agentId": "main"}}),
"node:hub",
)
.await
.unwrap();
hub.unregister_connection("home-mac", 1);
let event = tokio::time::timeout(Duration::from_millis(200), event_rx.recv())
.await
.expect("synthetic failure event must arrive")
.expect("recv ok");
assert_eq!(event["kind"], "status-update");
assert_eq!(event["status"]["state"], "failed");
assert_eq!(event["final"], true);
assert_eq!(
hub.metrics.inflight_losses.load(Ordering::Relaxed),
1,
"inflight_losses metric must increment"
);
}
#[tokio::test]
async fn unregister_resolves_pending_jsonrpc_for_owning_node_only() {
let hub = std::sync::Arc::new(RelayHub::new());
let (tx_a, mut rx_a) = mpsc::unbounded_channel();
let (tx_b, mut _rx_b) = mpsc::unbounded_channel();
hub.register_connection("a", tx_a, 1);
hub.register_connection("b", tx_b, 1);
hub.apply_route_lease("a", &["a/main".to_owned()], 10_000, 1)
.unwrap();
hub.apply_route_lease("b", &["b/main".to_owned()], 10_000, 1)
.unwrap();
let a_hub = hub.clone();
let a_call = tokio::spawn(async move {
a_hub
.invoke_jsonrpc(
"a/main",
"SendMessage",
serde_json::json!({"metadata": {"agentId": "main"}}),
"test",
)
.await
});
let _ = tokio::time::timeout(Duration::from_millis(200), rx_a.recv())
.await
.expect("a should have received request frame");
hub.unregister_connection("a", 1);
let response = tokio::time::timeout(Duration::from_millis(500), a_call)
.await
.expect("a's call must unblock fast")
.unwrap()
.unwrap();
assert!(response.error.is_some(), "must surface as JSON-RPC error");
assert!(response.error.unwrap().message.contains("disconnected"));
}
#[test]
fn revoked_node_is_not_resolvable() {
let relay = A2aRelayRuntime {
mode: A2aRelayModeRuntime::Hub,
relay_id: "main".to_owned(),
revoked_nodes: vec!["bad-node".to_owned()],
nodes: vec![A2aRelayNodeRuntime {
node_id: "bad-node".to_owned(),
token: "anything".to_owned(),
public_key: None,
roles: vec![],
scopes: vec![],
}],
..Default::default()
};
assert!(resolve_node(&relay, "bad-node").is_none());
}
#[test]
fn keypair_node_skips_token_verification_at_resolve_time() {
let relay = A2aRelayRuntime {
mode: A2aRelayModeRuntime::Hub,
relay_id: "main".to_owned(),
nodes: vec![A2aRelayNodeRuntime {
node_id: "kp-node".to_owned(),
token: String::new(),
public_key: Some("dummy".to_owned()),
roles: vec![],
scopes: vec![],
}],
..Default::default()
};
let node = resolve_node(&relay, "kp-node").expect("found");
assert!(node.public_key.is_some());
assert!(!verify_node_token(&node, ""));
}
#[test]
fn metrics_snapshot_includes_all_spec_counters() {
let metrics = RelayMetrics::default();
metrics.request_count.store(42, Ordering::Relaxed);
metrics
.request_latency_ms_total
.store(4200, Ordering::Relaxed);
metrics.acl_denials.store(7, Ordering::Relaxed);
let v = metrics.snapshot(3, 5);
assert_eq!(v["connected_nodes"], 3);
assert_eq!(v["route_count"], 5);
assert_eq!(v["request_count"], 42);
assert_eq!(v["request_latency_ms_avg"], 100);
assert_eq!(v["acl_denials"], 7);
for key in [
"ws_reconnects",
"auth_failures",
"route_expirations",
"failovers",
"inflight_losses",
] {
assert!(v.get(key).is_some(), "missing counter: {key}");
}
}
#[test]
fn route_expiration_increments_metric() {
let hub = RelayHub::new();
hub.apply_route_lease("a", &["a/main".to_owned()], 1, 1)
.unwrap();
std::thread::sleep(Duration::from_millis(10));
assert!(hub.route_for("a/main").is_none());
assert_eq!(hub.metrics.route_expirations.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn drop_guard_no_cancel_after_normal_completion() {
let hub = std::sync::Arc::new(RelayHub::new());
let (tx, mut rx) = mpsc::unbounded_channel();
hub.register_connection("a3", tx, 1);
hub.apply_route_lease("a3", &["a3/main".to_owned()], 10_000, 1)
.unwrap();
let (request_id, node_id, _event_rx) = hub
.invoke_streaming(
"a3/main",
"SendStreamingMessage",
serde_json::json!({"metadata": {"agentId": "main"}}),
"node:a1",
)
.await
.unwrap();
let _req = rx.recv().await.unwrap();
assert!(hub.complete_streaming(&request_id));
let guard = RelayStreamGuard::new(hub.clone(), node_id, request_id.clone());
drop(guard);
let no_msg = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
assert!(
no_msg.is_err(),
"no Cancel frame expected after normal completion"
);
}
}