use agent_client_protocol::{
ConnectionTo, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, Responder, SentRequest,
role::UntypedRole,
};
use futures::{AsyncRead, AsyncWrite};
use serde::{Deserialize, Serialize};
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
async fn recv<T: JsonRpcResponse + Send>(
response: SentRequest<T>,
) -> Result<T, agent_client_protocol::Error> {
let (tx, rx) = tokio::sync::oneshot::channel();
response.on_receiving_result(async move |result| {
tx.send(result)
.map_err(|_| agent_client_protocol::Error::internal_error())
})?;
rx.await
.map_err(|_| agent_client_protocol::Error::internal_error())?
}
fn setup_test_streams() -> (
impl AsyncRead,
impl AsyncWrite,
impl AsyncRead,
impl AsyncWrite,
) {
let (client_writer, server_reader) = tokio::io::duplex(1024);
let (server_writer, client_reader) = tokio::io::duplex(1024);
let server_reader = server_reader.compat();
let server_writer = server_writer.compat_write();
let client_reader = client_reader.compat();
let client_writer = client_writer.compat_write();
(server_reader, server_writer, client_reader, client_writer)
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct PingRequest {
value: u32,
}
impl JsonRpcMessage for PingRequest {
fn matches_method(method: &str) -> bool {
method == "ping"
}
fn method(&self) -> &'static str {
"ping"
}
fn to_untyped_message(
&self,
) -> Result<agent_client_protocol::UntypedMessage, agent_client_protocol::Error> {
agent_client_protocol::UntypedMessage::new(self.method(), self)
}
fn parse_message(
method: &str,
params: &impl serde::Serialize,
) -> Result<Self, agent_client_protocol::Error> {
if !Self::matches_method(method) {
return Err(agent_client_protocol::Error::method_not_found());
}
agent_client_protocol::util::json_cast(params)
}
}
impl JsonRpcRequest for PingRequest {
type Response = PongResponse;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PongResponse {
value: u32,
}
impl JsonRpcResponse for PongResponse {
fn into_json(self, _method: &str) -> Result<serde_json::Value, agent_client_protocol::Error> {
serde_json::to_value(self).map_err(agent_client_protocol::Error::into_internal_error)
}
fn from_value(
_method: &str,
value: serde_json::Value,
) -> Result<Self, agent_client_protocol::Error> {
agent_client_protocol::util::json_cast(&value)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SlowRequest {
delay_ms: u64,
id: u32,
}
impl JsonRpcMessage for SlowRequest {
fn matches_method(method: &str) -> bool {
method == "slow"
}
fn method(&self) -> &'static str {
"slow"
}
fn to_untyped_message(
&self,
) -> Result<agent_client_protocol::UntypedMessage, agent_client_protocol::Error> {
agent_client_protocol::UntypedMessage::new(self.method(), self)
}
fn parse_message(
method: &str,
params: &impl serde::Serialize,
) -> Result<Self, agent_client_protocol::Error> {
if !Self::matches_method(method) {
return Err(agent_client_protocol::Error::method_not_found());
}
agent_client_protocol::util::json_cast(params)
}
}
impl JsonRpcRequest for SlowRequest {
type Response = SlowResponse;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SlowResponse {
id: u32,
}
impl JsonRpcResponse for SlowResponse {
fn into_json(self, _method: &str) -> Result<serde_json::Value, agent_client_protocol::Error> {
serde_json::to_value(self).map_err(agent_client_protocol::Error::into_internal_error)
}
fn from_value(
_method: &str,
value: serde_json::Value,
) -> Result<Self, agent_client_protocol::Error> {
agent_client_protocol::util::json_cast(&value)
}
}
#[tokio::test(flavor = "current_thread")]
async fn test_bidirectional_communication() {
use tokio::task::LocalSet;
let local = LocalSet::new();
local
.run_until(async {
let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams();
let side_a_transport =
agent_client_protocol::ByteStreams::new(server_writer, server_reader);
let side_a = UntypedRole.builder().on_receive_request(
async |request: PingRequest,
responder: Responder<PongResponse>,
_connection: ConnectionTo<UntypedRole>| {
responder.respond(PongResponse {
value: request.value + 1,
})
},
agent_client_protocol::on_receive_request!(),
);
let side_b_transport =
agent_client_protocol::ByteStreams::new(client_writer, client_reader);
tokio::task::spawn_local(async move {
side_a.connect_to(side_a_transport).await.ok();
});
let result = UntypedRole
.builder()
.connect_with(
side_b_transport,
async |cx| -> Result<(), agent_client_protocol::Error> {
let request = PingRequest { value: 10 };
let response_future = recv(cx.send_request(request));
let response: Result<PongResponse, _> = response_future.await;
assert!(response.is_ok());
if let Ok(resp) = response {
assert_eq!(resp.value, 11);
}
Ok(())
},
)
.await;
assert!(result.is_ok(), "Test failed: {result:?}");
})
.await;
}
#[tokio::test(flavor = "current_thread")]
async fn test_request_ids() {
use tokio::task::LocalSet;
let local = LocalSet::new();
local
.run_until(async {
let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams();
let server_transport =
agent_client_protocol::ByteStreams::new(server_writer, server_reader);
let server = UntypedRole.builder().on_receive_request(
async |request: PingRequest,
responder: Responder<PongResponse>,
_connection: ConnectionTo<UntypedRole>| {
responder.respond(PongResponse {
value: request.value + 1,
})
},
agent_client_protocol::on_receive_request!(),
);
let client_transport =
agent_client_protocol::ByteStreams::new(client_writer, client_reader);
let client = UntypedRole.builder();
tokio::task::spawn_local(async move {
server.connect_to(server_transport).await.ok();
});
let result = client
.connect_with(
client_transport,
async |cx| -> Result<(), agent_client_protocol::Error> {
let req1 = PingRequest { value: 1 };
let req2 = PingRequest { value: 2 };
let req3 = PingRequest { value: 3 };
let resp1_future = recv(cx.send_request(req1));
let resp2_future = recv(cx.send_request(req2));
let resp3_future = recv(cx.send_request(req3));
let resp1: Result<PongResponse, _> = resp1_future.await;
let resp2: Result<PongResponse, _> = resp2_future.await;
let resp3: Result<PongResponse, _> = resp3_future.await;
assert_eq!(resp1.unwrap().value, 2); assert_eq!(resp2.unwrap().value, 3); assert_eq!(resp3.unwrap().value, 4);
Ok(())
},
)
.await;
assert!(result.is_ok(), "Test failed: {result:?}");
})
.await;
}
#[tokio::test(flavor = "current_thread")]
async fn test_out_of_order_responses() {
use tokio::task::LocalSet;
let local = LocalSet::new();
local
.run_until(async {
let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams();
let server_transport =
agent_client_protocol::ByteStreams::new(server_writer, server_reader);
let server = UntypedRole.builder().on_receive_request(
async |request: SlowRequest,
responder: Responder<SlowResponse>,
_connection: ConnectionTo<UntypedRole>| {
tokio::time::sleep(tokio::time::Duration::from_millis(request.delay_ms)).await;
responder.respond(SlowResponse { id: request.id })
},
agent_client_protocol::on_receive_request!(),
);
let client_transport =
agent_client_protocol::ByteStreams::new(client_writer, client_reader);
let client = UntypedRole.builder();
tokio::task::spawn_local(async move {
server.connect_to(server_transport).await.ok();
});
let result = client
.connect_with(
client_transport,
async |cx| -> Result<(), agent_client_protocol::Error> {
let req1 = SlowRequest {
delay_ms: 100,
id: 1,
};
let req2 = SlowRequest {
delay_ms: 50,
id: 2,
};
let req3 = SlowRequest {
delay_ms: 10,
id: 3,
};
let resp1_future = recv(cx.send_request(req1));
let resp2_future = recv(cx.send_request(req2));
let resp3_future = recv(cx.send_request(req3));
let resp1: Result<SlowResponse, _> = resp1_future.await;
let resp2: Result<SlowResponse, _> = resp2_future.await;
let resp3: Result<SlowResponse, _> = resp3_future.await;
assert_eq!(resp1.unwrap().id, 1);
assert_eq!(resp2.unwrap().id, 2);
assert_eq!(resp3.unwrap().id, 3);
Ok(())
},
)
.await;
assert!(result.is_ok(), "Test failed: {result:?}");
})
.await;
}