1use std::collections::HashMap;
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{ErrorCategory, VmError, VmValue};
17
18const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21pub struct HostBridge {
28 next_id: AtomicU64,
29 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31 cancelled: Arc<AtomicBool>,
33 stdout_lock: Arc<std::sync::Mutex<()>>,
35 session_id: std::sync::Mutex<String>,
37 script_name: std::sync::Mutex<String>,
39}
40
41#[allow(clippy::new_without_default)]
43impl HostBridge {
44 pub fn new() -> Self {
49 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
50 Arc::new(Mutex::new(HashMap::new()));
51 let cancelled = Arc::new(AtomicBool::new(false));
52
53 let pending_clone = pending.clone();
55 let cancelled_clone = cancelled.clone();
56 tokio::task::spawn_local(async move {
57 let stdin = tokio::io::stdin();
58 let reader = tokio::io::BufReader::new(stdin);
59 let mut lines = reader.lines();
60
61 while let Ok(Some(line)) = lines.next_line().await {
62 let line = line.trim().to_string();
63 if line.is_empty() {
64 continue;
65 }
66
67 let msg: serde_json::Value = match serde_json::from_str(&line) {
68 Ok(v) => v,
69 Err(_) => continue, };
71
72 if msg.get("id").is_none() {
74 if let Some(method) = msg["method"].as_str() {
75 if method == "cancel" {
76 cancelled_clone.store(true, Ordering::SeqCst);
77 }
78 }
79 continue;
80 }
81
82 if let Some(id) = msg["id"].as_u64() {
84 let mut pending = pending_clone.lock().await;
85 if let Some(sender) = pending.remove(&id) {
86 let _ = sender.send(msg);
87 }
88 }
89 }
90
91 let mut pending = pending_clone.lock().await;
93 pending.clear();
94 });
95
96 Self {
97 next_id: AtomicU64::new(1),
98 pending,
99 cancelled,
100 stdout_lock: Arc::new(std::sync::Mutex::new(())),
101 session_id: std::sync::Mutex::new(String::new()),
102 script_name: std::sync::Mutex::new(String::new()),
103 }
104 }
105
106 pub fn from_parts(
112 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
113 cancelled: Arc<AtomicBool>,
114 stdout_lock: Arc<std::sync::Mutex<()>>,
115 start_id: u64,
116 ) -> Self {
117 Self {
118 next_id: AtomicU64::new(start_id),
119 pending,
120 cancelled,
121 stdout_lock,
122 session_id: std::sync::Mutex::new(String::new()),
123 script_name: std::sync::Mutex::new(String::new()),
124 }
125 }
126
127 pub fn set_session_id(&self, id: &str) {
129 *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
130 }
131
132 pub fn set_script_name(&self, name: &str) {
134 *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
135 }
136
137 fn get_script_name(&self) -> String {
139 self.script_name
140 .lock()
141 .unwrap_or_else(|e| e.into_inner())
142 .clone()
143 }
144
145 fn get_session_id(&self) -> String {
147 self.session_id
148 .lock()
149 .unwrap_or_else(|e| e.into_inner())
150 .clone()
151 }
152
153 fn write_line(&self, line: &str) -> Result<(), VmError> {
155 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
156 let mut stdout = std::io::stdout().lock();
157 stdout
158 .write_all(line.as_bytes())
159 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
160 stdout
161 .write_all(b"\n")
162 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
163 stdout
164 .flush()
165 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
166 Ok(())
167 }
168
169 pub async fn call(
172 &self,
173 method: &str,
174 params: serde_json::Value,
175 ) -> Result<serde_json::Value, VmError> {
176 if self.is_cancelled() {
177 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
178 }
179
180 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
181
182 let request = serde_json::json!({
183 "jsonrpc": "2.0",
184 "id": id,
185 "method": method,
186 "params": params,
187 });
188
189 let (tx, rx) = oneshot::channel();
191 {
192 let mut pending = self.pending.lock().await;
193 pending.insert(id, tx);
194 }
195
196 let line = serde_json::to_string(&request)
198 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
199 if let Err(e) = self.write_line(&line) {
200 let mut pending = self.pending.lock().await;
202 pending.remove(&id);
203 return Err(e);
204 }
205
206 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
208 Ok(Ok(msg)) => msg,
209 Ok(Err(_)) => {
210 return Err(VmError::Runtime(
212 "Bridge: host closed connection before responding".into(),
213 ));
214 }
215 Err(_) => {
216 let mut pending = self.pending.lock().await;
218 pending.remove(&id);
219 return Err(VmError::Runtime(format!(
220 "Bridge: host did not respond to '{method}' within {}s",
221 DEFAULT_TIMEOUT.as_secs()
222 )));
223 }
224 };
225
226 if let Some(error) = response.get("error") {
228 let message = error["message"].as_str().unwrap_or("Unknown host error");
229 let code = error["code"].as_i64().unwrap_or(-1);
230 if code == -32001 {
232 return Err(VmError::CategorizedError {
233 message: message.to_string(),
234 category: ErrorCategory::ToolRejected,
235 });
236 }
237 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
238 }
239
240 Ok(response["result"].clone())
241 }
242
243 pub fn notify(&self, method: &str, params: serde_json::Value) {
246 let notification = serde_json::json!({
247 "jsonrpc": "2.0",
248 "method": method,
249 "params": params,
250 });
251 if let Ok(line) = serde_json::to_string(¬ification) {
252 let _ = self.write_line(&line);
253 }
254 }
255
256 pub fn is_cancelled(&self) -> bool {
258 self.cancelled.load(Ordering::SeqCst)
259 }
260
261 pub fn send_output(&self, text: &str) {
263 self.notify("output", serde_json::json!({"text": text}));
264 }
265
266 pub fn send_progress(
268 &self,
269 phase: &str,
270 message: &str,
271 progress: Option<i64>,
272 total: Option<i64>,
273 data: Option<serde_json::Value>,
274 ) {
275 let mut payload = serde_json::json!({"phase": phase, "message": message});
276 if let Some(p) = progress {
277 payload["progress"] = serde_json::json!(p);
278 }
279 if let Some(t) = total {
280 payload["total"] = serde_json::json!(t);
281 }
282 if let Some(d) = data {
283 payload["data"] = d;
284 }
285 self.notify("progress", payload);
286 }
287
288 pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
290 let mut payload = serde_json::json!({"level": level, "message": message});
291 if let Some(f) = fields {
292 payload["fields"] = f;
293 }
294 self.notify("log", payload);
295 }
296
297 pub fn send_call_start(
300 &self,
301 call_id: &str,
302 call_type: &str,
303 name: &str,
304 metadata: serde_json::Value,
305 ) {
306 let session_id = self.get_session_id();
307 let script = self.get_script_name();
308 self.notify(
309 "session/update",
310 serde_json::json!({
311 "sessionId": session_id,
312 "update": {
313 "sessionUpdate": "call_start",
314 "content": {
315 "call_id": call_id,
316 "call_type": call_type,
317 "name": name,
318 "script": script,
319 "metadata": metadata,
320 },
321 },
322 }),
323 );
324 }
325
326 pub fn send_call_progress(&self, call_id: &str, delta: &str, accumulated_tokens: u64) {
329 let session_id = self.get_session_id();
330 self.notify(
331 "session/update",
332 serde_json::json!({
333 "sessionId": session_id,
334 "update": {
335 "sessionUpdate": "call_progress",
336 "content": {
337 "call_id": call_id,
338 "delta": delta,
339 "accumulated_tokens": accumulated_tokens,
340 },
341 },
342 }),
343 );
344 }
345
346 pub fn send_call_end(
348 &self,
349 call_id: &str,
350 call_type: &str,
351 name: &str,
352 duration_ms: u64,
353 status: &str,
354 metadata: serde_json::Value,
355 ) {
356 let session_id = self.get_session_id();
357 let script = self.get_script_name();
358 self.notify(
359 "session/update",
360 serde_json::json!({
361 "sessionId": session_id,
362 "update": {
363 "sessionUpdate": "call_end",
364 "content": {
365 "call_id": call_id,
366 "call_type": call_type,
367 "name": name,
368 "script": script,
369 "duration_ms": duration_ms,
370 "status": status,
371 "metadata": metadata,
372 },
373 },
374 }),
375 );
376 }
377}
378
379pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
381 crate::stdlib::json_to_vm_value(val)
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_json_rpc_request_format() {
390 let request = serde_json::json!({
391 "jsonrpc": "2.0",
392 "id": 1,
393 "method": "llm_call",
394 "params": {
395 "prompt": "Hello",
396 "system": "Be helpful",
397 },
398 });
399 let s = serde_json::to_string(&request).unwrap();
400 assert!(s.contains("\"jsonrpc\":\"2.0\""));
401 assert!(s.contains("\"id\":1"));
402 assert!(s.contains("\"method\":\"llm_call\""));
403 }
404
405 #[test]
406 fn test_json_rpc_notification_format() {
407 let notification = serde_json::json!({
408 "jsonrpc": "2.0",
409 "method": "output",
410 "params": {"text": "[harn] hello\n"},
411 });
412 let s = serde_json::to_string(¬ification).unwrap();
413 assert!(s.contains("\"method\":\"output\""));
414 assert!(!s.contains("\"id\""));
415 }
416
417 #[test]
418 fn test_json_rpc_error_response_parsing() {
419 let response = serde_json::json!({
420 "jsonrpc": "2.0",
421 "id": 1,
422 "error": {
423 "code": -32600,
424 "message": "Invalid request",
425 },
426 });
427 assert!(response.get("error").is_some());
428 assert_eq!(
429 response["error"]["message"].as_str().unwrap(),
430 "Invalid request"
431 );
432 }
433
434 #[test]
435 fn test_json_rpc_success_response_parsing() {
436 let response = serde_json::json!({
437 "jsonrpc": "2.0",
438 "id": 1,
439 "result": {
440 "text": "Hello world",
441 "input_tokens": 10,
442 "output_tokens": 5,
443 },
444 });
445 assert!(response.get("result").is_some());
446 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
447 }
448
449 #[test]
450 fn test_cancelled_flag() {
451 let cancelled = Arc::new(AtomicBool::new(false));
452 assert!(!cancelled.load(Ordering::SeqCst));
453 cancelled.store(true, Ordering::SeqCst);
454 assert!(cancelled.load(Ordering::SeqCst));
455 }
456
457 #[test]
458 fn test_json_result_to_vm_value_string() {
459 let val = serde_json::json!("hello");
460 let vm_val = json_result_to_vm_value(&val);
461 assert_eq!(vm_val.display(), "hello");
462 }
463
464 #[test]
465 fn test_json_result_to_vm_value_dict() {
466 let val = serde_json::json!({"name": "test", "count": 42});
467 let vm_val = json_result_to_vm_value(&val);
468 let VmValue::Dict(d) = &vm_val else {
469 unreachable!("Expected Dict, got {:?}", vm_val);
470 };
471 assert_eq!(d.get("name").unwrap().display(), "test");
472 assert_eq!(d.get("count").unwrap().display(), "42");
473 }
474
475 #[test]
476 fn test_json_result_to_vm_value_null() {
477 let val = serde_json::json!(null);
478 let vm_val = json_result_to_vm_value(&val);
479 assert!(matches!(vm_val, VmValue::Nil));
480 }
481
482 #[test]
483 fn test_json_result_to_vm_value_nested() {
484 let val = serde_json::json!({
485 "text": "response",
486 "tool_calls": [
487 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
488 ],
489 "input_tokens": 100,
490 "output_tokens": 50,
491 });
492 let vm_val = json_result_to_vm_value(&val);
493 let VmValue::Dict(d) = &vm_val else {
494 unreachable!("Expected Dict, got {:?}", vm_val);
495 };
496 assert_eq!(d.get("text").unwrap().display(), "response");
497 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
498 unreachable!("Expected List for tool_calls");
499 };
500 assert_eq!(list.len(), 1);
501 }
502
503 #[test]
504 fn test_timeout_duration() {
505 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
506 }
507}