1use std::collections::{HashMap, VecDeque};
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::orchestration::MutationSessionRecord;
17use crate::value::{ErrorCategory, VmError, VmValue};
18use crate::visible_text::VisibleTextState;
19
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
22
23pub struct HostBridge {
30 next_id: AtomicU64,
31 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
33 cancelled: Arc<AtomicBool>,
35 stdout_lock: Arc<std::sync::Mutex<()>>,
37 session_id: std::sync::Mutex<String>,
39 script_name: std::sync::Mutex<String>,
41 queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>>,
43 resume_requested: Arc<AtomicBool>,
45 skills_reload_requested: Arc<AtomicBool>,
50 visible_call_states: std::sync::Mutex<HashMap<String, VisibleTextState>>,
52 visible_call_streams: std::sync::Mutex<HashMap<String, bool>>,
54 #[cfg(test)]
55 recorded_notifications: Arc<std::sync::Mutex<Vec<serde_json::Value>>>,
56}
57
58#[derive(Clone, Debug, PartialEq, Eq)]
59pub enum QueuedUserMessageMode {
60 InterruptImmediate,
61 FinishStep,
62 WaitForCompletion,
63}
64
65#[derive(Clone, Copy, Debug, PartialEq, Eq)]
66pub enum DeliveryCheckpoint {
67 InterruptImmediate,
68 AfterCurrentOperation,
69 EndOfInteraction,
70}
71
72impl QueuedUserMessageMode {
73 fn from_str(value: &str) -> Self {
74 match value {
75 "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
76 "finish_step" | "after_current_operation" => Self::FinishStep,
77 _ => Self::WaitForCompletion,
78 }
79 }
80}
81
82#[derive(Clone, Debug, PartialEq, Eq)]
83pub struct QueuedUserMessage {
84 pub content: String,
85 pub mode: QueuedUserMessageMode,
86}
87
88#[allow(clippy::new_without_default)]
90impl HostBridge {
91 pub fn new() -> Self {
96 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
97 Arc::new(Mutex::new(HashMap::new()));
98 let cancelled = Arc::new(AtomicBool::new(false));
99 let queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>> =
100 Arc::new(Mutex::new(VecDeque::new()));
101 let resume_requested = Arc::new(AtomicBool::new(false));
102 let skills_reload_requested = Arc::new(AtomicBool::new(false));
103
104 let pending_clone = pending.clone();
106 let cancelled_clone = cancelled.clone();
107 let queued_clone = queued_user_messages.clone();
108 let resume_clone = resume_requested.clone();
109 let skills_reload_clone = skills_reload_requested.clone();
110 tokio::task::spawn_local(async move {
111 let stdin = tokio::io::stdin();
112 let reader = tokio::io::BufReader::new(stdin);
113 let mut lines = reader.lines();
114
115 while let Ok(Some(line)) = lines.next_line().await {
116 let line = line.trim().to_string();
117 if line.is_empty() {
118 continue;
119 }
120
121 let msg: serde_json::Value = match serde_json::from_str(&line) {
122 Ok(v) => v,
123 Err(_) => continue,
124 };
125
126 if msg.get("id").is_none() {
128 if let Some(method) = msg["method"].as_str() {
129 if method == "cancel" {
130 cancelled_clone.store(true, Ordering::SeqCst);
131 } else if method == "agent/resume" {
132 resume_clone.store(true, Ordering::SeqCst);
133 } else if method == "skills/update" {
134 skills_reload_clone.store(true, Ordering::SeqCst);
135 } else if method == "user_message"
136 || method == "session/input"
137 || method == "agent/user_message"
138 {
139 let params = &msg["params"];
140 let content = params
141 .get("content")
142 .and_then(|v| v.as_str())
143 .unwrap_or("")
144 .to_string();
145 if !content.is_empty() {
146 let mode = QueuedUserMessageMode::from_str(
147 params
148 .get("mode")
149 .and_then(|v| v.as_str())
150 .unwrap_or("wait_for_completion"),
151 );
152 queued_clone
153 .lock()
154 .await
155 .push_back(QueuedUserMessage { content, mode });
156 }
157 }
158 }
159 continue;
160 }
161
162 if let Some(id) = msg["id"].as_u64() {
163 let mut pending = pending_clone.lock().await;
164 if let Some(sender) = pending.remove(&id) {
165 let _ = sender.send(msg);
166 }
167 }
168 }
169
170 let mut pending = pending_clone.lock().await;
172 pending.clear();
173 });
174
175 Self {
176 next_id: AtomicU64::new(1),
177 pending,
178 cancelled,
179 stdout_lock: Arc::new(std::sync::Mutex::new(())),
180 session_id: std::sync::Mutex::new(String::new()),
181 script_name: std::sync::Mutex::new(String::new()),
182 queued_user_messages,
183 resume_requested,
184 skills_reload_requested,
185 visible_call_states: std::sync::Mutex::new(HashMap::new()),
186 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
187 #[cfg(test)]
188 recorded_notifications: Arc::new(std::sync::Mutex::new(Vec::new())),
189 }
190 }
191
192 pub fn from_parts(
198 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
199 cancelled: Arc<AtomicBool>,
200 stdout_lock: Arc<std::sync::Mutex<()>>,
201 start_id: u64,
202 ) -> Self {
203 Self {
204 next_id: AtomicU64::new(start_id),
205 pending,
206 cancelled,
207 stdout_lock,
208 session_id: std::sync::Mutex::new(String::new()),
209 script_name: std::sync::Mutex::new(String::new()),
210 queued_user_messages: Arc::new(Mutex::new(VecDeque::new())),
211 resume_requested: Arc::new(AtomicBool::new(false)),
212 skills_reload_requested: Arc::new(AtomicBool::new(false)),
213 visible_call_states: std::sync::Mutex::new(HashMap::new()),
214 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
215 #[cfg(test)]
216 recorded_notifications: Arc::new(std::sync::Mutex::new(Vec::new())),
217 }
218 }
219
220 pub fn set_session_id(&self, id: &str) {
222 *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
223 }
224
225 pub fn set_script_name(&self, name: &str) {
227 *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
228 }
229
230 fn get_script_name(&self) -> String {
232 self.script_name
233 .lock()
234 .unwrap_or_else(|e| e.into_inner())
235 .clone()
236 }
237
238 fn get_session_id(&self) -> String {
240 self.session_id
241 .lock()
242 .unwrap_or_else(|e| e.into_inner())
243 .clone()
244 }
245
246 fn write_line(&self, line: &str) -> Result<(), VmError> {
248 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
249 let mut stdout = std::io::stdout().lock();
250 stdout
251 .write_all(line.as_bytes())
252 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
253 stdout
254 .write_all(b"\n")
255 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
256 stdout
257 .flush()
258 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
259 Ok(())
260 }
261
262 pub async fn call(
265 &self,
266 method: &str,
267 params: serde_json::Value,
268 ) -> Result<serde_json::Value, VmError> {
269 if self.is_cancelled() {
270 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
271 }
272
273 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
274
275 let request = crate::jsonrpc::request(id, method, params);
276
277 let (tx, rx) = oneshot::channel();
278 {
279 let mut pending = self.pending.lock().await;
280 pending.insert(id, tx);
281 }
282
283 let line = serde_json::to_string(&request)
284 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
285 if let Err(e) = self.write_line(&line) {
286 let mut pending = self.pending.lock().await;
287 pending.remove(&id);
288 return Err(e);
289 }
290
291 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
292 Ok(Ok(msg)) => msg,
293 Ok(Err(_)) => {
294 return Err(VmError::Runtime(
296 "Bridge: host closed connection before responding".into(),
297 ));
298 }
299 Err(_) => {
300 let mut pending = self.pending.lock().await;
301 pending.remove(&id);
302 return Err(VmError::Runtime(format!(
303 "Bridge: host did not respond to '{method}' within {}s",
304 DEFAULT_TIMEOUT.as_secs()
305 )));
306 }
307 };
308
309 if let Some(error) = response.get("error") {
310 let message = error["message"].as_str().unwrap_or("Unknown host error");
311 let code = error["code"].as_i64().unwrap_or(-1);
312 if code == -32001 {
314 return Err(VmError::CategorizedError {
315 message: message.to_string(),
316 category: ErrorCategory::ToolRejected,
317 });
318 }
319 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
320 }
321
322 Ok(response["result"].clone())
323 }
324
325 pub fn notify(&self, method: &str, params: serde_json::Value) {
328 let notification = crate::jsonrpc::notification(method, params);
329 #[cfg(test)]
330 self.recorded_notifications
331 .lock()
332 .unwrap_or_else(|e| e.into_inner())
333 .push(notification.clone());
334 if let Ok(line) = serde_json::to_string(¬ification) {
335 let _ = self.write_line(&line);
336 }
337 }
338
339 #[cfg(test)]
340 pub(crate) fn recorded_notifications(&self) -> Vec<serde_json::Value> {
341 self.recorded_notifications
342 .lock()
343 .unwrap_or_else(|e| e.into_inner())
344 .clone()
345 }
346
347 pub fn is_cancelled(&self) -> bool {
349 self.cancelled.load(Ordering::SeqCst)
350 }
351
352 pub fn take_resume_signal(&self) -> bool {
353 self.resume_requested.swap(false, Ordering::SeqCst)
354 }
355
356 pub fn signal_resume(&self) {
357 self.resume_requested.store(true, Ordering::SeqCst);
358 }
359
360 pub fn take_skills_reload_signal(&self) -> bool {
365 self.skills_reload_requested.swap(false, Ordering::SeqCst)
366 }
367
368 pub fn signal_skills_reload(&self) {
372 self.skills_reload_requested.store(true, Ordering::SeqCst);
373 }
374
375 pub async fn list_host_skills(&self) -> Result<Vec<serde_json::Value>, VmError> {
381 let result = self.call("skills/list", serde_json::json!({})).await?;
382 match result {
383 serde_json::Value::Array(items) => Ok(items),
384 serde_json::Value::Object(map) => match map.get("skills") {
385 Some(serde_json::Value::Array(items)) => Ok(items.clone()),
386 _ => Err(VmError::Runtime(
387 "skills/list: host response must be an array or { skills: [...] }".into(),
388 )),
389 },
390 _ => Err(VmError::Runtime(
391 "skills/list: unexpected response shape".into(),
392 )),
393 }
394 }
395
396 pub async fn fetch_host_skill(&self, id: &str) -> Result<serde_json::Value, VmError> {
400 self.call("skills/fetch", serde_json::json!({ "id": id }))
401 .await
402 }
403
404 pub async fn push_queued_user_message(&self, content: String, mode: &str) {
405 self.queued_user_messages
406 .lock()
407 .await
408 .push_back(QueuedUserMessage {
409 content,
410 mode: QueuedUserMessageMode::from_str(mode),
411 });
412 }
413
414 pub async fn take_queued_user_messages(
415 &self,
416 include_interrupt_immediate: bool,
417 include_finish_step: bool,
418 include_wait_for_completion: bool,
419 ) -> Vec<QueuedUserMessage> {
420 let mut queue = self.queued_user_messages.lock().await;
421 let mut selected = Vec::new();
422 let mut retained = VecDeque::new();
423 while let Some(message) = queue.pop_front() {
424 let should_take = match message.mode {
425 QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
426 QueuedUserMessageMode::FinishStep => include_finish_step,
427 QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
428 };
429 if should_take {
430 selected.push(message);
431 } else {
432 retained.push_back(message);
433 }
434 }
435 *queue = retained;
436 selected
437 }
438
439 pub async fn take_queued_user_messages_for(
440 &self,
441 checkpoint: DeliveryCheckpoint,
442 ) -> Vec<QueuedUserMessage> {
443 match checkpoint {
444 DeliveryCheckpoint::InterruptImmediate => {
445 self.take_queued_user_messages(true, false, false).await
446 }
447 DeliveryCheckpoint::AfterCurrentOperation => {
448 self.take_queued_user_messages(false, true, false).await
449 }
450 DeliveryCheckpoint::EndOfInteraction => {
451 self.take_queued_user_messages(false, false, true).await
452 }
453 }
454 }
455
456 pub fn send_output(&self, text: &str) {
458 self.notify("output", serde_json::json!({"text": text}));
459 }
460
461 pub fn send_progress(
463 &self,
464 phase: &str,
465 message: &str,
466 progress: Option<i64>,
467 total: Option<i64>,
468 data: Option<serde_json::Value>,
469 ) {
470 let mut payload = serde_json::json!({"phase": phase, "message": message});
471 if let Some(p) = progress {
472 payload["progress"] = serde_json::json!(p);
473 }
474 if let Some(t) = total {
475 payload["total"] = serde_json::json!(t);
476 }
477 if let Some(d) = data {
478 payload["data"] = d;
479 }
480 self.notify("progress", payload);
481 }
482
483 pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
485 let mut payload = serde_json::json!({"level": level, "message": message});
486 if let Some(f) = fields {
487 payload["fields"] = f;
488 }
489 self.notify("log", payload);
490 }
491
492 pub fn send_call_start(
495 &self,
496 call_id: &str,
497 call_type: &str,
498 name: &str,
499 metadata: serde_json::Value,
500 ) {
501 let session_id = self.get_session_id();
502 let script = self.get_script_name();
503 let stream_publicly = metadata
504 .get("stream_publicly")
505 .and_then(|value| value.as_bool())
506 .unwrap_or(true);
507 self.visible_call_streams
508 .lock()
509 .unwrap_or_else(|e| e.into_inner())
510 .insert(call_id.to_string(), stream_publicly);
511 self.notify(
512 "session/update",
513 serde_json::json!({
514 "sessionId": session_id,
515 "update": {
516 "sessionUpdate": "call_start",
517 "content": {
518 "toolCallId": call_id,
519 "call_type": call_type,
520 "name": name,
521 "script": script,
522 "metadata": metadata,
523 },
524 },
525 }),
526 );
527 }
528
529 pub fn send_call_progress(
532 &self,
533 call_id: &str,
534 delta: &str,
535 accumulated_tokens: u64,
536 user_visible: bool,
537 ) {
538 let session_id = self.get_session_id();
539 let (visible_text, visible_delta) = {
540 let stream_publicly = self
541 .visible_call_streams
542 .lock()
543 .unwrap_or_else(|e| e.into_inner())
544 .get(call_id)
545 .copied()
546 .unwrap_or(true);
547 let mut states = self
548 .visible_call_states
549 .lock()
550 .unwrap_or_else(|e| e.into_inner());
551 let state = states.entry(call_id.to_string()).or_default();
552 state.push(delta, stream_publicly)
553 };
554 self.notify(
555 "session/update",
556 serde_json::json!({
557 "sessionId": session_id,
558 "update": {
559 "sessionUpdate": "call_progress",
560 "content": {
561 "toolCallId": call_id,
562 "delta": delta,
563 "accumulated_tokens": accumulated_tokens,
564 "visible_text": visible_text,
565 "visible_delta": visible_delta,
566 "user_visible": user_visible,
567 },
568 },
569 }),
570 );
571 }
572
573 pub fn send_call_end(
575 &self,
576 call_id: &str,
577 call_type: &str,
578 name: &str,
579 duration_ms: u64,
580 status: &str,
581 metadata: serde_json::Value,
582 ) {
583 let session_id = self.get_session_id();
584 let script = self.get_script_name();
585 self.visible_call_states
586 .lock()
587 .unwrap_or_else(|e| e.into_inner())
588 .remove(call_id);
589 self.visible_call_streams
590 .lock()
591 .unwrap_or_else(|e| e.into_inner())
592 .remove(call_id);
593 self.notify(
594 "session/update",
595 serde_json::json!({
596 "sessionId": session_id,
597 "update": {
598 "sessionUpdate": "call_end",
599 "content": {
600 "toolCallId": call_id,
601 "call_type": call_type,
602 "name": name,
603 "script": script,
604 "duration_ms": duration_ms,
605 "status": status,
606 "metadata": metadata,
607 },
608 },
609 }),
610 );
611 }
612
613 pub fn send_worker_update(
615 &self,
616 worker_id: &str,
617 worker_name: &str,
618 status: &str,
619 metadata: serde_json::Value,
620 audit: Option<&MutationSessionRecord>,
621 ) {
622 let session_id = self.get_session_id();
623 let script = self.get_script_name();
624 let started_at = metadata.get("started_at").cloned().unwrap_or_default();
625 let finished_at = metadata.get("finished_at").cloned().unwrap_or_default();
626 let snapshot_path = metadata.get("snapshot_path").cloned().unwrap_or_default();
627 let run_id = metadata.get("child_run_id").cloned().unwrap_or_default();
628 let run_path = metadata.get("child_run_path").cloned().unwrap_or_default();
629 let lifecycle = serde_json::json!({
630 "event": status,
631 "worker_id": worker_id,
632 "worker_name": worker_name,
633 "started_at": started_at,
634 "finished_at": finished_at,
635 });
636 self.notify(
637 "session/update",
638 serde_json::json!({
639 "sessionId": session_id,
640 "update": {
641 "sessionUpdate": "worker_update",
642 "content": {
643 "worker_id": worker_id,
644 "worker_name": worker_name,
645 "status": status,
646 "script": script,
647 "started_at": started_at,
648 "finished_at": finished_at,
649 "snapshot_path": snapshot_path,
650 "run_id": run_id,
651 "run_path": run_path,
652 "lifecycle": lifecycle,
653 "audit": audit,
654 "metadata": metadata,
655 },
656 },
657 }),
658 );
659 }
660}
661
662pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
664 crate::stdlib::json_to_vm_value(val)
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670
671 #[test]
672 fn test_json_rpc_request_format() {
673 let request = crate::jsonrpc::request(
674 1,
675 "llm_call",
676 serde_json::json!({
677 "prompt": "Hello",
678 "system": "Be helpful",
679 }),
680 );
681 let s = serde_json::to_string(&request).unwrap();
682 assert!(s.contains("\"jsonrpc\":\"2.0\""));
683 assert!(s.contains("\"id\":1"));
684 assert!(s.contains("\"method\":\"llm_call\""));
685 }
686
687 #[test]
688 fn test_json_rpc_notification_format() {
689 let notification =
690 crate::jsonrpc::notification("output", serde_json::json!({"text": "[harn] hello\n"}));
691 let s = serde_json::to_string(¬ification).unwrap();
692 assert!(s.contains("\"method\":\"output\""));
693 assert!(!s.contains("\"id\""));
694 }
695
696 #[test]
697 fn test_json_rpc_error_response_parsing() {
698 let response = crate::jsonrpc::error_response(1, -32600, "Invalid request");
699 assert!(response.get("error").is_some());
700 assert_eq!(
701 response["error"]["message"].as_str().unwrap(),
702 "Invalid request"
703 );
704 }
705
706 #[test]
707 fn test_json_rpc_success_response_parsing() {
708 let response = crate::jsonrpc::response(
709 1,
710 serde_json::json!({
711 "text": "Hello world",
712 "input_tokens": 10,
713 "output_tokens": 5,
714 }),
715 );
716 assert!(response.get("result").is_some());
717 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
718 }
719
720 #[test]
721 fn test_cancelled_flag() {
722 let cancelled = Arc::new(AtomicBool::new(false));
723 assert!(!cancelled.load(Ordering::SeqCst));
724 cancelled.store(true, Ordering::SeqCst);
725 assert!(cancelled.load(Ordering::SeqCst));
726 }
727
728 #[test]
729 fn queued_messages_are_filtered_by_delivery_mode() {
730 let runtime = tokio::runtime::Builder::new_current_thread()
731 .enable_all()
732 .build()
733 .unwrap();
734 runtime.block_on(async {
735 let bridge = HostBridge::from_parts(
736 Arc::new(Mutex::new(HashMap::new())),
737 Arc::new(AtomicBool::new(false)),
738 Arc::new(std::sync::Mutex::new(())),
739 1,
740 );
741 bridge
742 .push_queued_user_message("first".to_string(), "finish_step")
743 .await;
744 bridge
745 .push_queued_user_message("second".to_string(), "wait_for_completion")
746 .await;
747
748 let finish_step = bridge.take_queued_user_messages(false, true, false).await;
749 assert_eq!(finish_step.len(), 1);
750 assert_eq!(finish_step[0].content, "first");
751
752 let turn_end = bridge.take_queued_user_messages(false, false, true).await;
753 assert_eq!(turn_end.len(), 1);
754 assert_eq!(turn_end[0].content, "second");
755 });
756 }
757
758 #[test]
759 fn test_json_result_to_vm_value_string() {
760 let val = serde_json::json!("hello");
761 let vm_val = json_result_to_vm_value(&val);
762 assert_eq!(vm_val.display(), "hello");
763 }
764
765 #[test]
766 fn test_json_result_to_vm_value_dict() {
767 let val = serde_json::json!({"name": "test", "count": 42});
768 let vm_val = json_result_to_vm_value(&val);
769 let VmValue::Dict(d) = &vm_val else {
770 unreachable!("Expected Dict, got {:?}", vm_val);
771 };
772 assert_eq!(d.get("name").unwrap().display(), "test");
773 assert_eq!(d.get("count").unwrap().display(), "42");
774 }
775
776 #[test]
777 fn test_json_result_to_vm_value_null() {
778 let val = serde_json::json!(null);
779 let vm_val = json_result_to_vm_value(&val);
780 assert!(matches!(vm_val, VmValue::Nil));
781 }
782
783 #[test]
784 fn test_json_result_to_vm_value_nested() {
785 let val = serde_json::json!({
786 "text": "response",
787 "tool_calls": [
788 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
789 ],
790 "input_tokens": 100,
791 "output_tokens": 50,
792 });
793 let vm_val = json_result_to_vm_value(&val);
794 let VmValue::Dict(d) = &vm_val else {
795 unreachable!("Expected Dict, got {:?}", vm_val);
796 };
797 assert_eq!(d.get("text").unwrap().display(), "response");
798 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
799 unreachable!("Expected List for tool_calls");
800 };
801 assert_eq!(list.len(), 1);
802 }
803
804 #[test]
805 fn test_timeout_duration() {
806 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
807 }
808}