use a2a::*;
use a2a_pb::protojson_conv::{self, ProtoJsonPayload};
use async_trait::async_trait;
use futures::stream::{self, BoxStream, StreamExt};
use reqwest::Client;
use std::collections::VecDeque;
use crate::push_config_compat::{
deserialize_list_task_push_notification_configs_response,
deserialize_task_push_notification_config,
serialize_create_task_push_notification_config_request,
};
use crate::transport::{ServiceParams, Transport, TransportFactory};
pub struct JsonRpcTransport {
client: Client,
endpoint: String,
}
impl JsonRpcTransport {
pub fn new(client: Client, endpoint: String) -> Self {
JsonRpcTransport { client, endpoint }
}
async fn call_value_with_payload(
&self,
params: &ServiceParams,
method: &str,
payload: serde_json::Value,
) -> Result<serde_json::Value, A2AError> {
let id = JsonRpcId::String(uuid::Uuid::now_v7().to_string());
let rpc_request = JsonRpcRequest::new(id, method, Some(payload));
let mut builder = self.client.post(&self.endpoint);
for (key, values) in params {
for v in values {
builder = builder.header(key, v);
}
}
let response = builder
.json(&rpc_request)
.send()
.await
.map_err(|e| A2AError::internal(format!("HTTP request failed: {e}")))?;
let rpc_response: JsonRpcResponse = response
.json()
.await
.map_err(|e| A2AError::internal(format!("failed to parse JSON-RPC response: {e}")))?;
if let Some(err) = rpc_response.error {
return Err(A2AError::new(err.code, err.message));
}
rpc_response
.result
.ok_or_else(|| A2AError::internal("JSON-RPC response missing result"))
}
async fn call_value<Req>(
&self,
params: &ServiceParams,
method: &str,
request_params: &Req,
) -> Result<serde_json::Value, A2AError>
where
Req: ProtoJsonPayload,
{
let payload = protojson_conv::to_value(request_params).map_err(|e| {
A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
})?;
self.call_value_with_payload(params, method, payload).await
}
async fn call<Req, Resp>(
&self,
params: &ServiceParams,
method: &str,
request_params: &Req,
) -> Result<Resp, A2AError>
where
Req: ProtoJsonPayload,
Resp: ProtoJsonPayload,
{
let result = self.call_value(params, method, request_params).await?;
protojson_conv::from_value(result)
.map_err(|e| A2AError::internal(format!("failed to deserialize result: {e}")))
}
async fn call_streaming<Req>(
&self,
params: &ServiceParams,
method: &str,
request_params: &Req,
) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError>
where
Req: ProtoJsonPayload,
{
let id = JsonRpcId::String(uuid::Uuid::now_v7().to_string());
let payload = protojson_conv::to_value(request_params).map_err(|e| {
A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
})?;
let rpc_request = JsonRpcRequest::new(id, method, Some(payload));
let mut builder = self
.client
.post(&self.endpoint)
.header("Accept", "text/event-stream");
for (key, values) in params {
for v in values {
builder = builder.header(key, v);
}
}
let response = builder
.json(&rpc_request)
.send()
.await
.map_err(|e| A2AError::internal(format!("HTTP request failed: {e}")))?;
let stream = response.bytes_stream();
let event_stream = parse_sse_stream(stream);
Ok(event_stream)
}
}
fn find_event_boundary(buf: &[u8]) -> Option<(usize, usize)> {
for i in 0..buf.len().saturating_sub(1) {
if buf[i] == b'\n' && buf[i + 1] == b'\n' {
return Some((i, i + 2));
}
if buf[i] == b'\r' && buf[i + 1] == b'\r' {
return Some((i, i + 2));
}
if i + 3 < buf.len() && &buf[i..i + 4] == b"\r\n\r\n" {
return Some((i, i + 4));
}
}
None
}
fn parse_sse_bytes<F>(
stream: impl futures::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
parse_event: F,
) -> BoxStream<'static, Result<StreamResponse, A2AError>>
where
F: Fn(&str) -> Option<Result<StreamResponse, A2AError>> + Clone + Send + 'static,
{
let mapped = stream::unfold(
(
Box::pin(stream),
Vec::<u8>::new(),
VecDeque::new(),
parse_event,
),
|(mut stream, mut buf, mut pending, parse_event)| async move {
loop {
if let Some(item) = pending.pop_front() {
return Some((item, (stream, buf, pending, parse_event)));
}
match stream.next().await {
Some(Ok(chunk)) => {
buf.extend_from_slice(&chunk);
while let Some((start, end)) = find_event_boundary(&buf) {
let event_bytes: Vec<u8> = buf.drain(..end).collect();
let event_text = match std::str::from_utf8(&event_bytes[..start]) {
Ok(s) => s,
Err(e) => {
pending.push_back(Err(A2AError::internal(format!(
"SSE UTF-8 decode error: {e}"
))));
continue;
}
};
if let Some(result) = parse_event(event_text) {
pending.push_back(result);
}
}
}
Some(Err(e)) => {
pending
.push_back(Err(A2AError::internal(format!("SSE stream error: {e}"))));
}
None => return None,
}
}
},
);
Box::pin(mapped)
}
fn parse_sse_stream(
stream: impl futures::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
) -> BoxStream<'static, Result<StreamResponse, A2AError>> {
parse_sse_bytes(stream, |event_text| {
let mut data = String::new();
for line in event_text.lines() {
if let Some(d) = line.strip_prefix("data: ") {
if !data.is_empty() {
data.push('\n');
}
data.push_str(d);
} else if let Some(d) = line.strip_prefix("data:") {
if !data.is_empty() {
data.push('\n');
}
data.push_str(d);
}
}
if data.is_empty() {
return None;
}
match serde_json::from_str::<JsonRpcResponse>(&data) {
Ok(rpc_resp) => {
if let Some(err) = rpc_resp.error {
return Some(Err(A2AError::new(err.code, err.message)));
}
if let Some(result) = rpc_resp.result {
match protojson_conv::from_value::<StreamResponse>(result) {
Ok(sr) => Some(Ok(sr)),
Err(e) => Some(Err(A2AError::internal(format!("SSE parse error: {e}")))),
}
} else {
None
}
}
Err(_) => {
match protojson_conv::from_str::<StreamResponse>(&data) {
Ok(sr) => Some(Ok(sr)),
Err(e) => Some(Err(A2AError::internal(format!("SSE parse error: {e}")))),
}
}
}
})
}
pub(crate) fn parse_sse_stream_rest(
stream: impl futures::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
) -> BoxStream<'static, Result<StreamResponse, A2AError>> {
parse_sse_bytes(stream, |event_text| {
let mut data = String::new();
for line in event_text.lines() {
if let Some(d) = line.strip_prefix("data: ") {
if !data.is_empty() {
data.push('\n');
}
data.push_str(d);
} else if let Some(d) = line.strip_prefix("data:") {
if !data.is_empty() {
data.push('\n');
}
data.push_str(d);
}
}
if data.is_empty() {
return None;
}
match protojson_conv::from_str::<StreamResponse>(&data) {
Ok(sr) => Some(Ok(sr)),
Err(e) => Some(Err(A2AError::internal(format!("SSE parse error: {e}")))),
}
})
}
#[async_trait]
impl Transport for JsonRpcTransport {
async fn send_message(
&self,
params: &ServiceParams,
req: &SendMessageRequest,
) -> Result<SendMessageResponse, A2AError> {
self.call(params, methods::SEND_MESSAGE, req).await
}
async fn send_streaming_message(
&self,
params: &ServiceParams,
req: &SendMessageRequest,
) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
self.call_streaming(params, methods::SEND_STREAMING_MESSAGE, req)
.await
}
async fn get_task(
&self,
params: &ServiceParams,
req: &GetTaskRequest,
) -> Result<Task, A2AError> {
self.call(params, methods::GET_TASK, req).await
}
async fn list_tasks(
&self,
params: &ServiceParams,
req: &ListTasksRequest,
) -> Result<ListTasksResponse, A2AError> {
self.call(params, methods::LIST_TASKS, req).await
}
async fn cancel_task(
&self,
params: &ServiceParams,
req: &CancelTaskRequest,
) -> Result<Task, A2AError> {
self.call(params, methods::CANCEL_TASK, req).await
}
async fn subscribe_to_task(
&self,
params: &ServiceParams,
req: &SubscribeToTaskRequest,
) -> Result<BoxStream<'static, Result<StreamResponse, A2AError>>, A2AError> {
self.call_streaming(params, methods::SUBSCRIBE_TO_TASK, req)
.await
}
async fn create_push_config(
&self,
params: &ServiceParams,
req: &CreateTaskPushNotificationConfigRequest,
) -> Result<TaskPushNotificationConfig, A2AError> {
let payload = serialize_create_task_push_notification_config_request(req)?;
let result = self
.call_value_with_payload(params, methods::CREATE_PUSH_CONFIG, payload)
.await?;
deserialize_task_push_notification_config(result)
}
async fn get_push_config(
&self,
params: &ServiceParams,
req: &GetTaskPushNotificationConfigRequest,
) -> Result<TaskPushNotificationConfig, A2AError> {
let result = self
.call_value(params, methods::GET_PUSH_CONFIG, req)
.await?;
deserialize_task_push_notification_config(result)
}
async fn list_push_configs(
&self,
params: &ServiceParams,
req: &ListTaskPushNotificationConfigsRequest,
) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
let result = self
.call_value(params, methods::LIST_PUSH_CONFIGS, req)
.await?;
deserialize_list_task_push_notification_configs_response(result)
}
async fn delete_push_config(
&self,
params: &ServiceParams,
req: &DeleteTaskPushNotificationConfigRequest,
) -> Result<(), A2AError> {
let id = JsonRpcId::String(uuid::Uuid::now_v7().to_string());
let request_params = protojson_conv::to_value(req).map_err(|e| {
A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
})?;
let rpc_request =
JsonRpcRequest::new(id, methods::DELETE_PUSH_CONFIG, Some(request_params));
let mut builder = self.client.post(&self.endpoint);
for (key, values) in params {
for v in values {
builder = builder.header(key, v);
}
}
let response = builder
.json(&rpc_request)
.send()
.await
.map_err(|e| A2AError::internal(format!("HTTP request failed: {e}")))?;
let rpc_response: JsonRpcResponse = response
.json()
.await
.map_err(|e| A2AError::internal(format!("failed to parse JSON-RPC response: {e}")))?;
if let Some(err) = rpc_response.error {
return Err(A2AError::new(err.code, err.message));
}
Ok(())
}
async fn get_extended_agent_card(
&self,
params: &ServiceParams,
req: &GetExtendedAgentCardRequest,
) -> Result<AgentCard, A2AError> {
self.call(params, methods::GET_EXTENDED_AGENT_CARD, req)
.await
}
async fn destroy(&self) -> Result<(), A2AError> {
Ok(())
}
}
pub struct JsonRpcTransportFactory {
client: Client,
}
impl JsonRpcTransportFactory {
pub fn new(client: Option<Client>) -> Self {
JsonRpcTransportFactory {
client: client.unwrap_or_default(),
}
}
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
pub fn with_root_certificates_pem(pem: &[u8]) -> Result<Self, A2AError> {
Ok(Self {
client: crate::build_reqwest_client_with_root_pem(pem)?,
})
}
}
#[async_trait]
impl TransportFactory for JsonRpcTransportFactory {
fn protocol(&self) -> &str {
TRANSPORT_PROTOCOL_JSONRPC
}
async fn create(
&self,
_card: &AgentCard,
iface: &AgentInterface,
) -> Result<Box<dyn Transport>, A2AError> {
Ok(Box::new(JsonRpcTransport::new(
self.client.clone(),
iface.url.clone(),
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::rcgen_self_signed_ca_pem;
use a2a_pb::protojson_conv;
use futures::StreamExt;
use serde_json::{Value, json};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
fn byte_stream(
chunks: Vec<String>,
) -> impl futures::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static {
stream::iter(
chunks
.into_iter()
.map(|s| Ok(bytes::Bytes::from(s)))
.collect::<Vec<_>>(),
)
}
async fn spawn_jsonrpc_server(response_body: String) -> (String, oneshot::Receiver<String>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (request_tx, request_rx) = oneshot::channel();
tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let request = read_http_request(&mut socket).await;
let response = format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
response_body.len(),
response_body,
);
let _ = request_tx.send(request);
socket.write_all(response.as_bytes()).await.unwrap();
});
(format!("http://{addr}"), request_rx)
}
async fn read_http_request(socket: &mut TcpStream) -> String {
let mut buffer = Vec::new();
let mut chunk = [0_u8; 1024];
let mut expected_len = None;
loop {
let read = socket.read(&mut chunk).await.unwrap();
if read == 0 {
break;
}
buffer.extend_from_slice(&chunk[..read]);
if expected_len.is_none() {
if let Some(header_end) = find_header_end(&buffer) {
let headers = String::from_utf8_lossy(&buffer[..header_end]);
expected_len = Some(header_end + parse_content_length(&headers));
}
}
if let Some(total_len) = expected_len {
if buffer.len() >= total_len {
break;
}
}
}
String::from_utf8(buffer).unwrap()
}
fn find_header_end(buffer: &[u8]) -> Option<usize> {
buffer
.windows(4)
.position(|window| window == b"\r\n\r\n")
.map(|position| position + 4)
}
fn parse_content_length(headers: &str) -> usize {
headers
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().ok())
.flatten()
})
.unwrap_or(0)
}
fn sample_create_push_config_request() -> CreateTaskPushNotificationConfigRequest {
CreateTaskPushNotificationConfigRequest {
task_id: "task-1".into(),
config: PushNotificationConfig {
url: "https://example.invalid/webhook".into(),
id: Some("cfg-1".into()),
token: Some("secret-token".into()),
authentication: Some(AuthenticationInfo {
scheme: "Bearer".into(),
credentials: Some("credential".into()),
}),
},
tenant: Some("tenant-1".into()),
}
}
#[tokio::test]
async fn test_parse_sse_stream_jsonrpc_envelope() {
let status_update = TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state: TaskState::Working,
message: None,
timestamp: None,
},
metadata: None,
};
let sr = StreamResponse::StatusUpdate(status_update);
let result_val = protojson_conv::to_value(&sr).unwrap();
let rpc_resp = JsonRpcResponse::success(JsonRpcId::Number(1), result_val);
let data = serde_json::to_string(&rpc_resp).unwrap();
let sse_text = format!("data: {}\n\n", data);
let stream = byte_stream(vec![sse_text]);
let mut parsed = parse_sse_stream(stream);
let item = parsed.next().await.unwrap().unwrap();
assert!(matches!(item, StreamResponse::StatusUpdate(_)));
assert!(parsed.next().await.is_none());
}
#[tokio::test]
async fn test_parse_sse_stream_jsonrpc_error() {
let rpc_resp = JsonRpcResponse::error(
JsonRpcId::Number(1),
JsonRpcError {
code: -32600,
message: "bad request".into(),
data: None,
},
);
let data = serde_json::to_string(&rpc_resp).unwrap();
let sse_text = format!("data: {}\n\n", data);
let stream = byte_stream(vec![sse_text]);
let mut parsed = parse_sse_stream(stream);
let item = parsed.next().await.unwrap();
assert!(item.is_err());
}
#[tokio::test]
async fn test_parse_sse_stream_direct_stream_response() {
let status_update = TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state: TaskState::Completed,
message: None,
timestamp: None,
},
metadata: None,
};
let sr = StreamResponse::StatusUpdate(status_update);
let data = serde_json::to_string(&protojson_conv::to_value(&sr).unwrap()).unwrap();
let sse_text = format!("data: {}\n\n", data);
let stream = byte_stream(vec![sse_text]);
let mut parsed = parse_sse_stream(stream);
let item = parsed.next().await.unwrap().unwrap();
assert!(matches!(item, StreamResponse::StatusUpdate(_)));
}
#[tokio::test]
async fn test_parse_sse_stream_rest_ok() {
let status_update = TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state: TaskState::Working,
message: None,
timestamp: None,
},
metadata: None,
};
let sr = StreamResponse::StatusUpdate(status_update);
let data = serde_json::to_string(&protojson_conv::to_value(&sr).unwrap()).unwrap();
let sse_text = format!("data: {}\n\n", data);
let stream = byte_stream(vec![sse_text]);
let mut parsed = parse_sse_stream_rest(stream);
let item = parsed.next().await.unwrap().unwrap();
assert!(matches!(item, StreamResponse::StatusUpdate(_)));
}
#[tokio::test]
async fn test_parse_sse_stream_rest_parse_error() {
let sse_text = "data: not-valid-json\n\n".to_string();
let stream = byte_stream(vec![sse_text]);
let mut parsed = parse_sse_stream_rest(stream);
let item = parsed.next().await.unwrap();
assert!(item.is_err());
}
#[tokio::test]
async fn test_parse_sse_stream_empty_data_lines_skipped() {
let sr = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state: TaskState::Completed,
message: None,
timestamp: None,
},
metadata: None,
});
let data = serde_json::to_string(&protojson_conv::to_value(&sr).unwrap()).unwrap();
let sse_text = format!("event: ping\n\ndata: {}\n\n", data);
let stream = byte_stream(vec![sse_text]);
let mut parsed = parse_sse_stream_rest(stream);
let item = parsed.next().await.unwrap().unwrap();
assert!(matches!(item, StreamResponse::StatusUpdate(_)));
}
#[tokio::test]
async fn test_parse_sse_stream_chunked_delivery() {
let sr = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state: TaskState::Completed,
message: None,
timestamp: None,
},
metadata: None,
});
let data = serde_json::to_string(&protojson_conv::to_value(&sr).unwrap()).unwrap();
let full = format!("data: {}\n\n", data);
let mid = full.len() / 2;
let chunk1 = full[..mid].to_string();
let chunk2 = full[mid..].to_string();
let stream = byte_stream(vec![chunk1, chunk2]);
let mut parsed = parse_sse_stream_rest(stream);
let item = parsed.next().await.unwrap().unwrap();
assert!(matches!(item, StreamResponse::StatusUpdate(_)));
}
#[tokio::test]
async fn test_parse_sse_stream_data_no_space() {
let sr = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state: TaskState::Completed,
message: None,
timestamp: None,
},
metadata: None,
});
let data = serde_json::to_string(&protojson_conv::to_value(&sr).unwrap()).unwrap();
let sse_text = format!("data:{}\n\n", data);
let stream = byte_stream(vec![sse_text]);
let mut parsed = parse_sse_stream_rest(stream);
let item = parsed.next().await.unwrap().unwrap();
assert!(matches!(item, StreamResponse::StatusUpdate(_)));
}
#[tokio::test]
async fn test_parse_sse_stream_multiple_events_separate_chunks() {
let make_sr = |state| {
StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state,
message: None,
timestamp: None,
},
metadata: None,
})
};
let sr1 =
serde_json::to_string(&protojson_conv::to_value(&make_sr(TaskState::Working)).unwrap())
.unwrap();
let sr2 = serde_json::to_string(
&protojson_conv::to_value(&make_sr(TaskState::Completed)).unwrap(),
)
.unwrap();
let chunk1 = format!("data: {}\n\n", sr1);
let chunk2 = format!("data: {}\n\n", sr2);
let stream = byte_stream(vec![chunk1, chunk2]);
let items: Vec<_> = parse_sse_stream_rest(stream).collect().await;
assert_eq!(items.len(), 2);
assert!(items[0].is_ok());
assert!(items[1].is_ok());
}
#[tokio::test]
async fn test_parse_sse_stream_multiple_events_single_chunk() {
let make_status = |state| {
StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state,
message: None,
timestamp: None,
},
metadata: None,
})
};
let rpc_payload = |sr: StreamResponse, id: i64| -> String {
let result = protojson_conv::to_value(&sr).unwrap();
serde_json::to_string(&JsonRpcResponse::success(JsonRpcId::Number(id), result)).unwrap()
};
let combined = format!(
"data: {}\n\ndata: {}\n\n",
rpc_payload(make_status(TaskState::Working), 1),
rpc_payload(make_status(TaskState::Completed), 2),
);
let stream = byte_stream(vec![combined]);
let items: Vec<_> = parse_sse_stream(stream).collect().await;
assert_eq!(
items.len(),
2,
"both events must be delivered even when packed into a single \
byte chunk followed by end-of-stream; got {items:?}",
);
assert!(items[0].is_ok());
assert!(items[1].is_ok());
}
#[tokio::test]
async fn test_parse_sse_stream_rest_multiple_events_single_chunk() {
let make_status = |state| {
StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state,
message: None,
timestamp: None,
},
metadata: None,
})
};
let encode = |sr: StreamResponse| -> String {
serde_json::to_string(&protojson_conv::to_value(&sr).unwrap()).unwrap()
};
let combined = format!(
"data: {}\n\ndata: {}\n\n",
encode(make_status(TaskState::Working)),
encode(make_status(TaskState::Completed)),
);
let stream = byte_stream(vec![combined]);
let items: Vec<_> = parse_sse_stream_rest(stream).collect().await;
assert_eq!(
items.len(),
2,
"both events must be delivered even when packed into a single \
byte chunk followed by end-of-stream; got {items:?}",
);
assert!(items[0].is_ok());
assert!(items[1].is_ok());
}
#[tokio::test]
async fn test_parse_sse_stream_split_multibyte_utf8() {
let message = Message {
message_id: "m-1".into(),
role: Role::Agent,
parts: vec![Part::text("ä¸")],
context_id: None,
task_id: Some("t1".into()),
metadata: None,
extensions: None,
reference_task_ids: None,
};
let status_update = TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state: TaskState::Working,
message: Some(message),
timestamp: None,
},
metadata: None,
};
let sr = StreamResponse::StatusUpdate(status_update);
let result = protojson_conv::to_value(&sr).unwrap();
let rpc_resp = JsonRpcResponse::success(JsonRpcId::Number(1), result);
let data = serde_json::to_string(&rpc_resp).unwrap();
let sse_text = format!("data: {data}\n\n");
let bytes = sse_text.as_bytes();
let zhong_start = bytes
.windows(3)
.position(|w| w == [0xE4, 0xB8, 0xAD])
.expect("payload must contain the UTF-8 encoding of 'ä¸'");
let split = zhong_start + 1;
let chunk1 = String::from_utf8_lossy(&bytes[..split]).into_owned();
let chunk2 = String::from_utf8_lossy(&bytes[split..]).into_owned();
assert!(
chunk1.ends_with(char::REPLACEMENT_CHARACTER)
|| chunk2.starts_with(char::REPLACEMENT_CHARACTER),
"test setup should split a multi-byte char so naive decode corrupts it",
);
let raw1 = bytes[..split].to_vec();
let raw2 = bytes[split..].to_vec();
let raw_stream = stream::iter(vec![
Ok::<_, reqwest::Error>(bytes::Bytes::from(raw1)),
Ok::<_, reqwest::Error>(bytes::Bytes::from(raw2)),
]);
let items: Vec<_> = parse_sse_stream(raw_stream).collect().await;
assert_eq!(items.len(), 1, "expected one event, got {items:?}");
let event = items.into_iter().next().unwrap().expect("decoded");
match event {
StreamResponse::StatusUpdate(su) => {
let msg = su.status.message.expect("message present");
let text = match &msg.parts[0].content {
PartContent::Text(t) => t.clone(),
other => panic!("expected text part, got {other:?}"),
};
assert_eq!(text, "ä¸", "multi-byte char must survive chunk split");
}
other => panic!("expected StatusUpdate, got {other:?}"),
}
}
#[tokio::test]
async fn test_parse_sse_stream_invalid_utf8_surfaces_error() {
let mut bytes = b"data: ".to_vec();
bytes.push(0xFF); bytes.extend_from_slice(b"\n\n");
let raw_stream = stream::iter(vec![Ok::<_, reqwest::Error>(bytes::Bytes::from(bytes))]);
let items: Vec<_> = parse_sse_stream(raw_stream).collect().await;
assert_eq!(items.len(), 1, "expected one error item, got {items:?}");
let err = items.into_iter().next().unwrap().expect_err("should error");
assert!(
err.to_string().contains("UTF-8"),
"error should mention UTF-8 decoding: {err}",
);
}
#[tokio::test]
async fn test_parse_sse_stream_invalid_protojson_result_surfaces_error() {
let rpc_resp = JsonRpcResponse::success(
JsonRpcId::Number(1),
serde_json::json!({"not_a_stream_response_field": 42}),
);
let data = serde_json::to_string(&rpc_resp).unwrap();
let sse_text = format!("data: {data}\n\n");
let stream = byte_stream(vec![sse_text]);
let items: Vec<_> = parse_sse_stream(stream).collect().await;
assert_eq!(items.len(), 1, "expected one error item, got {items:?}");
let err = items.into_iter().next().unwrap().expect_err("should error");
assert!(
err.to_string().contains("parse error"),
"error should mention parse error: {err}",
);
}
#[tokio::test]
async fn test_parse_sse_stream_direct_invalid_protojson_surfaces_error() {
let sse_text = "data: {\"foo\":\"bar\"}\n\n".to_string();
let stream = byte_stream(vec![sse_text]);
let items: Vec<_> = parse_sse_stream(stream).collect().await;
assert_eq!(items.len(), 1, "expected one error item, got {items:?}");
let err = items.into_iter().next().unwrap().expect_err("should error");
assert!(
err.to_string().contains("parse error"),
"error should mention parse error: {err}",
);
}
#[tokio::test]
async fn test_parse_sse_stream_surfaces_stream_errors() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
let err = reqwest::get(format!("http://{addr}/"))
.await
.expect_err("connection should fail");
let s = stream::iter(vec![Err::<bytes::Bytes, reqwest::Error>(err)]);
let items: Vec<_> = parse_sse_stream(s).collect().await;
assert_eq!(items.len(), 1, "expected one error item, got {items:?}");
let got = items.into_iter().next().unwrap().expect_err("should error");
assert!(
got.to_string().contains("SSE stream error"),
"error should mention SSE stream error: {got}",
);
}
#[tokio::test]
async fn test_parse_sse_stream_rest_surfaces_stream_errors() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
let err = reqwest::get(format!("http://{addr}/"))
.await
.expect_err("connection should fail");
let s = stream::iter(vec![Err::<bytes::Bytes, reqwest::Error>(err)]);
let items: Vec<_> = parse_sse_stream_rest(s).collect().await;
assert_eq!(items.len(), 1, "expected one error item, got {items:?}");
let got = items.into_iter().next().unwrap().expect_err("should error");
assert!(
got.to_string().contains("SSE stream error"),
"error should mention SSE stream error: {got}",
);
}
#[tokio::test]
async fn test_parse_sse_stream_crlf_line_endings() {
let sr = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state: TaskState::Working,
message: None,
timestamp: None,
},
metadata: None,
});
let data = serde_json::to_string(&protojson_conv::to_value(&sr).unwrap()).unwrap();
let sse_text = format!("data: {}\r\n\r\n", data);
let stream = byte_stream(vec![sse_text]);
let mut parsed = parse_sse_stream(stream);
let item = parsed.next().await.unwrap().unwrap();
assert!(matches!(item, StreamResponse::StatusUpdate(_)));
assert!(parsed.next().await.is_none());
}
#[tokio::test]
async fn test_parse_sse_stream_cr_line_endings() {
let sr = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: "t1".into(),
context_id: "c1".into(),
status: TaskStatus {
state: TaskState::Completed,
message: None,
timestamp: None,
},
metadata: None,
});
let data = serde_json::to_string(&protojson_conv::to_value(&sr).unwrap()).unwrap();
let sse_text = format!("data: {}\r\r", data);
let stream = byte_stream(vec![sse_text]);
let mut parsed = parse_sse_stream_rest(stream);
let item = parsed.next().await.unwrap().unwrap();
assert!(matches!(item, StreamResponse::StatusUpdate(_)));
}
#[test]
fn test_find_event_boundary_lf() {
assert_eq!(find_event_boundary(b"foo\n\nbar"), Some((3, 5)));
}
#[test]
fn test_find_event_boundary_cr() {
assert_eq!(find_event_boundary(b"foo\r\rbar"), Some((3, 5)));
}
#[test]
fn test_find_event_boundary_crlf() {
assert_eq!(find_event_boundary(b"foo\r\n\r\nbar"), Some((3, 7)));
}
#[test]
fn test_find_event_boundary_none() {
assert!(find_event_boundary(b"no boundary here").is_none());
}
#[test]
fn test_jsonrpc_transport_new() {
let t = JsonRpcTransport::new(Client::new(), "http://localhost:8080".into());
assert_eq!(t.endpoint, "http://localhost:8080");
}
#[test]
fn test_jsonrpc_transport_factory() {
let f = JsonRpcTransportFactory::new(None);
assert_eq!(f.protocol(), "JSONRPC");
}
#[tokio::test]
async fn test_jsonrpc_transport_factory_create() {
let f = JsonRpcTransportFactory::new(None);
let card = AgentCard {
name: "Test".into(),
description: "Test".into(),
version: "1.0".into(),
supported_interfaces: vec![],
capabilities: AgentCapabilities::default(),
default_input_modes: vec!["text/plain".into()],
default_output_modes: vec!["text/plain".into()],
skills: vec![],
provider: None,
documentation_url: None,
icon_url: None,
security_schemes: None,
security_requirements: None,
signatures: None,
};
let iface = AgentInterface::new("http://localhost:8080", "JSONRPC");
let transport = f.create(&card, &iface).await.unwrap();
transport.destroy().await.unwrap();
}
#[tokio::test]
async fn test_create_push_config_sends_nested_request_shape() {
let response = json!({
"jsonrpc": "2.0",
"id": "1",
"result": {
"taskId": "task-1",
"config": {
"url": "https://example.invalid/webhook",
"id": "cfg-1",
"token": "secret-token",
"authentication": {
"scheme": "Bearer",
"credentials": "credential"
}
},
"tenant": "tenant-1"
}
})
.to_string();
let (endpoint, request_rx) = spawn_jsonrpc_server(response).await;
let transport = JsonRpcTransport::new(Client::new(), endpoint);
let mut params = ServiceParams::new();
params.insert("x-trace".into(), vec!["alpha".into(), "beta".into()]);
let result = transport
.create_push_config(¶ms, &sample_create_push_config_request())
.await
.unwrap();
assert_eq!(result.task_id, "task-1");
assert_eq!(result.config.id.as_deref(), Some("cfg-1"));
let request = request_rx.await.unwrap();
let request_lower = request.to_ascii_lowercase();
assert!(request_lower.contains("x-trace: alpha"));
assert!(request_lower.contains("x-trace: beta"));
let body = request.split("\r\n\r\n").nth(1).unwrap();
let payload: Value = serde_json::from_str(body).unwrap();
assert_eq!(payload["method"], methods::CREATE_PUSH_CONFIG);
assert_eq!(
payload["params"],
json!({
"taskId": "task-1",
"config": {
"url": "https://example.invalid/webhook",
"id": "cfg-1",
"token": "secret-token",
"authentication": {
"scheme": "Bearer",
"credentials": "credential"
}
},
"tenant": "tenant-1"
})
);
}
#[tokio::test]
async fn test_create_push_config_surfaces_jsonrpc_error() {
let response = json!({
"jsonrpc": "2.0",
"id": "1",
"error": {
"code": error_code::INVALID_PARAMS,
"message": "invalid params",
"data": null
}
})
.to_string();
let (endpoint, _request_rx) = spawn_jsonrpc_server(response).await;
let transport = JsonRpcTransport::new(Client::new(), endpoint);
let error = transport
.create_push_config(&ServiceParams::new(), &sample_create_push_config_request())
.await
.unwrap_err();
assert_eq!(error.code, error_code::INVALID_PARAMS);
assert_eq!(error.message, "invalid params");
}
#[tokio::test]
async fn test_create_push_config_rejects_missing_result() {
let response = json!({
"jsonrpc": "2.0",
"id": "1"
})
.to_string();
let (endpoint, _request_rx) = spawn_jsonrpc_server(response).await;
let transport = JsonRpcTransport::new(Client::new(), endpoint);
let error = transport
.create_push_config(&ServiceParams::new(), &sample_create_push_config_request())
.await
.unwrap_err();
assert_eq!(error.code, error_code::INTERNAL_ERROR);
assert_eq!(error.message, "JSON-RPC response missing result");
}
#[tokio::test]
async fn test_get_push_config_uses_protojson_request_path() {
let response = json!({
"jsonrpc": "2.0",
"id": "1",
"result": {
"taskId": "task-1",
"config": {
"url": "https://example.invalid/webhook",
"id": "cfg-1",
"token": "secret-token"
},
"tenant": "tenant-1"
}
})
.to_string();
let (endpoint, request_rx) = spawn_jsonrpc_server(response).await;
let transport = JsonRpcTransport::new(Client::new(), endpoint);
let result = transport
.get_push_config(
&ServiceParams::new(),
&GetTaskPushNotificationConfigRequest {
task_id: "task-1".into(),
id: "cfg-1".into(),
tenant: Some("tenant-1".into()),
},
)
.await
.unwrap();
assert_eq!(result.task_id, "task-1");
assert_eq!(result.config.id.as_deref(), Some("cfg-1"));
let request = request_rx.await.unwrap();
let body = request.split("\r\n\r\n").nth(1).unwrap();
let payload: Value = serde_json::from_str(body).unwrap();
assert_eq!(payload["method"], methods::GET_PUSH_CONFIG);
assert_eq!(
payload["params"],
json!({
"taskId": "task-1",
"id": "cfg-1",
"tenant": "tenant-1"
})
);
}
#[test]
fn test_with_root_certificates_pem_valid() {
let pem = rcgen_self_signed_ca_pem();
let f = JsonRpcTransportFactory::with_root_certificates_pem(&pem).unwrap();
assert_eq!(f.protocol(), TRANSPORT_PROTOCOL_JSONRPC);
}
#[tokio::test]
async fn test_with_root_certificates_pem_factory_create() {
let pem = rcgen_self_signed_ca_pem();
let f = JsonRpcTransportFactory::with_root_certificates_pem(&pem).unwrap();
let card = AgentCard {
name: "Test".into(),
description: "Test".into(),
version: "1.0".into(),
supported_interfaces: vec![],
capabilities: AgentCapabilities::default(),
default_input_modes: vec!["text/plain".into()],
default_output_modes: vec!["text/plain".into()],
skills: vec![],
provider: None,
documentation_url: None,
icon_url: None,
security_schemes: None,
security_requirements: None,
signatures: None,
};
let iface = AgentInterface::new("https://localhost:3443/jsonrpc", "JSONRPC");
let transport = f.create(&card, &iface).await.unwrap();
transport.destroy().await.unwrap();
}
}