1use std::collections::{HashMap, VecDeque};
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::orchestration::MutationSessionRecord;
17use crate::value::{ErrorCategory, VmError, VmValue};
18
19const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
21
22pub struct HostBridge {
29 next_id: AtomicU64,
30 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
32 cancelled: Arc<AtomicBool>,
34 stdout_lock: Arc<std::sync::Mutex<()>>,
36 session_id: std::sync::Mutex<String>,
38 script_name: std::sync::Mutex<String>,
40 queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>>,
42 resume_requested: Arc<AtomicBool>,
44}
45
46#[derive(Clone, Debug, PartialEq, Eq)]
47pub enum QueuedUserMessageMode {
48 InterruptImmediate,
49 FinishStep,
50 WaitForCompletion,
51}
52
53#[derive(Clone, Copy, Debug, PartialEq, Eq)]
54pub enum DeliveryCheckpoint {
55 InterruptImmediate,
56 AfterCurrentOperation,
57 EndOfInteraction,
58}
59
60impl QueuedUserMessageMode {
61 fn from_str(value: &str) -> Self {
62 match value {
63 "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
64 "finish_step" | "after_current_operation" => Self::FinishStep,
65 _ => Self::WaitForCompletion,
66 }
67 }
68}
69
70#[derive(Clone, Debug, PartialEq, Eq)]
71pub struct QueuedUserMessage {
72 pub content: String,
73 pub mode: QueuedUserMessageMode,
74}
75
76#[allow(clippy::new_without_default)]
78impl HostBridge {
79 pub fn new() -> Self {
84 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
85 Arc::new(Mutex::new(HashMap::new()));
86 let cancelled = Arc::new(AtomicBool::new(false));
87 let queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>> =
88 Arc::new(Mutex::new(VecDeque::new()));
89 let resume_requested = Arc::new(AtomicBool::new(false));
90
91 let pending_clone = pending.clone();
93 let cancelled_clone = cancelled.clone();
94 let queued_clone = queued_user_messages.clone();
95 let resume_clone = resume_requested.clone();
96 tokio::task::spawn_local(async move {
97 let stdin = tokio::io::stdin();
98 let reader = tokio::io::BufReader::new(stdin);
99 let mut lines = reader.lines();
100
101 while let Ok(Some(line)) = lines.next_line().await {
102 let line = line.trim().to_string();
103 if line.is_empty() {
104 continue;
105 }
106
107 let msg: serde_json::Value = match serde_json::from_str(&line) {
108 Ok(v) => v,
109 Err(_) => continue, };
111
112 if msg.get("id").is_none() {
114 if let Some(method) = msg["method"].as_str() {
115 if method == "cancel" {
116 cancelled_clone.store(true, Ordering::SeqCst);
117 } else if method == "agent/resume" {
118 resume_clone.store(true, Ordering::SeqCst);
119 } else if method == "user_message"
120 || method == "session/input"
121 || method == "agent/user_message"
122 {
123 let params = &msg["params"];
124 let content = params
125 .get("content")
126 .and_then(|v| v.as_str())
127 .unwrap_or("")
128 .to_string();
129 if !content.is_empty() {
130 let mode = QueuedUserMessageMode::from_str(
131 params
132 .get("mode")
133 .and_then(|v| v.as_str())
134 .unwrap_or("wait_for_completion"),
135 );
136 queued_clone
137 .lock()
138 .await
139 .push_back(QueuedUserMessage { content, mode });
140 }
141 }
142 }
143 continue;
144 }
145
146 if let Some(id) = msg["id"].as_u64() {
148 let mut pending = pending_clone.lock().await;
149 if let Some(sender) = pending.remove(&id) {
150 let _ = sender.send(msg);
151 }
152 }
153 }
154
155 let mut pending = pending_clone.lock().await;
157 pending.clear();
158 });
159
160 Self {
161 next_id: AtomicU64::new(1),
162 pending,
163 cancelled,
164 stdout_lock: Arc::new(std::sync::Mutex::new(())),
165 session_id: std::sync::Mutex::new(String::new()),
166 script_name: std::sync::Mutex::new(String::new()),
167 queued_user_messages,
168 resume_requested,
169 }
170 }
171
172 pub fn from_parts(
178 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
179 cancelled: Arc<AtomicBool>,
180 stdout_lock: Arc<std::sync::Mutex<()>>,
181 start_id: u64,
182 ) -> Self {
183 Self {
184 next_id: AtomicU64::new(start_id),
185 pending,
186 cancelled,
187 stdout_lock,
188 session_id: std::sync::Mutex::new(String::new()),
189 script_name: std::sync::Mutex::new(String::new()),
190 queued_user_messages: Arc::new(Mutex::new(VecDeque::new())),
191 resume_requested: Arc::new(AtomicBool::new(false)),
192 }
193 }
194
195 pub fn set_session_id(&self, id: &str) {
197 *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
198 }
199
200 pub fn set_script_name(&self, name: &str) {
202 *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
203 }
204
205 fn get_script_name(&self) -> String {
207 self.script_name
208 .lock()
209 .unwrap_or_else(|e| e.into_inner())
210 .clone()
211 }
212
213 fn get_session_id(&self) -> String {
215 self.session_id
216 .lock()
217 .unwrap_or_else(|e| e.into_inner())
218 .clone()
219 }
220
221 fn write_line(&self, line: &str) -> Result<(), VmError> {
223 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
224 let mut stdout = std::io::stdout().lock();
225 stdout
226 .write_all(line.as_bytes())
227 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
228 stdout
229 .write_all(b"\n")
230 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
231 stdout
232 .flush()
233 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
234 Ok(())
235 }
236
237 pub async fn call(
240 &self,
241 method: &str,
242 params: serde_json::Value,
243 ) -> Result<serde_json::Value, VmError> {
244 if self.is_cancelled() {
245 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
246 }
247
248 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
249
250 let request = serde_json::json!({
251 "jsonrpc": "2.0",
252 "id": id,
253 "method": method,
254 "params": params,
255 });
256
257 let (tx, rx) = oneshot::channel();
259 {
260 let mut pending = self.pending.lock().await;
261 pending.insert(id, tx);
262 }
263
264 let line = serde_json::to_string(&request)
266 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
267 if let Err(e) = self.write_line(&line) {
268 let mut pending = self.pending.lock().await;
270 pending.remove(&id);
271 return Err(e);
272 }
273
274 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
276 Ok(Ok(msg)) => msg,
277 Ok(Err(_)) => {
278 return Err(VmError::Runtime(
280 "Bridge: host closed connection before responding".into(),
281 ));
282 }
283 Err(_) => {
284 let mut pending = self.pending.lock().await;
286 pending.remove(&id);
287 return Err(VmError::Runtime(format!(
288 "Bridge: host did not respond to '{method}' within {}s",
289 DEFAULT_TIMEOUT.as_secs()
290 )));
291 }
292 };
293
294 if let Some(error) = response.get("error") {
296 let message = error["message"].as_str().unwrap_or("Unknown host error");
297 let code = error["code"].as_i64().unwrap_or(-1);
298 if code == -32001 {
300 return Err(VmError::CategorizedError {
301 message: message.to_string(),
302 category: ErrorCategory::ToolRejected,
303 });
304 }
305 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
306 }
307
308 Ok(response["result"].clone())
309 }
310
311 pub fn notify(&self, method: &str, params: serde_json::Value) {
314 let notification = serde_json::json!({
315 "jsonrpc": "2.0",
316 "method": method,
317 "params": params,
318 });
319 if let Ok(line) = serde_json::to_string(¬ification) {
320 let _ = self.write_line(&line);
321 }
322 }
323
324 pub fn is_cancelled(&self) -> bool {
326 self.cancelled.load(Ordering::SeqCst)
327 }
328
329 pub fn take_resume_signal(&self) -> bool {
330 self.resume_requested.swap(false, Ordering::SeqCst)
331 }
332
333 pub fn signal_resume(&self) {
334 self.resume_requested.store(true, Ordering::SeqCst);
335 }
336
337 pub async fn push_queued_user_message(&self, content: String, mode: &str) {
338 self.queued_user_messages
339 .lock()
340 .await
341 .push_back(QueuedUserMessage {
342 content,
343 mode: QueuedUserMessageMode::from_str(mode),
344 });
345 }
346
347 pub async fn take_queued_user_messages(
348 &self,
349 include_interrupt_immediate: bool,
350 include_finish_step: bool,
351 include_wait_for_completion: bool,
352 ) -> Vec<QueuedUserMessage> {
353 let mut queue = self.queued_user_messages.lock().await;
354 let mut selected = Vec::new();
355 let mut retained = VecDeque::new();
356 while let Some(message) = queue.pop_front() {
357 let should_take = match message.mode {
358 QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
359 QueuedUserMessageMode::FinishStep => include_finish_step,
360 QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
361 };
362 if should_take {
363 selected.push(message);
364 } else {
365 retained.push_back(message);
366 }
367 }
368 *queue = retained;
369 selected
370 }
371
372 pub async fn take_queued_user_messages_for(
373 &self,
374 checkpoint: DeliveryCheckpoint,
375 ) -> Vec<QueuedUserMessage> {
376 match checkpoint {
377 DeliveryCheckpoint::InterruptImmediate => {
378 self.take_queued_user_messages(true, false, false).await
379 }
380 DeliveryCheckpoint::AfterCurrentOperation => {
381 self.take_queued_user_messages(false, true, false).await
382 }
383 DeliveryCheckpoint::EndOfInteraction => {
384 self.take_queued_user_messages(false, false, true).await
385 }
386 }
387 }
388
389 pub fn send_output(&self, text: &str) {
391 self.notify("output", serde_json::json!({"text": text}));
392 }
393
394 pub fn send_progress(
396 &self,
397 phase: &str,
398 message: &str,
399 progress: Option<i64>,
400 total: Option<i64>,
401 data: Option<serde_json::Value>,
402 ) {
403 let mut payload = serde_json::json!({"phase": phase, "message": message});
404 if let Some(p) = progress {
405 payload["progress"] = serde_json::json!(p);
406 }
407 if let Some(t) = total {
408 payload["total"] = serde_json::json!(t);
409 }
410 if let Some(d) = data {
411 payload["data"] = d;
412 }
413 self.notify("progress", payload);
414 }
415
416 pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
418 let mut payload = serde_json::json!({"level": level, "message": message});
419 if let Some(f) = fields {
420 payload["fields"] = f;
421 }
422 self.notify("log", payload);
423 }
424
425 pub fn send_call_start(
428 &self,
429 call_id: &str,
430 call_type: &str,
431 name: &str,
432 metadata: serde_json::Value,
433 ) {
434 let session_id = self.get_session_id();
435 let script = self.get_script_name();
436 self.notify(
437 "session/update",
438 serde_json::json!({
439 "sessionId": session_id,
440 "update": {
441 "sessionUpdate": "call_start",
442 "content": {
443 "call_id": call_id,
444 "call_type": call_type,
445 "name": name,
446 "script": script,
447 "metadata": metadata,
448 },
449 },
450 }),
451 );
452 }
453
454 pub fn send_call_progress(&self, call_id: &str, delta: &str, accumulated_tokens: u64) {
457 let session_id = self.get_session_id();
458 self.notify(
459 "session/update",
460 serde_json::json!({
461 "sessionId": session_id,
462 "update": {
463 "sessionUpdate": "call_progress",
464 "content": {
465 "call_id": call_id,
466 "delta": delta,
467 "accumulated_tokens": accumulated_tokens,
468 },
469 },
470 }),
471 );
472 }
473
474 pub fn send_call_end(
476 &self,
477 call_id: &str,
478 call_type: &str,
479 name: &str,
480 duration_ms: u64,
481 status: &str,
482 metadata: serde_json::Value,
483 ) {
484 let session_id = self.get_session_id();
485 let script = self.get_script_name();
486 self.notify(
487 "session/update",
488 serde_json::json!({
489 "sessionId": session_id,
490 "update": {
491 "sessionUpdate": "call_end",
492 "content": {
493 "call_id": call_id,
494 "call_type": call_type,
495 "name": name,
496 "script": script,
497 "duration_ms": duration_ms,
498 "status": status,
499 "metadata": metadata,
500 },
501 },
502 }),
503 );
504 }
505
506 pub fn send_worker_update(
508 &self,
509 worker_id: &str,
510 worker_name: &str,
511 status: &str,
512 metadata: serde_json::Value,
513 audit: Option<&MutationSessionRecord>,
514 ) {
515 let session_id = self.get_session_id();
516 let script = self.get_script_name();
517 let started_at = metadata.get("started_at").cloned().unwrap_or_default();
518 let finished_at = metadata.get("finished_at").cloned().unwrap_or_default();
519 let snapshot_path = metadata.get("snapshot_path").cloned().unwrap_or_default();
520 let run_id = metadata.get("child_run_id").cloned().unwrap_or_default();
521 let run_path = metadata.get("child_run_path").cloned().unwrap_or_default();
522 let lifecycle = serde_json::json!({
523 "event": status,
524 "worker_id": worker_id,
525 "worker_name": worker_name,
526 "started_at": started_at,
527 "finished_at": finished_at,
528 });
529 self.notify(
530 "session/update",
531 serde_json::json!({
532 "sessionId": session_id,
533 "update": {
534 "sessionUpdate": "worker_update",
535 "content": {
536 "worker_id": worker_id,
537 "worker_name": worker_name,
538 "status": status,
539 "script": script,
540 "started_at": started_at,
541 "finished_at": finished_at,
542 "snapshot_path": snapshot_path,
543 "run_id": run_id,
544 "run_path": run_path,
545 "lifecycle": lifecycle,
546 "audit": audit,
547 "metadata": metadata,
548 },
549 },
550 }),
551 );
552 }
553}
554
555pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
557 crate::stdlib::json_to_vm_value(val)
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 #[test]
565 fn test_json_rpc_request_format() {
566 let request = serde_json::json!({
567 "jsonrpc": "2.0",
568 "id": 1,
569 "method": "llm_call",
570 "params": {
571 "prompt": "Hello",
572 "system": "Be helpful",
573 },
574 });
575 let s = serde_json::to_string(&request).unwrap();
576 assert!(s.contains("\"jsonrpc\":\"2.0\""));
577 assert!(s.contains("\"id\":1"));
578 assert!(s.contains("\"method\":\"llm_call\""));
579 }
580
581 #[test]
582 fn test_json_rpc_notification_format() {
583 let notification = serde_json::json!({
584 "jsonrpc": "2.0",
585 "method": "output",
586 "params": {"text": "[harn] hello\n"},
587 });
588 let s = serde_json::to_string(¬ification).unwrap();
589 assert!(s.contains("\"method\":\"output\""));
590 assert!(!s.contains("\"id\""));
591 }
592
593 #[test]
594 fn test_json_rpc_error_response_parsing() {
595 let response = serde_json::json!({
596 "jsonrpc": "2.0",
597 "id": 1,
598 "error": {
599 "code": -32600,
600 "message": "Invalid request",
601 },
602 });
603 assert!(response.get("error").is_some());
604 assert_eq!(
605 response["error"]["message"].as_str().unwrap(),
606 "Invalid request"
607 );
608 }
609
610 #[test]
611 fn test_json_rpc_success_response_parsing() {
612 let response = serde_json::json!({
613 "jsonrpc": "2.0",
614 "id": 1,
615 "result": {
616 "text": "Hello world",
617 "input_tokens": 10,
618 "output_tokens": 5,
619 },
620 });
621 assert!(response.get("result").is_some());
622 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
623 }
624
625 #[test]
626 fn test_cancelled_flag() {
627 let cancelled = Arc::new(AtomicBool::new(false));
628 assert!(!cancelled.load(Ordering::SeqCst));
629 cancelled.store(true, Ordering::SeqCst);
630 assert!(cancelled.load(Ordering::SeqCst));
631 }
632
633 #[test]
634 fn queued_messages_are_filtered_by_delivery_mode() {
635 let runtime = tokio::runtime::Builder::new_current_thread()
636 .enable_all()
637 .build()
638 .unwrap();
639 runtime.block_on(async {
640 let bridge = HostBridge::from_parts(
641 Arc::new(Mutex::new(HashMap::new())),
642 Arc::new(AtomicBool::new(false)),
643 Arc::new(std::sync::Mutex::new(())),
644 1,
645 );
646 bridge
647 .push_queued_user_message("first".to_string(), "finish_step")
648 .await;
649 bridge
650 .push_queued_user_message("second".to_string(), "wait_for_completion")
651 .await;
652
653 let finish_step = bridge.take_queued_user_messages(false, true, false).await;
654 assert_eq!(finish_step.len(), 1);
655 assert_eq!(finish_step[0].content, "first");
656
657 let turn_end = bridge.take_queued_user_messages(false, false, true).await;
658 assert_eq!(turn_end.len(), 1);
659 assert_eq!(turn_end[0].content, "second");
660 });
661 }
662
663 #[test]
664 fn test_json_result_to_vm_value_string() {
665 let val = serde_json::json!("hello");
666 let vm_val = json_result_to_vm_value(&val);
667 assert_eq!(vm_val.display(), "hello");
668 }
669
670 #[test]
671 fn test_json_result_to_vm_value_dict() {
672 let val = serde_json::json!({"name": "test", "count": 42});
673 let vm_val = json_result_to_vm_value(&val);
674 let VmValue::Dict(d) = &vm_val else {
675 unreachable!("Expected Dict, got {:?}", vm_val);
676 };
677 assert_eq!(d.get("name").unwrap().display(), "test");
678 assert_eq!(d.get("count").unwrap().display(), "42");
679 }
680
681 #[test]
682 fn test_json_result_to_vm_value_null() {
683 let val = serde_json::json!(null);
684 let vm_val = json_result_to_vm_value(&val);
685 assert!(matches!(vm_val, VmValue::Nil));
686 }
687
688 #[test]
689 fn test_json_result_to_vm_value_nested() {
690 let val = serde_json::json!({
691 "text": "response",
692 "tool_calls": [
693 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
694 ],
695 "input_tokens": 100,
696 "output_tokens": 50,
697 });
698 let vm_val = json_result_to_vm_value(&val);
699 let VmValue::Dict(d) = &vm_val else {
700 unreachable!("Expected Dict, got {:?}", vm_val);
701 };
702 assert_eq!(d.get("text").unwrap().display(), "response");
703 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
704 unreachable!("Expected List for tool_calls");
705 };
706 assert_eq!(list.len(), 1);
707 }
708
709 #[test]
710 fn test_timeout_duration() {
711 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
712 }
713}