use std::{
sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
},
time::{Duration, Instant},
};
use crate::error::{RealizarError, Result};
#[derive(Debug, Clone)]
pub struct CapacityConfig {
pub capacity: usize,
pub num_experts: usize,
}
pub struct CapacityFactorRouter {
config: CapacityConfig,
queue_depths: Vec<AtomicUsize>,
}
impl CapacityFactorRouter {
#[must_use]
pub fn new(config: CapacityConfig) -> Self {
let queue_depths = (0..config.num_experts)
.map(|_| AtomicUsize::new(0))
.collect();
Self {
config,
queue_depths,
}
}
pub fn route(&self, scores: &[f32]) -> Result<usize> {
if scores.len() != self.config.num_experts {
return Err(RealizarError::MoeError(format!(
"Expected {} scores, got {}",
self.config.num_experts,
scores.len()
)));
}
let top2 = Self::top_k_indices(scores, 2);
let primary = top2[0];
if self.queue_depths[primary].load(Ordering::Relaxed) < self.config.capacity {
Ok(primary)
} else if top2.len() > 1 {
Ok(top2[1])
} else {
Err(RealizarError::ExpertCapacityExceeded {
expert_id: primary,
queue_depth: self.queue_depths[primary].load(Ordering::Relaxed),
capacity: self.config.capacity,
})
}
}
pub fn record_start(&self, expert_id: usize) {
self.queue_depths[expert_id].fetch_add(1, Ordering::Relaxed);
}
pub fn record_end(&self, expert_id: usize) {
self.queue_depths[expert_id].fetch_sub(1, Ordering::Relaxed);
}
#[must_use]
pub fn queue_depth(&self, expert_id: usize) -> usize {
self.queue_depths[expert_id].load(Ordering::Relaxed)
}
fn top_k_indices(scores: &[f32], k: usize) -> Vec<usize> {
let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.into_iter().take(k).map(|(i, _)| i).collect()
}
}
#[derive(Debug, Clone)]
pub struct PowerOfTwoConfig {
pub num_experts: usize,
pub capacity: usize,
}
pub struct PowerOfTwoChoicesRouter {
config: PowerOfTwoConfig,
queue_depths: Vec<AtomicUsize>,
}
impl PowerOfTwoChoicesRouter {
#[must_use]
pub fn new(config: PowerOfTwoConfig) -> Self {
let queue_depths = (0..config.num_experts)
.map(|_| AtomicUsize::new(0))
.collect();
Self {
config,
queue_depths,
}
}
pub fn route(&self, scores: &[f32]) -> Result<usize> {
if scores.len() != self.config.num_experts {
return Err(RealizarError::MoeError(format!(
"Expected {} scores, got {}",
self.config.num_experts,
scores.len()
)));
}
let top2 = Self::top_k_indices(scores, 2);
let mut best_choice = None;
let mut best_load = usize::MAX;
for &expert_id in &top2 {
let load = self.queue_depths[expert_id].load(Ordering::Relaxed);
if load < self.config.capacity && load < best_load {
best_load = load;
best_choice = Some(expert_id);
}
}
best_choice.ok_or_else(|| RealizarError::ExpertCapacityExceeded {
expert_id: top2[0],
queue_depth: self.queue_depths[top2[0]].load(Ordering::Relaxed),
capacity: self.config.capacity,
})
}
pub fn record_start(&self, expert_id: usize) {
self.queue_depths[expert_id].fetch_add(1, Ordering::Relaxed);
}
pub fn record_end(&self, expert_id: usize) {
self.queue_depths[expert_id].fetch_sub(1, Ordering::Relaxed);
}
#[must_use]
pub fn queue_depth(&self, expert_id: usize) -> usize {
self.queue_depths[expert_id].load(Ordering::Relaxed)
}
fn top_k_indices(scores: &[f32], k: usize) -> Vec<usize> {
let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.into_iter().take(k).map(|(i, _)| i).collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: usize,
pub success_threshold: usize,
pub timeout_ms: u64,
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Mutex<CircuitBreakerState>,
}
struct CircuitBreakerState {
current: CircuitState,
failure_count: usize,
success_count: usize,
last_failure_time: Option<Instant>,
}
impl CircuitBreaker {
#[must_use]
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Mutex::new(CircuitBreakerState {
current: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
}),
}
}
#[must_use]
pub fn state(&self) -> CircuitState {
let mut state = self.state.lock().expect("CircuitBreaker mutex poisoned");
self.maybe_transition_to_half_open(&mut state);
state.current
}
#[must_use]
pub fn allow_request(&self) -> bool {
let mut state = self.state.lock().expect("CircuitBreaker mutex poisoned");
self.maybe_transition_to_half_open(&mut state);
match state.current {
CircuitState::Open => false,
CircuitState::Closed | CircuitState::HalfOpen => true,
}
}
pub fn record_success(&self) {
let mut state = self.state.lock().expect("CircuitBreaker mutex poisoned");
self.maybe_transition_to_half_open(&mut state);
match state.current {
CircuitState::Closed => {
state.failure_count = 0; },
CircuitState::HalfOpen => {
state.success_count += 1;
if state.success_count >= self.config.success_threshold {
state.current = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
}
},
CircuitState::Open => {}, }
}
pub fn record_failure(&self) {
let mut state = self.state.lock().expect("CircuitBreaker mutex poisoned");
state.failure_count += 1;
state.last_failure_time = Some(Instant::now());
if state.failure_count >= self.config.failure_threshold {
state.current = CircuitState::Open;
state.success_count = 0;
}
}
fn maybe_transition_to_half_open(&self, state: &mut CircuitBreakerState) {
if state.current == CircuitState::Open {
if let Some(last_failure) = state.last_failure_time {
let timeout = Duration::from_millis(self.config.timeout_ms);
if last_failure.elapsed() >= timeout {
state.current = CircuitState::HalfOpen;
state.success_count = 0;
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct HeijunkaConfig {
pub target_latency_ms: f64,
pub max_concurrency: usize,
}
#[derive(Debug, Clone)]
pub struct LoadSheddingDecision {
pub shed_load: bool,
pub recommended_concurrency: usize,
}
pub struct HeijunkaController {
config: HeijunkaConfig,
}
impl HeijunkaController {
#[must_use]
pub fn new(config: HeijunkaConfig) -> Self {
Self { config }
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_sign_loss)]
pub fn optimal_concurrency(&self, arrival_rate: f64, latency_ms: f64) -> usize {
let optimal = (arrival_rate * latency_ms / 1000.0).ceil() as usize;
optimal.clamp(1, self.config.max_concurrency)
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_sign_loss)]
#[allow(clippy::cast_precision_loss)]
pub fn should_shed_load(
&self,
current_latency_ms: f64,
current_concurrency: usize,
) -> LoadSheddingDecision {
let should_shed = current_latency_ms > self.config.target_latency_ms
&& current_concurrency >= self.config.max_concurrency;
let ratio = self.config.target_latency_ms / current_latency_ms;
let concurrency_f64: f64 = current_concurrency as f64;
let recommended = (concurrency_f64 * ratio).ceil() as usize;
LoadSheddingDecision {
shed_load: should_shed,
recommended_concurrency: recommended.clamp(1, self.config.max_concurrency),
}
}
#[must_use]
pub fn target_latency_ms(&self) -> f64 {
self.config.target_latency_ms
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum AndonTrigger {
ModelChecksumMismatch {
model_id: String,
},
LatencyExceeded {
p99_ms: f64,
threshold_ms: f64,
},
ErrorRateThreshold {
rate: f64,
threshold: f64,
},
ExpertImbalance {
imbalance_ratio: f64,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AndonResponse {
Rollback,
Notify,
Quarantine,
}
impl AndonTrigger {
#[must_use]
pub fn response(&self) -> AndonResponse {
match self {
Self::ModelChecksumMismatch { .. } => AndonResponse::Rollback,
Self::ErrorRateThreshold { rate, threshold } => {
if *rate > threshold * 2.0 {
AndonResponse::Quarantine
} else {
AndonResponse::Notify
}
},
Self::LatencyExceeded { .. } | Self::ExpertImbalance { .. } => AndonResponse::Notify,
}
}
#[must_use]
pub fn is_critical(&self) -> bool {
matches!(
self.response(),
AndonResponse::Rollback | AndonResponse::Quarantine
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_power_of_two_choices_selects_least_loaded() {
let router = PowerOfTwoChoicesRouter::new(PowerOfTwoConfig {
num_experts: 4,
capacity: 100,
});
for _ in 0..50 {
router.record_start(1);
}
let scores = vec![0.1, 0.9, 0.8, 0.1];
let choice = router.route(&scores).unwrap();
assert_eq!(choice, 2);
}
#[test]
fn test_power_of_two_choices_equal_load_picks_best_score() {
let router = PowerOfTwoChoicesRouter::new(PowerOfTwoConfig {
num_experts: 4,
capacity: 100,
});
let scores = vec![0.1, 0.9, 0.8, 0.1];
let choice = router.route(&scores).unwrap();
assert_eq!(choice, 1);
}
#[test]
fn test_power_of_two_choices_respects_capacity() {
let router = PowerOfTwoChoicesRouter::new(PowerOfTwoConfig {
num_experts: 2,
capacity: 5,
});
for _ in 0..5 {
router.record_start(0);
router.record_start(1);
}
let scores = vec![0.9, 0.8];
let result = router.route(&scores);
assert!(result.is_err());
}
#[test]
fn test_circuit_breaker_starts_closed() {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 5,
success_threshold: 3,
timeout_ms: 1000,
});
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_opens_on_failures() {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 3,
success_threshold: 2,
timeout_ms: 1000,
});
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_circuit_breaker_blocks_when_open() {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 1,
timeout_ms: 100_000, });
cb.record_failure();
assert!(!cb.allow_request());
}
#[test]
fn test_circuit_breaker_half_open_after_timeout() {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 1,
timeout_ms: 1, });
cb.record_failure();
std::thread::sleep(std::time::Duration::from_millis(5));
assert_eq!(cb.state(), CircuitState::HalfOpen);
assert!(cb.allow_request()); }
#[test]
fn test_circuit_breaker_closes_on_success_in_half_open() {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
success_threshold: 2,
timeout_ms: 1,
});
cb.record_failure();
std::thread::sleep(std::time::Duration::from_millis(5));
cb.record_success();
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_heijunka_calculates_optimal_concurrency() {
let controller = HeijunkaController::new(HeijunkaConfig {
target_latency_ms: 100.0,
max_concurrency: 100,
});
let concurrency = controller.optimal_concurrency(10.0, 100.0);
assert_eq!(concurrency, 1);
}
#[test]
fn test_heijunka_caps_at_max_concurrency() {
let controller = HeijunkaController::new(HeijunkaConfig {
target_latency_ms: 100.0,
max_concurrency: 10,
});
let concurrency = controller.optimal_concurrency(1000.0, 100.0);
assert_eq!(concurrency, 10);
}
#[test]
fn test_heijunka_load_leveling_decision() {
let controller = HeijunkaController::new(HeijunkaConfig {
target_latency_ms: 100.0,
max_concurrency: 50,
});
let decision = controller.should_shed_load(150.0, 50);
assert!(decision.shed_load);
let decision = controller.should_shed_load(50.0, 10);
assert!(!decision.shed_load);
}
#[test]
fn test_route_to_best_expert() {
let router = CapacityFactorRouter::new(CapacityConfig {
capacity: 10,
num_experts: 4,
});
let scores = vec![0.1, 0.5, 0.3, 0.1];
assert_eq!(router.route(&scores).unwrap(), 1);
}
#[test]
fn test_fallback_when_primary_full() {
let router = CapacityFactorRouter::new(CapacityConfig {
capacity: 1,
num_experts: 4,
});
router.record_start(1); let scores = vec![0.1, 0.5, 0.3, 0.1];
assert_eq!(router.route(&scores).unwrap(), 2); }
#[test]
fn test_queue_depth_tracking() {
let router = CapacityFactorRouter::new(CapacityConfig {
capacity: 10,
num_experts: 2,
});
assert_eq!(router.queue_depth(0), 0);
router.record_start(0);
assert_eq!(router.queue_depth(0), 1);
router.record_end(0);
assert_eq!(router.queue_depth(0), 0);
}
#[test]
fn test_wrong_score_count_error() {
let router = CapacityFactorRouter::new(CapacityConfig {
capacity: 10,
num_experts: 4,
});
let scores = vec![0.5, 0.5]; assert!(router.route(&scores).is_err());
}
#[test]
fn test_andon_checksum_triggers_rollback() {
let trigger = AndonTrigger::ModelChecksumMismatch {
model_id: "model-1".to_string(),
};
assert_eq!(trigger.response(), AndonResponse::Rollback);
assert!(trigger.is_critical());
}
#[test]
fn test_andon_latency_triggers_notify() {
let trigger = AndonTrigger::LatencyExceeded {
p99_ms: 150.0,
threshold_ms: 100.0,
};
assert_eq!(trigger.response(), AndonResponse::Notify);
assert!(!trigger.is_critical());
}
#[test]
fn test_andon_high_error_rate_quarantines() {
let trigger = AndonTrigger::ErrorRateThreshold {
rate: 0.25,
threshold: 0.1,
};
assert_eq!(trigger.response(), AndonResponse::Quarantine);
assert!(trigger.is_critical());
}
#[test]
fn test_andon_moderate_error_rate_notifies() {
let trigger = AndonTrigger::ErrorRateThreshold {
rate: 0.15,
threshold: 0.1,
};
assert_eq!(trigger.response(), AndonResponse::Notify);
}
}