use ndarray::{Array1, Array2, ArrayView1};
use super::error::{MoeError, MoeResult};
use super::expert::Expert;
use super::gate::{GatingDecision, TopKGate};
use super::load_balance::BatchGatingStats;
pub const DEFAULT_CAPACITY_FACTOR: f64 = 1.25;
pub struct MoELayer {
gate: TopKGate,
experts: Vec<Box<dyn Expert>>,
capacity_factor: Option<f64>,
d_in: usize,
d_out: usize,
}
impl std::fmt::Debug for MoELayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MoELayer")
.field("num_experts", &self.experts.len())
.field("d_in", &self.d_in)
.field("d_out", &self.d_out)
.field("capacity_factor", &self.capacity_factor)
.finish()
}
}
impl MoELayer {
pub fn new(gate: TopKGate, experts: Vec<Box<dyn Expert>>) -> MoeResult<Self> {
if experts.is_empty() {
return Err(MoeError::EmptyExpertPool);
}
if gate.num_experts() != experts.len() {
return Err(MoeError::ShapeMismatch {
expected: gate.num_experts(),
got: experts.len(),
});
}
let d_in = experts[0].input_dim();
let d_out = experts[0].output_dim();
if gate.d_model() != d_in {
return Err(MoeError::ShapeMismatch {
expected: gate.d_model(),
got: d_in,
});
}
for expert in experts.iter().skip(1) {
if expert.input_dim() != d_in {
return Err(MoeError::ShapeMismatch {
expected: d_in,
got: expert.input_dim(),
});
}
if expert.output_dim() != d_out {
return Err(MoeError::ShapeMismatch {
expected: d_out,
got: expert.output_dim(),
});
}
}
Ok(Self {
gate,
experts,
capacity_factor: None,
d_in,
d_out,
})
}
pub fn with_capacity_factor(mut self, factor: f64) -> MoeResult<Self> {
if !factor.is_finite() || factor <= 0.0 {
return Err(MoeError::InvalidCapacityFactor { value: factor });
}
self.capacity_factor = Some(factor);
Ok(self)
}
pub fn without_capacity_factor(mut self) -> Self {
self.capacity_factor = None;
self
}
pub fn gate(&self) -> &TopKGate {
&self.gate
}
pub fn gate_mut(&mut self) -> &mut TopKGate {
&mut self.gate
}
pub fn num_experts(&self) -> usize {
self.experts.len()
}
pub fn capacity_factor(&self) -> Option<f64> {
self.capacity_factor
}
pub fn input_dim(&self) -> usize {
self.d_in
}
pub fn output_dim(&self) -> usize {
self.d_out
}
pub fn forward(&self, x: &ArrayView1<f64>) -> MoeResult<(Array1<f64>, GatingDecision)> {
if x.len() != self.d_in {
return Err(MoeError::ShapeMismatch {
expected: self.d_in,
got: x.len(),
});
}
let decision = self.gate.forward(x)?;
let mut output = Array1::<f64>::zeros(self.d_out);
for (slot, &expert_idx) in decision.top_k_indices.iter().enumerate() {
let weight = decision.top_k_softmax_weights[slot];
let expert_out = self.experts[expert_idx].forward(x)?;
output.scaled_add(weight, &expert_out);
}
Ok((output, decision))
}
pub fn forward_batch(
&self,
batch: &ndarray::ArrayView2<f64>,
) -> MoeResult<(Array2<f64>, BatchGatingStats)> {
if batch.ncols() != self.d_in {
return Err(MoeError::ShapeMismatch {
expected: self.d_in,
got: batch.ncols(),
});
}
let batch_size = batch.nrows();
let num_experts = self.experts.len();
let mut decisions: Vec<GatingDecision> = Vec::with_capacity(batch_size);
let mut stats = BatchGatingStats::empty(batch_size, num_experts);
for (t, row) in batch.rows().into_iter().enumerate() {
let decision = self.gate.forward(&row)?;
let full = decision.full_softmax();
for (i, v) in full.iter().enumerate() {
stats.gate_scores_per_token[(t, i)] = *v;
}
let primary = decision.top_k_indices[0];
stats.routed_expert_per_token.push(primary);
decisions.push(decision);
}
let capacity: usize = match self.capacity_factor {
Some(factor) => {
let cap = (factor * batch_size as f64 / num_experts as f64).ceil() as usize;
cap.max(1)
}
None => usize::MAX,
};
let mut assigned_counts = vec![0_usize; num_experts];
let mut outputs = Array2::<f64>::zeros((batch_size, self.d_out));
for (t, decision) in decisions.iter().enumerate() {
let row = batch.row(t);
for (slot, &expert_idx) in decision.top_k_indices.iter().enumerate() {
if assigned_counts[expert_idx] >= capacity {
continue;
}
assigned_counts[expert_idx] += 1;
let weight = decision.top_k_softmax_weights[slot];
let expert_out = self.experts[expert_idx].forward(&row)?;
let mut row_view = outputs.row_mut(t);
for (o, e) in row_view.iter_mut().zip(expert_out.iter()) {
*o += weight * *e;
}
}
}
Ok((outputs, stats))
}
}
#[cfg(test)]
mod local_tests {
use super::*;
use crate::moe::expert::LinearExpert;
use ndarray::{array, Array2};
fn linear_identity(dim: usize) -> LinearExpert {
let weights = Array2::<f64>::eye(dim);
let bias = Array1::<f64>::zeros(dim);
LinearExpert::from_arrays(weights, bias).expect("construct identity expert")
}
#[test]
fn new_rejects_empty_pool() {
let gate = TopKGate::xavier_init(2, 2, 1, 0).expect("gate");
let err = MoELayer::new(gate, Vec::new()).expect_err("must fail");
assert_eq!(err, MoeError::EmptyExpertPool);
}
#[test]
fn new_rejects_gate_expert_count_mismatch() {
let gate = TopKGate::xavier_init(2, 3, 1, 0).expect("gate");
let experts: Vec<Box<dyn Expert>> = vec![Box::new(linear_identity(2))];
let err = MoELayer::new(gate, experts).expect_err("must fail");
assert!(matches!(err, MoeError::ShapeMismatch { .. }));
}
#[test]
fn single_forward_matches_hand_computation() {
let gate = TopKGate::xavier_init(2, 2, 2, 7).expect("gate");
let experts: Vec<Box<dyn Expert>> =
vec![Box::new(linear_identity(2)), Box::new(linear_identity(2))];
let layer = MoELayer::new(gate, experts).expect("layer");
let x = array![1.5_f64, -2.5];
let (y, decision) = layer.forward(&x.view()).expect("forward");
assert_eq!(decision.top_k_indices.len(), 2);
for (a, b) in y.iter().zip(x.iter()) {
assert!((a - b).abs() < 1e-12, "identity pass-through failed");
}
}
}