1use parking_lot::RwLock;
4use cortexai_core::{
5 errors::ToolError,
6 tool::{ExecutionContext, Tool, ToolSchema},
7};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::time::timeout;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum CircuitState {
17 Closed,
19 Open,
21 HalfOpen,
23}
24
25#[derive(Debug, Clone)]
27pub struct CircuitBreakerConfig {
28 pub failure_threshold: u32,
30 pub reset_timeout: Duration,
32 pub success_threshold: u32,
34}
35
36impl Default for CircuitBreakerConfig {
37 fn default() -> Self {
38 Self {
39 failure_threshold: 5,
40 reset_timeout: Duration::from_secs(30),
41 success_threshold: 2,
42 }
43 }
44}
45
46#[derive(Debug)]
48struct CircuitBreaker {
49 state: CircuitState,
50 failure_count: u32,
51 success_count: u32,
52 last_failure_time: Option<Instant>,
53 config: CircuitBreakerConfig,
54}
55
56impl CircuitBreaker {
57 fn new(config: CircuitBreakerConfig) -> Self {
58 Self {
59 state: CircuitState::Closed,
60 failure_count: 0,
61 success_count: 0,
62 last_failure_time: None,
63 config,
64 }
65 }
66
67 fn can_execute(&mut self) -> bool {
68 match self.state {
69 CircuitState::Closed => true,
70 CircuitState::Open => {
71 if let Some(last_failure) = self.last_failure_time {
73 if last_failure.elapsed() >= self.config.reset_timeout {
74 self.state = CircuitState::HalfOpen;
75 self.success_count = 0;
76 return true;
77 }
78 }
79 false
80 }
81 CircuitState::HalfOpen => true,
82 }
83 }
84
85 fn record_success(&mut self) {
86 match self.state {
87 CircuitState::Closed => {
88 self.failure_count = 0;
89 }
90 CircuitState::HalfOpen => {
91 self.success_count += 1;
92 if self.success_count >= self.config.success_threshold {
93 self.state = CircuitState::Closed;
94 self.failure_count = 0;
95 self.success_count = 0;
96 }
97 }
98 CircuitState::Open => {}
99 }
100 }
101
102 fn record_failure(&mut self) {
103 self.failure_count += 1;
104 self.last_failure_time = Some(Instant::now());
105
106 match self.state {
107 CircuitState::Closed => {
108 if self.failure_count >= self.config.failure_threshold {
109 self.state = CircuitState::Open;
110 }
111 }
112 CircuitState::HalfOpen => {
113 self.state = CircuitState::Open;
114 self.success_count = 0;
115 }
116 CircuitState::Open => {}
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct RetryConfig {
124 pub max_retries: u32,
126 pub initial_backoff: Duration,
128 pub max_backoff: Duration,
130 pub multiplier: f64,
132 pub jitter: bool,
134}
135
136impl Default for RetryConfig {
137 fn default() -> Self {
138 Self {
139 max_retries: 3,
140 initial_backoff: Duration::from_millis(100),
141 max_backoff: Duration::from_secs(10),
142 multiplier: 2.0,
143 jitter: true,
144 }
145 }
146}
147
148impl RetryConfig {
149 pub fn backoff_duration(&self, attempt: u32) -> Duration {
151 let base = self.initial_backoff.as_millis() as f64;
152 let backoff = base * self.multiplier.powi(attempt as i32);
153 let capped = backoff.min(self.max_backoff.as_millis() as f64);
154
155 let final_backoff = if self.jitter {
156 let jitter = rand_jitter() * 0.3 * capped;
157 capped + jitter
158 } else {
159 capped
160 };
161
162 Duration::from_millis(final_backoff as u64)
163 }
164}
165
166fn rand_jitter() -> f64 {
168 let nanos = std::time::SystemTime::now()
169 .duration_since(std::time::UNIX_EPOCH)
170 .unwrap_or_default()
171 .subsec_nanos();
172 (nanos % 1000) as f64 / 1000.0
173}
174
175#[derive(Debug, Clone, Default, Serialize, Deserialize)]
177pub struct ToolStats {
178 pub total_calls: u64,
179 pub successful_calls: u64,
180 pub failed_calls: u64,
181 pub total_retries: u64,
182 pub circuit_breaks: u64,
183 pub timeouts: u64,
184 pub total_latency_ms: f64,
185}
186
187impl ToolStats {
188 pub fn success_rate(&self) -> f64 {
189 if self.total_calls > 0 {
190 self.successful_calls as f64 / self.total_calls as f64
191 } else {
192 1.0
193 }
194 }
195
196 pub fn avg_latency_ms(&self) -> f64 {
197 if self.successful_calls > 0 {
198 self.total_latency_ms / self.successful_calls as f64
199 } else {
200 0.0
201 }
202 }
203}
204
205struct EnhancedTool {
207 tool: Arc<dyn Tool>,
208 circuit_breaker: RwLock<CircuitBreaker>,
209 stats: RwLock<ToolStats>,
210 retry_config: RetryConfig,
211 timeout_duration: Duration,
212}
213
214pub struct EnhancedToolRegistry {
216 tools: Arc<RwLock<HashMap<String, Arc<EnhancedTool>>>>,
217 default_circuit_config: CircuitBreakerConfig,
218 default_retry_config: RetryConfig,
219 default_timeout: Duration,
220}
221
222impl EnhancedToolRegistry {
223 pub fn new() -> Self {
224 Self {
225 tools: Arc::new(RwLock::new(HashMap::new())),
226 default_circuit_config: CircuitBreakerConfig::default(),
227 default_retry_config: RetryConfig::default(),
228 default_timeout: Duration::from_secs(30),
229 }
230 }
231
232 pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
234 self.default_circuit_config = config;
235 self
236 }
237
238 pub fn with_retry(mut self, config: RetryConfig) -> Self {
240 self.default_retry_config = config;
241 self
242 }
243
244 pub fn with_timeout(mut self, timeout: Duration) -> Self {
246 self.default_timeout = timeout;
247 self
248 }
249
250 pub fn register(&self, tool: Arc<dyn Tool>) {
252 self.register_with_config(
253 tool,
254 self.default_circuit_config.clone(),
255 self.default_retry_config.clone(),
256 self.default_timeout,
257 );
258 }
259
260 pub fn register_with_config(
262 &self,
263 tool: Arc<dyn Tool>,
264 circuit_config: CircuitBreakerConfig,
265 retry_config: RetryConfig,
266 timeout_duration: Duration,
267 ) {
268 let schema = tool.schema();
269 let enhanced = Arc::new(EnhancedTool {
270 tool,
271 circuit_breaker: RwLock::new(CircuitBreaker::new(circuit_config)),
272 stats: RwLock::new(ToolStats::default()),
273 retry_config,
274 timeout_duration,
275 });
276
277 self.tools.write().insert(schema.name.clone(), enhanced);
278 }
279
280 pub async fn execute(
282 &self,
283 name: &str,
284 context: &ExecutionContext,
285 arguments: serde_json::Value,
286 ) -> Result<serde_json::Value, ToolError> {
287 let enhanced_tool = {
288 let tools = self.tools.read();
289 tools.get(name).cloned()
290 };
291
292 let enhanced_tool = enhanced_tool.ok_or_else(|| ToolError::NotFound(name.to_string()))?;
293
294 {
296 let mut cb = enhanced_tool.circuit_breaker.write();
297 if !cb.can_execute() {
298 enhanced_tool.stats.write().circuit_breaks += 1;
299 return Err(ToolError::CircuitOpen(name.to_string()));
300 }
301 }
302
303 enhanced_tool.stats.write().total_calls += 1;
305
306 let start = Instant::now();
307 let mut last_error = None;
308 let mut retries = 0;
309
310 for attempt in 0..=enhanced_tool.retry_config.max_retries {
312 if attempt > 0 {
313 retries += 1;
314 let backoff = enhanced_tool.retry_config.backoff_duration(attempt - 1);
315 tokio::time::sleep(backoff).await;
316 }
317
318 let result = timeout(
320 enhanced_tool.timeout_duration,
321 enhanced_tool.tool.execute(context, arguments.clone()),
322 )
323 .await;
324
325 match result {
326 Ok(Ok(value)) => {
327 let latency = start.elapsed().as_millis() as f64;
329 {
330 let mut stats = enhanced_tool.stats.write();
331 stats.successful_calls += 1;
332 stats.total_retries += retries;
333 stats.total_latency_ms += latency;
334 }
335 enhanced_tool.circuit_breaker.write().record_success();
336 return Ok(value);
337 }
338 Ok(Err(e)) => {
339 last_error = Some(e);
341 }
342 Err(_) => {
343 enhanced_tool.stats.write().timeouts += 1;
345 last_error = Some(ToolError::Timeout(name.to_string()));
346 }
347 }
348 }
349
350 {
352 let mut stats = enhanced_tool.stats.write();
353 stats.failed_calls += 1;
354 stats.total_retries += retries;
355 }
356 enhanced_tool.circuit_breaker.write().record_failure();
357
358 Err(last_error.unwrap_or_else(|| ToolError::Execution("Unknown error".to_string())))
359 }
360
361 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
363 self.tools.read().get(name).map(|et| et.tool.clone())
364 }
365
366 pub fn list_schemas(&self) -> Vec<ToolSchema> {
368 self.tools
369 .read()
370 .values()
371 .map(|et| et.tool.schema())
372 .collect()
373 }
374
375 pub fn get_stats(&self, name: &str) -> Option<ToolStats> {
377 self.tools
378 .read()
379 .get(name)
380 .map(|et| et.stats.read().clone())
381 }
382
383 pub fn get_circuit_state(&self, name: &str) -> Option<CircuitState> {
385 self.tools
386 .read()
387 .get(name)
388 .map(|et| et.circuit_breaker.read().state)
389 }
390
391 pub fn all_stats(&self) -> HashMap<String, ToolStats> {
393 self.tools
394 .read()
395 .iter()
396 .map(|(name, et)| (name.clone(), et.stats.read().clone()))
397 .collect()
398 }
399
400 pub fn reset_circuit(&self, name: &str) -> bool {
402 if let Some(et) = self.tools.read().get(name) {
403 let mut cb = et.circuit_breaker.write();
404 cb.state = CircuitState::Closed;
405 cb.failure_count = 0;
406 cb.success_count = 0;
407 cb.last_failure_time = None;
408 true
409 } else {
410 false
411 }
412 }
413
414 pub fn reset_all_circuits(&self) {
416 for et in self.tools.read().values() {
417 let mut cb = et.circuit_breaker.write();
418 cb.state = CircuitState::Closed;
419 cb.failure_count = 0;
420 cb.success_count = 0;
421 cb.last_failure_time = None;
422 }
423 }
424
425 pub fn has(&self, name: &str) -> bool {
427 self.tools.read().contains_key(name)
428 }
429
430 pub fn len(&self) -> usize {
432 self.tools.read().len()
433 }
434
435 pub fn is_empty(&self) -> bool {
437 self.tools.read().is_empty()
438 }
439
440 pub fn print_health_report(&self) {
442 println!("\n╔══════════════════════════════════════════════════════════════╗");
443 println!("║ TOOL REGISTRY HEALTH REPORT ║");
444 println!("╠══════════════════════════════════════════════════════════════╣");
445
446 let tools = self.tools.read();
447 for (name, et) in tools.iter() {
448 let stats = et.stats.read();
449 let cb = et.circuit_breaker.read();
450
451 let state_icon = match cb.state {
452 CircuitState::Closed => "🟢",
453 CircuitState::HalfOpen => "🟡",
454 CircuitState::Open => "🔴",
455 };
456
457 println!(
458 "║ {} {:<30} {:>6.1}% success ║",
459 state_icon,
460 if name.len() > 30 { &name[..30] } else { name },
461 stats.success_rate() * 100.0
462 );
463 println!(
464 "║ Calls: {:>8} | Retries: {:>6} | Avg: {:>6.0}ms ║",
465 stats.total_calls,
466 stats.total_retries,
467 stats.avg_latency_ms()
468 );
469 }
470
471 println!("╚══════════════════════════════════════════════════════════════╝\n");
472 }
473}
474
475impl Default for EnhancedToolRegistry {
476 fn default() -> Self {
477 Self::new()
478 }
479}
480
481impl Clone for EnhancedToolRegistry {
482 fn clone(&self) -> Self {
483 Self {
484 tools: self.tools.clone(),
485 default_circuit_config: self.default_circuit_config.clone(),
486 default_retry_config: self.default_retry_config.clone(),
487 default_timeout: self.default_timeout,
488 }
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use async_trait::async_trait;
496 use cortexai_core::types::AgentId;
497
498 fn test_ctx() -> ExecutionContext {
499 ExecutionContext::new(AgentId::new("test-agent"))
500 }
501
502 struct TestTool {
503 should_fail: std::sync::atomic::AtomicBool,
504 }
505
506 #[async_trait]
507 impl Tool for TestTool {
508 fn schema(&self) -> ToolSchema {
509 ToolSchema::new("test_tool", "A test tool")
510 }
511
512 async fn execute(
513 &self,
514 _context: &ExecutionContext,
515 _arguments: serde_json::Value,
516 ) -> Result<serde_json::Value, ToolError> {
517 if self.should_fail.load(std::sync::atomic::Ordering::SeqCst) {
518 Err(ToolError::Execution("Test failure".to_string()))
519 } else {
520 Ok(serde_json::json!({"result": "success"}))
521 }
522 }
523 }
524
525 #[tokio::test]
526 async fn test_circuit_breaker() {
527 let registry = EnhancedToolRegistry::new().with_circuit_breaker(CircuitBreakerConfig {
528 failure_threshold: 2,
529 reset_timeout: Duration::from_millis(100),
530 success_threshold: 1,
531 });
532
533 let tool = Arc::new(TestTool {
534 should_fail: std::sync::atomic::AtomicBool::new(true),
535 });
536
537 registry.register(tool.clone());
538
539 let ctx = test_ctx();
540
541 let _ = registry
543 .execute("test_tool", &ctx, serde_json::json!({}))
544 .await;
545 let _ = registry
546 .execute("test_tool", &ctx, serde_json::json!({}))
547 .await;
548
549 assert_eq!(
551 registry.get_circuit_state("test_tool"),
552 Some(CircuitState::Open)
553 );
554
555 tokio::time::sleep(Duration::from_millis(150)).await;
557
558 tool.should_fail
560 .store(false, std::sync::atomic::Ordering::SeqCst);
561 let result = registry
562 .execute("test_tool", &ctx, serde_json::json!({}))
563 .await;
564
565 assert!(result.is_ok());
566 assert_eq!(
567 registry.get_circuit_state("test_tool"),
568 Some(CircuitState::Closed)
569 );
570 }
571
572 #[test]
573 fn test_retry_backoff() {
574 let config = RetryConfig {
575 max_retries: 5,
576 initial_backoff: Duration::from_millis(100),
577 max_backoff: Duration::from_secs(10),
578 multiplier: 2.0,
579 jitter: false,
580 };
581
582 assert_eq!(config.backoff_duration(0), Duration::from_millis(100));
583 assert_eq!(config.backoff_duration(1), Duration::from_millis(200));
584 assert_eq!(config.backoff_duration(2), Duration::from_millis(400));
585 }
586}