use crate::riir::backend::cpu::cpu_matvec::{project_4bit_cpu, CpuMatvecError};
use crate::riir::variants::VARIANT;
use crate::riir::io::weight_file::WeightFile;
#[derive(Debug, thiserror::Error)]
pub enum MlpForwardError {
#[error("hidden buffer length {got} != hidden_dim ({expected})")]
HiddenLen { got: usize, expected: usize },
#[error("output buffer length {got} != hidden_dim ({expected})")]
OutLen { got: usize, expected: usize },
#[error("4-bit matvec error in MLP: {0}")]
Matvec(#[from] CpuMatvecError),
}
pub fn dense_mlp_swiglu_cpu(
wf: &WeightFile,
layer_idx: usize,
hidden: &[f32],
out: &mut [f32],
) -> Result<(), MlpForwardError> {
let prefix = format!("model.layers.{layer_idx}.mlp");
swiglu_ffn_4bit_cpu(wf, &prefix, VARIANT.dense_intermediate, hidden, out)
}
pub fn shared_expert_swiglu_cpu(
wf: &WeightFile,
layer_idx: usize,
hidden: &[f32],
out: &mut [f32],
) -> Result<(), MlpForwardError> {
let prefix = format!("model.layers.{layer_idx}.mlp.shared_experts");
swiglu_ffn_4bit_cpu(wf, &prefix, VARIANT.shared_intermediate, hidden, out)
}
pub fn swiglu_ffn_4bit_cpu(
wf: &WeightFile,
prefix: &str,
intermediate: usize,
hidden: &[f32],
out: &mut [f32],
) -> Result<(), MlpForwardError> {
let hidden_dim = VARIANT.hidden_dim;
if hidden.len() != hidden_dim {
return Err(MlpForwardError::HiddenLen {
got: hidden.len(),
expected: hidden_dim,
});
}
if out.len() != hidden_dim {
return Err(MlpForwardError::OutLen {
got: out.len(),
expected: hidden_dim,
});
}
let gate_name = format!("{prefix}.gate_proj");
let up_name = format!("{prefix}.up_proj");
let down_name = format!("{prefix}.down_proj");
let mut gate = vec![0.0f32; intermediate];
let mut up = vec![0.0f32; intermediate];
project_4bit_cpu(wf, &gate_name, hidden_dim, intermediate, hidden, &mut gate)?;
project_4bit_cpu(wf, &up_name, hidden_dim, intermediate, hidden, &mut up)?;
for i in 0..intermediate {
let g = gate[i];
let silu = g / (1.0 + (-g).exp());
gate[i] = silu * up[i];
}
project_4bit_cpu(wf, &down_name, intermediate, hidden_dim, &gate, out)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "model-cogito-v2-671b")]
#[test]
#[ignore = "needs Cogito-V2 weights mmap'd from /Volumes/Temp Backup"]
fn dense_mlp_layer0_smoke() {
use std::path::Path;
let bin = Path::new(
"/Volumes/Temp Backup/models/blallama/cogito-v2-671b/artifacts/model_weights.bin",
);
let manifest = Path::new(
"/Volumes/Temp Backup/models/blallama/cogito-v2-671b/artifacts/model_weights.json",
);
let wf = WeightFile::open(bin, manifest).expect("open weights");
let v = VARIANT;
let mut hidden = vec![0.0f32; v.hidden_dim];
hidden[42] = 1.0;
let mut out = vec![0.0f32; v.hidden_dim];
dense_mlp_swiglu_cpu(&wf, 0, &hidden, &mut out)
.expect("dense MLP layer 0");
assert!(out.iter().all(|x| x.is_finite()));
let max_abs = out.iter().fold(0.0f32, |m, &x| m.max(x.abs()));
assert!(max_abs > 0.0, "all-zero output — likely a wiring bug");
assert!(max_abs < 1e6, "magnitude {max_abs} suspiciously large");
}
}