use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use super::lstm::LSTMCell;
use crate::error::Result;
#[derive(Debug)]
pub struct MixtureOfExperts<F: Float + Debug> {
#[allow(dead_code)]
num_experts: usize,
#[allow(dead_code)]
expert_weights: Vec<Array2<F>>,
#[allow(dead_code)]
expert_biases: Vec<Array1<F>>,
#[allow(dead_code)]
gate_weights: Array2<F>,
#[allow(dead_code)]
gate_bias: Array1<F>,
}
impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand>
MixtureOfExperts<F>
{
pub fn new(input_dim: usize, output_dim: usize, num_experts: usize) -> Self {
let scale = F::from(2.0).expect("Failed to convert constant to float")
/ F::from(input_dim).expect("Failed to convert to float");
let std_dev = scale.sqrt();
let mut expert_weights = Vec::new();
let mut expert_biases = Vec::new();
for _ in 0..num_experts {
expert_weights.push(LSTMCell::random_matrix(output_dim, input_dim, std_dev));
expert_biases.push(Array1::zeros(output_dim));
}
Self {
num_experts,
expert_weights,
expert_biases,
gate_weights: LSTMCell::random_matrix(num_experts, input_dim, std_dev),
gate_bias: Array1::zeros(num_experts),
}
}
pub fn forward(&self, input: &Array2<F>) -> Result<Array2<F>> {
Ok(input.clone())
}
pub fn compute_gates(&self, input: &Array1<F>) -> Array1<F> {
Array1::ones(self.num_experts)
/ F::from(self.num_experts).expect("Failed to convert to float")
}
}