Skip to main content

acp_utils/
testing.rs

1//! Duplex-backed test harness for ACP connections.
2//!
3//! [`test_connection`] returns a full `(ConnectionTo<Client>, TestPeer)` pair
4//! over an in-memory duplex transport. Use it for integration-style tests that
5//! need to exercise the full serialize/dispatch path (so wire-format
6//! regressions like extension method-name typos surface in tests).
7//!
8//! When a test needs to pass a real [`Responder<ElicitationResponse>`] into a
9//! component under test (e.g. an elicitation UI) and observe what that
10//! component eventually sends, call [`TestPeer::fake_elicitation`]: it kicks
11//! off a placeholder elicitation request, hands back the captured responder,
12//! and returns a receiver that resolves when the responder is consumed.
13
14use crate::notifications::{ElicitationParams, ElicitationResponse, McpNotification};
15use agent_client_protocol::schema::SessionNotification;
16use agent_client_protocol::{
17    self as acp, Agent, Builder, ByteStreams, Client, ConnectionTo, HandleDispatchFrom, NullRun, Responder,
18};
19use rmcp::model::{CreateElicitationRequestParams, ElicitationSchema};
20use std::collections::VecDeque;
21use std::sync::{Arc, Mutex};
22use tokio::io::DuplexStream;
23use tokio::sync::{mpsc, oneshot};
24use tokio::task::spawn_local;
25use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
26
27pub type DuplexByteStreams = ByteStreams<Compat<DuplexStream>, Compat<DuplexStream>>;
28
29pub struct TestPeer {
30    session_notifications: mpsc::UnboundedReceiver<SessionNotification>,
31    mcp_notifications: mpsc::UnboundedReceiver<McpNotification>,
32    elicitation_requests: mpsc::UnboundedReceiver<ElicitationParams>,
33    elicitation_responses: Arc<Mutex<VecDeque<ElicitationResponse>>>,
34    responder_capture: Arc<Mutex<Option<oneshot::Sender<Responder<ElicitationResponse>>>>>,
35}
36
37impl TestPeer {
38    /// Build a `TestPeer` plus a pre-wired `Client.builder()` whose
39    /// notification handlers route session/mcp/elicitation traffic into the
40    /// peer. The caller decides whether to run the builder via `connect_to`
41    /// (drop the agent-side cx) or `connect_with` (capture the agent-side cx).
42    pub fn new() -> (Self, Builder<Client, impl HandleDispatchFrom<Agent>, NullRun>) {
43        let (sn_tx, sn_rx) = mpsc::unbounded_channel::<SessionNotification>();
44        let (mcp_tx, mcp_rx) = mpsc::unbounded_channel::<McpNotification>();
45        let (el_tx, el_rx) = mpsc::unbounded_channel::<ElicitationParams>();
46        let elicitation_responses: Arc<Mutex<VecDeque<ElicitationResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
47        let responder_capture: Arc<Mutex<Option<oneshot::Sender<Responder<ElicitationResponse>>>>> =
48            Arc::new(Mutex::new(None));
49
50        let builder = Client
51            .builder()
52            .on_receive_notification(
53                {
54                    let tx = sn_tx;
55                    async move |n: SessionNotification, _cx| {
56                        let _ = tx.send(n);
57                        Ok(())
58                    }
59                },
60                acp::on_receive_notification!(),
61            )
62            .on_receive_notification(
63                {
64                    let tx = mcp_tx;
65                    async move |n: McpNotification, _cx| {
66                        let _ = tx.send(n);
67                        Ok(())
68                    }
69                },
70                acp::on_receive_notification!(),
71            )
72            .on_receive_request(
73                {
74                    let tx = el_tx;
75                    let responses = elicitation_responses.clone();
76                    let capture = responder_capture.clone();
77                    async move |req: ElicitationParams, responder: Responder<ElicitationResponse>, _cx| {
78                        if let Some(capture_tx) = capture.lock().unwrap().take() {
79                            return match capture_tx.send(responder) {
80                                Ok(()) => Ok(()),
81                                Err(responder) => responder.respond_with_error(acp::Error::internal_error()),
82                            };
83                        }
84                        let _ = tx.send(req);
85                        let queued = responses.lock().unwrap().pop_front();
86                        match queued {
87                            Some(response) => responder.respond(response),
88                            None => responder.respond_with_error(acp::Error::method_not_found()),
89                        }
90                    }
91                },
92                acp::on_receive_request!(),
93            );
94
95        let peer = Self {
96            session_notifications: sn_rx,
97            mcp_notifications: mcp_rx,
98            elicitation_requests: el_rx,
99            elicitation_responses,
100            responder_capture,
101        };
102        (peer, builder)
103    }
104
105    pub async fn next_session_notification(&mut self) -> SessionNotification {
106        self.session_notifications.recv().await.expect("peer channel closed")
107    }
108
109    pub async fn next_mcp_notification(&mut self) -> McpNotification {
110        self.mcp_notifications.recv().await.expect("peer channel closed")
111    }
112
113    pub async fn next_elicitation_request(&mut self) -> ElicitationParams {
114        self.elicitation_requests.recv().await.expect("peer channel closed")
115    }
116
117    /// Queue a response the peer will hand back for the next incoming
118    /// elicitation request. If the queue is empty when a request arrives, the
119    /// peer responds with a protocol error, which exercises the
120    /// `cancel_result()` fallback path in the caller.
121    pub fn queue_elicitation_response(&self, response: ElicitationResponse) {
122        self.elicitation_responses.lock().unwrap().push_back(response);
123    }
124
125    /// Kick off a placeholder elicitation request from the agent side of `cx`,
126    /// hand back the [`Responder<ElicitationResponse>`] captured on the client
127    /// side, and return a receiver that resolves when the responder is
128    /// consumed.
129    ///
130    /// Use in tests that pass a `Responder<ElicitationResponse>` into code
131    /// under test and want to observe the response without driving a full ACP
132    /// round-trip themselves.
133    pub async fn fake_elicitation(
134        &mut self,
135        cx: &ConnectionTo<Client>,
136    ) -> (Responder<ElicitationResponse>, oneshot::Receiver<ElicitationResponse>) {
137        let (responder_tx, responder_rx) = oneshot::channel::<Responder<ElicitationResponse>>();
138        *self.responder_capture.lock().unwrap() = Some(responder_tx);
139
140        let (response_tx, response_rx) = oneshot::channel::<ElicitationResponse>();
141        let cx = cx.clone();
142        spawn_local(async move {
143            if let Ok(resp) = cx.send_request(placeholder_params()).block_task().await {
144                let _ = response_tx.send(resp);
145            }
146        });
147
148        let responder = responder_rx.await.expect("client handler must capture responder");
149        (responder, response_rx)
150    }
151}
152
153/// In-memory ACP transport pair: `(agent_transport, client_transport)`. Hand
154/// each half to a `connect_to` / `connect_with` call on the corresponding
155/// side. Must be used inside a `LocalSet` since the runners are `spawn_local`'d.
156pub fn duplex_pair() -> (DuplexByteStreams, DuplexByteStreams) {
157    let (agent_writer, client_reader) = tokio::io::duplex(4096);
158    let (client_writer, agent_reader) = tokio::io::duplex(4096);
159    let agent_transport = ByteStreams::new(agent_writer.compat_write(), agent_reader.compat());
160    let client_transport = ByteStreams::new(client_writer.compat_write(), client_reader.compat());
161    (agent_transport, client_transport)
162}
163
164/// Build a live `ConnectionTo<Client>` over an in-memory duplex transport with
165/// a peer on the other end. Must be called inside a `LocalSet`.
166pub async fn test_connection() -> (ConnectionTo<Client>, TestPeer) {
167    let (peer, client_builder) = TestPeer::new();
168    let (agent_transport, client_transport) = duplex_pair();
169
170    spawn_local(async move {
171        let _ = client_builder.connect_to(client_transport).await;
172    });
173
174    let (cx_tx, cx_rx) = oneshot::channel::<ConnectionTo<Client>>();
175    spawn_local(async move {
176        let _ = Agent
177            .builder()
178            .connect_with(agent_transport, async move |cx: ConnectionTo<Client>| {
179                let _ = cx_tx.send(cx);
180                std::future::pending::<()>().await;
181                Ok(())
182            })
183            .await;
184    });
185
186    let cx = cx_rx.await.expect("agent side connect_with produced a ConnectionTo");
187    (cx, peer)
188}
189
190fn placeholder_params() -> ElicitationParams {
191    ElicitationParams {
192        server_name: String::new(),
193        request: CreateElicitationRequestParams::FormElicitationParams {
194            meta: None,
195            message: String::new(),
196            requested_schema: ElicitationSchema::builder().build().expect("empty schema is valid"),
197        },
198    }
199}