use ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::random::{Normal, SeedableRng, StdRng};
use super::error::{MoeError, MoeResult};
pub trait Expert: Send + Sync {
fn input_dim(&self) -> usize;
fn output_dim(&self) -> usize;
fn forward(&self, x: &ArrayView1<f64>) -> MoeResult<Array1<f64>>;
}
#[derive(Debug, Clone)]
pub struct LinearExpert {
weights: Array2<f64>,
bias: Array1<f64>,
}
impl LinearExpert {
pub fn from_arrays(weights: Array2<f64>, bias: Array1<f64>) -> MoeResult<Self> {
if bias.len() != weights.nrows() {
return Err(MoeError::ShapeMismatch {
expected: weights.nrows(),
got: bias.len(),
});
}
Ok(Self { weights, bias })
}
pub fn zeros(d_in: usize, d_out: usize) -> MoeResult<Self> {
if d_in == 0 || d_out == 0 {
return Err(MoeError::ShapeMismatch {
expected: d_in.max(1),
got: 0,
});
}
Ok(Self {
weights: Array2::zeros((d_out, d_in)),
bias: Array1::zeros(d_out),
})
}
pub fn xavier_init(d_in: usize, d_out: usize, seed: u64) -> MoeResult<Self> {
if d_in == 0 || d_out == 0 {
return Err(MoeError::ShapeMismatch {
expected: d_in.max(1),
got: 0,
});
}
let std = (2.0_f64 / (d_in + d_out) as f64).sqrt();
if !(std.is_finite() && std > 0.0) {
return Err(MoeError::ShapeMismatch {
expected: 1,
got: 0,
});
}
let dist = Normal::new(0.0, std).map_err(|e| MoeError::ShapeMismatch {
expected: 1,
got: format!("{e}").len(),
})?;
let mut rng = StdRng::seed_from_u64(seed);
let mut weights = Array2::<f64>::zeros((d_out, d_in));
for value in weights.iter_mut() {
*value = rng.sample(dist);
}
let bias = Array1::<f64>::zeros(d_out);
Ok(Self { weights, bias })
}
pub fn weights(&self) -> &Array2<f64> {
&self.weights
}
pub fn bias(&self) -> &Array1<f64> {
&self.bias
}
}
impl Expert for LinearExpert {
fn input_dim(&self) -> usize {
self.weights.ncols()
}
fn output_dim(&self) -> usize {
self.weights.nrows()
}
fn forward(&self, x: &ArrayView1<f64>) -> MoeResult<Array1<f64>> {
if x.len() != self.input_dim() {
return Err(MoeError::ShapeMismatch {
expected: self.input_dim(),
got: x.len(),
});
}
let mut out = self.weights.dot(x);
out += &self.bias;
Ok(out)
}
}
#[cfg(test)]
#[allow(dead_code)]
pub(super) struct IdentityExpert {
dim: usize,
}
#[cfg(test)]
impl IdentityExpert {
#[allow(dead_code)]
pub(super) fn new(dim: usize) -> Self {
Self { dim }
}
}
#[cfg(test)]
impl Expert for IdentityExpert {
fn input_dim(&self) -> usize {
self.dim
}
fn output_dim(&self) -> usize {
self.dim
}
fn forward(&self, x: &ArrayView1<f64>) -> MoeResult<Array1<f64>> {
if x.len() != self.dim {
return Err(MoeError::ShapeMismatch {
expected: self.dim,
got: x.len(),
});
}
Ok(x.to_owned())
}
}