1use std::collections::HashMap;
8use std::io::Write;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::io::AsyncBufReadExt;
14use tokio::sync::{oneshot, Mutex};
15
16use crate::value::{VmError, VmValue};
17
18const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
20
21pub struct HostBridge {
28 next_id: AtomicU64,
29 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
31 cancelled: Arc<AtomicBool>,
33 stdout_lock: Arc<std::sync::Mutex<()>>,
35}
36
37#[allow(clippy::new_without_default)]
39impl HostBridge {
40 pub fn new() -> Self {
45 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
46 Arc::new(Mutex::new(HashMap::new()));
47 let cancelled = Arc::new(AtomicBool::new(false));
48
49 let pending_clone = pending.clone();
51 let cancelled_clone = cancelled.clone();
52 tokio::task::spawn_local(async move {
53 let stdin = tokio::io::stdin();
54 let reader = tokio::io::BufReader::new(stdin);
55 let mut lines = reader.lines();
56
57 while let Ok(Some(line)) = lines.next_line().await {
58 let line = line.trim().to_string();
59 if line.is_empty() {
60 continue;
61 }
62
63 let msg: serde_json::Value = match serde_json::from_str(&line) {
64 Ok(v) => v,
65 Err(_) => continue, };
67
68 if msg.get("id").is_none() {
70 if let Some(method) = msg["method"].as_str() {
71 if method == "cancel" {
72 cancelled_clone.store(true, Ordering::SeqCst);
73 }
74 }
75 continue;
76 }
77
78 if let Some(id) = msg["id"].as_u64() {
80 let mut pending = pending_clone.lock().await;
81 if let Some(sender) = pending.remove(&id) {
82 let _ = sender.send(msg);
83 }
84 }
85 }
86
87 let mut pending = pending_clone.lock().await;
89 pending.clear();
90 });
91
92 Self {
93 next_id: AtomicU64::new(1),
94 pending,
95 cancelled,
96 stdout_lock: Arc::new(std::sync::Mutex::new(())),
97 }
98 }
99
100 fn write_line(&self, line: &str) -> Result<(), VmError> {
102 let _guard = self.stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
103 let mut stdout = std::io::stdout().lock();
104 stdout
105 .write_all(line.as_bytes())
106 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
107 stdout
108 .write_all(b"\n")
109 .map_err(|e| VmError::Runtime(format!("Bridge write error: {e}")))?;
110 stdout
111 .flush()
112 .map_err(|e| VmError::Runtime(format!("Bridge flush error: {e}")))?;
113 Ok(())
114 }
115
116 pub async fn call(
119 &self,
120 method: &str,
121 params: serde_json::Value,
122 ) -> Result<serde_json::Value, VmError> {
123 if self.is_cancelled() {
124 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
125 }
126
127 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
128
129 let request = serde_json::json!({
130 "jsonrpc": "2.0",
131 "id": id,
132 "method": method,
133 "params": params,
134 });
135
136 let (tx, rx) = oneshot::channel();
138 {
139 let mut pending = self.pending.lock().await;
140 pending.insert(id, tx);
141 }
142
143 let line = serde_json::to_string(&request)
145 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
146 if let Err(e) = self.write_line(&line) {
147 let mut pending = self.pending.lock().await;
149 pending.remove(&id);
150 return Err(e);
151 }
152
153 let response = match tokio::time::timeout(DEFAULT_TIMEOUT, rx).await {
155 Ok(Ok(msg)) => msg,
156 Ok(Err(_)) => {
157 return Err(VmError::Runtime(
159 "Bridge: host closed connection before responding".into(),
160 ));
161 }
162 Err(_) => {
163 let mut pending = self.pending.lock().await;
165 pending.remove(&id);
166 return Err(VmError::Runtime(format!(
167 "Bridge: host did not respond to '{method}' within {}s",
168 DEFAULT_TIMEOUT.as_secs()
169 )));
170 }
171 };
172
173 if let Some(error) = response.get("error") {
175 let message = error["message"].as_str().unwrap_or("Unknown host error");
176 let code = error["code"].as_i64().unwrap_or(-1);
177 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
178 }
179
180 Ok(response["result"].clone())
181 }
182
183 pub fn notify(&self, method: &str, params: serde_json::Value) {
186 let notification = serde_json::json!({
187 "jsonrpc": "2.0",
188 "method": method,
189 "params": params,
190 });
191 if let Ok(line) = serde_json::to_string(¬ification) {
192 let _ = self.write_line(&line);
193 }
194 }
195
196 pub fn is_cancelled(&self) -> bool {
198 self.cancelled.load(Ordering::SeqCst)
199 }
200
201 pub fn send_output(&self, text: &str) {
203 self.notify("output", serde_json::json!({"text": text}));
204 }
205
206 pub fn send_progress(&self, phase: &str, message: &str) {
208 self.notify(
209 "progress",
210 serde_json::json!({"phase": phase, "message": message}),
211 );
212 }
213}
214
215pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
217 crate::stdlib::json_to_vm_value(val)
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_json_rpc_request_format() {
226 let request = serde_json::json!({
227 "jsonrpc": "2.0",
228 "id": 1,
229 "method": "llm_call",
230 "params": {
231 "prompt": "Hello",
232 "system": "Be helpful",
233 },
234 });
235 let s = serde_json::to_string(&request).unwrap();
236 assert!(s.contains("\"jsonrpc\":\"2.0\""));
237 assert!(s.contains("\"id\":1"));
238 assert!(s.contains("\"method\":\"llm_call\""));
239 }
240
241 #[test]
242 fn test_json_rpc_notification_format() {
243 let notification = serde_json::json!({
244 "jsonrpc": "2.0",
245 "method": "output",
246 "params": {"text": "[harn] hello\n"},
247 });
248 let s = serde_json::to_string(¬ification).unwrap();
249 assert!(s.contains("\"method\":\"output\""));
250 assert!(!s.contains("\"id\""));
251 }
252
253 #[test]
254 fn test_json_rpc_error_response_parsing() {
255 let response = serde_json::json!({
256 "jsonrpc": "2.0",
257 "id": 1,
258 "error": {
259 "code": -32600,
260 "message": "Invalid request",
261 },
262 });
263 assert!(response.get("error").is_some());
264 assert_eq!(
265 response["error"]["message"].as_str().unwrap(),
266 "Invalid request"
267 );
268 }
269
270 #[test]
271 fn test_json_rpc_success_response_parsing() {
272 let response = serde_json::json!({
273 "jsonrpc": "2.0",
274 "id": 1,
275 "result": {
276 "text": "Hello world",
277 "input_tokens": 10,
278 "output_tokens": 5,
279 },
280 });
281 assert!(response.get("result").is_some());
282 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
283 }
284
285 #[test]
286 fn test_cancelled_flag() {
287 let cancelled = Arc::new(AtomicBool::new(false));
288 assert!(!cancelled.load(Ordering::SeqCst));
289 cancelled.store(true, Ordering::SeqCst);
290 assert!(cancelled.load(Ordering::SeqCst));
291 }
292
293 #[test]
294 fn test_json_result_to_vm_value_string() {
295 let val = serde_json::json!("hello");
296 let vm_val = json_result_to_vm_value(&val);
297 assert_eq!(vm_val.display(), "hello");
298 }
299
300 #[test]
301 fn test_json_result_to_vm_value_dict() {
302 let val = serde_json::json!({"name": "test", "count": 42});
303 let vm_val = json_result_to_vm_value(&val);
304 if let VmValue::Dict(d) = &vm_val {
305 assert_eq!(d.get("name").unwrap().display(), "test");
306 assert_eq!(d.get("count").unwrap().display(), "42");
307 } else {
308 panic!("Expected Dict, got {:?}", vm_val);
309 }
310 }
311
312 #[test]
313 fn test_json_result_to_vm_value_null() {
314 let val = serde_json::json!(null);
315 let vm_val = json_result_to_vm_value(&val);
316 assert!(matches!(vm_val, VmValue::Nil));
317 }
318
319 #[test]
320 fn test_json_result_to_vm_value_nested() {
321 let val = serde_json::json!({
322 "text": "response",
323 "tool_calls": [
324 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
325 ],
326 "input_tokens": 100,
327 "output_tokens": 50,
328 });
329 let vm_val = json_result_to_vm_value(&val);
330 if let VmValue::Dict(d) = &vm_val {
331 assert_eq!(d.get("text").unwrap().display(), "response");
332 if let VmValue::List(list) = d.get("tool_calls").unwrap() {
333 assert_eq!(list.len(), 1);
334 } else {
335 panic!("Expected List for tool_calls");
336 }
337 } else {
338 panic!("Expected Dict");
339 }
340 }
341
342 #[test]
343 fn test_timeout_duration() {
344 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
345 }
346}