Skip to main content

harn_vm/
bridge.rs

1//! JSON-RPC 2.0 bridge for host communication.
2//!
3//! When `harn run --bridge` is used, the VM delegates builtins (llm_call,
4//! file I/O, tool execution) to a host process over stdin/stdout JSON-RPC.
5//! The host application handles these requests using its own providers.
6
7use std::collections::{HashMap, VecDeque};
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::orchestration::MutationSessionRecord;
17use crate::value::{ErrorCategory, VmError, VmValue};
18
19/// Default timeout for bridge calls (5 minutes).
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
21
22/// A JSON-RPC 2.0 bridge to a host process over stdin/stdout.
23///
24/// The bridge sends requests to the host on stdout and receives responses
25/// on stdin. A background task reads stdin and dispatches responses to
26/// waiting callers by request ID. All stdout writes are serialized through
27/// a mutex to prevent interleaving.
28pub struct HostBridge {
29    next_id: AtomicU64,
30    /// Pending request waiters, keyed by JSON-RPC id.
31    pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
32    /// Whether the host has sent a cancel notification.
33    cancelled: Arc<AtomicBool>,
34    /// Mutex protecting stdout writes to prevent interleaving.
35    stdout_lock: Arc<std::sync::Mutex<()>>,
36    /// ACP session ID (set in ACP mode for session-scoped notifications).
37    session_id: std::sync::Mutex<String>,
38    /// Name of the currently executing Harn script (without .harn suffix).
39    script_name: std::sync::Mutex<String>,
40    /// User messages injected by the host while a run is active.
41    queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>>,
42    /// Host-triggered resume signal for daemon agents.
43    resume_requested: Arc<AtomicBool>,
44}
45
46#[derive(Clone, Debug, PartialEq, Eq)]
47pub enum QueuedUserMessageMode {
48    InterruptImmediate,
49    FinishStep,
50    WaitForCompletion,
51}
52
53#[derive(Clone, Copy, Debug, PartialEq, Eq)]
54pub enum DeliveryCheckpoint {
55    InterruptImmediate,
56    AfterCurrentOperation,
57    EndOfInteraction,
58}
59
60impl QueuedUserMessageMode {
61    fn from_str(value: &str) -> Self {
62        match value {
63            "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
64            "finish_step" | "after_current_operation" => Self::FinishStep,
65            _ => Self::WaitForCompletion,
66        }
67    }
68}
69
70#[derive(Clone, Debug, PartialEq, Eq)]
71pub struct QueuedUserMessage {
72    pub content: String,
73    pub mode: QueuedUserMessageMode,
74}
75
76// Default doesn't apply — new() spawns async tasks requiring a tokio LocalSet.
77#[allow(clippy::new_without_default)]
78impl HostBridge {
79    /// Create a new bridge and spawn the stdin reader task.
80    ///
81    /// Must be called within a tokio LocalSet (uses spawn_local for the
82    /// stdin reader since it's single-threaded).
83    pub fn new() -> Self {
84        let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
85            Arc::new(Mutex::new(HashMap::new()));
86        let cancelled = Arc::new(AtomicBool::new(false));
87        let queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>> =
88            Arc::new(Mutex::new(VecDeque::new()));
89        let resume_requested = Arc::new(AtomicBool::new(false));
90
91        // Stdin reader: reads JSON-RPC lines and dispatches responses
92        let pending_clone = pending.clone();
93        let cancelled_clone = cancelled.clone();
94        let queued_clone = queued_user_messages.clone();
95        let resume_clone = resume_requested.clone();
96        tokio::task::spawn_local(async move {
97            let stdin = tokio::io::stdin();
98            let reader = tokio::io::BufReader::new(stdin);
99            let mut lines = reader.lines();
100
101            while let Ok(Some(line)) = lines.next_line().await {
102                let line = line.trim().to_string();
103                if line.is_empty() {
104                    continue;
105                }
106
107                let msg: serde_json::Value = match serde_json::from_str(&line) {
108                    Ok(v) => v,
109                    Err(_) => continue, // Skip malformed lines
110                };
111
112                // Check if this is a notification from the host (no id)
113                if msg.get("id").is_none() {
114                    if let Some(method) = msg["method"].as_str() {
115                        if method == "cancel" {
116                            cancelled_clone.store(true, Ordering::SeqCst);
117                        } else if method == "agent/resume" {
118                            resume_clone.store(true, Ordering::SeqCst);
119                        } else if method == "user_message"
120                            || method == "session/input"
121                            || method == "agent/user_message"
122                        {
123                            let params = &msg["params"];
124                            let content = params
125                                .get("content")
126                                .and_then(|v| v.as_str())
127                                .unwrap_or("")
128                                .to_string();
129                            if !content.is_empty() {
130                                let mode = QueuedUserMessageMode::from_str(
131                                    params
132                                        .get("mode")
133                                        .and_then(|v| v.as_str())
134                                        .unwrap_or("wait_for_completion"),
135                                );
136                                queued_clone
137                                    .lock()
138                                    .await
139                                    .push_back(QueuedUserMessage { content, mode });
140                            }
141                        }
142                    }
143                    continue;
144                }
145
146                // This is a response — dispatch to the waiting caller
147                if let Some(id) = msg["id"].as_u64() {
148                    let mut pending = pending_clone.lock().await;
149                    if let Some(sender) = pending.remove(&id) {
150                        let _ = sender.send(msg);
151                    }
152                }
153            }
154
155            // stdin closed — cancel any remaining pending requests by dropping senders
156            let mut pending = pending_clone.lock().await;
157            pending.clear();
158        });
159
160        Self {
161            next_id: AtomicU64::new(1),
162            pending,
163            cancelled,
164            stdout_lock: Arc::new(std::sync::Mutex::new(())),
165            session_id: std::sync::Mutex::new(String::new()),
166            script_name: std::sync::Mutex::new(String::new()),
167            queued_user_messages,
168            resume_requested,
169        }
170    }
171
172    /// Create a bridge from pre-existing shared state.
173    ///
174    /// Unlike `new()`, does **not** spawn a stdin reader — the caller is
175    /// responsible for dispatching responses into `pending`.  This is used
176    /// by ACP mode which already has its own stdin reader.
177    pub fn from_parts(
178        pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
179        cancelled: Arc<AtomicBool>,
180        stdout_lock: Arc<std::sync::Mutex<()>>,
181        start_id: u64,
182    ) -> Self {
183        Self {
184            next_id: AtomicU64::new(start_id),
185            pending,
186            cancelled,
187            stdout_lock,
188            session_id: std::sync::Mutex::new(String::new()),
189            script_name: std::sync::Mutex::new(String::new()),
190            queued_user_messages: Arc::new(Mutex::new(VecDeque::new())),
191            resume_requested: Arc::new(AtomicBool::new(false)),
192        }
193    }
194
195    /// Set the ACP session ID for session-scoped notifications.
196    pub fn set_session_id(&self, id: &str) {
197        *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
198    }
199
200    /// Set the currently executing script name (without .harn suffix).
201    pub fn set_script_name(&self, name: &str) {
202        *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
203    }
204
205    /// Get the current script name.
206    fn get_script_name(&self) -> String {
207        self.script_name
208            .lock()
209            .unwrap_or_else(|e| e.into_inner())
210            .clone()
211    }
212
213    /// Get the session ID.
214    fn get_session_id(&self) -> String {
215        self.session_id
216            .lock()
217            .unwrap_or_else(|e| e.into_inner())
218            .clone()
219    }
220
221    /// Write a complete JSON-RPC line to stdout, serialized through a mutex.
222    fn write_line(&self, line: &str) -> Result<(), VmError> {
223        let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
224        let mut stdout = std::io::stdout().lock();
225        stdout
226            .write_all(line.as_bytes())
227            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
228        stdout
229            .write_all(b"\n")
230            .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
231        stdout
232            .flush()
233            .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
234        Ok(())
235    }
236
237    /// Send a JSON-RPC request to the host and wait for the response.
238    /// Times out after 5 minutes to prevent deadlocks.
239    pub async fn call(
240        &self,
241        method: &str,
242        params: serde_json::Value,
243    ) -> Result<serde_json::Value, VmError> {
244        if self.is_cancelled() {
245            return Err(VmError::Runtime("Bridge: operation cancelled".into()));
246        }
247
248        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
249
250        let request = serde_json::json!({
251            "jsonrpc": "2.0",
252            "id": id,
253            "method": method,
254            "params": params,
255        });
256
257        // Register a oneshot channel to receive the response
258        let (tx, rx) = oneshot::channel();
259        {
260            let mut pending = self.pending.lock().await;
261            pending.insert(id, tx);
262        }
263
264        // Send the request (serialized through stdout mutex)
265        let line = serde_json::to_string(&request)
266            .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
267        if let Err(e) = self.write_line(&line) {
268            // Clean up pending entry on write failure
269            let mut pending = self.pending.lock().await;
270            pending.remove(&id);
271            return Err(e);
272        }
273
274        // Wait for the response with timeout
275        let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
276            Ok(Ok(msg)) => msg,
277            Ok(Err(_)) => {
278                // Sender dropped — host closed or stdin reader exited
279                return Err(VmError::Runtime(
280                    "Bridge: host closed connection before responding".into(),
281                ));
282            }
283            Err(_) => {
284                // Timeout — clean up pending entry
285                let mut pending = self.pending.lock().await;
286                pending.remove(&id);
287                return Err(VmError::Runtime(format!(
288                    "Bridge: host did not respond to '{method}' within {}s",
289                    DEFAULT_TIMEOUT.as_secs()
290                )));
291            }
292        };
293
294        // Check for JSON-RPC error
295        if let Some(error) = response.get("error") {
296            let message = error["message"].as_str().unwrap_or("Unknown host error");
297            let code = error["code"].as_i64().unwrap_or(-1);
298            // -32001: tool rejected by host (not permitted / not in allowlist)
299            if code == -32001 {
300                return Err(VmError::CategorizedError {
301                    message: message.to_string(),
302                    category: ErrorCategory::ToolRejected,
303                });
304            }
305            return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
306        }
307
308        Ok(response["result"].clone())
309    }
310
311    /// Send a JSON-RPC notification to the host (no response expected).
312    /// Serialized through the stdout mutex to prevent interleaving.
313    pub fn notify(&self, method: &str, params: serde_json::Value) {
314        let notification = serde_json::json!({
315            "jsonrpc": "2.0",
316            "method": method,
317            "params": params,
318        });
319        if let Ok(line) = serde_json::to_string(&notification) {
320            let _ = self.write_line(&line);
321        }
322    }
323
324    /// Check if the host has sent a cancel notification.
325    pub fn is_cancelled(&self) -> bool {
326        self.cancelled.load(Ordering::SeqCst)
327    }
328
329    pub fn take_resume_signal(&self) -> bool {
330        self.resume_requested.swap(false, Ordering::SeqCst)
331    }
332
333    pub fn signal_resume(&self) {
334        self.resume_requested.store(true, Ordering::SeqCst);
335    }
336
337    pub async fn push_queued_user_message(&self, content: String, mode: &str) {
338        self.queued_user_messages
339            .lock()
340            .await
341            .push_back(QueuedUserMessage {
342                content,
343                mode: QueuedUserMessageMode::from_str(mode),
344            });
345    }
346
347    pub async fn take_queued_user_messages(
348        &self,
349        include_interrupt_immediate: bool,
350        include_finish_step: bool,
351        include_wait_for_completion: bool,
352    ) -> Vec<QueuedUserMessage> {
353        let mut queue = self.queued_user_messages.lock().await;
354        let mut selected = Vec::new();
355        let mut retained = VecDeque::new();
356        while let Some(message) = queue.pop_front() {
357            let should_take = match message.mode {
358                QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
359                QueuedUserMessageMode::FinishStep => include_finish_step,
360                QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
361            };
362            if should_take {
363                selected.push(message);
364            } else {
365                retained.push_back(message);
366            }
367        }
368        *queue = retained;
369        selected
370    }
371
372    pub async fn take_queued_user_messages_for(
373        &self,
374        checkpoint: DeliveryCheckpoint,
375    ) -> Vec<QueuedUserMessage> {
376        match checkpoint {
377            DeliveryCheckpoint::InterruptImmediate => {
378                self.take_queued_user_messages(true, false, false).await
379            }
380            DeliveryCheckpoint::AfterCurrentOperation => {
381                self.take_queued_user_messages(false, true, false).await
382            }
383            DeliveryCheckpoint::EndOfInteraction => {
384                self.take_queued_user_messages(false, false, true).await
385            }
386        }
387    }
388
389    /// Send an output notification (for log/print in bridge mode).
390    pub fn send_output(&self, text: &str) {
391        self.notify("output", serde_json::json!({"text": text}));
392    }
393
394    /// Send a progress notification with optional numeric progress and structured data.
395    pub fn send_progress(
396        &self,
397        phase: &str,
398        message: &str,
399        progress: Option<i64>,
400        total: Option<i64>,
401        data: Option<serde_json::Value>,
402    ) {
403        let mut payload = serde_json::json!({"phase": phase, "message": message});
404        if let Some(p) = progress {
405            payload["progress"] = serde_json::json!(p);
406        }
407        if let Some(t) = total {
408            payload["total"] = serde_json::json!(t);
409        }
410        if let Some(d) = data {
411            payload["data"] = d;
412        }
413        self.notify("progress", payload);
414    }
415
416    /// Send a structured log notification.
417    pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
418        let mut payload = serde_json::json!({"level": level, "message": message});
419        if let Some(f) = fields {
420            payload["fields"] = f;
421        }
422        self.notify("log", payload);
423    }
424
425    /// Send a `session/update` with `call_start` — signals the beginning of
426    /// an LLM call, tool call, or builtin call for observability.
427    pub fn send_call_start(
428        &self,
429        call_id: &str,
430        call_type: &str,
431        name: &str,
432        metadata: serde_json::Value,
433    ) {
434        let session_id = self.get_session_id();
435        let script = self.get_script_name();
436        self.notify(
437            "session/update",
438            serde_json::json!({
439                "sessionId": session_id,
440                "update": {
441                    "sessionUpdate": "call_start",
442                    "content": {
443                        "call_id": call_id,
444                        "call_type": call_type,
445                        "name": name,
446                        "script": script,
447                        "metadata": metadata,
448                    },
449                },
450            }),
451        );
452    }
453
454    /// Send a `session/update` with `call_progress` — a streaming token delta
455    /// from an in-flight LLM call.
456    pub fn send_call_progress(&self, call_id: &str, delta: &str, accumulated_tokens: u64) {
457        let session_id = self.get_session_id();
458        self.notify(
459            "session/update",
460            serde_json::json!({
461                "sessionId": session_id,
462                "update": {
463                    "sessionUpdate": "call_progress",
464                    "content": {
465                        "call_id": call_id,
466                        "delta": delta,
467                        "accumulated_tokens": accumulated_tokens,
468                    },
469                },
470            }),
471        );
472    }
473
474    /// Send a `session/update` with `call_end` — signals completion of a call.
475    pub fn send_call_end(
476        &self,
477        call_id: &str,
478        call_type: &str,
479        name: &str,
480        duration_ms: u64,
481        status: &str,
482        metadata: serde_json::Value,
483    ) {
484        let session_id = self.get_session_id();
485        let script = self.get_script_name();
486        self.notify(
487            "session/update",
488            serde_json::json!({
489                "sessionId": session_id,
490                "update": {
491                    "sessionUpdate": "call_end",
492                    "content": {
493                        "call_id": call_id,
494                        "call_type": call_type,
495                        "name": name,
496                        "script": script,
497                        "duration_ms": duration_ms,
498                        "status": status,
499                        "metadata": metadata,
500                    },
501                },
502            }),
503        );
504    }
505
506    /// Send a worker lifecycle update for delegated/background execution.
507    pub fn send_worker_update(
508        &self,
509        worker_id: &str,
510        worker_name: &str,
511        status: &str,
512        metadata: serde_json::Value,
513        audit: Option<&MutationSessionRecord>,
514    ) {
515        let session_id = self.get_session_id();
516        let script = self.get_script_name();
517        let started_at = metadata.get("started_at").cloned().unwrap_or_default();
518        let finished_at = metadata.get("finished_at").cloned().unwrap_or_default();
519        let snapshot_path = metadata.get("snapshot_path").cloned().unwrap_or_default();
520        let run_id = metadata.get("child_run_id").cloned().unwrap_or_default();
521        let run_path = metadata.get("child_run_path").cloned().unwrap_or_default();
522        let lifecycle = serde_json::json!({
523            "event": status,
524            "worker_id": worker_id,
525            "worker_name": worker_name,
526            "started_at": started_at,
527            "finished_at": finished_at,
528        });
529        self.notify(
530            "session/update",
531            serde_json::json!({
532                "sessionId": session_id,
533                "update": {
534                    "sessionUpdate": "worker_update",
535                    "content": {
536                        "worker_id": worker_id,
537                        "worker_name": worker_name,
538                        "status": status,
539                        "script": script,
540                        "started_at": started_at,
541                        "finished_at": finished_at,
542                        "snapshot_path": snapshot_path,
543                        "run_id": run_id,
544                        "run_path": run_path,
545                        "lifecycle": lifecycle,
546                        "audit": audit,
547                        "metadata": metadata,
548                    },
549                },
550            }),
551        );
552    }
553}
554
555/// Convert a serde_json::Value to a VmValue.
556pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
557    crate::stdlib::json_to_vm_value(val)
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    #[test]
565    fn test_json_rpc_request_format() {
566        let request = serde_json::json!({
567            "jsonrpc": "2.0",
568            "id": 1,
569            "method": "llm_call",
570            "params": {
571                "prompt": "Hello",
572                "system": "Be helpful",
573            },
574        });
575        let s = serde_json::to_string(&request).unwrap();
576        assert!(s.contains("\"jsonrpc\":\"2.0\""));
577        assert!(s.contains("\"id\":1"));
578        assert!(s.contains("\"method\":\"llm_call\""));
579    }
580
581    #[test]
582    fn test_json_rpc_notification_format() {
583        let notification = serde_json::json!({
584            "jsonrpc": "2.0",
585            "method": "output",
586            "params": {"text": "[harn] hello\n"},
587        });
588        let s = serde_json::to_string(&notification).unwrap();
589        assert!(s.contains("\"method\":\"output\""));
590        assert!(!s.contains("\"id\""));
591    }
592
593    #[test]
594    fn test_json_rpc_error_response_parsing() {
595        let response = serde_json::json!({
596            "jsonrpc": "2.0",
597            "id": 1,
598            "error": {
599                "code": -32600,
600                "message": "Invalid request",
601            },
602        });
603        assert!(response.get("error").is_some());
604        assert_eq!(
605            response["error"]["message"].as_str().unwrap(),
606            "Invalid request"
607        );
608    }
609
610    #[test]
611    fn test_json_rpc_success_response_parsing() {
612        let response = serde_json::json!({
613            "jsonrpc": "2.0",
614            "id": 1,
615            "result": {
616                "text": "Hello world",
617                "input_tokens": 10,
618                "output_tokens": 5,
619            },
620        });
621        assert!(response.get("result").is_some());
622        assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
623    }
624
625    #[test]
626    fn test_cancelled_flag() {
627        let cancelled = Arc::new(AtomicBool::new(false));
628        assert!(!cancelled.load(Ordering::SeqCst));
629        cancelled.store(true, Ordering::SeqCst);
630        assert!(cancelled.load(Ordering::SeqCst));
631    }
632
633    #[test]
634    fn queued_messages_are_filtered_by_delivery_mode() {
635        let runtime = tokio::runtime::Builder::new_current_thread()
636            .enable_all()
637            .build()
638            .unwrap();
639        runtime.block_on(async {
640            let bridge = HostBridge::from_parts(
641                Arc::new(Mutex::new(HashMap::new())),
642                Arc::new(AtomicBool::new(false)),
643                Arc::new(std::sync::Mutex::new(())),
644                1,
645            );
646            bridge
647                .push_queued_user_message("first".to_string(), "finish_step")
648                .await;
649            bridge
650                .push_queued_user_message("second".to_string(), "wait_for_completion")
651                .await;
652
653            let finish_step = bridge.take_queued_user_messages(false, true, false).await;
654            assert_eq!(finish_step.len(), 1);
655            assert_eq!(finish_step[0].content, "first");
656
657            let turn_end = bridge.take_queued_user_messages(false, false, true).await;
658            assert_eq!(turn_end.len(), 1);
659            assert_eq!(turn_end[0].content, "second");
660        });
661    }
662
663    #[test]
664    fn test_json_result_to_vm_value_string() {
665        let val = serde_json::json!("hello");
666        let vm_val = json_result_to_vm_value(&val);
667        assert_eq!(vm_val.display(), "hello");
668    }
669
670    #[test]
671    fn test_json_result_to_vm_value_dict() {
672        let val = serde_json::json!({"name": "test", "count": 42});
673        let vm_val = json_result_to_vm_value(&val);
674        let VmValue::Dict(d) = &vm_val else {
675            unreachable!("Expected Dict, got {:?}", vm_val);
676        };
677        assert_eq!(d.get("name").unwrap().display(), "test");
678        assert_eq!(d.get("count").unwrap().display(), "42");
679    }
680
681    #[test]
682    fn test_json_result_to_vm_value_null() {
683        let val = serde_json::json!(null);
684        let vm_val = json_result_to_vm_value(&val);
685        assert!(matches!(vm_val, VmValue::Nil));
686    }
687
688    #[test]
689    fn test_json_result_to_vm_value_nested() {
690        let val = serde_json::json!({
691            "text": "response",
692            "tool_calls": [
693                {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
694            ],
695            "input_tokens": 100,
696            "output_tokens": 50,
697        });
698        let vm_val = json_result_to_vm_value(&val);
699        let VmValue::Dict(d) = &vm_val else {
700            unreachable!("Expected Dict, got {:?}", vm_val);
701        };
702        assert_eq!(d.get("text").unwrap().display(), "response");
703        let VmValue::List(list) = d.get("tool_calls").unwrap() else {
704            unreachable!("Expected List for tool_calls");
705        };
706        assert_eq!(list.len(), 1);
707    }
708
709    #[test]
710    fn test_timeout_duration() {
711        assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
712    }
713}