use std::cell::RefCell;
use crate::model::layer::{GELU, Layer, LayerNorm, Linear};
use crate::model::sequential::{Model, Sequential};
use crate::moe_model::MoEModel;
use crate::object::{Shape, Tensor};
use crate::{Error, Result};
pub struct DenseForwardCache {
pub input: Tensor<f32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct DenseConfig {
pub input_dim: usize,
pub hidden_dim: usize,
pub intermediate: usize,
pub output_dim: usize,
pub n_blocks: usize,
}
pub struct DenseModel {
pub config: DenseConfig,
pub input_proj: Sequential,
pub blocks: Vec<Sequential>,
pub output_proj: Sequential,
last_cache: RefCell<Option<DenseForwardCache>>,
}
impl DenseModel {
pub fn new(cfg: DenseConfig, seed: u64) -> Self {
let body_seed = ((seed >> 32) as u32).wrapping_add(0xB0B0_0001);
let proj_seed = (seed & 0xFFFF_FFFF) as u32;
let mut input_layers: Vec<Box<dyn Layer>> = Vec::new();
input_layers.push(Box::new(Linear::new(
cfg.input_dim,
cfg.hidden_dim,
proj_seed.wrapping_add(0x10),
)));
input_layers.push(Box::new(LayerNorm::new(cfg.hidden_dim, 1e-5)));
let input_proj = Sequential::new(input_layers);
let mut blocks: Vec<Sequential> = Vec::with_capacity(cfg.n_blocks);
for bi in 0..cfg.n_blocks {
let block_seed = body_seed.wrapping_add((bi as u32).wrapping_mul(0x1000));
let mut layers: Vec<Box<dyn Layer>> = Vec::new();
layers.push(Box::new(Linear::new(
cfg.hidden_dim,
cfg.intermediate,
block_seed.wrapping_add(0x20),
)));
layers.push(Box::new(LayerNorm::new(cfg.intermediate, 1e-5)));
layers.push(Box::new(GELU));
layers.push(Box::new(Linear::new(
cfg.intermediate,
cfg.hidden_dim,
block_seed.wrapping_add(0x30),
)));
layers.push(Box::new(LayerNorm::new(cfg.hidden_dim, 1e-5)));
layers.push(Box::new(GELU));
blocks.push(Sequential::new(layers));
}
let mut output_layers: Vec<Box<dyn Layer>> = Vec::new();
output_layers.push(Box::new(Linear::new(
cfg.hidden_dim,
cfg.output_dim,
proj_seed.wrapping_add(0x40),
)));
let output_proj = Sequential::new(output_layers);
Self {
config: cfg,
input_proj,
blocks,
output_proj,
last_cache: RefCell::new(None),
}
}
pub fn scalar_param_count(&self) -> usize {
let mut n = 0usize;
for p in self.input_proj.parameters() {
n += p.numel();
}
for block in &self.blocks {
for p in block.parameters() {
n += p.numel();
}
}
for p in self.output_proj.parameters() {
n += p.numel();
}
n
}
pub fn parameter_count(&self) -> usize {
let mut n = self.input_proj.parameters().len();
for block in &self.blocks {
n += block.parameters().len();
}
n += self.output_proj.parameters().len();
n
}
pub fn block_mut(&mut self, idx: usize) -> &mut Sequential {
&mut self.blocks[idx]
}
pub fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let current = self
.input_proj
.forward(&[input.clone()])
.map_err(|e| Error::backend(format!("DenseModel::forward input_proj: {e}")))?;
if current.len() != 1 {
return Err(Error::backend(format!(
"DenseModel::forward input_proj returned {} tensors, expected 1",
current.len()
)));
}
let mut current = current.into_iter().next().expect("len=1");
for (bi, block) in self.blocks.iter().enumerate() {
let outs = block
.forward(&[current.clone()])
.map_err(|e| Error::backend(format!("DenseModel::forward block {}: {e}", bi)))?;
if outs.len() != 1 {
return Err(Error::backend(format!(
"DenseModel::forward block {} returned {} tensors, expected 1",
bi,
outs.len()
)));
}
current = outs.into_iter().next().expect("len=1");
}
let logits_vec = self
.output_proj
.forward(&[current])
.map_err(|e| Error::backend(format!("DenseModel::forward output_proj: {e}")))?;
if logits_vec.len() != 1 {
return Err(Error::backend(format!(
"DenseModel::forward output_proj returned {} tensors, expected 1",
logits_vec.len()
)));
}
let logits = logits_vec.into_iter().next().expect("len=1");
*self.last_cache.borrow_mut() = Some(DenseForwardCache {
input: input.clone(),
});
Ok(logits)
}
pub fn backward(&self, grad_output: &Tensor<f32>) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
let cache = self
.last_cache
.borrow_mut()
.take()
.ok_or_else(|| Error::backend("DenseModel::backward called before forward"))?;
let (per_input_grads, mut param_grads) = self
.output_proj
.backward(&[grad_output.clone()])
.map_err(|e| Error::backend(format!("DenseModel::backward output_proj: {e}")))?;
if per_input_grads.len() != 1 {
return Err(Error::backend(format!(
"DenseModel::backward output_proj returned {} input grads, expected 1",
per_input_grads.len()
)));
}
let mut current_grad = per_input_grads.into_iter().next().expect("len=1");
for bi in (0..self.blocks.len()).rev() {
let (g, p) = self.blocks[bi]
.backward(&[current_grad])
.map_err(|e| Error::backend(format!("DenseModel::backward block {}: {e}", bi)))?;
if g.len() != 1 {
return Err(Error::backend(format!(
"DenseModel::backward block {} returned {} input grads, expected 1",
bi,
g.len()
)));
}
current_grad = g.into_iter().next().expect("len=1");
param_grads.extend(p);
}
let (g, p) = self
.input_proj
.backward(&[current_grad])
.map_err(|e| Error::backend(format!("DenseModel::backward input_proj: {e}")))?;
if g.len() != 1 {
return Err(Error::backend(format!(
"DenseModel::backward input_proj returned {} input grads, expected 1",
g.len()
)));
}
let grad_input = g.into_iter().next().expect("len=1");
param_grads.extend(p);
param_grads.reverse();
let _ = cache;
Ok((grad_input, param_grads))
}
}
fn total_scalar_params(maybe_moe: Option<&MoEModel>, dense: &DenseModel) -> usize {
let mut n = dense.scalar_param_count();
if let Some(m) = maybe_moe {
n += m.scalar_param_count();
}
n
}
pub enum QualityModel {
MoE(MoEModel),
Dense(DenseModel),
}
impl QualityModel {
pub fn scalar_param_count(&self) -> usize {
match self {
QualityModel::MoE(m) => m.scalar_param_count(),
QualityModel::Dense(m) => m.scalar_param_count(),
}
}
pub fn parameter_count(&self) -> usize {
match self {
QualityModel::MoE(m) => m.parameter_count(),
QualityModel::Dense(m) => m.parameter_count(),
}
}
pub fn forward(&self, input: &Tensor<f32>) -> Result<QualityOutput> {
match self {
QualityModel::MoE(m) => {
let out = m.forward(input)?;
Ok(QualityOutput {
logits: out.logits,
router_weights: out.router_weights,
})
}
QualityModel::Dense(m) => {
let logits = m.forward(input)?;
let b = match &logits.meta.shape.dims[0] {
crate::object::Dim::Static(v) => *v,
_ => {
return Err(Error::shape(
"QualityModel::Dense::forward logits batch dim must be static",
));
}
};
let router_weights = Tensor::dense_cpu(
logits.meta.domain.clone(),
Shape::from(vec![b, 1]),
vec![1.0f32; b],
);
Ok(QualityOutput {
logits,
router_weights,
})
}
}
}
pub fn backward(&self, grad_output: &Tensor<f32>) -> Result<(Tensor<f32>, Vec<Tensor<f32>>)> {
match self {
QualityModel::MoE(m) => m.backward(grad_output),
QualityModel::Dense(m) => m.backward(grad_output),
}
}
pub fn all_parameters_mut(&mut self) -> Vec<&mut crate::model::parameter::Parameter> {
match self {
QualityModel::MoE(m) => {
let mut out: Vec<&mut crate::model::parameter::Parameter> =
Vec::with_capacity(m.parameter_count());
out.extend(m.router.parameters_mut());
for expert in m.experts.iter_mut() {
out.extend(expert.parameters_mut());
}
out
}
QualityModel::Dense(m) => {
let mut out: Vec<&mut crate::model::parameter::Parameter> =
Vec::with_capacity(m.parameter_count());
out.extend(m.input_proj.parameters_mut());
for block in m.blocks.iter_mut() {
out.extend(block.parameters_mut());
}
out.extend(m.output_proj.parameters_mut());
out
}
}
}
#[allow(dead_code)]
pub fn total_with(maybe_moe: Option<&MoEModel>, dense: &DenseModel) -> usize {
total_scalar_params(maybe_moe, dense)
}
}
pub struct QualityOutput {
pub logits: Tensor<f32>,
pub router_weights: Tensor<f32>,
}