1use std::collections::HashMap;
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{VmError, VmValue};
17
18const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21pub struct HostBridge {
28 next_id: AtomicU64,
29 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31 cancelled: Arc<AtomicBool>,
33 stdout_lock: Arc<std::sync::Mutex<()>>,
35}
36
37#[allow(clippy::new_without_default)]
39impl HostBridge {
40 pub fn new() -> Self {
45 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
46 Arc::new(Mutex::new(HashMap::new()));
47 let cancelled = Arc::new(AtomicBool::new(false));
48
49 let pending_clone = pending.clone();
51 let cancelled_clone = cancelled.clone();
52 tokio::task::spawn_local(async move {
53 let stdin = tokio::io::stdin();
54 let reader = tokio::io::BufReader::new(stdin);
55 let mut lines = reader.lines();
56
57 while let Ok(Some(line)) = lines.next_line().await {
58 let line = line.trim().to_string();
59 if line.is_empty() {
60 continue;
61 }
62
63 let msg: serde_json::Value = match serde_json::from_str(&line) {
64 Ok(v) => v,
65 Err(_) => continue, };
67
68 if msg.get("id").is_none() {
70 if let Some(method) = msg["method"].as_str() {
71 if method == "cancel" {
72 cancelled_clone.store(true, Ordering::SeqCst);
73 }
74 }
75 continue;
76 }
77
78 if let Some(id) = msg["id"].as_u64() {
80 let mut pending = pending_clone.lock().await;
81 if let Some(sender) = pending.remove(&id) {
82 let _ = sender.send(msg);
83 }
84 }
85 }
86
87 let mut pending = pending_clone.lock().await;
89 pending.clear();
90 });
91
92 Self {
93 next_id: AtomicU64::new(1),
94 pending,
95 cancelled,
96 stdout_lock: Arc::new(std::sync::Mutex::new(())),
97 }
98 }
99
100 pub fn from_parts(
106 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
107 cancelled: Arc<AtomicBool>,
108 stdout_lock: Arc<std::sync::Mutex<()>>,
109 start_id: u64,
110 ) -> Self {
111 Self {
112 next_id: AtomicU64::new(start_id),
113 pending,
114 cancelled,
115 stdout_lock,
116 }
117 }
118
119 fn write_line(&self, line: &str) -> Result<(), VmError> {
121 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
122 let mut stdout = std::io::stdout().lock();
123 stdout
124 .write_all(line.as_bytes())
125 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
126 stdout
127 .write_all(b"\n")
128 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
129 stdout
130 .flush()
131 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
132 Ok(())
133 }
134
135 pub async fn call(
138 &self,
139 method: &str,
140 params: serde_json::Value,
141 ) -> Result<serde_json::Value, VmError> {
142 if self.is_cancelled() {
143 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
144 }
145
146 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
147
148 let request = serde_json::json!({
149 "jsonrpc": "2.0",
150 "id": id,
151 "method": method,
152 "params": params,
153 });
154
155 let (tx, rx) = oneshot::channel();
157 {
158 let mut pending = self.pending.lock().await;
159 pending.insert(id, tx);
160 }
161
162 let line = serde_json::to_string(&request)
164 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
165 if let Err(e) = self.write_line(&line) {
166 let mut pending = self.pending.lock().await;
168 pending.remove(&id);
169 return Err(e);
170 }
171
172 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
174 Ok(Ok(msg)) => msg,
175 Ok(Err(_)) => {
176 return Err(VmError::Runtime(
178 "Bridge: host closed connection before responding".into(),
179 ));
180 }
181 Err(_) => {
182 let mut pending = self.pending.lock().await;
184 pending.remove(&id);
185 return Err(VmError::Runtime(format!(
186 "Bridge: host did not respond to '{method}' within {}s",
187 DEFAULT_TIMEOUT.as_secs()
188 )));
189 }
190 };
191
192 if let Some(error) = response.get("error") {
194 let message = error["message"].as_str().unwrap_or("Unknown host error");
195 let code = error["code"].as_i64().unwrap_or(-1);
196 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
197 }
198
199 Ok(response["result"].clone())
200 }
201
202 pub fn notify(&self, method: &str, params: serde_json::Value) {
205 let notification = serde_json::json!({
206 "jsonrpc": "2.0",
207 "method": method,
208 "params": params,
209 });
210 if let Ok(line) = serde_json::to_string(¬ification) {
211 let _ = self.write_line(&line);
212 }
213 }
214
215 pub fn is_cancelled(&self) -> bool {
217 self.cancelled.load(Ordering::SeqCst)
218 }
219
220 pub fn send_output(&self, text: &str) {
222 self.notify("output", serde_json::json!({"text": text}));
223 }
224
225 pub fn send_progress(
227 &self,
228 phase: &str,
229 message: &str,
230 data: Option<serde_json::Value>,
231 ) {
232 let mut payload = serde_json::json!({"phase": phase, "message": message});
233 if let Some(d) = data {
234 payload["data"] = d;
235 }
236 self.notify("progress", payload);
237 }
238
239 pub fn send_log(
241 &self,
242 level: &str,
243 message: &str,
244 fields: Option<serde_json::Value>,
245 ) {
246 let mut payload = serde_json::json!({"level": level, "message": message});
247 if let Some(f) = fields {
248 payload["fields"] = f;
249 }
250 self.notify("log", payload);
251 }
252}
253
254pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
256 crate::stdlib::json_to_vm_value(val)
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_json_rpc_request_format() {
265 let request = serde_json::json!({
266 "jsonrpc": "2.0",
267 "id": 1,
268 "method": "llm_call",
269 "params": {
270 "prompt": "Hello",
271 "system": "Be helpful",
272 },
273 });
274 let s = serde_json::to_string(&request).unwrap();
275 assert!(s.contains("\"jsonrpc\":\"2.0\""));
276 assert!(s.contains("\"id\":1"));
277 assert!(s.contains("\"method\":\"llm_call\""));
278 }
279
280 #[test]
281 fn test_json_rpc_notification_format() {
282 let notification = serde_json::json!({
283 "jsonrpc": "2.0",
284 "method": "output",
285 "params": {"text": "[harn] hello\n"},
286 });
287 let s = serde_json::to_string(¬ification).unwrap();
288 assert!(s.contains("\"method\":\"output\""));
289 assert!(!s.contains("\"id\""));
290 }
291
292 #[test]
293 fn test_json_rpc_error_response_parsing() {
294 let response = serde_json::json!({
295 "jsonrpc": "2.0",
296 "id": 1,
297 "error": {
298 "code": -32600,
299 "message": "Invalid request",
300 },
301 });
302 assert!(response.get("error").is_some());
303 assert_eq!(
304 response["error"]["message"].as_str().unwrap(),
305 "Invalid request"
306 );
307 }
308
309 #[test]
310 fn test_json_rpc_success_response_parsing() {
311 let response = serde_json::json!({
312 "jsonrpc": "2.0",
313 "id": 1,
314 "result": {
315 "text": "Hello world",
316 "input_tokens": 10,
317 "output_tokens": 5,
318 },
319 });
320 assert!(response.get("result").is_some());
321 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
322 }
323
324 #[test]
325 fn test_cancelled_flag() {
326 let cancelled = Arc::new(AtomicBool::new(false));
327 assert!(!cancelled.load(Ordering::SeqCst));
328 cancelled.store(true, Ordering::SeqCst);
329 assert!(cancelled.load(Ordering::SeqCst));
330 }
331
332 #[test]
333 fn test_json_result_to_vm_value_string() {
334 let val = serde_json::json!("hello");
335 let vm_val = json_result_to_vm_value(&val);
336 assert_eq!(vm_val.display(), "hello");
337 }
338
339 #[test]
340 fn test_json_result_to_vm_value_dict() {
341 let val = serde_json::json!({"name": "test", "count": 42});
342 let vm_val = json_result_to_vm_value(&val);
343 let VmValue::Dict(d) = &vm_val else {
344 unreachable!("Expected Dict, got {:?}", vm_val);
345 };
346 assert_eq!(d.get("name").unwrap().display(), "test");
347 assert_eq!(d.get("count").unwrap().display(), "42");
348 }
349
350 #[test]
351 fn test_json_result_to_vm_value_null() {
352 let val = serde_json::json!(null);
353 let vm_val = json_result_to_vm_value(&val);
354 assert!(matches!(vm_val, VmValue::Nil));
355 }
356
357 #[test]
358 fn test_json_result_to_vm_value_nested() {
359 let val = serde_json::json!({
360 "text": "response",
361 "tool_calls": [
362 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
363 ],
364 "input_tokens": 100,
365 "output_tokens": 50,
366 });
367 let vm_val = json_result_to_vm_value(&val);
368 let VmValue::Dict(d) = &vm_val else {
369 unreachable!("Expected Dict, got {:?}", vm_val);
370 };
371 assert_eq!(d.get("text").unwrap().display(), "response");
372 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
373 unreachable!("Expected List for tool_calls");
374 };
375 assert_eq!(list.len(), 1);
376 }
377
378 #[test]
379 fn test_timeout_duration() {
380 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
381 }
382}