1use std::collections::HashMap;
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{ErrorCategory, VmError, VmValue};
17
18const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21pub struct HostBridge {
28 next_id: AtomicU64,
29 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31 cancelled: Arc<AtomicBool>,
33 stdout_lock: Arc<std::sync::Mutex<()>>,
35 session_id: std::sync::Mutex<String>,
37 script_name: std::sync::Mutex<String>,
39}
40
41#[allow(clippy::new_without_default)]
43impl HostBridge {
44 pub fn new() -> Self {
49 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
50 Arc::new(Mutex::new(HashMap::new()));
51 let cancelled = Arc::new(AtomicBool::new(false));
52
53 let pending_clone = pending.clone();
55 let cancelled_clone = cancelled.clone();
56 tokio::task::spawn_local(async move {
57 let stdin = tokio::io::stdin();
58 let reader = tokio::io::BufReader::new(stdin);
59 let mut lines = reader.lines();
60
61 while let Ok(Some(line)) = lines.next_line().await {
62 let line = line.trim().to_string();
63 if line.is_empty() {
64 continue;
65 }
66
67 let msg: serde_json::Value = match serde_json::from_str(&line) {
68 Ok(v) => v,
69 Err(_) => continue, };
71
72 if msg.get("id").is_none() {
74 if let Some(method) = msg["method"].as_str() {
75 if method == "cancel" {
76 cancelled_clone.store(true, Ordering::SeqCst);
77 }
78 }
79 continue;
80 }
81
82 if let Some(id) = msg["id"].as_u64() {
84 let mut pending = pending_clone.lock().await;
85 if let Some(sender) = pending.remove(&id) {
86 let _ = sender.send(msg);
87 }
88 }
89 }
90
91 let mut pending = pending_clone.lock().await;
93 pending.clear();
94 });
95
96 Self {
97 next_id: AtomicU64::new(1),
98 pending,
99 cancelled,
100 stdout_lock: Arc::new(std::sync::Mutex::new(())),
101 session_id: std::sync::Mutex::new(String::new()),
102 script_name: std::sync::Mutex::new(String::new()),
103 }
104 }
105
106 pub fn from_parts(
112 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
113 cancelled: Arc<AtomicBool>,
114 stdout_lock: Arc<std::sync::Mutex<()>>,
115 start_id: u64,
116 ) -> Self {
117 Self {
118 next_id: AtomicU64::new(start_id),
119 pending,
120 cancelled,
121 stdout_lock,
122 session_id: std::sync::Mutex::new(String::new()),
123 script_name: std::sync::Mutex::new(String::new()),
124 }
125 }
126
127 pub fn set_session_id(&self, id: &str) {
129 *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
130 }
131
132 pub fn set_script_name(&self, name: &str) {
134 *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
135 }
136
137 fn get_script_name(&self) -> String {
139 self.script_name.lock().unwrap_or_else(|e| e.into_inner()).clone()
140 }
141
142 fn get_session_id(&self) -> String {
144 self.session_id.lock().unwrap_or_else(|e| e.into_inner()).clone()
145 }
146
147 fn write_line(&self, line: &str) -> Result<(), VmError> {
149 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
150 let mut stdout = std::io::stdout().lock();
151 stdout
152 .write_all(line.as_bytes())
153 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
154 stdout
155 .write_all(b"\n")
156 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
157 stdout
158 .flush()
159 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
160 Ok(())
161 }
162
163 pub async fn call(
166 &self,
167 method: &str,
168 params: serde_json::Value,
169 ) -> Result<serde_json::Value, VmError> {
170 if self.is_cancelled() {
171 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
172 }
173
174 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
175
176 let request = serde_json::json!({
177 "jsonrpc": "2.0",
178 "id": id,
179 "method": method,
180 "params": params,
181 });
182
183 let (tx, rx) = oneshot::channel();
185 {
186 let mut pending = self.pending.lock().await;
187 pending.insert(id, tx);
188 }
189
190 let line = serde_json::to_string(&request)
192 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
193 if let Err(e) = self.write_line(&line) {
194 let mut pending = self.pending.lock().await;
196 pending.remove(&id);
197 return Err(e);
198 }
199
200 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
202 Ok(Ok(msg)) => msg,
203 Ok(Err(_)) => {
204 return Err(VmError::Runtime(
206 "Bridge: host closed connection before responding".into(),
207 ));
208 }
209 Err(_) => {
210 let mut pending = self.pending.lock().await;
212 pending.remove(&id);
213 return Err(VmError::Runtime(format!(
214 "Bridge: host did not respond to '{method}' within {}s",
215 DEFAULT_TIMEOUT.as_secs()
216 )));
217 }
218 };
219
220 if let Some(error) = response.get("error") {
222 let message = error["message"].as_str().unwrap_or("Unknown host error");
223 let code = error["code"].as_i64().unwrap_or(-1);
224 if code == -32001 {
226 return Err(VmError::CategorizedError {
227 message: message.to_string(),
228 category: ErrorCategory::ToolRejected,
229 });
230 }
231 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
232 }
233
234 Ok(response["result"].clone())
235 }
236
237 pub fn notify(&self, method: &str, params: serde_json::Value) {
240 let notification = serde_json::json!({
241 "jsonrpc": "2.0",
242 "method": method,
243 "params": params,
244 });
245 if let Ok(line) = serde_json::to_string(¬ification) {
246 let _ = self.write_line(&line);
247 }
248 }
249
250 pub fn is_cancelled(&self) -> bool {
252 self.cancelled.load(Ordering::SeqCst)
253 }
254
255 pub fn send_output(&self, text: &str) {
257 self.notify("output", serde_json::json!({"text": text}));
258 }
259
260 pub fn send_progress(&self, phase: &str, message: &str, data: Option<serde_json::Value>) {
262 let mut payload = serde_json::json!({"phase": phase, "message": message});
263 if let Some(d) = data {
264 payload["data"] = d;
265 }
266 self.notify("progress", payload);
267 }
268
269 pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
271 let mut payload = serde_json::json!({"level": level, "message": message});
272 if let Some(f) = fields {
273 payload["fields"] = f;
274 }
275 self.notify("log", payload);
276 }
277
278 pub fn send_call_start(
281 &self,
282 call_id: &str,
283 call_type: &str,
284 name: &str,
285 metadata: serde_json::Value,
286 ) {
287 let session_id = self.get_session_id();
288 let script = self.get_script_name();
289 self.notify(
290 "session/update",
291 serde_json::json!({
292 "sessionId": session_id,
293 "update": {
294 "sessionUpdate": "call_start",
295 "content": {
296 "call_id": call_id,
297 "call_type": call_type,
298 "name": name,
299 "script": script,
300 "metadata": metadata,
301 },
302 },
303 }),
304 );
305 }
306
307 pub fn send_call_progress(&self, call_id: &str, delta: &str, accumulated_tokens: u64) {
310 let session_id = self.get_session_id();
311 self.notify(
312 "session/update",
313 serde_json::json!({
314 "sessionId": session_id,
315 "update": {
316 "sessionUpdate": "call_progress",
317 "content": {
318 "call_id": call_id,
319 "delta": delta,
320 "accumulated_tokens": accumulated_tokens,
321 },
322 },
323 }),
324 );
325 }
326
327 pub fn send_call_end(
329 &self,
330 call_id: &str,
331 call_type: &str,
332 name: &str,
333 duration_ms: u64,
334 status: &str,
335 metadata: serde_json::Value,
336 ) {
337 let session_id = self.get_session_id();
338 let script = self.get_script_name();
339 self.notify(
340 "session/update",
341 serde_json::json!({
342 "sessionId": session_id,
343 "update": {
344 "sessionUpdate": "call_end",
345 "content": {
346 "call_id": call_id,
347 "call_type": call_type,
348 "name": name,
349 "script": script,
350 "duration_ms": duration_ms,
351 "status": status,
352 "metadata": metadata,
353 },
354 },
355 }),
356 );
357 }
358}
359
360pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
362 crate::stdlib::json_to_vm_value(val)
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_json_rpc_request_format() {
371 let request = serde_json::json!({
372 "jsonrpc": "2.0",
373 "id": 1,
374 "method": "llm_call",
375 "params": {
376 "prompt": "Hello",
377 "system": "Be helpful",
378 },
379 });
380 let s = serde_json::to_string(&request).unwrap();
381 assert!(s.contains("\"jsonrpc\":\"2.0\""));
382 assert!(s.contains("\"id\":1"));
383 assert!(s.contains("\"method\":\"llm_call\""));
384 }
385
386 #[test]
387 fn test_json_rpc_notification_format() {
388 let notification = serde_json::json!({
389 "jsonrpc": "2.0",
390 "method": "output",
391 "params": {"text": "[harn] hello\n"},
392 });
393 let s = serde_json::to_string(¬ification).unwrap();
394 assert!(s.contains("\"method\":\"output\""));
395 assert!(!s.contains("\"id\""));
396 }
397
398 #[test]
399 fn test_json_rpc_error_response_parsing() {
400 let response = serde_json::json!({
401 "jsonrpc": "2.0",
402 "id": 1,
403 "error": {
404 "code": -32600,
405 "message": "Invalid request",
406 },
407 });
408 assert!(response.get("error").is_some());
409 assert_eq!(
410 response["error"]["message"].as_str().unwrap(),
411 "Invalid request"
412 );
413 }
414
415 #[test]
416 fn test_json_rpc_success_response_parsing() {
417 let response = serde_json::json!({
418 "jsonrpc": "2.0",
419 "id": 1,
420 "result": {
421 "text": "Hello world",
422 "input_tokens": 10,
423 "output_tokens": 5,
424 },
425 });
426 assert!(response.get("result").is_some());
427 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
428 }
429
430 #[test]
431 fn test_cancelled_flag() {
432 let cancelled = Arc::new(AtomicBool::new(false));
433 assert!(!cancelled.load(Ordering::SeqCst));
434 cancelled.store(true, Ordering::SeqCst);
435 assert!(cancelled.load(Ordering::SeqCst));
436 }
437
438 #[test]
439 fn test_json_result_to_vm_value_string() {
440 let val = serde_json::json!("hello");
441 let vm_val = json_result_to_vm_value(&val);
442 assert_eq!(vm_val.display(), "hello");
443 }
444
445 #[test]
446 fn test_json_result_to_vm_value_dict() {
447 let val = serde_json::json!({"name": "test", "count": 42});
448 let vm_val = json_result_to_vm_value(&val);
449 let VmValue::Dict(d) = &vm_val else {
450 unreachable!("Expected Dict, got {:?}", vm_val);
451 };
452 assert_eq!(d.get("name").unwrap().display(), "test");
453 assert_eq!(d.get("count").unwrap().display(), "42");
454 }
455
456 #[test]
457 fn test_json_result_to_vm_value_null() {
458 let val = serde_json::json!(null);
459 let vm_val = json_result_to_vm_value(&val);
460 assert!(matches!(vm_val, VmValue::Nil));
461 }
462
463 #[test]
464 fn test_json_result_to_vm_value_nested() {
465 let val = serde_json::json!({
466 "text": "response",
467 "tool_calls": [
468 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
469 ],
470 "input_tokens": 100,
471 "output_tokens": 50,
472 });
473 let vm_val = json_result_to_vm_value(&val);
474 let VmValue::Dict(d) = &vm_val else {
475 unreachable!("Expected Dict, got {:?}", vm_val);
476 };
477 assert_eq!(d.get("text").unwrap().display(), "response");
478 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
479 unreachable!("Expected List for tool_calls");
480 };
481 assert_eq!(list.len(), 1);
482 }
483
484 #[test]
485 fn test_timeout_duration() {
486 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
487 }
488}