Skip to main content

aivcs_core/sandbox/
execution.rs

1//! Execution controls: timeout, retry with exponential backoff, circuit breaker.
2
3use std::future::Future;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10use super::error::{SandboxError, SandboxResult};
11
12/// Configuration for sandboxed tool execution.
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub struct SandboxConfig {
15    /// Maximum wall-clock time for a single attempt (milliseconds).
16    pub timeout_ms: u64,
17    /// Maximum number of retries (0 = no retries, run once).
18    pub max_retries: u32,
19    /// Base delay for exponential backoff between retries (milliseconds).
20    pub backoff_base_ms: u64,
21}
22
23impl Default for SandboxConfig {
24    fn default() -> Self {
25        Self {
26            timeout_ms: 30_000,
27            max_retries: 2,
28            backoff_base_ms: 500,
29        }
30    }
31}
32
33/// Atomic circuit breaker that opens after N consecutive failures.
34///
35/// Thread-safe via `AtomicU32`. Resets on success.
36#[derive(Debug)]
37pub struct CircuitBreaker {
38    consecutive_failures: AtomicU32,
39    threshold: u32,
40}
41
42impl CircuitBreaker {
43    /// Create a new circuit breaker with the given failure threshold.
44    pub fn new(threshold: u32) -> Self {
45        Self {
46            consecutive_failures: AtomicU32::new(0),
47            threshold,
48        }
49    }
50
51    /// Returns `true` if the breaker is open (too many consecutive failures).
52    pub fn is_open(&self) -> bool {
53        self.consecutive_failures.load(Ordering::Relaxed) >= self.threshold
54    }
55
56    /// Record a failure. Returns current consecutive failure count.
57    pub fn record_failure(&self) -> u32 {
58        self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1
59    }
60
61    /// Reset on success.
62    pub fn record_success(&self) {
63        self.consecutive_failures.store(0, Ordering::Relaxed);
64    }
65
66    /// Current consecutive failure count.
67    pub fn failure_count(&self) -> u32 {
68        self.consecutive_failures.load(Ordering::Relaxed)
69    }
70}
71
72/// The result of a tool execution attempt.
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
74pub struct ToolExecutionResult {
75    /// Whether the tool succeeded.
76    pub success: bool,
77    /// Number of attempts made (1 = no retries used).
78    pub attempts: u32,
79    /// Tool output payload (present on success).
80    pub output: Option<serde_json::Value>,
81    /// Error message (present on failure).
82    pub error: Option<String>,
83}
84
85/// Execute a tool with timeout, retry, and circuit-breaker controls.
86///
87/// `tool_fn` is an async closure that performs the actual tool work and returns
88/// `Ok(serde_json::Value)` on success or `Err(String)` on failure.
89///
90/// The circuit breaker is checked before each attempt and updated after.
91pub async fn execute_with_controls<F, Fut>(
92    config: &SandboxConfig,
93    breaker: &Arc<CircuitBreaker>,
94    tool_fn: F,
95) -> SandboxResult<ToolExecutionResult>
96where
97    F: Fn() -> Fut,
98    Fut: Future<Output = Result<serde_json::Value, String>>,
99{
100    let max_attempts = config.max_retries + 1;
101
102    for attempt in 1..=max_attempts {
103        // Check circuit breaker
104        if breaker.is_open() {
105            return Err(SandboxError::CircuitBreakerOpen {
106                consecutive_failures: breaker.failure_count(),
107                threshold: breaker.threshold,
108            });
109        }
110
111        let timeout = Duration::from_millis(config.timeout_ms);
112        let result = tokio::time::timeout(timeout, tool_fn()).await;
113
114        match result {
115            Ok(Ok(value)) => {
116                breaker.record_success();
117                return Ok(ToolExecutionResult {
118                    success: true,
119                    attempts: attempt,
120                    output: Some(value),
121                    error: None,
122                });
123            }
124            Ok(Err(err_msg)) => {
125                breaker.record_failure();
126                if attempt == max_attempts {
127                    return Ok(ToolExecutionResult {
128                        success: false,
129                        attempts: attempt,
130                        output: None,
131                        error: Some(err_msg),
132                    });
133                }
134                // Exponential backoff before retry
135                let delay = Duration::from_millis(config.backoff_base_ms * 2u64.pow(attempt - 1));
136                tokio::time::sleep(delay).await;
137            }
138            Err(_elapsed) => {
139                breaker.record_failure();
140                if attempt == max_attempts {
141                    return Err(SandboxError::Timeout {
142                        elapsed_ms: config.timeout_ms,
143                        limit_ms: config.timeout_ms,
144                    });
145                }
146                let delay = Duration::from_millis(config.backoff_base_ms * 2u64.pow(attempt - 1));
147                tokio::time::sleep(delay).await;
148            }
149        }
150    }
151
152    // Unreachable, but satisfy the compiler.
153    Err(SandboxError::ExecutionFailed {
154        attempts: max_attempts,
155        reason: "exhausted all attempts".into(),
156    })
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_circuit_breaker_starts_closed() {
165        let cb = CircuitBreaker::new(3);
166        assert!(!cb.is_open());
167        assert_eq!(cb.failure_count(), 0);
168    }
169
170    #[test]
171    fn test_circuit_breaker_opens_at_threshold() {
172        let cb = CircuitBreaker::new(3);
173        cb.record_failure();
174        cb.record_failure();
175        assert!(!cb.is_open());
176        cb.record_failure();
177        assert!(cb.is_open());
178    }
179
180    #[test]
181    fn test_circuit_breaker_resets_on_success() {
182        let cb = CircuitBreaker::new(3);
183        cb.record_failure();
184        cb.record_failure();
185        cb.record_success();
186        assert_eq!(cb.failure_count(), 0);
187        assert!(!cb.is_open());
188    }
189
190    #[test]
191    fn test_sandbox_config_default() {
192        let cfg = SandboxConfig::default();
193        assert_eq!(cfg.timeout_ms, 30_000);
194        assert_eq!(cfg.max_retries, 2);
195        assert_eq!(cfg.backoff_base_ms, 500);
196    }
197
198    #[test]
199    fn test_sandbox_config_serde_roundtrip() {
200        let cfg = SandboxConfig {
201            timeout_ms: 5000,
202            max_retries: 1,
203            backoff_base_ms: 100,
204        };
205        let json = serde_json::to_string(&cfg).unwrap();
206        let back: SandboxConfig = serde_json::from_str(&json).unwrap();
207        assert_eq!(cfg, back);
208    }
209
210    #[tokio::test]
211    async fn test_execute_success_on_first_attempt() {
212        let cfg = SandboxConfig {
213            timeout_ms: 1000,
214            max_retries: 2,
215            backoff_base_ms: 10,
216        };
217        let breaker = Arc::new(CircuitBreaker::new(5));
218
219        let result = execute_with_controls(&cfg, &breaker, || async {
220            Ok(serde_json::json!({"ok": true}))
221        })
222        .await
223        .unwrap();
224
225        assert!(result.success);
226        assert_eq!(result.attempts, 1);
227        assert!(result.output.is_some());
228    }
229
230    #[tokio::test]
231    async fn test_execute_retries_then_succeeds() {
232        let cfg = SandboxConfig {
233            timeout_ms: 1000,
234            max_retries: 2,
235            backoff_base_ms: 10,
236        };
237        let breaker = Arc::new(CircuitBreaker::new(5));
238        let counter = Arc::new(AtomicU32::new(0));
239        let counter_clone = counter.clone();
240
241        let result = execute_with_controls(&cfg, &breaker, move || {
242            let c = counter_clone.clone();
243            async move {
244                let n = c.fetch_add(1, Ordering::Relaxed);
245                if n < 2 {
246                    Err("not yet".into())
247                } else {
248                    Ok(serde_json::json!({"ok": true}))
249                }
250            }
251        })
252        .await
253        .unwrap();
254
255        assert!(result.success);
256        assert_eq!(result.attempts, 3);
257    }
258
259    #[tokio::test]
260    async fn test_execute_exhausts_retries() {
261        let cfg = SandboxConfig {
262            timeout_ms: 1000,
263            max_retries: 1,
264            backoff_base_ms: 10,
265        };
266        let breaker = Arc::new(CircuitBreaker::new(10));
267
268        let result = execute_with_controls(&cfg, &breaker, || async {
269            Err::<serde_json::Value, _>("always fails".to_string())
270        })
271        .await
272        .unwrap();
273
274        assert!(!result.success);
275        assert_eq!(result.attempts, 2);
276        assert!(result.error.unwrap().contains("always fails"));
277    }
278
279    #[tokio::test]
280    async fn test_execute_circuit_breaker_blocks() {
281        let cfg = SandboxConfig {
282            timeout_ms: 1000,
283            max_retries: 0,
284            backoff_base_ms: 10,
285        };
286        let breaker = Arc::new(CircuitBreaker::new(1));
287        breaker.record_failure(); // open the breaker
288
289        let result = execute_with_controls(&cfg, &breaker, || async {
290            Ok(serde_json::json!({"ok": true}))
291        })
292        .await;
293
294        assert!(result.is_err());
295        match result.unwrap_err() {
296            SandboxError::CircuitBreakerOpen { .. } => {}
297            other => panic!("expected CircuitBreakerOpen, got {:?}", other),
298        }
299    }
300
301    #[tokio::test]
302    async fn test_execute_timeout() {
303        let cfg = SandboxConfig {
304            timeout_ms: 50,
305            max_retries: 0,
306            backoff_base_ms: 10,
307        };
308        let breaker = Arc::new(CircuitBreaker::new(10));
309
310        let result = execute_with_controls(&cfg, &breaker, || async {
311            tokio::time::sleep(Duration::from_millis(200)).await;
312            Ok(serde_json::json!({"ok": true}))
313        })
314        .await;
315
316        assert!(result.is_err());
317        match result.unwrap_err() {
318            SandboxError::Timeout { .. } => {}
319            other => panic!("expected Timeout, got {:?}", other),
320        }
321    }
322}