Skip to main content

par_term_acp/
jsonrpc.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::process::{ChildStdin, ChildStdout};
9use tokio::sync::{Mutex, mpsc, oneshot};
10
11// ---------------------------------------------------------------------------
12// Wire types
13// ---------------------------------------------------------------------------
14
15/// A JSON-RPC 2.0 request (or notification when `id` is `None`).
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Request {
18    pub jsonrpc: String,
19    pub method: String,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub params: Option<Value>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub id: Option<u64>,
24}
25
26/// A JSON-RPC 2.0 response.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct Response {
29    pub jsonrpc: String,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub result: Option<Value>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub error: Option<RpcError>,
34    pub id: Option<u64>,
35}
36
37/// A JSON-RPC 2.0 error object.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct RpcError {
40    pub code: i64,
41    pub message: String,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub data: Option<Value>,
44}
45
46impl std::fmt::Display for RpcError {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        write!(f, "RPC error {}: {}", self.code, self.message)
49    }
50}
51
52impl std::error::Error for RpcError {}
53
54/// A raw incoming JSON-RPC message that can be classified as a response,
55/// notification, or an RPC call from the remote side.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct IncomingMessage {
58    pub jsonrpc: String,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub id: Option<u64>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub method: Option<String>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub params: Option<Value>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub result: Option<Value>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub error: Option<RpcError>,
69}
70
71impl IncomingMessage {
72    /// A response has no `method` and carries `result` or `error`.
73    pub fn is_response(&self) -> bool {
74        self.method.is_none() && (self.result.is_some() || self.error.is_some())
75    }
76
77    /// A notification has a `method` but no `id`.
78    pub fn is_notification(&self) -> bool {
79        self.method.is_some() && self.id.is_none()
80    }
81
82    /// An RPC call from the remote side has both `method` and `id`.
83    pub fn is_rpc_call(&self) -> bool {
84        self.method.is_some() && self.id.is_some()
85    }
86
87    /// Convert into a [`Response`] (only valid when [`is_response`] is true).
88    pub fn into_response(self) -> Response {
89        Response {
90            jsonrpc: self.jsonrpc,
91            result: self.result,
92            error: self.error,
93            id: self.id,
94        }
95    }
96}
97
98// ---------------------------------------------------------------------------
99// Client
100// ---------------------------------------------------------------------------
101
102/// A JSON-RPC 2.0 client that communicates over line-delimited JSON on the
103/// stdin/stdout of a child process.
104pub struct JsonRpcClient {
105    /// Writer half — protected by a mutex so multiple tasks can send.
106    writer: Arc<Mutex<ChildStdin>>,
107    /// Monotonically increasing request id counter.
108    next_id: Arc<AtomicU64>,
109    /// Pending requests awaiting a response, keyed by request id.
110    pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>>,
111    /// Receiver side — handed out exactly once via `take_incoming()`.
112    incoming_rx: Option<mpsc::UnboundedReceiver<IncomingMessage>>,
113}
114
115impl std::fmt::Debug for JsonRpcClient {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        f.debug_struct("JsonRpcClient").finish_non_exhaustive()
118    }
119}
120
121impl JsonRpcClient {
122    /// Create a new client.
123    ///
124    /// Spawns a background tokio task that reads line-delimited JSON from
125    /// `stdout`, routing responses to their pending futures and everything
126    /// else (notifications / incoming RPC calls) to an mpsc channel
127    /// retrievable via [`take_incoming`].
128    pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
129        let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>> =
130            Arc::new(Mutex::new(HashMap::new()));
131        let (incoming_tx, incoming_rx) = mpsc::unbounded_channel::<IncomingMessage>();
132
133        // Spawn the reader task.
134        let reader_pending = Arc::clone(&pending);
135        let reader_tx = incoming_tx;
136        tokio::spawn(async move {
137            let mut reader = BufReader::new(stdout);
138            let mut line = String::new();
139
140            loop {
141                line.clear();
142                match reader.read_line(&mut line).await {
143                    Ok(0) => {
144                        // EOF — child process closed stdout.
145                        break;
146                    }
147                    Ok(_) => {
148                        let trimmed = line.trim();
149                        if trimmed.is_empty() {
150                            continue;
151                        }
152
153                        let msg: IncomingMessage = match serde_json::from_str(trimmed) {
154                            Ok(m) => m,
155                            Err(e) => {
156                                log::error!("Failed to parse JSON-RPC message: {e}");
157                                continue;
158                            }
159                        };
160
161                        if msg.is_response() {
162                            // Route to the pending request future.
163                            if let Some(id) = msg.id {
164                                let mut map = reader_pending.lock().await;
165                                if let Some(tx) = map.remove(&id) {
166                                    let _ = tx.send(msg.into_response());
167                                } else {
168                                    log::error!("Received response for unknown request id {id}");
169                                }
170                            } else {
171                                log::error!("Received response without id: {trimmed}");
172                            }
173                        } else {
174                            // Notification or incoming RPC call.
175                            if reader_tx.send(msg).is_err() {
176                                // Receiver dropped — stop reading.
177                                break;
178                            }
179                        }
180                    }
181                    Err(e) => {
182                        log::error!("Error reading from child stdout: {e}");
183                        break;
184                    }
185                }
186            }
187
188            // Agent process closed stdout — fail any pending requests so
189            // callers don't hang forever waiting for a response.
190            let mut map = reader_pending.lock().await;
191            for (id, tx) in map.drain() {
192                let _ = tx.send(Response {
193                    jsonrpc: "2.0".to_string(),
194                    result: None,
195                    error: Some(RpcError {
196                        code: -32003,
197                        message: "Agent process exited".to_string(),
198                        data: None,
199                    }),
200                    id: Some(id),
201                });
202            }
203        });
204
205        Self {
206            writer: Arc::new(Mutex::new(stdin)),
207            next_id: Arc::new(AtomicU64::new(1)),
208            pending,
209            incoming_rx: Some(incoming_rx),
210        }
211    }
212
213    /// Take the receiver for incoming notifications and RPC calls.
214    ///
215    /// This can only be called once — subsequent calls return `None`.
216    pub fn take_incoming(&mut self) -> Option<mpsc::UnboundedReceiver<IncomingMessage>> {
217        self.incoming_rx.take()
218    }
219
220    /// Send a JSON-RPC request and wait for the matching response.
221    pub async fn request(
222        &self,
223        method: &str,
224        params: Option<Value>,
225    ) -> Result<Response, Box<dyn std::error::Error + Send + Sync>> {
226        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
227
228        let req = Request {
229            jsonrpc: "2.0".to_string(),
230            method: method.to_string(),
231            params,
232            id: Some(id),
233        };
234
235        let (tx, rx) = oneshot::channel::<Response>();
236
237        // Register the pending request before writing to avoid races.
238        {
239            let mut map = self.pending.lock().await;
240            map.insert(id, tx);
241        }
242
243        // Serialize and send.
244        let json = serde_json::to_string(&req)?;
245        {
246            let mut writer = self.writer.lock().await;
247            writer.write_all(format!("{json}\n").as_bytes()).await?;
248            writer.flush().await?;
249        }
250
251        // Wait for the response.
252        let response = rx.await?;
253        Ok(response)
254    }
255
256    /// Send a JSON-RPC notification (no id, no response expected).
257    pub async fn notify(
258        &self,
259        method: &str,
260        params: Option<Value>,
261    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
262        let req = Request {
263            jsonrpc: "2.0".to_string(),
264            method: method.to_string(),
265            params,
266            id: None,
267        };
268
269        let json = serde_json::to_string(&req)?;
270        let mut writer = self.writer.lock().await;
271        writer.write_all(format!("{json}\n").as_bytes()).await?;
272        writer.flush().await?;
273        Ok(())
274    }
275
276    /// Send a JSON-RPC response to an incoming RPC call from the agent.
277    pub async fn respond(
278        &self,
279        id: u64,
280        result: Option<Value>,
281        error: Option<RpcError>,
282    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
283        let resp = Response {
284            jsonrpc: "2.0".to_string(),
285            result,
286            error,
287            id: Some(id),
288        };
289
290        let json = serde_json::to_string(&resp)?;
291        log::info!("ACP WIRE OUT: {json}");
292        let mut writer = self.writer.lock().await;
293        writer.write_all(format!("{json}\n").as_bytes()).await?;
294        writer.flush().await?;
295        Ok(())
296    }
297}
298
299// ---------------------------------------------------------------------------
300// Tests
301// ---------------------------------------------------------------------------
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_incoming_message_classification() {
309        let msg: IncomingMessage =
310            serde_json::from_str(r#"{"jsonrpc":"2.0","id":1,"result":{"ok":true}}"#).unwrap();
311        assert!(msg.is_response());
312        assert!(!msg.is_notification());
313        assert!(!msg.is_rpc_call());
314
315        let msg: IncomingMessage =
316            serde_json::from_str(r#"{"jsonrpc":"2.0","method":"session/update","params":{}}"#)
317                .unwrap();
318        assert!(!msg.is_response());
319        assert!(msg.is_notification());
320        assert!(!msg.is_rpc_call());
321
322        let msg: IncomingMessage = serde_json::from_str(
323            r#"{"jsonrpc":"2.0","id":5,"method":"session/request_permission","params":{}}"#,
324        )
325        .unwrap();
326        assert!(!msg.is_response());
327        assert!(!msg.is_notification());
328        assert!(msg.is_rpc_call());
329    }
330
331    #[test]
332    fn test_request_serialization() {
333        let req = Request {
334            jsonrpc: "2.0".to_string(),
335            method: "initialize".to_string(),
336            params: Some(serde_json::json!({"protocolVersion": 1})),
337            id: Some(1),
338        };
339        let json = serde_json::to_string(&req).unwrap();
340        assert!(json.contains("initialize"));
341        assert!(json.contains("protocolVersion"));
342    }
343
344    #[test]
345    fn test_notification_has_no_id() {
346        let req = Request {
347            jsonrpc: "2.0".to_string(),
348            method: "session/update".to_string(),
349            params: Some(serde_json::json!({"status": "active"})),
350            id: None,
351        };
352        let json = serde_json::to_string(&req).unwrap();
353        assert!(!json.contains("\"id\""));
354    }
355
356    #[test]
357    fn test_response_serialization() {
358        let resp = Response {
359            jsonrpc: "2.0".to_string(),
360            result: Some(serde_json::json!({"capabilities": {}})),
361            error: None,
362            id: Some(1),
363        };
364        let json = serde_json::to_string(&resp).unwrap();
365        assert!(json.contains("capabilities"));
366        assert!(!json.contains("error"));
367    }
368
369    #[test]
370    fn test_rpc_error_display() {
371        let err = RpcError {
372            code: -32600,
373            message: "Invalid Request".to_string(),
374            data: None,
375        };
376        assert_eq!(format!("{err}"), "RPC error -32600: Invalid Request");
377    }
378
379    #[test]
380    fn test_incoming_into_response() {
381        let msg: IncomingMessage =
382            serde_json::from_str(r#"{"jsonrpc":"2.0","id":42,"result":{"data":"hello"}}"#).unwrap();
383        assert!(msg.is_response());
384
385        let resp = msg.into_response();
386        assert_eq!(resp.id, Some(42));
387        assert!(resp.result.is_some());
388        assert!(resp.error.is_none());
389    }
390}