use super::router::RoutingDecision;
use log::info;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingStats {
pub total_routings: u64,
pub total_tokens: u64,
pub tokens_dropped: u64,
pub expert_utilization: Vec<f32>,
pub average_load_balance_loss: f32,
pub average_router_z_loss: f32,
pub routing_efficiency: f32,
pub expert_token_counts: Vec<u64>,
pub routing_latency_stats: LatencyStats,
pub load_variance_history: Vec<f32>,
pub capacity_stats: CapacityStats,
}
impl RoutingStats {
pub fn new() -> Self {
Self {
total_routings: 0,
total_tokens: 0,
tokens_dropped: 0,
expert_utilization: Vec::new(),
average_load_balance_loss: 0.0,
average_router_z_loss: 0.0,
routing_efficiency: 0.0,
expert_token_counts: Vec::new(),
routing_latency_stats: LatencyStats::new(),
load_variance_history: Vec::new(),
capacity_stats: CapacityStats::new(),
}
}
pub fn record_routing(&mut self, routing_decision: &RoutingDecision) {
self.total_routings += 1;
self.total_tokens += routing_decision.total_tokens as u64;
self.tokens_dropped += routing_decision.tokens_dropped as u64;
let alpha = 1.0 / self.total_routings as f32;
self.average_load_balance_loss = alpha * routing_decision.load_balance_loss
+ (1.0 - alpha) * self.average_load_balance_loss;
self.average_router_z_loss =
alpha * routing_decision.router_z_loss + (1.0 - alpha) * self.average_router_z_loss;
if self.total_tokens > 0 {
self.routing_efficiency =
(self.total_tokens - self.tokens_dropped) as f32 / self.total_tokens as f32 * 100.0;
}
if self.expert_utilization.len() != routing_decision.expert_capacities.len() {
self.expert_utilization = vec![0.0; routing_decision.expert_capacities.len()];
self.expert_token_counts = vec![0; routing_decision.expert_capacities.len()];
}
for (i, &capacity) in routing_decision.expert_capacities.iter().enumerate() {
if i < self.expert_utilization.len() {
let utilization = if routing_decision.total_tokens > 0 {
capacity as f32 / routing_decision.total_tokens as f32
} else {
0.0
};
self.expert_utilization[i] =
alpha * utilization + (1.0 - alpha) * self.expert_utilization[i];
self.expert_token_counts[i] += capacity as u64;
}
}
let load_variance = self.calculate_load_variance(&routing_decision.expert_capacities);
self.load_variance_history.push(load_variance);
if self.load_variance_history.len() > 1000 {
self.load_variance_history.remove(0);
}
self.capacity_stats.update(routing_decision);
}
pub fn record_routing_latency(&mut self, latency: Duration) {
self.routing_latency_stats.record_latency(latency);
}
fn calculate_load_variance(&self, capacities: &[usize]) -> f32 {
if capacities.is_empty() {
return 0.0;
}
let mean = capacities.iter().sum::<usize>() as f32 / capacities.len() as f32;
let variance = capacities
.iter()
.map(|&cap| {
let diff = cap as f32 - mean;
diff * diff
})
.sum::<f32>()
/ capacities.len() as f32;
variance
}
pub fn utilization_cv(&self) -> f32 {
if self.expert_utilization.is_empty() {
return 0.0;
}
let mean =
self.expert_utilization.iter().sum::<f32>() / self.expert_utilization.len() as f32;
if mean <= 0.0 {
return 0.0;
}
let variance = self
.expert_utilization
.iter()
.map(|&util| {
let diff = util - mean;
diff * diff
})
.sum::<f32>()
/ self.expert_utilization.len() as f32;
variance.sqrt() / mean
}
pub fn most_utilized_expert(&self) -> Option<(usize, f32)> {
self.expert_utilization
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, &util)| (idx, util))
}
pub fn least_utilized_expert(&self) -> Option<(usize, f32)> {
self.expert_utilization
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, &util)| (idx, util))
}
pub fn utilization_summary(&self) -> HashMap<String, f32> {
let mut summary = HashMap::new();
if !self.expert_utilization.is_empty() {
let mean =
self.expert_utilization.iter().sum::<f32>() / self.expert_utilization.len() as f32;
let min = self
.expert_utilization
.iter()
.copied()
.fold(f32::INFINITY, f32::min);
let max = self
.expert_utilization
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
summary.insert("mean_utilization".to_string(), mean);
summary.insert("min_utilization".to_string(), min);
summary.insert("max_utilization".to_string(), max);
summary.insert("utilization_cv".to_string(), self.utilization_cv());
}
summary.insert("routing_efficiency".to_string(), self.routing_efficiency);
summary.insert(
"average_load_balance_loss".to_string(),
self.average_load_balance_loss,
);
summary.insert(
"average_router_z_loss".to_string(),
self.average_router_z_loss,
);
summary
}
pub fn recent_load_variance_trend(&self, window: usize) -> f32 {
if self.load_variance_history.len() < 2 {
return 0.0;
}
let start_idx = self.load_variance_history.len().saturating_sub(window);
let recent_variances = &self.load_variance_history[start_idx..];
if recent_variances.len() < 2 {
return 0.0;
}
let n = recent_variances.len() as f32;
let sum_x: f32 = (0..recent_variances.len()).map(|i| i as f32).sum();
let sum_y: f32 = recent_variances.iter().sum();
let sum_xy: f32 = recent_variances
.iter()
.enumerate()
.map(|(i, &y)| i as f32 * y)
.sum();
let sum_x2: f32 = (0..recent_variances.len())
.map(|i| (i as f32).powi(2))
.sum();
let denominator = n * sum_x2 - sum_x.powi(2);
if denominator.abs() < f32::EPSILON {
0.0
} else {
(n * sum_xy - sum_x * sum_y) / denominator
}
}
pub fn reset(&mut self) {
*self = Self::new();
}
pub fn throughput_stats(&self) -> ThroughputStats {
ThroughputStats {
total_tokens: self.total_tokens,
total_routings: self.total_routings,
tokens_per_routing: if self.total_routings > 0 {
self.total_tokens as f32 / self.total_routings as f32
} else {
0.0
},
routing_efficiency: self.routing_efficiency,
average_latency: self.routing_latency_stats.average_latency(),
}
}
}
impl Default for RoutingStats {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LatencyStats {
pub total_measurements: u64,
pub total_latency_ms: f64,
pub min_latency_ms: f64,
pub max_latency_ms: f64,
pub recent_latencies: Vec<f64>,
}
impl LatencyStats {
pub fn new() -> Self {
Self {
total_measurements: 0,
total_latency_ms: 0.0,
min_latency_ms: f64::INFINITY,
max_latency_ms: 0.0,
recent_latencies: Vec::new(),
}
}
pub fn record_latency(&mut self, latency: Duration) {
let latency_ms = latency.as_secs_f64() * 1000.0;
self.total_measurements += 1;
self.total_latency_ms += latency_ms;
self.min_latency_ms = self.min_latency_ms.min(latency_ms);
self.max_latency_ms = self.max_latency_ms.max(latency_ms);
self.recent_latencies.push(latency_ms);
if self.recent_latencies.len() > 1000 {
self.recent_latencies.remove(0);
}
}
pub fn average_latency(&self) -> f64 {
if self.total_measurements > 0 {
self.total_latency_ms / self.total_measurements as f64
} else {
0.0
}
}
pub fn percentile(&self, p: f64) -> f64 {
if self.recent_latencies.is_empty() {
return 0.0;
}
let mut sorted_latencies = self.recent_latencies.clone();
sorted_latencies.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let index = ((p / 100.0) * (sorted_latencies.len() - 1) as f64) as usize;
sorted_latencies[index.min(sorted_latencies.len() - 1)]
}
pub fn p95_latency(&self) -> f64 {
self.percentile(95.0)
}
pub fn p99_latency(&self) -> f64 {
self.percentile(99.0)
}
}
impl Default for LatencyStats {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapacityStats {
pub average_utilization: f32,
pub peak_utilization: f32,
pub capacity_exceeded_count: u64,
pub total_capacity: u64,
pub total_used: u64,
}
impl CapacityStats {
pub fn new() -> Self {
Self {
average_utilization: 0.0,
peak_utilization: 0.0,
capacity_exceeded_count: 0,
total_capacity: 0,
total_used: 0,
}
}
pub fn update(&mut self, routing_decision: &RoutingDecision) {
let current_utilization = routing_decision.expert_capacities.iter().sum::<usize>() as f32
/ (routing_decision.expert_capacities.len() as f32 * 100.0);
let alpha = 0.1; self.average_utilization =
alpha * current_utilization + (1.0 - alpha) * self.average_utilization;
self.peak_utilization = self.peak_utilization.max(current_utilization);
if routing_decision.tokens_dropped > 0 {
self.capacity_exceeded_count += 1;
}
self.total_used += routing_decision.expert_capacities.iter().sum::<usize>() as u64;
self.total_capacity += routing_decision.expert_capacities.len() as u64 * 100;
}
pub fn overall_utilization(&self) -> f32 {
if self.total_capacity > 0 {
(self.total_used as f32 / self.total_capacity as f32) * 100.0
} else {
0.0
}
}
}
impl Default for CapacityStats {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThroughputStats {
pub total_tokens: u64,
pub total_routings: u64,
pub tokens_per_routing: f32,
pub routing_efficiency: f32,
pub average_latency: f64,
}
impl ThroughputStats {
pub fn tokens_per_second(&self) -> f64 {
if self.average_latency > 0.0 {
(self.tokens_per_routing as f64 * 1000.0) / self.average_latency
} else {
0.0
}
}
pub fn routings_per_second(&self) -> f64 {
if self.average_latency > 0.0 {
1000.0 / self.average_latency
} else {
0.0
}
}
}
pub mod monitoring {
use super::*;
pub struct PerformanceMonitor {
stats: RoutingStats,
start_time: Instant,
last_report_time: Instant,
report_interval: Duration,
}
impl PerformanceMonitor {
pub fn new(report_interval: Duration) -> Self {
let now = Instant::now();
Self {
stats: RoutingStats::new(),
start_time: now,
last_report_time: now,
report_interval,
}
}
pub fn record_routing(&mut self, routing_decision: &RoutingDecision, latency: Duration) {
self.stats.record_routing(routing_decision);
self.stats.record_routing_latency(latency);
if self.last_report_time.elapsed() >= self.report_interval {
self.print_report();
self.last_report_time = Instant::now();
}
}
pub fn print_report(&self) {
let uptime = self.start_time.elapsed();
let throughput = self.stats.throughput_stats();
info!("🔍 Expert Routing Performance Report");
info!(" Uptime: {:.2}s", uptime.as_secs_f64());
info!(" Total routings: {}", self.stats.total_routings);
info!(" Total tokens: {}", self.stats.total_tokens);
info!(
" Routing efficiency: {:.2}%",
self.stats.routing_efficiency
);
info!(" Tokens/second: {:.2}", throughput.tokens_per_second());
info!(
" Average latency: {:.2}ms",
self.stats.routing_latency_stats.average_latency()
);
info!(
" P95 latency: {:.2}ms",
self.stats.routing_latency_stats.p95_latency()
);
info!(" Utilization CV: {:.3}", self.stats.utilization_cv());
if let Some((idx, util)) = self.stats.most_utilized_expert() {
info!(" Most utilized expert: {} ({:.2}%)", idx, util * 100.0);
}
if let Some((idx, util)) = self.stats.least_utilized_expert() {
info!(" Least utilized expert: {} ({:.2}%)", idx, util * 100.0);
}
}
pub fn stats(&self) -> &RoutingStats {
&self.stats
}
pub fn reset(&mut self) {
self.stats.reset();
self.start_time = Instant::now();
self.last_report_time = Instant::now();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expert_parallelism::router::{ExpertAssignment, RoutingDecision};
#[test]
fn test_routing_stats_creation() {
let stats = RoutingStats::new();
assert_eq!(stats.total_routings, 0);
assert_eq!(stats.total_tokens, 0);
assert_eq!(stats.routing_efficiency, 0.0);
}
#[test]
fn test_routing_stats_recording() {
let mut stats = RoutingStats::new();
let routing_decision = RoutingDecision {
expert_assignments: vec![vec![ExpertAssignment::new(0, 0.8, 0, 0)]],
expert_capacities: vec![5, 3, 2, 0],
total_tokens: 10,
tokens_dropped: 0,
load_balance_loss: 0.1,
router_z_loss: 0.05,
auxiliary_loss: 0.15,
};
stats.record_routing(&routing_decision);
assert_eq!(stats.total_routings, 1);
assert_eq!(stats.total_tokens, 10);
assert_eq!(stats.tokens_dropped, 0);
assert_eq!(stats.routing_efficiency, 100.0);
assert_eq!(stats.expert_utilization.len(), 4);
}
#[test]
fn test_latency_stats() {
let mut latency_stats = LatencyStats::new();
latency_stats.record_latency(Duration::from_millis(10));
latency_stats.record_latency(Duration::from_millis(20));
latency_stats.record_latency(Duration::from_millis(30));
assert_eq!(latency_stats.total_measurements, 3);
assert_eq!(latency_stats.average_latency(), 20.0);
assert_eq!(latency_stats.min_latency_ms, 10.0);
assert_eq!(latency_stats.max_latency_ms, 30.0);
}
#[test]
fn test_utilization_cv() {
let mut stats = RoutingStats::new();
stats.expert_utilization = vec![0.1, 0.2, 0.3, 0.4];
let cv = stats.utilization_cv();
assert!(cv > 0.0); }
#[test]
fn test_capacity_stats() {
let mut capacity_stats = CapacityStats::new();
let routing_decision = RoutingDecision {
expert_assignments: vec![],
expert_capacities: vec![50, 75, 25, 100], total_tokens: 250,
tokens_dropped: 0,
load_balance_loss: 0.0,
router_z_loss: 0.0,
auxiliary_loss: 0.0,
};
capacity_stats.update(&routing_decision);
assert!(capacity_stats.average_utilization > 0.0);
assert!(capacity_stats.peak_utilization > 0.0);
}
#[test]
fn test_throughput_stats() {
let throughput = ThroughputStats {
total_tokens: 1000,
total_routings: 10,
tokens_per_routing: 100.0,
routing_efficiency: 95.0,
average_latency: 50.0, };
assert_eq!(throughput.tokens_per_second(), 2000.0); assert_eq!(throughput.routings_per_second(), 20.0); }
#[test]
fn test_performance_monitor() {
let mut monitor = monitoring::PerformanceMonitor::new(Duration::from_secs(1));
let routing_decision = RoutingDecision {
expert_assignments: vec![],
expert_capacities: vec![10, 20, 30],
total_tokens: 60,
tokens_dropped: 0,
load_balance_loss: 0.1,
router_z_loss: 0.05,
auxiliary_loss: 0.15,
};
monitor.record_routing(&routing_decision, Duration::from_millis(25));
assert_eq!(monitor.stats().total_routings, 1);
assert_eq!(monitor.stats().total_tokens, 60);
}
}