Skip to main content

cortexai_tools/
registry.rs

1//! Enhanced tool registry with circuit breaker, retry, and timeout
2
3use parking_lot::RwLock;
4use cortexai_core::{
5    errors::ToolError,
6    tool::{ExecutionContext, Tool, ToolSchema},
7};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::time::timeout;
13
14/// Circuit breaker state
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum CircuitState {
17    /// Circuit is closed, requests flow normally
18    Closed,
19    /// Circuit is open, requests are rejected
20    Open,
21    /// Circuit is half-open, testing if service recovered
22    HalfOpen,
23}
24
25/// Circuit breaker configuration
26#[derive(Debug, Clone)]
27pub struct CircuitBreakerConfig {
28    /// Number of failures before opening circuit
29    pub failure_threshold: u32,
30    /// Duration to keep circuit open before testing
31    pub reset_timeout: Duration,
32    /// Number of successes in half-open state to close circuit
33    pub success_threshold: u32,
34}
35
36impl Default for CircuitBreakerConfig {
37    fn default() -> Self {
38        Self {
39            failure_threshold: 5,
40            reset_timeout: Duration::from_secs(30),
41            success_threshold: 2,
42        }
43    }
44}
45
46/// Circuit breaker state tracker
47#[derive(Debug)]
48struct CircuitBreaker {
49    state: CircuitState,
50    failure_count: u32,
51    success_count: u32,
52    last_failure_time: Option<Instant>,
53    config: CircuitBreakerConfig,
54}
55
56impl CircuitBreaker {
57    fn new(config: CircuitBreakerConfig) -> Self {
58        Self {
59            state: CircuitState::Closed,
60            failure_count: 0,
61            success_count: 0,
62            last_failure_time: None,
63            config,
64        }
65    }
66
67    fn can_execute(&mut self) -> bool {
68        match self.state {
69            CircuitState::Closed => true,
70            CircuitState::Open => {
71                // Check if reset timeout has elapsed
72                if let Some(last_failure) = self.last_failure_time {
73                    if last_failure.elapsed() >= self.config.reset_timeout {
74                        self.state = CircuitState::HalfOpen;
75                        self.success_count = 0;
76                        return true;
77                    }
78                }
79                false
80            }
81            CircuitState::HalfOpen => true,
82        }
83    }
84
85    fn record_success(&mut self) {
86        match self.state {
87            CircuitState::Closed => {
88                self.failure_count = 0;
89            }
90            CircuitState::HalfOpen => {
91                self.success_count += 1;
92                if self.success_count >= self.config.success_threshold {
93                    self.state = CircuitState::Closed;
94                    self.failure_count = 0;
95                    self.success_count = 0;
96                }
97            }
98            CircuitState::Open => {}
99        }
100    }
101
102    fn record_failure(&mut self) {
103        self.failure_count += 1;
104        self.last_failure_time = Some(Instant::now());
105
106        match self.state {
107            CircuitState::Closed => {
108                if self.failure_count >= self.config.failure_threshold {
109                    self.state = CircuitState::Open;
110                }
111            }
112            CircuitState::HalfOpen => {
113                self.state = CircuitState::Open;
114                self.success_count = 0;
115            }
116            CircuitState::Open => {}
117        }
118    }
119}
120
121/// Retry configuration
122#[derive(Debug, Clone)]
123pub struct RetryConfig {
124    /// Maximum number of retry attempts
125    pub max_retries: u32,
126    /// Initial backoff duration
127    pub initial_backoff: Duration,
128    /// Maximum backoff duration
129    pub max_backoff: Duration,
130    /// Backoff multiplier
131    pub multiplier: f64,
132    /// Whether to add jitter to backoff
133    pub jitter: bool,
134}
135
136impl Default for RetryConfig {
137    fn default() -> Self {
138        Self {
139            max_retries: 3,
140            initial_backoff: Duration::from_millis(100),
141            max_backoff: Duration::from_secs(10),
142            multiplier: 2.0,
143            jitter: true,
144        }
145    }
146}
147
148impl RetryConfig {
149    /// Calculate backoff duration for given attempt
150    pub fn backoff_duration(&self, attempt: u32) -> Duration {
151        let base = self.initial_backoff.as_millis() as f64;
152        let backoff = base * self.multiplier.powi(attempt as i32);
153        let capped = backoff.min(self.max_backoff.as_millis() as f64);
154
155        let final_backoff = if self.jitter {
156            let jitter = rand_jitter() * 0.3 * capped;
157            capped + jitter
158        } else {
159            capped
160        };
161
162        Duration::from_millis(final_backoff as u64)
163    }
164}
165
166/// Simple jitter using system time
167fn rand_jitter() -> f64 {
168    let nanos = std::time::SystemTime::now()
169        .duration_since(std::time::UNIX_EPOCH)
170        .unwrap_or_default()
171        .subsec_nanos();
172    (nanos % 1000) as f64 / 1000.0
173}
174
175/// Tool execution statistics
176#[derive(Debug, Clone, Default, Serialize, Deserialize)]
177pub struct ToolStats {
178    pub total_calls: u64,
179    pub successful_calls: u64,
180    pub failed_calls: u64,
181    pub total_retries: u64,
182    pub circuit_breaks: u64,
183    pub timeouts: u64,
184    pub total_latency_ms: f64,
185}
186
187impl ToolStats {
188    pub fn success_rate(&self) -> f64 {
189        if self.total_calls > 0 {
190            self.successful_calls as f64 / self.total_calls as f64
191        } else {
192            1.0
193        }
194    }
195
196    pub fn avg_latency_ms(&self) -> f64 {
197        if self.successful_calls > 0 {
198            self.total_latency_ms / self.successful_calls as f64
199        } else {
200            0.0
201        }
202    }
203}
204
205/// Tool wrapper with enhanced features
206struct EnhancedTool {
207    tool: Arc<dyn Tool>,
208    circuit_breaker: RwLock<CircuitBreaker>,
209    stats: RwLock<ToolStats>,
210    retry_config: RetryConfig,
211    timeout_duration: Duration,
212}
213
214/// Enhanced tool registry with circuit breaker, retry, and timeout
215pub struct EnhancedToolRegistry {
216    tools: Arc<RwLock<HashMap<String, Arc<EnhancedTool>>>>,
217    default_circuit_config: CircuitBreakerConfig,
218    default_retry_config: RetryConfig,
219    default_timeout: Duration,
220}
221
222impl EnhancedToolRegistry {
223    pub fn new() -> Self {
224        Self {
225            tools: Arc::new(RwLock::new(HashMap::new())),
226            default_circuit_config: CircuitBreakerConfig::default(),
227            default_retry_config: RetryConfig::default(),
228            default_timeout: Duration::from_secs(30),
229        }
230    }
231
232    /// Configure default circuit breaker settings
233    pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
234        self.default_circuit_config = config;
235        self
236    }
237
238    /// Configure default retry settings
239    pub fn with_retry(mut self, config: RetryConfig) -> Self {
240        self.default_retry_config = config;
241        self
242    }
243
244    /// Configure default timeout
245    pub fn with_timeout(mut self, timeout: Duration) -> Self {
246        self.default_timeout = timeout;
247        self
248    }
249
250    /// Register a tool with default settings
251    pub fn register(&self, tool: Arc<dyn Tool>) {
252        self.register_with_config(
253            tool,
254            self.default_circuit_config.clone(),
255            self.default_retry_config.clone(),
256            self.default_timeout,
257        );
258    }
259
260    /// Register a tool with custom settings
261    pub fn register_with_config(
262        &self,
263        tool: Arc<dyn Tool>,
264        circuit_config: CircuitBreakerConfig,
265        retry_config: RetryConfig,
266        timeout_duration: Duration,
267    ) {
268        let schema = tool.schema();
269        let enhanced = Arc::new(EnhancedTool {
270            tool,
271            circuit_breaker: RwLock::new(CircuitBreaker::new(circuit_config)),
272            stats: RwLock::new(ToolStats::default()),
273            retry_config,
274            timeout_duration,
275        });
276
277        self.tools.write().insert(schema.name.clone(), enhanced);
278    }
279
280    /// Execute a tool with circuit breaker, retry, and timeout
281    pub async fn execute(
282        &self,
283        name: &str,
284        context: &ExecutionContext,
285        arguments: serde_json::Value,
286    ) -> Result<serde_json::Value, ToolError> {
287        let enhanced_tool = {
288            let tools = self.tools.read();
289            tools.get(name).cloned()
290        };
291
292        let enhanced_tool = enhanced_tool.ok_or_else(|| ToolError::NotFound(name.to_string()))?;
293
294        // Check circuit breaker
295        {
296            let mut cb = enhanced_tool.circuit_breaker.write();
297            if !cb.can_execute() {
298                enhanced_tool.stats.write().circuit_breaks += 1;
299                return Err(ToolError::CircuitOpen(name.to_string()));
300            }
301        }
302
303        // Update call count
304        enhanced_tool.stats.write().total_calls += 1;
305
306        let start = Instant::now();
307        let mut last_error = None;
308        let mut retries = 0;
309
310        // Retry loop
311        for attempt in 0..=enhanced_tool.retry_config.max_retries {
312            if attempt > 0 {
313                retries += 1;
314                let backoff = enhanced_tool.retry_config.backoff_duration(attempt - 1);
315                tokio::time::sleep(backoff).await;
316            }
317
318            // Execute with timeout
319            let result = timeout(
320                enhanced_tool.timeout_duration,
321                enhanced_tool.tool.execute(context, arguments.clone()),
322            )
323            .await;
324
325            match result {
326                Ok(Ok(value)) => {
327                    // Success
328                    let latency = start.elapsed().as_millis() as f64;
329                    {
330                        let mut stats = enhanced_tool.stats.write();
331                        stats.successful_calls += 1;
332                        stats.total_retries += retries;
333                        stats.total_latency_ms += latency;
334                    }
335                    enhanced_tool.circuit_breaker.write().record_success();
336                    return Ok(value);
337                }
338                Ok(Err(e)) => {
339                    // Tool error
340                    last_error = Some(e);
341                }
342                Err(_) => {
343                    // Timeout
344                    enhanced_tool.stats.write().timeouts += 1;
345                    last_error = Some(ToolError::Timeout(name.to_string()));
346                }
347            }
348        }
349
350        // All retries exhausted
351        {
352            let mut stats = enhanced_tool.stats.write();
353            stats.failed_calls += 1;
354            stats.total_retries += retries;
355        }
356        enhanced_tool.circuit_breaker.write().record_failure();
357
358        Err(last_error.unwrap_or_else(|| ToolError::Execution("Unknown error".to_string())))
359    }
360
361    /// Get tool by name
362    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
363        self.tools.read().get(name).map(|et| et.tool.clone())
364    }
365
366    /// List all tool schemas
367    pub fn list_schemas(&self) -> Vec<ToolSchema> {
368        self.tools
369            .read()
370            .values()
371            .map(|et| et.tool.schema())
372            .collect()
373    }
374
375    /// Get statistics for a tool
376    pub fn get_stats(&self, name: &str) -> Option<ToolStats> {
377        self.tools
378            .read()
379            .get(name)
380            .map(|et| et.stats.read().clone())
381    }
382
383    /// Get circuit state for a tool
384    pub fn get_circuit_state(&self, name: &str) -> Option<CircuitState> {
385        self.tools
386            .read()
387            .get(name)
388            .map(|et| et.circuit_breaker.read().state)
389    }
390
391    /// Get all tool statistics
392    pub fn all_stats(&self) -> HashMap<String, ToolStats> {
393        self.tools
394            .read()
395            .iter()
396            .map(|(name, et)| (name.clone(), et.stats.read().clone()))
397            .collect()
398    }
399
400    /// Reset circuit breaker for a tool
401    pub fn reset_circuit(&self, name: &str) -> bool {
402        if let Some(et) = self.tools.read().get(name) {
403            let mut cb = et.circuit_breaker.write();
404            cb.state = CircuitState::Closed;
405            cb.failure_count = 0;
406            cb.success_count = 0;
407            cb.last_failure_time = None;
408            true
409        } else {
410            false
411        }
412    }
413
414    /// Reset all circuit breakers
415    pub fn reset_all_circuits(&self) {
416        for et in self.tools.read().values() {
417            let mut cb = et.circuit_breaker.write();
418            cb.state = CircuitState::Closed;
419            cb.failure_count = 0;
420            cb.success_count = 0;
421            cb.last_failure_time = None;
422        }
423    }
424
425    /// Check if tool exists
426    pub fn has(&self, name: &str) -> bool {
427        self.tools.read().contains_key(name)
428    }
429
430    /// Get tool count
431    pub fn len(&self) -> usize {
432        self.tools.read().len()
433    }
434
435    /// Check if registry is empty
436    pub fn is_empty(&self) -> bool {
437        self.tools.read().is_empty()
438    }
439
440    /// Print tool health report
441    pub fn print_health_report(&self) {
442        println!("\n╔══════════════════════════════════════════════════════════════╗");
443        println!("║               TOOL REGISTRY HEALTH REPORT                    ║");
444        println!("╠══════════════════════════════════════════════════════════════╣");
445
446        let tools = self.tools.read();
447        for (name, et) in tools.iter() {
448            let stats = et.stats.read();
449            let cb = et.circuit_breaker.read();
450
451            let state_icon = match cb.state {
452                CircuitState::Closed => "🟢",
453                CircuitState::HalfOpen => "🟡",
454                CircuitState::Open => "🔴",
455            };
456
457            println!(
458                "║ {} {:<30} {:>6.1}% success              ║",
459                state_icon,
460                if name.len() > 30 { &name[..30] } else { name },
461                stats.success_rate() * 100.0
462            );
463            println!(
464                "║   Calls: {:>8} | Retries: {:>6} | Avg: {:>6.0}ms           ║",
465                stats.total_calls,
466                stats.total_retries,
467                stats.avg_latency_ms()
468            );
469        }
470
471        println!("╚══════════════════════════════════════════════════════════════╝\n");
472    }
473}
474
475impl Default for EnhancedToolRegistry {
476    fn default() -> Self {
477        Self::new()
478    }
479}
480
481impl Clone for EnhancedToolRegistry {
482    fn clone(&self) -> Self {
483        Self {
484            tools: self.tools.clone(),
485            default_circuit_config: self.default_circuit_config.clone(),
486            default_retry_config: self.default_retry_config.clone(),
487            default_timeout: self.default_timeout,
488        }
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use async_trait::async_trait;
496    use cortexai_core::types::AgentId;
497
498    fn test_ctx() -> ExecutionContext {
499        ExecutionContext::new(AgentId::new("test-agent"))
500    }
501
502    struct TestTool {
503        should_fail: std::sync::atomic::AtomicBool,
504    }
505
506    #[async_trait]
507    impl Tool for TestTool {
508        fn schema(&self) -> ToolSchema {
509            ToolSchema::new("test_tool", "A test tool")
510        }
511
512        async fn execute(
513            &self,
514            _context: &ExecutionContext,
515            _arguments: serde_json::Value,
516        ) -> Result<serde_json::Value, ToolError> {
517            if self.should_fail.load(std::sync::atomic::Ordering::SeqCst) {
518                Err(ToolError::Execution("Test failure".to_string()))
519            } else {
520                Ok(serde_json::json!({"result": "success"}))
521            }
522        }
523    }
524
525    #[tokio::test]
526    async fn test_circuit_breaker() {
527        let registry = EnhancedToolRegistry::new().with_circuit_breaker(CircuitBreakerConfig {
528            failure_threshold: 2,
529            reset_timeout: Duration::from_millis(100),
530            success_threshold: 1,
531        });
532
533        let tool = Arc::new(TestTool {
534            should_fail: std::sync::atomic::AtomicBool::new(true),
535        });
536
537        registry.register(tool.clone());
538
539        let ctx = test_ctx();
540
541        // Fail twice to open circuit
542        let _ = registry
543            .execute("test_tool", &ctx, serde_json::json!({}))
544            .await;
545        let _ = registry
546            .execute("test_tool", &ctx, serde_json::json!({}))
547            .await;
548
549        // Circuit should be open
550        assert_eq!(
551            registry.get_circuit_state("test_tool"),
552            Some(CircuitState::Open)
553        );
554
555        // Wait for reset timeout
556        tokio::time::sleep(Duration::from_millis(150)).await;
557
558        // Circuit should transition to half-open
559        tool.should_fail
560            .store(false, std::sync::atomic::Ordering::SeqCst);
561        let result = registry
562            .execute("test_tool", &ctx, serde_json::json!({}))
563            .await;
564
565        assert!(result.is_ok());
566        assert_eq!(
567            registry.get_circuit_state("test_tool"),
568            Some(CircuitState::Closed)
569        );
570    }
571
572    #[test]
573    fn test_retry_backoff() {
574        let config = RetryConfig {
575            max_retries: 5,
576            initial_backoff: Duration::from_millis(100),
577            max_backoff: Duration::from_secs(10),
578            multiplier: 2.0,
579            jitter: false,
580        };
581
582        assert_eq!(config.backoff_duration(0), Duration::from_millis(100));
583        assert_eq!(config.backoff_duration(1), Duration::from_millis(200));
584        assert_eq!(config.backoff_duration(2), Duration::from_millis(400));
585    }
586}