use inferd_engine::Backend;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
pub const DEFAULT_FAILURE_THRESHOLD: u32 = 3;
pub const DEFAULT_FAILURE_WINDOW: Duration = Duration::from_secs(60);
pub const DEFAULT_COOLDOWN: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, Copy)]
pub struct BreakerPolicy {
pub failure_threshold: u32,
pub failure_window: Duration,
pub cooldown: Duration,
}
impl Default for BreakerPolicy {
fn default() -> Self {
Self {
failure_threshold: DEFAULT_FAILURE_THRESHOLD,
failure_window: DEFAULT_FAILURE_WINDOW,
cooldown: DEFAULT_COOLDOWN,
}
}
}
#[derive(Debug, Clone)]
struct BreakerState {
failures: Vec<Instant>,
open_until: Option<Instant>,
}
impl BreakerState {
fn new() -> Self {
Self {
failures: Vec::new(),
open_until: None,
}
}
fn prune(&mut self, now: Instant, window: Duration) {
let cutoff = now.checked_sub(window).unwrap_or(now);
self.failures.retain(|&t| t >= cutoff);
}
fn is_open(&self, now: Instant) -> bool {
self.open_until.is_some_and(|until| now < until)
}
}
struct Slot {
backend: Arc<dyn Backend>,
name: String,
state: BreakerState,
}
pub struct Dispatch {
pub backend: Arc<dyn Backend>,
pub name: String,
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum RouterError {
#[error("no backends registered")]
NoBackends,
#[error("no backend available")]
NoneAvailable,
}
pub struct Router {
slots: RwLock<Vec<Slot>>,
policy: BreakerPolicy,
name_index: HashMap<String, usize>,
}
impl Router {
pub fn new(backends: Vec<Arc<dyn Backend>>) -> Self {
Self::with_policy(backends, BreakerPolicy::default())
}
pub fn with_policy(backends: Vec<Arc<dyn Backend>>, policy: BreakerPolicy) -> Self {
let mut name_index = HashMap::with_capacity(backends.len());
let slots: Vec<Slot> = backends
.into_iter()
.enumerate()
.map(|(i, b)| {
let name = b.name().to_string();
name_index.insert(name.clone(), i);
Slot {
backend: b,
name,
state: BreakerState::new(),
}
})
.collect();
Self {
slots: RwLock::new(slots),
policy,
name_index,
}
}
pub fn len(&self) -> usize {
self.slots.read().expect("router rwlock poisoned").len()
}
pub fn is_empty(&self) -> bool {
self.slots
.read()
.expect("router rwlock poisoned")
.is_empty()
}
pub fn dispatch(&self) -> Result<Dispatch, RouterError> {
let now = Instant::now();
let mut guard = self.slots.write().expect("router rwlock poisoned");
if guard.is_empty() {
return Err(RouterError::NoBackends);
}
for slot in guard.iter_mut() {
slot.state.prune(now, self.policy.failure_window);
if !slot.backend.ready() {
continue;
}
if slot.state.is_open(now) {
continue;
}
if slot.state.open_until.is_some() {
slot.state.open_until = None;
slot.state.failures.clear();
}
return Ok(Dispatch {
backend: Arc::clone(&slot.backend),
name: slot.name.clone(),
});
}
Err(RouterError::NoneAvailable)
}
pub fn all_ready(&self) -> bool {
let guard = self.slots.read().expect("router rwlock poisoned");
!guard.is_empty() && guard.iter().all(|s| s.backend.ready())
}
pub fn record_success(&self, name: &str) {
let Some(&idx) = self.name_index.get(name) else {
return;
};
let mut guard = self.slots.write().expect("router rwlock poisoned");
if let Some(slot) = guard.get_mut(idx) {
slot.state.failures.clear();
slot.state.open_until = None;
}
}
pub fn record_failure(&self, name: &str) {
let Some(&idx) = self.name_index.get(name) else {
return;
};
let now = Instant::now();
let mut guard = self.slots.write().expect("router rwlock poisoned");
let Some(slot) = guard.get_mut(idx) else {
return;
};
slot.state.prune(now, self.policy.failure_window);
slot.state.failures.push(now);
if slot.state.failures.len() as u32 >= self.policy.failure_threshold {
slot.state.open_until = Some(now + self.policy.cooldown);
}
}
pub fn breaker_open(&self, name: &str) -> bool {
let Some(&idx) = self.name_index.get(name) else {
return false;
};
let now = Instant::now();
let guard = self.slots.read().expect("router rwlock poisoned");
guard.get(idx).is_some_and(|slot| slot.state.is_open(now))
}
}
#[cfg(test)]
mod tests {
use super::*;
use inferd_engine::mock::Mock;
fn fast_policy() -> BreakerPolicy {
BreakerPolicy {
failure_threshold: 2,
failure_window: Duration::from_millis(500),
cooldown: Duration::from_millis(100),
}
}
#[test]
fn empty_router_dispatch_returns_no_backends() {
let router = Router::new(vec![]);
assert!(router.is_empty());
assert_eq!(router.dispatch().err(), Some(RouterError::NoBackends));
}
#[test]
fn dispatch_returns_ready_backend() {
let mock = Arc::new(Mock::new());
let router = Router::new(vec![mock.clone()]);
let chosen = router.dispatch().expect("dispatch ok");
assert_eq!(chosen.name, "mock");
assert_eq!(chosen.backend.name(), "mock");
assert!(router.all_ready());
}
#[test]
fn unready_backend_returns_none_available() {
let mock = Arc::new(Mock::new());
mock.set_ready(false);
let router = Router::new(vec![mock]);
assert_eq!(router.dispatch().err(), Some(RouterError::NoneAvailable));
assert!(!router.all_ready());
}
#[test]
fn priority_picks_first_ready_backend() {
struct Named {
inner: Mock,
name: &'static str,
}
#[async_trait::async_trait]
impl Backend for Named {
fn name(&self) -> &str {
self.name
}
fn ready(&self) -> bool {
self.inner.ready()
}
async fn generate(
&self,
req: inferd_proto::Resolved,
) -> Result<inferd_engine::TokenStream, inferd_engine::GenerateError> {
self.inner.generate(req).await
}
}
let high = Arc::new(Named {
inner: Mock::new(),
name: "high",
});
let low = Arc::new(Named {
inner: Mock::new(),
name: "low",
});
high.inner.set_ready(false);
let router = Router::new(vec![high.clone(), low.clone()]);
assert_eq!(router.dispatch().unwrap().name, "low");
high.inner.set_ready(true);
assert_eq!(router.dispatch().unwrap().name, "high");
}
#[test]
fn breaker_opens_after_threshold_failures() {
let mock = Arc::new(Mock::new());
let router = Router::with_policy(vec![mock], fast_policy());
router.record_failure("mock");
assert!(!router.breaker_open("mock"));
router.record_failure("mock");
assert!(router.breaker_open("mock"));
assert_eq!(router.dispatch().err(), Some(RouterError::NoneAvailable));
}
#[test]
fn success_resets_failure_count() {
let mock = Arc::new(Mock::new());
let router = Router::with_policy(vec![mock], fast_policy());
router.record_failure("mock");
router.record_success("mock");
router.record_failure("mock");
assert!(!router.breaker_open("mock"));
}
#[test]
fn breaker_recovers_after_cooldown() {
let mock = Arc::new(Mock::new());
let router = Router::with_policy(vec![mock], fast_policy());
router.record_failure("mock");
router.record_failure("mock");
assert!(router.breaker_open("mock"));
std::thread::sleep(Duration::from_millis(150));
assert!(!router.breaker_open("mock"));
assert_eq!(router.dispatch().unwrap().name, "mock");
}
#[test]
fn old_failures_outside_window_dont_open_breaker() {
let mock = Arc::new(Mock::new());
let router = Router::with_policy(vec![mock], fast_policy());
router.record_failure("mock");
std::thread::sleep(Duration::from_millis(600)); router.record_failure("mock");
assert!(!router.breaker_open("mock"));
}
#[test]
fn record_failure_for_unknown_backend_is_a_noop() {
let mock = Arc::new(Mock::new());
let router = Router::new(vec![mock]);
router.record_failure("does-not-exist"); router.record_success("does-not-exist");
assert!(!router.breaker_open("does-not-exist"));
}
}