Skip to main content

harn_vm/
bridge.rs

1//! JSON-RPC 2.0 bridge for host communication.
2//!
3//! When `harn run --bridge` is used, the VM delegates builtins (llm_call,
4//! file I/O, tool execution) to a host process over stdin/stdout JSON-RPC.
5//! The host (e.g., Burin IDE) handles these requests using its own providers.
6
7use std::collections::{HashMap, VecDeque};
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{ErrorCategory, VmError, VmValue};
17
18/// Default timeout for bridge calls (5 minutes).
19const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21/// A JSON-RPC 2.0 bridge to a host process over stdin/stdout.
22///
23/// The bridge sends requests to the host on stdout and receives responses
24/// on stdin. A background task reads stdin and dispatches responses to
25/// waiting callers by request ID. All stdout writes are serialized through
26/// a mutex to prevent interleaving.
27pub struct HostBridge {
28    next_id: AtomicU64,
29    /// Pending request waiters, keyed by JSON-RPC id.
30    pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31    /// Whether the host has sent a cancel notification.
32    cancelled: Arc<AtomicBool>,
33    /// Mutex protecting stdout writes to prevent interleaving.
34    stdout_lock: Arc<std::sync::Mutex<()>>,
35    /// ACP session ID (set in ACP mode for session-scoped notifications).
36    session_id: std::sync::Mutex<String>,
37    /// Name of the currently executing Harn script (without .harn suffix).
38    script_name: std::sync::Mutex<String>,
39    /// User messages injected by the host while a run is active.
40    queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>>,
41}
42
43#[derive(Clone, Debug, PartialEq, Eq)]
44pub enum QueuedUserMessageMode {
45    InterruptImmediate,
46    FinishStep,
47    WaitForCompletion,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq, Eq)]
51pub enum DeliveryCheckpoint {
52    InterruptImmediate,
53    AfterCurrentOperation,
54    EndOfInteraction,
55}
56
57impl QueuedUserMessageMode {
58    fn from_str(value: &str) -> Self {
59        match value {
60            "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
61            "finish_step" | "after_current_operation" => Self::FinishStep,
62            _ => Self::WaitForCompletion,
63        }
64    }
65}
66
67#[derive(Clone, Debug, PartialEq, Eq)]
68pub struct QueuedUserMessage {
69    pub content: String,
70    pub mode: QueuedUserMessageMode,
71}
72
73// Default doesn't apply — new() spawns async tasks requiring a tokio LocalSet.
74#[allow(clippy::new_without_default)]
75impl HostBridge {
76    /// Create a new bridge and spawn the stdin reader task.
77    ///
78    /// Must be called within a tokio LocalSet (uses spawn_local for the
79    /// stdin reader since it's single-threaded).
80    pub fn new() -> Self {
81        let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
82            Arc::new(Mutex::new(HashMap::new()));
83        let cancelled = Arc::new(AtomicBool::new(false));
84        let queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>> =
85            Arc::new(Mutex::new(VecDeque::new()));
86
87        // Stdin reader: reads JSON-RPC lines and dispatches responses
88        let pending_clone = pending.clone();
89        let cancelled_clone = cancelled.clone();
90        let queued_clone = queued_user_messages.clone();
91        tokio::task::spawn_local(async move {
92            let stdin = tokio::io::stdin();
93            let reader = tokio::io::BufReader::new(stdin);
94            let mut lines = reader.lines();
95
96            while let Ok(Some(line)) = lines.next_line().await {
97                let line = line.trim().to_string();
98                if line.is_empty() {
99                    continue;
100                }
101
102                let msg: serde_json::Value = match serde_json::from_str(&line) {
103                    Ok(v) => v,
104                    Err(_) => continue, // Skip malformed lines
105                };
106
107                // Check if this is a notification from the host (no id)
108                if msg.get("id").is_none() {
109                    if let Some(method) = msg["method"].as_str() {
110                        if method == "cancel" {
111                            cancelled_clone.store(true, Ordering::SeqCst);
112                        } else if method == "user_message"
113                            || method == "session/input"
114                            || method == "agent/user_message"
115                        {
116                            let params = &msg["params"];
117                            let content = params
118                                .get("content")
119                                .and_then(|v| v.as_str())
120                                .unwrap_or("")
121                                .to_string();
122                            if !content.is_empty() {
123                                let mode = QueuedUserMessageMode::from_str(
124                                    params
125                                        .get("mode")
126                                        .and_then(|v| v.as_str())
127                                        .unwrap_or("wait_for_completion"),
128                                );
129                                queued_clone
130                                    .lock()
131                                    .await
132                                    .push_back(QueuedUserMessage { content, mode });
133                            }
134                        }
135                    }
136                    continue;
137                }
138
139                // This is a response — dispatch to the waiting caller
140                if let Some(id) = msg["id"].as_u64() {
141                    let mut pending = pending_clone.lock().await;
142                    if let Some(sender) = pending.remove(&id) {
143                        let _ = sender.send(msg);
144                    }
145                }
146            }
147
148            // stdin closed — cancel any remaining pending requests by dropping senders
149            let mut pending = pending_clone.lock().await;
150            pending.clear();
151        });
152
153        Self {
154            next_id: AtomicU64::new(1),
155            pending,
156            cancelled,
157            stdout_lock: Arc::new(std::sync::Mutex::new(())),
158            session_id: std::sync::Mutex::new(String::new()),
159            script_name: std::sync::Mutex::new(String::new()),
160            queued_user_messages,
161        }
162    }
163
164    /// Create a bridge from pre-existing shared state.
165    ///
166    /// Unlike `new()`, does **not** spawn a stdin reader — the caller is
167    /// responsible for dispatching responses into `pending`.  This is used
168    /// by ACP mode which already has its own stdin reader.
169    pub fn from_parts(
170        pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
171        cancelled: Arc<AtomicBool>,
172        stdout_lock: Arc<std::sync::Mutex<()>>,
173        start_id: u64,
174    ) -> Self {
175        Self {
176            next_id: AtomicU64::new(start_id),
177            pending,
178            cancelled,
179            stdout_lock,
180            session_id: std::sync::Mutex::new(String::new()),
181            script_name: std::sync::Mutex::new(String::new()),
182            queued_user_messages: Arc::new(Mutex::new(VecDeque::new())),
183        }
184    }
185
186    /// Set the ACP session ID for session-scoped notifications.
187    pub fn set_session_id(&self, id: &str) {
188        *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
189    }
190
191    /// Set the currently executing script name (without .harn suffix).
192    pub fn set_script_name(&self, name: &str) {
193        *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
194    }
195
196    /// Get the current script name.
197    fn get_script_name(&self) -> String {
198        self.script_name
199            .lock()
200            .unwrap_or_else(|e| e.into_inner())
201            .clone()
202    }
203
204    /// Get the session ID.
205    fn get_session_id(&self) -> String {
206        self.session_id
207            .lock()
208            .unwrap_or_else(|e| e.into_inner())
209            .clone()
210    }
211
212    /// Write a complete JSON-RPC line to stdout, serialized through a mutex.
213    fn write_line(&self, line: &str) -> Result<(), VmError> {
214        let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
215        let mut stdout = std::io::stdout().lock();
216        stdout
217            .write_all(line.as_bytes())
218            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
219        stdout
220            .write_all(b"\n")
221            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
222        stdout
223            .flush()
224            .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
225        Ok(())
226    }
227
228    /// Send a JSON-RPC request to the host and wait for the response.
229    /// Times out after 5 minutes to prevent deadlocks.
230    pub async fn call(
231        &self,
232        method: &str,
233        params: serde_json::Value,
234    ) -> Result<serde_json::Value, VmError> {
235        if self.is_cancelled() {
236            return Err(VmError::Runtime("Bridge: operation cancelled".into()));
237        }
238
239        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
240
241        let request = serde_json::json!({
242            "jsonrpc": "2.0",
243            "id": id,
244            "method": method,
245            "params": params,
246        });
247
248        // Register a oneshot channel to receive the response
249        let (tx, rx) = oneshot::channel();
250        {
251            let mut pending = self.pending.lock().await;
252            pending.insert(id, tx);
253        }
254
255        // Send the request (serialized through stdout mutex)
256        let line = serde_json::to_string(&request)
257            .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
258        if let Err(e) = self.write_line(&line) {
259            // Clean up pending entry on write failure
260            let mut pending = self.pending.lock().await;
261            pending.remove(&id);
262            return Err(e);
263        }
264
265        // Wait for the response with timeout
266        let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
267            Ok(Ok(msg)) => msg,
268            Ok(Err(_)) => {
269                // Sender dropped — host closed or stdin reader exited
270                return Err(VmError::Runtime(
271                    "Bridge: host closed connection before responding".into(),
272                ));
273            }
274            Err(_) => {
275                // Timeout — clean up pending entry
276                let mut pending = self.pending.lock().await;
277                pending.remove(&id);
278                return Err(VmError::Runtime(format!(
279                    "Bridge: host did not respond to '{method}' within {}s",
280                    DEFAULT_TIMEOUT.as_secs()
281                )));
282            }
283        };
284
285        // Check for JSON-RPC error
286        if let Some(error) = response.get("error") {
287            let message = error["message"].as_str().unwrap_or("Unknown host error");
288            let code = error["code"].as_i64().unwrap_or(-1);
289            // -32001: tool rejected by host (not permitted / not in allowlist)
290            if code == -32001 {
291                return Err(VmError::CategorizedError {
292                    message: message.to_string(),
293                    category: ErrorCategory::ToolRejected,
294                });
295            }
296            return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
297        }
298
299        Ok(response["result"].clone())
300    }
301
302    /// Send a JSON-RPC notification to the host (no response expected).
303    /// Serialized through the stdout mutex to prevent interleaving.
304    pub fn notify(&self, method: &str, params: serde_json::Value) {
305        let notification = serde_json::json!({
306            "jsonrpc": "2.0",
307            "method": method,
308            "params": params,
309        });
310        if let Ok(line) = serde_json::to_string(&notification) {
311            let _ = self.write_line(&line);
312        }
313    }
314
315    /// Check if the host has sent a cancel notification.
316    pub fn is_cancelled(&self) -> bool {
317        self.cancelled.load(Ordering::SeqCst)
318    }
319
320    pub async fn push_queued_user_message(&self, content: String, mode: &str) {
321        self.queued_user_messages
322            .lock()
323            .await
324            .push_back(QueuedUserMessage {
325                content,
326                mode: QueuedUserMessageMode::from_str(mode),
327            });
328    }
329
330    pub async fn take_queued_user_messages(
331        &self,
332        include_interrupt_immediate: bool,
333        include_finish_step: bool,
334        include_wait_for_completion: bool,
335    ) -> Vec<QueuedUserMessage> {
336        let mut queue = self.queued_user_messages.lock().await;
337        let mut selected = Vec::new();
338        let mut retained = VecDeque::new();
339        while let Some(message) = queue.pop_front() {
340            let should_take = match message.mode {
341                QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
342                QueuedUserMessageMode::FinishStep => include_finish_step,
343                QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
344            };
345            if should_take {
346                selected.push(message);
347            } else {
348                retained.push_back(message);
349            }
350        }
351        *queue = retained;
352        selected
353    }
354
355    pub async fn take_queued_user_messages_for(
356        &self,
357        checkpoint: DeliveryCheckpoint,
358    ) -> Vec<QueuedUserMessage> {
359        match checkpoint {
360            DeliveryCheckpoint::InterruptImmediate => {
361                self.take_queued_user_messages(true, false, false).await
362            }
363            DeliveryCheckpoint::AfterCurrentOperation => {
364                self.take_queued_user_messages(false, true, false).await
365            }
366            DeliveryCheckpoint::EndOfInteraction => {
367                self.take_queued_user_messages(false, false, true).await
368            }
369        }
370    }
371
372    /// Send an output notification (for log/print in bridge mode).
373    pub fn send_output(&self, text: &str) {
374        self.notify("output", serde_json::json!({"text": text}));
375    }
376
377    /// Send a progress notification with optional numeric progress and structured data.
378    pub fn send_progress(
379        &self,
380        phase: &str,
381        message: &str,
382        progress: Option<i64>,
383        total: Option<i64>,
384        data: Option<serde_json::Value>,
385    ) {
386        let mut payload = serde_json::json!({"phase": phase, "message": message});
387        if let Some(p) = progress {
388            payload["progress"] = serde_json::json!(p);
389        }
390        if let Some(t) = total {
391            payload["total"] = serde_json::json!(t);
392        }
393        if let Some(d) = data {
394            payload["data"] = d;
395        }
396        self.notify("progress", payload);
397    }
398
399    /// Send a structured log notification.
400    pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
401        let mut payload = serde_json::json!({"level": level, "message": message});
402        if let Some(f) = fields {
403            payload["fields"] = f;
404        }
405        self.notify("log", payload);
406    }
407
408    /// Send a `session/update` with `call_start` — signals the beginning of
409    /// an LLM call, tool call, or builtin call for observability.
410    pub fn send_call_start(
411        &self,
412        call_id: &str,
413        call_type: &str,
414        name: &str,
415        metadata: serde_json::Value,
416    ) {
417        let session_id = self.get_session_id();
418        let script = self.get_script_name();
419        self.notify(
420            "session/update",
421            serde_json::json!({
422                "sessionId": session_id,
423                "update": {
424                    "sessionUpdate": "call_start",
425                    "content": {
426                        "call_id": call_id,
427                        "call_type": call_type,
428                        "name": name,
429                        "script": script,
430                        "metadata": metadata,
431                    },
432                },
433            }),
434        );
435    }
436
437    /// Send a `session/update` with `call_progress` — a streaming token delta
438    /// from an in-flight LLM call.
439    pub fn send_call_progress(&self, call_id: &str, delta: &str, accumulated_tokens: u64) {
440        let session_id = self.get_session_id();
441        self.notify(
442            "session/update",
443            serde_json::json!({
444                "sessionId": session_id,
445                "update": {
446                    "sessionUpdate": "call_progress",
447                    "content": {
448                        "call_id": call_id,
449                        "delta": delta,
450                        "accumulated_tokens": accumulated_tokens,
451                    },
452                },
453            }),
454        );
455    }
456
457    /// Send a `session/update` with `call_end` — signals completion of a call.
458    pub fn send_call_end(
459        &self,
460        call_id: &str,
461        call_type: &str,
462        name: &str,
463        duration_ms: u64,
464        status: &str,
465        metadata: serde_json::Value,
466    ) {
467        let session_id = self.get_session_id();
468        let script = self.get_script_name();
469        self.notify(
470            "session/update",
471            serde_json::json!({
472                "sessionId": session_id,
473                "update": {
474                    "sessionUpdate": "call_end",
475                    "content": {
476                        "call_id": call_id,
477                        "call_type": call_type,
478                        "name": name,
479                        "script": script,
480                        "duration_ms": duration_ms,
481                        "status": status,
482                        "metadata": metadata,
483                    },
484                },
485            }),
486        );
487    }
488
489    /// Send a worker lifecycle update for delegated/background execution.
490    pub fn send_worker_update(
491        &self,
492        worker_id: &str,
493        worker_name: &str,
494        status: &str,
495        metadata: serde_json::Value,
496    ) {
497        let session_id = self.get_session_id();
498        let script = self.get_script_name();
499        self.notify(
500            "session/update",
501            serde_json::json!({
502                "sessionId": session_id,
503                "update": {
504                    "sessionUpdate": "worker_update",
505                    "content": {
506                        "worker_id": worker_id,
507                        "worker_name": worker_name,
508                        "status": status,
509                        "script": script,
510                        "metadata": metadata,
511                    },
512                },
513            }),
514        );
515    }
516}
517
518/// Convert a serde_json::Value to a VmValue.
519pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
520    crate::stdlib::json_to_vm_value(val)
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    #[test]
528    fn test_json_rpc_request_format() {
529        let request = serde_json::json!({
530            "jsonrpc": "2.0",
531            "id": 1,
532            "method": "llm_call",
533            "params": {
534                "prompt": "Hello",
535                "system": "Be helpful",
536            },
537        });
538        let s = serde_json::to_string(&request).unwrap();
539        assert!(s.contains("\"jsonrpc\":\"2.0\""));
540        assert!(s.contains("\"id\":1"));
541        assert!(s.contains("\"method\":\"llm_call\""));
542    }
543
544    #[test]
545    fn test_json_rpc_notification_format() {
546        let notification = serde_json::json!({
547            "jsonrpc": "2.0",
548            "method": "output",
549            "params": {"text": "[harn] hello\n"},
550        });
551        let s = serde_json::to_string(&notification).unwrap();
552        assert!(s.contains("\"method\":\"output\""));
553        assert!(!s.contains("\"id\""));
554    }
555
556    #[test]
557    fn test_json_rpc_error_response_parsing() {
558        let response = serde_json::json!({
559            "jsonrpc": "2.0",
560            "id": 1,
561            "error": {
562                "code": -32600,
563                "message": "Invalid request",
564            },
565        });
566        assert!(response.get("error").is_some());
567        assert_eq!(
568            response["error"]["message"].as_str().unwrap(),
569            "Invalid request"
570        );
571    }
572
573    #[test]
574    fn test_json_rpc_success_response_parsing() {
575        let response = serde_json::json!({
576            "jsonrpc": "2.0",
577            "id": 1,
578            "result": {
579                "text": "Hello world",
580                "input_tokens": 10,
581                "output_tokens": 5,
582            },
583        });
584        assert!(response.get("result").is_some());
585        assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
586    }
587
588    #[test]
589    fn test_cancelled_flag() {
590        let cancelled = Arc::new(AtomicBool::new(false));
591        assert!(!cancelled.load(Ordering::SeqCst));
592        cancelled.store(true, Ordering::SeqCst);
593        assert!(cancelled.load(Ordering::SeqCst));
594    }
595
596    #[test]
597    fn queued_messages_are_filtered_by_delivery_mode() {
598        let runtime = tokio::runtime::Builder::new_current_thread()
599            .enable_all()
600            .build()
601            .unwrap();
602        runtime.block_on(async {
603            let bridge = HostBridge::from_parts(
604                Arc::new(Mutex::new(HashMap::new())),
605                Arc::new(AtomicBool::new(false)),
606                Arc::new(std::sync::Mutex::new(())),
607                1,
608            );
609            bridge
610                .push_queued_user_message("first".to_string(), "finish_step")
611                .await;
612            bridge
613                .push_queued_user_message("second".to_string(), "wait_for_completion")
614                .await;
615
616            let finish_step = bridge.take_queued_user_messages(false, true, false).await;
617            assert_eq!(finish_step.len(), 1);
618            assert_eq!(finish_step[0].content, "first");
619
620            let turn_end = bridge.take_queued_user_messages(false, false, true).await;
621            assert_eq!(turn_end.len(), 1);
622            assert_eq!(turn_end[0].content, "second");
623        });
624    }
625
626    #[test]
627    fn test_json_result_to_vm_value_string() {
628        let val = serde_json::json!("hello");
629        let vm_val = json_result_to_vm_value(&val);
630        assert_eq!(vm_val.display(), "hello");
631    }
632
633    #[test]
634    fn test_json_result_to_vm_value_dict() {
635        let val = serde_json::json!({"name": "test", "count": 42});
636        let vm_val = json_result_to_vm_value(&val);
637        let VmValue::Dict(d) = &vm_val else {
638            unreachable!("Expected Dict, got {:?}", vm_val);
639        };
640        assert_eq!(d.get("name").unwrap().display(), "test");
641        assert_eq!(d.get("count").unwrap().display(), "42");
642    }
643
644    #[test]
645    fn test_json_result_to_vm_value_null() {
646        let val = serde_json::json!(null);
647        let vm_val = json_result_to_vm_value(&val);
648        assert!(matches!(vm_val, VmValue::Nil));
649    }
650
651    #[test]
652    fn test_json_result_to_vm_value_nested() {
653        let val = serde_json::json!({
654            "text": "response",
655            "tool_calls": [
656                {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
657            ],
658            "input_tokens": 100,
659            "output_tokens": 50,
660        });
661        let vm_val = json_result_to_vm_value(&val);
662        let VmValue::Dict(d) = &vm_val else {
663            unreachable!("Expected Dict, got {:?}", vm_val);
664        };
665        assert_eq!(d.get("text").unwrap().display(), "response");
666        let VmValue::List(list) = d.get("tool_calls").unwrap() else {
667            unreachable!("Expected List for tool_calls");
668        };
669        assert_eq!(list.len(), 1);
670    }
671
672    #[test]
673    fn test_timeout_duration() {
674        assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
675    }
676}