use std::any::Any;
use std::cell::RefCell;
use crate::model::layer::{Layer, Router};
use crate::model::sequential::{Model, Sequential};
use crate::object::Tensor;
pub mod backward;
pub mod builder;
pub mod diagnose_ref;
pub mod forward;
pub mod topology;
pub use topology::{IN_DIM, N_EXPERTS, OUT_DIM, TOP_K, param_count, params_for_size};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum MoESize {
Nano,
Tiny,
Medium,
Full,
}
impl MoESize {
pub fn dims(self) -> (usize, usize) {
match self {
MoESize::Nano => (128, 2),
MoESize::Tiny => (2048, 5),
MoESize::Medium => (4096, 5),
MoESize::Full => (4096, 12),
}
}
pub fn name(self) -> &'static str {
match self {
MoESize::Nano => "Nano",
MoESize::Tiny => "Tiny",
MoESize::Medium => "Medium",
MoESize::Full => "Full",
}
}
pub fn param_count(self) -> usize {
topology::param_count(self)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RouterConfig {
pub in_features: usize,
pub n_experts: usize,
pub top_k: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ExpertConfig {
pub in_features: usize,
pub hidden: usize,
pub n_hidden_layers: usize,
pub out_features: usize,
}
#[derive(Debug, Clone)]
pub struct MoEOutput {
pub logits: Tensor<f32>,
pub router_weights: Tensor<f32>,
}
pub struct MoEForwardCache {
pub input: Tensor<f32>,
pub router_cache: Box<dyn Any + Send>,
pub router_weights: Tensor<f32>,
pub expert_outputs: Vec<Tensor<f32>>,
}
pub struct MoEModel {
pub size: MoESize,
pub router: Router,
pub experts: Vec<Sequential>,
pub last_cache: RefCell<Option<MoEForwardCache>>,
}
impl MoEModel {
pub fn parameter_count(&self) -> usize {
let mut n = self.router.parameters().len();
for expert in &self.experts {
n += expert.parameters().len();
}
n
}
pub fn scalar_param_count(&self) -> usize {
let mut n = 0usize;
for p in self.router.parameters() {
n += p.numel();
}
for expert in &self.experts {
for p in expert.parameters() {
n += p.numel();
}
}
n
}
pub fn expert(&self, idx: usize) -> &Sequential {
&self.experts[idx]
}
pub fn expert_mut(&mut self, idx: usize) -> &mut Sequential {
&mut self.experts[idx]
}
}