1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::process::{ChildStdin, ChildStdout};
9use tokio::sync::{Mutex, mpsc, oneshot};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Request {
18 pub jsonrpc: String,
19 pub method: String,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub params: Option<Value>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub id: Option<u64>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct Response {
29 pub jsonrpc: String,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub result: Option<Value>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub error: Option<RpcError>,
34 pub id: Option<u64>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct RpcError {
40 pub code: i64,
41 pub message: String,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub data: Option<Value>,
44}
45
46impl std::fmt::Display for RpcError {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 write!(f, "RPC error {}: {}", self.code, self.message)
49 }
50}
51
52impl std::error::Error for RpcError {}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct IncomingMessage {
58 pub jsonrpc: String,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub id: Option<u64>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub method: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub params: Option<Value>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub result: Option<Value>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub error: Option<RpcError>,
69}
70
71impl IncomingMessage {
72 pub fn is_response(&self) -> bool {
74 self.method.is_none() && (self.result.is_some() || self.error.is_some())
75 }
76
77 pub fn is_notification(&self) -> bool {
79 self.method.is_some() && self.id.is_none()
80 }
81
82 pub fn is_rpc_call(&self) -> bool {
84 self.method.is_some() && self.id.is_some()
85 }
86
87 pub fn into_response(self) -> Response {
89 Response {
90 jsonrpc: self.jsonrpc,
91 result: self.result,
92 error: self.error,
93 id: self.id,
94 }
95 }
96}
97
98pub struct JsonRpcClient {
105 writer: Arc<Mutex<ChildStdin>>,
107 next_id: Arc<AtomicU64>,
109 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>>,
111 incoming_rx: Option<mpsc::UnboundedReceiver<IncomingMessage>>,
113}
114
115impl std::fmt::Debug for JsonRpcClient {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 f.debug_struct("JsonRpcClient").finish_non_exhaustive()
118 }
119}
120
121impl JsonRpcClient {
122 pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
129 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>> =
130 Arc::new(Mutex::new(HashMap::new()));
131 let (incoming_tx, incoming_rx) = mpsc::unbounded_channel::<IncomingMessage>();
132
133 let reader_pending = Arc::clone(&pending);
135 let reader_tx = incoming_tx;
136 tokio::spawn(async move {
137 let mut reader = BufReader::new(stdout);
138 let mut line = String::new();
139
140 loop {
141 line.clear();
142 match reader.read_line(&mut line).await {
143 Ok(0) => {
144 break;
146 }
147 Ok(_) => {
148 let trimmed = line.trim();
149 if trimmed.is_empty() {
150 continue;
151 }
152
153 let msg: IncomingMessage = match serde_json::from_str(trimmed) {
154 Ok(m) => m,
155 Err(e) => {
156 log::error!("Failed to parse JSON-RPC message: {e}");
157 continue;
158 }
159 };
160
161 if msg.is_response() {
162 if let Some(id) = msg.id {
164 let mut map = reader_pending.lock().await;
165 if let Some(tx) = map.remove(&id) {
166 let _ = tx.send(msg.into_response());
167 } else {
168 log::error!("Received response for unknown request id {id}");
169 }
170 } else {
171 log::error!("Received response without id: {trimmed}");
172 }
173 } else {
174 if reader_tx.send(msg).is_err() {
176 break;
178 }
179 }
180 }
181 Err(e) => {
182 log::error!("Error reading from child stdout: {e}");
183 break;
184 }
185 }
186 }
187
188 let mut map = reader_pending.lock().await;
191 for (id, tx) in map.drain() {
192 let _ = tx.send(Response {
193 jsonrpc: "2.0".to_string(),
194 result: None,
195 error: Some(RpcError {
196 code: -32003,
197 message: "Agent process exited".to_string(),
198 data: None,
199 }),
200 id: Some(id),
201 });
202 }
203 });
204
205 Self {
206 writer: Arc::new(Mutex::new(stdin)),
207 next_id: Arc::new(AtomicU64::new(1)),
208 pending,
209 incoming_rx: Some(incoming_rx),
210 }
211 }
212
213 pub fn take_incoming(&mut self) -> Option<mpsc::UnboundedReceiver<IncomingMessage>> {
217 self.incoming_rx.take()
218 }
219
220 pub async fn request(
222 &self,
223 method: &str,
224 params: Option<Value>,
225 ) -> Result<Response, Box<dyn std::error::Error + Send + Sync>> {
226 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
227
228 let req = Request {
229 jsonrpc: "2.0".to_string(),
230 method: method.to_string(),
231 params,
232 id: Some(id),
233 };
234
235 let (tx, rx) = oneshot::channel::<Response>();
236
237 {
239 let mut map = self.pending.lock().await;
240 map.insert(id, tx);
241 }
242
243 let json = serde_json::to_string(&req)?;
245 {
246 let mut writer = self.writer.lock().await;
247 writer.write_all(format!("{json}\n").as_bytes()).await?;
248 writer.flush().await?;
249 }
250
251 let response = rx.await?;
253 Ok(response)
254 }
255
256 pub async fn notify(
258 &self,
259 method: &str,
260 params: Option<Value>,
261 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
262 let req = Request {
263 jsonrpc: "2.0".to_string(),
264 method: method.to_string(),
265 params,
266 id: None,
267 };
268
269 let json = serde_json::to_string(&req)?;
270 let mut writer = self.writer.lock().await;
271 writer.write_all(format!("{json}\n").as_bytes()).await?;
272 writer.flush().await?;
273 Ok(())
274 }
275
276 pub async fn respond(
278 &self,
279 id: u64,
280 result: Option<Value>,
281 error: Option<RpcError>,
282 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
283 let resp = Response {
284 jsonrpc: "2.0".to_string(),
285 result,
286 error,
287 id: Some(id),
288 };
289
290 let json = serde_json::to_string(&resp)?;
291 log::info!("ACP WIRE OUT: {json}");
292 let mut writer = self.writer.lock().await;
293 writer.write_all(format!("{json}\n").as_bytes()).await?;
294 writer.flush().await?;
295 Ok(())
296 }
297}
298
299#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_incoming_message_classification() {
309 let msg: IncomingMessage =
310 serde_json::from_str(r#"{"jsonrpc":"2.0","id":1,"result":{"ok":true}}"#).unwrap();
311 assert!(msg.is_response());
312 assert!(!msg.is_notification());
313 assert!(!msg.is_rpc_call());
314
315 let msg: IncomingMessage =
316 serde_json::from_str(r#"{"jsonrpc":"2.0","method":"session/update","params":{}}"#)
317 .unwrap();
318 assert!(!msg.is_response());
319 assert!(msg.is_notification());
320 assert!(!msg.is_rpc_call());
321
322 let msg: IncomingMessage = serde_json::from_str(
323 r#"{"jsonrpc":"2.0","id":5,"method":"session/request_permission","params":{}}"#,
324 )
325 .unwrap();
326 assert!(!msg.is_response());
327 assert!(!msg.is_notification());
328 assert!(msg.is_rpc_call());
329 }
330
331 #[test]
332 fn test_request_serialization() {
333 let req = Request {
334 jsonrpc: "2.0".to_string(),
335 method: "initialize".to_string(),
336 params: Some(serde_json::json!({"protocolVersion": 1})),
337 id: Some(1),
338 };
339 let json = serde_json::to_string(&req).unwrap();
340 assert!(json.contains("initialize"));
341 assert!(json.contains("protocolVersion"));
342 }
343
344 #[test]
345 fn test_notification_has_no_id() {
346 let req = Request {
347 jsonrpc: "2.0".to_string(),
348 method: "session/update".to_string(),
349 params: Some(serde_json::json!({"status": "active"})),
350 id: None,
351 };
352 let json = serde_json::to_string(&req).unwrap();
353 assert!(!json.contains("\"id\""));
354 }
355
356 #[test]
357 fn test_response_serialization() {
358 let resp = Response {
359 jsonrpc: "2.0".to_string(),
360 result: Some(serde_json::json!({"capabilities": {}})),
361 error: None,
362 id: Some(1),
363 };
364 let json = serde_json::to_string(&resp).unwrap();
365 assert!(json.contains("capabilities"));
366 assert!(!json.contains("error"));
367 }
368
369 #[test]
370 fn test_rpc_error_display() {
371 let err = RpcError {
372 code: -32600,
373 message: "Invalid Request".to_string(),
374 data: None,
375 };
376 assert_eq!(format!("{err}"), "RPC error -32600: Invalid Request");
377 }
378
379 #[test]
380 fn test_incoming_into_response() {
381 let msg: IncomingMessage =
382 serde_json::from_str(r#"{"jsonrpc":"2.0","id":42,"result":{"data":"hello"}}"#).unwrap();
383 assert!(msg.is_response());
384
385 let resp = msg.into_response();
386 assert_eq!(resp.id, Some(42));
387 assert!(resp.result.is_some());
388 assert!(resp.error.is_none());
389 }
390}