use std::future::Future;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::Mutex;
use cognis_core::error::{CognisError, Result};
use cognis_core::language_models::chat_model::{
BaseChatModel, ChatStream, ModelProfile, ToolChoice,
};
use cognis_core::messages::Message;
use cognis_core::outputs::ChatResult;
use cognis_core::tools::ToolSchema;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
pub struct CircuitBreaker {
state: Arc<Mutex<CircuitState>>,
failure_threshold: u32,
reset_timeout: Duration,
consecutive_failures: Arc<AtomicU32>,
last_failure_time: Arc<Mutex<Option<Instant>>>,
}
impl CircuitBreaker {
pub fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
Self {
state: Arc::new(Mutex::new(CircuitState::Closed)),
failure_threshold,
reset_timeout,
consecutive_failures: Arc::new(AtomicU32::new(0)),
last_failure_time: Arc::new(Mutex::new(None)),
}
}
pub async fn state(&self) -> CircuitState {
let mut state = self.state.lock().await;
if *state == CircuitState::Open {
let last_failure = self.last_failure_time.lock().await;
if let Some(t) = *last_failure {
if t.elapsed() >= self.reset_timeout {
*state = CircuitState::HalfOpen;
}
}
}
*state
}
pub async fn reset(&self) {
let mut state = self.state.lock().await;
*state = CircuitState::Closed;
self.consecutive_failures.store(0, Ordering::SeqCst);
let mut last = self.last_failure_time.lock().await;
*last = None;
}
pub async fn call<F, Fut, T>(&self, f: F) -> Result<T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T>>,
{
let current_state = self.state().await;
match current_state {
CircuitState::Open => Err(CognisError::Other(
"Circuit breaker is open: too many consecutive failures".into(),
)),
CircuitState::HalfOpen | CircuitState::Closed => {
match f().await {
Ok(result) => {
self.consecutive_failures.store(0, Ordering::SeqCst);
let mut state = self.state.lock().await;
*state = CircuitState::Closed;
Ok(result)
}
Err(e) => {
let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
let mut last = self.last_failure_time.lock().await;
*last = Some(Instant::now());
if current_state == CircuitState::HalfOpen
|| failures >= self.failure_threshold
{
let mut state = self.state.lock().await;
*state = CircuitState::Open;
}
Err(e)
}
}
}
}
}
}
pub struct CircuitBreakerChatModel {
inner: Box<dyn BaseChatModel>,
breaker: CircuitBreaker,
}
impl CircuitBreakerChatModel {
pub fn new(
inner: Box<dyn BaseChatModel>,
failure_threshold: u32,
reset_timeout: Duration,
) -> Self {
Self {
inner,
breaker: CircuitBreaker::new(failure_threshold, reset_timeout),
}
}
pub async fn circuit_state(&self) -> CircuitState {
self.breaker.state().await
}
pub async fn reset(&self) {
self.breaker.reset().await;
}
}
#[async_trait]
impl BaseChatModel for CircuitBreakerChatModel {
async fn _generate(&self, messages: &[Message], stop: Option<&[String]>) -> Result<ChatResult> {
let inner = &self.inner;
self.breaker
.call(|| async move { inner._generate(messages, stop).await })
.await
}
fn llm_type(&self) -> &str {
self.inner.llm_type()
}
async fn _stream(&self, messages: &[Message], stop: Option<&[String]>) -> Result<ChatStream> {
let inner = &self.inner;
self.breaker
.call(|| async move { inner._stream(messages, stop).await })
.await
}
fn bind_tools(
&self,
tools: &[ToolSchema],
tool_choice: Option<ToolChoice>,
) -> Result<Box<dyn BaseChatModel>> {
self.inner.bind_tools(tools, tool_choice)
}
fn profile(&self) -> ModelProfile {
self.inner.profile()
}
fn get_num_tokens_from_messages(&self, messages: &[Message]) -> usize {
self.inner.get_num_tokens_from_messages(messages)
}
}
#[cfg(test)]
mod tests {
use super::*;
use cognis_core::messages::{AIMessage, HumanMessage};
use cognis_core::outputs::ChatGeneration;
use std::sync::atomic::AtomicU32;
struct MockChatModel {
fail_count: u32,
attempts: AtomicU32,
}
impl MockChatModel {
fn always_fails() -> Self {
Self {
fail_count: u32::MAX,
attempts: AtomicU32::new(0),
}
}
fn fails_n_times(n: u32) -> Self {
Self {
fail_count: n,
attempts: AtomicU32::new(0),
}
}
}
#[async_trait]
impl BaseChatModel for MockChatModel {
async fn _generate(
&self,
_messages: &[Message],
_stop: Option<&[String]>,
) -> Result<ChatResult> {
let attempt = self.attempts.fetch_add(1, Ordering::SeqCst);
if attempt < self.fail_count {
Err(CognisError::HttpError {
status: 500,
body: "Internal Server Error".into(),
})
} else {
Ok(ChatResult {
generations: vec![ChatGeneration {
text: "OK".into(),
message: Message::Ai(AIMessage::new("OK")),
generation_info: None,
}],
llm_output: None,
})
}
}
fn llm_type(&self) -> &str {
"mock"
}
}
#[tokio::test]
async fn test_circuit_breaker_closed_on_success() {
let model = CircuitBreakerChatModel::new(
Box::new(MockChatModel::fails_n_times(0)),
3,
Duration::from_secs(60),
);
let msgs = vec![Message::Human(HumanMessage::new("hi"))];
let result = model._generate(&msgs, None).await;
assert!(result.is_ok());
assert_eq!(model.circuit_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_opens_after_threshold() {
let model = CircuitBreakerChatModel::new(
Box::new(MockChatModel::always_fails()),
3,
Duration::from_secs(60),
);
let msgs = vec![Message::Human(HumanMessage::new("hi"))];
for _ in 0..3 {
let _ = model._generate(&msgs, None).await;
}
assert_eq!(model.circuit_state().await, CircuitState::Open);
let result = model._generate(&msgs, None).await;
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(
err.contains("Circuit breaker is open"),
"Expected circuit breaker error, got: {}",
err
);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_to_closed() {
let model = CircuitBreakerChatModel::new(
Box::new(MockChatModel::fails_n_times(3)), 3,
Duration::from_millis(50), );
let msgs = vec![Message::Human(HumanMessage::new("hi"))];
for _ in 0..3 {
let _ = model._generate(&msgs, None).await;
}
assert_eq!(model.circuit_state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(60)).await;
assert_eq!(model.circuit_state().await, CircuitState::HalfOpen);
let result = model._generate(&msgs, None).await;
assert!(result.is_ok());
assert_eq!(model.circuit_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_to_open() {
let model = CircuitBreakerChatModel::new(
Box::new(MockChatModel::always_fails()),
3,
Duration::from_millis(50),
);
let msgs = vec![Message::Human(HumanMessage::new("hi"))];
for _ in 0..3 {
let _ = model._generate(&msgs, None).await;
}
assert_eq!(model.circuit_state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(60)).await;
assert_eq!(model.circuit_state().await, CircuitState::HalfOpen);
let result = model._generate(&msgs, None).await;
assert!(result.is_err());
assert_eq!(model.circuit_state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_circuit_breaker_reset() {
let model = CircuitBreakerChatModel::new(
Box::new(MockChatModel::always_fails()),
2,
Duration::from_secs(60),
);
let msgs = vec![Message::Human(HumanMessage::new("hi"))];
for _ in 0..2 {
let _ = model._generate(&msgs, None).await;
}
assert_eq!(model.circuit_state().await, CircuitState::Open);
model.reset().await;
assert_eq!(model.circuit_state().await, CircuitState::Closed);
}
}