use serde::{Deserialize, Serialize};
use crate::moe_model::{MoEModel, MoESize};
use crate::{Error, Result};
pub use super::dense::{DenseConfig, DenseModel, QualityModel};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelKind {
MoE,
Dense,
}
impl Default for ModelKind {
fn default() -> Self {
ModelKind::MoE
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum RouterKind {
SheafPadic,
SheafOnly,
SoftmaxOnly,
}
impl Default for RouterKind {
fn default() -> Self {
RouterKind::SheafPadic
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ModelArch {
pub size: MoESize,
pub input_dim: usize,
pub hidden_dim: usize,
pub output_dim: usize,
pub n_experts: usize,
pub top_k: usize,
pub expert_depth: usize,
pub seed: u64,
#[serde(default)]
pub model_kind: ModelKind,
#[serde(default)]
pub router_kind: RouterKind,
}
impl Default for ModelArch {
fn default() -> Self {
Self::from_size(MoESize::Tiny, 0)
}
}
impl ModelArch {
pub fn from_moe_model(model: &MoEModel) -> Self {
let (hidden, depth) = model.size.dims();
Self {
size: model.size,
input_dim: crate::moe_model::topology::IN_DIM,
hidden_dim: hidden,
output_dim: crate::moe_model::topology::OUT_DIM,
n_experts: crate::moe_model::topology::N_EXPERTS,
top_k: crate::moe_model::topology::TOP_K,
expert_depth: depth,
seed: 0,
model_kind: ModelKind::MoE,
router_kind: RouterKind::SheafPadic,
}
}
pub fn from_size(size: MoESize, seed: u64) -> Self {
let (hidden, depth) = size.dims();
Self {
size,
input_dim: crate::moe_model::topology::IN_DIM,
hidden_dim: hidden,
output_dim: crate::moe_model::topology::OUT_DIM,
n_experts: crate::moe_model::topology::N_EXPERTS,
top_k: crate::moe_model::topology::TOP_K,
expert_depth: depth,
seed,
model_kind: ModelKind::MoE,
router_kind: RouterKind::SheafPadic,
}
}
pub fn from_dense(cfg: &DenseConfig, seed: u64) -> Self {
Self {
size: MoESize::Tiny,
input_dim: cfg.input_dim,
hidden_dim: cfg.hidden_dim,
output_dim: cfg.output_dim,
n_experts: 0,
top_k: 0,
expert_depth: cfg.n_blocks,
seed,
model_kind: ModelKind::Dense,
router_kind: RouterKind::SheafPadic,
}
}
pub fn build(&self) -> Result<MoEModel> {
if self.model_kind == ModelKind::Dense {
return Err(Error::backend(format!(
"arch: build() refuses ModelKind::Dense; use build_quality_model() \
to construct the matching DenseModel (size={:?}, \
model_kind={:?}, router_kind={:?})",
self.size, self.model_kind, self.router_kind
)));
}
let (hidden, depth) = self.size.dims();
if self.hidden_dim != hidden {
return Err(Error::backend(format!(
"arch: hidden_dim {} does not match size {:?} dims {}",
self.hidden_dim, self.size, hidden
)));
}
if self.expert_depth != depth {
return Err(Error::backend(format!(
"arch: expert_depth {} does not match size {:?} dims {}",
self.expert_depth, self.size, depth
)));
}
if self.input_dim != crate::moe_model::topology::IN_DIM {
return Err(Error::backend(format!(
"arch: input_dim {} does not match topology IN_DIM {}",
self.input_dim,
crate::moe_model::topology::IN_DIM
)));
}
if self.output_dim != crate::moe_model::topology::OUT_DIM {
return Err(Error::backend(format!(
"arch: output_dim {} does not match topology OUT_DIM {}",
self.output_dim,
crate::moe_model::topology::OUT_DIM
)));
}
if self.n_experts != crate::moe_model::topology::N_EXPERTS {
return Err(Error::backend(format!(
"arch: n_experts {} does not match topology N_EXPERTS {}",
self.n_experts,
crate::moe_model::topology::N_EXPERTS
)));
}
if self.top_k != crate::moe_model::topology::TOP_K {
return Err(Error::backend(format!(
"arch: top_k {} does not match topology TOP_K {}",
self.top_k,
crate::moe_model::topology::TOP_K
)));
}
Ok(MoEModel::new(self.size, self.seed))
}
pub fn build_quality_model(&self) -> Result<QualityModel> {
match self.model_kind {
ModelKind::MoE => Ok(QualityModel::MoE(self.build()?)),
ModelKind::Dense => {
if self.n_experts != 0 {
return Err(Error::backend(format!(
"arch: ModelKind::Dense requires n_experts=0, got {}",
self.n_experts
)));
}
if self.top_k != 0 {
return Err(Error::backend(format!(
"arch: ModelKind::Dense requires top_k=0, got {}",
self.top_k
)));
}
let cfg = DenseConfig {
input_dim: self.input_dim,
hidden_dim: self.hidden_dim,
intermediate: 4096,
output_dim: self.output_dim,
n_blocks: self.expert_depth,
};
Ok(QualityModel::Dense(DenseModel::new(cfg, self.seed)))
}
}
}
}