1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use futures::{
6 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
7 StreamExt as _,
8 channel::{
9 mpsc::{self, UnboundedReceiver, UnboundedSender},
10 oneshot,
11 },
12 future::LocalBoxFuture,
13 io::BufReader,
14 select_biased,
15};
16use parking_lot::Mutex;
17pub use schema::*;
18use serde::{Deserialize, Serialize};
19use serde_json::value::RawValue;
20use std::{
21 collections::HashMap,
22 fmt::Display,
23 rc::Rc,
24 sync::{
25 Arc,
26 atomic::{AtomicI32, Ordering::SeqCst},
27 },
28};
29
30pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
32
33pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
35
36impl AgentConnection {
37 pub fn connect_to_agent<H: 'static + Client>(
40 handler: H,
41 outgoing_bytes: impl Unpin + AsyncWrite,
42 incoming_bytes: impl Unpin + AsyncRead,
43 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
44 ) -> (Self, impl Future<Output = Result<(), Error>>) {
45 let handler = Arc::new(handler);
46 let (connection, io_task) = Connection::new(
47 Box::new(move |request| {
48 let handler = handler.clone();
49 async move { handler.call(request).await }.boxed_local()
50 }),
51 outgoing_bytes,
52 incoming_bytes,
53 spawn,
54 );
55 (Self(connection), io_task)
56 }
57
58 pub fn request<R: AgentRequest + 'static>(
60 &self,
61 params: R,
62 ) -> impl Future<Output = Result<R::Response, Error>> {
63 let params = params.into_any();
64 let result = self.0.request(params.method_name(), params);
65 async move {
66 let result = result.await?;
67 R::response_from_any(result)
68 }
69 }
70}
71
72impl ClientConnection {
73 pub fn connect_to_client<H: 'static + Agent>(
74 handler: H,
75 outgoing_bytes: impl Unpin + AsyncWrite,
76 incoming_bytes: impl Unpin + AsyncRead,
77 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
78 ) -> (Self, impl Future<Output = Result<(), Error>>) {
79 let handler = Arc::new(handler);
80 let (connection, io_task) = Connection::new(
81 Box::new(move |request| {
82 let handler = handler.clone();
83 async move { handler.call(request).await }.boxed_local()
84 }),
85 outgoing_bytes,
86 incoming_bytes,
87 spawn,
88 );
89 (Self(connection), io_task)
90 }
91
92 pub fn request<R: ClientRequest>(
93 &self,
94 params: R,
95 ) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
96 let params = params.into_any();
97 let result = self.0.request(params.method_name(), params);
98 async move {
99 let result = result.await?;
100 R::response_from_any(result)
101 }
102 }
103}
104
105struct Connection<In, Out>
106where
107 In: AnyRequest,
108 Out: AnyRequest,
109{
110 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
111 response_senders: ResponseSenders<Out::Response>,
112 next_id: AtomicI32,
113}
114
115type ResponseSenders<T> =
116 Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;
117
118#[derive(Debug, Deserialize)]
119struct IncomingMessage<'a> {
120 id: i32,
121 method: Option<&'a str>,
122 params: Option<&'a RawValue>,
123 result: Option<&'a RawValue>,
124 error: Option<Error>,
125}
126
127#[derive(Serialize)]
128#[serde(untagged)]
129enum OutgoingMessage<Req, Resp> {
130 Request {
131 id: i32,
132 method: Box<str>,
133 params: Req,
134 },
135 OkResponse {
136 id: i32,
137 result: Resp,
138 },
139 ErrorResponse {
140 id: i32,
141 error: Error,
142 },
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct Error {
147 pub code: i32,
148 pub message: String,
149 #[serde(skip_serializing_if = "Option::is_none")]
150 pub data: Option<ErrorData>,
151}
152
153impl Error {
154 pub fn new(code: i32, message: impl Into<String>) -> Self {
155 Error {
156 code,
157 message: message.into(),
158 data: None,
159 }
160 }
161
162 pub fn with_details(mut self, details: impl Into<String>) -> Self {
163 self.data = Some(ErrorData::new(details));
164 self
165 }
166
167 pub fn parse_error() -> Self {
169 Error::new(-32700, "Parse error")
170 }
171
172 pub fn invalid_request() -> Self {
174 Error::new(-32600, "Invalid Request")
175 }
176
177 pub fn method_not_found() -> Self {
179 Error::new(-32601, "Method not found")
180 }
181
182 pub fn invalid_params() -> Self {
184 Error::new(-32602, "Invalid params")
185 }
186
187 pub fn internal_error() -> Self {
189 Error::new(-32603, "Internal error")
190 }
191}
192
193impl std::error::Error for Error {}
194impl Display for Error {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 if self.message.is_empty() {
197 write!(f, "{}", self.code)?;
198 } else {
199 write!(f, "{}", self.message)?;
200 }
201
202 if let Some(data) = &self.data {
203 write!(f, ": {}", data.details)?;
204 }
205
206 Ok(())
207 }
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ErrorData {
212 pub details: String,
213}
214
215impl ErrorData {
216 pub fn new(details: impl Into<String>) -> Self {
217 ErrorData {
218 details: details.into(),
219 }
220 }
221}
222
223#[derive(Serialize)]
224pub struct JsonRpcMessage<Req, Resp> {
225 pub jsonrpc: &'static str,
226 #[serde(flatten)]
227 message: OutgoingMessage<Req, Resp>,
228}
229
230type ResponseHandler<In, Resp> =
231 Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;
232
233impl<In, Out> Connection<In, Out>
234where
235 In: AnyRequest,
236 Out: AnyRequest,
237{
238 fn new(
239 request_handler: ResponseHandler<In, In::Response>,
240 outgoing_bytes: impl Unpin + AsyncWrite,
241 incoming_bytes: impl Unpin + AsyncRead,
242 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
243 ) -> (Self, impl Future<Output = Result<(), Error>>) {
244 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
245 let (incoming_tx, incoming_rx) = mpsc::unbounded();
246 let this = Self {
247 response_senders: ResponseSenders::default(),
248 outgoing_tx: outgoing_tx.clone(),
249 next_id: AtomicI32::new(0),
250 };
251 Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
252 let io_task = Self::handle_io(
253 outgoing_rx,
254 incoming_tx,
255 this.response_senders.clone(),
256 outgoing_bytes,
257 incoming_bytes,
258 );
259 (this, io_task)
260 }
261
262 fn request(
263 &self,
264 method: &'static str,
265 params: Out,
266 ) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
267 let (tx, rx) = oneshot::channel();
268 let id = self.next_id.fetch_add(1, SeqCst);
269 self.response_senders.lock().insert(id, (method, tx));
270 if self
271 .outgoing_tx
272 .unbounded_send(OutgoingMessage::Request {
273 id,
274 method: method.into(),
275 params,
276 })
277 .is_err()
278 {
279 self.response_senders.lock().remove(&id);
280 }
281 async move {
282 rx.await
283 .map_err(|e| Error::internal_error().with_details(e.to_string()))?
284 }
285 }
286
287 async fn handle_io(
288 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
289 incoming_tx: UnboundedSender<(i32, In)>,
290 response_senders: ResponseSenders<Out::Response>,
291 mut outgoing_bytes: impl Unpin + AsyncWrite,
292 incoming_bytes: impl Unpin + AsyncRead,
293 ) -> Result<(), Error> {
294 let mut output_reader = BufReader::new(incoming_bytes);
295 let mut outgoing_line = Vec::new();
296 let mut incoming_line = String::new();
297 loop {
298 select_biased! {
299 message = outgoing_rx.next() => {
300 if let Some(message) = message {
301 outgoing_line.clear();
302 serde_json::to_writer(&mut outgoing_line, &message).map_err(|e| Error::internal_error()
303 .with_details(e.to_string()))?;
304 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
305 outgoing_line.push(b'\n');
306 outgoing_bytes.write_all(&outgoing_line).await.ok();
307 } else {
308 break;
309 }
310 }
311 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
312 if bytes_read.map_err(|e| Error::internal_error().with_details(e.to_string()))? == 0 {
313 break
314 }
315 log::trace!("recv: {}", &incoming_line);
316 match serde_json::from_str::<IncomingMessage>(&incoming_line) {
317 Ok(message) => {
318 if let Some(method) = message.method {
319 match In::from_method_and_params(method, message.params.unwrap_or(RawValue::NULL)) {
320 Ok(params) => {
321 incoming_tx.unbounded_send((message.id, params)).ok();
322 }
323 Err(error) => {
324 log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
325 }
326 }
327 } else if let Some(error) = message.error {
328 if let Some((_, tx)) = response_senders.lock().remove(&message.id) {
329 tx.send(Err(error)).ok();
330 }
331 } else {
332 let result = message.result.unwrap_or(RawValue::NULL);
333 if let Some((method, tx)) = response_senders.lock().remove(&message.id) {
334 match Out::response_from_method_and_result(method, result) {
335 Ok(result) => {
336 tx.send(Ok(result)).ok();
337 }
338 Err(error) => {
339 log::error!("failed to parse {method} message result: {error}. Raw: {result}");
340 }
341 }
342 }
343 }
344 }
345 Err(error) => {
346 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
347 }
348 }
349 incoming_line.clear();
350 }
351 }
352 }
353 response_senders.lock().clear();
354 Ok(())
355 }
356
357 fn handle_incoming(
358 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
359 mut incoming_rx: UnboundedReceiver<(i32, In)>,
360 incoming_handler: ResponseHandler<In, In::Response>,
361 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
362 ) {
363 let spawn = Rc::new(spawn);
364 let spawn2 = spawn.clone();
365 spawn(
366 async move {
367 while let Some((id, params)) = incoming_rx.next().await {
368 let result = incoming_handler(params);
369 let outgoing_tx = outgoing_tx.clone();
370 spawn2(
371 async move {
372 let result = result.await;
373 match result {
374 Ok(result) => {
375 outgoing_tx
376 .unbounded_send(OutgoingMessage::OkResponse { id, result })
377 .ok();
378 }
379 Err(error) => {
380 outgoing_tx
381 .unbounded_send(OutgoingMessage::ErrorResponse {
382 id,
383 error: Error::internal_error()
384 .with_details(error.to_string()),
385 })
386 .ok();
387 }
388 }
389 }
390 .boxed_local(),
391 )
392 }
393 }
394 .boxed_local(),
395 )
396 }
397}