Skip to main content

ralph_adapters/
json_rpc_handler.rs

1//! JSON-RPC stream handler for emitting orchestration events as JSON lines.
2//!
3//! This handler implements `StreamHandler` and writes one JSON line per event
4//! to a configurable writer (typically stdout). It's the event producer side
5//! of Ralph's JSON-RPC protocol, enabling machine-readable output for frontends.
6
7use crate::{SessionResult, StreamHandler};
8use ralph_proto::json_rpc::{RpcEvent, emit_event_line};
9use serde_json::Value;
10use std::io::{self, Write};
11use std::sync::{Arc, Mutex};
12use std::time::Instant;
13use tracing::warn;
14
15/// Stream handler that emits JSON-RPC events to a writer.
16///
17/// Each `StreamHandler` callback produces one JSON line using the event types
18/// defined in `ralph_proto::json_rpc`. The handler tracks iteration context
19/// and timing to populate event metadata.
20pub struct JsonRpcStreamHandler<W: Write + Send> {
21    /// Output writer (typically stdout, but configurable for testing).
22    writer: Arc<Mutex<W>>,
23    /// Current iteration number (1-indexed).
24    iteration: u32,
25    /// Current hat ID.
26    hat: Option<String>,
27    /// Current backend name.
28    backend: Option<String>,
29    /// Tool call start times for duration tracking.
30    tool_start_times: std::collections::HashMap<String, Instant>,
31    /// Set to true after a broken pipe; suppresses all further writes.
32    poisoned: bool,
33}
34
35impl<W: Write + Send> JsonRpcStreamHandler<W> {
36    /// Creates a new JSON-RPC handler writing to the given writer.
37    ///
38    /// # Arguments
39    /// * `writer` - The output sink (wrapped in Arc<Mutex> for thread safety).
40    /// * `iteration` - Current iteration number (1-indexed).
41    /// * `hat` - Current hat ID (e.g., "builder", "planner").
42    /// * `backend` - Backend name (e.g., "claude", "gemini").
43    pub fn new(
44        writer: Arc<Mutex<W>>,
45        iteration: u32,
46        hat: Option<String>,
47        backend: Option<String>,
48    ) -> Self {
49        Self {
50            writer,
51            iteration,
52            hat,
53            backend,
54            tool_start_times: std::collections::HashMap::new(),
55            poisoned: false,
56        }
57    }
58
59    /// Updates the iteration number for subsequent events.
60    pub fn set_iteration(&mut self, iteration: u32) {
61        self.iteration = iteration;
62    }
63
64    /// Updates the hat for subsequent events.
65    pub fn set_hat(&mut self, hat: Option<String>) {
66        self.hat = hat;
67    }
68
69    /// Updates the backend for subsequent events.
70    pub fn set_backend(&mut self, backend: Option<String>) {
71        self.backend = backend;
72    }
73
74    /// Writes an event to the output, handling errors gracefully.
75    fn emit(&mut self, event: RpcEvent) {
76        if self.poisoned {
77            return;
78        }
79        let line = emit_event_line(&event);
80        if let Ok(mut writer) = self.writer.lock() {
81            if let Err(e) = writer.write_all(line.as_bytes()) {
82                warn!(error = %e, "Failed to write JSON-RPC event");
83                if e.kind() == io::ErrorKind::BrokenPipe {
84                    self.poisoned = true;
85                }
86                return;
87            }
88            // Flush immediately to ensure events are delivered promptly
89            if let Err(e) = writer.flush() {
90                warn!(error = %e, "Failed to flush JSON-RPC event");
91                if e.kind() == io::ErrorKind::BrokenPipe {
92                    self.poisoned = true;
93                }
94            }
95        }
96    }
97}
98
99impl<W: Write + Send> StreamHandler for JsonRpcStreamHandler<W> {
100    fn on_text(&mut self, text: &str) {
101        self.emit(RpcEvent::TextDelta {
102            iteration: self.iteration,
103            delta: text.to_string(),
104        });
105    }
106
107    fn on_tool_call(&mut self, name: &str, id: &str, input: &Value) {
108        // Track start time for duration calculation
109        self.tool_start_times.insert(id.to_string(), Instant::now());
110
111        self.emit(RpcEvent::ToolCallStart {
112            iteration: self.iteration,
113            tool_name: name.to_string(),
114            tool_call_id: id.to_string(),
115            input: input.clone(),
116        });
117    }
118
119    fn on_tool_result(&mut self, id: &str, output: &str) {
120        // Calculate duration from start time
121        let duration_ms = self
122            .tool_start_times
123            .remove(id)
124            .map(|start| start.elapsed().as_millis() as u64)
125            .unwrap_or(0);
126
127        self.emit(RpcEvent::ToolCallEnd {
128            iteration: self.iteration,
129            tool_call_id: id.to_string(),
130            output: output.to_string(),
131            is_error: false,
132            duration_ms,
133        });
134    }
135
136    fn on_error(&mut self, error: &str) {
137        self.emit(RpcEvent::Error {
138            iteration: self.iteration,
139            code: "EXECUTION_ERROR".to_string(),
140            message: error.to_string(),
141            recoverable: true,
142        });
143    }
144
145    fn on_complete(&mut self, result: &SessionResult) {
146        self.emit(RpcEvent::IterationEnd {
147            iteration: self.iteration,
148            duration_ms: result.duration_ms,
149            cost_usd: result.total_cost_usd,
150            input_tokens: result.input_tokens,
151            output_tokens: result.output_tokens,
152            cache_read_tokens: result.cache_read_tokens,
153            cache_write_tokens: result.cache_write_tokens,
154            loop_complete_triggered: false, // Determined externally by orchestration
155        });
156    }
157}
158
159/// Creates a JsonRpcStreamHandler writing to stdout.
160pub fn stdout_json_rpc_handler(
161    iteration: u32,
162    hat: Option<String>,
163    backend: Option<String>,
164) -> JsonRpcStreamHandler<io::Stdout> {
165    JsonRpcStreamHandler::new(Arc::new(Mutex::new(io::stdout())), iteration, hat, backend)
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use serde_json::json;
172
173    fn capture_handler() -> (JsonRpcStreamHandler<Vec<u8>>, Arc<Mutex<Vec<u8>>>) {
174        let buffer = Arc::new(Mutex::new(Vec::new()));
175        let handler = JsonRpcStreamHandler::new(
176            buffer.clone(),
177            3,
178            Some("builder".to_string()),
179            Some("claude".to_string()),
180        );
181        (handler, buffer)
182    }
183
184    fn get_output(buffer: &Arc<Mutex<Vec<u8>>>) -> String {
185        let guard = buffer.lock().unwrap();
186        String::from_utf8_lossy(&guard).to_string()
187    }
188
189    fn parse_json_line(line: &str) -> serde_json::Value {
190        serde_json::from_str(line).expect("should be valid JSON")
191    }
192
193    #[test]
194    fn test_text_delta_event() {
195        let (mut handler, buffer) = capture_handler();
196
197        handler.on_text("hello world");
198
199        let output = get_output(&buffer);
200        let json = parse_json_line(output.trim());
201
202        assert_eq!(json["type"], "text_delta");
203        assert_eq!(json["iteration"], 3);
204        assert_eq!(json["delta"], "hello world");
205    }
206
207    #[test]
208    fn test_tool_call_start_event() {
209        let (mut handler, buffer) = capture_handler();
210
211        handler.on_tool_call("Bash", "call-1", &json!({"command": "ls -la"}));
212
213        let output = get_output(&buffer);
214        let json = parse_json_line(output.trim());
215
216        assert_eq!(json["type"], "tool_call_start");
217        assert_eq!(json["iteration"], 3);
218        assert_eq!(json["tool_name"], "Bash");
219        assert_eq!(json["tool_call_id"], "call-1");
220        assert_eq!(json["input"]["command"], "ls -la");
221    }
222
223    #[test]
224    fn test_tool_call_end_event() {
225        let (mut handler, buffer) = capture_handler();
226
227        // Simulate a tool call followed by result
228        handler.on_tool_call("Read", "call-2", &json!({"file_path": "/tmp/test"}));
229        handler.on_tool_result("call-2", "file contents here");
230
231        let output = get_output(&buffer);
232        let lines: Vec<&str> = output.trim().lines().collect();
233        assert_eq!(lines.len(), 2);
234
235        let end_json = parse_json_line(lines[1]);
236        assert_eq!(end_json["type"], "tool_call_end");
237        assert_eq!(end_json["iteration"], 3);
238        assert_eq!(end_json["tool_call_id"], "call-2");
239        assert_eq!(end_json["output"], "file contents here");
240        assert_eq!(end_json["is_error"], false);
241    }
242
243    #[test]
244    fn test_error_event() {
245        let (mut handler, buffer) = capture_handler();
246
247        handler.on_error("Connection timeout");
248
249        let output = get_output(&buffer);
250        let json = parse_json_line(output.trim());
251
252        assert_eq!(json["type"], "error");
253        assert_eq!(json["iteration"], 3);
254        assert_eq!(json["code"], "EXECUTION_ERROR");
255        assert_eq!(json["message"], "Connection timeout");
256        assert_eq!(json["recoverable"], true);
257    }
258
259    #[test]
260    fn test_iteration_end_event() {
261        let (mut handler, buffer) = capture_handler();
262
263        let result = SessionResult {
264            duration_ms: 5432,
265            total_cost_usd: 0.0054,
266            num_turns: 3,
267            is_error: false,
268            ..Default::default()
269        };
270        handler.on_complete(&result);
271
272        let output = get_output(&buffer);
273        let json = parse_json_line(output.trim());
274
275        assert_eq!(json["type"], "iteration_end");
276        assert_eq!(json["iteration"], 3);
277        assert_eq!(json["duration_ms"], 5432);
278        assert_eq!(json["cost_usd"], 0.0054);
279    }
280
281    #[test]
282    fn test_one_line_per_event() {
283        let (mut handler, buffer) = capture_handler();
284
285        handler.on_text("first");
286        handler.on_text("second");
287        handler.on_tool_call("Grep", "t1", &json!({"pattern": "test"}));
288
289        let output = get_output(&buffer);
290        let lines: Vec<&str> = output.trim().lines().collect();
291        assert_eq!(lines.len(), 3);
292
293        // Each line should be valid JSON
294        for line in lines {
295            let _ = parse_json_line(line);
296        }
297    }
298
299    #[test]
300    fn test_iteration_metadata_included() {
301        let (mut handler, buffer) = capture_handler();
302
303        // All events should include the iteration number
304        handler.on_text("test");
305        handler.on_error("error");
306
307        let output = get_output(&buffer);
308        for line in output.trim().lines() {
309            let json = parse_json_line(line);
310            assert_eq!(json["iteration"], 3, "iteration should be present");
311        }
312    }
313
314    #[test]
315    fn test_set_iteration_updates_subsequent_events() {
316        let (mut handler, buffer) = capture_handler();
317
318        handler.on_text("at iter 3");
319        handler.set_iteration(4);
320        handler.on_text("at iter 4");
321
322        let output = get_output(&buffer);
323        let lines: Vec<&str> = output.trim().lines().collect();
324
325        let first = parse_json_line(lines[0]);
326        let second = parse_json_line(lines[1]);
327
328        assert_eq!(first["iteration"], 3);
329        assert_eq!(second["iteration"], 4);
330    }
331
332    #[test]
333    fn test_tool_duration_tracking() {
334        let (mut handler, buffer) = capture_handler();
335
336        handler.on_tool_call("Bash", "slow-call", &json!({"command": "sleep 0.01"}));
337        std::thread::sleep(std::time::Duration::from_millis(10));
338        handler.on_tool_result("slow-call", "done");
339
340        let output = get_output(&buffer);
341        let lines: Vec<&str> = output.trim().lines().collect();
342        let end_json = parse_json_line(lines[1]);
343
344        // Duration should be > 0 (we slept for 10ms)
345        let duration = end_json["duration_ms"].as_u64().unwrap();
346        assert!(duration >= 10, "duration should be at least 10ms");
347    }
348
349    #[test]
350    fn test_unknown_tool_result_has_zero_duration() {
351        let (mut handler, buffer) = capture_handler();
352
353        // Result without prior call
354        handler.on_tool_result("unknown-id", "output");
355
356        let output = get_output(&buffer);
357        let json = parse_json_line(output.trim());
358
359        assert_eq!(json["duration_ms"], 0);
360    }
361
362    /// A writer that returns BrokenPipe on every write, simulating a disconnected consumer.
363    struct BrokenPipeWriter {
364        write_attempts: std::cell::Cell<u32>,
365    }
366
367    impl BrokenPipeWriter {
368        fn new() -> Self {
369            Self {
370                write_attempts: std::cell::Cell::new(0),
371            }
372        }
373
374        fn attempts(&self) -> u32 {
375            self.write_attempts.get()
376        }
377    }
378
379    impl Write for BrokenPipeWriter {
380        fn write(&mut self, _buf: &[u8]) -> io::Result<usize> {
381            self.write_attempts.set(self.write_attempts.get() + 1);
382            Err(io::Error::new(
383                io::ErrorKind::BrokenPipe,
384                "Broken pipe (os error 32)",
385            ))
386        }
387
388        fn flush(&mut self) -> io::Result<()> {
389            Ok(())
390        }
391    }
392
393    #[test]
394    fn test_broken_pipe_stops_emitting_after_first_failure() {
395        let writer = Arc::new(Mutex::new(BrokenPipeWriter::new()));
396        let mut handler = JsonRpcStreamHandler::new(
397            writer.clone(),
398            1,
399            Some("builder".to_string()),
400            Some("claude".to_string()),
401        );
402
403        // Emit many events — simulates the log spam from the bug report
404        for i in 0..10 {
405            handler.on_text(&format!("event {i}"));
406        }
407
408        let attempts = writer.lock().unwrap().attempts();
409        // BUG: currently all 10 writes are attempted, producing 10 WARN logs.
410        // After fix, only 1 write should be attempted before the handler stops.
411        assert_eq!(
412            attempts, 1,
413            "should stop writing after first broken pipe, but attempted {attempts} writes"
414        );
415    }
416}