1use std::collections::{HashMap, VecDeque};
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::orchestration::MutationSessionRecord;
17use crate::value::{ErrorCategory, VmError, VmValue};
18use crate::visible_text::VisibleTextState;
19
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
22
23pub struct HostBridge {
30 next_id: AtomicU64,
31 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
33 cancelled: Arc<AtomicBool>,
35 stdout_lock: Arc<std::sync::Mutex<()>>,
37 session_id: std::sync::Mutex<String>,
39 script_name: std::sync::Mutex<String>,
41 queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>>,
43 resume_requested: Arc<AtomicBool>,
45 visible_call_states: std::sync::Mutex<HashMap<String, VisibleTextState>>,
47 visible_call_streams: std::sync::Mutex<HashMap<String, bool>>,
49 #[cfg(test)]
50 recorded_notifications: Arc<std::sync::Mutex<Vec<serde_json::Value>>>,
51}
52
53#[derive(Clone, Debug, PartialEq, Eq)]
54pub enum QueuedUserMessageMode {
55 InterruptImmediate,
56 FinishStep,
57 WaitForCompletion,
58}
59
60#[derive(Clone, Copy, Debug, PartialEq, Eq)]
61pub enum DeliveryCheckpoint {
62 InterruptImmediate,
63 AfterCurrentOperation,
64 EndOfInteraction,
65}
66
67impl QueuedUserMessageMode {
68 fn from_str(value: &str) -> Self {
69 match value {
70 "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
71 "finish_step" | "after_current_operation" => Self::FinishStep,
72 _ => Self::WaitForCompletion,
73 }
74 }
75}
76
77#[derive(Clone, Debug, PartialEq, Eq)]
78pub struct QueuedUserMessage {
79 pub content: String,
80 pub mode: QueuedUserMessageMode,
81}
82
83#[allow(clippy::new_without_default)]
85impl HostBridge {
86 pub fn new() -> Self {
91 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
92 Arc::new(Mutex::new(HashMap::new()));
93 let cancelled = Arc::new(AtomicBool::new(false));
94 let queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>> =
95 Arc::new(Mutex::new(VecDeque::new()));
96 let resume_requested = Arc::new(AtomicBool::new(false));
97
98 let pending_clone = pending.clone();
100 let cancelled_clone = cancelled.clone();
101 let queued_clone = queued_user_messages.clone();
102 let resume_clone = resume_requested.clone();
103 tokio::task::spawn_local(async move {
104 let stdin = tokio::io::stdin();
105 let reader = tokio::io::BufReader::new(stdin);
106 let mut lines = reader.lines();
107
108 while let Ok(Some(line)) = lines.next_line().await {
109 let line = line.trim().to_string();
110 if line.is_empty() {
111 continue;
112 }
113
114 let msg: serde_json::Value = match serde_json::from_str(&line) {
115 Ok(v) => v,
116 Err(_) => continue,
117 };
118
119 if msg.get("id").is_none() {
121 if let Some(method) = msg["method"].as_str() {
122 if method == "cancel" {
123 cancelled_clone.store(true, Ordering::SeqCst);
124 } else if method == "agent/resume" {
125 resume_clone.store(true, Ordering::SeqCst);
126 } else if method == "user_message"
127 || method == "session/input"
128 || method == "agent/user_message"
129 {
130 let params = &msg["params"];
131 let content = params
132 .get("content")
133 .and_then(|v| v.as_str())
134 .unwrap_or("")
135 .to_string();
136 if !content.is_empty() {
137 let mode = QueuedUserMessageMode::from_str(
138 params
139 .get("mode")
140 .and_then(|v| v.as_str())
141 .unwrap_or("wait_for_completion"),
142 );
143 queued_clone
144 .lock()
145 .await
146 .push_back(QueuedUserMessage { content, mode });
147 }
148 }
149 }
150 continue;
151 }
152
153 if let Some(id) = msg["id"].as_u64() {
154 let mut pending = pending_clone.lock().await;
155 if let Some(sender) = pending.remove(&id) {
156 let _ = sender.send(msg);
157 }
158 }
159 }
160
161 let mut pending = pending_clone.lock().await;
163 pending.clear();
164 });
165
166 Self {
167 next_id: AtomicU64::new(1),
168 pending,
169 cancelled,
170 stdout_lock: Arc::new(std::sync::Mutex::new(())),
171 session_id: std::sync::Mutex::new(String::new()),
172 script_name: std::sync::Mutex::new(String::new()),
173 queued_user_messages,
174 resume_requested,
175 visible_call_states: std::sync::Mutex::new(HashMap::new()),
176 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
177 #[cfg(test)]
178 recorded_notifications: Arc::new(std::sync::Mutex::new(Vec::new())),
179 }
180 }
181
182 pub fn from_parts(
188 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
189 cancelled: Arc<AtomicBool>,
190 stdout_lock: Arc<std::sync::Mutex<()>>,
191 start_id: u64,
192 ) -> Self {
193 Self {
194 next_id: AtomicU64::new(start_id),
195 pending,
196 cancelled,
197 stdout_lock,
198 session_id: std::sync::Mutex::new(String::new()),
199 script_name: std::sync::Mutex::new(String::new()),
200 queued_user_messages: Arc::new(Mutex::new(VecDeque::new())),
201 resume_requested: Arc::new(AtomicBool::new(false)),
202 visible_call_states: std::sync::Mutex::new(HashMap::new()),
203 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
204 #[cfg(test)]
205 recorded_notifications: Arc::new(std::sync::Mutex::new(Vec::new())),
206 }
207 }
208
209 pub fn set_session_id(&self, id: &str) {
211 *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
212 }
213
214 pub fn set_script_name(&self, name: &str) {
216 *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
217 }
218
219 fn get_script_name(&self) -> String {
221 self.script_name
222 .lock()
223 .unwrap_or_else(|e| e.into_inner())
224 .clone()
225 }
226
227 fn get_session_id(&self) -> String {
229 self.session_id
230 .lock()
231 .unwrap_or_else(|e| e.into_inner())
232 .clone()
233 }
234
235 fn write_line(&self, line: &str) -> Result<(), VmError> {
237 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
238 let mut stdout = std::io::stdout().lock();
239 stdout
240 .write_all(line.as_bytes())
241 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
242 stdout
243 .write_all(b"\n")
244 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
245 stdout
246 .flush()
247 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
248 Ok(())
249 }
250
251 pub async fn call(
254 &self,
255 method: &str,
256 params: serde_json::Value,
257 ) -> Result<serde_json::Value, VmError> {
258 if self.is_cancelled() {
259 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
260 }
261
262 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
263
264 let request = crate::jsonrpc::request(id, method, params);
265
266 let (tx, rx) = oneshot::channel();
267 {
268 let mut pending = self.pending.lock().await;
269 pending.insert(id, tx);
270 }
271
272 let line = serde_json::to_string(&request)
273 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
274 if let Err(e) = self.write_line(&line) {
275 let mut pending = self.pending.lock().await;
276 pending.remove(&id);
277 return Err(e);
278 }
279
280 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
281 Ok(Ok(msg)) => msg,
282 Ok(Err(_)) => {
283 return Err(VmError::Runtime(
285 "Bridge: host closed connection before responding".into(),
286 ));
287 }
288 Err(_) => {
289 let mut pending = self.pending.lock().await;
290 pending.remove(&id);
291 return Err(VmError::Runtime(format!(
292 "Bridge: host did not respond to '{method}' within {}s",
293 DEFAULT_TIMEOUT.as_secs()
294 )));
295 }
296 };
297
298 if let Some(error) = response.get("error") {
299 let message = error["message"].as_str().unwrap_or("Unknown host error");
300 let code = error["code"].as_i64().unwrap_or(-1);
301 if code == -32001 {
303 return Err(VmError::CategorizedError {
304 message: message.to_string(),
305 category: ErrorCategory::ToolRejected,
306 });
307 }
308 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
309 }
310
311 Ok(response["result"].clone())
312 }
313
314 pub fn notify(&self, method: &str, params: serde_json::Value) {
317 let notification = crate::jsonrpc::notification(method, params);
318 #[cfg(test)]
319 self.recorded_notifications
320 .lock()
321 .unwrap_or_else(|e| e.into_inner())
322 .push(notification.clone());
323 if let Ok(line) = serde_json::to_string(¬ification) {
324 let _ = self.write_line(&line);
325 }
326 }
327
328 #[cfg(test)]
329 pub(crate) fn recorded_notifications(&self) -> Vec<serde_json::Value> {
330 self.recorded_notifications
331 .lock()
332 .unwrap_or_else(|e| e.into_inner())
333 .clone()
334 }
335
336 pub fn is_cancelled(&self) -> bool {
338 self.cancelled.load(Ordering::SeqCst)
339 }
340
341 pub fn take_resume_signal(&self) -> bool {
342 self.resume_requested.swap(false, Ordering::SeqCst)
343 }
344
345 pub fn signal_resume(&self) {
346 self.resume_requested.store(true, Ordering::SeqCst);
347 }
348
349 pub async fn push_queued_user_message(&self, content: String, mode: &str) {
350 self.queued_user_messages
351 .lock()
352 .await
353 .push_back(QueuedUserMessage {
354 content,
355 mode: QueuedUserMessageMode::from_str(mode),
356 });
357 }
358
359 pub async fn take_queued_user_messages(
360 &self,
361 include_interrupt_immediate: bool,
362 include_finish_step: bool,
363 include_wait_for_completion: bool,
364 ) -> Vec<QueuedUserMessage> {
365 let mut queue = self.queued_user_messages.lock().await;
366 let mut selected = Vec::new();
367 let mut retained = VecDeque::new();
368 while let Some(message) = queue.pop_front() {
369 let should_take = match message.mode {
370 QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
371 QueuedUserMessageMode::FinishStep => include_finish_step,
372 QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
373 };
374 if should_take {
375 selected.push(message);
376 } else {
377 retained.push_back(message);
378 }
379 }
380 *queue = retained;
381 selected
382 }
383
384 pub async fn take_queued_user_messages_for(
385 &self,
386 checkpoint: DeliveryCheckpoint,
387 ) -> Vec<QueuedUserMessage> {
388 match checkpoint {
389 DeliveryCheckpoint::InterruptImmediate => {
390 self.take_queued_user_messages(true, false, false).await
391 }
392 DeliveryCheckpoint::AfterCurrentOperation => {
393 self.take_queued_user_messages(false, true, false).await
394 }
395 DeliveryCheckpoint::EndOfInteraction => {
396 self.take_queued_user_messages(false, false, true).await
397 }
398 }
399 }
400
401 pub fn send_output(&self, text: &str) {
403 self.notify("output", serde_json::json!({"text": text}));
404 }
405
406 pub fn send_progress(
408 &self,
409 phase: &str,
410 message: &str,
411 progress: Option<i64>,
412 total: Option<i64>,
413 data: Option<serde_json::Value>,
414 ) {
415 let mut payload = serde_json::json!({"phase": phase, "message": message});
416 if let Some(p) = progress {
417 payload["progress"] = serde_json::json!(p);
418 }
419 if let Some(t) = total {
420 payload["total"] = serde_json::json!(t);
421 }
422 if let Some(d) = data {
423 payload["data"] = d;
424 }
425 self.notify("progress", payload);
426 }
427
428 pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
430 let mut payload = serde_json::json!({"level": level, "message": message});
431 if let Some(f) = fields {
432 payload["fields"] = f;
433 }
434 self.notify("log", payload);
435 }
436
437 pub fn send_call_start(
440 &self,
441 call_id: &str,
442 call_type: &str,
443 name: &str,
444 metadata: serde_json::Value,
445 ) {
446 let session_id = self.get_session_id();
447 let script = self.get_script_name();
448 let stream_publicly = metadata
449 .get("stream_publicly")
450 .and_then(|value| value.as_bool())
451 .unwrap_or(true);
452 self.visible_call_streams
453 .lock()
454 .unwrap_or_else(|e| e.into_inner())
455 .insert(call_id.to_string(), stream_publicly);
456 self.notify(
457 "session/update",
458 serde_json::json!({
459 "sessionId": session_id,
460 "update": {
461 "sessionUpdate": "call_start",
462 "content": {
463 "toolCallId": call_id,
464 "call_type": call_type,
465 "name": name,
466 "script": script,
467 "metadata": metadata,
468 },
469 },
470 }),
471 );
472 }
473
474 pub fn send_call_progress(
477 &self,
478 call_id: &str,
479 delta: &str,
480 accumulated_tokens: u64,
481 user_visible: bool,
482 ) {
483 let session_id = self.get_session_id();
484 let (visible_text, visible_delta) = {
485 let stream_publicly = self
486 .visible_call_streams
487 .lock()
488 .unwrap_or_else(|e| e.into_inner())
489 .get(call_id)
490 .copied()
491 .unwrap_or(true);
492 let mut states = self
493 .visible_call_states
494 .lock()
495 .unwrap_or_else(|e| e.into_inner());
496 let state = states.entry(call_id.to_string()).or_default();
497 state.push(delta, stream_publicly)
498 };
499 self.notify(
500 "session/update",
501 serde_json::json!({
502 "sessionId": session_id,
503 "update": {
504 "sessionUpdate": "call_progress",
505 "content": {
506 "toolCallId": call_id,
507 "delta": delta,
508 "accumulated_tokens": accumulated_tokens,
509 "visible_text": visible_text,
510 "visible_delta": visible_delta,
511 "user_visible": user_visible,
512 },
513 },
514 }),
515 );
516 }
517
518 pub fn send_call_end(
520 &self,
521 call_id: &str,
522 call_type: &str,
523 name: &str,
524 duration_ms: u64,
525 status: &str,
526 metadata: serde_json::Value,
527 ) {
528 let session_id = self.get_session_id();
529 let script = self.get_script_name();
530 self.visible_call_states
531 .lock()
532 .unwrap_or_else(|e| e.into_inner())
533 .remove(call_id);
534 self.visible_call_streams
535 .lock()
536 .unwrap_or_else(|e| e.into_inner())
537 .remove(call_id);
538 self.notify(
539 "session/update",
540 serde_json::json!({
541 "sessionId": session_id,
542 "update": {
543 "sessionUpdate": "call_end",
544 "content": {
545 "toolCallId": call_id,
546 "call_type": call_type,
547 "name": name,
548 "script": script,
549 "duration_ms": duration_ms,
550 "status": status,
551 "metadata": metadata,
552 },
553 },
554 }),
555 );
556 }
557
558 pub fn send_worker_update(
560 &self,
561 worker_id: &str,
562 worker_name: &str,
563 status: &str,
564 metadata: serde_json::Value,
565 audit: Option<&MutationSessionRecord>,
566 ) {
567 let session_id = self.get_session_id();
568 let script = self.get_script_name();
569 let started_at = metadata.get("started_at").cloned().unwrap_or_default();
570 let finished_at = metadata.get("finished_at").cloned().unwrap_or_default();
571 let snapshot_path = metadata.get("snapshot_path").cloned().unwrap_or_default();
572 let run_id = metadata.get("child_run_id").cloned().unwrap_or_default();
573 let run_path = metadata.get("child_run_path").cloned().unwrap_or_default();
574 let lifecycle = serde_json::json!({
575 "event": status,
576 "worker_id": worker_id,
577 "worker_name": worker_name,
578 "started_at": started_at,
579 "finished_at": finished_at,
580 });
581 self.notify(
582 "session/update",
583 serde_json::json!({
584 "sessionId": session_id,
585 "update": {
586 "sessionUpdate": "worker_update",
587 "content": {
588 "worker_id": worker_id,
589 "worker_name": worker_name,
590 "status": status,
591 "script": script,
592 "started_at": started_at,
593 "finished_at": finished_at,
594 "snapshot_path": snapshot_path,
595 "run_id": run_id,
596 "run_path": run_path,
597 "lifecycle": lifecycle,
598 "audit": audit,
599 "metadata": metadata,
600 },
601 },
602 }),
603 );
604 }
605}
606
607pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
609 crate::stdlib::json_to_vm_value(val)
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615
616 #[test]
617 fn test_json_rpc_request_format() {
618 let request = crate::jsonrpc::request(
619 1,
620 "llm_call",
621 serde_json::json!({
622 "prompt": "Hello",
623 "system": "Be helpful",
624 }),
625 );
626 let s = serde_json::to_string(&request).unwrap();
627 assert!(s.contains("\"jsonrpc\":\"2.0\""));
628 assert!(s.contains("\"id\":1"));
629 assert!(s.contains("\"method\":\"llm_call\""));
630 }
631
632 #[test]
633 fn test_json_rpc_notification_format() {
634 let notification =
635 crate::jsonrpc::notification("output", serde_json::json!({"text": "[harn] hello\n"}));
636 let s = serde_json::to_string(¬ification).unwrap();
637 assert!(s.contains("\"method\":\"output\""));
638 assert!(!s.contains("\"id\""));
639 }
640
641 #[test]
642 fn test_json_rpc_error_response_parsing() {
643 let response = crate::jsonrpc::error_response(1, -32600, "Invalid request");
644 assert!(response.get("error").is_some());
645 assert_eq!(
646 response["error"]["message"].as_str().unwrap(),
647 "Invalid request"
648 );
649 }
650
651 #[test]
652 fn test_json_rpc_success_response_parsing() {
653 let response = crate::jsonrpc::response(
654 1,
655 serde_json::json!({
656 "text": "Hello world",
657 "input_tokens": 10,
658 "output_tokens": 5,
659 }),
660 );
661 assert!(response.get("result").is_some());
662 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
663 }
664
665 #[test]
666 fn test_cancelled_flag() {
667 let cancelled = Arc::new(AtomicBool::new(false));
668 assert!(!cancelled.load(Ordering::SeqCst));
669 cancelled.store(true, Ordering::SeqCst);
670 assert!(cancelled.load(Ordering::SeqCst));
671 }
672
673 #[test]
674 fn queued_messages_are_filtered_by_delivery_mode() {
675 let runtime = tokio::runtime::Builder::new_current_thread()
676 .enable_all()
677 .build()
678 .unwrap();
679 runtime.block_on(async {
680 let bridge = HostBridge::from_parts(
681 Arc::new(Mutex::new(HashMap::new())),
682 Arc::new(AtomicBool::new(false)),
683 Arc::new(std::sync::Mutex::new(())),
684 1,
685 );
686 bridge
687 .push_queued_user_message("first".to_string(), "finish_step")
688 .await;
689 bridge
690 .push_queued_user_message("second".to_string(), "wait_for_completion")
691 .await;
692
693 let finish_step = bridge.take_queued_user_messages(false, true, false).await;
694 assert_eq!(finish_step.len(), 1);
695 assert_eq!(finish_step[0].content, "first");
696
697 let turn_end = bridge.take_queued_user_messages(false, false, true).await;
698 assert_eq!(turn_end.len(), 1);
699 assert_eq!(turn_end[0].content, "second");
700 });
701 }
702
703 #[test]
704 fn test_json_result_to_vm_value_string() {
705 let val = serde_json::json!("hello");
706 let vm_val = json_result_to_vm_value(&val);
707 assert_eq!(vm_val.display(), "hello");
708 }
709
710 #[test]
711 fn test_json_result_to_vm_value_dict() {
712 let val = serde_json::json!({"name": "test", "count": 42});
713 let vm_val = json_result_to_vm_value(&val);
714 let VmValue::Dict(d) = &vm_val else {
715 unreachable!("Expected Dict, got {:?}", vm_val);
716 };
717 assert_eq!(d.get("name").unwrap().display(), "test");
718 assert_eq!(d.get("count").unwrap().display(), "42");
719 }
720
721 #[test]
722 fn test_json_result_to_vm_value_null() {
723 let val = serde_json::json!(null);
724 let vm_val = json_result_to_vm_value(&val);
725 assert!(matches!(vm_val, VmValue::Nil));
726 }
727
728 #[test]
729 fn test_json_result_to_vm_value_nested() {
730 let val = serde_json::json!({
731 "text": "response",
732 "tool_calls": [
733 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
734 ],
735 "input_tokens": 100,
736 "output_tokens": 50,
737 });
738 let vm_val = json_result_to_vm_value(&val);
739 let VmValue::Dict(d) = &vm_val else {
740 unreachable!("Expected Dict, got {:?}", vm_val);
741 };
742 assert_eq!(d.get("text").unwrap().display(), "response");
743 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
744 unreachable!("Expected List for tool_calls");
745 };
746 assert_eq!(list.len(), 1);
747 }
748
749 #[test]
750 fn test_timeout_duration() {
751 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
752 }
753}