1use std::collections::{HashMap, VecDeque};
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::orchestration::MutationSessionRecord;
17use crate::value::{ErrorCategory, VmError, VmValue};
18use crate::visible_text::VisibleTextState;
19
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
22
23pub struct HostBridge {
30 next_id: AtomicU64,
31 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
33 cancelled: Arc<AtomicBool>,
35 stdout_lock: Arc<std::sync::Mutex<()>>,
37 session_id: std::sync::Mutex<String>,
39 script_name: std::sync::Mutex<String>,
41 queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>>,
43 resume_requested: Arc<AtomicBool>,
45 visible_call_states: std::sync::Mutex<HashMap<String, VisibleTextState>>,
47 visible_call_streams: std::sync::Mutex<HashMap<String, bool>>,
49 #[cfg(test)]
50 recorded_notifications: Arc<std::sync::Mutex<Vec<serde_json::Value>>>,
51}
52
53#[derive(Clone, Debug, PartialEq, Eq)]
54pub enum QueuedUserMessageMode {
55 InterruptImmediate,
56 FinishStep,
57 WaitForCompletion,
58}
59
60#[derive(Clone, Copy, Debug, PartialEq, Eq)]
61pub enum DeliveryCheckpoint {
62 InterruptImmediate,
63 AfterCurrentOperation,
64 EndOfInteraction,
65}
66
67impl QueuedUserMessageMode {
68 fn from_str(value: &str) -> Self {
69 match value {
70 "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
71 "finish_step" | "after_current_operation" => Self::FinishStep,
72 _ => Self::WaitForCompletion,
73 }
74 }
75}
76
77#[derive(Clone, Debug, PartialEq, Eq)]
78pub struct QueuedUserMessage {
79 pub content: String,
80 pub mode: QueuedUserMessageMode,
81}
82
83#[allow(clippy::new_without_default)]
85impl HostBridge {
86 pub fn new() -> Self {
91 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
92 Arc::new(Mutex::new(HashMap::new()));
93 let cancelled = Arc::new(AtomicBool::new(false));
94 let queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>> =
95 Arc::new(Mutex::new(VecDeque::new()));
96 let resume_requested = Arc::new(AtomicBool::new(false));
97
98 let pending_clone = pending.clone();
100 let cancelled_clone = cancelled.clone();
101 let queued_clone = queued_user_messages.clone();
102 let resume_clone = resume_requested.clone();
103 tokio::task::spawn_local(async move {
104 let stdin = tokio::io::stdin();
105 let reader = tokio::io::BufReader::new(stdin);
106 let mut lines = reader.lines();
107
108 while let Ok(Some(line)) = lines.next_line().await {
109 let line = line.trim().to_string();
110 if line.is_empty() {
111 continue;
112 }
113
114 let msg: serde_json::Value = match serde_json::from_str(&line) {
115 Ok(v) => v,
116 Err(_) => continue, };
118
119 if msg.get("id").is_none() {
121 if let Some(method) = msg["method"].as_str() {
122 if method == "cancel" {
123 cancelled_clone.store(true, Ordering::SeqCst);
124 } else if method == "agent/resume" {
125 resume_clone.store(true, Ordering::SeqCst);
126 } else if method == "user_message"
127 || method == "session/input"
128 || method == "agent/user_message"
129 {
130 let params = &msg["params"];
131 let content = params
132 .get("content")
133 .and_then(|v| v.as_str())
134 .unwrap_or("")
135 .to_string();
136 if !content.is_empty() {
137 let mode = QueuedUserMessageMode::from_str(
138 params
139 .get("mode")
140 .and_then(|v| v.as_str())
141 .unwrap_or("wait_for_completion"),
142 );
143 queued_clone
144 .lock()
145 .await
146 .push_back(QueuedUserMessage { content, mode });
147 }
148 }
149 }
150 continue;
151 }
152
153 if let Some(id) = msg["id"].as_u64() {
155 let mut pending = pending_clone.lock().await;
156 if let Some(sender) = pending.remove(&id) {
157 let _ = sender.send(msg);
158 }
159 }
160 }
161
162 let mut pending = pending_clone.lock().await;
164 pending.clear();
165 });
166
167 Self {
168 next_id: AtomicU64::new(1),
169 pending,
170 cancelled,
171 stdout_lock: Arc::new(std::sync::Mutex::new(())),
172 session_id: std::sync::Mutex::new(String::new()),
173 script_name: std::sync::Mutex::new(String::new()),
174 queued_user_messages,
175 resume_requested,
176 visible_call_states: std::sync::Mutex::new(HashMap::new()),
177 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
178 #[cfg(test)]
179 recorded_notifications: Arc::new(std::sync::Mutex::new(Vec::new())),
180 }
181 }
182
183 pub fn from_parts(
189 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
190 cancelled: Arc<AtomicBool>,
191 stdout_lock: Arc<std::sync::Mutex<()>>,
192 start_id: u64,
193 ) -> Self {
194 Self {
195 next_id: AtomicU64::new(start_id),
196 pending,
197 cancelled,
198 stdout_lock,
199 session_id: std::sync::Mutex::new(String::new()),
200 script_name: std::sync::Mutex::new(String::new()),
201 queued_user_messages: Arc::new(Mutex::new(VecDeque::new())),
202 resume_requested: Arc::new(AtomicBool::new(false)),
203 visible_call_states: std::sync::Mutex::new(HashMap::new()),
204 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
205 #[cfg(test)]
206 recorded_notifications: Arc::new(std::sync::Mutex::new(Vec::new())),
207 }
208 }
209
210 pub fn set_session_id(&self, id: &str) {
212 *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
213 }
214
215 pub fn set_script_name(&self, name: &str) {
217 *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
218 }
219
220 fn get_script_name(&self) -> String {
222 self.script_name
223 .lock()
224 .unwrap_or_else(|e| e.into_inner())
225 .clone()
226 }
227
228 fn get_session_id(&self) -> String {
230 self.session_id
231 .lock()
232 .unwrap_or_else(|e| e.into_inner())
233 .clone()
234 }
235
236 fn write_line(&self, line: &str) -> Result<(), VmError> {
238 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
239 let mut stdout = std::io::stdout().lock();
240 stdout
241 .write_all(line.as_bytes())
242 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
243 stdout
244 .write_all(b"\n")
245 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
246 stdout
247 .flush()
248 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
249 Ok(())
250 }
251
252 pub async fn call(
255 &self,
256 method: &str,
257 params: serde_json::Value,
258 ) -> Result<serde_json::Value, VmError> {
259 if self.is_cancelled() {
260 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
261 }
262
263 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
264
265 let request = serde_json::json!({
266 "jsonrpc": "2.0",
267 "id": id,
268 "method": method,
269 "params": params,
270 });
271
272 let (tx, rx) = oneshot::channel();
274 {
275 let mut pending = self.pending.lock().await;
276 pending.insert(id, tx);
277 }
278
279 let line = serde_json::to_string(&request)
281 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
282 if let Err(e) = self.write_line(&line) {
283 let mut pending = self.pending.lock().await;
285 pending.remove(&id);
286 return Err(e);
287 }
288
289 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
291 Ok(Ok(msg)) => msg,
292 Ok(Err(_)) => {
293 return Err(VmError::Runtime(
295 "Bridge: host closed connection before responding".into(),
296 ));
297 }
298 Err(_) => {
299 let mut pending = self.pending.lock().await;
301 pending.remove(&id);
302 return Err(VmError::Runtime(format!(
303 "Bridge: host did not respond to '{method}' within {}s",
304 DEFAULT_TIMEOUT.as_secs()
305 )));
306 }
307 };
308
309 if let Some(error) = response.get("error") {
311 let message = error["message"].as_str().unwrap_or("Unknown host error");
312 let code = error["code"].as_i64().unwrap_or(-1);
313 if code == -32001 {
315 return Err(VmError::CategorizedError {
316 message: message.to_string(),
317 category: ErrorCategory::ToolRejected,
318 });
319 }
320 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
321 }
322
323 Ok(response["result"].clone())
324 }
325
326 pub fn notify(&self, method: &str, params: serde_json::Value) {
329 let notification = serde_json::json!({
330 "jsonrpc": "2.0",
331 "method": method,
332 "params": params,
333 });
334 #[cfg(test)]
335 self.recorded_notifications
336 .lock()
337 .unwrap_or_else(|e| e.into_inner())
338 .push(notification.clone());
339 if let Ok(line) = serde_json::to_string(¬ification) {
340 let _ = self.write_line(&line);
341 }
342 }
343
344 #[cfg(test)]
345 pub(crate) fn recorded_notifications(&self) -> Vec<serde_json::Value> {
346 self.recorded_notifications
347 .lock()
348 .unwrap_or_else(|e| e.into_inner())
349 .clone()
350 }
351
352 pub fn is_cancelled(&self) -> bool {
354 self.cancelled.load(Ordering::SeqCst)
355 }
356
357 pub fn take_resume_signal(&self) -> bool {
358 self.resume_requested.swap(false, Ordering::SeqCst)
359 }
360
361 pub fn signal_resume(&self) {
362 self.resume_requested.store(true, Ordering::SeqCst);
363 }
364
365 pub async fn push_queued_user_message(&self, content: String, mode: &str) {
366 self.queued_user_messages
367 .lock()
368 .await
369 .push_back(QueuedUserMessage {
370 content,
371 mode: QueuedUserMessageMode::from_str(mode),
372 });
373 }
374
375 pub async fn take_queued_user_messages(
376 &self,
377 include_interrupt_immediate: bool,
378 include_finish_step: bool,
379 include_wait_for_completion: bool,
380 ) -> Vec<QueuedUserMessage> {
381 let mut queue = self.queued_user_messages.lock().await;
382 let mut selected = Vec::new();
383 let mut retained = VecDeque::new();
384 while let Some(message) = queue.pop_front() {
385 let should_take = match message.mode {
386 QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
387 QueuedUserMessageMode::FinishStep => include_finish_step,
388 QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
389 };
390 if should_take {
391 selected.push(message);
392 } else {
393 retained.push_back(message);
394 }
395 }
396 *queue = retained;
397 selected
398 }
399
400 pub async fn take_queued_user_messages_for(
401 &self,
402 checkpoint: DeliveryCheckpoint,
403 ) -> Vec<QueuedUserMessage> {
404 match checkpoint {
405 DeliveryCheckpoint::InterruptImmediate => {
406 self.take_queued_user_messages(true, false, false).await
407 }
408 DeliveryCheckpoint::AfterCurrentOperation => {
409 self.take_queued_user_messages(false, true, false).await
410 }
411 DeliveryCheckpoint::EndOfInteraction => {
412 self.take_queued_user_messages(false, false, true).await
413 }
414 }
415 }
416
417 pub fn send_output(&self, text: &str) {
419 self.notify("output", serde_json::json!({"text": text}));
420 }
421
422 pub fn send_progress(
424 &self,
425 phase: &str,
426 message: &str,
427 progress: Option<i64>,
428 total: Option<i64>,
429 data: Option<serde_json::Value>,
430 ) {
431 let mut payload = serde_json::json!({"phase": phase, "message": message});
432 if let Some(p) = progress {
433 payload["progress"] = serde_json::json!(p);
434 }
435 if let Some(t) = total {
436 payload["total"] = serde_json::json!(t);
437 }
438 if let Some(d) = data {
439 payload["data"] = d;
440 }
441 self.notify("progress", payload);
442 }
443
444 pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
446 let mut payload = serde_json::json!({"level": level, "message": message});
447 if let Some(f) = fields {
448 payload["fields"] = f;
449 }
450 self.notify("log", payload);
451 }
452
453 pub fn send_call_start(
456 &self,
457 call_id: &str,
458 call_type: &str,
459 name: &str,
460 metadata: serde_json::Value,
461 ) {
462 let session_id = self.get_session_id();
463 let script = self.get_script_name();
464 let stream_publicly = metadata
465 .get("stream_publicly")
466 .and_then(|value| value.as_bool())
467 .unwrap_or(true);
468 self.visible_call_streams
469 .lock()
470 .unwrap_or_else(|e| e.into_inner())
471 .insert(call_id.to_string(), stream_publicly);
472 self.notify(
473 "session/update",
474 serde_json::json!({
475 "sessionId": session_id,
476 "update": {
477 "sessionUpdate": "call_start",
478 "content": {
479 "call_id": call_id,
480 "call_type": call_type,
481 "name": name,
482 "script": script,
483 "metadata": metadata,
484 },
485 },
486 }),
487 );
488 }
489
490 pub fn send_call_progress(
493 &self,
494 call_id: &str,
495 delta: &str,
496 accumulated_tokens: u64,
497 user_visible: bool,
498 ) {
499 let session_id = self.get_session_id();
500 let (visible_text, visible_delta) = {
501 let stream_publicly = self
502 .visible_call_streams
503 .lock()
504 .unwrap_or_else(|e| e.into_inner())
505 .get(call_id)
506 .copied()
507 .unwrap_or(true);
508 let mut states = self
509 .visible_call_states
510 .lock()
511 .unwrap_or_else(|e| e.into_inner());
512 let state = states.entry(call_id.to_string()).or_default();
513 state.push(delta, stream_publicly)
514 };
515 self.notify(
516 "session/update",
517 serde_json::json!({
518 "sessionId": session_id,
519 "update": {
520 "sessionUpdate": "call_progress",
521 "content": {
522 "call_id": call_id,
523 "delta": delta,
524 "accumulated_tokens": accumulated_tokens,
525 "visible_text": visible_text,
526 "visible_delta": visible_delta,
527 "user_visible": user_visible,
528 },
529 },
530 }),
531 );
532 }
533
534 pub fn send_call_end(
536 &self,
537 call_id: &str,
538 call_type: &str,
539 name: &str,
540 duration_ms: u64,
541 status: &str,
542 metadata: serde_json::Value,
543 ) {
544 let session_id = self.get_session_id();
545 let script = self.get_script_name();
546 self.visible_call_states
547 .lock()
548 .unwrap_or_else(|e| e.into_inner())
549 .remove(call_id);
550 self.visible_call_streams
551 .lock()
552 .unwrap_or_else(|e| e.into_inner())
553 .remove(call_id);
554 self.notify(
555 "session/update",
556 serde_json::json!({
557 "sessionId": session_id,
558 "update": {
559 "sessionUpdate": "call_end",
560 "content": {
561 "call_id": call_id,
562 "call_type": call_type,
563 "name": name,
564 "script": script,
565 "duration_ms": duration_ms,
566 "status": status,
567 "metadata": metadata,
568 },
569 },
570 }),
571 );
572 }
573
574 pub fn send_worker_update(
576 &self,
577 worker_id: &str,
578 worker_name: &str,
579 status: &str,
580 metadata: serde_json::Value,
581 audit: Option<&MutationSessionRecord>,
582 ) {
583 let session_id = self.get_session_id();
584 let script = self.get_script_name();
585 let started_at = metadata.get("started_at").cloned().unwrap_or_default();
586 let finished_at = metadata.get("finished_at").cloned().unwrap_or_default();
587 let snapshot_path = metadata.get("snapshot_path").cloned().unwrap_or_default();
588 let run_id = metadata.get("child_run_id").cloned().unwrap_or_default();
589 let run_path = metadata.get("child_run_path").cloned().unwrap_or_default();
590 let lifecycle = serde_json::json!({
591 "event": status,
592 "worker_id": worker_id,
593 "worker_name": worker_name,
594 "started_at": started_at,
595 "finished_at": finished_at,
596 });
597 self.notify(
598 "session/update",
599 serde_json::json!({
600 "sessionId": session_id,
601 "update": {
602 "sessionUpdate": "worker_update",
603 "content": {
604 "worker_id": worker_id,
605 "worker_name": worker_name,
606 "status": status,
607 "script": script,
608 "started_at": started_at,
609 "finished_at": finished_at,
610 "snapshot_path": snapshot_path,
611 "run_id": run_id,
612 "run_path": run_path,
613 "lifecycle": lifecycle,
614 "audit": audit,
615 "metadata": metadata,
616 },
617 },
618 }),
619 );
620 }
621}
622
623pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
625 crate::stdlib::json_to_vm_value(val)
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631
632 #[test]
633 fn test_json_rpc_request_format() {
634 let request = serde_json::json!({
635 "jsonrpc": "2.0",
636 "id": 1,
637 "method": "llm_call",
638 "params": {
639 "prompt": "Hello",
640 "system": "Be helpful",
641 },
642 });
643 let s = serde_json::to_string(&request).unwrap();
644 assert!(s.contains("\"jsonrpc\":\"2.0\""));
645 assert!(s.contains("\"id\":1"));
646 assert!(s.contains("\"method\":\"llm_call\""));
647 }
648
649 #[test]
650 fn test_json_rpc_notification_format() {
651 let notification = serde_json::json!({
652 "jsonrpc": "2.0",
653 "method": "output",
654 "params": {"text": "[harn] hello\n"},
655 });
656 let s = serde_json::to_string(¬ification).unwrap();
657 assert!(s.contains("\"method\":\"output\""));
658 assert!(!s.contains("\"id\""));
659 }
660
661 #[test]
662 fn test_json_rpc_error_response_parsing() {
663 let response = serde_json::json!({
664 "jsonrpc": "2.0",
665 "id": 1,
666 "error": {
667 "code": -32600,
668 "message": "Invalid request",
669 },
670 });
671 assert!(response.get("error").is_some());
672 assert_eq!(
673 response["error"]["message"].as_str().unwrap(),
674 "Invalid request"
675 );
676 }
677
678 #[test]
679 fn test_json_rpc_success_response_parsing() {
680 let response = serde_json::json!({
681 "jsonrpc": "2.0",
682 "id": 1,
683 "result": {
684 "text": "Hello world",
685 "input_tokens": 10,
686 "output_tokens": 5,
687 },
688 });
689 assert!(response.get("result").is_some());
690 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
691 }
692
693 #[test]
694 fn test_cancelled_flag() {
695 let cancelled = Arc::new(AtomicBool::new(false));
696 assert!(!cancelled.load(Ordering::SeqCst));
697 cancelled.store(true, Ordering::SeqCst);
698 assert!(cancelled.load(Ordering::SeqCst));
699 }
700
701 #[test]
702 fn queued_messages_are_filtered_by_delivery_mode() {
703 let runtime = tokio::runtime::Builder::new_current_thread()
704 .enable_all()
705 .build()
706 .unwrap();
707 runtime.block_on(async {
708 let bridge = HostBridge::from_parts(
709 Arc::new(Mutex::new(HashMap::new())),
710 Arc::new(AtomicBool::new(false)),
711 Arc::new(std::sync::Mutex::new(())),
712 1,
713 );
714 bridge
715 .push_queued_user_message("first".to_string(), "finish_step")
716 .await;
717 bridge
718 .push_queued_user_message("second".to_string(), "wait_for_completion")
719 .await;
720
721 let finish_step = bridge.take_queued_user_messages(false, true, false).await;
722 assert_eq!(finish_step.len(), 1);
723 assert_eq!(finish_step[0].content, "first");
724
725 let turn_end = bridge.take_queued_user_messages(false, false, true).await;
726 assert_eq!(turn_end.len(), 1);
727 assert_eq!(turn_end[0].content, "second");
728 });
729 }
730
731 #[test]
732 fn test_json_result_to_vm_value_string() {
733 let val = serde_json::json!("hello");
734 let vm_val = json_result_to_vm_value(&val);
735 assert_eq!(vm_val.display(), "hello");
736 }
737
738 #[test]
739 fn test_json_result_to_vm_value_dict() {
740 let val = serde_json::json!({"name": "test", "count": 42});
741 let vm_val = json_result_to_vm_value(&val);
742 let VmValue::Dict(d) = &vm_val else {
743 unreachable!("Expected Dict, got {:?}", vm_val);
744 };
745 assert_eq!(d.get("name").unwrap().display(), "test");
746 assert_eq!(d.get("count").unwrap().display(), "42");
747 }
748
749 #[test]
750 fn test_json_result_to_vm_value_null() {
751 let val = serde_json::json!(null);
752 let vm_val = json_result_to_vm_value(&val);
753 assert!(matches!(vm_val, VmValue::Nil));
754 }
755
756 #[test]
757 fn test_json_result_to_vm_value_nested() {
758 let val = serde_json::json!({
759 "text": "response",
760 "tool_calls": [
761 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
762 ],
763 "input_tokens": 100,
764 "output_tokens": 50,
765 });
766 let vm_val = json_result_to_vm_value(&val);
767 let VmValue::Dict(d) = &vm_val else {
768 unreachable!("Expected Dict, got {:?}", vm_val);
769 };
770 assert_eq!(d.get("text").unwrap().display(), "response");
771 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
772 unreachable!("Expected List for tool_calls");
773 };
774 assert_eq!(list.len(), 1);
775 }
776
777 #[test]
778 fn test_timeout_duration() {
779 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
780 }
781}