use agent_client_protocol::{
ConnectionTo, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, Responder, SentRequest,
role::UntypedRole,
};
use expect_test::expect;
use futures::{AsyncRead, AsyncWrite};
use serde::{Deserialize, Serialize};
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
async fn recv<T: JsonRpcResponse + Send>(
response: SentRequest<T>,
) -> Result<T, agent_client_protocol::Error> {
let (tx, rx) = tokio::sync::oneshot::channel();
response.on_receiving_result(async move |result| {
tx.send(result)
.map_err(|_| agent_client_protocol::Error::internal_error())
})?;
rx.await
.map_err(|_| agent_client_protocol::Error::internal_error())?
}
fn setup_test_streams() -> (
impl AsyncRead,
impl AsyncWrite,
impl AsyncRead,
impl AsyncWrite,
) {
let (client_writer, server_reader) = tokio::io::duplex(1024);
let (server_writer, client_reader) = tokio::io::duplex(1024);
let server_reader = server_reader.compat();
let server_writer = server_writer.compat_write();
let client_reader = client_reader.compat();
let client_writer = client_writer.compat_write();
(server_reader, server_writer, client_reader, client_writer)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SimpleRequest {
message: String,
}
impl JsonRpcMessage for SimpleRequest {
fn matches_method(method: &str) -> bool {
method == "simple_method"
}
fn method(&self) -> &'static str {
"simple_method"
}
fn to_untyped_message(
&self,
) -> Result<agent_client_protocol::UntypedMessage, agent_client_protocol::Error> {
agent_client_protocol::UntypedMessage::new(self.method(), self)
}
fn parse_message(
method: &str,
params: &impl serde::Serialize,
) -> Result<Self, agent_client_protocol::Error> {
if !Self::matches_method(method) {
return Err(agent_client_protocol::Error::method_not_found());
}
agent_client_protocol::util::json_cast_params(params)
}
}
impl JsonRpcRequest for SimpleRequest {
type Response = SimpleResponse;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SimpleResponse {
result: String,
}
impl JsonRpcResponse for SimpleResponse {
fn into_json(self, _method: &str) -> Result<serde_json::Value, agent_client_protocol::Error> {
serde_json::to_value(self).map_err(agent_client_protocol::Error::into_internal_error)
}
fn from_value(
_method: &str,
value: serde_json::Value,
) -> Result<Self, agent_client_protocol::Error> {
agent_client_protocol::util::json_cast(&value)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SimpleNotification {
message: String,
}
impl JsonRpcMessage for SimpleNotification {
fn matches_method(method: &str) -> bool {
method == "simple_notification"
}
fn method(&self) -> &'static str {
"simple_notification"
}
fn to_untyped_message(
&self,
) -> Result<agent_client_protocol::UntypedMessage, agent_client_protocol::Error> {
agent_client_protocol::UntypedMessage::new(self.method(), self)
}
fn parse_message(
method: &str,
params: &impl serde::Serialize,
) -> Result<Self, agent_client_protocol::Error> {
if !Self::matches_method(method) {
return Err(agent_client_protocol::Error::method_not_found());
}
agent_client_protocol::util::json_cast_params(params)
}
}
impl agent_client_protocol::JsonRpcNotification for SimpleNotification {}
#[tokio::test(flavor = "current_thread")]
async fn test_invalid_json() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::task::LocalSet;
let local = LocalSet::new();
local
.run_until(async {
let (mut client_writer, server_reader) = tokio::io::duplex(1024);
let (server_writer, mut client_reader) = tokio::io::duplex(1024);
let server_reader = server_reader.compat();
let server_writer = server_writer.compat_write();
let server_transport =
agent_client_protocol::ByteStreams::new(server_writer, server_reader);
let server = UntypedRole.builder();
tokio::task::spawn_local(async move {
drop(server.connect_to(server_transport).await);
});
let invalid_json = b"{\"method\": \"test\", \"id\": 1, INVALID}\n";
client_writer.write_all(invalid_json).await.unwrap();
client_writer.flush().await.unwrap();
let mut buffer = vec![0u8; 1024];
let n = client_reader.read(&mut buffer).await.unwrap();
let response_str = String::from_utf8_lossy(&buffer[..n]);
let response: serde_json::Value =
serde_json::from_str(response_str.trim()).expect("Response should be valid JSON");
expect![[r#"
{
"error": {
"code": -32700,
"data": {
"line": "{\"method\": \"test\", \"id\": 1, INVALID}"
},
"message": "Parse error"
},
"jsonrpc": "2.0"
}"#]]
.assert_eq(&serde_json::to_string_pretty(&response).unwrap());
})
.await;
}
#[tokio::test]
#[ignore = "hangs indefinitely - see https://github.com/agentclientprotocol/rust-sdk/issues/64"]
async fn test_incomplete_line() {
use futures::io::Cursor;
let incomplete_json = b"{\"method\": \"test\", \"id\": 1";
let input = Cursor::new(incomplete_json.to_vec());
let output = Cursor::new(Vec::new());
let transport = agent_client_protocol::ByteStreams::new(output, input);
let connection = UntypedRole.builder();
let result = connection.connect_to(transport).await;
assert!(result.is_ok() || result.is_err());
}
#[tokio::test(flavor = "current_thread")]
async fn test_unknown_method() {
use tokio::task::LocalSet;
let local = LocalSet::new();
local
.run_until(async {
let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams();
let server_transport =
agent_client_protocol::ByteStreams::new(server_writer, server_reader);
let server = UntypedRole.builder();
let client_transport =
agent_client_protocol::ByteStreams::new(client_writer, client_reader);
let client = UntypedRole.builder();
tokio::task::spawn_local(async move {
server.connect_to(server_transport).await.ok();
});
let result = client
.connect_with(
client_transport,
async |cx| -> Result<(), agent_client_protocol::Error> {
let request = SimpleRequest {
message: "test".to_string(),
};
let result: Result<SimpleResponse, _> =
recv(cx.send_request(request)).await;
assert!(result.is_err());
if let Err(err) = result {
assert!(matches!(
err.code,
agent_client_protocol::ErrorCode::MethodNotFound
));
}
Ok(())
},
)
.await;
assert!(result.is_ok(), "Test failed: {result:?}");
})
.await;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ErrorRequest {
value: String,
}
impl JsonRpcMessage for ErrorRequest {
fn matches_method(method: &str) -> bool {
method == "error_method"
}
fn method(&self) -> &'static str {
"error_method"
}
fn to_untyped_message(
&self,
) -> Result<agent_client_protocol::UntypedMessage, agent_client_protocol::Error> {
agent_client_protocol::UntypedMessage::new(self.method(), self)
}
fn parse_message(
method: &str,
params: &impl serde::Serialize,
) -> Result<Self, agent_client_protocol::Error> {
if !Self::matches_method(method) {
return Err(agent_client_protocol::Error::method_not_found());
}
agent_client_protocol::util::json_cast_params(params)
}
}
impl JsonRpcRequest for ErrorRequest {
type Response = SimpleResponse;
}
#[tokio::test(flavor = "current_thread")]
async fn test_handler_returns_error() {
use tokio::task::LocalSet;
let local = LocalSet::new();
local
.run_until(async {
let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams();
let server_transport =
agent_client_protocol::ByteStreams::new(server_writer, server_reader);
let server = UntypedRole.builder().on_receive_request(
async |_request: ErrorRequest,
responder: Responder<SimpleResponse>,
_connection: ConnectionTo<UntypedRole>| {
responder.respond_with_error(agent_client_protocol::Error::internal_error())
},
agent_client_protocol::on_receive_request!(),
);
let client_transport =
agent_client_protocol::ByteStreams::new(client_writer, client_reader);
let client = UntypedRole.builder();
tokio::task::spawn_local(async move {
server.connect_to(server_transport).await.ok();
});
let result = client
.connect_with(
client_transport,
async |cx| -> Result<(), agent_client_protocol::Error> {
let request = ErrorRequest {
value: "trigger error".to_string(),
};
let result: Result<SimpleResponse, _> =
recv(cx.send_request(request)).await;
assert!(result.is_err());
if let Err(err) = result {
assert!(matches!(
err.code,
agent_client_protocol::ErrorCode::InternalError
));
}
Ok(())
},
)
.await;
assert!(result.is_ok(), "Test failed: {result:?}");
})
.await;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct EmptyRequest;
impl JsonRpcMessage for EmptyRequest {
fn matches_method(method: &str) -> bool {
method == "strict_method"
}
fn method(&self) -> &'static str {
"strict_method"
}
fn to_untyped_message(
&self,
) -> Result<agent_client_protocol::UntypedMessage, agent_client_protocol::Error> {
agent_client_protocol::UntypedMessage::new(self.method(), self)
}
fn parse_message(
method: &str,
_params: &impl serde::Serialize,
) -> Result<Self, agent_client_protocol::Error> {
if !Self::matches_method(method) {
return Err(agent_client_protocol::Error::method_not_found());
}
Ok(EmptyRequest)
}
}
impl JsonRpcRequest for EmptyRequest {
type Response = SimpleResponse;
}
#[tokio::test(flavor = "current_thread")]
async fn test_missing_required_params() {
use tokio::task::LocalSet;
let local = LocalSet::new();
local
.run_until(async {
let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams();
let server_transport =
agent_client_protocol::ByteStreams::new(server_writer, server_reader);
let server = UntypedRole.builder().on_receive_request(
async |_request: EmptyRequest,
responder: Responder<SimpleResponse>,
_connection: ConnectionTo<UntypedRole>| {
responder.respond_with_error(agent_client_protocol::Error::invalid_params())
},
agent_client_protocol::on_receive_request!(),
);
let client_transport =
agent_client_protocol::ByteStreams::new(client_writer, client_reader);
let client = UntypedRole.builder();
tokio::task::spawn_local(async move {
server.connect_to(server_transport).await.ok();
});
let result = client
.connect_with(
client_transport,
async |cx| -> Result<(), agent_client_protocol::Error> {
let request = EmptyRequest;
let result: Result<SimpleResponse, _> =
recv(cx.send_request(request)).await;
assert!(result.is_err());
if let Err(err) = result {
assert!(matches!(
err.code,
agent_client_protocol::ErrorCode::InvalidParams
)); }
Ok(())
},
)
.await;
assert!(result.is_ok(), "Test failed: {result:?}");
})
.await;
}
#[tokio::test(flavor = "current_thread")]
async fn test_invalid_params_keeps_connection_alive() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::task::LocalSet;
let local = LocalSet::new();
local
.run_until(async {
let (mut client_writer, server_reader) = tokio::io::duplex(4096);
let (server_writer, mut client_reader) = tokio::io::duplex(4096);
let server_reader = server_reader.compat();
let server_writer = server_writer.compat_write();
let server_transport =
agent_client_protocol::ByteStreams::new(server_writer, server_reader);
let server = UntypedRole.builder().on_receive_request(
async |request: SimpleRequest,
responder: Responder<SimpleResponse>,
_connection: ConnectionTo<UntypedRole>| {
responder.respond(SimpleResponse {
result: format!("echo: {}", request.message),
})
},
agent_client_protocol::on_receive_request!(),
);
tokio::task::spawn_local(async move {
drop(server.connect_to(server_transport).await);
});
let bad_request =
b"{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"simple_method\",\"params\":{\"wrong_field\":\"hello\"}}\n";
client_writer.write_all(bad_request).await.unwrap();
client_writer.flush().await.unwrap();
let mut buffer = vec![0u8; 4096];
let n = client_reader.read(&mut buffer).await.unwrap();
let response_str = String::from_utf8_lossy(&buffer[..n]);
let response: serde_json::Value =
serde_json::from_str(response_str.trim()).expect("Response should be valid JSON");
assert_eq!(response["id"], 1);
assert!(response["error"].is_object(), "Expected error object");
assert_eq!(
response["error"]["code"], -32602,
"Expected invalid params (-32602)"
);
let good_request =
b"{\"jsonrpc\":\"2.0\",\"id\":2,\"method\":\"simple_method\",\"params\":{\"message\":\"hello\"}}\n";
client_writer.write_all(good_request).await.unwrap();
client_writer.flush().await.unwrap();
let n = client_reader.read(&mut buffer).await.unwrap();
let response_str = String::from_utf8_lossy(&buffer[..n]);
let response: serde_json::Value =
serde_json::from_str(response_str.trim()).expect("Response should be valid JSON");
assert_eq!(response["id"], 2);
assert_eq!(response["result"]["result"], "echo: hello");
})
.await;
}
async fn read_jsonrpc_response_line(
reader: &mut tokio::io::BufReader<tokio::io::DuplexStream>,
) -> serde_json::Value {
use tokio::io::AsyncBufReadExt as _;
let mut line = String::new();
match tokio::time::timeout(
tokio::time::Duration::from_secs(1),
reader.read_line(&mut line),
)
.await
{
Ok(Ok(0)) | Err(_) => panic!("timed out waiting for JSON-RPC response"),
Ok(Ok(_)) => serde_json::from_str(line.trim()).expect("response should be valid JSON"),
Ok(Err(e)) => panic!("failed to read JSON-RPC response line: {e}"),
}
}
#[tokio::test(flavor = "current_thread")]
async fn test_bad_request_params_return_invalid_params_and_connection_stays_alive() {
use tokio::io::{AsyncWriteExt, BufReader};
use tokio::task::LocalSet;
let local = LocalSet::new();
local
.run_until(async {
let (mut client_writer, server_reader) = tokio::io::duplex(2048);
let (server_writer, client_reader) = tokio::io::duplex(2048);
let server_reader = server_reader.compat();
let server_writer = server_writer.compat_write();
let server_transport =
agent_client_protocol::ByteStreams::new(server_writer, server_reader);
let server = UntypedRole.builder().on_receive_request(
async |request: SimpleRequest,
responder: Responder<SimpleResponse>,
_connection: ConnectionTo<UntypedRole>| {
responder.respond(SimpleResponse {
result: format!("echo: {}", request.message),
})
},
agent_client_protocol::on_receive_request!(),
);
tokio::task::spawn_local(async move {
if let Err(err) = server.connect_to(server_transport).await {
panic!("server should stay alive: {err:?}");
}
});
let mut client_reader = BufReader::new(client_reader);
client_writer
.write_all(
br#"{"jsonrpc":"2.0","id":3,"method":"simple_method","params":{"content":"hello"}}
"#,
)
.await
.unwrap();
client_writer.flush().await.unwrap();
let invalid_response = read_jsonrpc_response_line(&mut client_reader).await;
expect![[r#"
{
"error": {
"code": -32602,
"data": {
"error": "missing field `message`",
"json": {
"content": "hello"
},
"phase": "deserialization"
},
"message": "Invalid params"
},
"id": 3,
"jsonrpc": "2.0"
}"#]]
.assert_eq(&serde_json::to_string_pretty(&invalid_response).unwrap());
client_writer
.write_all(
br#"{"jsonrpc":"2.0","id":4,"method":"simple_method","params":{"message":"hello"}}
"#,
)
.await
.unwrap();
client_writer.flush().await.unwrap();
let ok_response = read_jsonrpc_response_line(&mut client_reader).await;
expect![[r#"
{
"id": 4,
"jsonrpc": "2.0",
"result": {
"result": "echo: hello"
}
}"#]]
.assert_eq(&serde_json::to_string_pretty(&ok_response).unwrap());
})
.await;
}
#[tokio::test(flavor = "current_thread")]
async fn test_bad_notification_params_send_error_notification_and_connection_stays_alive() {
use tokio::io::{AsyncWriteExt, BufReader};
use tokio::task::LocalSet;
let local = LocalSet::new();
local
.run_until(async {
let (mut client_writer, server_reader) = tokio::io::duplex(2048);
let (server_writer, client_reader) = tokio::io::duplex(2048);
let server_reader = server_reader.compat();
let server_writer = server_writer.compat_write();
let server_transport =
agent_client_protocol::ByteStreams::new(server_writer, server_reader);
let server = UntypedRole
.builder()
.on_receive_notification(
async |_notif: SimpleNotification,
_connection: ConnectionTo<UntypedRole>| {
Ok(())
},
agent_client_protocol::on_receive_notification!(),
)
.on_receive_request(
async |request: SimpleRequest,
responder: Responder<SimpleResponse>,
_connection: ConnectionTo<UntypedRole>| {
responder.respond(SimpleResponse {
result: format!("echo: {}", request.message),
})
},
agent_client_protocol::on_receive_request!(),
);
tokio::task::spawn_local(async move {
if let Err(err) = server.connect_to(server_transport).await {
panic!("server should stay alive: {err:?}");
}
});
let mut client_reader = BufReader::new(client_reader);
client_writer
.write_all(
br#"{"jsonrpc":"2.0","method":"simple_notification","params":{"wrong_field":"hello"}}
"#,
)
.await
.unwrap();
client_writer.flush().await.unwrap();
let error_notification = read_jsonrpc_response_line(&mut client_reader).await;
expect![[r#"
{
"error": {
"code": -32602,
"data": {
"error": "missing field `message`",
"json": {
"wrong_field": "hello"
},
"phase": "deserialization"
},
"message": "Invalid params"
},
"jsonrpc": "2.0"
}"#]]
.assert_eq(&serde_json::to_string_pretty(&error_notification).unwrap());
client_writer
.write_all(
br#"{"jsonrpc":"2.0","id":10,"method":"simple_method","params":{"message":"after bad notification"}}
"#,
)
.await
.unwrap();
client_writer.flush().await.unwrap();
let ok_response = read_jsonrpc_response_line(&mut client_reader).await;
expect![[r#"
{
"id": 10,
"jsonrpc": "2.0",
"result": {
"result": "echo: after bad notification"
}
}"#]]
.assert_eq(&serde_json::to_string_pretty(&ok_response).unwrap());
})
.await;
}