Skip to main content

harn_vm/
bridge.rs

1//! JSON-RPC 2.0 bridge for host communication.
2//!
3//! When `harn run --bridge` is used, the VM delegates builtins (llm_call,
4//! file I/O, tool execution) to a host process over stdin/stdout JSON-RPC.
5//! The host (e.g., Burin IDE) handles these requests using its own providers.
6
7use std::collections::HashMap;
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{VmError, VmValue};
17
18/// Default timeout for bridge calls (5 minutes).
19const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21/// A JSON-RPC 2.0 bridge to a host process over stdin/stdout.
22///
23/// The bridge sends requests to the host on stdout and receives responses
24/// on stdin. A background task reads stdin and dispatches responses to
25/// waiting callers by request ID. All stdout writes are serialized through
26/// a mutex to prevent interleaving.
27pub struct HostBridge {
28    next_id: AtomicU64,
29    /// Pending request waiters, keyed by JSON-RPC id.
30    pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31    /// Whether the host has sent a cancel notification.
32    cancelled: Arc<AtomicBool>,
33    /// Mutex protecting stdout writes to prevent interleaving.
34    stdout_lock: Arc<std::sync::Mutex<()>>,
35}
36
37// Default doesn't apply — new() spawns async tasks requiring a tokio LocalSet.
38#[allow(clippy::new_without_default)]
39impl HostBridge {
40    /// Create a new bridge and spawn the stdin reader task.
41    ///
42    /// Must be called within a tokio LocalSet (uses spawn_local for the
43    /// stdin reader since it's single-threaded).
44    pub fn new() -> Self {
45        let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
46            Arc::new(Mutex::new(HashMap::new()));
47        let cancelled = Arc::new(AtomicBool::new(false));
48
49        // Stdin reader: reads JSON-RPC lines and dispatches responses
50        let pending_clone = pending.clone();
51        let cancelled_clone = cancelled.clone();
52        tokio::task::spawn_local(async move {
53            let stdin = tokio::io::stdin();
54            let reader = tokio::io::BufReader::new(stdin);
55            let mut lines = reader.lines();
56
57            while let Ok(Some(line)) = lines.next_line().await {
58                let line = line.trim().to_string();
59                if line.is_empty() {
60                    continue;
61                }
62
63                let msg: serde_json::Value = match serde_json::from_str(&line) {
64                    Ok(v) => v,
65                    Err(_) => continue, // Skip malformed lines
66                };
67
68                // Check if this is a notification from the host (no id)
69                if msg.get("id").is_none() {
70                    if let Some(method) = msg["method"].as_str() {
71                        if method == "cancel" {
72                            cancelled_clone.store(true, Ordering::SeqCst);
73                        }
74                    }
75                    continue;
76                }
77
78                // This is a response — dispatch to the waiting caller
79                if let Some(id) = msg["id"].as_u64() {
80                    let mut pending = pending_clone.lock().await;
81                    if let Some(sender) = pending.remove(&id) {
82                        let _ = sender.send(msg);
83                    }
84                }
85            }
86
87            // stdin closed — cancel any remaining pending requests by dropping senders
88            let mut pending = pending_clone.lock().await;
89            pending.clear();
90        });
91
92        Self {
93            next_id: AtomicU64::new(1),
94            pending,
95            cancelled,
96            stdout_lock: Arc::new(std::sync::Mutex::new(())),
97        }
98    }
99
100    /// Create a bridge from pre-existing shared state.
101    ///
102    /// Unlike `new()`, does **not** spawn a stdin reader — the caller is
103    /// responsible for dispatching responses into `pending`.  This is used
104    /// by ACP mode which already has its own stdin reader.
105    pub fn from_parts(
106        pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
107        cancelled: Arc<AtomicBool>,
108        stdout_lock: Arc<std::sync::Mutex<()>>,
109        start_id: u64,
110    ) -> Self {
111        Self {
112            next_id: AtomicU64::new(start_id),
113            pending,
114            cancelled,
115            stdout_lock,
116        }
117    }
118
119    /// Write a complete JSON-RPC line to stdout, serialized through a mutex.
120    fn write_line(&self, line: &str) -> Result<(), VmError> {
121        let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
122        let mut stdout = std::io::stdout().lock();
123        stdout
124            .write_all(line.as_bytes())
125            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
126        stdout
127            .write_all(b"\n")
128            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
129        stdout
130            .flush()
131            .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
132        Ok(())
133    }
134
135    /// Send a JSON-RPC request to the host and wait for the response.
136    /// Times out after 5 minutes to prevent deadlocks.
137    pub async fn call(
138        &self,
139        method: &str,
140        params: serde_json::Value,
141    ) -> Result<serde_json::Value, VmError> {
142        if self.is_cancelled() {
143            return Err(VmError::Runtime("Bridge: operation cancelled".into()));
144        }
145
146        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
147
148        let request = serde_json::json!({
149            "jsonrpc": "2.0",
150            "id": id,
151            "method": method,
152            "params": params,
153        });
154
155        // Register a oneshot channel to receive the response
156        let (tx, rx) = oneshot::channel();
157        {
158            let mut pending = self.pending.lock().await;
159            pending.insert(id, tx);
160        }
161
162        // Send the request (serialized through stdout mutex)
163        let line = serde_json::to_string(&request)
164            .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
165        if let Err(e) = self.write_line(&line) {
166            // Clean up pending entry on write failure
167            let mut pending = self.pending.lock().await;
168            pending.remove(&id);
169            return Err(e);
170        }
171
172        // Wait for the response with timeout
173        let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
174            Ok(Ok(msg)) => msg,
175            Ok(Err(_)) => {
176                // Sender dropped — host closed or stdin reader exited
177                return Err(VmError::Runtime(
178                    "Bridge: host closed connection before responding".into(),
179                ));
180            }
181            Err(_) => {
182                // Timeout — clean up pending entry
183                let mut pending = self.pending.lock().await;
184                pending.remove(&id);
185                return Err(VmError::Runtime(format!(
186                    "Bridge: host did not respond to '{method}' within {}s",
187                    DEFAULT_TIMEOUT.as_secs()
188                )));
189            }
190        };
191
192        // Check for JSON-RPC error
193        if let Some(error) = response.get("error") {
194            let message = error["message"].as_str().unwrap_or("Unknown host error");
195            let code = error["code"].as_i64().unwrap_or(-1);
196            return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
197        }
198
199        Ok(response["result"].clone())
200    }
201
202    /// Send a JSON-RPC notification to the host (no response expected).
203    /// Serialized through the stdout mutex to prevent interleaving.
204    pub fn notify(&self, method: &str, params: serde_json::Value) {
205        let notification = serde_json::json!({
206            "jsonrpc": "2.0",
207            "method": method,
208            "params": params,
209        });
210        if let Ok(line) = serde_json::to_string(&notification) {
211            let _ = self.write_line(&line);
212        }
213    }
214
215    /// Check if the host has sent a cancel notification.
216    pub fn is_cancelled(&self) -> bool {
217        self.cancelled.load(Ordering::SeqCst)
218    }
219
220    /// Send an output notification (for log/print in bridge mode).
221    pub fn send_output(&self, text: &str) {
222        self.notify("output", serde_json::json!({"text": text}));
223    }
224
225    /// Send a progress notification.
226    pub fn send_progress(&self, phase: &str, message: &str) {
227        self.notify(
228            "progress",
229            serde_json::json!({"phase": phase, "message": message}),
230        );
231    }
232}
233
234/// Convert a serde_json::Value to a VmValue.
235pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
236    crate::stdlib::json_to_vm_value(val)
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn test_json_rpc_request_format() {
245        let request = serde_json::json!({
246            "jsonrpc": "2.0",
247            "id": 1,
248            "method": "llm_call",
249            "params": {
250                "prompt": "Hello",
251                "system": "Be helpful",
252            },
253        });
254        let s = serde_json::to_string(&request).unwrap();
255        assert!(s.contains("\"jsonrpc\":\"2.0\""));
256        assert!(s.contains("\"id\":1"));
257        assert!(s.contains("\"method\":\"llm_call\""));
258    }
259
260    #[test]
261    fn test_json_rpc_notification_format() {
262        let notification = serde_json::json!({
263            "jsonrpc": "2.0",
264            "method": "output",
265            "params": {"text": "[harn] hello\n"},
266        });
267        let s = serde_json::to_string(&notification).unwrap();
268        assert!(s.contains("\"method\":\"output\""));
269        assert!(!s.contains("\"id\""));
270    }
271
272    #[test]
273    fn test_json_rpc_error_response_parsing() {
274        let response = serde_json::json!({
275            "jsonrpc": "2.0",
276            "id": 1,
277            "error": {
278                "code": -32600,
279                "message": "Invalid request",
280            },
281        });
282        assert!(response.get("error").is_some());
283        assert_eq!(
284            response["error"]["message"].as_str().unwrap(),
285            "Invalid request"
286        );
287    }
288
289    #[test]
290    fn test_json_rpc_success_response_parsing() {
291        let response = serde_json::json!({
292            "jsonrpc": "2.0",
293            "id": 1,
294            "result": {
295                "text": "Hello world",
296                "input_tokens": 10,
297                "output_tokens": 5,
298            },
299        });
300        assert!(response.get("result").is_some());
301        assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
302    }
303
304    #[test]
305    fn test_cancelled_flag() {
306        let cancelled = Arc::new(AtomicBool::new(false));
307        assert!(!cancelled.load(Ordering::SeqCst));
308        cancelled.store(true, Ordering::SeqCst);
309        assert!(cancelled.load(Ordering::SeqCst));
310    }
311
312    #[test]
313    fn test_json_result_to_vm_value_string() {
314        let val = serde_json::json!("hello");
315        let vm_val = json_result_to_vm_value(&val);
316        assert_eq!(vm_val.display(), "hello");
317    }
318
319    #[test]
320    fn test_json_result_to_vm_value_dict() {
321        let val = serde_json::json!({"name": "test", "count": 42});
322        let vm_val = json_result_to_vm_value(&val);
323        if let VmValue::Dict(d) = &vm_val {
324            assert_eq!(d.get("name").unwrap().display(), "test");
325            assert_eq!(d.get("count").unwrap().display(), "42");
326        } else {
327            panic!("Expected Dict, got {:?}", vm_val);
328        }
329    }
330
331    #[test]
332    fn test_json_result_to_vm_value_null() {
333        let val = serde_json::json!(null);
334        let vm_val = json_result_to_vm_value(&val);
335        assert!(matches!(vm_val, VmValue::Nil));
336    }
337
338    #[test]
339    fn test_json_result_to_vm_value_nested() {
340        let val = serde_json::json!({
341            "text": "response",
342            "tool_calls": [
343                {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
344            ],
345            "input_tokens": 100,
346            "output_tokens": 50,
347        });
348        let vm_val = json_result_to_vm_value(&val);
349        if let VmValue::Dict(d) = &vm_val {
350            assert_eq!(d.get("text").unwrap().display(), "response");
351            if let VmValue::List(list) = d.get("tool_calls").unwrap() {
352                assert_eq!(list.len(), 1);
353            } else {
354                panic!("Expected List for tool_calls");
355            }
356        } else {
357            panic!("Expected Dict");
358        }
359    }
360
361    #[test]
362    fn test_timeout_duration() {
363        assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
364    }
365}