pe-core 0.1.0

Core types for Potential Expectations — messages, channels, state, traits
Documentation
//! Circuit breaker -- fail-fast on repeated transient LLM errors.
//!
//! States: **Closed** (normal), **Open** (rejecting), **HalfOpen** (probing).
//! Uses atomic operations for lock-free state management.

use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Duration;

use async_trait::async_trait;

use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
use crate::message::Message;
use crate::provider_middleware::ProviderMiddleware;

/// Circuit breaker state encoded as u32 for atomic storage.
const STATE_CLOSED: u32 = 0;
const STATE_OPEN: u32 = 1;
const STATE_HALF_OPEN: u32 = 2;

/// Lock-free circuit breaker for LLM providers.
///
/// Tracks consecutive transient failures. When `failure_threshold` is reached,
/// the breaker opens and rejects calls immediately. After `recovery_timeout`,
/// it moves to half-open and allows one probe call.
pub struct CircuitBreaker {
    failure_threshold: u32,
    recovery_timeout: Duration,
    /// Current state (STATE_CLOSED / STATE_OPEN / STATE_HALF_OPEN).
    state: AtomicU32,
    /// Consecutive failure count.
    failure_count: AtomicU32,
    /// Timestamp (millis since epoch) when the breaker opened.
    opened_at: AtomicU64,
}

impl CircuitBreaker {
    /// Create a circuit breaker.
    ///
    /// - `failure_threshold`: consecutive transient failures before opening.
    /// - `recovery_timeout`: time to wait before probing in half-open state.
    pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
        Self {
            failure_threshold,
            recovery_timeout,
            state: AtomicU32::new(STATE_CLOSED),
            failure_count: AtomicU32::new(0),
            opened_at: AtomicU64::new(0),
        }
    }

    /// Current state as a human-readable string (for diagnostics).
    pub fn state_name(&self) -> &'static str {
        match self.state.load(Ordering::SeqCst) {
            STATE_CLOSED => "closed",
            STATE_OPEN => "open",
            STATE_HALF_OPEN => "half-open",
            _ => "unknown",
        }
    }

    /// Current consecutive failure count.
    pub fn failure_count(&self) -> u32 {
        self.failure_count.load(Ordering::SeqCst)
    }

    fn now_millis() -> u64 {
        std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_default()
            .as_millis() as u64
    }

    fn record_success(&self) {
        self.failure_count.store(0, Ordering::SeqCst);
        self.state.store(STATE_CLOSED, Ordering::SeqCst);
    }

    fn record_failure(&self) {
        let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
        if count >= self.failure_threshold {
            self.state.store(STATE_OPEN, Ordering::SeqCst);
            self.opened_at.store(Self::now_millis(), Ordering::SeqCst);
        }
    }

    fn should_allow(&self) -> bool {
        match self.state.load(Ordering::SeqCst) {
            STATE_CLOSED => true,
            STATE_HALF_OPEN => {
                // Only one probe at a time: CAS from HalfOpen to Closed (probe in flight).
                // All other callers are rejected until the probe completes.
                self.state
                    .compare_exchange(
                        STATE_HALF_OPEN,
                        STATE_CLOSED,
                        Ordering::SeqCst,
                        Ordering::SeqCst,
                    )
                    .is_ok()
            }
            STATE_OPEN => {
                let opened = self.opened_at.load(Ordering::SeqCst);
                let elapsed = Self::now_millis().saturating_sub(opened);
                if elapsed >= self.recovery_timeout.as_millis() as u64 {
                    // CAS from Open to HalfOpen — only one thread wins
                    self.state
                        .compare_exchange(
                            STATE_OPEN,
                            STATE_HALF_OPEN,
                            Ordering::SeqCst,
                            Ordering::SeqCst,
                        )
                        .is_ok()
                } else {
                    false
                }
            }
            _ => false,
        }
    }
}

