1use crate::diagnostics::agent_output::AgentOutputLogger;
4use std::sync::{Arc, Mutex};
5
6#[allow(dead_code)]
11pub struct DiagnosticStreamHandler<H> {
12 inner: H,
13 logger: Arc<Mutex<AgentOutputLogger>>,
14}
15
16impl<H> DiagnosticStreamHandler<H> {
17 pub fn new(inner: H, logger: Arc<Mutex<AgentOutputLogger>>) -> Self {
19 Self { inner, logger }
20 }
21}
22
23#[cfg(test)]
24mod tests {
25 use super::*;
26 use crate::diagnostics::agent_output::{AgentOutputContent, AgentOutputEntry};
27 use std::fs::File;
28 use std::io::{BufRead, BufReader};
29 use tempfile::TempDir;
30
31 struct MockStreamHandler {
33 text_calls: Arc<Mutex<Vec<String>>>,
34 tool_calls: Arc<Mutex<Vec<String>>>,
35 tool_results: Arc<Mutex<Vec<String>>>,
36 errors: Arc<Mutex<Vec<String>>>,
37 completes: Arc<Mutex<usize>>,
38 }
39
40 impl MockStreamHandler {
41 fn new() -> Self {
42 Self {
43 text_calls: Arc::new(Mutex::new(Vec::new())),
44 tool_calls: Arc::new(Mutex::new(Vec::new())),
45 tool_results: Arc::new(Mutex::new(Vec::new())),
46 errors: Arc::new(Mutex::new(Vec::new())),
47 completes: Arc::new(Mutex::new(0)),
48 }
49 }
50 }
51
52 trait StreamHandler: Send {
55 fn on_text(&mut self, text: &str);
56 fn on_tool_call(&mut self, name: &str, id: &str, input: &serde_json::Value);
57 fn on_tool_result(&mut self, id: &str, output: &str);
58 fn on_error(&mut self, error: &str);
59 fn on_complete(&mut self, result: &SessionResult);
60 }
61
62 #[derive(Debug)]
64 #[allow(dead_code)]
65 struct SessionResult {
66 is_error: bool,
67 duration_ms: u64,
68 total_cost_usd: f64,
69 num_turns: u32,
70 }
71
72 impl StreamHandler for MockStreamHandler {
73 fn on_text(&mut self, text: &str) {
74 self.text_calls.lock().unwrap().push(text.to_string());
75 }
76
77 fn on_tool_call(&mut self, name: &str, id: &str, _input: &serde_json::Value) {
78 self.tool_calls
79 .lock()
80 .unwrap()
81 .push(format!("{}:{}", name, id));
82 }
83
84 fn on_tool_result(&mut self, id: &str, output: &str) {
85 self.tool_results
86 .lock()
87 .unwrap()
88 .push(format!("{}:{}", id, output));
89 }
90
91 fn on_error(&mut self, error: &str) {
92 self.errors.lock().unwrap().push(error.to_string());
93 }
94
95 fn on_complete(&mut self, _result: &SessionResult) {
96 *self.completes.lock().unwrap() += 1;
97 }
98 }
99
100 impl<H: StreamHandler> StreamHandler for DiagnosticStreamHandler<H> {
104 fn on_text(&mut self, text: &str) {
105 let _ = self.logger.lock().unwrap().log(AgentOutputContent::Text {
106 text: text.to_string(),
107 });
108 self.inner.on_text(text);
109 }
110
111 fn on_tool_call(&mut self, name: &str, id: &str, input: &serde_json::Value) {
112 let _ = self
113 .logger
114 .lock()
115 .unwrap()
116 .log(AgentOutputContent::ToolCall {
117 name: name.to_string(),
118 id: id.to_string(),
119 input: input.clone(),
120 });
121 self.inner.on_tool_call(name, id, input);
122 }
123
124 fn on_tool_result(&mut self, id: &str, output: &str) {
125 let _ = self
126 .logger
127 .lock()
128 .unwrap()
129 .log(AgentOutputContent::ToolResult {
130 id: id.to_string(),
131 output: output.to_string(),
132 });
133 self.inner.on_tool_result(id, output);
134 }
135
136 fn on_error(&mut self, error: &str) {
137 let _ = self.logger.lock().unwrap().log(AgentOutputContent::Error {
138 message: error.to_string(),
139 });
140 self.inner.on_error(error);
141 }
142
143 fn on_complete(&mut self, result: &SessionResult) {
144 let _ = self
145 .logger
146 .lock()
147 .unwrap()
148 .log(AgentOutputContent::Complete {
149 input_tokens: None,
150 output_tokens: None,
151 });
152 self.inner.on_complete(result);
153 }
154 }
155
156 #[test]
157 fn test_wrapper_calls_inner_handler() {
158 let temp = TempDir::new().unwrap();
159 let logger = Arc::new(Mutex::new(AgentOutputLogger::new(temp.path()).unwrap()));
160 logger.lock().unwrap().set_context(1, "ralph");
161
162 let mock = MockStreamHandler::new();
163 let text_calls = mock.text_calls.clone();
164 let tool_calls = mock.tool_calls.clone();
165 let errors = mock.errors.clone();
166
167 let mut wrapper = DiagnosticStreamHandler::new(mock, logger);
168
169 wrapper.on_text("Hello");
170 wrapper.on_tool_call("Read", "t1", &serde_json::json!({"file": "test.rs"}));
171 wrapper.on_error("Failed");
172
173 assert_eq!(text_calls.lock().unwrap().len(), 1);
175 assert_eq!(text_calls.lock().unwrap()[0], "Hello");
176
177 assert_eq!(tool_calls.lock().unwrap().len(), 1);
178 assert_eq!(tool_calls.lock().unwrap()[0], "Read:t1");
179
180 assert_eq!(errors.lock().unwrap().len(), 1);
181 assert_eq!(errors.lock().unwrap()[0], "Failed");
182 }
183
184 #[test]
185 fn test_wrapper_logs_all_events() {
186 let temp = TempDir::new().unwrap();
187 let logger = Arc::new(Mutex::new(AgentOutputLogger::new(temp.path()).unwrap()));
188 logger.lock().unwrap().set_context(1, "ralph");
189
190 let mock = MockStreamHandler::new();
191 let mut wrapper = DiagnosticStreamHandler::new(mock, logger);
192
193 wrapper.on_text("Building");
194 wrapper.on_tool_call("Execute", "t1", &serde_json::json!({"cmd": "cargo test"}));
195 wrapper.on_tool_result("t1", "Tests passed");
196 wrapper.on_error("Parse error");
197 wrapper.on_complete(&SessionResult {
198 is_error: false,
199 duration_ms: 1000,
200 total_cost_usd: 0.05,
201 num_turns: 3,
202 });
203
204 drop(wrapper);
205
206 let file = File::open(temp.path().join("agent-output.jsonl")).unwrap();
208 let reader = BufReader::new(file);
209 let lines: Vec<String> = reader.lines().map(|l| l.unwrap()).collect();
210
211 assert_eq!(lines.len(), 5);
212
213 let entries: Vec<AgentOutputEntry> = lines
215 .iter()
216 .map(|l| serde_json::from_str(l).unwrap())
217 .collect();
218
219 assert!(matches!(
220 entries[0].content,
221 AgentOutputContent::Text { .. }
222 ));
223 assert!(matches!(
224 entries[1].content,
225 AgentOutputContent::ToolCall { .. }
226 ));
227 assert!(matches!(
228 entries[2].content,
229 AgentOutputContent::ToolResult { .. }
230 ));
231 assert!(matches!(
232 entries[3].content,
233 AgentOutputContent::Error { .. }
234 ));
235 assert!(matches!(
236 entries[4].content,
237 AgentOutputContent::Complete { .. }
238 ));
239 }
240
241 #[test]
242 fn test_thread_safety() {
243 use std::thread;
244
245 let temp = TempDir::new().unwrap();
246 let logger = Arc::new(Mutex::new(AgentOutputLogger::new(temp.path()).unwrap()));
247 logger.lock().unwrap().set_context(1, "ralph");
248
249 let logger1 = logger.clone();
250 let logger2 = logger.clone();
251
252 let handle1 = thread::spawn(move || {
253 let mock = MockStreamHandler::new();
254 let mut wrapper = DiagnosticStreamHandler::new(mock, logger1);
255 for i in 0..10 {
256 wrapper.on_text(&format!("Thread1-{}", i));
257 }
258 });
259
260 let handle2 = thread::spawn(move || {
261 let mock = MockStreamHandler::new();
262 let mut wrapper = DiagnosticStreamHandler::new(mock, logger2);
263 for i in 0..10 {
264 wrapper.on_text(&format!("Thread2-{}", i));
265 }
266 });
267
268 handle1.join().unwrap();
269 handle2.join().unwrap();
270
271 let file = File::open(temp.path().join("agent-output.jsonl")).unwrap();
273 let reader = BufReader::new(file);
274 let lines: Vec<String> = reader.lines().map(|l| l.unwrap()).collect();
275
276 assert_eq!(lines.len(), 20);
277 }
278}