Skip to main content

harn_vm/
bridge.rs

1//! JSON-RPC 2.0 bridge for host communication.
2//!
3//! When `harn run --bridge` is used, the VM delegates builtins (llm_call,
4//! file I/O, tool execution) to a host process over stdin/stdout JSON-RPC.
5//! The host (e.g., Burin IDE) handles these requests using its own providers.
6
7use std::collections::{HashMap, VecDeque};
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{ErrorCategory, VmError, VmValue};
17
18/// Default timeout for bridge calls (5 minutes).
19const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21/// A JSON-RPC 2.0 bridge to a host process over stdin/stdout.
22///
23/// The bridge sends requests to the host on stdout and receives responses
24/// on stdin. A background task reads stdin and dispatches responses to
25/// waiting callers by request ID. All stdout writes are serialized through
26/// a mutex to prevent interleaving.
27pub struct HostBridge {
28    next_id: AtomicU64,
29    /// Pending request waiters, keyed by JSON-RPC id.
30    pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31    /// Whether the host has sent a cancel notification.
32    cancelled: Arc<AtomicBool>,
33    /// Mutex protecting stdout writes to prevent interleaving.
34    stdout_lock: Arc<std::sync::Mutex<()>>,
35    /// ACP session ID (set in ACP mode for session-scoped notifications).
36    session_id: std::sync::Mutex<String>,
37    /// Name of the currently executing Harn script (without .harn suffix).
38    script_name: std::sync::Mutex<String>,
39    /// User messages injected by the host while a run is active.
40    queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>>,
41    /// Host-triggered resume signal for daemon agents.
42    resume_requested: Arc<AtomicBool>,
43}
44
45#[derive(Clone, Debug, PartialEq, Eq)]
46pub enum QueuedUserMessageMode {
47    InterruptImmediate,
48    FinishStep,
49    WaitForCompletion,
50}
51
52#[derive(Clone, Copy, Debug, PartialEq, Eq)]
53pub enum DeliveryCheckpoint {
54    InterruptImmediate,
55    AfterCurrentOperation,
56    EndOfInteraction,
57}
58
59impl QueuedUserMessageMode {
60    fn from_str(value: &str) -> Self {
61        match value {
62            "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
63            "finish_step" | "after_current_operation" => Self::FinishStep,
64            _ => Self::WaitForCompletion,
65        }
66    }
67}
68
69#[derive(Clone, Debug, PartialEq, Eq)]
70pub struct QueuedUserMessage {
71    pub content: String,
72    pub mode: QueuedUserMessageMode,
73}
74
75// Default doesn't apply — new() spawns async tasks requiring a tokio LocalSet.
76#[allow(clippy::new_without_default)]
77impl HostBridge {
78    /// Create a new bridge and spawn the stdin reader task.
79    ///
80    /// Must be called within a tokio LocalSet (uses spawn_local for the
81    /// stdin reader since it's single-threaded).
82    pub fn new() -> Self {
83        let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
84            Arc::new(Mutex::new(HashMap::new()));
85        let cancelled = Arc::new(AtomicBool::new(false));
86        let queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>> =
87            Arc::new(Mutex::new(VecDeque::new()));
88        let resume_requested = Arc::new(AtomicBool::new(false));
89
90        // Stdin reader: reads JSON-RPC lines and dispatches responses
91        let pending_clone = pending.clone();
92        let cancelled_clone = cancelled.clone();
93        let queued_clone = queued_user_messages.clone();
94        let resume_clone = resume_requested.clone();
95        tokio::task::spawn_local(async move {
96            let stdin = tokio::io::stdin();
97            let reader = tokio::io::BufReader::new(stdin);
98            let mut lines = reader.lines();
99
100            while let Ok(Some(line)) = lines.next_line().await {
101                let line = line.trim().to_string();
102                if line.is_empty() {
103                    continue;
104                }
105
106                let msg: serde_json::Value = match serde_json::from_str(&line) {
107                    Ok(v) => v,
108                    Err(_) => continue, // Skip malformed lines
109                };
110
111                // Check if this is a notification from the host (no id)
112                if msg.get("id").is_none() {
113                    if let Some(method) = msg["method"].as_str() {
114                        if method == "cancel" {
115                            cancelled_clone.store(true, Ordering::SeqCst);
116                        } else if method == "agent/resume" {
117                            resume_clone.store(true, Ordering::SeqCst);
118                        } else if method == "user_message"
119                            || method == "session/input"
120                            || method == "agent/user_message"
121                        {
122                            let params = &msg["params"];
123                            let content = params
124                                .get("content")
125                                .and_then(|v| v.as_str())
126                                .unwrap_or("")
127                                .to_string();
128                            if !content.is_empty() {
129                                let mode = QueuedUserMessageMode::from_str(
130                                    params
131                                        .get("mode")
132                                        .and_then(|v| v.as_str())
133                                        .unwrap_or("wait_for_completion"),
134                                );
135                                queued_clone
136                                    .lock()
137                                    .await
138                                    .push_back(QueuedUserMessage { content, mode });
139                            }
140                        }
141                    }
142                    continue;
143                }
144
145                // This is a response — dispatch to the waiting caller
146                if let Some(id) = msg["id"].as_u64() {
147                    let mut pending = pending_clone.lock().await;
148                    if let Some(sender) = pending.remove(&id) {
149                        let _ = sender.send(msg);
150                    }
151                }
152            }
153
154            // stdin closed — cancel any remaining pending requests by dropping senders
155            let mut pending = pending_clone.lock().await;
156            pending.clear();
157        });
158
159        Self {
160            next_id: AtomicU64::new(1),
161            pending,
162            cancelled,
163            stdout_lock: Arc::new(std::sync::Mutex::new(())),
164            session_id: std::sync::Mutex::new(String::new()),
165            script_name: std::sync::Mutex::new(String::new()),
166            queued_user_messages,
167            resume_requested,
168        }
169    }
170
171    /// Create a bridge from pre-existing shared state.
172    ///
173    /// Unlike `new()`, does **not** spawn a stdin reader — the caller is
174    /// responsible for dispatching responses into `pending`.  This is used
175    /// by ACP mode which already has its own stdin reader.
176    pub fn from_parts(
177        pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
178        cancelled: Arc<AtomicBool>,
179        stdout_lock: Arc<std::sync::Mutex<()>>,
180        start_id: u64,
181    ) -> Self {
182        Self {
183            next_id: AtomicU64::new(start_id),
184            pending,
185            cancelled,
186            stdout_lock,
187            session_id: std::sync::Mutex::new(String::new()),
188            script_name: std::sync::Mutex::new(String::new()),
189            queued_user_messages: Arc::new(Mutex::new(VecDeque::new())),
190            resume_requested: Arc::new(AtomicBool::new(false)),
191        }
192    }
193
194    /// Set the ACP session ID for session-scoped notifications.
195    pub fn set_session_id(&self, id: &str) {
196        *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
197    }
198
199    /// Set the currently executing script name (without .harn suffix).
200    pub fn set_script_name(&self, name: &str) {
201        *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
202    }
203
204    /// Get the current script name.
205    fn get_script_name(&self) -> String {
206        self.script_name
207            .lock()
208            .unwrap_or_else(|e| e.into_inner())
209            .clone()
210    }
211
212    /// Get the session ID.
213    fn get_session_id(&self) -> String {
214        self.session_id
215            .lock()
216            .unwrap_or_else(|e| e.into_inner())
217            .clone()
218    }
219
220    /// Write a complete JSON-RPC line to stdout, serialized through a mutex.
221    fn write_line(&self, line: &str) -> Result<(), VmError> {
222        let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
223        let mut stdout = std::io::stdout().lock();
224        stdout
225            .write_all(line.as_bytes())
226            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
227        stdout
228            .write_all(b"\n")
229            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
230        stdout
231            .flush()
232            .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
233        Ok(())
234    }
235
236    /// Send a JSON-RPC request to the host and wait for the response.
237    /// Times out after 5 minutes to prevent deadlocks.
238    pub async fn call(
239        &self,
240        method: &str,
241        params: serde_json::Value,
242    ) -> Result<serde_json::Value, VmError> {
243        if self.is_cancelled() {
244            return Err(VmError::Runtime("Bridge: operation cancelled".into()));
245        }
246
247        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
248
249        let request = serde_json::json!({
250            "jsonrpc": "2.0",
251            "id": id,
252            "method": method,
253            "params": params,
254        });
255
256        // Register a oneshot channel to receive the response
257        let (tx, rx) = oneshot::channel();
258        {
259            let mut pending = self.pending.lock().await;
260            pending.insert(id, tx);
261        }
262
263        // Send the request (serialized through stdout mutex)
264        let line = serde_json::to_string(&request)
265            .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
266        if let Err(e) = self.write_line(&line) {
267            // Clean up pending entry on write failure
268            let mut pending = self.pending.lock().await;
269            pending.remove(&id);
270            return Err(e);
271        }
272
273        // Wait for the response with timeout
274        let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
275            Ok(Ok(msg)) => msg,
276            Ok(Err(_)) => {
277                // Sender dropped — host closed or stdin reader exited
278                return Err(VmError::Runtime(
279                    "Bridge: host closed connection before responding".into(),
280                ));
281            }
282            Err(_) => {
283                // Timeout — clean up pending entry
284                let mut pending = self.pending.lock().await;
285                pending.remove(&id);
286                return Err(VmError::Runtime(format!(
287                    "Bridge: host did not respond to '{method}' within {}s",
288                    DEFAULT_TIMEOUT.as_secs()
289                )));
290            }
291        };
292
293        // Check for JSON-RPC error
294        if let Some(error) = response.get("error") {
295            let message = error["message"].as_str().unwrap_or("Unknown host error");
296            let code = error["code"].as_i64().unwrap_or(-1);
297            // -32001: tool rejected by host (not permitted / not in allowlist)
298            if code == -32001 {
299                return Err(VmError::CategorizedError {
300                    message: message.to_string(),
301                    category: ErrorCategory::ToolRejected,
302                });
303            }
304            return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
305        }
306
307        Ok(response["result"].clone())
308    }
309
310    /// Send a JSON-RPC notification to the host (no response expected).
311    /// Serialized through the stdout mutex to prevent interleaving.
312    pub fn notify(&self, method: &str, params: serde_json::Value) {
313        let notification = serde_json::json!({
314            "jsonrpc": "2.0",
315            "method": method,
316            "params": params,
317        });
318        if let Ok(line) = serde_json::to_string(&notification) {
319            let _ = self.write_line(&line);
320        }
321    }
322
323    /// Check if the host has sent a cancel notification.
324    pub fn is_cancelled(&self) -> bool {
325        self.cancelled.load(Ordering::SeqCst)
326    }
327
328    pub fn take_resume_signal(&self) -> bool {
329        self.resume_requested.swap(false, Ordering::SeqCst)
330    }
331
332    pub fn signal_resume(&self) {
333        self.resume_requested.store(true, Ordering::SeqCst);
334    }
335
336    pub async fn push_queued_user_message(&self, content: String, mode: &str) {
337        self.queued_user_messages
338            .lock()
339            .await
340            .push_back(QueuedUserMessage {
341                content,
342                mode: QueuedUserMessageMode::from_str(mode),
343            });
344    }
345
346    pub async fn take_queued_user_messages(
347        &self,
348        include_interrupt_immediate: bool,
349        include_finish_step: bool,
350        include_wait_for_completion: bool,
351    ) -> Vec<QueuedUserMessage> {
352        let mut queue = self.queued_user_messages.lock().await;
353        let mut selected = Vec::new();
354        let mut retained = VecDeque::new();
355        while let Some(message) = queue.pop_front() {
356            let should_take = match message.mode {
357                QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
358                QueuedUserMessageMode::FinishStep => include_finish_step,
359                QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
360            };
361            if should_take {
362                selected.push(message);
363            } else {
364                retained.push_back(message);
365            }
366        }
367        *queue = retained;
368        selected
369    }
370
371    pub async fn take_queued_user_messages_for(
372        &self,
373        checkpoint: DeliveryCheckpoint,
374    ) -> Vec<QueuedUserMessage> {
375        match checkpoint {
376            DeliveryCheckpoint::InterruptImmediate => {
377                self.take_queued_user_messages(true, false, false).await
378            }
379            DeliveryCheckpoint::AfterCurrentOperation => {
380                self.take_queued_user_messages(false, true, false).await
381            }
382            DeliveryCheckpoint::EndOfInteraction => {
383                self.take_queued_user_messages(false, false, true).await
384            }
385        }
386    }
387
388    /// Send an output notification (for log/print in bridge mode).
389    pub fn send_output(&self, text: &str) {
390        self.notify("output", serde_json::json!({"text": text}));
391    }
392
393    /// Send a progress notification with optional numeric progress and structured data.
394    pub fn send_progress(
395        &self,
396        phase: &str,
397        message: &str,
398        progress: Option<i64>,
399        total: Option<i64>,
400        data: Option<serde_json::Value>,
401    ) {
402        let mut payload = serde_json::json!({"phase": phase, "message": message});
403        if let Some(p) = progress {
404            payload["progress"] = serde_json::json!(p);
405        }
406        if let Some(t) = total {
407            payload["total"] = serde_json::json!(t);
408        }
409        if let Some(d) = data {
410            payload["data"] = d;
411        }
412        self.notify("progress", payload);
413    }
414
415    /// Send a structured log notification.
416    pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
417        let mut payload = serde_json::json!({"level": level, "message": message});
418        if let Some(f) = fields {
419            payload["fields"] = f;
420        }
421        self.notify("log", payload);
422    }
423
424    /// Send a `session/update` with `call_start` — signals the beginning of
425    /// an LLM call, tool call, or builtin call for observability.
426    pub fn send_call_start(
427        &self,
428        call_id: &str,
429        call_type: &str,
430        name: &str,
431        metadata: serde_json::Value,
432    ) {
433        let session_id = self.get_session_id();
434        let script = self.get_script_name();
435        self.notify(
436            "session/update",
437            serde_json::json!({
438                "sessionId": session_id,
439                "update": {
440                    "sessionUpdate": "call_start",
441                    "content": {
442                        "call_id": call_id,
443                        "call_type": call_type,
444                        "name": name,
445                        "script": script,
446                        "metadata": metadata,
447                    },
448                },
449            }),
450        );
451    }
452
453    /// Send a `session/update` with `call_progress` — a streaming token delta
454    /// from an in-flight LLM call.
455    pub fn send_call_progress(&self, call_id: &str, delta: &str, accumulated_tokens: u64) {
456        let session_id = self.get_session_id();
457        self.notify(
458            "session/update",
459            serde_json::json!({
460                "sessionId": session_id,
461                "update": {
462                    "sessionUpdate": "call_progress",
463                    "content": {
464                        "call_id": call_id,
465                        "delta": delta,
466                        "accumulated_tokens": accumulated_tokens,
467                    },
468                },
469            }),
470        );
471    }
472
473    /// Send a `session/update` with `call_end` — signals completion of a call.
474    pub fn send_call_end(
475        &self,
476        call_id: &str,
477        call_type: &str,
478        name: &str,
479        duration_ms: u64,
480        status: &str,
481        metadata: serde_json::Value,
482    ) {
483        let session_id = self.get_session_id();
484        let script = self.get_script_name();
485        self.notify(
486            "session/update",
487            serde_json::json!({
488                "sessionId": session_id,
489                "update": {
490                    "sessionUpdate": "call_end",
491                    "content": {
492                        "call_id": call_id,
493                        "call_type": call_type,
494                        "name": name,
495                        "script": script,
496                        "duration_ms": duration_ms,
497                        "status": status,
498                        "metadata": metadata,
499                    },
500                },
501            }),
502        );
503    }
504
505    /// Send a worker lifecycle update for delegated/background execution.
506    pub fn send_worker_update(
507        &self,
508        worker_id: &str,
509        worker_name: &str,
510        status: &str,
511        metadata: serde_json::Value,
512    ) {
513        let session_id = self.get_session_id();
514        let script = self.get_script_name();
515        let started_at = metadata.get("started_at").cloned().unwrap_or_default();
516        let finished_at = metadata.get("finished_at").cloned().unwrap_or_default();
517        let snapshot_path = metadata.get("snapshot_path").cloned().unwrap_or_default();
518        let run_id = metadata.get("child_run_id").cloned().unwrap_or_default();
519        let run_path = metadata.get("child_run_path").cloned().unwrap_or_default();
520        self.notify(
521            "session/update",
522            serde_json::json!({
523                "sessionId": session_id,
524                "update": {
525                    "sessionUpdate": "worker_update",
526                    "content": {
527                        "worker_id": worker_id,
528                        "worker_name": worker_name,
529                        "status": status,
530                        "script": script,
531                        "started_at": started_at,
532                        "finished_at": finished_at,
533                        "snapshot_path": snapshot_path,
534                        "run_id": run_id,
535                        "run_path": run_path,
536                        "metadata": metadata,
537                    },
538                },
539            }),
540        );
541    }
542}
543
544/// Convert a serde_json::Value to a VmValue.
545pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
546    crate::stdlib::json_to_vm_value(val)
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_json_rpc_request_format() {
555        let request = serde_json::json!({
556            "jsonrpc": "2.0",
557            "id": 1,
558            "method": "llm_call",
559            "params": {
560                "prompt": "Hello",
561                "system": "Be helpful",
562            },
563        });
564        let s = serde_json::to_string(&request).unwrap();
565        assert!(s.contains("\"jsonrpc\":\"2.0\""));
566        assert!(s.contains("\"id\":1"));
567        assert!(s.contains("\"method\":\"llm_call\""));
568    }
569
570    #[test]
571    fn test_json_rpc_notification_format() {
572        let notification = serde_json::json!({
573            "jsonrpc": "2.0",
574            "method": "output",
575            "params": {"text": "[harn] hello\n"},
576        });
577        let s = serde_json::to_string(&notification).unwrap();
578        assert!(s.contains("\"method\":\"output\""));
579        assert!(!s.contains("\"id\""));
580    }
581
582    #[test]
583    fn test_json_rpc_error_response_parsing() {
584        let response = serde_json::json!({
585            "jsonrpc": "2.0",
586            "id": 1,
587            "error": {
588                "code": -32600,
589                "message": "Invalid request",
590            },
591        });
592        assert!(response.get("error").is_some());
593        assert_eq!(
594            response["error"]["message"].as_str().unwrap(),
595            "Invalid request"
596        );
597    }
598
599    #[test]
600    fn test_json_rpc_success_response_parsing() {
601        let response = serde_json::json!({
602            "jsonrpc": "2.0",
603            "id": 1,
604            "result": {
605                "text": "Hello world",
606                "input_tokens": 10,
607                "output_tokens": 5,
608            },
609        });
610        assert!(response.get("result").is_some());
611        assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
612    }
613
614    #[test]
615    fn test_cancelled_flag() {
616        let cancelled = Arc::new(AtomicBool::new(false));
617        assert!(!cancelled.load(Ordering::SeqCst));
618        cancelled.store(true, Ordering::SeqCst);
619        assert!(cancelled.load(Ordering::SeqCst));
620    }
621
622    #[test]
623    fn queued_messages_are_filtered_by_delivery_mode() {
624        let runtime = tokio::runtime::Builder::new_current_thread()
625            .enable_all()
626            .build()
627            .unwrap();
628        runtime.block_on(async {
629            let bridge = HostBridge::from_parts(
630                Arc::new(Mutex::new(HashMap::new())),
631                Arc::new(AtomicBool::new(false)),
632                Arc::new(std::sync::Mutex::new(())),
633                1,
634            );
635            bridge
636                .push_queued_user_message("first".to_string(), "finish_step")
637                .await;
638            bridge
639                .push_queued_user_message("second".to_string(), "wait_for_completion")
640                .await;
641
642            let finish_step = bridge.take_queued_user_messages(false, true, false).await;
643            assert_eq!(finish_step.len(), 1);
644            assert_eq!(finish_step[0].content, "first");
645
646            let turn_end = bridge.take_queued_user_messages(false, false, true).await;
647            assert_eq!(turn_end.len(), 1);
648            assert_eq!(turn_end[0].content, "second");
649        });
650    }
651
652    #[test]
653    fn test_json_result_to_vm_value_string() {
654        let val = serde_json::json!("hello");
655        let vm_val = json_result_to_vm_value(&val);
656        assert_eq!(vm_val.display(), "hello");
657    }
658
659    #[test]
660    fn test_json_result_to_vm_value_dict() {
661        let val = serde_json::json!({"name": "test", "count": 42});
662        let vm_val = json_result_to_vm_value(&val);
663        let VmValue::Dict(d) = &vm_val else {
664            unreachable!("Expected Dict, got {:?}", vm_val);
665        };
666        assert_eq!(d.get("name").unwrap().display(), "test");
667        assert_eq!(d.get("count").unwrap().display(), "42");
668    }
669
670    #[test]
671    fn test_json_result_to_vm_value_null() {
672        let val = serde_json::json!(null);
673        let vm_val = json_result_to_vm_value(&val);
674        assert!(matches!(vm_val, VmValue::Nil));
675    }
676
677    #[test]
678    fn test_json_result_to_vm_value_nested() {
679        let val = serde_json::json!({
680            "text": "response",
681            "tool_calls": [
682                {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
683            ],
684            "input_tokens": 100,
685            "output_tokens": 50,
686        });
687        let vm_val = json_result_to_vm_value(&val);
688        let VmValue::Dict(d) = &vm_val else {
689            unreachable!("Expected Dict, got {:?}", vm_val);
690        };
691        assert_eq!(d.get("text").unwrap().display(), "response");
692        let VmValue::List(list) = d.get("tool_calls").unwrap() else {
693            unreachable!("Expected List for tool_calls");
694        };
695        assert_eq!(list.len(), 1);
696    }
697
698    #[test]
699    fn test_timeout_duration() {
700        assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
701    }
702}