use ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::random::{Normal, SeedableRng, StdRng};
use smallvec::SmallVec;
use super::error::{MoeError, MoeResult};
#[derive(Debug, Clone, PartialEq)]
pub struct GatingDecision {
pub top_k_indices: SmallVec<[usize; 8]>,
pub top_k_softmax_weights: SmallVec<[f64; 8]>,
pub raw_logits: Vec<f64>,
}
impl GatingDecision {
pub fn k(&self) -> usize {
self.top_k_indices.len()
}
pub fn num_experts(&self) -> usize {
self.raw_logits.len()
}
pub fn full_softmax(&self) -> Vec<f64> {
let n = self.raw_logits.len();
if n == 0 {
return Vec::new();
}
let max_logit = self
.raw_logits
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let mut out = Vec::with_capacity(n);
let mut sum = 0.0_f64;
for &logit in &self.raw_logits {
let e = (logit - max_logit).exp();
sum += e;
out.push(e);
}
if sum > 0.0 {
for v in &mut out {
*v /= sum;
}
} else {
let uniform = 1.0_f64 / n as f64;
out.fill(uniform);
}
out
}
}
#[derive(Debug, Clone)]
pub struct TopKGate {
weights: Array2<f64>,
k: usize,
}
impl TopKGate {
pub fn from_weights(weights: Array2<f64>, k: usize) -> MoeResult<Self> {
let num_experts = weights.nrows();
let d_model = weights.ncols();
if num_experts == 0 {
return Err(MoeError::EmptyExpertPool);
}
if d_model == 0 {
return Err(MoeError::ShapeMismatch {
expected: 1,
got: 0,
});
}
if k == 0 || k > num_experts {
return Err(MoeError::InvalidTopK { k, num_experts });
}
Ok(Self { weights, k })
}
pub fn xavier_init(d_model: usize, num_experts: usize, k: usize, seed: u64) -> MoeResult<Self> {
if num_experts == 0 {
return Err(MoeError::EmptyExpertPool);
}
if d_model == 0 {
return Err(MoeError::ShapeMismatch {
expected: 1,
got: 0,
});
}
if k == 0 || k > num_experts {
return Err(MoeError::InvalidTopK { k, num_experts });
}
let std = (2.0_f64 / (d_model + num_experts) as f64).sqrt();
let dist = Normal::new(0.0, std).map_err(|_| MoeError::ShapeMismatch {
expected: 1,
got: 0,
})?;
let mut rng = StdRng::seed_from_u64(seed);
let mut weights = Array2::<f64>::zeros((num_experts, d_model));
for value in weights.iter_mut() {
*value = rng.sample(dist);
}
Ok(Self { weights, k })
}
pub fn num_experts(&self) -> usize {
self.weights.nrows()
}
pub fn d_model(&self) -> usize {
self.weights.ncols()
}
pub fn k(&self) -> usize {
self.k
}
pub fn set_k(&mut self, k: usize) -> MoeResult<()> {
if k == 0 || k > self.num_experts() {
return Err(MoeError::InvalidTopK {
k,
num_experts: self.num_experts(),
});
}
self.k = k;
Ok(())
}
pub fn weights(&self) -> &Array2<f64> {
&self.weights
}
pub fn logits(&self, x: &ArrayView1<f64>) -> MoeResult<Array1<f64>> {
if x.len() != self.d_model() {
return Err(MoeError::ShapeMismatch {
expected: self.d_model(),
got: x.len(),
});
}
Ok(self.weights.dot(x))
}
pub fn forward(&self, x: &ArrayView1<f64>) -> MoeResult<GatingDecision> {
let logits = self.logits(x)?;
let raw_logits_vec: Vec<f64> = logits.to_vec();
let mut order: Vec<usize> = (0..raw_logits_vec.len()).collect();
order.sort_by(|&a, &b| {
raw_logits_vec[b]
.partial_cmp(&raw_logits_vec[a])
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.cmp(&b))
});
let mut top_k_indices: SmallVec<[usize; 8]> = SmallVec::new();
let mut top_k_logits: SmallVec<[f64; 8]> = SmallVec::new();
for &idx in order.iter().take(self.k) {
top_k_indices.push(idx);
top_k_logits.push(raw_logits_vec[idx]);
}
let max_logit = top_k_logits
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let mut exp_values: SmallVec<[f64; 8]> = SmallVec::new();
let mut sum = 0.0_f64;
for &lg in top_k_logits.iter() {
let e = (lg - max_logit).exp();
sum += e;
exp_values.push(e);
}
let mut top_k_softmax_weights: SmallVec<[f64; 8]> = SmallVec::new();
if sum > 0.0 {
for e in exp_values.iter() {
top_k_softmax_weights.push(*e / sum);
}
} else {
let uniform = 1.0_f64 / self.k as f64;
for _ in 0..self.k {
top_k_softmax_weights.push(uniform);
}
}
Ok(GatingDecision {
top_k_indices,
top_k_softmax_weights,
raw_logits: raw_logits_vec,
})
}
}