#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use hashbrown::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LoadContext {
pub target_health: HashMap<String, f32>,
pub estimated_response_ms: HashMap<String, u32>,
pub queue_depth: HashMap<String, u32>,
pub active_connections: HashMap<String, u32>,
pub error_rates: HashMap<String, f32>,
pub circuit_breakers: HashMap<String, CircuitState>,
pub global_load: f32,
pub pending_tasks: u32,
}
impl LoadContext {
#[must_use]
pub fn new() -> Self {
Self {
target_health: HashMap::new(),
estimated_response_ms: HashMap::new(),
queue_depth: HashMap::new(),
active_connections: HashMap::new(),
error_rates: HashMap::new(),
circuit_breakers: HashMap::new(),
global_load: 0.0,
pending_tasks: 0,
}
}
pub fn set_health(&mut self, endpoint: impl Into<String>, health: f32) {
self.target_health
.insert(endpoint.into(), health.clamp(0.0, 1.0));
}
pub fn set_response_time(&mut self, endpoint: impl Into<String>, ms: u32) {
self.estimated_response_ms.insert(endpoint.into(), ms);
}
pub fn set_queue_depth(&mut self, endpoint: impl Into<String>, depth: u32) {
self.queue_depth.insert(endpoint.into(), depth);
}
pub fn set_circuit_state(&mut self, endpoint: impl Into<String>, state: CircuitState) {
self.circuit_breakers.insert(endpoint.into(), state);
}
#[must_use]
pub fn get_health(&self, endpoint: &str) -> f32 {
self.target_health.get(endpoint).copied().unwrap_or(1.0)
}
#[must_use]
pub fn get_response_time(&self, endpoint: &str) -> u32 {
self.estimated_response_ms
.get(endpoint)
.copied()
.unwrap_or(1000)
}
#[must_use]
pub fn is_available(&self, endpoint: &str) -> bool {
match self.circuit_breakers.get(endpoint) {
Some(CircuitState::Open) => false,
_ => true,
}
}
#[must_use]
pub fn availability_score(&self, endpoint: &str) -> f32 {
if !self.is_available(endpoint) {
return 0.0;
}
let health = self.get_health(endpoint);
let error_rate = self.error_rates.get(endpoint).copied().unwrap_or(0.0);
let queue_factor = self
.queue_depth
.get(endpoint)
.map(|&d| 1.0 - (d as f32 / 1000.0).min(1.0))
.unwrap_or(1.0);
health * (1.0 - error_rate) * queue_factor
}
#[must_use]
pub fn best_endpoints(&self, limit: usize) -> Vec<(String, f32)> {
let mut scores: Vec<_> = self
.target_health
.keys()
.map(|ep| (ep.clone(), self.availability_score(ep)))
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal));
scores.truncate(limit);
scores
}
#[must_use]
pub fn is_overloaded(&self) -> bool {
self.global_load > 0.9 || self.pending_tasks > 1000
}
#[must_use]
pub fn recommended_concurrency(&self) -> u32 {
if self.is_overloaded() {
1
} else if self.global_load > 0.7 {
2
} else if self.global_load > 0.5 {
4
} else {
8
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
impl Default for CircuitState {
fn default() -> Self {
Self::Closed
}
}
#[cfg(feature = "load")]
impl LoadContext {
pub fn from_celers_stats(stats: &celers_core::WorkerStats) -> Self {
let mut ctx = Self::new();
ctx.pending_tasks = stats.active_tasks;
if let Some(loadavg) = stats.loadavg {
let one_min = loadavg[0];
let max_concurrency = stats
.pool
.as_ref()
.map(|p| f64::from(p.max_concurrency))
.filter(|&c| c > 0.0)
.unwrap_or(1.0);
ctx.global_load = (one_min / max_concurrency).clamp(0.0, 1.0) as f32;
}
if let Some(ref pool) = stats.pool {
let busy = pool.pool_size.saturating_sub(pool.available);
ctx.queue_depth.insert("__pool__".to_string(), busy);
ctx.active_connections
.insert("__pool__".to_string(), pool.pool_size);
}
if let Some(ref broker) = stats.broker {
let state = if broker.connected {
CircuitState::Closed
} else {
CircuitState::Open
};
ctx.circuit_breakers.insert("__broker__".to_string(), state);
let health = if broker.connected { 1.0_f32 } else { 0.0_f32 };
ctx.target_health.insert("__broker__".to_string(), health);
}
if stats.total_tasks > 0 {
let error_rate = stats.failed as f32 / stats.total_tasks as f32;
ctx.error_rates
.insert("__global__".to_string(), error_rate.clamp(0.0, 1.0));
}
ctx
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_context() {
let mut ctx = LoadContext::new();
ctx.set_health("endpoint1", 0.9);
ctx.set_response_time("endpoint1", 100);
assert!(ctx.get_health("endpoint1") > 0.8);
assert!(ctx.is_available("endpoint1"));
}
#[test]
fn test_circuit_breaker() {
let mut ctx = LoadContext::new();
ctx.set_circuit_state("endpoint1", CircuitState::Open);
assert!(!ctx.is_available("endpoint1"));
assert_eq!(ctx.availability_score("endpoint1"), 0.0);
}
#[test]
fn test_best_endpoints() {
let mut ctx = LoadContext::new();
ctx.set_health("good", 0.9);
ctx.set_health("bad", 0.3);
ctx.set_health("medium", 0.6);
let best = ctx.best_endpoints(2);
assert_eq!(best.len(), 2);
assert_eq!(best[0].0, "good");
}
#[test]
fn test_overload_detection() {
let mut ctx = LoadContext::new();
assert!(!ctx.is_overloaded());
ctx.global_load = 0.95;
assert!(ctx.is_overloaded());
}
}