use crate::gguf::GgufQuantType;
pub use super::ExpertId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ExpertPrecision {
Hot,
Warm,
Cold,
}
impl ExpertPrecision {
pub fn name(&self) -> &'static str {
match self {
ExpertPrecision::Hot => "hot",
ExpertPrecision::Warm => "warm",
ExpertPrecision::Cold => "cold",
}
}
}
#[derive(Debug, Clone)]
pub struct PrecisionConfig {
pub hot_percentile: f32,
pub cold_percentile: f32,
pub hot_format: GgufQuantType,
pub warm_format: GgufQuantType,
pub cold_format: GgufQuantType,
}
impl Default for PrecisionConfig {
fn default() -> Self {
Self {
hot_percentile: 0.9,
cold_percentile: 0.3,
hot_format: GgufQuantType::Q4_K,
warm_format: GgufQuantType::Q3_K,
cold_format: GgufQuantType::Q2_K,
}
}
}
impl PrecisionConfig {
pub fn memory_constrained() -> Self {
Self {
hot_percentile: 0.95,
cold_percentile: 0.4,
hot_format: GgufQuantType::Q4_K,
warm_format: GgufQuantType::Q2_K,
cold_format: GgufQuantType::Q2_K,
}
}
pub fn quality_focused() -> Self {
Self {
hot_percentile: 0.8,
cold_percentile: 0.2,
hot_format: GgufQuantType::Q5_K,
warm_format: GgufQuantType::Q4_K,
cold_format: GgufQuantType::Q3_K,
}
}
pub fn validate(&self) -> Result<(), &'static str> {
if self.hot_percentile <= 0.0 || self.hot_percentile > 1.0 {
return Err("hot_percentile must be in (0.0, 1.0]");
}
if self.cold_percentile < 0.0 || self.cold_percentile >= 1.0 {
return Err("cold_percentile must be in [0.0, 1.0)");
}
if self.cold_percentile >= self.hot_percentile {
return Err("cold_percentile must be less than hot_percentile");
}
Ok(())
}
}
pub struct PrecisionAllocator {
num_experts: usize,
counts: Vec<u64>,
config: PrecisionConfig,
hot_threshold: u64,
cold_threshold: u64,
}
impl PrecisionAllocator {
pub fn new(num_experts: usize, config: PrecisionConfig) -> Result<Self, &'static str> {
config.validate()?;
Ok(Self {
num_experts,
counts: vec![0; num_experts],
config,
hot_threshold: 0,
cold_threshold: 0,
})
}
pub fn new_unchecked(num_experts: usize, config: PrecisionConfig) -> Self {
Self::new(num_experts, config).expect("PrecisionConfig validation failed")
}
#[inline]
pub fn record_activation(&mut self, expert_id: ExpertId) {
if expert_id < self.num_experts {
self.counts[expert_id] = self.counts[expert_id].saturating_add(1);
}
}
pub fn record_activations(&mut self, expert_ids: &[ExpertId]) {
for &expert_id in expert_ids {
self.record_activation(expert_id);
}
}
pub fn allocate(&self, expert_id: ExpertId) -> ExpertPrecision {
if expert_id >= self.num_experts {
return ExpertPrecision::Cold;
}
let count = self.counts[expert_id];
if self.hot_threshold == 0 && self.cold_threshold == 0 {
return ExpertPrecision::Cold;
}
if count >= self.hot_threshold && self.hot_threshold > 0 {
ExpertPrecision::Hot
} else if count >= self.cold_threshold && count > 0 {
ExpertPrecision::Warm
} else {
ExpertPrecision::Cold
}
}
pub fn get_format(&self, expert_id: ExpertId) -> GgufQuantType {
match self.allocate(expert_id) {
ExpertPrecision::Hot => self.config.hot_format,
ExpertPrecision::Warm => self.config.warm_format,
ExpertPrecision::Cold => self.config.cold_format,
}
}
pub fn recompute_thresholds(&mut self) {
let max_count = self.counts.iter().copied().max().unwrap_or(0);
if max_count == 0 {
self.hot_threshold = 0;
self.cold_threshold = 0;
return;
}
self.hot_threshold = (max_count as f64 * self.config.hot_percentile as f64).ceil() as u64;
self.cold_threshold =
(max_count as f64 * self.config.cold_percentile as f64).floor() as u64;
if self.cold_threshold == 0 && max_count > 0 {
self.cold_threshold = 1;
}
}
pub fn get_precision_map(&self) -> Vec<(ExpertId, ExpertPrecision)> {
(0..self.num_experts)
.map(|id| (id, self.allocate(id)))
.collect()
}
pub fn get_count(&self, expert_id: ExpertId) -> u64 {
self.counts.get(expert_id).copied().unwrap_or(0)
}
pub fn total_activations(&self) -> u64 {
self.counts.iter().sum()
}
pub fn tier_counts(&self) -> (usize, usize, usize) {
let mut hot = 0;
let mut warm = 0;
let mut cold = 0;
for id in 0..self.num_experts {
match self.allocate(id) {
ExpertPrecision::Hot => hot += 1,
ExpertPrecision::Warm => warm += 1,
ExpertPrecision::Cold => cold += 1,
}
}
(hot, warm, cold)
}
pub fn reset(&mut self) {
self.counts.fill(0);
self.hot_threshold = 0;
self.cold_threshold = 0;
}
pub fn num_experts(&self) -> usize {
self.num_experts
}
pub fn hot_threshold(&self) -> u64 {
self.hot_threshold
}
pub fn cold_threshold(&self) -> u64 {
self.cold_threshold
}
pub fn config(&self) -> &PrecisionConfig {
&self.config
}
pub fn experts_by_precision(&self, precision: ExpertPrecision) -> Vec<ExpertId> {
(0..self.num_experts)
.filter(|&id| self.allocate(id) == precision)
.collect()
}
pub fn compute_percentile(&self, expert_id: ExpertId) -> f32 {
if expert_id >= self.num_experts {
return 0.0;
}
let max_count = self.counts.iter().copied().max().unwrap_or(0);
if max_count == 0 {
return 0.0;
}
self.counts[expert_id] as f32 / max_count as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allocator_creation() {
let config = PrecisionConfig::default();
let allocator = PrecisionAllocator::new(8, config).unwrap();
assert_eq!(allocator.num_experts(), 8);
assert_eq!(allocator.total_activations(), 0);
assert_eq!(allocator.hot_threshold(), 0);
assert_eq!(allocator.cold_threshold(), 0);
for id in 0..8 {
assert_eq!(allocator.allocate(id), ExpertPrecision::Cold);
}
}
#[test]
fn test_hot_expert_allocation() {
let config = PrecisionConfig {
hot_percentile: 0.9,
cold_percentile: 0.3,
..Default::default()
};
let mut allocator = PrecisionAllocator::new(8, config).unwrap();
for _ in 0..100 {
allocator.record_activation(0);
}
for id in 1..8 {
for _ in 0..10 {
allocator.record_activation(id);
}
}
allocator.recompute_thresholds();
assert_eq!(allocator.allocate(0), ExpertPrecision::Hot);
assert_eq!(allocator.get_format(0), GgufQuantType::Q4_K);
}
#[test]
fn test_warm_expert_allocation() {
let config = PrecisionConfig {
hot_percentile: 0.9,
cold_percentile: 0.3,
..Default::default()
};
let mut allocator = PrecisionAllocator::new(8, config).unwrap();
for _ in 0..100 {
allocator.record_activation(0);
}
for _ in 0..50 {
allocator.record_activation(1);
}
allocator.recompute_thresholds();
assert_eq!(allocator.allocate(1), ExpertPrecision::Warm);
assert_eq!(allocator.get_format(1), GgufQuantType::Q3_K);
}
#[test]
fn test_cold_expert_allocation() {
let config = PrecisionConfig {
hot_percentile: 0.9,
cold_percentile: 0.3,
..Default::default()
};
let mut allocator = PrecisionAllocator::new(8, config).unwrap();
for _ in 0..100 {
allocator.record_activation(0);
}
for _ in 0..5 {
allocator.record_activation(7);
}
allocator.recompute_thresholds();
assert_eq!(allocator.allocate(7), ExpertPrecision::Cold);
assert_eq!(allocator.get_format(7), GgufQuantType::Q2_K);
}
#[test]
fn test_percentile_thresholds() {
let config = PrecisionConfig {
hot_percentile: 0.8,
cold_percentile: 0.2,
..Default::default()
};
let mut allocator = PrecisionAllocator::new(4, config).unwrap();
for _ in 0..100 {
allocator.record_activation(0);
}
for _ in 0..75 {
allocator.record_activation(1);
}
for _ in 0..25 {
allocator.record_activation(2);
}
for _ in 0..5 {
allocator.record_activation(3);
}
allocator.recompute_thresholds();
assert!(
allocator.hot_threshold() >= 80 && allocator.hot_threshold() <= 81,
"hot_threshold {} should be 80 or 81",
allocator.hot_threshold()
);
assert_eq!(allocator.cold_threshold(), 20);
assert_eq!(allocator.allocate(0), ExpertPrecision::Hot);
assert_eq!(allocator.allocate(1), ExpertPrecision::Warm);
assert_eq!(allocator.allocate(2), ExpertPrecision::Warm);
assert_eq!(allocator.allocate(3), ExpertPrecision::Cold);
}
#[test]
fn test_activation_recording() {
let config = PrecisionConfig::default();
let mut allocator = PrecisionAllocator::new(4, config).unwrap();
allocator.record_activation(0);
allocator.record_activation(0);
allocator.record_activation(1);
assert_eq!(allocator.get_count(0), 2);
assert_eq!(allocator.get_count(1), 1);
assert_eq!(allocator.get_count(2), 0);
assert_eq!(allocator.total_activations(), 3);
allocator.record_activations(&[2, 2, 3, 0]);
assert_eq!(allocator.get_count(0), 3);
assert_eq!(allocator.get_count(2), 2);
assert_eq!(allocator.get_count(3), 1);
assert_eq!(allocator.total_activations(), 7);
allocator.record_activation(100);
assert_eq!(allocator.total_activations(), 7);
}
#[test]
fn test_format_mapping() {
let config = PrecisionConfig {
hot_percentile: 0.9,
cold_percentile: 0.3,
hot_format: GgufQuantType::Q5_K,
warm_format: GgufQuantType::Q4_K,
cold_format: GgufQuantType::Q3_K,
};
let mut allocator = PrecisionAllocator::new(3, config).unwrap();
for _ in 0..100 {
allocator.record_activation(0);
}
for _ in 0..50 {
allocator.record_activation(1);
}
for _ in 0..10 {
allocator.record_activation(2);
}
allocator.recompute_thresholds();
assert_eq!(allocator.get_format(0), GgufQuantType::Q5_K);
assert_eq!(allocator.get_format(1), GgufQuantType::Q4_K);
assert_eq!(allocator.get_format(2), GgufQuantType::Q3_K);
}
#[test]
fn test_recompute_thresholds() {
let config = PrecisionConfig {
hot_percentile: 0.9,
cold_percentile: 0.3,
..Default::default()
};
let mut allocator = PrecisionAllocator::new(4, config).unwrap();
allocator.recompute_thresholds();
assert_eq!(allocator.hot_threshold(), 0);
assert_eq!(allocator.cold_threshold(), 0);
for _ in 0..100 {
allocator.record_activation(0);
}
allocator.recompute_thresholds();
assert_eq!(allocator.hot_threshold(), 90);
assert_eq!(allocator.cold_threshold(), 30);
for _ in 0..100 {
allocator.record_activation(0);
}
allocator.recompute_thresholds();
assert_eq!(allocator.hot_threshold(), 180);
assert_eq!(allocator.cold_threshold(), 60);
}
#[test]
fn test_precision_map() {
let config = PrecisionConfig {
hot_percentile: 0.9,
cold_percentile: 0.3,
..Default::default()
};
let mut allocator = PrecisionAllocator::new(4, config).unwrap();
for _ in 0..100 {
allocator.record_activation(0);
}
for _ in 0..50 {
allocator.record_activation(1);
}
for _ in 0..35 {
allocator.record_activation(2);
}
for _ in 0..10 {
allocator.record_activation(3);
}
allocator.recompute_thresholds();
let map = allocator.get_precision_map();
assert_eq!(map.len(), 4);
assert_eq!(map[0], (0, ExpertPrecision::Hot));
assert_eq!(map[1], (1, ExpertPrecision::Warm));
assert_eq!(map[2], (2, ExpertPrecision::Warm));
assert_eq!(map[3], (3, ExpertPrecision::Cold));
}
#[test]
fn test_tier_counts() {
let config = PrecisionConfig {
hot_percentile: 0.9,
cold_percentile: 0.3,
..Default::default()
};
let mut allocator = PrecisionAllocator::new(8, config).unwrap();
for _ in 0..100 {
allocator.record_activation(0);
}
for _ in 0..95 {
allocator.record_activation(1);
}
for _ in 0..50 {
allocator.record_activation(2);
}
for _ in 0..40 {
allocator.record_activation(3);
}
for _ in 0..35 {
allocator.record_activation(4);
}
for _ in 0..10 {
allocator.record_activation(5);
}
for _ in 0..5 {
allocator.record_activation(6);
}
allocator.recompute_thresholds();
let (hot, warm, cold) = allocator.tier_counts();
assert_eq!(hot, 2, "Expected 2 hot experts");
assert!(warm >= 2, "Expected at least 2 warm experts");
assert!(cold >= 2, "Expected at least 2 cold experts");
assert_eq!(hot + warm + cold, 8, "Total should equal num_experts");
}
#[test]
fn test_reset() {
let config = PrecisionConfig::default();
let mut allocator = PrecisionAllocator::new(4, config).unwrap();
for _ in 0..100 {
allocator.record_activation(0);
}
allocator.recompute_thresholds();
assert!(allocator.total_activations() > 0);
assert!(allocator.hot_threshold() > 0);
allocator.reset();
assert_eq!(allocator.total_activations(), 0);
assert_eq!(allocator.hot_threshold(), 0);
assert_eq!(allocator.cold_threshold(), 0);
for id in 0..4 {
assert_eq!(allocator.get_count(id), 0);
}
}
#[test]
fn test_experts_by_precision() {
let config = PrecisionConfig {
hot_percentile: 0.9,
cold_percentile: 0.3,
..Default::default()
};
let mut allocator = PrecisionAllocator::new(6, config).unwrap();
for _ in 0..100 {
allocator.record_activation(0);
} for _ in 0..92 {
allocator.record_activation(1);
} for _ in 0..50 {
allocator.record_activation(2);
} for _ in 0..40 {
allocator.record_activation(3);
} for _ in 0..10 {
allocator.record_activation(4);
}
allocator.recompute_thresholds();
let hot_experts = allocator.experts_by_precision(ExpertPrecision::Hot);
let warm_experts = allocator.experts_by_precision(ExpertPrecision::Warm);
let cold_experts = allocator.experts_by_precision(ExpertPrecision::Cold);
assert!(hot_experts.contains(&0));
assert!(hot_experts.contains(&1));
assert!(warm_experts.contains(&2) || warm_experts.contains(&3));
assert!(cold_experts.contains(&4) || cold_experts.contains(&5));
}
#[test]
fn test_compute_percentile() {
let config = PrecisionConfig::default();
let mut allocator = PrecisionAllocator::new(4, config).unwrap();
assert_eq!(allocator.compute_percentile(0), 0.0);
for _ in 0..100 {
allocator.record_activation(0);
}
for _ in 0..50 {
allocator.record_activation(1);
}
for _ in 0..25 {
allocator.record_activation(2);
}
assert!((allocator.compute_percentile(0) - 1.0).abs() < f32::EPSILON);
assert!((allocator.compute_percentile(1) - 0.5).abs() < f32::EPSILON);
assert!((allocator.compute_percentile(2) - 0.25).abs() < f32::EPSILON);
assert!((allocator.compute_percentile(3) - 0.0).abs() < f32::EPSILON);
assert_eq!(allocator.compute_percentile(100), 0.0);
}
#[test]
fn test_config_validation() {
let valid = PrecisionConfig::default();
assert!(valid.validate().is_ok());
let invalid1 = PrecisionConfig {
hot_percentile: 1.5,
..Default::default()
};
assert!(invalid1.validate().is_err());
let invalid2 = PrecisionConfig {
hot_percentile: 0.5,
cold_percentile: 0.6,
..Default::default()
};
assert!(invalid2.validate().is_err());
let invalid3 = PrecisionConfig {
cold_percentile: -0.1,
..Default::default()
};
assert!(invalid3.validate().is_err());
}
#[test]
fn test_precision_name() {
assert_eq!(ExpertPrecision::Hot.name(), "hot");
assert_eq!(ExpertPrecision::Warm.name(), "warm");
assert_eq!(ExpertPrecision::Cold.name(), "cold");
}
#[test]
fn test_out_of_bounds_expert_id() {
let config = PrecisionConfig::default();
let allocator = PrecisionAllocator::new(4, config).unwrap();
assert_eq!(allocator.allocate(100), ExpertPrecision::Cold);
assert_eq!(allocator.get_format(100), GgufQuantType::Q2_K);
assert_eq!(allocator.get_count(100), 0);
}
#[test]
fn test_memory_constrained_config() {
let config = PrecisionConfig::memory_constrained();
assert!(config.validate().is_ok());
assert_eq!(config.hot_percentile, 0.95);
assert_eq!(config.cold_percentile, 0.4);
assert_eq!(config.warm_format, GgufQuantType::Q2_K);
assert_eq!(config.cold_format, GgufQuantType::Q2_K);
}
#[test]
fn test_quality_focused_config() {
let config = PrecisionConfig::quality_focused();
assert!(config.validate().is_ok());
assert_eq!(config.hot_percentile, 0.8);
assert_eq!(config.cold_percentile, 0.2);
assert_eq!(config.hot_format, GgufQuantType::Q5_K);
assert_eq!(config.warm_format, GgufQuantType::Q4_K);
assert_eq!(config.cold_format, GgufQuantType::Q3_K);
}
#[test]
fn test_saturating_add_for_counts() {
let config = PrecisionConfig::default();
let mut allocator = PrecisionAllocator::new(1, config).unwrap();
allocator.counts[0] = u64::MAX - 1;
allocator.record_activation(0);
assert_eq!(allocator.get_count(0), u64::MAX);
allocator.record_activation(0);
assert_eq!(allocator.get_count(0), u64::MAX);
}
}