Skip to main content

harn_vm/
bridge.rs

1//! JSON-RPC 2.0 bridge for host communication.
2//!
3//! When `harn run --bridge` is used, the VM delegates builtins (llm_call,
4//! file I/O, tool execution) to a host process over stdin/stdout JSON-RPC.
5//! The host (e.g., Burin IDE) handles these requests using its own providers.
6
7use std::collections::HashMap;
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{ErrorCategory, VmError, VmValue};
17
18/// Default timeout for bridge calls (5 minutes).
19const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21/// A JSON-RPC 2.0 bridge to a host process over stdin/stdout.
22///
23/// The bridge sends requests to the host on stdout and receives responses
24/// on stdin. A background task reads stdin and dispatches responses to
25/// waiting callers by request ID. All stdout writes are serialized through
26/// a mutex to prevent interleaving.
27pub struct HostBridge {
28    next_id: AtomicU64,
29    /// Pending request waiters, keyed by JSON-RPC id.
30    pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31    /// Whether the host has sent a cancel notification.
32    cancelled: Arc<AtomicBool>,
33    /// Mutex protecting stdout writes to prevent interleaving.
34    stdout_lock: Arc<std::sync::Mutex<()>>,
35    /// ACP session ID (set in ACP mode for session-scoped notifications).
36    session_id: std::sync::Mutex<String>,
37    /// Name of the currently executing Harn script (without .harn suffix).
38    script_name: std::sync::Mutex<String>,
39}
40
41// Default doesn't apply — new() spawns async tasks requiring a tokio LocalSet.
42#[allow(clippy::new_without_default)]
43impl HostBridge {
44    /// Create a new bridge and spawn the stdin reader task.
45    ///
46    /// Must be called within a tokio LocalSet (uses spawn_local for the
47    /// stdin reader since it's single-threaded).
48    pub fn new() -> Self {
49        let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
50            Arc::new(Mutex::new(HashMap::new()));
51        let cancelled = Arc::new(AtomicBool::new(false));
52
53        // Stdin reader: reads JSON-RPC lines and dispatches responses
54        let pending_clone = pending.clone();
55        let cancelled_clone = cancelled.clone();
56        tokio::task::spawn_local(async move {
57            let stdin = tokio::io::stdin();
58            let reader = tokio::io::BufReader::new(stdin);
59            let mut lines = reader.lines();
60
61            while let Ok(Some(line)) = lines.next_line().await {
62                let line = line.trim().to_string();
63                if line.is_empty() {
64                    continue;
65                }
66
67                let msg: serde_json::Value = match serde_json::from_str(&line) {
68                    Ok(v) => v,
69                    Err(_) => continue, // Skip malformed lines
70                };
71
72                // Check if this is a notification from the host (no id)
73                if msg.get("id").is_none() {
74                    if let Some(method) = msg["method"].as_str() {
75                        if method == "cancel" {
76                            cancelled_clone.store(true, Ordering::SeqCst);
77                        }
78                    }
79                    continue;
80                }
81
82                // This is a response — dispatch to the waiting caller
83                if let Some(id) = msg["id"].as_u64() {
84                    let mut pending = pending_clone.lock().await;
85                    if let Some(sender) = pending.remove(&id) {
86                        let _ = sender.send(msg);
87                    }
88                }
89            }
90
91            // stdin closed — cancel any remaining pending requests by dropping senders
92            let mut pending = pending_clone.lock().await;
93            pending.clear();
94        });
95
96        Self {
97            next_id: AtomicU64::new(1),
98            pending,
99            cancelled,
100            stdout_lock: Arc::new(std::sync::Mutex::new(())),
101            session_id: std::sync::Mutex::new(String::new()),
102            script_name: std::sync::Mutex::new(String::new()),
103        }
104    }
105
106    /// Create a bridge from pre-existing shared state.
107    ///
108    /// Unlike `new()`, does **not** spawn a stdin reader — the caller is
109    /// responsible for dispatching responses into `pending`.  This is used
110    /// by ACP mode which already has its own stdin reader.
111    pub fn from_parts(
112        pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
113        cancelled: Arc<AtomicBool>,
114        stdout_lock: Arc<std::sync::Mutex<()>>,
115        start_id: u64,
116    ) -> Self {
117        Self {
118            next_id: AtomicU64::new(start_id),
119            pending,
120            cancelled,
121            stdout_lock,
122            session_id: std::sync::Mutex::new(String::new()),
123            script_name: std::sync::Mutex::new(String::new()),
124        }
125    }
126
127    /// Set the ACP session ID for session-scoped notifications.
128    pub fn set_session_id(&self, id: &str) {
129        *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
130    }
131
132    /// Set the currently executing script name (without .harn suffix).
133    pub fn set_script_name(&self, name: &str) {
134        *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
135    }
136
137    /// Get the current script name.
138    fn get_script_name(&self) -> String {
139        self.script_name
140            .lock()
141            .unwrap_or_else(|e| e.into_inner())
142            .clone()
143    }
144
145    /// Get the session ID.
146    fn get_session_id(&self) -> String {
147        self.session_id
148            .lock()
149            .unwrap_or_else(|e| e.into_inner())
150            .clone()
151    }
152
153    /// Write a complete JSON-RPC line to stdout, serialized through a mutex.
154    fn write_line(&self, line: &str) -> Result<(), VmError> {
155        let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
156        let mut stdout = std::io::stdout().lock();
157        stdout
158            .write_all(line.as_bytes())
159            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
160        stdout
161            .write_all(b"\n")
162            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
163        stdout
164            .flush()
165            .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
166        Ok(())
167    }
168
169    /// Send a JSON-RPC request to the host and wait for the response.
170    /// Times out after 5 minutes to prevent deadlocks.
171    pub async fn call(
172        &self,
173        method: &str,
174        params: serde_json::Value,
175    ) -> Result<serde_json::Value, VmError> {
176        if self.is_cancelled() {
177            return Err(VmError::Runtime("Bridge: operation cancelled".into()));
178        }
179
180        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
181
182        let request = serde_json::json!({
183            "jsonrpc": "2.0",
184            "id": id,
185            "method": method,
186            "params": params,
187        });
188
189        // Register a oneshot channel to receive the response
190        let (tx, rx) = oneshot::channel();
191        {
192            let mut pending = self.pending.lock().await;
193            pending.insert(id, tx);
194        }
195
196        // Send the request (serialized through stdout mutex)
197        let line = serde_json::to_string(&request)
198            .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
199        if let Err(e) = self.write_line(&line) {
200            // Clean up pending entry on write failure
201            let mut pending = self.pending.lock().await;
202            pending.remove(&id);
203            return Err(e);
204        }
205
206        // Wait for the response with timeout
207        let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
208            Ok(Ok(msg)) => msg,
209            Ok(Err(_)) => {
210                // Sender dropped — host closed or stdin reader exited
211                return Err(VmError::Runtime(
212                    "Bridge: host closed connection before responding".into(),
213                ));
214            }
215            Err(_) => {
216                // Timeout — clean up pending entry
217                let mut pending = self.pending.lock().await;
218                pending.remove(&id);
219                return Err(VmError::Runtime(format!(
220                    "Bridge: host did not respond to '{method}' within {}s",
221                    DEFAULT_TIMEOUT.as_secs()
222                )));
223            }
224        };
225
226        // Check for JSON-RPC error
227        if let Some(error) = response.get("error") {
228            let message = error["message"].as_str().unwrap_or("Unknown host error");
229            let code = error["code"].as_i64().unwrap_or(-1);
230            // -32001: tool rejected by host (not permitted / not in allowlist)
231            if code == -32001 {
232                return Err(VmError::CategorizedError {
233                    message: message.to_string(),
234                    category: ErrorCategory::ToolRejected,
235                });
236            }
237            return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
238        }
239
240        Ok(response["result"].clone())
241    }
242
243    /// Send a JSON-RPC notification to the host (no response expected).
244    /// Serialized through the stdout mutex to prevent interleaving.
245    pub fn notify(&self, method: &str, params: serde_json::Value) {
246        let notification = serde_json::json!({
247            "jsonrpc": "2.0",
248            "method": method,
249            "params": params,
250        });
251        if let Ok(line) = serde_json::to_string(&notification) {
252            let _ = self.write_line(&line);
253        }
254    }
255
256    /// Check if the host has sent a cancel notification.
257    pub fn is_cancelled(&self) -> bool {
258        self.cancelled.load(Ordering::SeqCst)
259    }
260
261    /// Send an output notification (for log/print in bridge mode).
262    pub fn send_output(&self, text: &str) {
263        self.notify("output", serde_json::json!({"text": text}));
264    }
265
266    /// Send a progress notification with optional numeric progress and structured data.
267    pub fn send_progress(
268        &self,
269        phase: &str,
270        message: &str,
271        progress: Option<i64>,
272        total: Option<i64>,
273        data: Option<serde_json::Value>,
274    ) {
275        let mut payload = serde_json::json!({"phase": phase, "message": message});
276        if let Some(p) = progress {
277            payload["progress"] = serde_json::json!(p);
278        }
279        if let Some(t) = total {
280            payload["total"] = serde_json::json!(t);
281        }
282        if let Some(d) = data {
283            payload["data"] = d;
284        }
285        self.notify("progress", payload);
286    }
287
288    /// Send a structured log notification.
289    pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
290        let mut payload = serde_json::json!({"level": level, "message": message});
291        if let Some(f) = fields {
292            payload["fields"] = f;
293        }
294        self.notify("log", payload);
295    }
296
297    /// Send a `session/update` with `call_start` — signals the beginning of
298    /// an LLM call, tool call, or builtin call for observability.
299    pub fn send_call_start(
300        &self,
301        call_id: &str,
302        call_type: &str,
303        name: &str,
304        metadata: serde_json::Value,
305    ) {
306        let session_id = self.get_session_id();
307        let script = self.get_script_name();
308        self.notify(
309            "session/update",
310            serde_json::json!({
311                "sessionId": session_id,
312                "update": {
313                    "sessionUpdate": "call_start",
314                    "content": {
315                        "call_id": call_id,
316                        "call_type": call_type,
317                        "name": name,
318                        "script": script,
319                        "metadata": metadata,
320                    },
321                },
322            }),
323        );
324    }
325
326    /// Send a `session/update` with `call_progress` — a streaming token delta
327    /// from an in-flight LLM call.
328    pub fn send_call_progress(&self, call_id: &str, delta: &str, accumulated_tokens: u64) {
329        let session_id = self.get_session_id();
330        self.notify(
331            "session/update",
332            serde_json::json!({
333                "sessionId": session_id,
334                "update": {
335                    "sessionUpdate": "call_progress",
336                    "content": {
337                        "call_id": call_id,
338                        "delta": delta,
339                        "accumulated_tokens": accumulated_tokens,
340                    },
341                },
342            }),
343        );
344    }
345
346    /// Send a `session/update` with `call_end` — signals completion of a call.
347    pub fn send_call_end(
348        &self,
349        call_id: &str,
350        call_type: &str,
351        name: &str,
352        duration_ms: u64,
353        status: &str,
354        metadata: serde_json::Value,
355    ) {
356        let session_id = self.get_session_id();
357        let script = self.get_script_name();
358        self.notify(
359            "session/update",
360            serde_json::json!({
361                "sessionId": session_id,
362                "update": {
363                    "sessionUpdate": "call_end",
364                    "content": {
365                        "call_id": call_id,
366                        "call_type": call_type,
367                        "name": name,
368                        "script": script,
369                        "duration_ms": duration_ms,
370                        "status": status,
371                        "metadata": metadata,
372                    },
373                },
374            }),
375        );
376    }
377}
378
379/// Convert a serde_json::Value to a VmValue.
380pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
381    crate::stdlib::json_to_vm_value(val)
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn test_json_rpc_request_format() {
390        let request = serde_json::json!({
391            "jsonrpc": "2.0",
392            "id": 1,
393            "method": "llm_call",
394            "params": {
395                "prompt": "Hello",
396                "system": "Be helpful",
397            },
398        });
399        let s = serde_json::to_string(&request).unwrap();
400        assert!(s.contains("\"jsonrpc\":\"2.0\""));
401        assert!(s.contains("\"id\":1"));
402        assert!(s.contains("\"method\":\"llm_call\""));
403    }
404
405    #[test]
406    fn test_json_rpc_notification_format() {
407        let notification = serde_json::json!({
408            "jsonrpc": "2.0",
409            "method": "output",
410            "params": {"text": "[harn] hello\n"},
411        });
412        let s = serde_json::to_string(&notification).unwrap();
413        assert!(s.contains("\"method\":\"output\""));
414        assert!(!s.contains("\"id\""));
415    }
416
417    #[test]
418    fn test_json_rpc_error_response_parsing() {
419        let response = serde_json::json!({
420            "jsonrpc": "2.0",
421            "id": 1,
422            "error": {
423                "code": -32600,
424                "message": "Invalid request",
425            },
426        });
427        assert!(response.get("error").is_some());
428        assert_eq!(
429            response["error"]["message"].as_str().unwrap(),
430            "Invalid request"
431        );
432    }
433
434    #[test]
435    fn test_json_rpc_success_response_parsing() {
436        let response = serde_json::json!({
437            "jsonrpc": "2.0",
438            "id": 1,
439            "result": {
440                "text": "Hello world",
441                "input_tokens": 10,
442                "output_tokens": 5,
443            },
444        });
445        assert!(response.get("result").is_some());
446        assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
447    }
448
449    #[test]
450    fn test_cancelled_flag() {
451        let cancelled = Arc::new(AtomicBool::new(false));
452        assert!(!cancelled.load(Ordering::SeqCst));
453        cancelled.store(true, Ordering::SeqCst);
454        assert!(cancelled.load(Ordering::SeqCst));
455    }
456
457    #[test]
458    fn test_json_result_to_vm_value_string() {
459        let val = serde_json::json!("hello");
460        let vm_val = json_result_to_vm_value(&val);
461        assert_eq!(vm_val.display(), "hello");
462    }
463
464    #[test]
465    fn test_json_result_to_vm_value_dict() {
466        let val = serde_json::json!({"name": "test", "count": 42});
467        let vm_val = json_result_to_vm_value(&val);
468        let VmValue::Dict(d) = &vm_val else {
469            unreachable!("Expected Dict, got {:?}", vm_val);
470        };
471        assert_eq!(d.get("name").unwrap().display(), "test");
472        assert_eq!(d.get("count").unwrap().display(), "42");
473    }
474
475    #[test]
476    fn test_json_result_to_vm_value_null() {
477        let val = serde_json::json!(null);
478        let vm_val = json_result_to_vm_value(&val);
479        assert!(matches!(vm_val, VmValue::Nil));
480    }
481
482    #[test]
483    fn test_json_result_to_vm_value_nested() {
484        let val = serde_json::json!({
485            "text": "response",
486            "tool_calls": [
487                {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
488            ],
489            "input_tokens": 100,
490            "output_tokens": 50,
491        });
492        let vm_val = json_result_to_vm_value(&val);
493        let VmValue::Dict(d) = &vm_val else {
494            unreachable!("Expected Dict, got {:?}", vm_val);
495        };
496        assert_eq!(d.get("text").unwrap().display(), "response");
497        let VmValue::List(list) = d.get("tool_calls").unwrap() else {
498            unreachable!("Expected List for tool_calls");
499        };
500        assert_eq!(list.len(), 1);
501    }
502
503    #[test]
504    fn test_timeout_duration() {
505        assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
506    }
507}