use std::{
sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
},
time::{Duration, Instant},
};
use crate::error::{RealizarError, Result};
#[derive(Debug, Clone)]
pub struct CapacityConfig {
pub capacity: usize,
pub num_experts: usize,
}
pub struct CapacityFactorRouter {
config: CapacityConfig,
queue_depths: Vec<AtomicUsize>,
}
impl CapacityFactorRouter {
#[must_use]
pub fn new(config: CapacityConfig) -> Self {
let queue_depths = (0..config.num_experts)
.map(|_| AtomicUsize::new(0))
.collect();
Self {
config,
queue_depths,
}
}
pub fn route(&self, scores: &[f32]) -> Result<usize> {
if scores.len() != self.config.num_experts {
return Err(RealizarError::MoeError(format!(
"Expected {} scores, got {}",
self.config.num_experts,
scores.len()
)));
}
let top2 = Self::top_k_indices(scores, 2);
let primary = top2[0];
if self.queue_depths[primary].load(Ordering::Relaxed) < self.config.capacity {
Ok(primary)
} else if top2.len() > 1 {
Ok(top2[1])
} else {
Err(RealizarError::ExpertCapacityExceeded {
expert_id: primary,
queue_depth: self.queue_depths[primary].load(Ordering::Relaxed),
capacity: self.config.capacity,
})
}
}
pub fn record_start(&self, expert_id: usize) {
self.queue_depths[expert_id].fetch_add(1, Ordering::Relaxed);
}
pub fn record_end(&self, expert_id: usize) {
self.queue_depths[expert_id].fetch_sub(1, Ordering::Relaxed);
}
#[must_use]
pub fn queue_depth(&self, expert_id: usize) -> usize {
self.queue_depths[expert_id].load(Ordering::Relaxed)
}
fn top_k_indices(scores: &[f32], k: usize) -> Vec<usize> {
let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.into_iter().take(k).map(|(i, _)| i).collect()
}
}
#[derive(Debug, Clone)]
pub struct PowerOfTwoConfig {
pub num_experts: usize,
pub capacity: usize,
}
pub struct PowerOfTwoChoicesRouter {
config: PowerOfTwoConfig,
queue_depths: Vec<AtomicUsize>,
}
impl PowerOfTwoChoicesRouter {
#[must_use]
pub fn new(config: PowerOfTwoConfig) -> Self {
let queue_depths = (0..config.num_experts)
.map(|_| AtomicUsize::new(0))
.collect();
Self {
config,
queue_depths,
}
}
pub fn route(&self, scores: &[f32]) -> Result<usize> {
if scores.len() != self.config.num_experts {
return Err(RealizarError::MoeError(format!(
"Expected {} scores, got {}",
self.config.num_experts,
scores.len()
)));
}
let top2 = Self::top_k_indices(scores, 2);
let mut best_choice = None;
let mut best_load = usize::MAX;
for &expert_id in &top2 {
let load = self.queue_depths[expert_id].load(Ordering::Relaxed);
if load < self.config.capacity && load < best_load {
best_load = load;
best_choice = Some(expert_id);
}
}
best_choice.ok_or_else(|| RealizarError::ExpertCapacityExceeded {
expert_id: top2[0],
queue_depth: self.queue_depths[top2[0]].load(Ordering::Relaxed),
capacity: self.config.capacity,
})
}
pub fn record_start(&self, expert_id: usize) {
self.queue_depths[expert_id].fetch_add(1, Ordering::Relaxed);
}
pub fn record_end(&self, expert_id: usize) {
self.queue_depths[expert_id].fetch_sub(1, Ordering::Relaxed);
}
#[must_use]
pub fn queue_depth(&self, expert_id: usize) -> usize {
self.queue_depths[expert_id].load(Ordering::Relaxed)
}
fn top_k_indices(scores: &[f32], k: usize) -> Vec<usize> {
let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.into_iter().take(k).map(|(i, _)| i).collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: usize,
pub success_threshold: usize,
pub timeout_ms: u64,
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Mutex<CircuitBreakerState>,
}
struct CircuitBreakerState {
current: CircuitState,
failure_count: usize,
success_count: usize,
last_failure_time: Option<Instant>,
}
impl CircuitBreaker {
#[must_use]
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Mutex::new(CircuitBreakerState {
current: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
}),
}
}
#[must_use]
pub fn state(&self) -> CircuitState {
let mut state = self.state.lock().expect("CircuitBreaker mutex poisoned");
self.maybe_transition_to_half_open(&mut state);
state.current
}
#[must_use]
pub fn allow_request(&self) -> bool {
let mut state = self.state.lock().expect("CircuitBreaker mutex poisoned");
self.maybe_transition_to_half_open(&mut state);
match state.current {
CircuitState::Open => false,
CircuitState::Closed | CircuitState::HalfOpen => true,
}
}
pub fn record_success(&self) {
let mut state = self.state.lock().expect("CircuitBreaker mutex poisoned");
self.maybe_transition_to_half_open(&mut state);
match state.current {
CircuitState::Closed => {
state.failure_count = 0; },
CircuitState::HalfOpen => {
state.success_count += 1;
if state.success_count >= self.config.success_threshold {
state.current = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
}
},
CircuitState::Open => {}, }
}
pub fn record_failure(&self) {
let mut state = self.state.lock().expect("CircuitBreaker mutex poisoned");
state.failure_count += 1;
state.last_failure_time = Some(Instant::now());
if state.failure_count >= self.config.failure_threshold {
state.current = CircuitState::Open;
state.success_count = 0;
}
}
fn maybe_transition_to_half_open(&self, state: &mut CircuitBreakerState) {
if state.current == CircuitState::Open {
if let Some(last_failure) = state.last_failure_time {
let timeout = Duration::from_millis(self.config.timeout_ms);
if last_failure.elapsed() >= timeout {
state.current = CircuitState::HalfOpen;
state.success_count = 0;
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct HeijunkaConfig {
pub target_latency_ms: f64,
pub max_concurrency: usize,
}
#[derive(Debug, Clone)]
pub struct LoadSheddingDecision {
pub shed_load: bool,
pub recommended_concurrency: usize,
}
pub struct HeijunkaController {
config: HeijunkaConfig,
}
include!("mod_optimal_concurrency_heijunka.rs");