use crate::error::{ModelError, ModelResult};
use kizzasi_core::SignalPredictor;
use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
use std::fmt;
use tracing::{debug, trace};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RoutingStrategy {
Softmax,
TopK,
NoisyTopK,
}
impl fmt::Display for RoutingStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Softmax => write!(f, "Softmax"),
Self::TopK => write!(f, "Top-K"),
Self::NoisyTopK => write!(f, "Noisy Top-K"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MoEConfig {
pub num_experts: usize,
pub top_k: usize,
pub input_dim: usize,
pub output_dim: usize,
pub routing_strategy: RoutingStrategy,
pub load_balance_coeff: f32,
pub expert_dropout: f32,
pub noise_std: f32,
}
impl Default for MoEConfig {
fn default() -> Self {
Self {
num_experts: 8,
top_k: 2,
input_dim: 256,
output_dim: 256,
routing_strategy: RoutingStrategy::TopK,
load_balance_coeff: 0.01,
expert_dropout: 0.0,
noise_std: 1.0,
}
}
}
#[derive(Debug)]
pub struct Router {
weights: Array2<f32>,
noise_weights: Option<Array2<f32>>,
config: MoEConfig,
}
impl Router {
pub fn new(config: MoEConfig) -> ModelResult<Self> {
debug!(
"Creating router: {} experts, top-k={}, strategy={}",
config.num_experts, config.top_k, config.routing_strategy
);
let weights = Array2::zeros((config.input_dim, config.num_experts));
let noise_weights = if config.routing_strategy == RoutingStrategy::NoisyTopK {
Some(Array2::zeros((config.input_dim, config.num_experts)))
} else {
None
};
Ok(Self {
weights,
noise_weights,
config,
})
}
pub fn route(&self, input: &Array1<f32>) -> ModelResult<(Vec<usize>, Vec<f32>)> {
trace!("Computing routing for input shape: {:?}", input.shape());
if input.len() != self.config.input_dim {
return Err(ModelError::dimension_mismatch(
"router input",
self.config.input_dim,
input.len(),
));
}
let logits = self.weights.t().dot(input);
let logits = if let Some(ref noise_weights) = self.noise_weights {
let _noise_logits = noise_weights.t().dot(input);
logits
} else {
logits
};
match self.config.routing_strategy {
RoutingStrategy::Softmax => self.softmax_route(&logits),
RoutingStrategy::TopK | RoutingStrategy::NoisyTopK => self.topk_route(&logits),
}
}
fn softmax_route(&self, logits: &Array1<f32>) -> ModelResult<(Vec<usize>, Vec<f32>)> {
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
let probabilities: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
let indices: Vec<usize> = (0..self.config.num_experts).collect();
Ok((indices, probabilities))
}
fn topk_route(&self, logits: &Array1<f32>) -> ModelResult<(Vec<usize>, Vec<f32>)> {
let mut indexed_logits: Vec<(usize, f32)> =
logits.iter().enumerate().map(|(i, &v)| (i, v)).collect();
indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k = self.config.top_k.min(self.config.num_experts);
let top_experts: Vec<(usize, f32)> = indexed_logits.into_iter().take(top_k).collect();
let top_logits: Vec<f32> = top_experts.iter().map(|(_, v)| *v).collect();
let max_logit = top_logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = top_logits.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: f32 = exp_logits.iter().sum();
let indices: Vec<usize> = top_experts.iter().map(|(i, _)| *i).collect();
let weights: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
trace!(
"Selected {} experts: {:?} with weights {:?}",
top_k,
indices,
weights
);
Ok((indices, weights))
}
pub fn load_balance_loss(&self, all_routes: &[Vec<usize>]) -> f32 {
if all_routes.is_empty() {
return 0.0;
}
let mut expert_counts = vec![0.0f32; self.config.num_experts];
for routes in all_routes {
for &expert_idx in routes {
expert_counts[expert_idx] += 1.0;
}
}
let total: f32 = expert_counts.iter().sum();
if total == 0.0 {
return 0.0;
}
let mean = total / self.config.num_experts as f32;
let variance: f32 = expert_counts
.iter()
.map(|&count| (count - mean).powi(2))
.sum::<f32>()
/ self.config.num_experts as f32;
let std_dev = variance.sqrt();
let cv = if mean > 0.0 { std_dev / mean } else { 0.0 };
cv * self.config.load_balance_coeff
}
}
pub struct Expert {
id: usize,
input_proj: Option<Array2<f32>>,
output_proj: Option<Array2<f32>>,
weights: Array2<f32>,
}
impl Expert {
pub fn new(
id: usize,
input_dim: usize,
hidden_dim: usize,
output_dim: usize,
) -> ModelResult<Self> {
debug!(
"Creating expert {}: {}→{}→{}",
id, input_dim, hidden_dim, output_dim
);
let weights = Array2::zeros((hidden_dim, output_dim));
let input_proj = if input_dim != hidden_dim {
Some(Array2::zeros((input_dim, hidden_dim)))
} else {
None
};
let output_proj = None;
Ok(Self {
id,
input_proj,
output_proj,
weights,
})
}
pub fn forward(&self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
trace!(
"Expert {} forward: input shape {:?}",
self.id,
input.shape()
);
let hidden = if let Some(ref proj) = self.input_proj {
proj.t().dot(input)
} else {
input.clone()
};
let output = self.weights.t().dot(&hidden);
let output = if let Some(ref proj) = self.output_proj {
proj.t().dot(&output)
} else {
output
};
Ok(output)
}
}
pub struct MixtureOfExperts {
router: Router,
experts: Vec<Expert>,
config: MoEConfig,
routing_history: Vec<Vec<usize>>,
}
impl MixtureOfExperts {
pub fn new(config: MoEConfig) -> ModelResult<Self> {
debug!(
"Creating MixtureOfExperts: {} experts, strategy={}",
config.num_experts, config.routing_strategy
);
let router = Router::new(config.clone())?;
let mut experts = Vec::with_capacity(config.num_experts);
for i in 0..config.num_experts {
let expert = Expert::new(
i,
config.input_dim,
config.input_dim, config.output_dim,
)?;
experts.push(expert);
}
Ok(Self {
router,
experts,
config,
routing_history: Vec::new(),
})
}
pub fn forward(&mut self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
trace!("MoE forward: input shape {:?}", input.shape());
let (expert_indices, weights) = self.router.route(input)?;
self.routing_history.push(expert_indices.clone());
let mut output = Array1::zeros(self.config.output_dim);
for (idx, &expert_idx) in expert_indices.iter().enumerate() {
let expert_output = self.experts[expert_idx].forward(input)?;
let weight = weights[idx];
output = output + expert_output.mapv(|x| x * weight);
}
Ok(output)
}
pub fn get_load_balance_loss(&self) -> f32 {
self.router.load_balance_loss(&self.routing_history)
}
pub fn clear_routing_history(&mut self) {
self.routing_history.clear();
}
pub fn expert_usage_stats(&self) -> Vec<usize> {
let mut counts = vec![0usize; self.config.num_experts];
for routes in &self.routing_history {
for &idx in routes {
counts[idx] += 1;
}
}
counts
}
}
impl SignalPredictor for MixtureOfExperts {
fn step(&mut self, input: &Array1<f32>) -> kizzasi_core::CoreResult<Array1<f32>> {
self.forward(input)
.map_err(|e| kizzasi_core::CoreError::InferenceError(e.to_string()))
}
fn reset(&mut self) {
self.clear_routing_history();
}
fn context_window(&self) -> usize {
1 }
}
impl fmt::Display for MixtureOfExperts {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"MixtureOfExperts({} experts, top_k={}, strategy={})",
self.config.num_experts, self.config.top_k, self.config.routing_strategy
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_router_creation() {
let config = MoEConfig::default();
let router = Router::new(config).expect("Failed to create router");
assert_eq!(router.weights.shape(), &[256, 8]);
}
#[test]
fn test_topk_routing() {
let config = MoEConfig {
num_experts: 4,
top_k: 2,
routing_strategy: RoutingStrategy::TopK,
..Default::default()
};
let router = Router::new(config).expect("Failed to create router");
let input = Array1::from_vec(vec![0.1; 256]);
let (indices, weights) = router.route(&input).expect("Routing failed");
assert_eq!(indices.len(), 2, "Should select top-2 experts");
assert_eq!(weights.len(), 2, "Should have 2 weights");
let sum: f32 = weights.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "Weights should sum to 1.0");
}
#[test]
fn test_softmax_routing() {
let config = MoEConfig {
num_experts: 4,
routing_strategy: RoutingStrategy::Softmax,
..Default::default()
};
let router = Router::new(config.clone()).expect("Failed to create router");
let input = Array1::from_vec(vec![0.1; 256]);
let (indices, weights) = router.route(&input).expect("Routing failed");
assert_eq!(indices.len(), config.num_experts, "Should use all experts");
assert_eq!(
weights.len(),
config.num_experts,
"Should have weights for all experts"
);
let sum: f32 = weights.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "Weights should sum to 1.0");
}
#[test]
fn test_expert_forward() {
let expert = Expert::new(0, 256, 256, 256).expect("Failed to create expert");
let input = Array1::from_vec(vec![0.1; 256]);
let output = expert.forward(&input).expect("Forward failed");
assert_eq!(output.len(), 256, "Output should have correct dimension");
}
#[test]
fn test_moe_forward() {
let config = MoEConfig {
num_experts: 4,
top_k: 2,
input_dim: 128,
output_dim: 128,
..Default::default()
};
let mut moe = MixtureOfExperts::new(config).expect("Failed to create MoE");
let input = Array1::from_vec(vec![0.1; 128]);
let output = moe.forward(&input).expect("Forward failed");
assert_eq!(output.len(), 128, "Output should have correct dimension");
}
#[test]
fn test_load_balance_loss() {
let config = MoEConfig {
num_experts: 4,
top_k: 1,
input_dim: 64,
output_dim: 64,
..Default::default()
};
let mut moe = MixtureOfExperts::new(config).expect("Failed to create MoE");
for _ in 0..10 {
let input = Array1::from_vec(vec![0.1; 64]);
let _ = moe.forward(&input);
}
let loss = moe.get_load_balance_loss();
assert!(loss >= 0.0, "Load balance loss should be non-negative");
}
#[test]
fn test_expert_usage_stats() {
let config = MoEConfig {
num_experts: 4,
top_k: 2,
input_dim: 64,
output_dim: 64,
..Default::default()
};
let mut moe = MixtureOfExperts::new(config.clone()).expect("Failed to create MoE");
for _ in 0..10 {
let input = Array1::from_vec(vec![0.1; 64]);
let _ = moe.forward(&input);
}
let stats = moe.expert_usage_stats();
assert_eq!(stats.len(), config.num_experts);
let total: usize = stats.iter().sum();
assert!(total > 0, "At least some experts should be used");
}
#[test]
fn test_signal_predictor_trait() {
let config = MoEConfig {
input_dim: 64,
output_dim: 64,
..Default::default()
};
let mut moe = MixtureOfExperts::new(config).expect("Failed to create MoE");
let input = Array1::from_vec(vec![0.5; 64]);
let output = moe.step(&input).expect("Step failed");
assert_eq!(output.len(), 64);
assert_eq!(moe.context_window(), 1);
moe.reset();
assert_eq!(moe.routing_history.len(), 0);
}
#[test]
fn test_dimension_mismatch() {
let config = MoEConfig {
input_dim: 128,
output_dim: 128,
..Default::default()
};
let mut moe = MixtureOfExperts::new(config).expect("Failed to create MoE");
let wrong_input = Array1::from_vec(vec![0.1; 64]);
let result = moe.forward(&wrong_input);
assert!(result.is_err(), "Should fail with dimension mismatch");
}
}