1use std::collections::{HashMap, VecDeque};
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::orchestration::MutationSessionRecord;
17use crate::value::{ErrorCategory, VmError, VmValue};
18use crate::visible_text::VisibleTextState;
19
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
22
23pub struct HostBridge {
30 next_id: AtomicU64,
31 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
33 cancelled: Arc<AtomicBool>,
35 stdout_lock: Arc<std::sync::Mutex<()>>,
37 session_id: std::sync::Mutex<String>,
39 script_name: std::sync::Mutex<String>,
41 queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>>,
43 resume_requested: Arc<AtomicBool>,
45 visible_call_states: std::sync::Mutex<HashMap<String, VisibleTextState>>,
47 visible_call_streams: std::sync::Mutex<HashMap<String, bool>>,
49 #[cfg(test)]
50 recorded_notifications: Arc<std::sync::Mutex<Vec<serde_json::Value>>>,
51}
52
53#[derive(Clone, Debug, PartialEq, Eq)]
54pub enum QueuedUserMessageMode {
55 InterruptImmediate,
56 FinishStep,
57 WaitForCompletion,
58}
59
60#[derive(Clone, Copy, Debug, PartialEq, Eq)]
61pub enum DeliveryCheckpoint {
62 InterruptImmediate,
63 AfterCurrentOperation,
64 EndOfInteraction,
65}
66
67impl QueuedUserMessageMode {
68 fn from_str(value: &str) -> Self {
69 match value {
70 "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
71 "finish_step" | "after_current_operation" => Self::FinishStep,
72 _ => Self::WaitForCompletion,
73 }
74 }
75}
76
77#[derive(Clone, Debug, PartialEq, Eq)]
78pub struct QueuedUserMessage {
79 pub content: String,
80 pub mode: QueuedUserMessageMode,
81}
82
83#[allow(clippy::new_without_default)]
85impl HostBridge {
86 pub fn new() -> Self {
91 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
92 Arc::new(Mutex::new(HashMap::new()));
93 let cancelled = Arc::new(AtomicBool::new(false));
94 let queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>> =
95 Arc::new(Mutex::new(VecDeque::new()));
96 let resume_requested = Arc::new(AtomicBool::new(false));
97
98 let pending_clone = pending.clone();
100 let cancelled_clone = cancelled.clone();
101 let queued_clone = queued_user_messages.clone();
102 let resume_clone = resume_requested.clone();
103 tokio::task::spawn_local(async move {
104 let stdin = tokio::io::stdin();
105 let reader = tokio::io::BufReader::new(stdin);
106 let mut lines = reader.lines();
107
108 while let Ok(Some(line)) = lines.next_line().await {
109 let line = line.trim().to_string();
110 if line.is_empty() {
111 continue;
112 }
113
114 let msg: serde_json::Value = match serde_json::from_str(&line) {
115 Ok(v) => v,
116 Err(_) => continue, };
118
119 if msg.get("id").is_none() {
121 if let Some(method) = msg["method"].as_str() {
122 if method == "cancel" {
123 cancelled_clone.store(true, Ordering::SeqCst);
124 } else if method == "agent/resume" {
125 resume_clone.store(true, Ordering::SeqCst);
126 } else if method == "user_message"
127 || method == "session/input"
128 || method == "agent/user_message"
129 {
130 let params = &msg["params"];
131 let content = params
132 .get("content")
133 .and_then(|v| v.as_str())
134 .unwrap_or("")
135 .to_string();
136 if !content.is_empty() {
137 let mode = QueuedUserMessageMode::from_str(
138 params
139 .get("mode")
140 .and_then(|v| v.as_str())
141 .unwrap_or("wait_for_completion"),
142 );
143 queued_clone
144 .lock()
145 .await
146 .push_back(QueuedUserMessage { content, mode });
147 }
148 }
149 }
150 continue;
151 }
152
153 if let Some(id) = msg["id"].as_u64() {
155 let mut pending = pending_clone.lock().await;
156 if let Some(sender) = pending.remove(&id) {
157 let _ = sender.send(msg);
158 }
159 }
160 }
161
162 let mut pending = pending_clone.lock().await;
164 pending.clear();
165 });
166
167 Self {
168 next_id: AtomicU64::new(1),
169 pending,
170 cancelled,
171 stdout_lock: Arc::new(std::sync::Mutex::new(())),
172 session_id: std::sync::Mutex::new(String::new()),
173 script_name: std::sync::Mutex::new(String::new()),
174 queued_user_messages,
175 resume_requested,
176 visible_call_states: std::sync::Mutex::new(HashMap::new()),
177 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
178 #[cfg(test)]
179 recorded_notifications: Arc::new(std::sync::Mutex::new(Vec::new())),
180 }
181 }
182
183 pub fn from_parts(
189 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
190 cancelled: Arc<AtomicBool>,
191 stdout_lock: Arc<std::sync::Mutex<()>>,
192 start_id: u64,
193 ) -> Self {
194 Self {
195 next_id: AtomicU64::new(start_id),
196 pending,
197 cancelled,
198 stdout_lock,
199 session_id: std::sync::Mutex::new(String::new()),
200 script_name: std::sync::Mutex::new(String::new()),
201 queued_user_messages: Arc::new(Mutex::new(VecDeque::new())),
202 resume_requested: Arc::new(AtomicBool::new(false)),
203 visible_call_states: std::sync::Mutex::new(HashMap::new()),
204 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
205 #[cfg(test)]
206 recorded_notifications: Arc::new(std::sync::Mutex::new(Vec::new())),
207 }
208 }
209
210 pub fn set_session_id(&self, id: &str) {
212 *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
213 }
214
215 pub fn set_script_name(&self, name: &str) {
217 *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
218 }
219
220 fn get_script_name(&self) -> String {
222 self.script_name
223 .lock()
224 .unwrap_or_else(|e| e.into_inner())
225 .clone()
226 }
227
228 fn get_session_id(&self) -> String {
230 self.session_id
231 .lock()
232 .unwrap_or_else(|e| e.into_inner())
233 .clone()
234 }
235
236 fn write_line(&self, line: &str) -> Result<(), VmError> {
238 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
239 let mut stdout = std::io::stdout().lock();
240 stdout
241 .write_all(line.as_bytes())
242 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
243 stdout
244 .write_all(b"\n")
245 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
246 stdout
247 .flush()
248 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
249 Ok(())
250 }
251
252 pub async fn call(
255 &self,
256 method: &str,
257 params: serde_json::Value,
258 ) -> Result<serde_json::Value, VmError> {
259 if self.is_cancelled() {
260 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
261 }
262
263 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
264
265 let request = crate::jsonrpc::request(id, method, params);
266
267 let (tx, rx) = oneshot::channel();
269 {
270 let mut pending = self.pending.lock().await;
271 pending.insert(id, tx);
272 }
273
274 let line = serde_json::to_string(&request)
276 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
277 if let Err(e) = self.write_line(&line) {
278 let mut pending = self.pending.lock().await;
280 pending.remove(&id);
281 return Err(e);
282 }
283
284 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
286 Ok(Ok(msg)) => msg,
287 Ok(Err(_)) => {
288 return Err(VmError::Runtime(
290 "Bridge: host closed connection before responding".into(),
291 ));
292 }
293 Err(_) => {
294 let mut pending = self.pending.lock().await;
296 pending.remove(&id);
297 return Err(VmError::Runtime(format!(
298 "Bridge: host did not respond to '{method}' within {}s",
299 DEFAULT_TIMEOUT.as_secs()
300 )));
301 }
302 };
303
304 if let Some(error) = response.get("error") {
306 let message = error["message"].as_str().unwrap_or("Unknown host error");
307 let code = error["code"].as_i64().unwrap_or(-1);
308 if code == -32001 {
310 return Err(VmError::CategorizedError {
311 message: message.to_string(),
312 category: ErrorCategory::ToolRejected,
313 });
314 }
315 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
316 }
317
318 Ok(response["result"].clone())
319 }
320
321 pub fn notify(&self, method: &str, params: serde_json::Value) {
324 let notification = crate::jsonrpc::notification(method, params);
325 #[cfg(test)]
326 self.recorded_notifications
327 .lock()
328 .unwrap_or_else(|e| e.into_inner())
329 .push(notification.clone());
330 if let Ok(line) = serde_json::to_string(¬ification) {
331 let _ = self.write_line(&line);
332 }
333 }
334
335 #[cfg(test)]
336 pub(crate) fn recorded_notifications(&self) -> Vec<serde_json::Value> {
337 self.recorded_notifications
338 .lock()
339 .unwrap_or_else(|e| e.into_inner())
340 .clone()
341 }
342
343 pub fn is_cancelled(&self) -> bool {
345 self.cancelled.load(Ordering::SeqCst)
346 }
347
348 pub fn take_resume_signal(&self) -> bool {
349 self.resume_requested.swap(false, Ordering::SeqCst)
350 }
351
352 pub fn signal_resume(&self) {
353 self.resume_requested.store(true, Ordering::SeqCst);
354 }
355
356 pub async fn push_queued_user_message(&self, content: String, mode: &str) {
357 self.queued_user_messages
358 .lock()
359 .await
360 .push_back(QueuedUserMessage {
361 content,
362 mode: QueuedUserMessageMode::from_str(mode),
363 });
364 }
365
366 pub async fn take_queued_user_messages(
367 &self,
368 include_interrupt_immediate: bool,
369 include_finish_step: bool,
370 include_wait_for_completion: bool,
371 ) -> Vec<QueuedUserMessage> {
372 let mut queue = self.queued_user_messages.lock().await;
373 let mut selected = Vec::new();
374 let mut retained = VecDeque::new();
375 while let Some(message) = queue.pop_front() {
376 let should_take = match message.mode {
377 QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
378 QueuedUserMessageMode::FinishStep => include_finish_step,
379 QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
380 };
381 if should_take {
382 selected.push(message);
383 } else {
384 retained.push_back(message);
385 }
386 }
387 *queue = retained;
388 selected
389 }
390
391 pub async fn take_queued_user_messages_for(
392 &self,
393 checkpoint: DeliveryCheckpoint,
394 ) -> Vec<QueuedUserMessage> {
395 match checkpoint {
396 DeliveryCheckpoint::InterruptImmediate => {
397 self.take_queued_user_messages(true, false, false).await
398 }
399 DeliveryCheckpoint::AfterCurrentOperation => {
400 self.take_queued_user_messages(false, true, false).await
401 }
402 DeliveryCheckpoint::EndOfInteraction => {
403 self.take_queued_user_messages(false, false, true).await
404 }
405 }
406 }
407
408 pub fn send_output(&self, text: &str) {
410 self.notify("output", serde_json::json!({"text": text}));
411 }
412
413 pub fn send_progress(
415 &self,
416 phase: &str,
417 message: &str,
418 progress: Option<i64>,
419 total: Option<i64>,
420 data: Option<serde_json::Value>,
421 ) {
422 let mut payload = serde_json::json!({"phase": phase, "message": message});
423 if let Some(p) = progress {
424 payload["progress"] = serde_json::json!(p);
425 }
426 if let Some(t) = total {
427 payload["total"] = serde_json::json!(t);
428 }
429 if let Some(d) = data {
430 payload["data"] = d;
431 }
432 self.notify("progress", payload);
433 }
434
435 pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
437 let mut payload = serde_json::json!({"level": level, "message": message});
438 if let Some(f) = fields {
439 payload["fields"] = f;
440 }
441 self.notify("log", payload);
442 }
443
444 pub fn send_call_start(
447 &self,
448 call_id: &str,
449 call_type: &str,
450 name: &str,
451 metadata: serde_json::Value,
452 ) {
453 let session_id = self.get_session_id();
454 let script = self.get_script_name();
455 let stream_publicly = metadata
456 .get("stream_publicly")
457 .and_then(|value| value.as_bool())
458 .unwrap_or(true);
459 self.visible_call_streams
460 .lock()
461 .unwrap_or_else(|e| e.into_inner())
462 .insert(call_id.to_string(), stream_publicly);
463 self.notify(
464 "session/update",
465 serde_json::json!({
466 "sessionId": session_id,
467 "update": {
468 "sessionUpdate": "call_start",
469 "content": {
470 "toolCallId": call_id,
471 "call_type": call_type,
472 "name": name,
473 "script": script,
474 "metadata": metadata,
475 },
476 },
477 }),
478 );
479 }
480
481 pub fn send_call_progress(
484 &self,
485 call_id: &str,
486 delta: &str,
487 accumulated_tokens: u64,
488 user_visible: bool,
489 ) {
490 let session_id = self.get_session_id();
491 let (visible_text, visible_delta) = {
492 let stream_publicly = self
493 .visible_call_streams
494 .lock()
495 .unwrap_or_else(|e| e.into_inner())
496 .get(call_id)
497 .copied()
498 .unwrap_or(true);
499 let mut states = self
500 .visible_call_states
501 .lock()
502 .unwrap_or_else(|e| e.into_inner());
503 let state = states.entry(call_id.to_string()).or_default();
504 state.push(delta, stream_publicly)
505 };
506 self.notify(
507 "session/update",
508 serde_json::json!({
509 "sessionId": session_id,
510 "update": {
511 "sessionUpdate": "call_progress",
512 "content": {
513 "toolCallId": call_id,
514 "delta": delta,
515 "accumulated_tokens": accumulated_tokens,
516 "visible_text": visible_text,
517 "visible_delta": visible_delta,
518 "user_visible": user_visible,
519 },
520 },
521 }),
522 );
523 }
524
525 pub fn send_call_end(
527 &self,
528 call_id: &str,
529 call_type: &str,
530 name: &str,
531 duration_ms: u64,
532 status: &str,
533 metadata: serde_json::Value,
534 ) {
535 let session_id = self.get_session_id();
536 let script = self.get_script_name();
537 self.visible_call_states
538 .lock()
539 .unwrap_or_else(|e| e.into_inner())
540 .remove(call_id);
541 self.visible_call_streams
542 .lock()
543 .unwrap_or_else(|e| e.into_inner())
544 .remove(call_id);
545 self.notify(
546 "session/update",
547 serde_json::json!({
548 "sessionId": session_id,
549 "update": {
550 "sessionUpdate": "call_end",
551 "content": {
552 "toolCallId": call_id,
553 "call_type": call_type,
554 "name": name,
555 "script": script,
556 "duration_ms": duration_ms,
557 "status": status,
558 "metadata": metadata,
559 },
560 },
561 }),
562 );
563 }
564
565 pub fn send_worker_update(
567 &self,
568 worker_id: &str,
569 worker_name: &str,
570 status: &str,
571 metadata: serde_json::Value,
572 audit: Option<&MutationSessionRecord>,
573 ) {
574 let session_id = self.get_session_id();
575 let script = self.get_script_name();
576 let started_at = metadata.get("started_at").cloned().unwrap_or_default();
577 let finished_at = metadata.get("finished_at").cloned().unwrap_or_default();
578 let snapshot_path = metadata.get("snapshot_path").cloned().unwrap_or_default();
579 let run_id = metadata.get("child_run_id").cloned().unwrap_or_default();
580 let run_path = metadata.get("child_run_path").cloned().unwrap_or_default();
581 let lifecycle = serde_json::json!({
582 "event": status,
583 "worker_id": worker_id,
584 "worker_name": worker_name,
585 "started_at": started_at,
586 "finished_at": finished_at,
587 });
588 self.notify(
589 "session/update",
590 serde_json::json!({
591 "sessionId": session_id,
592 "update": {
593 "sessionUpdate": "worker_update",
594 "content": {
595 "worker_id": worker_id,
596 "worker_name": worker_name,
597 "status": status,
598 "script": script,
599 "started_at": started_at,
600 "finished_at": finished_at,
601 "snapshot_path": snapshot_path,
602 "run_id": run_id,
603 "run_path": run_path,
604 "lifecycle": lifecycle,
605 "audit": audit,
606 "metadata": metadata,
607 },
608 },
609 }),
610 );
611 }
612}
613
614pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
616 crate::stdlib::json_to_vm_value(val)
617}
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622
623 #[test]
624 fn test_json_rpc_request_format() {
625 let request = crate::jsonrpc::request(
626 1,
627 "llm_call",
628 serde_json::json!({
629 "prompt": "Hello",
630 "system": "Be helpful",
631 }),
632 );
633 let s = serde_json::to_string(&request).unwrap();
634 assert!(s.contains("\"jsonrpc\":\"2.0\""));
635 assert!(s.contains("\"id\":1"));
636 assert!(s.contains("\"method\":\"llm_call\""));
637 }
638
639 #[test]
640 fn test_json_rpc_notification_format() {
641 let notification =
642 crate::jsonrpc::notification("output", serde_json::json!({"text": "[harn] hello\n"}));
643 let s = serde_json::to_string(¬ification).unwrap();
644 assert!(s.contains("\"method\":\"output\""));
645 assert!(!s.contains("\"id\""));
646 }
647
648 #[test]
649 fn test_json_rpc_error_response_parsing() {
650 let response = crate::jsonrpc::error_response(1, -32600, "Invalid request");
651 assert!(response.get("error").is_some());
652 assert_eq!(
653 response["error"]["message"].as_str().unwrap(),
654 "Invalid request"
655 );
656 }
657
658 #[test]
659 fn test_json_rpc_success_response_parsing() {
660 let response = crate::jsonrpc::response(
661 1,
662 serde_json::json!({
663 "text": "Hello world",
664 "input_tokens": 10,
665 "output_tokens": 5,
666 }),
667 );
668 assert!(response.get("result").is_some());
669 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
670 }
671
672 #[test]
673 fn test_cancelled_flag() {
674 let cancelled = Arc::new(AtomicBool::new(false));
675 assert!(!cancelled.load(Ordering::SeqCst));
676 cancelled.store(true, Ordering::SeqCst);
677 assert!(cancelled.load(Ordering::SeqCst));
678 }
679
680 #[test]
681 fn queued_messages_are_filtered_by_delivery_mode() {
682 let runtime = tokio::runtime::Builder::new_current_thread()
683 .enable_all()
684 .build()
685 .unwrap();
686 runtime.block_on(async {
687 let bridge = HostBridge::from_parts(
688 Arc::new(Mutex::new(HashMap::new())),
689 Arc::new(AtomicBool::new(false)),
690 Arc::new(std::sync::Mutex::new(())),
691 1,
692 );
693 bridge
694 .push_queued_user_message("first".to_string(), "finish_step")
695 .await;
696 bridge
697 .push_queued_user_message("second".to_string(), "wait_for_completion")
698 .await;
699
700 let finish_step = bridge.take_queued_user_messages(false, true, false).await;
701 assert_eq!(finish_step.len(), 1);
702 assert_eq!(finish_step[0].content, "first");
703
704 let turn_end = bridge.take_queued_user_messages(false, false, true).await;
705 assert_eq!(turn_end.len(), 1);
706 assert_eq!(turn_end[0].content, "second");
707 });
708 }
709
710 #[test]
711 fn test_json_result_to_vm_value_string() {
712 let val = serde_json::json!("hello");
713 let vm_val = json_result_to_vm_value(&val);
714 assert_eq!(vm_val.display(), "hello");
715 }
716
717 #[test]
718 fn test_json_result_to_vm_value_dict() {
719 let val = serde_json::json!({"name": "test", "count": 42});
720 let vm_val = json_result_to_vm_value(&val);
721 let VmValue::Dict(d) = &vm_val else {
722 unreachable!("Expected Dict, got {:?}", vm_val);
723 };
724 assert_eq!(d.get("name").unwrap().display(), "test");
725 assert_eq!(d.get("count").unwrap().display(), "42");
726 }
727
728 #[test]
729 fn test_json_result_to_vm_value_null() {
730 let val = serde_json::json!(null);
731 let vm_val = json_result_to_vm_value(&val);
732 assert!(matches!(vm_val, VmValue::Nil));
733 }
734
735 #[test]
736 fn test_json_result_to_vm_value_nested() {
737 let val = serde_json::json!({
738 "text": "response",
739 "tool_calls": [
740 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
741 ],
742 "input_tokens": 100,
743 "output_tokens": 50,
744 });
745 let vm_val = json_result_to_vm_value(&val);
746 let VmValue::Dict(d) = &vm_val else {
747 unreachable!("Expected Dict, got {:?}", vm_val);
748 };
749 assert_eq!(d.get("text").unwrap().display(), "response");
750 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
751 unreachable!("Expected List for tool_calls");
752 };
753 assert_eq!(list.len(), 1);
754 }
755
756 #[test]
757 fn test_timeout_duration() {
758 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
759 }
760}