1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use futures::{
6 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
7 StreamExt as _,
8 channel::{
9 mpsc::{self, UnboundedReceiver, UnboundedSender},
10 oneshot,
11 },
12 future::LocalBoxFuture,
13 io::BufReader,
14 select_biased,
15};
16use parking_lot::Mutex;
17pub use schema::*;
18use semver::Comparator;
19use serde::{Deserialize, Serialize};
20use serde_json::value::RawValue;
21use std::{
22 collections::HashMap,
23 rc::Rc,
24 sync::{
25 Arc,
26 atomic::{AtomicI32, Ordering::SeqCst},
27 },
28};
29
30pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
32
33pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
35
36impl AgentConnection {
37 pub fn connect_to_agent<H: 'static + Client>(
40 handler: H,
41 outgoing_bytes: impl Unpin + AsyncWrite,
42 incoming_bytes: impl Unpin + AsyncRead,
43 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
44 ) -> (Self, impl Future<Output = Result<(), Error>>) {
45 let handler = Arc::new(handler);
46 let (connection, io_task) = Connection::new(
47 Box::new(move |request| {
48 let handler = handler.clone();
49 async move { handler.call(request).await }.boxed_local()
50 }),
51 outgoing_bytes,
52 incoming_bytes,
53 spawn,
54 );
55 (Self(connection), io_task)
56 }
57
58 pub fn request<R: AgentRequest + 'static>(
60 &self,
61 params: R,
62 ) -> impl Future<Output = Result<R::Response, Error>> {
63 let params = params.into_any();
64 let result = self.0.request(params.method_name(), params);
65 async move {
66 let result = result.await?;
67 R::response_from_any(result)
68 }
69 }
70
71 pub fn request_any(
73 &self,
74 params: AnyAgentRequest,
75 ) -> impl use<> + Future<Output = Result<AnyAgentResult, Error>> {
76 self.0.request(params.method_name(), params)
77 }
78
79 pub async fn initialize(&self) -> Result<InitializeResponse, Error> {
82 let protocol_version = ProtocolVersion::latest();
83 let version_requirement = Comparator {
85 op: semver::Op::Caret,
86 major: protocol_version.major,
87 minor: (protocol_version.major == 0).then_some(protocol_version.minor),
88 patch: (protocol_version.major == 0 && protocol_version.minor == 0)
89 .then_some(protocol_version.patch),
90 pre: protocol_version.pre.clone(),
91 };
92 let response = self.request(InitializeParams { protocol_version }).await?;
93
94 let server_version = &response.protocol_version;
95
96 if version_requirement.matches(server_version) {
97 Ok(response)
98 } else {
99 Err(Error::invalid_request().with_data(format!(
100 "Incompatible versions: Server {server_version} / Client: {version_requirement}"
101 )))
102 }
103 }
104}
105
106impl ClientConnection {
107 pub fn connect_to_client<H: 'static + Agent>(
108 handler: H,
109 outgoing_bytes: impl Unpin + AsyncWrite,
110 incoming_bytes: impl Unpin + AsyncRead,
111 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
112 ) -> (Self, impl Future<Output = Result<(), Error>>) {
113 let handler = Arc::new(handler);
114 let (connection, io_task) = Connection::new(
115 Box::new(move |request| {
116 let handler = handler.clone();
117 async move { handler.call(request).await }.boxed_local()
118 }),
119 outgoing_bytes,
120 incoming_bytes,
121 spawn,
122 );
123 (Self(connection), io_task)
124 }
125
126 pub fn request<R: ClientRequest>(
127 &self,
128 params: R,
129 ) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
130 let params = params.into_any();
131 let result = self.0.request(params.method_name(), params);
132 async move {
133 let result = result.await?;
134 R::response_from_any(result)
135 }
136 }
137
138 pub fn request_any(
140 &self,
141 method: &'static str,
142 params: AnyClientRequest,
143 ) -> impl Future<Output = Result<AnyClientResult, Error>> {
144 self.0.request(method, params)
145 }
146}
147
148struct Connection<In, Out>
149where
150 In: AnyRequest,
151 Out: AnyRequest,
152{
153 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
154 response_senders: ResponseSenders<Out::Response>,
155 next_id: AtomicI32,
156}
157
158type ResponseSenders<T> =
159 Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;
160
161#[derive(Debug, Deserialize)]
162struct IncomingMessage<'a> {
163 id: i32,
164 method: Option<&'a str>,
165 params: Option<&'a RawValue>,
166 result: Option<&'a RawValue>,
167 error: Option<Error>,
168}
169
170#[derive(Serialize)]
171#[serde(untagged)]
172enum OutgoingMessage<Req, Resp> {
173 Request {
174 id: i32,
175 method: Box<str>,
176 #[serde(skip_serializing_if = "is_none_or_null")]
177 params: Option<Req>,
178 },
179 OkResponse {
180 id: i32,
181 result: Resp,
182 },
183 ErrorResponse {
184 id: i32,
185 error: Error,
186 },
187}
188
189fn is_none_or_null<T: Serialize>(opt: &Option<T>) -> bool {
190 match opt {
191 None => true,
192 Some(value) => {
193 matches!(serde_json::to_value(value), Ok(serde_json::Value::Null))
194 }
195 }
196}
197
198#[derive(Debug, Deserialize, Serialize)]
199enum JsonSchemaVersion {
200 #[serde(rename = "2.0")]
201 V2,
202}
203
204#[derive(Serialize)]
205struct OutJsonRpcMessage<Req, Resp> {
206 jsonrpc: JsonSchemaVersion,
207 #[serde(flatten)]
208 message: OutgoingMessage<Req, Resp>,
209}
210
211type ResponseHandler<In, Resp> =
212 Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;
213
214impl<In, Out> Connection<In, Out>
215where
216 In: AnyRequest,
217 Out: AnyRequest,
218{
219 fn new(
220 request_handler: ResponseHandler<In, In::Response>,
221 outgoing_bytes: impl Unpin + AsyncWrite,
222 incoming_bytes: impl Unpin + AsyncRead,
223 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
224 ) -> (Self, impl Future<Output = Result<(), Error>>) {
225 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
226 let (incoming_tx, incoming_rx) = mpsc::unbounded();
227 let this = Self {
228 response_senders: ResponseSenders::default(),
229 outgoing_tx: outgoing_tx.clone(),
230 next_id: AtomicI32::new(0),
231 };
232 Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
233 let io_task = Self::handle_io(
234 outgoing_rx,
235 incoming_tx,
236 this.response_senders.clone(),
237 outgoing_bytes,
238 incoming_bytes,
239 );
240 (this, io_task)
241 }
242
243 fn request(
244 &self,
245 method: &'static str,
246 params: Out,
247 ) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
248 let (tx, rx) = oneshot::channel();
249 let id = self.next_id.fetch_add(1, SeqCst);
250 self.response_senders.lock().insert(id, (method, tx));
251 if self
252 .outgoing_tx
253 .unbounded_send(OutgoingMessage::Request {
254 id,
255 method: method.into(),
256 params: Some(params),
257 })
258 .is_err()
259 {
260 self.response_senders.lock().remove(&id);
261 }
262 async move {
263 rx.await
264 .map_err(|e| Error::internal_error().with_data(e.to_string()))?
265 }
266 }
267
268 async fn handle_io(
269 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
270 incoming_tx: UnboundedSender<(i32, In)>,
271 response_senders: ResponseSenders<Out::Response>,
272 mut outgoing_bytes: impl Unpin + AsyncWrite,
273 incoming_bytes: impl Unpin + AsyncRead,
274 ) -> Result<(), Error> {
275 let mut output_reader = BufReader::new(incoming_bytes);
276 let mut outgoing_line = Vec::new();
277 let mut incoming_line = String::new();
278 loop {
279 select_biased! {
280 message = outgoing_rx.next() => {
281 if let Some(message) = message {
282 let message = OutJsonRpcMessage {
283 jsonrpc: JsonSchemaVersion::V2,
284 message,
285 };
286 outgoing_line.clear();
287 serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
288 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
289 outgoing_line.push(b'\n');
290 outgoing_bytes.write_all(&outgoing_line).await.ok();
291 } else {
292 break;
293 }
294 }
295 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
296 if bytes_read.map_err(Error::into_internal_error)? == 0 {
297 break
298 }
299 log::trace!("recv: {}", &incoming_line);
300 match serde_json::from_str::<IncomingMessage>(&incoming_line) {
301 Ok(IncomingMessage { id, method, params, result, error }) => {
302 if let Some(method) = method {
303 match In::from_method_and_params(method, params.unwrap_or(RawValue::NULL)) {
304 Ok(params) => {
305 incoming_tx.unbounded_send((id, params)).ok();
306 }
307 Err(error) => {
308 log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
309 }
310 }
311 } else if let Some(error) = error {
312 if let Some((_, tx)) = response_senders.lock().remove(&id) {
313 tx.send(Err(error)).ok();
314 }
315 } else {
316 let result = result.unwrap_or(RawValue::NULL);
317 if let Some((method, tx)) = response_senders.lock().remove(&id) {
318 match Out::response_from_method_and_result(method, result) {
319 Ok(result) => {
320 tx.send(Ok(result)).ok();
321 }
322 Err(error) => {
323 log::error!("failed to parse {method} message result: {error}. Raw: {result}");
324 }
325 }
326 }
327 }
328 }
329 Err(error) => {
330 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
331 }
332 }
333 incoming_line.clear();
334 }
335 }
336 }
337 response_senders.lock().clear();
338 Ok(())
339 }
340
341 fn handle_incoming(
342 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
343 mut incoming_rx: UnboundedReceiver<(i32, In)>,
344 incoming_handler: ResponseHandler<In, In::Response>,
345 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
346 ) {
347 let spawn = Rc::new(spawn);
348 let spawn2 = spawn.clone();
349 spawn(
350 async move {
351 while let Some((id, params)) = incoming_rx.next().await {
352 let result = incoming_handler(params);
353 let outgoing_tx = outgoing_tx.clone();
354 spawn2(
355 async move {
356 let result = result.await;
357 match result {
358 Ok(result) => {
359 outgoing_tx
360 .unbounded_send(OutgoingMessage::OkResponse { id, result })
361 .ok();
362 }
363 Err(error) => {
364 outgoing_tx
365 .unbounded_send(OutgoingMessage::ErrorResponse {
366 id,
367 error: Error::into_internal_error(error),
368 })
369 .ok();
370 }
371 }
372 }
373 .boxed_local(),
374 )
375 }
376 }
377 .boxed_local(),
378 )
379 }
380}