ralph_workflow/reducer/
fault_tolerant_executor.rs1use crate::agents::{AgentRole, JsonParserType};
13use crate::pipeline::{run_with_prompt, PipelineRuntime, PromptCommand};
14use crate::reducer::event::{AgentErrorKind, PipelineEvent};
15use anyhow::Result;
16use std::io;
17
18#[derive(Clone, Copy)]
20pub struct AgentExecutionConfig<'a> {
21 pub role: AgentRole,
23 pub agent_name: &'a str,
25 pub cmd_str: &'a str,
27 pub parser_type: JsonParserType,
29 pub env_vars: &'a std::collections::HashMap<String, String>,
31 pub prompt: &'a str,
33 pub display_name: &'a str,
35 pub logfile: &'a str,
37}
38
39pub fn execute_agent_fault_tolerantly(
61 config: AgentExecutionConfig<'_>,
62 runtime: &mut PipelineRuntime<'_>,
63) -> Result<PipelineEvent> {
64 let role = config.role;
65
66 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
67 try_agent_execution(config, runtime)
68 }));
69
70 match result {
71 Ok(event_result) => event_result,
72 Err(_) => {
73 let error_kind = AgentErrorKind::InternalError;
74 let retriable = is_retriable_agent_error(&error_kind);
75
76 Ok(PipelineEvent::AgentInvocationFailed {
77 role,
78 agent: config.agent_name.to_string(),
79 exit_code: 1,
80 error_kind,
81 retriable,
82 })
83 }
84 }
85}
86
87fn try_agent_execution(
93 config: AgentExecutionConfig<'_>,
94 runtime: &mut PipelineRuntime<'_>,
95) -> Result<PipelineEvent> {
96 let prompt_cmd = PromptCommand {
97 label: config.agent_name,
98 display_name: config.display_name,
99 cmd_str: config.cmd_str,
100 prompt: config.prompt,
101 logfile: config.logfile,
102 parser_type: config.parser_type,
103 env_vars: config.env_vars,
104 };
105
106 match run_with_prompt(&prompt_cmd, runtime) {
107 Ok(result) if result.exit_code == 0 => Ok(PipelineEvent::AgentInvocationSucceeded {
108 role: config.role,
109 agent: config.agent_name.to_string(),
110 }),
111 Ok(result) => {
112 let exit_code = result.exit_code;
113 let error_kind = classify_agent_error(exit_code, &result.stderr);
114 let retriable = is_retriable_agent_error(&error_kind);
115
116 Ok(PipelineEvent::AgentInvocationFailed {
117 role: config.role,
118 agent: config.agent_name.to_string(),
119 exit_code,
120 error_kind,
121 retriable,
122 })
123 }
124 Err(e) => {
125 let error_kind = if let Ok(io_err) = e.downcast::<io::Error>() {
126 classify_io_error(&io_err)
127 } else {
128 AgentErrorKind::InternalError
129 };
130 let retriable = is_retriable_agent_error(&error_kind);
131
132 Ok(PipelineEvent::AgentInvocationFailed {
133 role: config.role,
134 agent: config.agent_name.to_string(),
135 exit_code: 1,
136 error_kind,
137 retriable,
138 })
139 }
140 }
141}
142
143fn classify_agent_error(exit_code: i32, stderr: &str) -> AgentErrorKind {
145 const SIGSEGV: i32 = 139;
146 const SIGABRT: i32 = 134;
147 const SIGTERM: i32 = 143;
148
149 match exit_code {
150 SIGSEGV | SIGABRT => AgentErrorKind::InternalError,
151 SIGTERM => AgentErrorKind::Timeout,
152 _ => {
153 let stderr_lower = stderr.to_lowercase();
154
155 if stderr_lower.contains("network")
156 || stderr_lower.contains("connection")
157 || stderr_lower.contains("timeout")
158 {
159 AgentErrorKind::Network
160 } else if stderr_lower.contains("auth")
161 || stderr_lower.contains("api key")
162 || stderr_lower.contains("unauthorized")
163 {
164 AgentErrorKind::Authentication
165 } else if stderr_lower.contains("rate limit")
166 || stderr_lower.contains("quota")
167 || stderr_lower.contains("too many requests")
168 {
169 AgentErrorKind::RateLimit
170 } else if stderr_lower.contains("model")
171 && (stderr_lower.contains("not found") || stderr_lower.contains("unavailable"))
172 {
173 AgentErrorKind::ModelUnavailable
174 } else if stderr_lower.contains("parse")
175 || stderr_lower.contains("invalid")
176 || stderr_lower.contains("malformed")
177 {
178 AgentErrorKind::ParsingError
179 } else if stderr_lower.contains("permission")
180 || stderr_lower.contains("access denied")
181 || stderr_lower.contains("file")
182 {
183 AgentErrorKind::FileSystem
184 } else {
185 AgentErrorKind::InternalError
186 }
187 }
188 }
189}
190
191fn classify_io_error(error: &io::Error) -> AgentErrorKind {
193 let error_msg = error.to_string().to_lowercase();
194
195 if error_msg.contains("timeout") {
196 AgentErrorKind::Timeout
197 } else if error_msg.contains("permission")
198 || error_msg.contains("access denied")
199 || error_msg.contains("no such file")
200 || error_msg.contains("not found")
201 {
202 AgentErrorKind::FileSystem
203 } else if error_msg.contains("broken pipe") || error_msg.contains("connection") {
204 AgentErrorKind::Network
205 } else {
206 AgentErrorKind::InternalError
207 }
208}
209
210fn is_retriable_agent_error(error_kind: &AgentErrorKind) -> bool {
215 matches!(
216 error_kind,
217 AgentErrorKind::Network
218 | AgentErrorKind::RateLimit
219 | AgentErrorKind::Timeout
220 | AgentErrorKind::ModelUnavailable
221 )
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_classify_agent_error_sigsegv() {
230 let error_kind = classify_agent_error(139, "");
231 assert_eq!(error_kind, AgentErrorKind::InternalError);
232 }
233
234 #[test]
235 fn test_classify_agent_error_sigabrt() {
236 let error_kind = classify_agent_error(134, "");
237 assert_eq!(error_kind, AgentErrorKind::InternalError);
238 }
239
240 #[test]
241 fn test_classify_agent_error_sigterm() {
242 let error_kind = classify_agent_error(143, "");
243 assert_eq!(error_kind, AgentErrorKind::Timeout);
244 }
245
246 #[test]
247 fn test_classify_agent_error_network() {
248 let error_kind = classify_agent_error(1, "Connection timeout");
249 assert_eq!(error_kind, AgentErrorKind::Network);
250 }
251
252 #[test]
253 fn test_classify_agent_error_rate_limit() {
254 let error_kind = classify_agent_error(1, "Rate limit exceeded");
255 assert_eq!(error_kind, AgentErrorKind::RateLimit);
256 }
257
258 #[test]
259 fn test_classify_agent_error_authentication() {
260 let error_kind = classify_agent_error(1, "Invalid API key");
261 assert_eq!(error_kind, AgentErrorKind::Authentication);
262 }
263
264 #[test]
265 fn test_classify_agent_error_model_unavailable() {
266 let error_kind = classify_agent_error(1, "Model not found");
267 assert_eq!(error_kind, AgentErrorKind::ModelUnavailable);
268 }
269
270 #[test]
271 fn test_is_retriable_agent_error() {
272 assert!(is_retriable_agent_error(&AgentErrorKind::Network));
273 assert!(is_retriable_agent_error(&AgentErrorKind::RateLimit));
274 assert!(is_retriable_agent_error(&AgentErrorKind::Timeout));
275 assert!(is_retriable_agent_error(&AgentErrorKind::ModelUnavailable));
276 assert!(!is_retriable_agent_error(&AgentErrorKind::Authentication));
277 assert!(!is_retriable_agent_error(&AgentErrorKind::ParsingError));
278 assert!(!is_retriable_agent_error(&AgentErrorKind::FileSystem));
279 assert!(!is_retriable_agent_error(&AgentErrorKind::InternalError));
280 }
281
282 #[test]
283 fn test_classify_io_error_timeout() {
284 let error = io::Error::new(io::ErrorKind::TimedOut, "Operation timeout");
285 let error_kind = classify_io_error(&error);
286 assert_eq!(error_kind, AgentErrorKind::Timeout);
287 }
288
289 #[test]
290 fn test_classify_io_error_filesystem() {
291 let error = io::Error::new(io::ErrorKind::PermissionDenied, "Permission denied");
292 let error_kind = classify_io_error(&error);
293 assert_eq!(error_kind, AgentErrorKind::FileSystem);
294 }
295
296 #[test]
297 fn test_classify_io_error_network() {
298 let error = io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe");
299 let error_kind = classify_io_error(&error);
300 assert_eq!(error_kind, AgentErrorKind::Network);
301 }
302}