1use std::future::Future;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10use super::error::{SandboxError, SandboxResult};
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub struct SandboxConfig {
15 pub timeout_ms: u64,
17 pub max_retries: u32,
19 pub backoff_base_ms: u64,
21}
22
23impl Default for SandboxConfig {
24 fn default() -> Self {
25 Self {
26 timeout_ms: 30_000,
27 max_retries: 2,
28 backoff_base_ms: 500,
29 }
30 }
31}
32
33#[derive(Debug)]
37pub struct CircuitBreaker {
38 consecutive_failures: AtomicU32,
39 threshold: u32,
40}
41
42impl CircuitBreaker {
43 pub fn new(threshold: u32) -> Self {
45 Self {
46 consecutive_failures: AtomicU32::new(0),
47 threshold,
48 }
49 }
50
51 pub fn is_open(&self) -> bool {
53 self.consecutive_failures.load(Ordering::Relaxed) >= self.threshold
54 }
55
56 pub fn record_failure(&self) -> u32 {
58 self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1
59 }
60
61 pub fn record_success(&self) {
63 self.consecutive_failures.store(0, Ordering::Relaxed);
64 }
65
66 pub fn failure_count(&self) -> u32 {
68 self.consecutive_failures.load(Ordering::Relaxed)
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
74pub struct ToolExecutionResult {
75 pub success: bool,
77 pub attempts: u32,
79 pub output: Option<serde_json::Value>,
81 pub error: Option<String>,
83}
84
85pub async fn execute_with_controls<F, Fut>(
92 config: &SandboxConfig,
93 breaker: &Arc<CircuitBreaker>,
94 tool_fn: F,
95) -> SandboxResult<ToolExecutionResult>
96where
97 F: Fn() -> Fut,
98 Fut: Future<Output = Result<serde_json::Value, String>>,
99{
100 let max_attempts = config.max_retries + 1;
101
102 for attempt in 1..=max_attempts {
103 if breaker.is_open() {
105 return Err(SandboxError::CircuitBreakerOpen {
106 consecutive_failures: breaker.failure_count(),
107 threshold: breaker.threshold,
108 });
109 }
110
111 let timeout = Duration::from_millis(config.timeout_ms);
112 let result = tokio::time::timeout(timeout, tool_fn()).await;
113
114 match result {
115 Ok(Ok(value)) => {
116 breaker.record_success();
117 return Ok(ToolExecutionResult {
118 success: true,
119 attempts: attempt,
120 output: Some(value),
121 error: None,
122 });
123 }
124 Ok(Err(err_msg)) => {
125 breaker.record_failure();
126 if attempt == max_attempts {
127 return Ok(ToolExecutionResult {
128 success: false,
129 attempts: attempt,
130 output: None,
131 error: Some(err_msg),
132 });
133 }
134 let delay = Duration::from_millis(config.backoff_base_ms * 2u64.pow(attempt - 1));
136 tokio::time::sleep(delay).await;
137 }
138 Err(_elapsed) => {
139 breaker.record_failure();
140 if attempt == max_attempts {
141 return Err(SandboxError::Timeout {
142 elapsed_ms: config.timeout_ms,
143 limit_ms: config.timeout_ms,
144 });
145 }
146 let delay = Duration::from_millis(config.backoff_base_ms * 2u64.pow(attempt - 1));
147 tokio::time::sleep(delay).await;
148 }
149 }
150 }
151
152 Err(SandboxError::ExecutionFailed {
154 attempts: max_attempts,
155 reason: "exhausted all attempts".into(),
156 })
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_circuit_breaker_starts_closed() {
165 let cb = CircuitBreaker::new(3);
166 assert!(!cb.is_open());
167 assert_eq!(cb.failure_count(), 0);
168 }
169
170 #[test]
171 fn test_circuit_breaker_opens_at_threshold() {
172 let cb = CircuitBreaker::new(3);
173 cb.record_failure();
174 cb.record_failure();
175 assert!(!cb.is_open());
176 cb.record_failure();
177 assert!(cb.is_open());
178 }
179
180 #[test]
181 fn test_circuit_breaker_resets_on_success() {
182 let cb = CircuitBreaker::new(3);
183 cb.record_failure();
184 cb.record_failure();
185 cb.record_success();
186 assert_eq!(cb.failure_count(), 0);
187 assert!(!cb.is_open());
188 }
189
190 #[test]
191 fn test_sandbox_config_default() {
192 let cfg = SandboxConfig::default();
193 assert_eq!(cfg.timeout_ms, 30_000);
194 assert_eq!(cfg.max_retries, 2);
195 assert_eq!(cfg.backoff_base_ms, 500);
196 }
197
198 #[test]
199 fn test_sandbox_config_serde_roundtrip() {
200 let cfg = SandboxConfig {
201 timeout_ms: 5000,
202 max_retries: 1,
203 backoff_base_ms: 100,
204 };
205 let json = serde_json::to_string(&cfg).unwrap();
206 let back: SandboxConfig = serde_json::from_str(&json).unwrap();
207 assert_eq!(cfg, back);
208 }
209
210 #[tokio::test]
211 async fn test_execute_success_on_first_attempt() {
212 let cfg = SandboxConfig {
213 timeout_ms: 1000,
214 max_retries: 2,
215 backoff_base_ms: 10,
216 };
217 let breaker = Arc::new(CircuitBreaker::new(5));
218
219 let result = execute_with_controls(&cfg, &breaker, || async {
220 Ok(serde_json::json!({"ok": true}))
221 })
222 .await
223 .unwrap();
224
225 assert!(result.success);
226 assert_eq!(result.attempts, 1);
227 assert!(result.output.is_some());
228 }
229
230 #[tokio::test]
231 async fn test_execute_retries_then_succeeds() {
232 let cfg = SandboxConfig {
233 timeout_ms: 1000,
234 max_retries: 2,
235 backoff_base_ms: 10,
236 };
237 let breaker = Arc::new(CircuitBreaker::new(5));
238 let counter = Arc::new(AtomicU32::new(0));
239 let counter_clone = counter.clone();
240
241 let result = execute_with_controls(&cfg, &breaker, move || {
242 let c = counter_clone.clone();
243 async move {
244 let n = c.fetch_add(1, Ordering::Relaxed);
245 if n < 2 {
246 Err("not yet".into())
247 } else {
248 Ok(serde_json::json!({"ok": true}))
249 }
250 }
251 })
252 .await
253 .unwrap();
254
255 assert!(result.success);
256 assert_eq!(result.attempts, 3);
257 }
258
259 #[tokio::test]
260 async fn test_execute_exhausts_retries() {
261 let cfg = SandboxConfig {
262 timeout_ms: 1000,
263 max_retries: 1,
264 backoff_base_ms: 10,
265 };
266 let breaker = Arc::new(CircuitBreaker::new(10));
267
268 let result = execute_with_controls(&cfg, &breaker, || async {
269 Err::<serde_json::Value, _>("always fails".to_string())
270 })
271 .await
272 .unwrap();
273
274 assert!(!result.success);
275 assert_eq!(result.attempts, 2);
276 assert!(result.error.unwrap().contains("always fails"));
277 }
278
279 #[tokio::test]
280 async fn test_execute_circuit_breaker_blocks() {
281 let cfg = SandboxConfig {
282 timeout_ms: 1000,
283 max_retries: 0,
284 backoff_base_ms: 10,
285 };
286 let breaker = Arc::new(CircuitBreaker::new(1));
287 breaker.record_failure(); let result = execute_with_controls(&cfg, &breaker, || async {
290 Ok(serde_json::json!({"ok": true}))
291 })
292 .await;
293
294 assert!(result.is_err());
295 match result.unwrap_err() {
296 SandboxError::CircuitBreakerOpen { .. } => {}
297 other => panic!("expected CircuitBreakerOpen, got {:?}", other),
298 }
299 }
300
301 #[tokio::test]
302 async fn test_execute_timeout() {
303 let cfg = SandboxConfig {
304 timeout_ms: 50,
305 max_retries: 0,
306 backoff_base_ms: 10,
307 };
308 let breaker = Arc::new(CircuitBreaker::new(10));
309
310 let result = execute_with_controls(&cfg, &breaker, || async {
311 tokio::time::sleep(Duration::from_millis(200)).await;
312 Ok(serde_json::json!({"ok": true}))
313 })
314 .await;
315
316 assert!(result.is_err());
317 match result.unwrap_err() {
318 SandboxError::Timeout { .. } => {}
319 other => panic!("expected Timeout, got {:?}", other),
320 }
321 }
322}