1use anyhow::Result;
2use std::process::Stdio;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::process::Command;
6use tokio::sync::Mutex;
7use tracing::{debug, info, warn};
8
9use crate::agents::types::SessionConfig;
10use crate::agents::types::{
11 RetryConfig, SuccessCheck, DEFAULT_ON_FAILURE_TIMEOUT_SECONDS, DEFAULT_RETRY_TIMEOUT_SECONDS,
12};
13use crate::config::Config;
14use crate::conversation::message::Message;
15use crate::conversation::Conversation;
16use crate::tool_monitor::RepetitionInspector;
17
18#[derive(Debug, Clone, PartialEq)]
20pub enum RetryResult {
21 Skipped,
23 MaxAttemptsReached,
25 SuccessChecksPassed,
27 Retried,
29}
30
31const ASTER_RECIPE_RETRY_TIMEOUT_SECONDS: &str = "ASTER_RECIPE_RETRY_TIMEOUT_SECONDS";
33
34const ASTER_RECIPE_ON_FAILURE_TIMEOUT_SECONDS: &str = "ASTER_RECIPE_ON_FAILURE_TIMEOUT_SECONDS";
36
37#[derive(Debug, Default)]
39pub struct RetryManager {
40 attempts: Arc<Mutex<u32>>,
42 repetition_inspector: Option<Arc<Mutex<Option<RepetitionInspector>>>>,
44}
45
46impl RetryManager {
47 pub fn new() -> Self {
49 Self {
50 attempts: Arc::new(Mutex::new(0)),
51 repetition_inspector: None,
52 }
53 }
54
55 pub fn with_repetition_inspector(
57 repetition_inspector: Arc<Mutex<Option<RepetitionInspector>>>,
58 ) -> Self {
59 Self {
60 attempts: Arc::new(Mutex::new(0)),
61 repetition_inspector: Some(repetition_inspector),
62 }
63 }
64
65 pub async fn reset_attempts(&self) {
67 let mut attempts = self.attempts.lock().await;
68 *attempts = 0;
69
70 if let Some(inspector) = &self.repetition_inspector {
72 if let Some(inspector) = inspector.lock().await.as_mut() {
73 inspector.reset();
74 }
75 }
76 }
77
78 pub async fn increment_attempts(&self) -> u32 {
80 let mut attempts = self.attempts.lock().await;
81 *attempts += 1;
82 *attempts
83 }
84
85 pub async fn get_attempts(&self) -> u32 {
87 *self.attempts.lock().await
88 }
89
90 async fn reset_status_for_retry(
92 messages: &mut Conversation,
93 initial_messages: &[Message],
94 final_output_tool: &Arc<Mutex<Option<crate::agents::final_output_tool::FinalOutputTool>>>,
95 ) {
96 *messages = Conversation::new_unvalidated(initial_messages.to_vec());
97 info!("Reset message history to initial state for retry");
98
99 if let Some(final_output_tool) = final_output_tool.lock().await.as_mut() {
100 final_output_tool.final_output = None;
101 info!("Cleared final output tool state for retry");
102 }
103 }
104
105 pub async fn handle_retry_logic(
106 &self,
107 messages: &mut Conversation,
108 session_config: &SessionConfig,
109 initial_messages: &[Message],
110 final_output_tool: &Arc<Mutex<Option<crate::agents::final_output_tool::FinalOutputTool>>>,
111 ) -> Result<RetryResult> {
112 let Some(retry_config) = &session_config.retry_config else {
113 return Ok(RetryResult::Skipped);
114 };
115
116 let success = execute_success_checks(&retry_config.checks, retry_config).await?;
117
118 if success {
119 info!("All success checks passed, no retry needed");
120 return Ok(RetryResult::SuccessChecksPassed);
121 }
122
123 let current_attempts = self.get_attempts().await;
124 if current_attempts >= retry_config.max_retries {
125 let error_msg = Message::assistant().with_text(format!(
126 "Maximum retry attempts ({}) exceeded. Unable to complete the task successfully.",
127 retry_config.max_retries
128 ));
129 messages.push(error_msg);
130 warn!(
131 "Maximum retry attempts ({}) exceeded",
132 retry_config.max_retries
133 );
134 crate::posthog::emit_error(
135 "retry_max_exceeded",
136 &format!("Max retries ({}) exceeded", retry_config.max_retries),
137 );
138 return Ok(RetryResult::MaxAttemptsReached);
139 }
140
141 if let Some(on_failure_cmd) = &retry_config.on_failure {
142 info!("Executing on_failure command: {}", on_failure_cmd);
143 execute_on_failure_command(on_failure_cmd, retry_config).await?;
144 }
145
146 Self::reset_status_for_retry(messages, initial_messages, final_output_tool).await;
147
148 let new_attempts = self.increment_attempts().await;
149 info!("Incrementing retry attempts to {}", new_attempts);
150
151 Ok(RetryResult::Retried)
152 }
153}
154
155fn get_retry_timeout(retry_config: &RetryConfig) -> Duration {
158 let timeout_seconds = retry_config
159 .timeout_seconds
160 .or_else(|| {
161 let config = Config::global();
162 config.get_param(ASTER_RECIPE_RETRY_TIMEOUT_SECONDS).ok()
163 })
164 .unwrap_or(DEFAULT_RETRY_TIMEOUT_SECONDS);
165
166 Duration::from_secs(timeout_seconds)
167}
168
169fn get_on_failure_timeout(retry_config: &RetryConfig) -> Duration {
172 let timeout_seconds = retry_config
173 .on_failure_timeout_seconds
174 .or_else(|| {
175 let config = Config::global();
176 config
177 .get_param(ASTER_RECIPE_ON_FAILURE_TIMEOUT_SECONDS)
178 .ok()
179 })
180 .unwrap_or(DEFAULT_ON_FAILURE_TIMEOUT_SECONDS);
181
182 Duration::from_secs(timeout_seconds)
183}
184
185pub async fn execute_success_checks(
187 checks: &[SuccessCheck],
188 retry_config: &RetryConfig,
189) -> Result<bool> {
190 let timeout = get_retry_timeout(retry_config);
191
192 for check in checks {
193 match check {
194 SuccessCheck::Shell { command } => {
195 let result = execute_shell_command(command, timeout).await?;
196 if !result.status.success() {
197 warn!(
198 "Success check failed: command '{}' exited with status {}, stderr: {}",
199 command,
200 result.status,
201 String::from_utf8_lossy(&result.stderr)
202 );
203 return Ok(false);
204 }
205 info!(
206 "Success check passed: command '{}' completed successfully",
207 command
208 );
209 }
210 }
211 }
212 Ok(true)
213}
214
215pub async fn execute_shell_command(
217 command: &str,
218 timeout: std::time::Duration,
219) -> Result<std::process::Output> {
220 debug!(
221 "Executing shell command with timeout {:?}: {}",
222 timeout, command
223 );
224
225 let future = async {
226 let mut cmd = if cfg!(target_os = "windows") {
227 let mut cmd = Command::new("cmd");
228 cmd.args(["/C", command]);
229 cmd.env("ASTER_TERMINAL", "1");
230 cmd
231 } else {
232 let mut cmd = Command::new("sh");
233 cmd.args(["-c", command]);
234 cmd.env("ASTER_TERMINAL", "1");
235 cmd
236 };
237
238 let output = cmd
239 .stdout(Stdio::piped())
240 .stderr(Stdio::piped())
241 .stdin(Stdio::null())
242 .kill_on_drop(true)
243 .output()
244 .await?;
245
246 debug!(
247 "Shell command completed with status: {}, stdout: {}, stderr: {}",
248 output.status,
249 String::from_utf8_lossy(&output.stdout),
250 String::from_utf8_lossy(&output.stderr)
251 );
252
253 Ok(output)
254 };
255
256 match tokio::time::timeout(timeout, future).await {
257 Ok(result) => result,
258 Err(_) => {
259 let error_msg = format!("Shell command timed out after {:?}: {}", timeout, command);
260 warn!("{}", error_msg);
261 Err(anyhow::anyhow!("{}", error_msg))
262 }
263 }
264}
265
266pub async fn execute_on_failure_command(command: &str, retry_config: &RetryConfig) -> Result<()> {
268 let timeout = get_on_failure_timeout(retry_config);
269 info!(
270 "Executing on_failure command with timeout {:?}: {}",
271 timeout, command
272 );
273
274 let output = match execute_shell_command(command, timeout).await {
275 Ok(output) => output,
276 Err(e) => {
277 if e.to_string().contains("timed out") {
278 let error_msg = format!(
279 "On_failure command timed out after {:?}: {}",
280 timeout, command
281 );
282 warn!("{}", error_msg);
283 return Err(anyhow::anyhow!(error_msg));
284 } else {
285 warn!("On_failure command execution error: {}", e);
286 return Err(e);
287 }
288 }
289 };
290
291 if !output.status.success() {
292 let error_msg = format!(
293 "On_failure command failed: command '{}' exited with status {}, stderr: {}",
294 command,
295 output.status,
296 String::from_utf8_lossy(&output.stderr)
297 );
298 warn!("{}", error_msg);
299 return Err(anyhow::anyhow!(error_msg));
300 } else {
301 info!("On_failure command completed successfully: {}", command);
302 }
303
304 Ok(())
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use crate::agents::types::SuccessCheck;
311
312 fn create_test_retry_config() -> RetryConfig {
313 RetryConfig {
314 max_retries: 3,
315 checks: vec![],
316 on_failure: None,
317 timeout_seconds: Some(60),
318 on_failure_timeout_seconds: Some(120),
319 }
320 }
321
322 #[test]
323 fn test_retry_result_enum() {
324 assert_ne!(RetryResult::Skipped, RetryResult::MaxAttemptsReached);
325 assert_ne!(RetryResult::Skipped, RetryResult::SuccessChecksPassed);
326 assert_ne!(RetryResult::Skipped, RetryResult::Retried);
327 assert_ne!(
328 RetryResult::MaxAttemptsReached,
329 RetryResult::SuccessChecksPassed
330 );
331 assert_ne!(RetryResult::MaxAttemptsReached, RetryResult::Retried);
332 assert_ne!(RetryResult::SuccessChecksPassed, RetryResult::Retried);
333
334 let result = RetryResult::Retried;
335 let cloned = result.clone();
336 assert_eq!(result, cloned);
337
338 let debug_str = format!("{:?}", RetryResult::MaxAttemptsReached);
339 assert!(debug_str.contains("MaxAttemptsReached"));
340 }
341
342 #[tokio::test]
343 async fn test_execute_success_checks_all_pass() {
344 let checks = vec![
345 SuccessCheck::Shell {
346 command: "echo 'test'".to_string(),
347 },
348 SuccessCheck::Shell {
349 command: "true".to_string(),
350 },
351 ];
352 let retry_config = create_test_retry_config();
353
354 let result = execute_success_checks(&checks, &retry_config).await;
355 assert!(result.is_ok());
356 assert!(result.unwrap());
357 }
358
359 #[tokio::test]
360 async fn test_execute_success_checks_one_fails() {
361 let checks = vec![
362 SuccessCheck::Shell {
363 command: "echo 'test'".to_string(),
364 },
365 SuccessCheck::Shell {
366 command: "false".to_string(),
367 },
368 ];
369 let retry_config = create_test_retry_config();
370
371 let result = execute_success_checks(&checks, &retry_config).await;
372 assert!(result.is_ok());
373 assert!(!result.unwrap());
374 }
375
376 #[tokio::test]
377 async fn test_execute_shell_command_success() {
378 let result = execute_shell_command("echo 'hello world'", Duration::from_secs(30)).await;
379 assert!(result.is_ok());
380 let output = result.unwrap();
381 assert!(output.status.success());
382 assert!(String::from_utf8_lossy(&output.stdout).contains("hello world"));
383 }
384
385 #[tokio::test]
386 async fn test_execute_shell_command_failure() {
387 let result = execute_shell_command("false", Duration::from_secs(30)).await;
388 assert!(result.is_ok());
389 let output = result.unwrap();
390 assert!(!output.status.success());
391 }
392
393 #[tokio::test]
394 async fn test_execute_on_failure_command_success() {
395 let retry_config = create_test_retry_config();
396 let result = execute_on_failure_command("echo 'cleanup'", &retry_config).await;
397 assert!(result.is_ok());
398 }
399
400 #[tokio::test]
401 async fn test_execute_on_failure_command_failure() {
402 let retry_config = create_test_retry_config();
403 let result = execute_on_failure_command("false", &retry_config).await;
404 assert!(result.is_err());
405 }
406
407 #[tokio::test]
408 async fn test_shell_command_timeout() {
409 let timeout = std::time::Duration::from_millis(100);
410 let result = if cfg!(target_os = "windows") {
411 execute_shell_command("timeout /t 1", timeout).await
412 } else {
413 execute_shell_command("sleep 1", timeout).await
414 };
415
416 assert!(result.is_err());
417 }
418
419 #[tokio::test]
420 async fn test_get_retry_timeout_uses_config_default() {
421 let retry_config = RetryConfig {
422 max_retries: 1,
423 checks: vec![],
424 on_failure: None,
425 timeout_seconds: None,
426 on_failure_timeout_seconds: None,
427 };
428
429 let timeout = get_retry_timeout(&retry_config);
430 assert_eq!(timeout, Duration::from_secs(DEFAULT_RETRY_TIMEOUT_SECONDS));
431 }
432
433 #[tokio::test]
434 async fn test_get_retry_timeout_uses_retry_config() {
435 let retry_config = RetryConfig {
436 max_retries: 1,
437 checks: vec![],
438 on_failure: None,
439 timeout_seconds: Some(120),
440 on_failure_timeout_seconds: None,
441 };
442
443 let timeout = get_retry_timeout(&retry_config);
444 assert_eq!(timeout, Duration::from_secs(120));
445 }
446
447 #[tokio::test]
448 async fn test_get_on_failure_timeout_uses_config_default() {
449 let retry_config = RetryConfig {
450 max_retries: 1,
451 checks: vec![],
452 on_failure: None,
453 timeout_seconds: None,
454 on_failure_timeout_seconds: None,
455 };
456
457 let timeout = get_on_failure_timeout(&retry_config);
458 assert_eq!(
459 timeout,
460 Duration::from_secs(DEFAULT_ON_FAILURE_TIMEOUT_SECONDS)
461 );
462 }
463
464 #[tokio::test]
465 async fn test_get_on_failure_timeout_uses_retry_config() {
466 let retry_config = RetryConfig {
467 max_retries: 1,
468 checks: vec![],
469 on_failure: None,
470 timeout_seconds: None,
471 on_failure_timeout_seconds: Some(900),
472 };
473
474 let timeout = get_on_failure_timeout(&retry_config);
475 assert_eq!(timeout, Duration::from_secs(900));
476 }
477
478 #[tokio::test]
479 async fn test_on_failure_timeout_different_from_retry_timeout() {
480 let retry_config = RetryConfig {
481 max_retries: 1,
482 checks: vec![],
483 on_failure: None,
484 timeout_seconds: Some(60),
485 on_failure_timeout_seconds: Some(300),
486 };
487
488 let retry_timeout = get_retry_timeout(&retry_config);
489 let on_failure_timeout = get_on_failure_timeout(&retry_config);
490
491 assert_eq!(retry_timeout, Duration::from_secs(60));
492 assert_eq!(on_failure_timeout, Duration::from_secs(300));
493 assert_ne!(retry_timeout, on_failure_timeout);
494 }
495}