1use std::collections::{HashMap, VecDeque};
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{ErrorCategory, VmError, VmValue};
17
18const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21pub struct HostBridge {
28 next_id: AtomicU64,
29 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31 cancelled: Arc<AtomicBool>,
33 stdout_lock: Arc<std::sync::Mutex<()>>,
35 session_id: std::sync::Mutex<String>,
37 script_name: std::sync::Mutex<String>,
39 queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>>,
41}
42
43#[derive(Clone, Debug, PartialEq, Eq)]
44pub enum QueuedUserMessageMode {
45 InterruptImmediate,
46 FinishStep,
47 WaitForCompletion,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq, Eq)]
51pub enum DeliveryCheckpoint {
52 InterruptImmediate,
53 AfterCurrentOperation,
54 EndOfInteraction,
55}
56
57impl QueuedUserMessageMode {
58 fn from_str(value: &str) -> Self {
59 match value {
60 "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
61 "finish_step" | "after_current_operation" => Self::FinishStep,
62 _ => Self::WaitForCompletion,
63 }
64 }
65}
66
67#[derive(Clone, Debug, PartialEq, Eq)]
68pub struct QueuedUserMessage {
69 pub content: String,
70 pub mode: QueuedUserMessageMode,
71}
72
73#[allow(clippy::new_without_default)]
75impl HostBridge {
76 pub fn new() -> Self {
81 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
82 Arc::new(Mutex::new(HashMap::new()));
83 let cancelled = Arc::new(AtomicBool::new(false));
84 let queued_user_messages: Arc<Mutex<VecDeque<QueuedUserMessage>>> =
85 Arc::new(Mutex::new(VecDeque::new()));
86
87 let pending_clone = pending.clone();
89 let cancelled_clone = cancelled.clone();
90 let queued_clone = queued_user_messages.clone();
91 tokio::task::spawn_local(async move {
92 let stdin = tokio::io::stdin();
93 let reader = tokio::io::BufReader::new(stdin);
94 let mut lines = reader.lines();
95
96 while let Ok(Some(line)) = lines.next_line().await {
97 let line = line.trim().to_string();
98 if line.is_empty() {
99 continue;
100 }
101
102 let msg: serde_json::Value = match serde_json::from_str(&line) {
103 Ok(v) => v,
104 Err(_) => continue, };
106
107 if msg.get("id").is_none() {
109 if let Some(method) = msg["method"].as_str() {
110 if method == "cancel" {
111 cancelled_clone.store(true, Ordering::SeqCst);
112 } else if method == "user_message"
113 || method == "session/input"
114 || method == "agent/user_message"
115 {
116 let params = &msg["params"];
117 let content = params
118 .get("content")
119 .and_then(|v| v.as_str())
120 .unwrap_or("")
121 .to_string();
122 if !content.is_empty() {
123 let mode = QueuedUserMessageMode::from_str(
124 params
125 .get("mode")
126 .and_then(|v| v.as_str())
127 .unwrap_or("wait_for_completion"),
128 );
129 queued_clone
130 .lock()
131 .await
132 .push_back(QueuedUserMessage { content, mode });
133 }
134 }
135 }
136 continue;
137 }
138
139 if let Some(id) = msg["id"].as_u64() {
141 let mut pending = pending_clone.lock().await;
142 if let Some(sender) = pending.remove(&id) {
143 let _ = sender.send(msg);
144 }
145 }
146 }
147
148 let mut pending = pending_clone.lock().await;
150 pending.clear();
151 });
152
153 Self {
154 next_id: AtomicU64::new(1),
155 pending,
156 cancelled,
157 stdout_lock: Arc::new(std::sync::Mutex::new(())),
158 session_id: std::sync::Mutex::new(String::new()),
159 script_name: std::sync::Mutex::new(String::new()),
160 queued_user_messages,
161 }
162 }
163
164 pub fn from_parts(
170 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
171 cancelled: Arc<AtomicBool>,
172 stdout_lock: Arc<std::sync::Mutex<()>>,
173 start_id: u64,
174 ) -> Self {
175 Self {
176 next_id: AtomicU64::new(start_id),
177 pending,
178 cancelled,
179 stdout_lock,
180 session_id: std::sync::Mutex::new(String::new()),
181 script_name: std::sync::Mutex::new(String::new()),
182 queued_user_messages: Arc::new(Mutex::new(VecDeque::new())),
183 }
184 }
185
186 pub fn set_session_id(&self, id: &str) {
188 *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
189 }
190
191 pub fn set_script_name(&self, name: &str) {
193 *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
194 }
195
196 fn get_script_name(&self) -> String {
198 self.script_name
199 .lock()
200 .unwrap_or_else(|e| e.into_inner())
201 .clone()
202 }
203
204 fn get_session_id(&self) -> String {
206 self.session_id
207 .lock()
208 .unwrap_or_else(|e| e.into_inner())
209 .clone()
210 }
211
212 fn write_line(&self, line: &str) -> Result<(), VmError> {
214 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
215 let mut stdout = std::io::stdout().lock();
216 stdout
217 .write_all(line.as_bytes())
218 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
219 stdout
220 .write_all(b"\n")
221 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
222 stdout
223 .flush()
224 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
225 Ok(())
226 }
227
228 pub async fn call(
231 &self,
232 method: &str,
233 params: serde_json::Value,
234 ) -> Result<serde_json::Value, VmError> {
235 if self.is_cancelled() {
236 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
237 }
238
239 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
240
241 let request = serde_json::json!({
242 "jsonrpc": "2.0",
243 "id": id,
244 "method": method,
245 "params": params,
246 });
247
248 let (tx, rx) = oneshot::channel();
250 {
251 let mut pending = self.pending.lock().await;
252 pending.insert(id, tx);
253 }
254
255 let line = serde_json::to_string(&request)
257 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
258 if let Err(e) = self.write_line(&line) {
259 let mut pending = self.pending.lock().await;
261 pending.remove(&id);
262 return Err(e);
263 }
264
265 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
267 Ok(Ok(msg)) => msg,
268 Ok(Err(_)) => {
269 return Err(VmError::Runtime(
271 "Bridge: host closed connection before responding".into(),
272 ));
273 }
274 Err(_) => {
275 let mut pending = self.pending.lock().await;
277 pending.remove(&id);
278 return Err(VmError::Runtime(format!(
279 "Bridge: host did not respond to '{method}' within {}s",
280 DEFAULT_TIMEOUT.as_secs()
281 )));
282 }
283 };
284
285 if let Some(error) = response.get("error") {
287 let message = error["message"].as_str().unwrap_or("Unknown host error");
288 let code = error["code"].as_i64().unwrap_or(-1);
289 if code == -32001 {
291 return Err(VmError::CategorizedError {
292 message: message.to_string(),
293 category: ErrorCategory::ToolRejected,
294 });
295 }
296 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
297 }
298
299 Ok(response["result"].clone())
300 }
301
302 pub fn notify(&self, method: &str, params: serde_json::Value) {
305 let notification = serde_json::json!({
306 "jsonrpc": "2.0",
307 "method": method,
308 "params": params,
309 });
310 if let Ok(line) = serde_json::to_string(¬ification) {
311 let _ = self.write_line(&line);
312 }
313 }
314
315 pub fn is_cancelled(&self) -> bool {
317 self.cancelled.load(Ordering::SeqCst)
318 }
319
320 pub async fn push_queued_user_message(&self, content: String, mode: &str) {
321 self.queued_user_messages
322 .lock()
323 .await
324 .push_back(QueuedUserMessage {
325 content,
326 mode: QueuedUserMessageMode::from_str(mode),
327 });
328 }
329
330 pub async fn take_queued_user_messages(
331 &self,
332 include_interrupt_immediate: bool,
333 include_finish_step: bool,
334 include_wait_for_completion: bool,
335 ) -> Vec<QueuedUserMessage> {
336 let mut queue = self.queued_user_messages.lock().await;
337 let mut selected = Vec::new();
338 let mut retained = VecDeque::new();
339 while let Some(message) = queue.pop_front() {
340 let should_take = match message.mode {
341 QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
342 QueuedUserMessageMode::FinishStep => include_finish_step,
343 QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
344 };
345 if should_take {
346 selected.push(message);
347 } else {
348 retained.push_back(message);
349 }
350 }
351 *queue = retained;
352 selected
353 }
354
355 pub async fn take_queued_user_messages_for(
356 &self,
357 checkpoint: DeliveryCheckpoint,
358 ) -> Vec<QueuedUserMessage> {
359 match checkpoint {
360 DeliveryCheckpoint::InterruptImmediate => {
361 self.take_queued_user_messages(true, false, false).await
362 }
363 DeliveryCheckpoint::AfterCurrentOperation => {
364 self.take_queued_user_messages(false, true, false).await
365 }
366 DeliveryCheckpoint::EndOfInteraction => {
367 self.take_queued_user_messages(false, false, true).await
368 }
369 }
370 }
371
372 pub fn send_output(&self, text: &str) {
374 self.notify("output", serde_json::json!({"text": text}));
375 }
376
377 pub fn send_progress(
379 &self,
380 phase: &str,
381 message: &str,
382 progress: Option<i64>,
383 total: Option<i64>,
384 data: Option<serde_json::Value>,
385 ) {
386 let mut payload = serde_json::json!({"phase": phase, "message": message});
387 if let Some(p) = progress {
388 payload["progress"] = serde_json::json!(p);
389 }
390 if let Some(t) = total {
391 payload["total"] = serde_json::json!(t);
392 }
393 if let Some(d) = data {
394 payload["data"] = d;
395 }
396 self.notify("progress", payload);
397 }
398
399 pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
401 let mut payload = serde_json::json!({"level": level, "message": message});
402 if let Some(f) = fields {
403 payload["fields"] = f;
404 }
405 self.notify("log", payload);
406 }
407
408 pub fn send_call_start(
411 &self,
412 call_id: &str,
413 call_type: &str,
414 name: &str,
415 metadata: serde_json::Value,
416 ) {
417 let session_id = self.get_session_id();
418 let script = self.get_script_name();
419 self.notify(
420 "session/update",
421 serde_json::json!({
422 "sessionId": session_id,
423 "update": {
424 "sessionUpdate": "call_start",
425 "content": {
426 "call_id": call_id,
427 "call_type": call_type,
428 "name": name,
429 "script": script,
430 "metadata": metadata,
431 },
432 },
433 }),
434 );
435 }
436
437 pub fn send_call_progress(&self, call_id: &str, delta: &str, accumulated_tokens: u64) {
440 let session_id = self.get_session_id();
441 self.notify(
442 "session/update",
443 serde_json::json!({
444 "sessionId": session_id,
445 "update": {
446 "sessionUpdate": "call_progress",
447 "content": {
448 "call_id": call_id,
449 "delta": delta,
450 "accumulated_tokens": accumulated_tokens,
451 },
452 },
453 }),
454 );
455 }
456
457 pub fn send_call_end(
459 &self,
460 call_id: &str,
461 call_type: &str,
462 name: &str,
463 duration_ms: u64,
464 status: &str,
465 metadata: serde_json::Value,
466 ) {
467 let session_id = self.get_session_id();
468 let script = self.get_script_name();
469 self.notify(
470 "session/update",
471 serde_json::json!({
472 "sessionId": session_id,
473 "update": {
474 "sessionUpdate": "call_end",
475 "content": {
476 "call_id": call_id,
477 "call_type": call_type,
478 "name": name,
479 "script": script,
480 "duration_ms": duration_ms,
481 "status": status,
482 "metadata": metadata,
483 },
484 },
485 }),
486 );
487 }
488}
489
490pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
492 crate::stdlib::json_to_vm_value(val)
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_json_rpc_request_format() {
501 let request = serde_json::json!({
502 "jsonrpc": "2.0",
503 "id": 1,
504 "method": "llm_call",
505 "params": {
506 "prompt": "Hello",
507 "system": "Be helpful",
508 },
509 });
510 let s = serde_json::to_string(&request).unwrap();
511 assert!(s.contains("\"jsonrpc\":\"2.0\""));
512 assert!(s.contains("\"id\":1"));
513 assert!(s.contains("\"method\":\"llm_call\""));
514 }
515
516 #[test]
517 fn test_json_rpc_notification_format() {
518 let notification = serde_json::json!({
519 "jsonrpc": "2.0",
520 "method": "output",
521 "params": {"text": "[harn] hello\n"},
522 });
523 let s = serde_json::to_string(¬ification).unwrap();
524 assert!(s.contains("\"method\":\"output\""));
525 assert!(!s.contains("\"id\""));
526 }
527
528 #[test]
529 fn test_json_rpc_error_response_parsing() {
530 let response = serde_json::json!({
531 "jsonrpc": "2.0",
532 "id": 1,
533 "error": {
534 "code": -32600,
535 "message": "Invalid request",
536 },
537 });
538 assert!(response.get("error").is_some());
539 assert_eq!(
540 response["error"]["message"].as_str().unwrap(),
541 "Invalid request"
542 );
543 }
544
545 #[test]
546 fn test_json_rpc_success_response_parsing() {
547 let response = serde_json::json!({
548 "jsonrpc": "2.0",
549 "id": 1,
550 "result": {
551 "text": "Hello world",
552 "input_tokens": 10,
553 "output_tokens": 5,
554 },
555 });
556 assert!(response.get("result").is_some());
557 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
558 }
559
560 #[test]
561 fn test_cancelled_flag() {
562 let cancelled = Arc::new(AtomicBool::new(false));
563 assert!(!cancelled.load(Ordering::SeqCst));
564 cancelled.store(true, Ordering::SeqCst);
565 assert!(cancelled.load(Ordering::SeqCst));
566 }
567
568 #[test]
569 fn queued_messages_are_filtered_by_delivery_mode() {
570 let runtime = tokio::runtime::Builder::new_current_thread()
571 .enable_all()
572 .build()
573 .unwrap();
574 runtime.block_on(async {
575 let bridge = HostBridge::from_parts(
576 Arc::new(Mutex::new(HashMap::new())),
577 Arc::new(AtomicBool::new(false)),
578 Arc::new(std::sync::Mutex::new(())),
579 1,
580 );
581 bridge
582 .push_queued_user_message("first".to_string(), "finish_step")
583 .await;
584 bridge
585 .push_queued_user_message("second".to_string(), "wait_for_completion")
586 .await;
587
588 let finish_step = bridge.take_queued_user_messages(false, true, false).await;
589 assert_eq!(finish_step.len(), 1);
590 assert_eq!(finish_step[0].content, "first");
591
592 let turn_end = bridge.take_queued_user_messages(false, false, true).await;
593 assert_eq!(turn_end.len(), 1);
594 assert_eq!(turn_end[0].content, "second");
595 });
596 }
597
598 #[test]
599 fn test_json_result_to_vm_value_string() {
600 let val = serde_json::json!("hello");
601 let vm_val = json_result_to_vm_value(&val);
602 assert_eq!(vm_val.display(), "hello");
603 }
604
605 #[test]
606 fn test_json_result_to_vm_value_dict() {
607 let val = serde_json::json!({"name": "test", "count": 42});
608 let vm_val = json_result_to_vm_value(&val);
609 let VmValue::Dict(d) = &vm_val else {
610 unreachable!("Expected Dict, got {:?}", vm_val);
611 };
612 assert_eq!(d.get("name").unwrap().display(), "test");
613 assert_eq!(d.get("count").unwrap().display(), "42");
614 }
615
616 #[test]
617 fn test_json_result_to_vm_value_null() {
618 let val = serde_json::json!(null);
619 let vm_val = json_result_to_vm_value(&val);
620 assert!(matches!(vm_val, VmValue::Nil));
621 }
622
623 #[test]
624 fn test_json_result_to_vm_value_nested() {
625 let val = serde_json::json!({
626 "text": "response",
627 "tool_calls": [
628 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
629 ],
630 "input_tokens": 100,
631 "output_tokens": 50,
632 });
633 let vm_val = json_result_to_vm_value(&val);
634 let VmValue::Dict(d) = &vm_val else {
635 unreachable!("Expected Dict, got {:?}", vm_val);
636 };
637 assert_eq!(d.get("text").unwrap().display(), "response");
638 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
639 unreachable!("Expected List for tool_calls");
640 };
641 assert_eq!(list.len(), 1);
642 }
643
644 #[test]
645 fn test_timeout_duration() {
646 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
647 }
648}