#[async_trait]
impl ProviderMiddleware for CircuitBreaker {
    async fn wrap_complete(
        &self,
        messages: &[Message],
        tools: &[ToolSchema],
        next: &dyn LlmProvider,
    ) -> Result<LlmResponse, PeError> {
        if !self.should_allow() {
            return Err(PeError::LlmProvider {
                details: "circuit breaker open — provider is unavailable".into(),
            });
        }

        match next.complete(messages, tools).await {
            Ok(resp) => {
                self.record_success();
                Ok(resp)
            }
            Err(e) if e.is_transient() => {
                self.record_failure();
                Err(e)
            }
            Err(e) => Err(e), // permanent errors don't affect the breaker
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::mock_provider::MockProvider;

    fn llm_err() -> PeError {
        PeError::LlmProvider {
            details: "err".into(),
        }
    }

    fn fail_provider(n: usize) -> MockProvider {
        let mut p = MockProvider::new();
        for _ in 0..n {
            p = p.respond_with_error(llm_err());
        }
        p
    }

    #[tokio::test]
    async fn test_closed_allows_calls() {
        let cb = CircuitBreaker::new(3, Duration::from_secs(60));
        let resp = cb
            .wrap_complete(&[], &[], &MockProvider::new().respond_with("ok"))
            .await
            .unwrap();
        assert_eq!(resp.message.content.as_text(), Some("ok"));
        assert_eq!(cb.state_name(), "closed");
    }

    #[tokio::test]
    async fn test_opens_after_threshold_failures() {
        let cb = CircuitBreaker::new(2, Duration::from_secs(60));
        let provider = fail_provider(2);
        let _ = cb.wrap_complete(&[], &[], &provider).await;
        assert_eq!(cb.state_name(), "closed");
        let _ = cb.wrap_complete(&[], &[], &provider).await;
        assert_eq!(cb.state_name(), "open");
    }

    #[tokio::test]
    async fn test_open_rejects_immediately() {
        let cb = CircuitBreaker::new(1, Duration::from_secs(60));
        let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
        assert_eq!(cb.state_name(), "open");

        let ok = MockProvider::new().respond_with("should not reach");
        let err = cb.wrap_complete(&[], &[], &ok).await.unwrap_err();
        assert!(matches!(err, PeError::LlmProvider { .. }));
        assert_eq!(ok.remaining(), 1);
    }

    #[tokio::test]
    async fn test_half_open_recovery_and_reopen() {
        let cb = CircuitBreaker::new(1, Duration::from_millis(10));
        let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
        assert_eq!(cb.state_name(), "open");
        tokio::time::sleep(Duration::from_millis(20)).await;

        // Successful probe resets to closed
        let probe = MockProvider::new().respond_with("recovered");
        let resp = cb.wrap_complete(&[], &[], &probe).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("recovered"));
        assert_eq!(cb.state_name(), "closed");

        // Trip again, wait, probe fails -> reopens
        let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
        tokio::time::sleep(Duration::from_millis(20)).await;
        let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
        assert_eq!(cb.state_name(), "open");
    }

    #[tokio::test]
    async fn test_permanent_errors_dont_trip_breaker() {
        let cb = CircuitBreaker::new(1, Duration::from_secs(60));
        let p = MockProvider::new().respond_with_error(PeError::PermissionDenied {
            action: "write".into(),
        });
        let _ = cb.wrap_complete(&[], &[], &p).await;
        assert_eq!(cb.state_name(), "closed");
        assert_eq!(cb.failure_count(), 0);
    }

    #[tokio::test]
    async fn test_success_resets_failure_count() {
        let cb = CircuitBreaker::new(3, Duration::from_secs(60));
        let provider = fail_provider(2).respond_with("ok");
        let _ = cb.wrap_complete(&[], &[], &provider).await;
        let _ = cb.wrap_complete(&[], &[], &provider).await;
        assert_eq!(cb.failure_count(), 2);
        let _ = cb.wrap_complete(&[], &[], &provider).await;
        assert_eq!(cb.failure_count(), 0);
    }
}