1use crate::{SessionResult, StreamHandler};
8use ralph_proto::json_rpc::{RpcEvent, emit_event_line};
9use serde_json::Value;
10use std::io::{self, Write};
11use std::sync::{Arc, Mutex};
12use std::time::Instant;
13use tracing::warn;
14
15pub struct JsonRpcStreamHandler<W: Write + Send> {
21 writer: Arc<Mutex<W>>,
23 iteration: u32,
25 hat: Option<String>,
27 backend: Option<String>,
29 tool_start_times: std::collections::HashMap<String, Instant>,
31 poisoned: bool,
33}
34
35impl<W: Write + Send> JsonRpcStreamHandler<W> {
36 pub fn new(
44 writer: Arc<Mutex<W>>,
45 iteration: u32,
46 hat: Option<String>,
47 backend: Option<String>,
48 ) -> Self {
49 Self {
50 writer,
51 iteration,
52 hat,
53 backend,
54 tool_start_times: std::collections::HashMap::new(),
55 poisoned: false,
56 }
57 }
58
59 pub fn set_iteration(&mut self, iteration: u32) {
61 self.iteration = iteration;
62 }
63
64 pub fn set_hat(&mut self, hat: Option<String>) {
66 self.hat = hat;
67 }
68
69 pub fn set_backend(&mut self, backend: Option<String>) {
71 self.backend = backend;
72 }
73
74 fn emit(&mut self, event: RpcEvent) {
76 if self.poisoned {
77 return;
78 }
79 let line = emit_event_line(&event);
80 if let Ok(mut writer) = self.writer.lock() {
81 if let Err(e) = writer.write_all(line.as_bytes()) {
82 warn!(error = %e, "Failed to write JSON-RPC event");
83 if e.kind() == io::ErrorKind::BrokenPipe {
84 self.poisoned = true;
85 }
86 return;
87 }
88 if let Err(e) = writer.flush() {
90 warn!(error = %e, "Failed to flush JSON-RPC event");
91 if e.kind() == io::ErrorKind::BrokenPipe {
92 self.poisoned = true;
93 }
94 }
95 }
96 }
97}
98
99impl<W: Write + Send> StreamHandler for JsonRpcStreamHandler<W> {
100 fn on_text(&mut self, text: &str) {
101 self.emit(RpcEvent::TextDelta {
102 iteration: self.iteration,
103 delta: text.to_string(),
104 });
105 }
106
107 fn on_tool_call(&mut self, name: &str, id: &str, input: &Value) {
108 self.tool_start_times.insert(id.to_string(), Instant::now());
110
111 self.emit(RpcEvent::ToolCallStart {
112 iteration: self.iteration,
113 tool_name: name.to_string(),
114 tool_call_id: id.to_string(),
115 input: input.clone(),
116 });
117 }
118
119 fn on_tool_result(&mut self, id: &str, output: &str) {
120 let duration_ms = self
122 .tool_start_times
123 .remove(id)
124 .map(|start| start.elapsed().as_millis() as u64)
125 .unwrap_or(0);
126
127 self.emit(RpcEvent::ToolCallEnd {
128 iteration: self.iteration,
129 tool_call_id: id.to_string(),
130 output: output.to_string(),
131 is_error: false,
132 duration_ms,
133 });
134 }
135
136 fn on_error(&mut self, error: &str) {
137 self.emit(RpcEvent::Error {
138 iteration: self.iteration,
139 code: "EXECUTION_ERROR".to_string(),
140 message: error.to_string(),
141 recoverable: true,
142 });
143 }
144
145 fn on_complete(&mut self, result: &SessionResult) {
146 self.emit(RpcEvent::IterationEnd {
147 iteration: self.iteration,
148 duration_ms: result.duration_ms,
149 cost_usd: result.total_cost_usd,
150 input_tokens: result.input_tokens,
151 output_tokens: result.output_tokens,
152 cache_read_tokens: result.cache_read_tokens,
153 cache_write_tokens: result.cache_write_tokens,
154 loop_complete_triggered: false, });
156 }
157}
158
159pub fn stdout_json_rpc_handler(
161 iteration: u32,
162 hat: Option<String>,
163 backend: Option<String>,
164) -> JsonRpcStreamHandler<io::Stdout> {
165 JsonRpcStreamHandler::new(Arc::new(Mutex::new(io::stdout())), iteration, hat, backend)
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use serde_json::json;
172
173 fn capture_handler() -> (JsonRpcStreamHandler<Vec<u8>>, Arc<Mutex<Vec<u8>>>) {
174 let buffer = Arc::new(Mutex::new(Vec::new()));
175 let handler = JsonRpcStreamHandler::new(
176 buffer.clone(),
177 3,
178 Some("builder".to_string()),
179 Some("claude".to_string()),
180 );
181 (handler, buffer)
182 }
183
184 fn get_output(buffer: &Arc<Mutex<Vec<u8>>>) -> String {
185 let guard = buffer.lock().unwrap();
186 String::from_utf8_lossy(&guard).to_string()
187 }
188
189 fn parse_json_line(line: &str) -> serde_json::Value {
190 serde_json::from_str(line).expect("should be valid JSON")
191 }
192
193 #[test]
194 fn test_text_delta_event() {
195 let (mut handler, buffer) = capture_handler();
196
197 handler.on_text("hello world");
198
199 let output = get_output(&buffer);
200 let json = parse_json_line(output.trim());
201
202 assert_eq!(json["type"], "text_delta");
203 assert_eq!(json["iteration"], 3);
204 assert_eq!(json["delta"], "hello world");
205 }
206
207 #[test]
208 fn test_tool_call_start_event() {
209 let (mut handler, buffer) = capture_handler();
210
211 handler.on_tool_call("Bash", "call-1", &json!({"command": "ls -la"}));
212
213 let output = get_output(&buffer);
214 let json = parse_json_line(output.trim());
215
216 assert_eq!(json["type"], "tool_call_start");
217 assert_eq!(json["iteration"], 3);
218 assert_eq!(json["tool_name"], "Bash");
219 assert_eq!(json["tool_call_id"], "call-1");
220 assert_eq!(json["input"]["command"], "ls -la");
221 }
222
223 #[test]
224 fn test_tool_call_end_event() {
225 let (mut handler, buffer) = capture_handler();
226
227 handler.on_tool_call("Read", "call-2", &json!({"file_path": "/tmp/test"}));
229 handler.on_tool_result("call-2", "file contents here");
230
231 let output = get_output(&buffer);
232 let lines: Vec<&str> = output.trim().lines().collect();
233 assert_eq!(lines.len(), 2);
234
235 let end_json = parse_json_line(lines[1]);
236 assert_eq!(end_json["type"], "tool_call_end");
237 assert_eq!(end_json["iteration"], 3);
238 assert_eq!(end_json["tool_call_id"], "call-2");
239 assert_eq!(end_json["output"], "file contents here");
240 assert_eq!(end_json["is_error"], false);
241 }
242
243 #[test]
244 fn test_error_event() {
245 let (mut handler, buffer) = capture_handler();
246
247 handler.on_error("Connection timeout");
248
249 let output = get_output(&buffer);
250 let json = parse_json_line(output.trim());
251
252 assert_eq!(json["type"], "error");
253 assert_eq!(json["iteration"], 3);
254 assert_eq!(json["code"], "EXECUTION_ERROR");
255 assert_eq!(json["message"], "Connection timeout");
256 assert_eq!(json["recoverable"], true);
257 }
258
259 #[test]
260 fn test_iteration_end_event() {
261 let (mut handler, buffer) = capture_handler();
262
263 let result = SessionResult {
264 duration_ms: 5432,
265 total_cost_usd: 0.0054,
266 num_turns: 3,
267 is_error: false,
268 ..Default::default()
269 };
270 handler.on_complete(&result);
271
272 let output = get_output(&buffer);
273 let json = parse_json_line(output.trim());
274
275 assert_eq!(json["type"], "iteration_end");
276 assert_eq!(json["iteration"], 3);
277 assert_eq!(json["duration_ms"], 5432);
278 assert_eq!(json["cost_usd"], 0.0054);
279 }
280
281 #[test]
282 fn test_one_line_per_event() {
283 let (mut handler, buffer) = capture_handler();
284
285 handler.on_text("first");
286 handler.on_text("second");
287 handler.on_tool_call("Grep", "t1", &json!({"pattern": "test"}));
288
289 let output = get_output(&buffer);
290 let lines: Vec<&str> = output.trim().lines().collect();
291 assert_eq!(lines.len(), 3);
292
293 for line in lines {
295 let _ = parse_json_line(line);
296 }
297 }
298
299 #[test]
300 fn test_iteration_metadata_included() {
301 let (mut handler, buffer) = capture_handler();
302
303 handler.on_text("test");
305 handler.on_error("error");
306
307 let output = get_output(&buffer);
308 for line in output.trim().lines() {
309 let json = parse_json_line(line);
310 assert_eq!(json["iteration"], 3, "iteration should be present");
311 }
312 }
313
314 #[test]
315 fn test_set_iteration_updates_subsequent_events() {
316 let (mut handler, buffer) = capture_handler();
317
318 handler.on_text("at iter 3");
319 handler.set_iteration(4);
320 handler.on_text("at iter 4");
321
322 let output = get_output(&buffer);
323 let lines: Vec<&str> = output.trim().lines().collect();
324
325 let first = parse_json_line(lines[0]);
326 let second = parse_json_line(lines[1]);
327
328 assert_eq!(first["iteration"], 3);
329 assert_eq!(second["iteration"], 4);
330 }
331
332 #[test]
333 fn test_tool_duration_tracking() {
334 let (mut handler, buffer) = capture_handler();
335
336 handler.on_tool_call("Bash", "slow-call", &json!({"command": "sleep 0.01"}));
337 std::thread::sleep(std::time::Duration::from_millis(10));
338 handler.on_tool_result("slow-call", "done");
339
340 let output = get_output(&buffer);
341 let lines: Vec<&str> = output.trim().lines().collect();
342 let end_json = parse_json_line(lines[1]);
343
344 let duration = end_json["duration_ms"].as_u64().unwrap();
346 assert!(duration >= 10, "duration should be at least 10ms");
347 }
348
349 #[test]
350 fn test_unknown_tool_result_has_zero_duration() {
351 let (mut handler, buffer) = capture_handler();
352
353 handler.on_tool_result("unknown-id", "output");
355
356 let output = get_output(&buffer);
357 let json = parse_json_line(output.trim());
358
359 assert_eq!(json["duration_ms"], 0);
360 }
361
362 struct BrokenPipeWriter {
364 write_attempts: std::cell::Cell<u32>,
365 }
366
367 impl BrokenPipeWriter {
368 fn new() -> Self {
369 Self {
370 write_attempts: std::cell::Cell::new(0),
371 }
372 }
373
374 fn attempts(&self) -> u32 {
375 self.write_attempts.get()
376 }
377 }
378
379 impl Write for BrokenPipeWriter {
380 fn write(&mut self, _buf: &[u8]) -> io::Result<usize> {
381 self.write_attempts.set(self.write_attempts.get() + 1);
382 Err(io::Error::new(
383 io::ErrorKind::BrokenPipe,
384 "Broken pipe (os error 32)",
385 ))
386 }
387
388 fn flush(&mut self) -> io::Result<()> {
389 Ok(())
390 }
391 }
392
393 #[test]
394 fn test_broken_pipe_stops_emitting_after_first_failure() {
395 let writer = Arc::new(Mutex::new(BrokenPipeWriter::new()));
396 let mut handler = JsonRpcStreamHandler::new(
397 writer.clone(),
398 1,
399 Some("builder".to_string()),
400 Some("claude".to_string()),
401 );
402
403 for i in 0..10 {
405 handler.on_text(&format!("event {i}"));
406 }
407
408 let attempts = writer.lock().unwrap().attempts();
409 assert_eq!(
412 attempts, 1,
413 "should stop writing after first broken pipe, but attempted {attempts} writes"
414 );
415 }
416}