1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use futures::{
6 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
7 StreamExt as _,
8 channel::{
9 mpsc::{self, UnboundedReceiver, UnboundedSender},
10 oneshot,
11 },
12 future::LocalBoxFuture,
13 io::BufReader,
14 select_biased,
15};
16use parking_lot::Mutex;
17pub use schema::*;
18use serde::{Deserialize, Serialize};
19use serde_json::value::RawValue;
20use std::{
21 collections::HashMap,
22 rc::Rc,
23 sync::{
24 Arc,
25 atomic::{AtomicI32, Ordering::SeqCst},
26 },
27};
28
29pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
31
32pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
34
35impl AgentConnection {
36 pub fn connect_to_agent<H: 'static + Client>(
39 handler: H,
40 outgoing_bytes: impl Unpin + AsyncWrite,
41 incoming_bytes: impl Unpin + AsyncRead,
42 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
43 ) -> (Self, impl Future<Output = Result<(), Error>>) {
44 let handler = Arc::new(handler);
45 let (connection, io_task) = Connection::new(
46 Box::new(move |request| {
47 let handler = handler.clone();
48 async move { handler.call(request).await }.boxed_local()
49 }),
50 outgoing_bytes,
51 incoming_bytes,
52 spawn,
53 );
54 (Self(connection), io_task)
55 }
56
57 pub fn request<R: AgentRequest + 'static>(
59 &self,
60 params: R,
61 ) -> impl Future<Output = Result<R::Response, Error>> {
62 let params = params.into_any();
63 let result = self.0.request(params.method_name(), params);
64 async move {
65 let result = result.await?;
66 R::response_from_any(result)
67 }
68 }
69}
70
71impl ClientConnection {
72 pub fn connect_to_client<H: 'static + Agent>(
73 handler: H,
74 outgoing_bytes: impl Unpin + AsyncWrite,
75 incoming_bytes: impl Unpin + AsyncRead,
76 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
77 ) -> (Self, impl Future<Output = Result<(), Error>>) {
78 let handler = Arc::new(handler);
79 let (connection, io_task) = Connection::new(
80 Box::new(move |request| {
81 let handler = handler.clone();
82 async move { handler.call(request).await }.boxed_local()
83 }),
84 outgoing_bytes,
85 incoming_bytes,
86 spawn,
87 );
88 (Self(connection), io_task)
89 }
90
91 pub fn request<R: ClientRequest>(
92 &self,
93 params: R,
94 ) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
95 let params = params.into_any();
96 let result = self.0.request(params.method_name(), params);
97 async move {
98 let result = result.await?;
99 R::response_from_any(result)
100 }
101 }
102}
103
104struct Connection<In, Out>
105where
106 In: AnyRequest,
107 Out: AnyRequest,
108{
109 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
110 response_senders: ResponseSenders<Out::Response>,
111 next_id: AtomicI32,
112}
113
114type ResponseSenders<T> =
115 Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;
116
117#[derive(Debug, Deserialize)]
118struct IncomingMessage<'a> {
119 id: i32,
120 method: Option<&'a str>,
121 params: Option<&'a RawValue>,
122 result: Option<&'a RawValue>,
123 error: Option<Error>,
124}
125
126#[derive(Serialize)]
127#[serde(untagged)]
128enum OutgoingMessage<Req, Resp> {
129 Request {
130 id: i32,
131 method: Box<str>,
132 params: Req,
133 },
134 OkResponse {
135 id: i32,
136 result: Resp,
137 },
138 ErrorResponse {
139 id: i32,
140 error: Error,
141 },
142}
143
144#[derive(Serialize)]
145pub struct JsonRpcMessage<Req, Resp> {
146 pub jsonrpc: &'static str,
147 #[serde(flatten)]
148 message: OutgoingMessage<Req, Resp>,
149}
150
151type ResponseHandler<In, Resp> =
152 Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;
153
154impl<In, Out> Connection<In, Out>
155where
156 In: AnyRequest,
157 Out: AnyRequest,
158{
159 fn new(
160 request_handler: ResponseHandler<In, In::Response>,
161 outgoing_bytes: impl Unpin + AsyncWrite,
162 incoming_bytes: impl Unpin + AsyncRead,
163 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
164 ) -> (Self, impl Future<Output = Result<(), Error>>) {
165 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
166 let (incoming_tx, incoming_rx) = mpsc::unbounded();
167 let this = Self {
168 response_senders: ResponseSenders::default(),
169 outgoing_tx: outgoing_tx.clone(),
170 next_id: AtomicI32::new(0),
171 };
172 Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
173 let io_task = Self::handle_io(
174 outgoing_rx,
175 incoming_tx,
176 this.response_senders.clone(),
177 outgoing_bytes,
178 incoming_bytes,
179 );
180 (this, io_task)
181 }
182
183 fn request(
184 &self,
185 method: &'static str,
186 params: Out,
187 ) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
188 let (tx, rx) = oneshot::channel();
189 let id = self.next_id.fetch_add(1, SeqCst);
190 self.response_senders.lock().insert(id, (method, tx));
191 if self
192 .outgoing_tx
193 .unbounded_send(OutgoingMessage::Request {
194 id,
195 method: method.into(),
196 params,
197 })
198 .is_err()
199 {
200 self.response_senders.lock().remove(&id);
201 }
202 async move {
203 rx.await
204 .map_err(|e| Error::internal_error().with_data(e.to_string()))?
205 }
206 }
207
208 async fn handle_io(
209 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
210 incoming_tx: UnboundedSender<(i32, In)>,
211 response_senders: ResponseSenders<Out::Response>,
212 mut outgoing_bytes: impl Unpin + AsyncWrite,
213 incoming_bytes: impl Unpin + AsyncRead,
214 ) -> Result<(), Error> {
215 let mut output_reader = BufReader::new(incoming_bytes);
216 let mut outgoing_line = Vec::new();
217 let mut incoming_line = String::new();
218 loop {
219 select_biased! {
220 message = outgoing_rx.next() => {
221 if let Some(message) = message {
222 outgoing_line.clear();
223 serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
224 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
225 outgoing_line.push(b'\n');
226 outgoing_bytes.write_all(&outgoing_line).await.ok();
227 } else {
228 break;
229 }
230 }
231 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
232 if bytes_read.map_err(Error::into_internal_error)? == 0 {
233 break
234 }
235 log::trace!("recv: {}", &incoming_line);
236 match serde_json::from_str::<IncomingMessage>(&incoming_line) {
237 Ok(message) => {
238 if let Some(method) = message.method {
239 match In::from_method_and_params(method, message.params.unwrap_or(RawValue::NULL)) {
240 Ok(params) => {
241 incoming_tx.unbounded_send((message.id, params)).ok();
242 }
243 Err(error) => {
244 log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
245 }
246 }
247 } else if let Some(error) = message.error {
248 if let Some((_, tx)) = response_senders.lock().remove(&message.id) {
249 tx.send(Err(error)).ok();
250 }
251 } else {
252 let result = message.result.unwrap_or(RawValue::NULL);
253 if let Some((method, tx)) = response_senders.lock().remove(&message.id) {
254 match Out::response_from_method_and_result(method, result) {
255 Ok(result) => {
256 tx.send(Ok(result)).ok();
257 }
258 Err(error) => {
259 log::error!("failed to parse {method} message result: {error}. Raw: {result}");
260 }
261 }
262 }
263 }
264 }
265 Err(error) => {
266 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
267 }
268 }
269 incoming_line.clear();
270 }
271 }
272 }
273 response_senders.lock().clear();
274 Ok(())
275 }
276
277 fn handle_incoming(
278 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
279 mut incoming_rx: UnboundedReceiver<(i32, In)>,
280 incoming_handler: ResponseHandler<In, In::Response>,
281 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
282 ) {
283 let spawn = Rc::new(spawn);
284 let spawn2 = spawn.clone();
285 spawn(
286 async move {
287 while let Some((id, params)) = incoming_rx.next().await {
288 let result = incoming_handler(params);
289 let outgoing_tx = outgoing_tx.clone();
290 spawn2(
291 async move {
292 let result = result.await;
293 match result {
294 Ok(result) => {
295 outgoing_tx
296 .unbounded_send(OutgoingMessage::OkResponse { id, result })
297 .ok();
298 }
299 Err(error) => {
300 outgoing_tx
301 .unbounded_send(OutgoingMessage::ErrorResponse {
302 id,
303 error: Error::into_internal_error(error),
304 })
305 .ok();
306 }
307 }
308 }
309 .boxed_local(),
310 )
311 }
312 }
313 .boxed_local(),
314 )
315 }
316}