Skip to main content

harn_vm/
bridge.rs

1//! JSON-RPC 2.0 bridge for host communication.
2//!
3//! When `harn run --bridge` is used, the VM delegates builtins (llm_call,
4//! file I/O, tool execution) to a host process over stdin/stdout JSON-RPC.
5//! The host (e.g., Burin IDE) handles these requests using its own providers.
6
7use std::collections::HashMap;
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{VmError, VmValue};
17
18/// Default timeout for bridge calls (5 minutes).
19const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21/// A JSON-RPC 2.0 bridge to a host process over stdin/stdout.
22///
23/// The bridge sends requests to the host on stdout and receives responses
24/// on stdin. A background task reads stdin and dispatches responses to
25/// waiting callers by request ID. All stdout writes are serialized through
26/// a mutex to prevent interleaving.
27pub struct HostBridge {
28    next_id: AtomicU64,
29    /// Pending request waiters, keyed by JSON-RPC id.
30    pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31    /// Whether the host has sent a cancel notification.
32    cancelled: Arc<AtomicBool>,
33    /// Mutex protecting stdout writes to prevent interleaving.
34    stdout_lock: Arc<std::sync::Mutex<()>>,
35}
36
37// Default doesn't apply — new() spawns async tasks requiring a tokio LocalSet.
38#[allow(clippy::new_without_default)]
39impl HostBridge {
40    /// Create a new bridge and spawn the stdin reader task.
41    ///
42    /// Must be called within a tokio LocalSet (uses spawn_local for the
43    /// stdin reader since it's single-threaded).
44    pub fn new() -> Self {
45        let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
46            Arc::new(Mutex::new(HashMap::new()));
47        let cancelled = Arc::new(AtomicBool::new(false));
48
49        // Stdin reader: reads JSON-RPC lines and dispatches responses
50        let pending_clone = pending.clone();
51        let cancelled_clone = cancelled.clone();
52        tokio::task::spawn_local(async move {
53            let stdin = tokio::io::stdin();
54            let reader = tokio::io::BufReader::new(stdin);
55            let mut lines = reader.lines();
56
57            while let Ok(Some(line)) = lines.next_line().await {
58                let line = line.trim().to_string();
59                if line.is_empty() {
60                    continue;
61                }
62
63                let msg: serde_json::Value = match serde_json::from_str(&line) {
64                    Ok(v) => v,
65                    Err(_) => continue, // Skip malformed lines
66                };
67
68                // Check if this is a notification from the host (no id)
69                if msg.get("id").is_none() {
70                    if let Some(method) = msg["method"].as_str() {
71                        if method == "cancel" {
72                            cancelled_clone.store(true, Ordering::SeqCst);
73                        }
74                    }
75                    continue;
76                }
77
78                // This is a response — dispatch to the waiting caller
79                if let Some(id) = msg["id"].as_u64() {
80                    let mut pending = pending_clone.lock().await;
81                    if let Some(sender) = pending.remove(&id) {
82                        let _ = sender.send(msg);
83                    }
84                }
85            }
86
87            // stdin closed — cancel any remaining pending requests by dropping senders
88            let mut pending = pending_clone.lock().await;
89            pending.clear();
90        });
91
92        Self {
93            next_id: AtomicU64::new(1),
94            pending,
95            cancelled,
96            stdout_lock: Arc::new(std::sync::Mutex::new(())),
97        }
98    }
99
100    /// Write a complete JSON-RPC line to stdout, serialized through a mutex.
101    fn write_line(&self, line: &str) -> Result<(), VmError> {
102        let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
103        let mut stdout = std::io::stdout().lock();
104        stdout
105            .write_all(line.as_bytes())
106            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
107        stdout
108            .write_all(b"\n")
109            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
110        stdout
111            .flush()
112            .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
113        Ok(())
114    }
115
116    /// Send a JSON-RPC request to the host and wait for the response.
117    /// Times out after 5 minutes to prevent deadlocks.
118    pub async fn call(
119        &self,
120        method: &str,
121        params: serde_json::Value,
122    ) -> Result<serde_json::Value, VmError> {
123        if self.is_cancelled() {
124            return Err(VmError::Runtime("Bridge: operation cancelled".into()));
125        }
126
127        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
128
129        let request = serde_json::json!({
130            "jsonrpc": "2.0",
131            "id": id,
132            "method": method,
133            "params": params,
134        });
135
136        // Register a oneshot channel to receive the response
137        let (tx, rx) = oneshot::channel();
138        {
139            let mut pending = self.pending.lock().await;
140            pending.insert(id, tx);
141        }
142
143        // Send the request (serialized through stdout mutex)
144        let line = serde_json::to_string(&request)
145            .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
146        if let Err(e) = self.write_line(&line) {
147            // Clean up pending entry on write failure
148            let mut pending = self.pending.lock().await;
149            pending.remove(&id);
150            return Err(e);
151        }
152
153        // Wait for the response with timeout
154        let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
155            Ok(Ok(msg)) => msg,
156            Ok(Err(_)) => {
157                // Sender dropped — host closed or stdin reader exited
158                return Err(VmError::Runtime(
159                    "Bridge: host closed connection before responding".into(),
160                ));
161            }
162            Err(_) => {
163                // Timeout — clean up pending entry
164                let mut pending = self.pending.lock().await;
165                pending.remove(&id);
166                return Err(VmError::Runtime(format!(
167                    "Bridge: host did not respond to '{method}' within {}s",
168                    DEFAULT_TIMEOUT.as_secs()
169                )));
170            }
171        };
172
173        // Check for JSON-RPC error
174        if let Some(error) = response.get("error") {
175            let message = error["message"].as_str().unwrap_or("Unknown host error");
176            let code = error["code"].as_i64().unwrap_or(-1);
177            return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
178        }
179
180        Ok(response["result"].clone())
181    }
182
183    /// Send a JSON-RPC notification to the host (no response expected).
184    /// Serialized through the stdout mutex to prevent interleaving.
185    pub fn notify(&self, method: &str, params: serde_json::Value) {
186        let notification = serde_json::json!({
187            "jsonrpc": "2.0",
188            "method": method,
189            "params": params,
190        });
191        if let Ok(line) = serde_json::to_string(&notification) {
192            let _ = self.write_line(&line);
193        }
194    }
195
196    /// Check if the host has sent a cancel notification.
197    pub fn is_cancelled(&self) -> bool {
198        self.cancelled.load(Ordering::SeqCst)
199    }
200
201    /// Send an output notification (for log/print in bridge mode).
202    pub fn send_output(&self, text: &str) {
203        self.notify("output", serde_json::json!({"text": text}));
204    }
205
206    /// Send a progress notification.
207    pub fn send_progress(&self, phase: &str, message: &str) {
208        self.notify(
209            "progress",
210            serde_json::json!({"phase": phase, "message": message}),
211        );
212    }
213}
214
215/// Convert a serde_json::Value to a VmValue.
216pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
217    crate::stdlib::json_to_vm_value(val)
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_json_rpc_request_format() {
226        let request = serde_json::json!({
227            "jsonrpc": "2.0",
228            "id": 1,
229            "method": "llm_call",
230            "params": {
231                "prompt": "Hello",
232                "system": "Be helpful",
233            },
234        });
235        let s = serde_json::to_string(&request).unwrap();
236        assert!(s.contains("\"jsonrpc\":\"2.0\""));
237        assert!(s.contains("\"id\":1"));
238        assert!(s.contains("\"method\":\"llm_call\""));
239    }
240
241    #[test]
242    fn test_json_rpc_notification_format() {
243        let notification = serde_json::json!({
244            "jsonrpc": "2.0",
245            "method": "output",
246            "params": {"text": "[harn] hello\n"},
247        });
248        let s = serde_json::to_string(&notification).unwrap();
249        assert!(s.contains("\"method\":\"output\""));
250        assert!(!s.contains("\"id\""));
251    }
252
253    #[test]
254    fn test_json_rpc_error_response_parsing() {
255        let response = serde_json::json!({
256            "jsonrpc": "2.0",
257            "id": 1,
258            "error": {
259                "code": -32600,
260                "message": "Invalid request",
261            },
262        });
263        assert!(response.get("error").is_some());
264        assert_eq!(
265            response["error"]["message"].as_str().unwrap(),
266            "Invalid request"
267        );
268    }
269
270    #[test]
271    fn test_json_rpc_success_response_parsing() {
272        let response = serde_json::json!({
273            "jsonrpc": "2.0",
274            "id": 1,
275            "result": {
276                "text": "Hello world",
277                "input_tokens": 10,
278                "output_tokens": 5,
279            },
280        });
281        assert!(response.get("result").is_some());
282        assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
283    }
284
285    #[test]
286    fn test_cancelled_flag() {
287        let cancelled = Arc::new(AtomicBool::new(false));
288        assert!(!cancelled.load(Ordering::SeqCst));
289        cancelled.store(true, Ordering::SeqCst);
290        assert!(cancelled.load(Ordering::SeqCst));
291    }
292
293    #[test]
294    fn test_json_result_to_vm_value_string() {
295        let val = serde_json::json!("hello");
296        let vm_val = json_result_to_vm_value(&val);
297        assert_eq!(vm_val.display(), "hello");
298    }
299
300    #[test]
301    fn test_json_result_to_vm_value_dict() {
302        let val = serde_json::json!({"name": "test", "count": 42});
303        let vm_val = json_result_to_vm_value(&val);
304        if let VmValue::Dict(d) = &vm_val {
305            assert_eq!(d.get("name").unwrap().display(), "test");
306            assert_eq!(d.get("count").unwrap().display(), "42");
307        } else {
308            panic!("Expected Dict, got {:?}", vm_val);
309        }
310    }
311
312    #[test]
313    fn test_json_result_to_vm_value_null() {
314        let val = serde_json::json!(null);
315        let vm_val = json_result_to_vm_value(&val);
316        assert!(matches!(vm_val, VmValue::Nil));
317    }
318
319    #[test]
320    fn test_json_result_to_vm_value_nested() {
321        let val = serde_json::json!({
322            "text": "response",
323            "tool_calls": [
324                {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
325            ],
326            "input_tokens": 100,
327            "output_tokens": 50,
328        });
329        let vm_val = json_result_to_vm_value(&val);
330        if let VmValue::Dict(d) = &vm_val {
331            assert_eq!(d.get("text").unwrap().display(), "response");
332            if let VmValue::List(list) = d.get("tool_calls").unwrap() {
333                assert_eq!(list.len(), 1);
334            } else {
335                panic!("Expected List for tool_calls");
336            }
337        } else {
338            panic!("Expected Dict");
339        }
340    }
341
342    #[test]
343    fn test_timeout_duration() {
344        assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
345    }
346}