1use crate::notifications::{ElicitationParams, ElicitationResponse, McpNotification};
15use agent_client_protocol::schema::SessionNotification;
16use agent_client_protocol::{
17 self as acp, Agent, Builder, ByteStreams, Client, ConnectionTo, HandleDispatchFrom, NullRun, Responder,
18};
19use rmcp::model::{CreateElicitationRequestParams, ElicitationSchema};
20use std::collections::VecDeque;
21use std::sync::{Arc, Mutex};
22use tokio::io::DuplexStream;
23use tokio::sync::{mpsc, oneshot};
24use tokio::task::spawn_local;
25use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
26
27pub type DuplexByteStreams = ByteStreams<Compat<DuplexStream>, Compat<DuplexStream>>;
28
29pub struct TestPeer {
30 session_notifications: mpsc::UnboundedReceiver<SessionNotification>,
31 mcp_notifications: mpsc::UnboundedReceiver<McpNotification>,
32 elicitation_requests: mpsc::UnboundedReceiver<ElicitationParams>,
33 elicitation_responses: Arc<Mutex<VecDeque<ElicitationResponse>>>,
34 responder_capture: Arc<Mutex<Option<oneshot::Sender<Responder<ElicitationResponse>>>>>,
35}
36
37impl TestPeer {
38 pub fn new() -> (Self, Builder<Client, impl HandleDispatchFrom<Agent>, NullRun>) {
43 let (sn_tx, sn_rx) = mpsc::unbounded_channel::<SessionNotification>();
44 let (mcp_tx, mcp_rx) = mpsc::unbounded_channel::<McpNotification>();
45 let (el_tx, el_rx) = mpsc::unbounded_channel::<ElicitationParams>();
46 let elicitation_responses: Arc<Mutex<VecDeque<ElicitationResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
47 let responder_capture: Arc<Mutex<Option<oneshot::Sender<Responder<ElicitationResponse>>>>> =
48 Arc::new(Mutex::new(None));
49
50 let builder = Client
51 .builder()
52 .on_receive_notification(
53 {
54 let tx = sn_tx;
55 async move |n: SessionNotification, _cx| {
56 let _ = tx.send(n);
57 Ok(())
58 }
59 },
60 acp::on_receive_notification!(),
61 )
62 .on_receive_notification(
63 {
64 let tx = mcp_tx;
65 async move |n: McpNotification, _cx| {
66 let _ = tx.send(n);
67 Ok(())
68 }
69 },
70 acp::on_receive_notification!(),
71 )
72 .on_receive_request(
73 {
74 let tx = el_tx;
75 let responses = elicitation_responses.clone();
76 let capture = responder_capture.clone();
77 async move |req: ElicitationParams, responder: Responder<ElicitationResponse>, _cx| {
78 if let Some(capture_tx) = capture.lock().unwrap().take() {
79 return match capture_tx.send(responder) {
80 Ok(()) => Ok(()),
81 Err(responder) => responder.respond_with_error(acp::Error::internal_error()),
82 };
83 }
84 let _ = tx.send(req);
85 let queued = responses.lock().unwrap().pop_front();
86 match queued {
87 Some(response) => responder.respond(response),
88 None => responder.respond_with_error(acp::Error::method_not_found()),
89 }
90 }
91 },
92 acp::on_receive_request!(),
93 );
94
95 let peer = Self {
96 session_notifications: sn_rx,
97 mcp_notifications: mcp_rx,
98 elicitation_requests: el_rx,
99 elicitation_responses,
100 responder_capture,
101 };
102 (peer, builder)
103 }
104
105 pub async fn next_session_notification(&mut self) -> SessionNotification {
106 self.session_notifications.recv().await.expect("peer channel closed")
107 }
108
109 pub async fn next_mcp_notification(&mut self) -> McpNotification {
110 self.mcp_notifications.recv().await.expect("peer channel closed")
111 }
112
113 pub async fn next_elicitation_request(&mut self) -> ElicitationParams {
114 self.elicitation_requests.recv().await.expect("peer channel closed")
115 }
116
117 pub fn queue_elicitation_response(&self, response: ElicitationResponse) {
122 self.elicitation_responses.lock().unwrap().push_back(response);
123 }
124
125 pub fn capture_next_elicitation(&self) -> oneshot::Receiver<Responder<ElicitationResponse>> {
126 let (responder_tx, responder_rx) = oneshot::channel::<Responder<ElicitationResponse>>();
127 *self.responder_capture.lock().unwrap() = Some(responder_tx);
128 responder_rx
129 }
130
131 pub async fn fake_elicitation(
140 &mut self,
141 cx: &ConnectionTo<Client>,
142 ) -> (Responder<ElicitationResponse>, oneshot::Receiver<ElicitationResponse>) {
143 let (responder_tx, responder_rx) = oneshot::channel::<Responder<ElicitationResponse>>();
144 *self.responder_capture.lock().unwrap() = Some(responder_tx);
145
146 let (response_tx, response_rx) = oneshot::channel::<ElicitationResponse>();
147 let cx = cx.clone();
148 spawn_local(async move {
149 if let Ok(resp) = cx.send_request(placeholder_params()).block_task().await {
150 let _ = response_tx.send(resp);
151 }
152 });
153
154 let responder = responder_rx.await.expect("client handler must capture responder");
155 (responder, response_rx)
156 }
157}
158
159pub fn duplex_pair() -> (DuplexByteStreams, DuplexByteStreams) {
163 let (agent_writer, client_reader) = tokio::io::duplex(4096);
164 let (client_writer, agent_reader) = tokio::io::duplex(4096);
165 let agent_transport = ByteStreams::new(agent_writer.compat_write(), agent_reader.compat());
166 let client_transport = ByteStreams::new(client_writer.compat_write(), client_reader.compat());
167 (agent_transport, client_transport)
168}
169
170pub async fn test_connection() -> (ConnectionTo<Client>, TestPeer) {
173 let (peer, client_builder) = TestPeer::new();
174 let (agent_transport, client_transport) = duplex_pair();
175
176 spawn_local(async move {
177 let _ = client_builder.connect_to(client_transport).await;
178 });
179
180 let (cx_tx, cx_rx) = oneshot::channel::<ConnectionTo<Client>>();
181 spawn_local(async move {
182 let _ = Agent
183 .builder()
184 .connect_with(agent_transport, async move |cx: ConnectionTo<Client>| {
185 let _ = cx_tx.send(cx);
186 std::future::pending::<()>().await;
187 Ok(())
188 })
189 .await;
190 });
191
192 let cx = cx_rx.await.expect("agent side connect_with produced a ConnectionTo");
193 (cx, peer)
194}
195
196fn placeholder_params() -> ElicitationParams {
197 ElicitationParams {
198 server_name: String::new(),
199 request: CreateElicitationRequestParams::FormElicitationParams {
200 meta: None,
201 message: String::new(),
202 requested_schema: ElicitationSchema::builder().build().expect("empty schema is valid"),
203 },
204 }
205}