1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use futures::{
6 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
7 StreamExt as _,
8 channel::{
9 mpsc::{self, UnboundedReceiver, UnboundedSender},
10 oneshot,
11 },
12 future::LocalBoxFuture,
13 io::BufReader,
14 select_biased,
15};
16use parking_lot::Mutex;
17pub use schema::*;
18use semver::Comparator;
19use serde::{Deserialize, Serialize};
20use serde_json::value::RawValue;
21use std::{
22 collections::HashMap,
23 rc::Rc,
24 sync::{
25 Arc,
26 atomic::{AtomicI32, Ordering::SeqCst},
27 },
28};
29
30pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
32
33pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
35
36impl AgentConnection {
37 pub fn connect_to_agent<H: 'static + Client>(
40 handler: H,
41 outgoing_bytes: impl Unpin + AsyncWrite,
42 incoming_bytes: impl Unpin + AsyncRead,
43 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
44 ) -> (Self, impl Future<Output = Result<(), Error>>) {
45 let handler = Arc::new(handler);
46 let (connection, io_task) = Connection::new(
47 Box::new(move |request| {
48 let handler = handler.clone();
49 async move { handler.call(request).await }.boxed_local()
50 }),
51 outgoing_bytes,
52 incoming_bytes,
53 spawn,
54 );
55 (Self(connection), io_task)
56 }
57
58 pub fn request<R: AgentRequest + 'static>(
60 &self,
61 params: R,
62 ) -> impl Future<Output = Result<R::Response, Error>> {
63 let params = params.into_any();
64 let result = self.0.request(params.method_name(), params);
65 async move {
66 let result = result.await?;
67 R::response_from_any(result)
68 }
69 }
70
71 pub fn request_any(
73 &self,
74 params: AnyAgentRequest,
75 ) -> impl use<> + Future<Output = Result<AnyAgentResult, Error>> {
76 self.0.request(params.method_name(), params)
77 }
78
79 pub async fn initialize(&self) -> Result<InitializeResponse, Error> {
82 let protocol_version = ProtocolVersion::latest();
83 let version_requirement = Comparator {
84 op: semver::Op::Caret,
85 major: protocol_version.major,
86 minor: Some(protocol_version.minor),
87 patch: Some(protocol_version.patch),
88 pre: protocol_version.pre.clone(),
89 };
90 let response = self.request(InitializeParams { protocol_version }).await?;
91
92 let server_version = &response.protocol_version;
93
94 if version_requirement.matches(server_version) {
95 Ok(response)
96 } else {
97 Err(Error::invalid_request().with_data(format!(
98 "Incompatible versions: Server {server_version} / Client: {version_requirement}"
99 )))
100 }
101 }
102}
103
104impl ClientConnection {
105 pub fn connect_to_client<H: 'static + Agent>(
106 handler: H,
107 outgoing_bytes: impl Unpin + AsyncWrite,
108 incoming_bytes: impl Unpin + AsyncRead,
109 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
110 ) -> (Self, impl Future<Output = Result<(), Error>>) {
111 let handler = Arc::new(handler);
112 let (connection, io_task) = Connection::new(
113 Box::new(move |request| {
114 let handler = handler.clone();
115 async move { handler.call(request).await }.boxed_local()
116 }),
117 outgoing_bytes,
118 incoming_bytes,
119 spawn,
120 );
121 (Self(connection), io_task)
122 }
123
124 pub fn request<R: ClientRequest>(
125 &self,
126 params: R,
127 ) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
128 let params = params.into_any();
129 let result = self.0.request(params.method_name(), params);
130 async move {
131 let result = result.await?;
132 R::response_from_any(result)
133 }
134 }
135
136 pub fn request_any(
138 &self,
139 method: &'static str,
140 params: AnyClientRequest,
141 ) -> impl Future<Output = Result<AnyClientResult, Error>> {
142 self.0.request(method, params)
143 }
144}
145
146struct Connection<In, Out>
147where
148 In: AnyRequest,
149 Out: AnyRequest,
150{
151 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
152 response_senders: ResponseSenders<Out::Response>,
153 next_id: AtomicI32,
154}
155
156type ResponseSenders<T> =
157 Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;
158
159#[derive(Debug, Deserialize)]
160struct IncomingMessage<'a> {
161 id: i32,
162 method: Option<&'a str>,
163 params: Option<&'a RawValue>,
164 result: Option<&'a RawValue>,
165 error: Option<Error>,
166}
167
168#[derive(Serialize)]
169#[serde(untagged)]
170enum OutgoingMessage<Req, Resp> {
171 Request {
172 id: i32,
173 method: Box<str>,
174 #[serde(skip_serializing_if = "is_none_or_null")]
175 params: Option<Req>,
176 },
177 OkResponse {
178 id: i32,
179 result: Resp,
180 },
181 ErrorResponse {
182 id: i32,
183 error: Error,
184 },
185}
186
187fn is_none_or_null<T: Serialize>(opt: &Option<T>) -> bool {
188 match opt {
189 None => true,
190 Some(value) => {
191 matches!(serde_json::to_value(value), Ok(serde_json::Value::Null))
192 }
193 }
194}
195
196#[derive(Debug, Deserialize, Serialize)]
197enum JsonSchemaVersion {
198 #[serde(rename = "2.0")]
199 V2,
200}
201
202#[derive(Serialize)]
203struct OutJsonRpcMessage<Req, Resp> {
204 jsonrpc: JsonSchemaVersion,
205 #[serde(flatten)]
206 message: OutgoingMessage<Req, Resp>,
207}
208
209type ResponseHandler<In, Resp> =
210 Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;
211
212impl<In, Out> Connection<In, Out>
213where
214 In: AnyRequest,
215 Out: AnyRequest,
216{
217 fn new(
218 request_handler: ResponseHandler<In, In::Response>,
219 outgoing_bytes: impl Unpin + AsyncWrite,
220 incoming_bytes: impl Unpin + AsyncRead,
221 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
222 ) -> (Self, impl Future<Output = Result<(), Error>>) {
223 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
224 let (incoming_tx, incoming_rx) = mpsc::unbounded();
225 let this = Self {
226 response_senders: ResponseSenders::default(),
227 outgoing_tx: outgoing_tx.clone(),
228 next_id: AtomicI32::new(0),
229 };
230 Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
231 let io_task = Self::handle_io(
232 outgoing_rx,
233 incoming_tx,
234 this.response_senders.clone(),
235 outgoing_bytes,
236 incoming_bytes,
237 );
238 (this, io_task)
239 }
240
241 fn request(
242 &self,
243 method: &'static str,
244 params: Out,
245 ) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
246 let (tx, rx) = oneshot::channel();
247 let id = self.next_id.fetch_add(1, SeqCst);
248 self.response_senders.lock().insert(id, (method, tx));
249 if self
250 .outgoing_tx
251 .unbounded_send(OutgoingMessage::Request {
252 id,
253 method: method.into(),
254 params: Some(params),
255 })
256 .is_err()
257 {
258 self.response_senders.lock().remove(&id);
259 }
260 async move {
261 rx.await
262 .map_err(|e| Error::internal_error().with_data(e.to_string()))?
263 }
264 }
265
266 async fn handle_io(
267 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
268 incoming_tx: UnboundedSender<(i32, In)>,
269 response_senders: ResponseSenders<Out::Response>,
270 mut outgoing_bytes: impl Unpin + AsyncWrite,
271 incoming_bytes: impl Unpin + AsyncRead,
272 ) -> Result<(), Error> {
273 let mut output_reader = BufReader::new(incoming_bytes);
274 let mut outgoing_line = Vec::new();
275 let mut incoming_line = String::new();
276 loop {
277 select_biased! {
278 message = outgoing_rx.next() => {
279 if let Some(message) = message {
280 let message = OutJsonRpcMessage {
281 jsonrpc: JsonSchemaVersion::V2,
282 message,
283 };
284 outgoing_line.clear();
285 serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
286 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
287 outgoing_line.push(b'\n');
288 outgoing_bytes.write_all(&outgoing_line).await.ok();
289 } else {
290 break;
291 }
292 }
293 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
294 if bytes_read.map_err(Error::into_internal_error)? == 0 {
295 break
296 }
297 log::trace!("recv: {}", &incoming_line);
298 match serde_json::from_str::<IncomingMessage>(&incoming_line) {
299 Ok(IncomingMessage { id, method, params, result, error }) => {
300 if let Some(method) = method {
301 match In::from_method_and_params(method, params.unwrap_or(RawValue::NULL)) {
302 Ok(params) => {
303 incoming_tx.unbounded_send((id, params)).ok();
304 }
305 Err(error) => {
306 log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
307 }
308 }
309 } else if let Some(error) = error {
310 if let Some((_, tx)) = response_senders.lock().remove(&id) {
311 tx.send(Err(error)).ok();
312 }
313 } else {
314 let result = result.unwrap_or(RawValue::NULL);
315 if let Some((method, tx)) = response_senders.lock().remove(&id) {
316 match Out::response_from_method_and_result(method, result) {
317 Ok(result) => {
318 tx.send(Ok(result)).ok();
319 }
320 Err(error) => {
321 log::error!("failed to parse {method} message result: {error}. Raw: {result}");
322 }
323 }
324 }
325 }
326 }
327 Err(error) => {
328 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
329 }
330 }
331 incoming_line.clear();
332 }
333 }
334 }
335 response_senders.lock().clear();
336 Ok(())
337 }
338
339 fn handle_incoming(
340 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
341 mut incoming_rx: UnboundedReceiver<(i32, In)>,
342 incoming_handler: ResponseHandler<In, In::Response>,
343 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
344 ) {
345 let spawn = Rc::new(spawn);
346 let spawn2 = spawn.clone();
347 spawn(
348 async move {
349 while let Some((id, params)) = incoming_rx.next().await {
350 let result = incoming_handler(params);
351 let outgoing_tx = outgoing_tx.clone();
352 spawn2(
353 async move {
354 let result = result.await;
355 match result {
356 Ok(result) => {
357 outgoing_tx
358 .unbounded_send(OutgoingMessage::OkResponse { id, result })
359 .ok();
360 }
361 Err(error) => {
362 outgoing_tx
363 .unbounded_send(OutgoingMessage::ErrorResponse {
364 id,
365 error: Error::into_internal_error(error),
366 })
367 .ok();
368 }
369 }
370 }
371 .boxed_local(),
372 )
373 }
374 }
375 .boxed_local(),
376 )
377 }
378}