1use crate::{ErrorKind, ToolError, ToolResult};
4use async_trait::async_trait;
5use futures::stream::{self, StreamExt};
6use parking_lot::RwLock;
7use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{watch, Semaphore};
11use tracing::{debug, instrument, warn};
12
13#[derive(Clone)]
15pub struct ExecutionContext {
16 cancellation: Arc<watch::Receiver<bool>>,
17 pub timeout: Option<Duration>,
18 pub max_memory: Option<usize>,
19 pub metadata: Arc<RwLock<ExecutionMetadata>>,
20 started_at: Instant,
21}
22
23impl ExecutionContext {
24 pub fn new() -> Self {
25 let (tx, rx) = watch::channel(false);
26 std::mem::drop(tx);
27 Self {
28 cancellation: Arc::new(rx),
29 timeout: None,
30 max_memory: None,
31 metadata: Arc::new(RwLock::new(ExecutionMetadata::default())),
32 started_at: Instant::now(),
33 }
34 }
35
36 pub fn with_cancellation(cancellation: watch::Receiver<bool>) -> Self {
37 Self {
38 cancellation: Arc::new(cancellation),
39 timeout: None,
40 max_memory: None,
41 metadata: Arc::new(RwLock::new(ExecutionMetadata::default())),
42 started_at: Instant::now(),
43 }
44 }
45
46 pub fn with_timeout(mut self, timeout: Duration) -> Self {
47 self.timeout = Some(timeout);
48 self
49 }
50
51 pub fn elapsed(&self) -> Duration {
52 self.started_at.elapsed()
53 }
54
55 pub fn is_cancelled(&self) -> bool {
56 *self.cancellation.borrow()
57 }
58
59 pub fn check_cancelled(&self) -> ToolResult<()> {
60 if self.is_cancelled() {
61 Err(ToolError::Cancelled)
62 } else {
63 Ok(())
64 }
65 }
66
67 pub fn set_metadata<V: serde::Serialize>(&self, key: impl Into<String>, value: V) {
68 if let Ok(v) = serde_json::to_value(value) {
69 self.metadata.write().fields.insert(key.into(), v);
70 }
71 }
72
73 pub fn get_metadata<T: for<'de> serde::Deserialize<'de>>(&self, key: &str) -> Option<T> {
74 self.metadata
75 .read()
76 .fields
77 .get(key)
78 .and_then(|v| serde_json::from_value(v.clone()).ok())
79 }
80}
81
82impl Default for ExecutionContext {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88#[derive(Default)]
89pub struct ExecutionMetadata {
90 pub fields: std::collections::HashMap<String, serde_json::Value>,
91}
92
93#[async_trait]
95pub trait ToolExecutor: Send + Sync {
96 type Output: serde::Serialize + Send;
97 type Error: std::error::Error + Send + Sync + 'static;
98
99 async fn execute(&self, ctx: &ExecutionContext) -> Result<Self::Output, Self::Error>;
100
101 async fn execute_tool(&self, ctx: &ExecutionContext) -> ToolResult<Self::Output> {
102 self.execute(ctx).await.map_err(|e| ToolError::custom(e))
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct RetryPolicy {
109 pub max_attempts: u32,
110 pub base_delay: Duration,
111 pub max_delay: Duration,
112 pub strategy: RetryStrategy,
113 pub retryable_errors: Vec<ErrorKind>,
114 pub jitter: bool,
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118pub enum RetryStrategy {
119 Fixed,
120 Exponential,
121 Linear,
122}
123
124impl RetryPolicy {
125 pub fn exponential(max_attempts: u32) -> Self {
126 Self {
127 max_attempts,
128 base_delay: Duration::from_millis(100),
129 max_delay: Duration::from_secs(30),
130 strategy: RetryStrategy::Exponential,
131 retryable_errors: vec![ErrorKind::Network, ErrorKind::Timeout, ErrorKind::Resource],
132 jitter: true,
133 }
134 }
135
136 pub fn fixed(max_attempts: u32, delay: Duration) -> Self {
137 Self {
138 max_attempts,
139 base_delay: delay,
140 max_delay: delay,
141 strategy: RetryStrategy::Fixed,
142 retryable_errors: vec![ErrorKind::Network, ErrorKind::Timeout, ErrorKind::Resource],
143 jitter: false,
144 }
145 }
146
147 pub fn with_backoff(mut self, delay: Duration) -> Self {
148 self.base_delay = delay;
149 self
150 }
151
152 pub fn with_max_delay(mut self, delay: Duration) -> Self {
153 self.max_delay = delay;
154 self
155 }
156
157 pub fn with_jitter(mut self, jitter: bool) -> Self {
158 self.jitter = jitter;
159 self
160 }
161
162 pub fn should_retry(&self, error: &ToolError) -> bool {
163 self.retryable_errors.contains(&error.kind())
164 }
165
166 pub fn calculate_backoff(&self, attempt: u32) -> Duration {
167 let delay = match self.strategy {
168 RetryStrategy::Fixed => self.base_delay,
169 RetryStrategy::Exponential => {
170 let multiplier = 2u32.pow(attempt.saturating_sub(1));
171 self.base_delay.saturating_mul(multiplier)
172 }
173 RetryStrategy::Linear => self.base_delay.saturating_mul(attempt),
174 };
175
176 let delay = delay.min(self.max_delay);
177
178 if self.jitter {
179 let jitter_factor = 0.5 + (rand::random::<f64>() * 1.0);
181 Duration::from_secs_f64(delay.as_secs_f64() * jitter_factor)
182 } else {
183 delay
184 }
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct CircuitBreaker {
191 failure_threshold: u32,
192 success_threshold: u32,
193 timeout: Duration,
194 state: Arc<RwLock<CircuitBreakerState>>,
195 failures: Arc<AtomicU64>,
196 successes: Arc<AtomicU64>,
197 last_failure_time: Arc<RwLock<Option<Instant>>>,
198}
199
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201enum CircuitBreakerState {
202 Closed,
203 Open,
204 HalfOpen,
205}
206
207impl CircuitBreaker {
208 pub fn new(failure_threshold: u32, timeout: Duration) -> Self {
209 Self {
210 failure_threshold,
211 success_threshold: 2,
212 timeout,
213 state: Arc::new(RwLock::new(CircuitBreakerState::Closed)),
214 failures: Arc::new(AtomicU64::new(0)),
215 successes: Arc::new(AtomicU64::new(0)),
216 last_failure_time: Arc::new(RwLock::new(None)),
217 }
218 }
219
220 pub fn call<F, Fut, T>(&self, f: F) -> impl std::future::Future<Output = ToolResult<T>>
221 where
222 F: FnOnce() -> Fut,
223 Fut: std::future::Future<Output = ToolResult<T>>,
224 {
225 let state = *self.state.read();
226 let should_attempt = match state {
227 CircuitBreakerState::Open => {
228 if let Some(last_failure) = *self.last_failure_time.read() {
229 last_failure.elapsed() > self.timeout
230 } else {
231 false
232 }
233 }
234 _ => true,
235 };
236
237 let failures = self.failures.clone();
238 let successes = self.successes.clone();
239 let state_arc = self.state.clone();
240 let last_failure = self.last_failure_time.clone();
241 let failure_threshold = self.failure_threshold;
242 let success_threshold = self.success_threshold;
243
244 async move {
245 if !should_attempt {
246 return Err(ToolError::execution_failed("Circuit breaker is open"));
247 }
248
249 match f().await {
250 Ok(result) => {
251 successes.fetch_add(1, Ordering::Relaxed);
252 let success_count = successes.load(Ordering::Relaxed);
253
254 if success_count >= success_threshold as u64 {
255 *state_arc.write() = CircuitBreakerState::Closed;
256 failures.store(0, Ordering::Relaxed);
257 successes.store(0, Ordering::Relaxed);
258 }
259
260 Ok(result)
261 }
262 Err(err) => {
263 failures.fetch_add(1, Ordering::Relaxed);
264 *last_failure.write() = Some(Instant::now());
265
266 if failures.load(Ordering::Relaxed) >= failure_threshold as u64 {
267 *state_arc.write() = CircuitBreakerState::Open;
268 }
269
270 Err(err)
271 }
272 }
273 }
274 }
275}
276
277#[derive(Clone)]
279pub struct Executor {
280 config: Arc<ExecutorConfig>,
281 semaphore: Arc<Semaphore>,
282 metrics: Arc<ExecutorMetrics>,
283 circuit_breaker: Option<Arc<CircuitBreaker>>,
284}
285
286#[derive(Debug)]
287struct ExecutorConfig {
288 default_timeout: Option<Duration>,
289 max_concurrent: usize,
290 retry_policy: Option<RetryPolicy>,
291 enable_tracing: bool,
292}
293
294impl Default for ExecutorConfig {
295 fn default() -> Self {
296 Self {
297 default_timeout: Some(Duration::from_secs(30)),
298 max_concurrent: 100,
299 retry_policy: None,
300 enable_tracing: true,
301 }
302 }
303}
304
305#[derive(Debug, Default)]
306pub struct ExecutorMetrics {
307 pub total_executions: AtomicUsize,
308 pub successful_executions: AtomicUsize,
309 pub failed_executions: AtomicUsize,
310 pub total_duration_ms: AtomicU64,
311}
312
313impl ExecutorMetrics {
314 pub fn success_rate(&self) -> f64 {
315 let total = self.total_executions.load(Ordering::Relaxed);
316 if total == 0 {
317 return 0.0;
318 }
319 let successful = self.successful_executions.load(Ordering::Relaxed);
320 (successful as f64 / total as f64) * 100.0
321 }
322
323 pub fn avg_duration_ms(&self) -> f64 {
324 let total = self.total_executions.load(Ordering::Relaxed);
325 if total == 0 {
326 return 0.0;
327 }
328 let duration = self.total_duration_ms.load(Ordering::Relaxed);
329 duration as f64 / total as f64
330 }
331}
332
333impl Executor {
334 pub fn new() -> Self {
335 let config = ExecutorConfig::default();
336 let max_concurrent = config.max_concurrent;
337 Self {
338 config: Arc::new(config),
339 semaphore: Arc::new(Semaphore::new(max_concurrent)),
340 metrics: Arc::new(ExecutorMetrics::default()),
341 circuit_breaker: None,
342 }
343 }
344
345 pub fn builder() -> ExecutorBuilder {
346 ExecutorBuilder::new()
347 }
348
349 pub fn metrics(&self) -> &ExecutorMetrics {
350 &self.metrics
351 }
352
353 #[instrument(skip(self, tool))]
354 pub async fn execute<T>(&self, tool: &T) -> ToolResult<T::Output>
355 where
356 T: ToolExecutor,
357 {
358 let ctx = ExecutionContext::new();
359 self.execute_with_context(tool, &ctx).await
360 }
361
362 pub async fn execute_with_context<T>(
363 &self,
364 tool: &T,
365 ctx: &ExecutionContext,
366 ) -> ToolResult<T::Output>
367 where
368 T: ToolExecutor,
369 {
370 let _permit = self
372 .semaphore
373 .acquire()
374 .await
375 .map_err(|_| ToolError::execution_failed("Failed to acquire execution permit"))?;
376
377 let start = Instant::now();
378 self.metrics
379 .total_executions
380 .fetch_add(1, Ordering::Relaxed);
381
382 let result = if let Some(ref cb) = self.circuit_breaker {
383 cb.call(|| self.execute_internal(tool, ctx)).await
384 } else {
385 self.execute_internal(tool, ctx).await
386 };
387
388 let duration = start.elapsed();
389 self.metrics
390 .total_duration_ms
391 .fetch_add(duration.as_millis() as u64, Ordering::Relaxed);
392
393 match &result {
394 Ok(_) => {
395 self.metrics
396 .successful_executions
397 .fetch_add(1, Ordering::Relaxed);
398 debug!("Tool execution succeeded in {:?}", duration);
399 }
400 Err(e) => {
401 self.metrics
402 .failed_executions
403 .fetch_add(1, Ordering::Relaxed);
404 warn!("Tool execution failed: {} (duration: {:?})", e, duration);
405 }
406 }
407
408 result
409 }
410
411 async fn execute_internal<T>(&self, tool: &T, ctx: &ExecutionContext) -> ToolResult<T::Output>
412 where
413 T: ToolExecutor,
414 {
415 let timeout = ctx.timeout.or(self.config.default_timeout);
416
417 if let Some(ref retry_policy) = self.config.retry_policy {
418 self.execute_with_retry(tool, ctx, retry_policy, timeout)
419 .await
420 } else if let Some(timeout_duration) = timeout {
421 self.execute_with_timeout(tool, ctx, timeout_duration).await
422 } else {
423 tool.execute_tool(ctx).await
424 }
425 }
426
427 async fn execute_with_timeout<T>(
428 &self,
429 tool: &T,
430 ctx: &ExecutionContext,
431 timeout: Duration,
432 ) -> ToolResult<T::Output>
433 where
434 T: ToolExecutor,
435 {
436 tokio::time::timeout(timeout, tool.execute_tool(ctx))
437 .await
438 .map_err(|_| ToolError::Timeout(timeout))?
439 }
440
441 async fn execute_with_retry<T>(
442 &self,
443 tool: &T,
444 ctx: &ExecutionContext,
445 policy: &RetryPolicy,
446 timeout: Option<Duration>,
447 ) -> ToolResult<T::Output>
448 where
449 T: ToolExecutor,
450 {
451 let mut attempts = 0;
452 let mut last_error = None;
453
454 while attempts <= policy.max_attempts {
455 let result = if let Some(timeout_duration) = timeout {
456 self.execute_with_timeout(tool, ctx, timeout_duration).await
457 } else {
458 tool.execute_tool(ctx).await
459 };
460
461 match result {
462 Ok(output) => return Ok(output),
463 Err(err) => {
464 attempts += 1;
465 if !policy.should_retry(&err) || attempts > policy.max_attempts {
466 return Err(err);
467 }
468 last_error = Some(err);
469 let delay = policy.calculate_backoff(attempts);
470 debug!("Retrying after {:?} (attempt {})", delay, attempts);
471 tokio::time::sleep(delay).await;
472 }
473 }
474 }
475
476 Err(last_error
477 .unwrap_or_else(|| ToolError::execution_failed("Max retry attempts exceeded")))
478 }
479
480 pub async fn execute_batch<T>(&self, tools: Vec<T>) -> Vec<ToolResult<T::Output>>
482 where
483 T: ToolExecutor + Clone,
484 {
485 stream::iter(tools)
486 .map(|tool| async move { self.execute(&tool).await })
487 .buffer_unordered(self.config.max_concurrent)
488 .collect()
489 .await
490 }
491}
492
493impl Default for Executor {
494 fn default() -> Self {
495 Self::new()
496 }
497}
498
499#[derive(Default)]
501pub struct ExecutorBuilder {
502 config: ExecutorConfig,
503 circuit_breaker: Option<CircuitBreaker>,
504}
505
506impl ExecutorBuilder {
507 pub fn new() -> Self {
508 Self {
509 config: ExecutorConfig::default(),
510 circuit_breaker: None,
511 }
512 }
513
514 pub fn timeout(mut self, timeout: Duration) -> Self {
515 self.config.default_timeout = Some(timeout);
516 self
517 }
518
519 pub fn max_concurrent(mut self, max: usize) -> Self {
520 self.config.max_concurrent = max;
521 self
522 }
523
524 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
525 self.config.retry_policy = Some(policy);
526 self
527 }
528
529 pub fn circuit_breaker(mut self, failure_threshold: u32, timeout: Duration) -> Self {
530 self.circuit_breaker = Some(CircuitBreaker::new(failure_threshold, timeout));
531 self
532 }
533
534 pub fn enable_tracing(mut self, enable: bool) -> Self {
535 self.config.enable_tracing = enable;
536 self
537 }
538
539 pub fn build(self) -> Executor {
540 let max_concurrent = self.config.max_concurrent;
541 Executor {
542 config: Arc::new(self.config),
543 semaphore: Arc::new(Semaphore::new(max_concurrent)),
544 metrics: Arc::new(ExecutorMetrics::default()),
545 circuit_breaker: self.circuit_breaker.map(Arc::new),
546 }
547 }
548}
549
550mod rand {
552 use std::cell::Cell;
553
554 thread_local! {
555 static RNG: Cell<u64> = Cell::new(0x4d595df4d0f33173);
556 }
557
558 pub fn random<T: SampleUniform>() -> T {
559 T::sample_uniform()
560 }
561
562 pub trait SampleUniform: Sized {
563 fn sample_uniform() -> Self;
564 }
565
566 impl SampleUniform for f64 {
567 fn sample_uniform() -> Self {
568 RNG.with(|rng| {
569 let mut x = rng.get();
570 x ^= x << 13;
571 x ^= x >> 7;
572 x ^= x << 17;
573 rng.set(x);
574 (x as f64) / (u64::MAX as f64)
575 })
576 }
577 }
578}