1#[cfg(test)]
2mod acp_tests;
3mod schema;
4
5use futures::{
6 AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, FutureExt as _,
7 StreamExt as _,
8 channel::{
9 mpsc::{self, UnboundedReceiver, UnboundedSender},
10 oneshot,
11 },
12 future::LocalBoxFuture,
13 io::BufReader,
14 select_biased,
15};
16use parking_lot::Mutex;
17pub use schema::*;
18use semver::Comparator;
19use serde::{Deserialize, Serialize};
20use serde_json::value::RawValue;
21use std::{
22 collections::HashMap,
23 rc::Rc,
24 sync::{
25 Arc,
26 atomic::{AtomicI32, Ordering::SeqCst},
27 },
28};
29
30pub struct AgentConnection(Connection<AnyClientRequest, AnyAgentRequest>);
32
33pub struct ClientConnection(Connection<AnyAgentRequest, AnyClientRequest>);
35
36impl AgentConnection {
37 pub fn connect_to_agent<H: 'static + Client>(
40 handler: H,
41 outgoing_bytes: impl Unpin + AsyncWrite,
42 incoming_bytes: impl Unpin + AsyncRead,
43 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
44 ) -> (Self, impl Future<Output = Result<(), Error>>) {
45 let handler = Arc::new(handler);
46 let (connection, io_task) = Connection::new(
47 Box::new(move |request| {
48 let handler = handler.clone();
49 async move { handler.call(request).await }.boxed_local()
50 }),
51 outgoing_bytes,
52 incoming_bytes,
53 spawn,
54 );
55 (Self(connection), io_task)
56 }
57
58 pub fn request<R: AgentRequest + 'static>(
60 &self,
61 params: R,
62 ) -> impl Future<Output = Result<R::Response, Error>> {
63 let params = params.into_any();
64 let result = self.0.request(params.method_name(), params);
65 async move {
66 let result = result.await?;
67 R::response_from_any(result)
68 }
69 }
70
71 pub fn request_any(
73 &self,
74 params: AnyAgentRequest,
75 ) -> impl use<> + Future<Output = Result<AnyAgentResult, Error>> {
76 self.0.request(params.method_name(), params)
77 }
78
79 pub async fn initialize(
82 &self,
83 context_servers: HashMap<String, ContextServer>,
84 ) -> Result<InitializeResponse, Error> {
85 let protocol_version = ProtocolVersion::latest();
86 let version_requirement = Comparator {
88 op: semver::Op::Caret,
89 major: protocol_version.major,
90 minor: (protocol_version.major == 0).then_some(protocol_version.minor),
91 patch: (protocol_version.major == 0 && protocol_version.minor == 0)
92 .then_some(protocol_version.patch),
93 pre: protocol_version.pre.clone(),
94 };
95 let response = self
96 .request(InitializeParams {
97 protocol_version,
98 context_servers,
99 })
100 .await?;
101
102 let server_version = &response.protocol_version;
103
104 if version_requirement.matches(server_version) {
105 Ok(response)
106 } else {
107 Err(Error::invalid_request().with_data(format!(
108 "Incompatible versions: Server {server_version} / Client: {version_requirement}"
109 )))
110 }
111 }
112}
113
114impl ClientConnection {
115 pub fn connect_to_client<H: 'static + Agent>(
116 handler: H,
117 outgoing_bytes: impl Unpin + AsyncWrite,
118 incoming_bytes: impl Unpin + AsyncRead,
119 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
120 ) -> (Self, impl Future<Output = Result<(), Error>>) {
121 let handler = Arc::new(handler);
122 let (connection, io_task) = Connection::new(
123 Box::new(move |request| {
124 let handler = handler.clone();
125 async move { handler.call(request).await }.boxed_local()
126 }),
127 outgoing_bytes,
128 incoming_bytes,
129 spawn,
130 );
131 (Self(connection), io_task)
132 }
133
134 pub fn request<R: ClientRequest>(
135 &self,
136 params: R,
137 ) -> impl use<R> + Future<Output = Result<R::Response, Error>> {
138 let params = params.into_any();
139 let result = self.0.request(params.method_name(), params);
140 async move {
141 let result = result.await?;
142 R::response_from_any(result)
143 }
144 }
145
146 pub fn request_any(
148 &self,
149 method: &'static str,
150 params: AnyClientRequest,
151 ) -> impl Future<Output = Result<AnyClientResult, Error>> {
152 self.0.request(method, params)
153 }
154}
155
156struct Connection<In, Out>
157where
158 In: AnyRequest,
159 Out: AnyRequest,
160{
161 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
162 response_senders: ResponseSenders<Out::Response>,
163 next_id: AtomicI32,
164}
165
166type ResponseSenders<T> =
167 Arc<Mutex<HashMap<i32, (&'static str, oneshot::Sender<Result<T, Error>>)>>>;
168
169#[derive(Debug, Deserialize)]
170struct IncomingMessage<'a> {
171 id: i32,
172 method: Option<&'a str>,
173 params: Option<&'a RawValue>,
174 result: Option<&'a RawValue>,
175 error: Option<Error>,
176}
177
178#[derive(Serialize)]
179#[serde(untagged)]
180enum OutgoingMessage<Req, Resp> {
181 Request {
182 id: i32,
183 method: Box<str>,
184 #[serde(skip_serializing_if = "is_none_or_null")]
185 params: Option<Req>,
186 },
187 OkResponse {
188 id: i32,
189 result: Resp,
190 },
191 ErrorResponse {
192 id: i32,
193 error: Error,
194 },
195}
196
197fn is_none_or_null<T: Serialize>(opt: &Option<T>) -> bool {
198 match opt {
199 None => true,
200 Some(value) => {
201 matches!(serde_json::to_value(value), Ok(serde_json::Value::Null))
202 }
203 }
204}
205
206#[derive(Debug, Deserialize, Serialize)]
207enum JsonSchemaVersion {
208 #[serde(rename = "2.0")]
209 V2,
210}
211
212#[derive(Serialize)]
213struct OutJsonRpcMessage<Req, Resp> {
214 jsonrpc: JsonSchemaVersion,
215 #[serde(flatten)]
216 message: OutgoingMessage<Req, Resp>,
217}
218
219type ResponseHandler<In, Resp> =
220 Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<Resp, Error>>>;
221
222impl<In, Out> Connection<In, Out>
223where
224 In: AnyRequest,
225 Out: AnyRequest,
226{
227 fn new(
228 request_handler: ResponseHandler<In, In::Response>,
229 outgoing_bytes: impl Unpin + AsyncWrite,
230 incoming_bytes: impl Unpin + AsyncRead,
231 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
232 ) -> (Self, impl Future<Output = Result<(), Error>>) {
233 let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
234 let (incoming_tx, incoming_rx) = mpsc::unbounded();
235 let this = Self {
236 response_senders: ResponseSenders::default(),
237 outgoing_tx: outgoing_tx.clone(),
238 next_id: AtomicI32::new(0),
239 };
240 Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
241 let io_task = Self::handle_io(
242 outgoing_rx,
243 incoming_tx,
244 this.response_senders.clone(),
245 outgoing_bytes,
246 incoming_bytes,
247 );
248 (this, io_task)
249 }
250
251 fn request(
252 &self,
253 method: &'static str,
254 params: Out,
255 ) -> impl use<In, Out> + Future<Output = Result<Out::Response, Error>> {
256 let (tx, rx) = oneshot::channel();
257 let id = self.next_id.fetch_add(1, SeqCst);
258 self.response_senders.lock().insert(id, (method, tx));
259 if self
260 .outgoing_tx
261 .unbounded_send(OutgoingMessage::Request {
262 id,
263 method: method.into(),
264 params: Some(params),
265 })
266 .is_err()
267 {
268 self.response_senders.lock().remove(&id);
269 }
270 async move {
271 rx.await
272 .map_err(|e| Error::internal_error().with_data(e.to_string()))?
273 }
274 }
275
276 async fn handle_io(
277 mut outgoing_rx: UnboundedReceiver<OutgoingMessage<Out, In::Response>>,
278 incoming_tx: UnboundedSender<(i32, In)>,
279 response_senders: ResponseSenders<Out::Response>,
280 mut outgoing_bytes: impl Unpin + AsyncWrite,
281 incoming_bytes: impl Unpin + AsyncRead,
282 ) -> Result<(), Error> {
283 let mut output_reader = BufReader::new(incoming_bytes);
284 let mut outgoing_line = Vec::new();
285 let mut incoming_line = String::new();
286 loop {
287 select_biased! {
288 message = outgoing_rx.next() => {
289 if let Some(message) = message {
290 let message = OutJsonRpcMessage {
291 jsonrpc: JsonSchemaVersion::V2,
292 message,
293 };
294 outgoing_line.clear();
295 serde_json::to_writer(&mut outgoing_line, &message).map_err(Error::into_internal_error)?;
296 log::trace!("send: {}", String::from_utf8_lossy(&outgoing_line));
297 outgoing_line.push(b'\n');
298 outgoing_bytes.write_all(&outgoing_line).await.ok();
299 } else {
300 break;
301 }
302 }
303 bytes_read = output_reader.read_line(&mut incoming_line).fuse() => {
304 if bytes_read.map_err(Error::into_internal_error)? == 0 {
305 break
306 }
307 log::trace!("recv: {}", &incoming_line);
308 match serde_json::from_str::<IncomingMessage>(&incoming_line) {
309 Ok(IncomingMessage { id, method, params, result, error }) => {
310 if let Some(method) = method {
311 match In::from_method_and_params(method, params.unwrap_or(RawValue::NULL)) {
312 Ok(params) => {
313 incoming_tx.unbounded_send((id, params)).ok();
314 }
315 Err(error) => {
316 log::error!("failed to parse incoming {method} message params: {error}. Raw: {incoming_line}");
317 }
318 }
319 } else if let Some(error) = error {
320 if let Some((_, tx)) = response_senders.lock().remove(&id) {
321 tx.send(Err(error)).ok();
322 }
323 } else {
324 let result = result.unwrap_or(RawValue::NULL);
325 if let Some((method, tx)) = response_senders.lock().remove(&id) {
326 match Out::response_from_method_and_result(method, result) {
327 Ok(result) => {
328 tx.send(Ok(result)).ok();
329 }
330 Err(error) => {
331 log::error!("failed to parse {method} message result: {error}. Raw: {result}");
332 }
333 }
334 }
335 }
336 }
337 Err(error) => {
338 log::error!("failed to parse incoming message: {error}. Raw: {incoming_line}");
339 }
340 }
341 incoming_line.clear();
342 }
343 }
344 }
345 response_senders.lock().clear();
346 Ok(())
347 }
348
349 fn handle_incoming(
350 outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
351 mut incoming_rx: UnboundedReceiver<(i32, In)>,
352 incoming_handler: ResponseHandler<In, In::Response>,
353 spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
354 ) {
355 let spawn = Rc::new(spawn);
356 let spawn2 = spawn.clone();
357 spawn(
358 async move {
359 while let Some((id, params)) = incoming_rx.next().await {
360 let result = incoming_handler(params);
361 let outgoing_tx = outgoing_tx.clone();
362 spawn2(
363 async move {
364 let result = result.await;
365 match result {
366 Ok(result) => {
367 outgoing_tx
368 .unbounded_send(OutgoingMessage::OkResponse { id, result })
369 .ok();
370 }
371 Err(error) => {
372 outgoing_tx
373 .unbounded_send(OutgoingMessage::ErrorResponse {
374 id,
375 error: Error::into_internal_error(error),
376 })
377 .ok();
378 }
379 }
380 }
381 .boxed_local(),
382 )
383 }
384 }
385 .boxed_local(),
386 )
387 }
388}