1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use anyhow::{Result, anyhow};
6use futures::{
7 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
8 StreamExt as _,
9 channel::{
10 mpsc::{self, UnboundedReceiver, UnboundedSender},
11 oneshot,
12 },
13 future::LocalBoxFuture,
14 io::BufReader,
15 select_biased,
16};
17use parking_lot::Mutex;
18pub use schema::*;
19use serde::{Deserialize, Serialize};
20use serde_json::value::RawValue;
21use std::{
22 collections::HashMap,
23 fmt::Display,
24 rc::Rc,
25 sync::{
26 Arc,
27 atomic::{AtomicI32, Ordering::SeqCst},
28 },
29};
30
31pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
33
34pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
36
37impl AgentConnection {
38 pub fn connect_to_agent<H: 'static + Client>(
41 handler: H,
42 outgoing_bytes: impl Unpin + AsyncWrite,
43 incoming_bytes: impl Unpin + AsyncRead,
44 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
45 ) -> (Self, impl Future<Output = Result<()>>) {
46 let handler = Arc::new(handler);
47 let (connection, io_task) = Connection::new(
48 Box::new(move |request| {
49 let handler = handler.clone();
50 async move { handler.call(request).await }.boxed_local()
51 }),
52 outgoing_bytes,
53 incoming_bytes,
54 spawn,
55 );
56 (Self(connection), io_task)
57 }
58
59 pub fn request<R: AgentRequest + 'static>(
61 &self,
62 params: R,
63 ) -> impl Future<Output = Result<R::Response>> {
64 let params = params.into_any();
65 let result = self.0.request(params.method_name(), params);
66 async move {
67 let result = result.await?;
68 R::response_from_any(result).ok_or_else(|| {
69 anyhow!(crate::Error {
70 code: -32700,
71 message: "Unexpected Response".to_string(),
72 })
73 })
74 }
75 }
76}
77
78impl ClientConnection {
79 pub fn connect_to_client<H: 'static + Agent>(
80 handler: H,
81 outgoing_bytes: impl Unpin + AsyncWrite,
82 incoming_bytes: impl Unpin + AsyncRead,
83 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
84 ) -> (Self, impl Future<Output = Result<()>>) {
85 let handler = Arc::new(handler);
86 let (connection, io_task) = Connection::new(
87 Box::new(move |request| {
88 let handler = handler.clone();
89 async move { handler.call(request).await }.boxed_local()
90 }),
91 outgoing_bytes,
92 incoming_bytes,
93 spawn,
94 );
95 (Self(connection), io_task)
96 }
97
98 pub fn request<R: ClientRequest>(
99 &self,
100 params: R,
101 ) -> impl use<R> + Future<Output = Result<R::Response>> {
102 let params = params.into_any();
103 let result = self.0.request(params.method_name(), params);
104 async move {
105 let result = result.await?;
106 R::response_from_any(result).ok_or_else(|| {
107 anyhow!(Error {
108 code: -32700,
109 message: "Could not parse".to_string(),
110 })
111 })
112 }
113 }
114}
115
116struct Connection<In, Out>
117where
118 In: AnyRequest,
119 Out: AnyRequest,
120{
121 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
122 response_senders: ResponseSenders<Out::Response>,
123 next_id: AtomicI32,
124}
125
126type ResponseSenders<T> =
127 Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, crate::Error>>)>>>;
128
129#[derive(Debug, Deserialize)]
130struct IncomingMessage<'a> {
131 id: i32,
132 method: Option<&'a str>,
133 params: Option<&'a RawValue>,
134 result: Option<&'a RawValue>,
135 error: Option<Error>,
136}
137
138#[derive(Serialize)]
139#[serde(untagged)]
140enum OutgoingMessage<Req, Resp> {
141 Request {
142 id: i32,
143 method: Box<str>,
144 params: Req,
145 },
146 OkResponse {
147 id: i32,
148 result: Resp,
149 },
150 ErrorResponse {
151 id: i32,
152 error: Error,
153 },
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct Error {
158 pub code: i32,
159 pub message: String,
160}
161
162impl std::error::Error for Error {}
163impl Display for Error {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 if self.message.is_empty() {
166 write!(f, "{}", self.code)
167 } else {
168 write!(f, "{}", self.message)
169 }
170 }
171}
172
173#[derive(Serialize)]
174pub struct JsonRpcMessage<Req, Resp> {
175 pub jsonrpc: &'static str,
176 #[serde(flatten)]
177 message: OutgoingMessage<Req, Resp>,
178}
179
180impl<In, Out> Connection<In, Out>
181where
182 In: AnyRequest,
183 Out: AnyRequest,
184{
185 fn new(
186 request_handler: Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<In::Response>>>,
187 outgoing_bytes: impl Unpin + AsyncWrite,
188 incoming_bytes: impl Unpin + AsyncRead,
189 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
190 ) -> (Self, impl Future<Output = Result<()>>) {
191 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
192 let (incoming_tx, incoming_rx) = mpsc::unbounded();
193 let this = Self {
194 response_senders: ResponseSenders::default(),
195 outgoing_tx: outgoing_tx.clone(),
196 next_id: AtomicI32::new(0),
197 };
198 Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
199 let io_task = Self::handle_io(
200 outgoing_rx,
201 incoming_tx,
202 this.response_senders.clone(),
203 outgoing_bytes,
204 incoming_bytes,
205 );
206 (this, io_task)
207 }
208
209 fn request(
210 &self,
211 method: &'static str,
212 params: Out,
213 ) -> impl use<In, Out> + Future<Output = Result<Out::Response>> {
214 let (tx, rx) = oneshot::channel();
215 let id = self.next_id.fetch_add(1, SeqCst);
216 self.response_senders.lock().insert(id, (method, tx));
217 if self
218 .outgoing_tx
219 .unbounded_send(OutgoingMessage::Request {
220 id,
221 method: method.into(),
222 params,
223 })
224 .is_err()
225 {
226 self.response_senders.lock().remove(&id);
227 }
228 async move { rx.await?.map_err(|e| anyhow!(e)) }
229 }
230
231 async fn handle_io(
232 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
233 incoming_tx: UnboundedSender<(i32, In)>,
234 response_senders: ResponseSenders<Out::Response>,
235 mut outgoing_bytes: impl Unpin + AsyncWrite,
236 incoming_bytes: impl Unpin + AsyncRead,
237 ) -> Result<()> {
238 let mut output_reader = BufReader::new(incoming_bytes);
239 let mut outgoing_line = Vec::new();
240 let mut incoming_line = String::new();
241 loop {
242 select_biased! {
243 message = outgoing_rx.next() => {
244 if let Some(message) = message {
245 outgoing_line.clear();
246 serde_json::to_writer(&mut outgoing_line, &message)?;
247 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
248 outgoing_line.push(b'\n');
249 outgoing_bytes.write_all(&outgoing_line).await.ok();
250 } else {
251 break;
252 }
253 }
254 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
255 if bytes_read? == 0 {
256 break
257 }
258 log::trace!("recv: {}", &incoming_line);
259 match serde_json::from_str::<IncomingMessage>(&incoming_line) {
260 Ok(message) => {
261 if let Some(method) = message.method {
262 match In::from_method_and_params(method, message.params.unwrap_or(RawValue::NULL)) {
263 Ok(params) => {
264 incoming_tx.unbounded_send((message.id, params)).ok();
265 }
266 Err(error) => {
267 log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
268 }
269 }
270 } else if let Some(error) = message.error {
271 if let Some((_, tx)) = response_senders.lock().remove(&message.id) {
272 tx.send(Err(error)).ok();
273 }
274 } else {
275 let result = message.result.unwrap_or(RawValue::NULL);
276 if let Some((method, tx)) = response_senders.lock().remove(&message.id) {
277 match Out::response_from_method_and_result(method, result) {
278 Ok(result) => {
279 tx.send(Ok(result)).ok();
280 }
281 Err(error) => {
282 log::error!("failed to parse {method} message result: {error}. Raw: {result}");
283 }
284 }
285 }
286 }
287 }
288 Err(error) => {
289 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
290 }
291 }
292 incoming_line.clear();
293 }
294 }
295 }
296 response_senders.lock().clear();
297 Ok(())
298 }
299
300 fn handle_incoming(
301 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
302 mut incoming_rx: UnboundedReceiver<(i32, In)>,
303 incoming_handler: Box<
304 dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<In::Response>>,
305 >,
306 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
307 ) {
308 let spawn = Rc::new(spawn);
309 let spawn2 = spawn.clone();
310 spawn(
311 async move {
312 while let Some((id, params)) = incoming_rx.next().await {
313 let result = incoming_handler(params);
314 let outgoing_tx = outgoing_tx.clone();
315 spawn2(
316 async move {
317 let result = result.await;
318 match result {
319 Ok(result) => {
320 outgoing_tx
321 .unbounded_send(OutgoingMessage::OkResponse { id, result })
322 .ok();
323 }
324 Err(error) => {
325 outgoing_tx
326 .unbounded_send(OutgoingMessage::ErrorResponse {
327 id,
328 error: Error {
329 code: -32603,
330 message: error.to_string(),
331 },
332 })
333 .ok();
334 }
335 }
336 }
337 .boxed_local(),
338 )
339 }
340 }
341 .boxed_local(),
342 )
343 }
344}