Skip to main content

aster/agents/
retry.rs

1use anyhow::Result;
2use std::process::Stdio;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::process::Command;
6use tokio::sync::Mutex;
7use tracing::{debug, info, warn};
8
9use crate::agents::types::SessionConfig;
10use crate::agents::types::{
11    RetryConfig, SuccessCheck, DEFAULT_ON_FAILURE_TIMEOUT_SECONDS, DEFAULT_RETRY_TIMEOUT_SECONDS,
12};
13use crate::config::Config;
14use crate::conversation::message::Message;
15use crate::conversation::Conversation;
16use crate::tool_monitor::RepetitionInspector;
17
18/// Result of a retry logic evaluation
19#[derive(Debug, Clone, PartialEq)]
20pub enum RetryResult {
21    /// No retry configuration or session available, retry logic skipped
22    Skipped,
23    /// Maximum retry attempts reached, cannot retry further
24    MaxAttemptsReached,
25    /// Success checks passed, no retry needed
26    SuccessChecksPassed,
27    /// Retry is needed and will be performed
28    Retried,
29}
30
31/// Environment variable for configuring retry timeout globally
32const ASTER_RECIPE_RETRY_TIMEOUT_SECONDS: &str = "ASTER_RECIPE_RETRY_TIMEOUT_SECONDS";
33
34/// Environment variable for configuring on_failure timeout globally
35const ASTER_RECIPE_ON_FAILURE_TIMEOUT_SECONDS: &str = "ASTER_RECIPE_ON_FAILURE_TIMEOUT_SECONDS";
36
37/// Manages retry state and operations for agent execution
38#[derive(Debug, Default)]
39pub struct RetryManager {
40    /// Current number of retry attempts
41    attempts: Arc<Mutex<u32>>,
42    /// Optional repetition inspector for reset operations
43    repetition_inspector: Option<Arc<Mutex<Option<RepetitionInspector>>>>,
44}
45
46impl RetryManager {
47    /// Create a new retry manager
48    pub fn new() -> Self {
49        Self {
50            attempts: Arc::new(Mutex::new(0)),
51            repetition_inspector: None,
52        }
53    }
54
55    /// Create a new retry manager with repetition inspector
56    pub fn with_repetition_inspector(
57        repetition_inspector: Arc<Mutex<Option<RepetitionInspector>>>,
58    ) -> Self {
59        Self {
60            attempts: Arc::new(Mutex::new(0)),
61            repetition_inspector: Some(repetition_inspector),
62        }
63    }
64
65    /// Reset the retry attempts counter to 0
66    pub async fn reset_attempts(&self) {
67        let mut attempts = self.attempts.lock().await;
68        *attempts = 0;
69
70        // Reset repetition inspector if available
71        if let Some(inspector) = &self.repetition_inspector {
72            if let Some(inspector) = inspector.lock().await.as_mut() {
73                inspector.reset();
74            }
75        }
76    }
77
78    /// Increment the retry attempts counter and return the new value
79    pub async fn increment_attempts(&self) -> u32 {
80        let mut attempts = self.attempts.lock().await;
81        *attempts += 1;
82        *attempts
83    }
84
85    /// Get the current retry attempts count
86    pub async fn get_attempts(&self) -> u32 {
87        *self.attempts.lock().await
88    }
89
90    /// Reset status for retry: clear message history and final output tool state
91    async fn reset_status_for_retry(
92        messages: &mut Conversation,
93        initial_messages: &[Message],
94        final_output_tool: &Arc<Mutex<Option<crate::agents::final_output_tool::FinalOutputTool>>>,
95    ) {
96        *messages = Conversation::new_unvalidated(initial_messages.to_vec());
97        info!("Reset message history to initial state for retry");
98
99        if let Some(final_output_tool) = final_output_tool.lock().await.as_mut() {
100            final_output_tool.final_output = None;
101            info!("Cleared final output tool state for retry");
102        }
103    }
104
105    pub async fn handle_retry_logic(
106        &self,
107        messages: &mut Conversation,
108        session_config: &SessionConfig,
109        initial_messages: &[Message],
110        final_output_tool: &Arc<Mutex<Option<crate::agents::final_output_tool::FinalOutputTool>>>,
111    ) -> Result<RetryResult> {
112        let Some(retry_config) = &session_config.retry_config else {
113            return Ok(RetryResult::Skipped);
114        };
115
116        let success = execute_success_checks(&retry_config.checks, retry_config).await?;
117
118        if success {
119            info!("All success checks passed, no retry needed");
120            return Ok(RetryResult::SuccessChecksPassed);
121        }
122
123        let current_attempts = self.get_attempts().await;
124        if current_attempts >= retry_config.max_retries {
125            let error_msg = Message::assistant().with_text(format!(
126                "Maximum retry attempts ({}) exceeded. Unable to complete the task successfully.",
127                retry_config.max_retries
128            ));
129            messages.push(error_msg);
130            warn!(
131                "Maximum retry attempts ({}) exceeded",
132                retry_config.max_retries
133            );
134            crate::posthog::emit_error(
135                "retry_max_exceeded",
136                &format!("Max retries ({}) exceeded", retry_config.max_retries),
137            );
138            return Ok(RetryResult::MaxAttemptsReached);
139        }
140
141        if let Some(on_failure_cmd) = &retry_config.on_failure {
142            info!("Executing on_failure command: {}", on_failure_cmd);
143            execute_on_failure_command(on_failure_cmd, retry_config).await?;
144        }
145
146        Self::reset_status_for_retry(messages, initial_messages, final_output_tool).await;
147
148        let new_attempts = self.increment_attempts().await;
149        info!("Incrementing retry attempts to {}", new_attempts);
150
151        Ok(RetryResult::Retried)
152    }
153}
154
155/// Get the configured timeout duration for retry operations
156/// retry_config.timeout_seconds -> env var -> default
157fn get_retry_timeout(retry_config: &RetryConfig) -> Duration {
158    let timeout_seconds = retry_config
159        .timeout_seconds
160        .or_else(|| {
161            let config = Config::global();
162            config.get_param(ASTER_RECIPE_RETRY_TIMEOUT_SECONDS).ok()
163        })
164        .unwrap_or(DEFAULT_RETRY_TIMEOUT_SECONDS);
165
166    Duration::from_secs(timeout_seconds)
167}
168
169/// Get the configured timeout duration for on_failure operations
170/// retry_config.on_failure_timeout_seconds -> env var -> default
171fn get_on_failure_timeout(retry_config: &RetryConfig) -> Duration {
172    let timeout_seconds = retry_config
173        .on_failure_timeout_seconds
174        .or_else(|| {
175            let config = Config::global();
176            config
177                .get_param(ASTER_RECIPE_ON_FAILURE_TIMEOUT_SECONDS)
178                .ok()
179        })
180        .unwrap_or(DEFAULT_ON_FAILURE_TIMEOUT_SECONDS);
181
182    Duration::from_secs(timeout_seconds)
183}
184
185/// Execute all success checks and return true if all pass
186pub async fn execute_success_checks(
187    checks: &[SuccessCheck],
188    retry_config: &RetryConfig,
189) -> Result<bool> {
190    let timeout = get_retry_timeout(retry_config);
191
192    for check in checks {
193        match check {
194            SuccessCheck::Shell { command } => {
195                let result = execute_shell_command(command, timeout).await?;
196                if !result.status.success() {
197                    warn!(
198                        "Success check failed: command '{}' exited with status {}, stderr: {}",
199                        command,
200                        result.status,
201                        String::from_utf8_lossy(&result.stderr)
202                    );
203                    return Ok(false);
204                }
205                info!(
206                    "Success check passed: command '{}' completed successfully",
207                    command
208                );
209            }
210        }
211    }
212    Ok(true)
213}
214
215/// Execute a shell command with cross-platform compatibility and mandatory timeout
216pub async fn execute_shell_command(
217    command: &str,
218    timeout: std::time::Duration,
219) -> Result<std::process::Output> {
220    debug!(
221        "Executing shell command with timeout {:?}: {}",
222        timeout, command
223    );
224
225    let future = async {
226        let mut cmd = if cfg!(target_os = "windows") {
227            let mut cmd = Command::new("cmd");
228            cmd.args(["/C", command]);
229            cmd.env("ASTER_TERMINAL", "1");
230            cmd
231        } else {
232            let mut cmd = Command::new("sh");
233            cmd.args(["-c", command]);
234            cmd.env("ASTER_TERMINAL", "1");
235            cmd
236        };
237
238        let output = cmd
239            .stdout(Stdio::piped())
240            .stderr(Stdio::piped())
241            .stdin(Stdio::null())
242            .kill_on_drop(true)
243            .output()
244            .await?;
245
246        debug!(
247            "Shell command completed with status: {}, stdout: {}, stderr: {}",
248            output.status,
249            String::from_utf8_lossy(&output.stdout),
250            String::from_utf8_lossy(&output.stderr)
251        );
252
253        Ok(output)
254    };
255
256    match tokio::time::timeout(timeout, future).await {
257        Ok(result) => result,
258        Err(_) => {
259            let error_msg = format!("Shell command timed out after {:?}: {}", timeout, command);
260            warn!("{}", error_msg);
261            Err(anyhow::anyhow!("{}", error_msg))
262        }
263    }
264}
265
266/// Execute an on_failure command and return an error if it fails
267pub async fn execute_on_failure_command(command: &str, retry_config: &RetryConfig) -> Result<()> {
268    let timeout = get_on_failure_timeout(retry_config);
269    info!(
270        "Executing on_failure command with timeout {:?}: {}",
271        timeout, command
272    );
273
274    let output = match execute_shell_command(command, timeout).await {
275        Ok(output) => output,
276        Err(e) => {
277            if e.to_string().contains("timed out") {
278                let error_msg = format!(
279                    "On_failure command timed out after {:?}: {}",
280                    timeout, command
281                );
282                warn!("{}", error_msg);
283                return Err(anyhow::anyhow!(error_msg));
284            } else {
285                warn!("On_failure command execution error: {}", e);
286                return Err(e);
287            }
288        }
289    };
290
291    if !output.status.success() {
292        let error_msg = format!(
293            "On_failure command failed: command '{}' exited with status {}, stderr: {}",
294            command,
295            output.status,
296            String::from_utf8_lossy(&output.stderr)
297        );
298        warn!("{}", error_msg);
299        return Err(anyhow::anyhow!(error_msg));
300    } else {
301        info!("On_failure command completed successfully: {}", command);
302    }
303
304    Ok(())
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use crate::agents::types::SuccessCheck;
311
312    fn create_test_retry_config() -> RetryConfig {
313        RetryConfig {
314            max_retries: 3,
315            checks: vec![],
316            on_failure: None,
317            timeout_seconds: Some(60),
318            on_failure_timeout_seconds: Some(120),
319        }
320    }
321
322    #[test]
323    fn test_retry_result_enum() {
324        assert_ne!(RetryResult::Skipped, RetryResult::MaxAttemptsReached);
325        assert_ne!(RetryResult::Skipped, RetryResult::SuccessChecksPassed);
326        assert_ne!(RetryResult::Skipped, RetryResult::Retried);
327        assert_ne!(
328            RetryResult::MaxAttemptsReached,
329            RetryResult::SuccessChecksPassed
330        );
331        assert_ne!(RetryResult::MaxAttemptsReached, RetryResult::Retried);
332        assert_ne!(RetryResult::SuccessChecksPassed, RetryResult::Retried);
333
334        let result = RetryResult::Retried;
335        let cloned = result.clone();
336        assert_eq!(result, cloned);
337
338        let debug_str = format!("{:?}", RetryResult::MaxAttemptsReached);
339        assert!(debug_str.contains("MaxAttemptsReached"));
340    }
341
342    #[tokio::test]
343    async fn test_execute_success_checks_all_pass() {
344        let checks = vec![
345            SuccessCheck::Shell {
346                command: "echo 'test'".to_string(),
347            },
348            SuccessCheck::Shell {
349                command: "true".to_string(),
350            },
351        ];
352        let retry_config = create_test_retry_config();
353
354        let result = execute_success_checks(&checks, &retry_config).await;
355        assert!(result.is_ok());
356        assert!(result.unwrap());
357    }
358
359    #[tokio::test]
360    async fn test_execute_success_checks_one_fails() {
361        let checks = vec![
362            SuccessCheck::Shell {
363                command: "echo 'test'".to_string(),
364            },
365            SuccessCheck::Shell {
366                command: "false".to_string(),
367            },
368        ];
369        let retry_config = create_test_retry_config();
370
371        let result = execute_success_checks(&checks, &retry_config).await;
372        assert!(result.is_ok());
373        assert!(!result.unwrap());
374    }
375
376    #[tokio::test]
377    async fn test_execute_shell_command_success() {
378        let result = execute_shell_command("echo 'hello world'", Duration::from_secs(30)).await;
379        assert!(result.is_ok());
380        let output = result.unwrap();
381        assert!(output.status.success());
382        assert!(String::from_utf8_lossy(&output.stdout).contains("hello world"));
383    }
384
385    #[tokio::test]
386    async fn test_execute_shell_command_failure() {
387        let result = execute_shell_command("false", Duration::from_secs(30)).await;
388        assert!(result.is_ok());
389        let output = result.unwrap();
390        assert!(!output.status.success());
391    }
392
393    #[tokio::test]
394    async fn test_execute_on_failure_command_success() {
395        let retry_config = create_test_retry_config();
396        let result = execute_on_failure_command("echo 'cleanup'", &retry_config).await;
397        assert!(result.is_ok());
398    }
399
400    #[tokio::test]
401    async fn test_execute_on_failure_command_failure() {
402        let retry_config = create_test_retry_config();
403        let result = execute_on_failure_command("false", &retry_config).await;
404        assert!(result.is_err());
405    }
406
407    #[tokio::test]
408    async fn test_shell_command_timeout() {
409        let timeout = std::time::Duration::from_millis(100);
410        let result = if cfg!(target_os = "windows") {
411            execute_shell_command("timeout /t 1", timeout).await
412        } else {
413            execute_shell_command("sleep 1", timeout).await
414        };
415
416        assert!(result.is_err());
417    }
418
419    #[tokio::test]
420    async fn test_get_retry_timeout_uses_config_default() {
421        let retry_config = RetryConfig {
422            max_retries: 1,
423            checks: vec![],
424            on_failure: None,
425            timeout_seconds: None,
426            on_failure_timeout_seconds: None,
427        };
428
429        let timeout = get_retry_timeout(&retry_config);
430        assert_eq!(timeout, Duration::from_secs(DEFAULT_RETRY_TIMEOUT_SECONDS));
431    }
432
433    #[tokio::test]
434    async fn test_get_retry_timeout_uses_retry_config() {
435        let retry_config = RetryConfig {
436            max_retries: 1,
437            checks: vec![],
438            on_failure: None,
439            timeout_seconds: Some(120),
440            on_failure_timeout_seconds: None,
441        };
442
443        let timeout = get_retry_timeout(&retry_config);
444        assert_eq!(timeout, Duration::from_secs(120));
445    }
446
447    #[tokio::test]
448    async fn test_get_on_failure_timeout_uses_config_default() {
449        let retry_config = RetryConfig {
450            max_retries: 1,
451            checks: vec![],
452            on_failure: None,
453            timeout_seconds: None,
454            on_failure_timeout_seconds: None,
455        };
456
457        let timeout = get_on_failure_timeout(&retry_config);
458        assert_eq!(
459            timeout,
460            Duration::from_secs(DEFAULT_ON_FAILURE_TIMEOUT_SECONDS)
461        );
462    }
463
464    #[tokio::test]
465    async fn test_get_on_failure_timeout_uses_retry_config() {
466        let retry_config = RetryConfig {
467            max_retries: 1,
468            checks: vec![],
469            on_failure: None,
470            timeout_seconds: None,
471            on_failure_timeout_seconds: Some(900),
472        };
473
474        let timeout = get_on_failure_timeout(&retry_config);
475        assert_eq!(timeout, Duration::from_secs(900));
476    }
477
478    #[tokio::test]
479    async fn test_on_failure_timeout_different_from_retry_timeout() {
480        let retry_config = RetryConfig {
481            max_retries: 1,
482            checks: vec![],
483            on_failure: None,
484            timeout_seconds: Some(60),
485            on_failure_timeout_seconds: Some(300),
486        };
487
488        let retry_timeout = get_retry_timeout(&retry_config);
489        let on_failure_timeout = get_on_failure_timeout(&retry_config);
490
491        assert_eq!(retry_timeout, Duration::from_secs(60));
492        assert_eq!(on_failure_timeout, Duration::from_secs(300));
493        assert_ne!(retry_timeout, on_failure_timeout);
494    }
495}