1use std::collections::{HashMap, VecDeque};
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::orchestration::MutationSessionRecord;
17use crate::value::{ErrorCategory, VmError, VmValue};
18use crate::visible_text::VisibleTextState;
19
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
22
23pub struct HostBridge {
30 next_id: AtomicU64,
31 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
33 cancelled: Arc<AtomicBool>,
35 stdout_lock: Arc<std::sync::Mutex<()>>,
37 session_id: std::sync::Mutex<String>,
39 script_name: std::sync::Mutex<String>,
41 queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>>,
43 resume_requested: Arc<AtomicBool>,
45 visible_call_states: std::sync::Mutex<HashMap<String, VisibleTextState>>,
47 visible_call_streams: std::sync::Mutex<HashMap<String, bool>>,
49}
50
51#[derive(Clone, Debug, PartialEq, Eq)]
52pub enum QueuedUserMessageMode {
53 InterruptImmediate,
54 FinishStep,
55 WaitForCompletion,
56}
57
58#[derive(Clone, Copy, Debug, PartialEq, Eq)]
59pub enum DeliveryCheckpoint {
60 InterruptImmediate,
61 AfterCurrentOperation,
62 EndOfInteraction,
63}
64
65impl QueuedUserMessageMode {
66 fn from_str(value: &str) -> Self {
67 match value {
68 "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
69 "finish_step" | "after_current_operation" => Self::FinishStep,
70 _ => Self::WaitForCompletion,
71 }
72 }
73}
74
75#[derive(Clone, Debug, PartialEq, Eq)]
76pub struct QueuedUserMessage {
77 pub content: String,
78 pub mode: QueuedUserMessageMode,
79}
80
81#[allow(clippy::new_without_default)]
83impl HostBridge {
84 pub fn new() -> Self {
89 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
90 Arc::new(Mutex::new(HashMap::new()));
91 let cancelled = Arc::new(AtomicBool::new(false));
92 let queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>> =
93 Arc::new(Mutex::new(VecDeque::new()));
94 let resume_requested = Arc::new(AtomicBool::new(false));
95
96 let pending_clone = pending.clone();
98 let cancelled_clone = cancelled.clone();
99 let queued_clone = queued_user_messages.clone();
100 let resume_clone = resume_requested.clone();
101 tokio::task::spawn_local(async move {
102 let stdin = tokio::io::stdin();
103 let reader = tokio::io::BufReader::new(stdin);
104 let mut lines = reader.lines();
105
106 while let Ok(Some(line)) = lines.next_line().await {
107 let line = line.trim().to_string();
108 if line.is_empty() {
109 continue;
110 }
111
112 let msg: serde_json::Value = match serde_json::from_str(&line) {
113 Ok(v) => v,
114 Err(_) => continue, };
116
117 if msg.get("id").is_none() {
119 if let Some(method) = msg["method"].as_str() {
120 if method == "cancel" {
121 cancelled_clone.store(true, Ordering::SeqCst);
122 } else if method == "agent/resume" {
123 resume_clone.store(true, Ordering::SeqCst);
124 } else if method == "user_message"
125 || method == "session/input"
126 || method == "agent/user_message"
127 {
128 let params = &msg["params"];
129 let content = params
130 .get("content")
131 .and_then(|v| v.as_str())
132 .unwrap_or("")
133 .to_string();
134 if !content.is_empty() {
135 let mode = QueuedUserMessageMode::from_str(
136 params
137 .get("mode")
138 .and_then(|v| v.as_str())
139 .unwrap_or("wait_for_completion"),
140 );
141 queued_clone
142 .lock()
143 .await
144 .push_back(QueuedUserMessage { content, mode });
145 }
146 }
147 }
148 continue;
149 }
150
151 if let Some(id) = msg["id"].as_u64() {
153 let mut pending = pending_clone.lock().await;
154 if let Some(sender) = pending.remove(&id) {
155 let _ = sender.send(msg);
156 }
157 }
158 }
159
160 let mut pending = pending_clone.lock().await;
162 pending.clear();
163 });
164
165 Self {
166 next_id: AtomicU64::new(1),
167 pending,
168 cancelled,
169 stdout_lock: Arc::new(std::sync::Mutex::new(())),
170 session_id: std::sync::Mutex::new(String::new()),
171 script_name: std::sync::Mutex::new(String::new()),
172 queued_user_messages,
173 resume_requested,
174 visible_call_states: std::sync::Mutex::new(HashMap::new()),
175 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
176 }
177 }
178
179 pub fn from_parts(
185 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
186 cancelled: Arc<AtomicBool>,
187 stdout_lock: Arc<std::sync::Mutex<()>>,
188 start_id: u64,
189 ) -> Self {
190 Self {
191 next_id: AtomicU64::new(start_id),
192 pending,
193 cancelled,
194 stdout_lock,
195 session_id: std::sync::Mutex::new(String::new()),
196 script_name: std::sync::Mutex::new(String::new()),
197 queued_user_messages: Arc::new(Mutex::new(VecDeque::new())),
198 resume_requested: Arc::new(AtomicBool::new(false)),
199 visible_call_states: std::sync::Mutex::new(HashMap::new()),
200 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
201 }
202 }
203
204 pub fn set_session_id(&self, id: &str) {
206 *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
207 }
208
209 pub fn set_script_name(&self, name: &str) {
211 *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
212 }
213
214 fn get_script_name(&self) -> String {
216 self.script_name
217 .lock()
218 .unwrap_or_else(|e| e.into_inner())
219 .clone()
220 }
221
222 fn get_session_id(&self) -> String {
224 self.session_id
225 .lock()
226 .unwrap_or_else(|e| e.into_inner())
227 .clone()
228 }
229
230 fn write_line(&self, line: &str) -> Result<(), VmError> {
232 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
233 let mut stdout = std::io::stdout().lock();
234 stdout
235 .write_all(line.as_bytes())
236 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
237 stdout
238 .write_all(b"\n")
239 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
240 stdout
241 .flush()
242 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
243 Ok(())
244 }
245
246 pub async fn call(
249 &self,
250 method: &str,
251 params: serde_json::Value,
252 ) -> Result<serde_json::Value, VmError> {
253 if self.is_cancelled() {
254 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
255 }
256
257 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
258
259 let request = serde_json::json!({
260 "jsonrpc": "2.0",
261 "id": id,
262 "method": method,
263 "params": params,
264 });
265
266 let (tx, rx) = oneshot::channel();
268 {
269 let mut pending = self.pending.lock().await;
270 pending.insert(id, tx);
271 }
272
273 let line = serde_json::to_string(&request)
275 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
276 if let Err(e) = self.write_line(&line) {
277 let mut pending = self.pending.lock().await;
279 pending.remove(&id);
280 return Err(e);
281 }
282
283 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
285 Ok(Ok(msg)) => msg,
286 Ok(Err(_)) => {
287 return Err(VmError::Runtime(
289 "Bridge: host closed connection before responding".into(),
290 ));
291 }
292 Err(_) => {
293 let mut pending = self.pending.lock().await;
295 pending.remove(&id);
296 return Err(VmError::Runtime(format!(
297 "Bridge: host did not respond to '{method}' within {}s",
298 DEFAULT_TIMEOUT.as_secs()
299 )));
300 }
301 };
302
303 if let Some(error) = response.get("error") {
305 let message = error["message"].as_str().unwrap_or("Unknown host error");
306 let code = error["code"].as_i64().unwrap_or(-1);
307 if code == -32001 {
309 return Err(VmError::CategorizedError {
310 message: message.to_string(),
311 category: ErrorCategory::ToolRejected,
312 });
313 }
314 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
315 }
316
317 Ok(response["result"].clone())
318 }
319
320 pub fn notify(&self, method: &str, params: serde_json::Value) {
323 let notification = serde_json::json!({
324 "jsonrpc": "2.0",
325 "method": method,
326 "params": params,
327 });
328 if let Ok(line) = serde_json::to_string(¬ification) {
329 let _ = self.write_line(&line);
330 }
331 }
332
333 pub fn is_cancelled(&self) -> bool {
335 self.cancelled.load(Ordering::SeqCst)
336 }
337
338 pub fn take_resume_signal(&self) -> bool {
339 self.resume_requested.swap(false, Ordering::SeqCst)
340 }
341
342 pub fn signal_resume(&self) {
343 self.resume_requested.store(true, Ordering::SeqCst);
344 }
345
346 pub async fn push_queued_user_message(&self, content: String, mode: &str) {
347 self.queued_user_messages
348 .lock()
349 .await
350 .push_back(QueuedUserMessage {
351 content,
352 mode: QueuedUserMessageMode::from_str(mode),
353 });
354 }
355
356 pub async fn take_queued_user_messages(
357 &self,
358 include_interrupt_immediate: bool,
359 include_finish_step: bool,
360 include_wait_for_completion: bool,
361 ) -> Vec<QueuedUserMessage> {
362 let mut queue = self.queued_user_messages.lock().await;
363 let mut selected = Vec::new();
364 let mut retained = VecDeque::new();
365 while let Some(message) = queue.pop_front() {
366 let should_take = match message.mode {
367 QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
368 QueuedUserMessageMode::FinishStep => include_finish_step,
369 QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
370 };
371 if should_take {
372 selected.push(message);
373 } else {
374 retained.push_back(message);
375 }
376 }
377 *queue = retained;
378 selected
379 }
380
381 pub async fn take_queued_user_messages_for(
382 &self,
383 checkpoint: DeliveryCheckpoint,
384 ) -> Vec<QueuedUserMessage> {
385 match checkpoint {
386 DeliveryCheckpoint::InterruptImmediate => {
387 self.take_queued_user_messages(true, false, false).await
388 }
389 DeliveryCheckpoint::AfterCurrentOperation => {
390 self.take_queued_user_messages(false, true, false).await
391 }
392 DeliveryCheckpoint::EndOfInteraction => {
393 self.take_queued_user_messages(false, false, true).await
394 }
395 }
396 }
397
398 pub fn send_output(&self, text: &str) {
400 self.notify("output", serde_json::json!({"text": text}));
401 }
402
403 pub fn send_progress(
405 &self,
406 phase: &str,
407 message: &str,
408 progress: Option<i64>,
409 total: Option<i64>,
410 data: Option<serde_json::Value>,
411 ) {
412 let mut payload = serde_json::json!({"phase": phase, "message": message});
413 if let Some(p) = progress {
414 payload["progress"] = serde_json::json!(p);
415 }
416 if let Some(t) = total {
417 payload["total"] = serde_json::json!(t);
418 }
419 if let Some(d) = data {
420 payload["data"] = d;
421 }
422 self.notify("progress", payload);
423 }
424
425 pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
427 let mut payload = serde_json::json!({"level": level, "message": message});
428 if let Some(f) = fields {
429 payload["fields"] = f;
430 }
431 self.notify("log", payload);
432 }
433
434 pub fn send_call_start(
437 &self,
438 call_id: &str,
439 call_type: &str,
440 name: &str,
441 metadata: serde_json::Value,
442 ) {
443 let session_id = self.get_session_id();
444 let script = self.get_script_name();
445 let stream_publicly = metadata
446 .get("stream_publicly")
447 .and_then(|value| value.as_bool())
448 .unwrap_or(true);
449 self.visible_call_streams
450 .lock()
451 .unwrap_or_else(|e| e.into_inner())
452 .insert(call_id.to_string(), stream_publicly);
453 self.notify(
454 "session/update",
455 serde_json::json!({
456 "sessionId": session_id,
457 "update": {
458 "sessionUpdate": "call_start",
459 "content": {
460 "call_id": call_id,
461 "call_type": call_type,
462 "name": name,
463 "script": script,
464 "metadata": metadata,
465 },
466 },
467 }),
468 );
469 }
470
471 pub fn send_call_progress(&self, call_id: &str, delta: &str, accumulated_tokens: u64) {
474 let session_id = self.get_session_id();
475 let (visible_text, visible_delta) = {
476 let stream_publicly = self
477 .visible_call_streams
478 .lock()
479 .unwrap_or_else(|e| e.into_inner())
480 .get(call_id)
481 .copied()
482 .unwrap_or(true);
483 let mut states = self
484 .visible_call_states
485 .lock()
486 .unwrap_or_else(|e| e.into_inner());
487 let state = states.entry(call_id.to_string()).or_default();
488 state.push(delta, stream_publicly)
489 };
490 self.notify(
491 "session/update",
492 serde_json::json!({
493 "sessionId": session_id,
494 "update": {
495 "sessionUpdate": "call_progress",
496 "content": {
497 "call_id": call_id,
498 "delta": delta,
499 "accumulated_tokens": accumulated_tokens,
500 "visible_text": visible_text,
501 "visible_delta": visible_delta,
502 },
503 },
504 }),
505 );
506 }
507
508 pub fn send_call_end(
510 &self,
511 call_id: &str,
512 call_type: &str,
513 name: &str,
514 duration_ms: u64,
515 status: &str,
516 metadata: serde_json::Value,
517 ) {
518 let session_id = self.get_session_id();
519 let script = self.get_script_name();
520 self.visible_call_states
521 .lock()
522 .unwrap_or_else(|e| e.into_inner())
523 .remove(call_id);
524 self.visible_call_streams
525 .lock()
526 .unwrap_or_else(|e| e.into_inner())
527 .remove(call_id);
528 self.notify(
529 "session/update",
530 serde_json::json!({
531 "sessionId": session_id,
532 "update": {
533 "sessionUpdate": "call_end",
534 "content": {
535 "call_id": call_id,
536 "call_type": call_type,
537 "name": name,
538 "script": script,
539 "duration_ms": duration_ms,
540 "status": status,
541 "metadata": metadata,
542 },
543 },
544 }),
545 );
546 }
547
548 pub fn send_worker_update(
550 &self,
551 worker_id: &str,
552 worker_name: &str,
553 status: &str,
554 metadata: serde_json::Value,
555 audit: Option<&MutationSessionRecord>,
556 ) {
557 let session_id = self.get_session_id();
558 let script = self.get_script_name();
559 let started_at = metadata.get("started_at").cloned().unwrap_or_default();
560 let finished_at = metadata.get("finished_at").cloned().unwrap_or_default();
561 let snapshot_path = metadata.get("snapshot_path").cloned().unwrap_or_default();
562 let run_id = metadata.get("child_run_id").cloned().unwrap_or_default();
563 let run_path = metadata.get("child_run_path").cloned().unwrap_or_default();
564 let lifecycle = serde_json::json!({
565 "event": status,
566 "worker_id": worker_id,
567 "worker_name": worker_name,
568 "started_at": started_at,
569 "finished_at": finished_at,
570 });
571 self.notify(
572 "session/update",
573 serde_json::json!({
574 "sessionId": session_id,
575 "update": {
576 "sessionUpdate": "worker_update",
577 "content": {
578 "worker_id": worker_id,
579 "worker_name": worker_name,
580 "status": status,
581 "script": script,
582 "started_at": started_at,
583 "finished_at": finished_at,
584 "snapshot_path": snapshot_path,
585 "run_id": run_id,
586 "run_path": run_path,
587 "lifecycle": lifecycle,
588 "audit": audit,
589 "metadata": metadata,
590 },
591 },
592 }),
593 );
594 }
595}
596
597pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
599 crate::stdlib::json_to_vm_value(val)
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn test_json_rpc_request_format() {
608 let request = serde_json::json!({
609 "jsonrpc": "2.0",
610 "id": 1,
611 "method": "llm_call",
612 "params": {
613 "prompt": "Hello",
614 "system": "Be helpful",
615 },
616 });
617 let s = serde_json::to_string(&request).unwrap();
618 assert!(s.contains("\"jsonrpc\":\"2.0\""));
619 assert!(s.contains("\"id\":1"));
620 assert!(s.contains("\"method\":\"llm_call\""));
621 }
622
623 #[test]
624 fn test_json_rpc_notification_format() {
625 let notification = serde_json::json!({
626 "jsonrpc": "2.0",
627 "method": "output",
628 "params": {"text": "[harn] hello\n"},
629 });
630 let s = serde_json::to_string(¬ification).unwrap();
631 assert!(s.contains("\"method\":\"output\""));
632 assert!(!s.contains("\"id\""));
633 }
634
635 #[test]
636 fn test_json_rpc_error_response_parsing() {
637 let response = serde_json::json!({
638 "jsonrpc": "2.0",
639 "id": 1,
640 "error": {
641 "code": -32600,
642 "message": "Invalid request",
643 },
644 });
645 assert!(response.get("error").is_some());
646 assert_eq!(
647 response["error"]["message"].as_str().unwrap(),
648 "Invalid request"
649 );
650 }
651
652 #[test]
653 fn test_json_rpc_success_response_parsing() {
654 let response = serde_json::json!({
655 "jsonrpc": "2.0",
656 "id": 1,
657 "result": {
658 "text": "Hello world",
659 "input_tokens": 10,
660 "output_tokens": 5,
661 },
662 });
663 assert!(response.get("result").is_some());
664 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
665 }
666
667 #[test]
668 fn test_cancelled_flag() {
669 let cancelled = Arc::new(AtomicBool::new(false));
670 assert!(!cancelled.load(Ordering::SeqCst));
671 cancelled.store(true, Ordering::SeqCst);
672 assert!(cancelled.load(Ordering::SeqCst));
673 }
674
675 #[test]
676 fn queued_messages_are_filtered_by_delivery_mode() {
677 let runtime = tokio::runtime::Builder::new_current_thread()
678 .enable_all()
679 .build()
680 .unwrap();
681 runtime.block_on(async {
682 let bridge = HostBridge::from_parts(
683 Arc::new(Mutex::new(HashMap::new())),
684 Arc::new(AtomicBool::new(false)),
685 Arc::new(std::sync::Mutex::new(())),
686 1,
687 );
688 bridge
689 .push_queued_user_message("first".to_string(), "finish_step")
690 .await;
691 bridge
692 .push_queued_user_message("second".to_string(), "wait_for_completion")
693 .await;
694
695 let finish_step = bridge.take_queued_user_messages(false, true, false).await;
696 assert_eq!(finish_step.len(), 1);
697 assert_eq!(finish_step[0].content, "first");
698
699 let turn_end = bridge.take_queued_user_messages(false, false, true).await;
700 assert_eq!(turn_end.len(), 1);
701 assert_eq!(turn_end[0].content, "second");
702 });
703 }
704
705 #[test]
706 fn test_json_result_to_vm_value_string() {
707 let val = serde_json::json!("hello");
708 let vm_val = json_result_to_vm_value(&val);
709 assert_eq!(vm_val.display(), "hello");
710 }
711
712 #[test]
713 fn test_json_result_to_vm_value_dict() {
714 let val = serde_json::json!({"name": "test", "count": 42});
715 let vm_val = json_result_to_vm_value(&val);
716 let VmValue::Dict(d) = &vm_val else {
717 unreachable!("Expected Dict, got {:?}", vm_val);
718 };
719 assert_eq!(d.get("name").unwrap().display(), "test");
720 assert_eq!(d.get("count").unwrap().display(), "42");
721 }
722
723 #[test]
724 fn test_json_result_to_vm_value_null() {
725 let val = serde_json::json!(null);
726 let vm_val = json_result_to_vm_value(&val);
727 assert!(matches!(vm_val, VmValue::Nil));
728 }
729
730 #[test]
731 fn test_json_result_to_vm_value_nested() {
732 let val = serde_json::json!({
733 "text": "response",
734 "tool_calls": [
735 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
736 ],
737 "input_tokens": 100,
738 "output_tokens": 50,
739 });
740 let vm_val = json_result_to_vm_value(&val);
741 let VmValue::Dict(d) = &vm_val else {
742 unreachable!("Expected Dict, got {:?}", vm_val);
743 };
744 assert_eq!(d.get("text").unwrap().display(), "response");
745 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
746 unreachable!("Expected List for tool_calls");
747 };
748 assert_eq!(list.len(), 1);
749 }
750
751 #[test]
752 fn test_timeout_duration() {
753 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
754 }
755}