1use crate::notifications::{ElicitationParams, ElicitationResponse, McpNotification};
15use agent_client_protocol::schema::SessionNotification;
16use agent_client_protocol::{
17 self as acp, Agent, Builder, ByteStreams, Client, ConnectionTo, HandleDispatchFrom, NullRun, Responder,
18};
19use rmcp::model::{CreateElicitationRequestParams, ElicitationSchema};
20use std::collections::VecDeque;
21use std::sync::{Arc, Mutex};
22use tokio::io::DuplexStream;
23use tokio::sync::{mpsc, oneshot};
24use tokio::task::spawn_local;
25use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
26
27pub type DuplexByteStreams = ByteStreams<Compat<DuplexStream>, Compat<DuplexStream>>;
28
29pub struct TestPeer {
30 session_notifications: mpsc::UnboundedReceiver<SessionNotification>,
31 mcp_notifications: mpsc::UnboundedReceiver<McpNotification>,
32 elicitation_requests: mpsc::UnboundedReceiver<ElicitationParams>,
33 elicitation_responses: Arc<Mutex<VecDeque<ElicitationResponse>>>,
34 responder_capture: Arc<Mutex<Option<oneshot::Sender<Responder<ElicitationResponse>>>>>,
35}
36
37impl TestPeer {
38 pub fn new() -> (Self, Builder<Client, impl HandleDispatchFrom<Agent>, NullRun>) {
43 let (sn_tx, sn_rx) = mpsc::unbounded_channel::<SessionNotification>();
44 let (mcp_tx, mcp_rx) = mpsc::unbounded_channel::<McpNotification>();
45 let (el_tx, el_rx) = mpsc::unbounded_channel::<ElicitationParams>();
46 let elicitation_responses: Arc<Mutex<VecDeque<ElicitationResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
47 let responder_capture: Arc<Mutex<Option<oneshot::Sender<Responder<ElicitationResponse>>>>> =
48 Arc::new(Mutex::new(None));
49
50 let builder = Client
51 .builder()
52 .on_receive_notification(
53 {
54 let tx = sn_tx;
55 async move |n: SessionNotification, _cx| {
56 let _ = tx.send(n);
57 Ok(())
58 }
59 },
60 acp::on_receive_notification!(),
61 )
62 .on_receive_notification(
63 {
64 let tx = mcp_tx;
65 async move |n: McpNotification, _cx| {
66 let _ = tx.send(n);
67 Ok(())
68 }
69 },
70 acp::on_receive_notification!(),
71 )
72 .on_receive_request(
73 {
74 let tx = el_tx;
75 let responses = elicitation_responses.clone();
76 let capture = responder_capture.clone();
77 async move |req: ElicitationParams, responder: Responder<ElicitationResponse>, _cx| {
78 if let Some(capture_tx) = capture.lock().unwrap().take() {
79 return match capture_tx.send(responder) {
80 Ok(()) => Ok(()),
81 Err(responder) => responder.respond_with_error(acp::Error::internal_error()),
82 };
83 }
84 let _ = tx.send(req);
85 let queued = responses.lock().unwrap().pop_front();
86 match queued {
87 Some(response) => responder.respond(response),
88 None => responder.respond_with_error(acp::Error::method_not_found()),
89 }
90 }
91 },
92 acp::on_receive_request!(),
93 );
94
95 let peer = Self {
96 session_notifications: sn_rx,
97 mcp_notifications: mcp_rx,
98 elicitation_requests: el_rx,
99 elicitation_responses,
100 responder_capture,
101 };
102 (peer, builder)
103 }
104
105 pub async fn next_session_notification(&mut self) -> SessionNotification {
106 self.session_notifications.recv().await.expect("peer channel closed")
107 }
108
109 pub async fn next_mcp_notification(&mut self) -> McpNotification {
110 self.mcp_notifications.recv().await.expect("peer channel closed")
111 }
112
113 pub async fn next_elicitation_request(&mut self) -> ElicitationParams {
114 self.elicitation_requests.recv().await.expect("peer channel closed")
115 }
116
117 pub fn queue_elicitation_response(&self, response: ElicitationResponse) {
122 self.elicitation_responses.lock().unwrap().push_back(response);
123 }
124
125 pub async fn fake_elicitation(
134 &mut self,
135 cx: &ConnectionTo<Client>,
136 ) -> (Responder<ElicitationResponse>, oneshot::Receiver<ElicitationResponse>) {
137 let (responder_tx, responder_rx) = oneshot::channel::<Responder<ElicitationResponse>>();
138 *self.responder_capture.lock().unwrap() = Some(responder_tx);
139
140 let (response_tx, response_rx) = oneshot::channel::<ElicitationResponse>();
141 let cx = cx.clone();
142 spawn_local(async move {
143 if let Ok(resp) = cx.send_request(placeholder_params()).block_task().await {
144 let _ = response_tx.send(resp);
145 }
146 });
147
148 let responder = responder_rx.await.expect("client handler must capture responder");
149 (responder, response_rx)
150 }
151}
152
153pub fn duplex_pair() -> (DuplexByteStreams, DuplexByteStreams) {
157 let (agent_writer, client_reader) = tokio::io::duplex(4096);
158 let (client_writer, agent_reader) = tokio::io::duplex(4096);
159 let agent_transport = ByteStreams::new(agent_writer.compat_write(), agent_reader.compat());
160 let client_transport = ByteStreams::new(client_writer.compat_write(), client_reader.compat());
161 (agent_transport, client_transport)
162}
163
164pub async fn test_connection() -> (ConnectionTo<Client>, TestPeer) {
167 let (peer, client_builder) = TestPeer::new();
168 let (agent_transport, client_transport) = duplex_pair();
169
170 spawn_local(async move {
171 let _ = client_builder.connect_to(client_transport).await;
172 });
173
174 let (cx_tx, cx_rx) = oneshot::channel::<ConnectionTo<Client>>();
175 spawn_local(async move {
176 let _ = Agent
177 .builder()
178 .connect_with(agent_transport, async move |cx: ConnectionTo<Client>| {
179 let _ = cx_tx.send(cx);
180 std::future::pending::<()>().await;
181 Ok(())
182 })
183 .await;
184 });
185
186 let cx = cx_rx.await.expect("agent side connect_with produced a ConnectionTo");
187 (cx, peer)
188}
189
190fn placeholder_params() -> ElicitationParams {
191 ElicitationParams {
192 server_name: String::new(),
193 request: CreateElicitationRequestParams::FormElicitationParams {
194 meta: None,
195 message: String::new(),
196 requested_schema: ElicitationSchema::builder().build().expect("empty schema is valid"),
197 },
198 }
199}