1use std::collections::{HashMap, VecDeque};
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{ErrorCategory, VmError, VmValue};
17
18const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21pub struct HostBridge {
28 next_id: AtomicU64,
29 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31 cancelled: Arc<AtomicBool>,
33 stdout_lock: Arc<std::sync::Mutex<()>>,
35 session_id: std::sync::Mutex<String>,
37 script_name: std::sync::Mutex<String>,
39 queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>>,
41 resume_requested: Arc<AtomicBool>,
43}
44
45#[derive(Clone, Debug, PartialEq, Eq)]
46pub enum QueuedUserMessageMode {
47 InterruptImmediate,
48 FinishStep,
49 WaitForCompletion,
50}
51
52#[derive(Clone, Copy, Debug, PartialEq, Eq)]
53pub enum DeliveryCheckpoint {
54 InterruptImmediate,
55 AfterCurrentOperation,
56 EndOfInteraction,
57}
58
59impl QueuedUserMessageMode {
60 fn from_str(value: &str) -> Self {
61 match value {
62 "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
63 "finish_step" | "after_current_operation" => Self::FinishStep,
64 _ => Self::WaitForCompletion,
65 }
66 }
67}
68
69#[derive(Clone, Debug, PartialEq, Eq)]
70pub struct QueuedUserMessage {
71 pub content: String,
72 pub mode: QueuedUserMessageMode,
73}
74
75#[allow(clippy::new_without_default)]
77impl HostBridge {
78 pub fn new() -> Self {
83 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
84 Arc::new(Mutex::new(HashMap::new()));
85 let cancelled = Arc::new(AtomicBool::new(false));
86 let queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>> =
87 Arc::new(Mutex::new(VecDeque::new()));
88 let resume_requested = Arc::new(AtomicBool::new(false));
89
90 let pending_clone = pending.clone();
92 let cancelled_clone = cancelled.clone();
93 let queued_clone = queued_user_messages.clone();
94 let resume_clone = resume_requested.clone();
95 tokio::task::spawn_local(async move {
96 let stdin = tokio::io::stdin();
97 let reader = tokio::io::BufReader::new(stdin);
98 let mut lines = reader.lines();
99
100 while let Ok(Some(line)) = lines.next_line().await {
101 let line = line.trim().to_string();
102 if line.is_empty() {
103 continue;
104 }
105
106 let msg: serde_json::Value = match serde_json::from_str(&line) {
107 Ok(v) => v,
108 Err(_) => continue, };
110
111 if msg.get("id").is_none() {
113 if let Some(method) = msg["method"].as_str() {
114 if method == "cancel" {
115 cancelled_clone.store(true, Ordering::SeqCst);
116 } else if method == "agent/resume" {
117 resume_clone.store(true, Ordering::SeqCst);
118 } else if method == "user_message"
119 || method == "session/input"
120 || method == "agent/user_message"
121 {
122 let params = &msg["params"];
123 let content = params
124 .get("content")
125 .and_then(|v| v.as_str())
126 .unwrap_or("")
127 .to_string();
128 if !content.is_empty() {
129 let mode = QueuedUserMessageMode::from_str(
130 params
131 .get("mode")
132 .and_then(|v| v.as_str())
133 .unwrap_or("wait_for_completion"),
134 );
135 queued_clone
136 .lock()
137 .await
138 .push_back(QueuedUserMessage { content, mode });
139 }
140 }
141 }
142 continue;
143 }
144
145 if let Some(id) = msg["id"].as_u64() {
147 let mut pending = pending_clone.lock().await;
148 if let Some(sender) = pending.remove(&id) {
149 let _ = sender.send(msg);
150 }
151 }
152 }
153
154 let mut pending = pending_clone.lock().await;
156 pending.clear();
157 });
158
159 Self {
160 next_id: AtomicU64::new(1),
161 pending,
162 cancelled,
163 stdout_lock: Arc::new(std::sync::Mutex::new(())),
164 session_id: std::sync::Mutex::new(String::new()),
165 script_name: std::sync::Mutex::new(String::new()),
166 queued_user_messages,
167 resume_requested,
168 }
169 }
170
171 pub fn from_parts(
177 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
178 cancelled: Arc<AtomicBool>,
179 stdout_lock: Arc<std::sync::Mutex<()>>,
180 start_id: u64,
181 ) -> Self {
182 Self {
183 next_id: AtomicU64::new(start_id),
184 pending,
185 cancelled,
186 stdout_lock,
187 session_id: std::sync::Mutex::new(String::new()),
188 script_name: std::sync::Mutex::new(String::new()),
189 queued_user_messages: Arc::new(Mutex::new(VecDeque::new())),
190 resume_requested: Arc::new(AtomicBool::new(false)),
191 }
192 }
193
194 pub fn set_session_id(&self, id: &str) {
196 *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
197 }
198
199 pub fn set_script_name(&self, name: &str) {
201 *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
202 }
203
204 fn get_script_name(&self) -> String {
206 self.script_name
207 .lock()
208 .unwrap_or_else(|e| e.into_inner())
209 .clone()
210 }
211
212 fn get_session_id(&self) -> String {
214 self.session_id
215 .lock()
216 .unwrap_or_else(|e| e.into_inner())
217 .clone()
218 }
219
220 fn write_line(&self, line: &str) -> Result<(), VmError> {
222 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
223 let mut stdout = std::io::stdout().lock();
224 stdout
225 .write_all(line.as_bytes())
226 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
227 stdout
228 .write_all(b"\n")
229 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
230 stdout
231 .flush()
232 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
233 Ok(())
234 }
235
236 pub async fn call(
239 &self,
240 method: &str,
241 params: serde_json::Value,
242 ) -> Result<serde_json::Value, VmError> {
243 if self.is_cancelled() {
244 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
245 }
246
247 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
248
249 let request = serde_json::json!({
250 "jsonrpc": "2.0",
251 "id": id,
252 "method": method,
253 "params": params,
254 });
255
256 let (tx, rx) = oneshot::channel();
258 {
259 let mut pending = self.pending.lock().await;
260 pending.insert(id, tx);
261 }
262
263 let line = serde_json::to_string(&request)
265 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
266 if let Err(e) = self.write_line(&line) {
267 let mut pending = self.pending.lock().await;
269 pending.remove(&id);
270 return Err(e);
271 }
272
273 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
275 Ok(Ok(msg)) => msg,
276 Ok(Err(_)) => {
277 return Err(VmError::Runtime(
279 "Bridge: host closed connection before responding".into(),
280 ));
281 }
282 Err(_) => {
283 let mut pending = self.pending.lock().await;
285 pending.remove(&id);
286 return Err(VmError::Runtime(format!(
287 "Bridge: host did not respond to '{method}' within {}s",
288 DEFAULT_TIMEOUT.as_secs()
289 )));
290 }
291 };
292
293 if let Some(error) = response.get("error") {
295 let message = error["message"].as_str().unwrap_or("Unknown host error");
296 let code = error["code"].as_i64().unwrap_or(-1);
297 if code == -32001 {
299 return Err(VmError::CategorizedError {
300 message: message.to_string(),
301 category: ErrorCategory::ToolRejected,
302 });
303 }
304 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
305 }
306
307 Ok(response["result"].clone())
308 }
309
310 pub fn notify(&self, method: &str, params: serde_json::Value) {
313 let notification = serde_json::json!({
314 "jsonrpc": "2.0",
315 "method": method,
316 "params": params,
317 });
318 if let Ok(line) = serde_json::to_string(¬ification) {
319 let _ = self.write_line(&line);
320 }
321 }
322
323 pub fn is_cancelled(&self) -> bool {
325 self.cancelled.load(Ordering::SeqCst)
326 }
327
328 pub fn take_resume_signal(&self) -> bool {
329 self.resume_requested.swap(false, Ordering::SeqCst)
330 }
331
332 pub fn signal_resume(&self) {
333 self.resume_requested.store(true, Ordering::SeqCst);
334 }
335
336 pub async fn push_queued_user_message(&self, content: String, mode: &str) {
337 self.queued_user_messages
338 .lock()
339 .await
340 .push_back(QueuedUserMessage {
341 content,
342 mode: QueuedUserMessageMode::from_str(mode),
343 });
344 }
345
346 pub async fn take_queued_user_messages(
347 &self,
348 include_interrupt_immediate: bool,
349 include_finish_step: bool,
350 include_wait_for_completion: bool,
351 ) -> Vec<QueuedUserMessage> {
352 let mut queue = self.queued_user_messages.lock().await;
353 let mut selected = Vec::new();
354 let mut retained = VecDeque::new();
355 while let Some(message) = queue.pop_front() {
356 let should_take = match message.mode {
357 QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
358 QueuedUserMessageMode::FinishStep => include_finish_step,
359 QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
360 };
361 if should_take {
362 selected.push(message);
363 } else {
364 retained.push_back(message);
365 }
366 }
367 *queue = retained;
368 selected
369 }
370
371 pub async fn take_queued_user_messages_for(
372 &self,
373 checkpoint: DeliveryCheckpoint,
374 ) -> Vec<QueuedUserMessage> {
375 match checkpoint {
376 DeliveryCheckpoint::InterruptImmediate => {
377 self.take_queued_user_messages(true, false, false).await
378 }
379 DeliveryCheckpoint::AfterCurrentOperation => {
380 self.take_queued_user_messages(false, true, false).await
381 }
382 DeliveryCheckpoint::EndOfInteraction => {
383 self.take_queued_user_messages(false, false, true).await
384 }
385 }
386 }
387
388 pub fn send_output(&self, text: &str) {
390 self.notify("output", serde_json::json!({"text": text}));
391 }
392
393 pub fn send_progress(
395 &self,
396 phase: &str,
397 message: &str,
398 progress: Option<i64>,
399 total: Option<i64>,
400 data: Option<serde_json::Value>,
401 ) {
402 let mut payload = serde_json::json!({"phase": phase, "message": message});
403 if let Some(p) = progress {
404 payload["progress"] = serde_json::json!(p);
405 }
406 if let Some(t) = total {
407 payload["total"] = serde_json::json!(t);
408 }
409 if let Some(d) = data {
410 payload["data"] = d;
411 }
412 self.notify("progress", payload);
413 }
414
415 pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
417 let mut payload = serde_json::json!({"level": level, "message": message});
418 if let Some(f) = fields {
419 payload["fields"] = f;
420 }
421 self.notify("log", payload);
422 }
423
424 pub fn send_call_start(
427 &self,
428 call_id: &str,
429 call_type: &str,
430 name: &str,
431 metadata: serde_json::Value,
432 ) {
433 let session_id = self.get_session_id();
434 let script = self.get_script_name();
435 self.notify(
436 "session/update",
437 serde_json::json!({
438 "sessionId": session_id,
439 "update": {
440 "sessionUpdate": "call_start",
441 "content": {
442 "call_id": call_id,
443 "call_type": call_type,
444 "name": name,
445 "script": script,
446 "metadata": metadata,
447 },
448 },
449 }),
450 );
451 }
452
453 pub fn send_call_progress(&self, call_id: &str, delta: &str, accumulated_tokens: u64) {
456 let session_id = self.get_session_id();
457 self.notify(
458 "session/update",
459 serde_json::json!({
460 "sessionId": session_id,
461 "update": {
462 "sessionUpdate": "call_progress",
463 "content": {
464 "call_id": call_id,
465 "delta": delta,
466 "accumulated_tokens": accumulated_tokens,
467 },
468 },
469 }),
470 );
471 }
472
473 pub fn send_call_end(
475 &self,
476 call_id: &str,
477 call_type: &str,
478 name: &str,
479 duration_ms: u64,
480 status: &str,
481 metadata: serde_json::Value,
482 ) {
483 let session_id = self.get_session_id();
484 let script = self.get_script_name();
485 self.notify(
486 "session/update",
487 serde_json::json!({
488 "sessionId": session_id,
489 "update": {
490 "sessionUpdate": "call_end",
491 "content": {
492 "call_id": call_id,
493 "call_type": call_type,
494 "name": name,
495 "script": script,
496 "duration_ms": duration_ms,
497 "status": status,
498 "metadata": metadata,
499 },
500 },
501 }),
502 );
503 }
504
505 pub fn send_worker_update(
507 &self,
508 worker_id: &str,
509 worker_name: &str,
510 status: &str,
511 metadata: serde_json::Value,
512 ) {
513 let session_id = self.get_session_id();
514 let script = self.get_script_name();
515 let started_at = metadata.get("started_at").cloned().unwrap_or_default();
516 let finished_at = metadata.get("finished_at").cloned().unwrap_or_default();
517 let snapshot_path = metadata.get("snapshot_path").cloned().unwrap_or_default();
518 let run_id = metadata.get("child_run_id").cloned().unwrap_or_default();
519 let run_path = metadata.get("child_run_path").cloned().unwrap_or_default();
520 self.notify(
521 "session/update",
522 serde_json::json!({
523 "sessionId": session_id,
524 "update": {
525 "sessionUpdate": "worker_update",
526 "content": {
527 "worker_id": worker_id,
528 "worker_name": worker_name,
529 "status": status,
530 "script": script,
531 "started_at": started_at,
532 "finished_at": finished_at,
533 "snapshot_path": snapshot_path,
534 "run_id": run_id,
535 "run_path": run_path,
536 "metadata": metadata,
537 },
538 },
539 }),
540 );
541 }
542}
543
544pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
546 crate::stdlib::json_to_vm_value(val)
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_json_rpc_request_format() {
555 let request = serde_json::json!({
556 "jsonrpc": "2.0",
557 "id": 1,
558 "method": "llm_call",
559 "params": {
560 "prompt": "Hello",
561 "system": "Be helpful",
562 },
563 });
564 let s = serde_json::to_string(&request).unwrap();
565 assert!(s.contains("\"jsonrpc\":\"2.0\""));
566 assert!(s.contains("\"id\":1"));
567 assert!(s.contains("\"method\":\"llm_call\""));
568 }
569
570 #[test]
571 fn test_json_rpc_notification_format() {
572 let notification = serde_json::json!({
573 "jsonrpc": "2.0",
574 "method": "output",
575 "params": {"text": "[harn] hello\n"},
576 });
577 let s = serde_json::to_string(¬ification).unwrap();
578 assert!(s.contains("\"method\":\"output\""));
579 assert!(!s.contains("\"id\""));
580 }
581
582 #[test]
583 fn test_json_rpc_error_response_parsing() {
584 let response = serde_json::json!({
585 "jsonrpc": "2.0",
586 "id": 1,
587 "error": {
588 "code": -32600,
589 "message": "Invalid request",
590 },
591 });
592 assert!(response.get("error").is_some());
593 assert_eq!(
594 response["error"]["message"].as_str().unwrap(),
595 "Invalid request"
596 );
597 }
598
599 #[test]
600 fn test_json_rpc_success_response_parsing() {
601 let response = serde_json::json!({
602 "jsonrpc": "2.0",
603 "id": 1,
604 "result": {
605 "text": "Hello world",
606 "input_tokens": 10,
607 "output_tokens": 5,
608 },
609 });
610 assert!(response.get("result").is_some());
611 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
612 }
613
614 #[test]
615 fn test_cancelled_flag() {
616 let cancelled = Arc::new(AtomicBool::new(false));
617 assert!(!cancelled.load(Ordering::SeqCst));
618 cancelled.store(true, Ordering::SeqCst);
619 assert!(cancelled.load(Ordering::SeqCst));
620 }
621
622 #[test]
623 fn queued_messages_are_filtered_by_delivery_mode() {
624 let runtime = tokio::runtime::Builder::new_current_thread()
625 .enable_all()
626 .build()
627 .unwrap();
628 runtime.block_on(async {
629 let bridge = HostBridge::from_parts(
630 Arc::new(Mutex::new(HashMap::new())),
631 Arc::new(AtomicBool::new(false)),
632 Arc::new(std::sync::Mutex::new(())),
633 1,
634 );
635 bridge
636 .push_queued_user_message("first".to_string(), "finish_step")
637 .await;
638 bridge
639 .push_queued_user_message("second".to_string(), "wait_for_completion")
640 .await;
641
642 let finish_step = bridge.take_queued_user_messages(false, true, false).await;
643 assert_eq!(finish_step.len(), 1);
644 assert_eq!(finish_step[0].content, "first");
645
646 let turn_end = bridge.take_queued_user_messages(false, false, true).await;
647 assert_eq!(turn_end.len(), 1);
648 assert_eq!(turn_end[0].content, "second");
649 });
650 }
651
652 #[test]
653 fn test_json_result_to_vm_value_string() {
654 let val = serde_json::json!("hello");
655 let vm_val = json_result_to_vm_value(&val);
656 assert_eq!(vm_val.display(), "hello");
657 }
658
659 #[test]
660 fn test_json_result_to_vm_value_dict() {
661 let val = serde_json::json!({"name": "test", "count": 42});
662 let vm_val = json_result_to_vm_value(&val);
663 let VmValue::Dict(d) = &vm_val else {
664 unreachable!("Expected Dict, got {:?}", vm_val);
665 };
666 assert_eq!(d.get("name").unwrap().display(), "test");
667 assert_eq!(d.get("count").unwrap().display(), "42");
668 }
669
670 #[test]
671 fn test_json_result_to_vm_value_null() {
672 let val = serde_json::json!(null);
673 let vm_val = json_result_to_vm_value(&val);
674 assert!(matches!(vm_val, VmValue::Nil));
675 }
676
677 #[test]
678 fn test_json_result_to_vm_value_nested() {
679 let val = serde_json::json!({
680 "text": "response",
681 "tool_calls": [
682 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
683 ],
684 "input_tokens": 100,
685 "output_tokens": 50,
686 });
687 let vm_val = json_result_to_vm_value(&val);
688 let VmValue::Dict(d) = &vm_val else {
689 unreachable!("Expected Dict, got {:?}", vm_val);
690 };
691 assert_eq!(d.get("text").unwrap().display(), "response");
692 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
693 unreachable!("Expected List for tool_calls");
694 };
695 assert_eq!(list.len(), 1);
696 }
697
698 #[test]
699 fn test_timeout_duration() {
700 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
701 }
702}