use std::collections::HashMap;
use std::pin::Pin;
use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::StreamExt;
use oatf::primitives::interpolate_value;
use serde_json::{Value, json};
use tokio::sync::{mpsc, watch};
use tokio_util::sync::CancellationToken;
use crate::engine::driver::PhaseDriver;
use crate::engine::types::{Direction, DriveResult, ProtocolEvent};
use crate::error::EngineError;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
const AGENT_CARD_TIMEOUT: Duration = Duration::from_secs(5);
const MAX_RETRIES: u32 = 3;
const INITIAL_RETRY_BACKOFF: Duration = Duration::from_secs(1);
const MAX_CONSECUTIVE_SSE_ERRORS: usize = 10;
const DEFAULT_STREAM_TIMEOUT: Duration = Duration::from_secs(60);
struct A2aSseParser {
inner: crate::transport::sse::SseParser,
consecutive_errors: usize,
}
impl A2aSseParser {
const fn new() -> Self {
Self {
inner: crate::transport::sse::SseParser::new(),
consecutive_errors: 0,
}
}
fn feed(&mut self, bytes: &[u8]) -> Vec<Result<Value, String>> {
let raw_events = self.inner.feed(bytes);
let mut events = Vec::new();
for raw in raw_events {
match raw {
Err(e) => {
self.consecutive_errors += 1;
events.push(Err(format!("A2A SSE parse error: {e}")));
}
Ok(raw_event) => {
events.push(self.dispatch_raw_event(&raw_event.data));
}
}
}
events
}
fn dispatch_raw_event(&mut self, data_str: &str) -> Result<Value, String> {
let parsed: Value = match serde_json::from_str(data_str) {
Ok(v) => v,
Err(e) => {
self.consecutive_errors += 1;
return Err(format!("malformed JSON in A2A SSE data: {e}"));
}
};
match extract_jsonrpc_result(&parsed) {
Ok(result) => {
self.consecutive_errors = 0;
Ok(result)
}
Err(err) => {
self.consecutive_errors += 1;
Err(err)
}
}
}
const fn consecutive_errors(&self) -> usize {
self.consecutive_errors
}
fn finish(&mut self) -> Vec<Result<Value, String>> {
let raw_events = self.inner.finish();
let mut events = Vec::new();
for raw in raw_events {
match raw {
Err(e) => {
self.consecutive_errors += 1;
events.push(Err(format!("A2A SSE parse error: {e}")));
}
Ok(raw_event) => {
events.push(self.dispatch_raw_event(&raw_event.data));
}
}
}
events
}
}
struct A2aSseStream {
parser: A2aSseParser,
stream: Pin<Box<dyn futures_util::Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
pending: Vec<Result<Value, String>>,
}
impl A2aSseStream {
fn new(response: reqwest::Response) -> Self {
Self {
parser: A2aSseParser::new(),
stream: Box::pin(response.bytes_stream()),
pending: Vec::new(),
}
}
async fn next_event(&mut self) -> Result<Option<Value>, EngineError> {
loop {
if let Some(result) = self.pending.pop() {
return match result {
Ok(event) => Ok(Some(event)),
Err(msg) => Err(EngineError::Driver(msg)),
};
}
match self.stream.next().await {
Some(Ok(bytes)) => {
let mut events = self.parser.feed(&bytes);
events.reverse();
self.pending = events;
}
Some(Err(e)) => {
return Err(EngineError::Driver(format!("A2A SSE stream error: {e}")));
}
None => {
let mut events = self.parser.finish();
if events.is_empty() {
return Ok(None);
}
events.reverse();
self.pending = events;
}
}
}
}
}
fn extract_jsonrpc_result(response: &Value) -> Result<Value, String> {
let Some(obj) = response.as_object() else {
return Err("invalid A2A JSON-RPC response: expected object".to_string());
};
if obj.get("jsonrpc").and_then(Value::as_str) != Some("2.0") {
return Err("invalid A2A JSON-RPC response: expected jsonrpc='2.0'".to_string());
}
if let Some(error) = obj.get("error") {
return Err(format!("A2A JSON-RPC error: {error}"));
}
obj.get("result")
.cloned()
.ok_or_else(|| "invalid A2A JSON-RPC response: missing result".to_string())
}
struct A2aClientTransport {
agent_url: String,
client: reqwest::Client,
headers: Vec<(String, String)>,
context_id: Option<String>,
}
impl A2aClientTransport {
fn new(endpoint: &str, headers: Vec<(String, String)>) -> Self {
Self {
agent_url: endpoint.to_string(),
client: reqwest::Client::new(),
headers,
context_id: None,
}
}
async fn get_agent_card(&self) -> Result<Value, EngineError> {
let url = format!(
"{}/.well-known/agent.json",
self.agent_url.trim_end_matches('/')
);
let mut request = self.client.get(&url);
for (key, value) in &self.headers {
request = request.header(key.as_str(), value.as_str());
}
let response = request
.timeout(AGENT_CARD_TIMEOUT)
.send()
.await
.map_err(|e| EngineError::Driver(format!("Agent Card fetch failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let body = response
.text()
.await
.unwrap_or_else(|_| "<unreadable>".to_string());
return Err(EngineError::Driver(format!(
"Agent Card fetch returned HTTP {status}: {body}"
)));
}
response
.json::<Value>()
.await
.map_err(|e| EngineError::Driver(format!("Agent Card parse failed: {e}")))
}
async fn message_send(&self, body: &Value) -> Result<Value, EngineError> {
let mut backoff = INITIAL_RETRY_BACKOFF;
for attempt in 0..=MAX_RETRIES {
let mut request = self
.client
.post(&self.agent_url)
.header("Content-Type", "application/json");
for (key, value) in &self.headers {
request = request.header(key.as_str(), value.as_str());
}
let response = request
.json(body)
.timeout(DEFAULT_TIMEOUT)
.send()
.await
.map_err(|e| EngineError::Driver(format!("A2A message/send failed: {e}")))?;
let status = response.status();
if status.is_success() {
return response
.json::<Value>()
.await
.map_err(|e| EngineError::Driver(format!("A2A response parse failed: {e}")));
}
if status.as_u16() == 429 && attempt < MAX_RETRIES {
tracing::warn!(
attempt = attempt + 1,
max_retries = MAX_RETRIES,
backoff_ms = backoff.as_millis(),
"A2A agent returned 429, retrying"
);
tokio::time::sleep(backoff).await;
backoff *= 2;
continue;
}
let resp_body = response
.text()
.await
.unwrap_or_else(|_| "<unreadable>".to_string());
return Err(EngineError::Driver(format!(
"A2A agent returned HTTP {status}: {resp_body}"
)));
}
Err(EngineError::Driver("A2A retry loop exhausted".into()))
}
async fn message_stream(&self, body: &Value) -> Result<A2aSseStream, EngineError> {
let mut request = self
.client
.post(&self.agent_url)
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream");
for (key, value) in &self.headers {
request = request.header(key.as_str(), value.as_str());
}
let response = request
.json(body)
.timeout(DEFAULT_TIMEOUT)
.send()
.await
.map_err(|e| EngineError::Driver(format!("A2A message/stream failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let resp_body = response
.text()
.await
.unwrap_or_else(|_| "<unreadable>".to_string());
return Err(EngineError::Driver(format!(
"A2A agent returned HTTP {status}: {resp_body}"
)));
}
Ok(A2aSseStream::new(response))
}
}
fn build_task_message(
state: &Value,
extractors: &HashMap<String, String>,
context_id: Option<&str>,
streaming: bool,
) -> Result<Value, EngineError> {
let task_message = state.get("task_message").ok_or_else(|| {
EngineError::Driver(
"A2A phase state missing 'task_message' key — \
each A2A client phase must define state.task_message"
.to_string(),
)
})?;
let (mut interpolated, _) = interpolate_value(task_message, extractors, None, None);
if interpolated.get("messageId").is_none() || interpolated["messageId"].is_null() {
interpolated["messageId"] = Value::String(uuid::Uuid::new_v4().to_string());
}
if interpolated.get("kind").is_none() {
interpolated["kind"] = Value::String("message".to_string());
}
if let Some(ctx) = context_id
&& interpolated.get("contextId").is_none()
{
interpolated["contextId"] = Value::String(ctx.to_string());
}
let method = if streaming {
"message/stream"
} else {
"message/send"
};
let mut params = json!({ "message": interpolated });
if let Some(config) = state.get("configuration") {
let (interpolated_config, _) = interpolate_value(config, extractors, None, None);
params["configuration"] = interpolated_config;
}
if let Some(metadata) = state.get("metadata") {
let (interpolated_metadata, _) = interpolate_value(metadata, extractors, None, None);
params["metadata"] = interpolated_metadata;
}
Ok(json!({
"jsonrpc": "2.0",
"id": uuid::Uuid::new_v4().to_string(),
"method": method,
"params": params,
}))
}
fn detect_event_type(result: &Value) -> &str {
match result.get("kind").and_then(Value::as_str) {
Some("task") => "task/created",
Some("message") => "message/response",
Some("status-update") => "task/status",
Some("artifact-update") => "task/artifact",
_ => "unknown",
}
}
fn resolve_status_qualifier(event_type: &str, result: &Value) -> String {
if event_type == "task/status"
&& let Some(state) = result
.get("status")
.and_then(|s| s.get("state"))
.and_then(Value::as_str)
{
return format!("task/status:{state}");
}
event_type.to_string()
}
pub struct A2aClientDriver {
transport: A2aClientTransport,
#[allow(dead_code)]
raw_synthesize: bool,
}
#[async_trait]
impl PhaseDriver for A2aClientDriver {
async fn drive_phase(
&mut self,
_phase_index: usize,
state: &Value,
extractors: watch::Receiver<HashMap<String, String>>,
event_tx: mpsc::Sender<ProtocolEvent>,
cancel: CancellationToken,
) -> Result<DriveResult, EngineError> {
let current_extractors = extractors.borrow().clone();
if state
.get("fetch_agent_card")
.and_then(Value::as_bool)
.unwrap_or(false)
{
let _ = event_tx
.send(ProtocolEvent {
direction: Direction::Outgoing,
method: "agent_card/get".to_string(),
content: json!({}),
})
.await;
let card = self.transport.get_agent_card().await?;
let _ = event_tx
.send(ProtocolEvent {
direction: Direction::Incoming,
method: "agent_card/get".to_string(),
content: card,
})
.await;
}
let streaming = state
.get("streaming")
.and_then(Value::as_bool)
.unwrap_or(false);
let message = build_task_message(
state,
¤t_extractors,
self.transport.context_id.as_deref(),
streaming,
)?;
let method = if streaming {
"message/stream"
} else {
"message/send"
};
let _ = event_tx
.send(ProtocolEvent {
direction: Direction::Outgoing,
method: method.to_string(),
content: message.get("params").cloned().unwrap_or(Value::Null),
})
.await;
if streaming {
self.drive_streaming(message, event_tx, cancel).await
} else {
self.drive_synchronous(message, method, event_tx).await
}
}
}
impl A2aClientDriver {
async fn drive_synchronous(
&mut self,
message: Value,
_method: &str,
event_tx: mpsc::Sender<ProtocolEvent>,
) -> Result<DriveResult, EngineError> {
let response = self.transport.message_send(&message).await?;
let result = extract_jsonrpc_result(&response)
.map_err(|e| EngineError::Driver(format!("A2A request failed: {e}")))?;
let event_type = detect_event_type(&result);
if let Some(ctx) = result.get("contextId").and_then(Value::as_str) {
self.transport.context_id = Some(ctx.to_string());
}
let _ = event_tx
.send(ProtocolEvent {
direction: Direction::Incoming,
method: event_type.to_string(),
content: result,
})
.await;
Ok(DriveResult::Complete)
}
#[allow(clippy::cognitive_complexity)]
async fn drive_streaming(
&mut self,
message: Value,
event_tx: mpsc::Sender<ProtocolEvent>,
cancel: CancellationToken,
) -> Result<DriveResult, EngineError> {
let mut sse_stream = self.transport.message_stream(&message).await?;
let _ = event_tx
.send(ProtocolEvent {
direction: Direction::Incoming,
method: "message/stream".to_string(),
content: json!({"status": "connected"}),
})
.await;
let stream_timeout = DEFAULT_STREAM_TIMEOUT;
let mut received_final = false;
loop {
tokio::select! {
result = tokio::time::timeout(stream_timeout, sse_stream.next_event()) => {
match result {
Ok(Ok(Some(sse_result))) => {
let event_type = detect_event_type(&sse_result);
let qualified = resolve_status_qualifier(event_type, &sse_result);
if self.transport.context_id.is_none()
&& let Some(ctx) = sse_result.get("contextId").and_then(Value::as_str)
{
self.transport.context_id = Some(ctx.to_string());
}
let _ = event_tx.send(ProtocolEvent {
direction: Direction::Incoming,
method: qualified,
content: sse_result.clone(),
}).await;
if sse_result
.get("final")
.and_then(Value::as_bool)
.unwrap_or(false)
{
received_final = true;
break;
}
}
Ok(Ok(None)) => {
tracing::warn!("A2A SSE stream closed without final event");
break;
}
Ok(Err(e)) => {
tracing::warn!("A2A SSE parse error: {e}");
if sse_stream.parser.consecutive_errors() >= MAX_CONSECUTIVE_SSE_ERRORS {
tracing::warn!(
"closing A2A connection after {} consecutive parse errors",
MAX_CONSECUTIVE_SSE_ERRORS
);
break;
}
}
Err(_) => {
tracing::warn!(?stream_timeout, "A2A stream timed out");
break;
}
}
}
() = cancel.cancelled() => {
break;
}
}
}
if received_final {
Ok(DriveResult::Complete)
} else {
Ok(DriveResult::TransportClosed)
}
}
}
#[must_use]
pub fn create_a2a_client_driver(
endpoint: &str,
headers: Vec<(String, String)>,
raw_synthesize: bool,
) -> A2aClientDriver {
let transport = A2aClientTransport::new(endpoint, headers);
A2aClientDriver {
transport,
raw_synthesize,
}
}
#[cfg(fuzzing)]
#[must_use]
pub fn fuzz_a2a_sse_feed(bytes: &[u8]) -> Vec<Result<serde_json::Value, String>> {
let mut parser = A2aSseParser::new();
parser.feed(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_basic_data_event() {
let mut parser = A2aSseParser::new();
let input = b"data: {\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"kind\":\"task\",\"id\":\"t1\"}}\n\n";
let events = parser.feed(input);
assert_eq!(events.len(), 1);
let event = events[0].as_ref().unwrap();
assert_eq!(event["kind"], "task");
assert_eq!(event["id"], "t1");
}
#[test]
fn parse_malformed_json_skipped() {
let mut parser = A2aSseParser::new();
let input = b"data: not json\n\n";
let events = parser.feed(input);
assert_eq!(events.len(), 1);
assert!(events[0].is_err());
assert_eq!(parser.consecutive_errors(), 1);
}
#[test]
fn parse_consecutive_errors_tracked() {
let mut parser = A2aSseParser::new();
for _ in 0..5 {
parser.feed(b"data: bad\n\n");
}
assert_eq!(parser.consecutive_errors(), 5);
parser.feed(b"data: {\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"ok\":true}}\n\n");
assert_eq!(parser.consecutive_errors(), 0);
}
#[test]
fn parse_multiple_events_one_chunk() {
let mut parser = A2aSseParser::new();
let input = b"data: {\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"kind\":\"task\"}}\n\ndata: {\"jsonrpc\":\"2.0\",\"id\":\"2\",\"result\":{\"kind\":\"status-update\"}}\n\n";
let events = parser.feed(input);
assert_eq!(events.len(), 2);
assert!(events[0].is_ok());
assert!(events[1].is_ok());
}
#[test]
fn parse_incremental_chunks() {
let mut parser = A2aSseParser::new();
let events1 = parser.feed(b"data: {\"res");
assert!(events1.is_empty());
let events2 =
parser.feed(b"ult\":{\"kind\":\"task\"},\"jsonrpc\":\"2.0\",\"id\":\"1\"}\n\n");
assert_eq!(events2.len(), 1);
}
#[test]
fn parse_sse_comment_ignored() {
let mut parser = A2aSseParser::new();
let input =
b": comment\ndata: {\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"ok\":true}}\n\n";
let events = parser.feed(input);
assert_eq!(events.len(), 1);
assert!(events[0].is_ok());
}
#[test]
fn parse_empty_data_ignored() {
let mut parser = A2aSseParser::new();
let input = b"\n\n";
let events = parser.feed(input);
assert!(events.is_empty());
}
#[test]
fn parse_extracts_result_field() {
let mut parser = A2aSseParser::new();
let input = b"data: {\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"kind\":\"status-update\",\"taskId\":\"t1\",\"status\":{\"state\":\"completed\"},\"final\":true}}\n\n";
let events = parser.feed(input);
let event = events[0].as_ref().unwrap();
assert_eq!(event["kind"], "status-update");
assert_eq!(event["taskId"], "t1");
assert_eq!(event["final"], true);
}
#[test]
fn detect_task_created() {
let result = json!({"kind": "task", "id": "t1"});
assert_eq!(detect_event_type(&result), "task/created");
}
#[test]
fn detect_message_response() {
let result = json!({"kind": "message", "role": "agent"});
assert_eq!(detect_event_type(&result), "message/response");
}
#[test]
fn detect_status_update() {
let result = json!({"kind": "status-update", "status": {"state": "working"}});
assert_eq!(detect_event_type(&result), "task/status");
}
#[test]
fn detect_artifact_update() {
let result = json!({"kind": "artifact-update", "artifact": {}});
assert_eq!(detect_event_type(&result), "task/artifact");
}
#[test]
fn detect_unknown_kind() {
let result = json!({"kind": "future-type"});
assert_eq!(detect_event_type(&result), "unknown");
}
#[test]
fn detect_missing_kind() {
let result = json!({"data": "no kind"});
assert_eq!(detect_event_type(&result), "unknown");
}
#[test]
fn status_qualifier_resolution() {
let result =
json!({"kind": "status-update", "status": {"state": "completed"}, "final": true});
assert_eq!(
resolve_status_qualifier("task/status", &result),
"task/status:completed"
);
}
#[test]
fn status_qualifier_input_required() {
let result = json!({"kind": "status-update", "status": {"state": "input-required"}});
assert_eq!(
resolve_status_qualifier("task/status", &result),
"task/status:input-required"
);
}
#[test]
fn status_qualifier_auth_required() {
let result = json!({"kind": "status-update", "status": {"state": "auth-required"}});
assert_eq!(
resolve_status_qualifier("task/status", &result),
"task/status:auth-required"
);
}
#[test]
fn status_qualifier_rejected() {
let result = json!({"kind": "status-update", "status": {"state": "rejected"}});
assert_eq!(
resolve_status_qualifier("task/status", &result),
"task/status:rejected"
);
}
#[test]
fn status_qualifier_non_status_event() {
let result = json!({"kind": "task"});
assert_eq!(
resolve_status_qualifier("task/created", &result),
"task/created"
);
}
#[test]
fn build_task_message_from_state() {
let state = json!({
"task_message": {
"role": "user",
"parts": [{"kind": "text", "text": "Hello"}]
}
});
let msg = build_task_message(&state, &HashMap::new(), None, false).unwrap();
assert_eq!(msg["jsonrpc"], "2.0");
assert_eq!(msg["method"], "message/send");
assert_eq!(msg["params"]["message"]["role"], "user");
assert_eq!(msg["params"]["message"]["kind"], "message");
assert!(msg["params"]["message"]["messageId"].is_string());
}
#[test]
fn build_task_message_streaming() {
let state = json!({
"task_message": {
"role": "user",
"parts": [{"kind": "text", "text": "Stream this"}]
}
});
let msg = build_task_message(&state, &HashMap::new(), None, true).unwrap();
assert_eq!(msg["method"], "message/stream");
}
#[test]
fn missing_task_message_errors() {
let state = json!({"other_key": "value"});
let result = build_task_message(&state, &HashMap::new(), None, false);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("task_message"), "got: {err}");
}
#[test]
fn auto_generate_message_id() {
let state = json!({
"task_message": {
"role": "user",
"parts": [{"kind": "text", "text": "test"}]
}
});
let msg = build_task_message(&state, &HashMap::new(), None, false).unwrap();
let message_id = msg["params"]["message"]["messageId"].as_str().unwrap();
assert!(!message_id.is_empty());
}
#[test]
fn preserve_existing_message_id() {
let state = json!({
"task_message": {
"role": "user",
"parts": [{"kind": "text", "text": "test"}],
"messageId": "custom-id"
}
});
let msg = build_task_message(&state, &HashMap::new(), None, false).unwrap();
assert_eq!(msg["params"]["message"]["messageId"], "custom-id");
}
#[test]
fn context_id_persistence() {
let state = json!({
"task_message": {
"role": "user",
"parts": [{"kind": "text", "text": "test"}]
}
});
let msg = build_task_message(&state, &HashMap::new(), Some("ctx-123"), false).unwrap();
assert_eq!(msg["params"]["message"]["contextId"], "ctx-123");
}
#[test]
fn context_id_not_overridden() {
let state = json!({
"task_message": {
"role": "user",
"parts": [],
"contextId": "explicit-ctx"
}
});
let msg = build_task_message(&state, &HashMap::new(), Some("auto-ctx"), false).unwrap();
assert_eq!(msg["params"]["message"]["contextId"], "explicit-ctx");
}
#[test]
fn build_with_interpolation() {
let state = json!({
"task_message": {
"role": "user",
"parts": [{"kind": "text", "text": "Use {{tool_name}}"}]
}
});
let mut extractors = HashMap::new();
extractors.insert("tool_name".to_string(), "calculator".to_string());
let msg = build_task_message(&state, &extractors, None, false).unwrap();
let text = msg["params"]["message"]["parts"][0]["text"]
.as_str()
.unwrap();
assert_eq!(text, "Use calculator");
}
#[test]
fn build_with_configuration() {
let state = json!({
"task_message": {
"role": "user",
"parts": [{"kind": "text", "text": "test"}]
},
"configuration": {
"acceptedOutputModes": ["text/plain"],
"historyLength": 0
}
});
let msg = build_task_message(&state, &HashMap::new(), None, false).unwrap();
assert!(msg["params"]["configuration"].is_object());
assert_eq!(msg["params"]["configuration"]["historyLength"], 0);
}
#[test]
fn build_with_metadata() {
let state = json!({
"task_message": {
"role": "user",
"parts": [{"kind": "text", "text": "test"}]
},
"metadata": {
"source": "test-harness"
}
});
let msg = build_task_message(&state, &HashMap::new(), None, false).unwrap();
assert_eq!(msg["params"]["metadata"]["source"], "test-harness");
}
#[test]
fn sync_response_kind_task() {
let result = json!({"kind": "task", "id": "t1", "status": {"state": "completed"}});
assert_eq!(detect_event_type(&result), "task/created");
}
#[test]
fn sync_response_kind_message() {
let result = json!({"kind": "message", "role": "agent", "parts": []});
assert_eq!(detect_event_type(&result), "message/response");
}
#[test]
fn transport_creation() {
let transport = A2aClientTransport::new("http://localhost:9090", vec![]);
assert_eq!(transport.agent_url, "http://localhost:9090");
assert!(transport.context_id.is_none());
assert!(transport.headers.is_empty());
}
#[test]
fn transport_with_headers() {
let headers = vec![("Authorization".to_string(), "Bearer token".to_string())];
let transport = A2aClientTransport::new("http://localhost:9090", headers);
assert_eq!(transport.headers.len(), 1);
assert_eq!(transport.headers[0].0, "Authorization");
}
mod proptests {
use super::*;
use proptest::prelude::*;
fn arb_sse_frame() -> impl Strategy<Value = Vec<u8>> {
(1..=100_i64).prop_map(|id| {
format!(
"data: {{\"jsonrpc\":\"2.0\",\"id\":\"{id}\",\"result\":{{\"kind\":\"task\",\"id\":\"t{id}\"}}}}\n\n"
)
.into_bytes()
})
}
fn arb_sse_stream_with_splits() -> impl Strategy<Value = (Vec<u8>, Vec<usize>)> {
prop::collection::vec(arb_sse_frame(), 1..6).prop_flat_map(|frames| {
let stream: Vec<u8> = frames.into_iter().flatten().collect();
let len = stream.len();
let splits = prop::collection::vec(0..len, 1..8).prop_map(|mut pts| {
pts.sort_unstable();
pts.dedup();
pts
});
(Just(stream), splits)
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn prop_a2a_sse_chunk_independence(
(stream, splits) in arb_sse_stream_with_splits()
) {
let mut one_shot = A2aSseParser::new();
let one_shot_ok: Vec<_> = one_shot
.feed(&stream)
.into_iter()
.filter_map(Result::ok)
.collect();
let mut chunked = A2aSseParser::new();
let mut chunked_ok: Vec<_> = Vec::new();
let mut prev = 0;
for &split in &splits {
if split > prev {
chunked_ok.extend(
chunked.feed(&stream[prev..split]).into_iter().filter_map(Result::ok),
);
prev = split;
}
}
chunked_ok.extend(
chunked.feed(&stream[prev..]).into_iter().filter_map(Result::ok),
);
prop_assert_eq!(one_shot_ok.len(), chunked_ok.len(),
"chunk independence: one-shot={}, chunked={}",
one_shot_ok.len(), chunked_ok.len());
}
#[test]
fn prop_a2a_sse_no_panic(data in prop::collection::vec(any::<u8>(), 0..512)) {
let mut parser = A2aSseParser::new();
let _ = parser.feed(&data);
}
#[test]
fn prop_a2a_result_extraction(id in 1..=1000_i64, has_result in any::<bool>()) {
let mut parser = A2aSseParser::new();
let input = if has_result {
format!(
"data: {{\"jsonrpc\":\"2.0\",\"id\":\"{id}\",\"result\":{{\"kind\":\"task\",\"extracted\":true}}}}\n\n"
)
} else {
format!(
"data: {{\"kind\":\"task\",\"id\":\"{id}\",\"noResult\":true}}\n\n"
)
};
let events = parser.feed(input.as_bytes());
prop_assert_eq!(events.len(), 1);
if has_result {
let val = events[0].as_ref().unwrap();
prop_assert_eq!(val.get("extracted").and_then(Value::as_bool), Some(true));
} else {
prop_assert!(events[0].is_err());
}
}
}
}
#[test]
fn create_driver() {
let driver = create_a2a_client_driver(
"http://localhost:9090",
vec![("Auth".to_string(), "Bearer x".to_string())],
true,
);
assert!(driver.raw_synthesize);
assert_eq!(driver.transport.agent_url, "http://localhost:9090");
assert_eq!(driver.transport.headers.len(), 1);
}
}