use sacp::schema::{McpServer, NewSessionRequest, NewSessionResponse};
use sacp::{JrHandlerChain, JrRequestCx};
use sacp_proxy::{AcpProxyExt, McpServiceRegistry};
use std::sync::{Arc, Mutex};
use tokio::task::LocalSet;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
async fn recv<R: sacp::JrResponsePayload + Send>(
response: sacp::JrResponse<R>,
) -> Result<R, sacp::Error> {
let (tx, rx) = tokio::sync::oneshot::channel();
response.await_when_result_received(async move |result| {
tx.send(result).map_err(|_| sacp::Error::internal_error())
})?;
rx.await.map_err(|_| sacp::Error::internal_error())?
}
struct MockMcpServer;
impl sacp::Component for MockMcpServer {
fn serve(
self,
client: impl sacp::Component,
) -> impl std::future::Future<Output = Result<(), sacp::Error>> + Send {
async move {
sacp::JrHandlerChain::new().serve(client).await
}
}
}
#[tokio::test(flavor = "current_thread")]
async fn test_provide_mcp_handler_chain() {
let local = LocalSet::new();
local
.run_until(async {
let (client_writer, server_reader) = tokio::io::duplex(1024);
let (server_writer, client_reader) = tokio::io::duplex(1024);
let server_reader = server_reader.compat();
let server_writer = server_writer.compat_write();
let client_reader = client_reader.compat();
let client_writer = client_writer.compat_write();
let registry = McpServiceRegistry::new();
registry
.add_mcp_server("test-server", || MockMcpServer)
.expect("Failed to add MCP server");
let seen_servers = Arc::new(Mutex::new(Vec::new()));
let seen_servers_clone = seen_servers.clone();
let server_transport = sacp::ByteStreams::new(server_writer, server_reader);
let server = JrHandlerChain::new()
.provide_mcp(registry)
.on_receive_request(
async move |request: NewSessionRequest,
request_cx: JrRequestCx<NewSessionResponse>| {
let mut servers = seen_servers_clone.lock().unwrap();
for server in &request.mcp_servers {
if let McpServer::Http { name, .. } = server {
servers.push(name.clone());
}
}
request_cx.respond(NewSessionResponse {
session_id: "test-session".to_string().into(),
meta: None,
modes: None,
})
},
)
.proxy();
let client_transport = sacp::ByteStreams::new(client_writer, client_reader);
let client = JrHandlerChain::new();
tokio::task::spawn_local(async move {
if let Err(e) = server.serve(server_transport).await {
eprintln!("Server error: {e:?}");
}
});
let result = client
.with_client(
client_transport,
async |cx| -> std::result::Result<(), sacp::Error> {
let request = NewSessionRequest {
mcp_servers: vec![],
cwd: std::env::current_dir()
.unwrap_or_else(|_| std::path::PathBuf::from("/")),
meta: None,
};
let response =
recv(cx.send_request(request))
.await
.map_err(|e| -> sacp::Error {
sacp::util::internal_error(format!(
"NewSession request failed: {e:?}"
))
})?;
assert_eq!(response.session_id, "test-session".into());
Ok(())
},
)
.await;
assert!(result.is_ok(), "Test failed: {:?}", result);
let servers = seen_servers.lock().unwrap();
assert_eq!(
servers.len(),
1,
"Expected custom handler to see 1 MCP server"
);
assert_eq!(servers[0], "test-server");
})
.await;
}
#[tokio::test(flavor = "current_thread")]
async fn test_provide_mcp_preserves_existing_servers() {
let local = LocalSet::new();
local
.run_until(async {
let (client_writer, server_reader) = tokio::io::duplex(1024);
let (server_writer, client_reader) = tokio::io::duplex(1024);
let server_reader = server_reader.compat();
let server_writer = server_writer.compat_write();
let client_reader = client_reader.compat();
let client_writer = client_writer.compat_write();
let registry = McpServiceRegistry::new();
registry
.add_mcp_server("injected-server", || MockMcpServer)
.expect("Failed to add MCP server");
let seen_servers = Arc::new(Mutex::new(Vec::new()));
let seen_servers_clone = seen_servers.clone();
let server_transport = sacp::ByteStreams::new(server_writer, server_reader);
let server = JrHandlerChain::new()
.provide_mcp(registry)
.on_receive_request(
async move |request: NewSessionRequest,
request_cx: JrRequestCx<NewSessionResponse>| {
let mut servers = seen_servers_clone.lock().unwrap();
for server in &request.mcp_servers {
if let McpServer::Http { name, .. } = server {
servers.push(name.clone());
}
}
request_cx.respond(NewSessionResponse {
session_id: "test-session".to_string().into(),
meta: None,
modes: None,
})
},
)
.proxy();
let client_transport = sacp::ByteStreams::new(client_writer, client_reader);
let client = JrHandlerChain::new();
tokio::task::spawn_local(async move {
if let Err(e) = server.serve(server_transport).await {
eprintln!("Server error: {e:?}");
}
});
let result = client
.with_client(
client_transport,
async |cx| -> std::result::Result<(), sacp::Error> {
let request = NewSessionRequest {
mcp_servers: vec![McpServer::Http {
name: "existing-server".to_string(),
url: "http://example.com".to_string(),
headers: vec![],
}],
cwd: std::env::current_dir()
.unwrap_or_else(|_| std::path::PathBuf::from("/")),
meta: None,
};
let response =
recv(cx.send_request(request))
.await
.map_err(|e| -> sacp::Error {
sacp::util::internal_error(format!(
"NewSession request failed: {e:?}"
))
})?;
assert_eq!(response.session_id, "test-session".into());
Ok(())
},
)
.await;
assert!(result.is_ok(), "Test failed: {:?}", result);
let servers = seen_servers.lock().unwrap();
assert_eq!(
servers.len(),
2,
"Expected custom handler to see 2 MCP servers"
);
assert!(servers.contains(&"existing-server".to_string()));
assert!(servers.contains(&"injected-server".to_string()));
})
.await;
}
#[tokio::test(flavor = "current_thread")]
async fn test_provide_mcp_multiple_servers() {
let local = LocalSet::new();
local
.run_until(async {
let (client_writer, server_reader) = tokio::io::duplex(1024);
let (server_writer, client_reader) = tokio::io::duplex(1024);
let server_reader = server_reader.compat();
let server_writer = server_writer.compat_write();
let client_reader = client_reader.compat();
let client_writer = client_writer.compat_write();
let registry = McpServiceRegistry::new();
registry
.add_mcp_server("server-1", || MockMcpServer)
.expect("Failed to add MCP server 1");
registry
.add_mcp_server("server-2", || MockMcpServer)
.expect("Failed to add MCP server 2");
registry
.add_mcp_server("server-3", || MockMcpServer)
.expect("Failed to add MCP server 3");
let seen_servers = Arc::new(Mutex::new(Vec::new()));
let seen_servers_clone = seen_servers.clone();
let server_transport = sacp::ByteStreams::new(server_writer, server_reader);
let server = JrHandlerChain::new()
.provide_mcp(registry)
.on_receive_request(
async move |request: NewSessionRequest,
request_cx: JrRequestCx<NewSessionResponse>| {
let mut servers = seen_servers_clone.lock().unwrap();
for server in &request.mcp_servers {
if let McpServer::Http { name, .. } = server {
servers.push(name.clone());
}
}
request_cx.respond(NewSessionResponse {
session_id: "test-session".to_string().into(),
meta: None,
modes: None,
})
},
)
.proxy();
let client_transport = sacp::ByteStreams::new(client_writer, client_reader);
let client = JrHandlerChain::new();
tokio::task::spawn_local(async move {
if let Err(e) = server.serve(server_transport).await {
eprintln!("Server error: {e:?}");
}
});
let result = client
.with_client(
client_transport,
async |cx| -> std::result::Result<(), sacp::Error> {
let request = NewSessionRequest {
mcp_servers: vec![],
cwd: std::env::current_dir()
.unwrap_or_else(|_| std::path::PathBuf::from("/")),
meta: None,
};
let response =
recv(cx.send_request(request))
.await
.map_err(|e| -> sacp::Error {
sacp::util::internal_error(format!(
"NewSession request failed: {e:?}"
))
})?;
assert_eq!(response.session_id, "test-session".into());
Ok(())
},
)
.await;
assert!(result.is_ok(), "Test failed: {:?}", result);
let servers = seen_servers.lock().unwrap();
assert_eq!(
servers.len(),
3,
"Expected custom handler to see 3 MCP servers"
);
assert!(servers.contains(&"server-1".to_string()));
assert!(servers.contains(&"server-2".to_string()));
assert!(servers.contains(&"server-3".to_string()));
})
.await;
}