use agent_client_protocol::schema::{
AgentCapabilities, InitializeProxyRequest, InitializeRequest, InitializeResponse,
ProtocolVersion,
};
use agent_client_protocol::{Agent, Client, Conductor, ConnectTo, DynConnectTo, Proxy};
use agent_client_protocol_conductor::{ConductorImpl, McpBridgeMode, ProxiesAndAgent};
use agent_client_protocol_test::testy::Testy;
use std::sync::Arc;
use std::sync::Mutex;
use tokio::io::duplex;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
async fn recv<T: agent_client_protocol::JsonRpcResponse + Send>(
response: agent_client_protocol::SentRequest<T>,
) -> Result<T, agent_client_protocol::Error> {
let (tx, rx) = tokio::sync::oneshot::channel();
response.on_receiving_result(async move |result| {
tx.send(result)
.map_err(|_| agent_client_protocol::Error::internal_error())
})?;
rx.await
.map_err(|_| agent_client_protocol::Error::internal_error())?
}
#[derive(Debug, Clone, PartialEq)]
enum InitRequestType {
Initialize,
InitializeProxy,
}
struct InitConfig {
received_init_type: Mutex<Option<InitRequestType>>,
}
impl InitConfig {
fn new() -> Arc<Self> {
Arc::new(Self {
received_init_type: Mutex::new(None),
})
}
fn read_init_type(&self) -> Option<InitRequestType> {
self.received_init_type
.lock()
.expect("not poisoned")
.clone()
}
}
struct InitComponent {
config: Arc<InitConfig>,
}
impl InitComponent {
fn new(config: &Arc<InitConfig>) -> Self {
Self {
config: config.clone(),
}
}
}
impl ConnectTo<Conductor> for InitComponent {
async fn connect_to(
self,
client: impl ConnectTo<Proxy>,
) -> Result<(), agent_client_protocol::Error> {
let config = self.config;
let config2 = Arc::clone(&config);
Proxy
.builder()
.name("init-component")
.on_receive_request_from(
Client,
async move |request: InitializeProxyRequest, responder, cx| {
*config.received_init_type.lock().expect("unpoisoned") =
Some(InitRequestType::InitializeProxy);
cx.send_request_to(agent_client_protocol::Agent, request.initialize)
.on_receiving_result(async move |response| {
let response: InitializeResponse = response?;
responder.respond(response)
})
},
agent_client_protocol::on_receive_request!(),
)
.on_receive_request_from(
Client,
async move |request: InitializeRequest, responder, _cx| {
*config2.received_init_type.lock().expect("unpoisoned") =
Some(InitRequestType::Initialize);
let response = InitializeResponse::new(request.protocol_version)
.agent_capabilities(AgentCapabilities::new());
responder.respond(response)
},
agent_client_protocol::on_receive_request!(),
)
.connect_to(client)
.await
}
}
async fn run_test_with_components(
proxies: Vec<InitComponent>,
editor_task: impl AsyncFnOnce(
agent_client_protocol::ConnectionTo<Agent>,
) -> Result<(), agent_client_protocol::Error>,
) -> Result<(), agent_client_protocol::Error> {
let (editor_out, conductor_in) = duplex(1024);
let (conductor_out, editor_in) = duplex(1024);
let transport =
agent_client_protocol::ByteStreams::new(editor_out.compat_write(), editor_in.compat());
agent_client_protocol::Client
.builder()
.name("editor-to-connector")
.with_spawned(|_cx| async move {
ConductorImpl::new_agent(
"conductor".to_string(),
ProxiesAndAgent::new(Testy::new()).proxies(proxies),
McpBridgeMode::default(),
)
.run(agent_client_protocol::ByteStreams::new(
conductor_out.compat_write(),
conductor_in.compat(),
))
.await
})
.connect_with(transport, editor_task)
.await
}
#[tokio::test]
async fn test_single_component_gets_initialize_request() -> Result<(), agent_client_protocol::Error>
{
run_test_with_components(vec![], async |connection_to_editor| {
let init_response = recv(
connection_to_editor.send_request(InitializeRequest::new(ProtocolVersion::LATEST)),
)
.await;
assert!(
init_response.is_ok(),
"Initialize should succeed: {init_response:?}"
);
Ok::<(), agent_client_protocol::Error>(())
})
.await?;
Ok(())
}
#[tokio::test]
async fn test_two_components_proxy_gets_initialize_proxy()
-> Result<(), agent_client_protocol::Error> {
let component1 = InitConfig::new();
run_test_with_components(
vec![InitComponent::new(&component1)],
async |connection_to_editor| {
let init_response = recv(
connection_to_editor.send_request(InitializeRequest::new(ProtocolVersion::LATEST)),
)
.await;
assert!(
init_response.is_ok(),
"Initialize should succeed: {init_response:?}"
);
Ok::<(), agent_client_protocol::Error>(())
},
)
.await?;
assert_eq!(
component1.read_init_type(),
Some(InitRequestType::InitializeProxy),
"Proxy component should receive InitializeProxyRequest"
);
Ok(())
}
#[tokio::test]
async fn test_three_components_all_proxies_get_initialize_proxy()
-> Result<(), agent_client_protocol::Error> {
let component1 = InitConfig::new();
let component2 = InitConfig::new();
run_test_with_components(
vec![
InitComponent::new(&component1),
InitComponent::new(&component2),
],
async |connection_to_editor| {
let init_response = recv(
connection_to_editor.send_request(InitializeRequest::new(ProtocolVersion::LATEST)),
)
.await;
assert!(
init_response.is_ok(),
"Initialize should succeed: {init_response:?}"
);
Ok::<(), agent_client_protocol::Error>(())
},
)
.await?;
assert_eq!(
component1.read_init_type(),
Some(InitRequestType::InitializeProxy),
"First proxy should receive InitializeProxyRequest"
);
assert_eq!(
component2.read_init_type(),
Some(InitRequestType::InitializeProxy),
"Second proxy should receive InitializeProxyRequest"
);
Ok(())
}
struct BadProxy;
impl ConnectTo<Conductor> for BadProxy {
async fn connect_to(
self,
client: impl ConnectTo<Proxy>,
) -> Result<(), agent_client_protocol::Error> {
Proxy
.builder()
.name("bad-proxy")
.on_receive_request_from(
Client,
async move |request: InitializeProxyRequest, responder, cx| {
cx.send_request_to(Agent, request)
.on_receiving_result(async move |response| {
let response: InitializeResponse = response?;
responder.respond(response)
})
},
agent_client_protocol::on_receive_request!(),
)
.connect_to(client)
.await
}
}
async fn run_bad_proxy_test(
proxies: Vec<DynConnectTo<Conductor>>,
agent: DynConnectTo<Client>,
editor_task: impl AsyncFnOnce(
agent_client_protocol::ConnectionTo<Agent>,
) -> Result<(), agent_client_protocol::Error>,
) -> Result<(), agent_client_protocol::Error> {
let (editor_out, conductor_in) = duplex(1024);
let (conductor_out, editor_in) = duplex(1024);
let transport =
agent_client_protocol::ByteStreams::new(editor_out.compat_write(), editor_in.compat());
agent_client_protocol::Client
.builder()
.name("editor-to-connector")
.with_spawned(|_cx| async move {
ConductorImpl::new_agent(
"conductor".to_string(),
ProxiesAndAgent::new(agent).proxies(proxies),
McpBridgeMode::default(),
)
.run(agent_client_protocol::ByteStreams::new(
conductor_out.compat_write(),
conductor_in.compat(),
))
.await
})
.connect_with(transport, editor_task)
.await
}
#[tokio::test]
async fn test_conductor_rejects_initialize_proxy_forwarded_to_agent()
-> Result<(), agent_client_protocol::Error> {
let result = run_bad_proxy_test(
vec![DynConnectTo::new(BadProxy)],
DynConnectTo::new(Testy::new()),
async |connection_to_editor| {
let init_response = recv(
connection_to_editor.send_request(InitializeRequest::new(ProtocolVersion::LATEST)),
)
.await;
if let Err(err) = init_response {
assert!(
err.to_string().contains("initialize/proxy"),
"Error should mention initialize/proxy: {err:?}"
);
}
Ok::<(), agent_client_protocol::Error>(())
},
)
.await;
match result {
Ok(()) => panic!("Expected error when proxy forwards InitializeProxyRequest to agent"),
Err(err) => {
assert!(
err.to_string().contains("initialize/proxy"),
"Error should mention initialize/proxy: {err:?}"
);
}
}
Ok(())
}
#[tokio::test]
async fn test_conductor_rejects_initialize_proxy_forwarded_to_proxy()
-> Result<(), agent_client_protocol::Error> {
let result = run_bad_proxy_test(
vec![
DynConnectTo::new(BadProxy),
DynConnectTo::new(InitComponent::new(&InitConfig::new())), ],
DynConnectTo::new(Testy::new()), async |connection_to_editor| {
let init_response = recv(
connection_to_editor.send_request(InitializeRequest::new(ProtocolVersion::LATEST)),
)
.await;
if let Err(err) = init_response {
assert!(
err.to_string().contains("initialize/proxy"),
"Error should mention initialize/proxy: {err:?}"
);
}
Ok::<(), agent_client_protocol::Error>(())
},
)
.await;
match result {
Ok(()) => panic!("Expected error when proxy forwards InitializeProxyRequest to proxy"),
Err(err) => {
assert!(
err.to_string().contains("initialize/proxy"),
"Error should mention initialize/proxy: {err:?}"
);
}
}
Ok(())
}