Skip to main content

harn_vm/
bridge.rs

1//! JSON-RPC 2.0 bridge for host communication.
2//!
3//! When `harn run --bridge` is used, the VM delegates builtins (llm_call,
4//! file I/O, tool execution) to a host process over stdin/stdout JSON-RPC.
5//! The host (e.g., Burin IDE) handles these requests using its own providers.
6
7use std::collections::HashMap;
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{VmError, VmValue};
17
18/// Default timeout for bridge calls (5 minutes).
19const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21/// A JSON-RPC 2.0 bridge to a host process over stdin/stdout.
22///
23/// The bridge sends requests to the host on stdout and receives responses
24/// on stdin. A background task reads stdin and dispatches responses to
25/// waiting callers by request ID. All stdout writes are serialized through
26/// a mutex to prevent interleaving.
27pub struct HostBridge {
28    next_id: AtomicU64,
29    /// Pending request waiters, keyed by JSON-RPC id.
30    pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31    /// Whether the host has sent a cancel notification.
32    cancelled: Arc<AtomicBool>,
33    /// Mutex protecting stdout writes to prevent interleaving.
34    stdout_lock: Arc<std::sync::Mutex<()>>,
35}
36
37// Default doesn't apply — new() spawns async tasks requiring a tokio LocalSet.
38#[allow(clippy::new_without_default)]
39impl HostBridge {
40    /// Create a new bridge and spawn the stdin reader task.
41    ///
42    /// Must be called within a tokio LocalSet (uses spawn_local for the
43    /// stdin reader since it's single-threaded).
44    pub fn new() -> Self {
45        let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
46            Arc::new(Mutex::new(HashMap::new()));
47        let cancelled = Arc::new(AtomicBool::new(false));
48
49        // Stdin reader: reads JSON-RPC lines and dispatches responses
50        let pending_clone = pending.clone();
51        let cancelled_clone = cancelled.clone();
52        tokio::task::spawn_local(async move {
53            let stdin = tokio::io::stdin();
54            let reader = tokio::io::BufReader::new(stdin);
55            let mut lines = reader.lines();
56
57            while let Ok(Some(line)) = lines.next_line().await {
58                let line = line.trim().to_string();
59                if line.is_empty() {
60                    continue;
61                }
62
63                let msg: serde_json::Value = match serde_json::from_str(&line) {
64                    Ok(v) => v,
65                    Err(_) => continue, // Skip malformed lines
66                };
67
68                // Check if this is a notification from the host (no id)
69                if msg.get("id").is_none() {
70                    if let Some(method) = msg["method"].as_str() {
71                        if method == "cancel" {
72                            cancelled_clone.store(true, Ordering::SeqCst);
73                        }
74                    }
75                    continue;
76                }
77
78                // This is a response — dispatch to the waiting caller
79                if let Some(id) = msg["id"].as_u64() {
80                    let mut pending = pending_clone.lock().await;
81                    if let Some(sender) = pending.remove(&id) {
82                        let _ = sender.send(msg);
83                    }
84                }
85            }
86
87            // stdin closed — cancel any remaining pending requests by dropping senders
88            let mut pending = pending_clone.lock().await;
89            pending.clear();
90        });
91
92        Self {
93            next_id: AtomicU64::new(1),
94            pending,
95            cancelled,
96            stdout_lock: Arc::new(std::sync::Mutex::new(())),
97        }
98    }
99
100    /// Create a bridge from pre-existing shared state.
101    ///
102    /// Unlike `new()`, does **not** spawn a stdin reader — the caller is
103    /// responsible for dispatching responses into `pending`.  This is used
104    /// by ACP mode which already has its own stdin reader.
105    pub fn from_parts(
106        pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
107        cancelled: Arc<AtomicBool>,
108        stdout_lock: Arc<std::sync::Mutex<()>>,
109        start_id: u64,
110    ) -> Self {
111        Self {
112            next_id: AtomicU64::new(start_id),
113            pending,
114            cancelled,
115            stdout_lock,
116        }
117    }
118
119    /// Write a complete JSON-RPC line to stdout, serialized through a mutex.
120    fn write_line(&self, line: &str) -> Result<(), VmError> {
121        let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
122        let mut stdout = std::io::stdout().lock();
123        stdout
124            .write_all(line.as_bytes())
125            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
126        stdout
127            .write_all(b"\n")
128            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
129        stdout
130            .flush()
131            .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
132        Ok(())
133    }
134
135    /// Send a JSON-RPC request to the host and wait for the response.
136    /// Times out after 5 minutes to prevent deadlocks.
137    pub async fn call(
138        &self,
139        method: &str,
140        params: serde_json::Value,
141    ) -> Result<serde_json::Value, VmError> {
142        if self.is_cancelled() {
143            return Err(VmError::Runtime("Bridge: operation cancelled".into()));
144        }
145
146        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
147
148        let request = serde_json::json!({
149            "jsonrpc": "2.0",
150            "id": id,
151            "method": method,
152            "params": params,
153        });
154
155        // Register a oneshot channel to receive the response
156        let (tx, rx) = oneshot::channel();
157        {
158            let mut pending = self.pending.lock().await;
159            pending.insert(id, tx);
160        }
161
162        // Send the request (serialized through stdout mutex)
163        let line = serde_json::to_string(&request)
164            .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
165        if let Err(e) = self.write_line(&line) {
166            // Clean up pending entry on write failure
167            let mut pending = self.pending.lock().await;
168            pending.remove(&id);
169            return Err(e);
170        }
171
172        // Wait for the response with timeout
173        let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
174            Ok(Ok(msg)) => msg,
175            Ok(Err(_)) => {
176                // Sender dropped — host closed or stdin reader exited
177                return Err(VmError::Runtime(
178                    "Bridge: host closed connection before responding".into(),
179                ));
180            }
181            Err(_) => {
182                // Timeout — clean up pending entry
183                let mut pending = self.pending.lock().await;
184                pending.remove(&id);
185                return Err(VmError::Runtime(format!(
186                    "Bridge: host did not respond to '{method}' within {}s",
187                    DEFAULT_TIMEOUT.as_secs()
188                )));
189            }
190        };
191
192        // Check for JSON-RPC error
193        if let Some(error) = response.get("error") {
194            let message = error["message"].as_str().unwrap_or("Unknown host error");
195            let code = error["code"].as_i64().unwrap_or(-1);
196            return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
197        }
198
199        Ok(response["result"].clone())
200    }
201
202    /// Send a JSON-RPC notification to the host (no response expected).
203    /// Serialized through the stdout mutex to prevent interleaving.
204    pub fn notify(&self, method: &str, params: serde_json::Value) {
205        let notification = serde_json::json!({
206            "jsonrpc": "2.0",
207            "method": method,
208            "params": params,
209        });
210        if let Ok(line) = serde_json::to_string(&notification) {
211            let _ = self.write_line(&line);
212        }
213    }
214
215    /// Check if the host has sent a cancel notification.
216    pub fn is_cancelled(&self) -> bool {
217        self.cancelled.load(Ordering::SeqCst)
218    }
219
220    /// Send an output notification (for log/print in bridge mode).
221    pub fn send_output(&self, text: &str) {
222        self.notify("output", serde_json::json!({"text": text}));
223    }
224
225    /// Send a progress notification with optional structured data payload.
226    pub fn send_progress(
227        &self,
228        phase: &str,
229        message: &str,
230        data: Option<serde_json::Value>,
231    ) {
232        let mut payload = serde_json::json!({"phase": phase, "message": message});
233        if let Some(d) = data {
234            payload["data"] = d;
235        }
236        self.notify("progress", payload);
237    }
238
239    /// Send a structured log notification.
240    pub fn send_log(
241        &self,
242        level: &str,
243        message: &str,
244        fields: Option<serde_json::Value>,
245    ) {
246        let mut payload = serde_json::json!({"level": level, "message": message});
247        if let Some(f) = fields {
248            payload["fields"] = f;
249        }
250        self.notify("log", payload);
251    }
252}
253
254/// Convert a serde_json::Value to a VmValue.
255pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
256    crate::stdlib::json_to_vm_value(val)
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_json_rpc_request_format() {
265        let request = serde_json::json!({
266            "jsonrpc": "2.0",
267            "id": 1,
268            "method": "llm_call",
269            "params": {
270                "prompt": "Hello",
271                "system": "Be helpful",
272            },
273        });
274        let s = serde_json::to_string(&request).unwrap();
275        assert!(s.contains("\"jsonrpc\":\"2.0\""));
276        assert!(s.contains("\"id\":1"));
277        assert!(s.contains("\"method\":\"llm_call\""));
278    }
279
280    #[test]
281    fn test_json_rpc_notification_format() {
282        let notification = serde_json::json!({
283            "jsonrpc": "2.0",
284            "method": "output",
285            "params": {"text": "[harn] hello\n"},
286        });
287        let s = serde_json::to_string(&notification).unwrap();
288        assert!(s.contains("\"method\":\"output\""));
289        assert!(!s.contains("\"id\""));
290    }
291
292    #[test]
293    fn test_json_rpc_error_response_parsing() {
294        let response = serde_json::json!({
295            "jsonrpc": "2.0",
296            "id": 1,
297            "error": {
298                "code": -32600,
299                "message": "Invalid request",
300            },
301        });
302        assert!(response.get("error").is_some());
303        assert_eq!(
304            response["error"]["message"].as_str().unwrap(),
305            "Invalid request"
306        );
307    }
308
309    #[test]
310    fn test_json_rpc_success_response_parsing() {
311        let response = serde_json::json!({
312            "jsonrpc": "2.0",
313            "id": 1,
314            "result": {
315                "text": "Hello world",
316                "input_tokens": 10,
317                "output_tokens": 5,
318            },
319        });
320        assert!(response.get("result").is_some());
321        assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
322    }
323
324    #[test]
325    fn test_cancelled_flag() {
326        let cancelled = Arc::new(AtomicBool::new(false));
327        assert!(!cancelled.load(Ordering::SeqCst));
328        cancelled.store(true, Ordering::SeqCst);
329        assert!(cancelled.load(Ordering::SeqCst));
330    }
331
332    #[test]
333    fn test_json_result_to_vm_value_string() {
334        let val = serde_json::json!("hello");
335        let vm_val = json_result_to_vm_value(&val);
336        assert_eq!(vm_val.display(), "hello");
337    }
338
339    #[test]
340    fn test_json_result_to_vm_value_dict() {
341        let val = serde_json::json!({"name": "test", "count": 42});
342        let vm_val = json_result_to_vm_value(&val);
343        let VmValue::Dict(d) = &vm_val else {
344            unreachable!("Expected Dict, got {:?}", vm_val);
345        };
346        assert_eq!(d.get("name").unwrap().display(), "test");
347        assert_eq!(d.get("count").unwrap().display(), "42");
348    }
349
350    #[test]
351    fn test_json_result_to_vm_value_null() {
352        let val = serde_json::json!(null);
353        let vm_val = json_result_to_vm_value(&val);
354        assert!(matches!(vm_val, VmValue::Nil));
355    }
356
357    #[test]
358    fn test_json_result_to_vm_value_nested() {
359        let val = serde_json::json!({
360            "text": "response",
361            "tool_calls": [
362                {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
363            ],
364            "input_tokens": 100,
365            "output_tokens": 50,
366        });
367        let vm_val = json_result_to_vm_value(&val);
368        let VmValue::Dict(d) = &vm_val else {
369            unreachable!("Expected Dict, got {:?}", vm_val);
370        };
371        assert_eq!(d.get("text").unwrap().display(), "response");
372        let VmValue::List(list) = d.get("tool_calls").unwrap() else {
373            unreachable!("Expected List for tool_calls");
374        };
375        assert_eq!(list.len(), 1);
376    }
377
378    #[test]
379    fn test_timeout_duration() {
380        assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
381    }
382}