use crate::error::{LmError, LmResult};
use crate::weights::WeightTensor;
#[inline]
pub fn gelu(x: f32) -> f32 {
let c = (2.0_f32 / std::f32::consts::PI).sqrt();
let inner = c * (x + 0.044_715 * x * x * x);
0.5 * x * (1.0 + inner.tanh())
}
#[inline]
pub fn silu(x: f32) -> f32 {
x / (1.0 + (-x).exp())
}
fn linear_vec(w: &WeightTensor, b: Option<&[f32]>, x: &[f32]) -> LmResult<Vec<f32>> {
if w.shape.len() != 2 {
return Err(LmError::DimensionMismatch {
expected: 2,
got: w.shape.len(),
});
}
let out_dim = w.shape[0];
let in_dim = w.shape[1];
if x.len() != in_dim {
return Err(LmError::DimensionMismatch {
expected: in_dim,
got: x.len(),
});
}
let mut out = vec![0.0_f32; out_dim];
for (i, o) in out.iter_mut().enumerate() {
let row = &w.data[i * in_dim..(i + 1) * in_dim];
*o = row.iter().zip(x.iter()).map(|(&wi, &xi)| wi * xi).sum();
}
if let Some(bias) = b {
if bias.len() != out_dim {
return Err(LmError::DimensionMismatch {
expected: out_dim,
got: bias.len(),
});
}
for (o, &bi) in out.iter_mut().zip(bias.iter()) {
*o += bi;
}
}
Ok(out)
}
pub(crate) fn linear_batch(
w: &WeightTensor,
b: Option<&[f32]>,
x: &[f32],
n_tokens: usize,
) -> LmResult<Vec<f32>> {
if n_tokens == 1 {
return linear_vec(w, b, x);
}
if w.shape.len() != 2 {
return Err(LmError::DimensionMismatch {
expected: 2,
got: w.shape.len(),
});
}
let out_dim = w.shape[0];
let in_dim = w.shape[1];
if x.len() != n_tokens * in_dim {
return Err(LmError::DimensionMismatch {
expected: n_tokens * in_dim,
got: x.len(),
});
}
let mut out = vec![0.0_f32; n_tokens * out_dim];
for t in 0..n_tokens {
let x_row = &x[t * in_dim..(t + 1) * in_dim];
let o_row = &mut out[t * out_dim..(t + 1) * out_dim];
for (i, o) in o_row.iter_mut().enumerate() {
let w_row = &w.data[i * in_dim..(i + 1) * in_dim];
*o = w_row
.iter()
.zip(x_row.iter())
.map(|(&wi, &xi)| wi * xi)
.sum();
}
if let Some(bias) = b {
for (o, &bi) in o_row.iter_mut().zip(bias.iter()) {
*o += bi;
}
}
}
Ok(out)
}
#[derive(Debug, Clone)]
pub struct MlpFfn {
pub hidden_dim: usize,
pub intermediate_dim: usize,
pub w_fc: WeightTensor,
pub b_fc: Vec<f32>,
pub w_proj: WeightTensor,
pub b_proj: Vec<f32>,
}
impl MlpFfn {
pub fn new(hidden_dim: usize, intermediate_dim: usize) -> LmResult<Self> {
if hidden_dim == 0 || intermediate_dim == 0 {
return Err(LmError::InvalidConfig {
msg: "MlpFfn: dimensions must be > 0".into(),
});
}
Ok(Self {
hidden_dim,
intermediate_dim,
w_fc: WeightTensor::zeros(&[intermediate_dim, hidden_dim]),
b_fc: vec![0.0; intermediate_dim],
w_proj: WeightTensor::zeros(&[hidden_dim, intermediate_dim]),
b_proj: vec![0.0; hidden_dim],
})
}
pub fn forward(&self, x: &[f32], n_tokens: usize) -> LmResult<Vec<f32>> {
if n_tokens == 0 {
return Err(LmError::EmptyInput {
context: "MlpFfn::forward n_tokens",
});
}
let mut h = linear_batch(&self.w_fc, Some(&self.b_fc), x, n_tokens)?;
for v in &mut h {
*v = gelu(*v);
}
linear_batch(&self.w_proj, Some(&self.b_proj), &h, n_tokens)
}
}
#[derive(Debug, Clone)]
pub struct SwiGluFfn {
pub hidden_dim: usize,
pub intermediate_dim: usize,
pub w_gate: WeightTensor,
pub w_up: WeightTensor,
pub w_down: WeightTensor,
}
impl SwiGluFfn {
pub fn new(hidden_dim: usize, intermediate_dim: usize) -> LmResult<Self> {
if hidden_dim == 0 || intermediate_dim == 0 {
return Err(LmError::InvalidConfig {
msg: "SwiGluFfn: dimensions must be > 0".into(),
});
}
Ok(Self {
hidden_dim,
intermediate_dim,
w_gate: WeightTensor::zeros(&[intermediate_dim, hidden_dim]),
w_up: WeightTensor::zeros(&[intermediate_dim, hidden_dim]),
w_down: WeightTensor::zeros(&[hidden_dim, intermediate_dim]),
})
}
pub fn forward(&self, x: &[f32], n_tokens: usize) -> LmResult<Vec<f32>> {
if n_tokens == 0 {
return Err(LmError::EmptyInput {
context: "SwiGluFfn::forward n_tokens",
});
}
let gate = linear_batch(&self.w_gate, None, x, n_tokens)?;
let up = linear_batch(&self.w_up, None, x, n_tokens)?;
let mut h = vec![0.0_f32; n_tokens * self.intermediate_dim];
for (i, ((&g, &u), h_out)) in gate.iter().zip(up.iter()).zip(h.iter_mut()).enumerate() {
let _ = i;
*h_out = silu(g) * u;
}
linear_batch(&self.w_down, None, &h, n_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gelu_at_zero() {
assert!((gelu(0.0)).abs() < 1e-6);
}
#[test]
fn gelu_positive_input() {
assert!((gelu(10.0) - 10.0).abs() < 1e-3);
}
#[test]
fn gelu_negative_input() {
assert!(gelu(-10.0).abs() < 1e-3);
}
#[test]
fn silu_at_zero() {
assert!(silu(0.0).abs() < 1e-6);
}
#[test]
fn silu_positive() {
assert!((silu(10.0) - 10.0).abs() < 1e-3);
}
#[test]
fn silu_gradient_through_zero() {
assert!(silu(1.0) > 0.0);
assert!(silu(-1.0) < 0.0);
}
#[test]
fn linear_vec_identity() {
let w = WeightTensor::eye(2, 2);
let out = linear_vec(&w, None, &[3.0, 4.0]).unwrap();
assert!((out[0] - 3.0).abs() < 1e-6);
assert!((out[1] - 4.0).abs() < 1e-6);
}
#[test]
fn linear_vec_with_bias() {
let w = WeightTensor::eye(2, 2);
let bias = vec![10.0_f32, 20.0];
let out = linear_vec(&w, Some(&bias), &[1.0, 2.0]).unwrap();
assert!((out[0] - 11.0).abs() < 1e-6);
assert!((out[1] - 22.0).abs() < 1e-6);
}
#[test]
fn mlp_ffn_zero_weights_zero_output() {
let ffn = MlpFfn::new(4, 8).unwrap();
let x = vec![1.0_f32; 4];
let out = ffn.forward(&x, 1).unwrap();
assert!(out.iter().all(|&v| v.abs() < 1e-6));
}
#[test]
fn mlp_ffn_identity_chain() {
let mut ffn = MlpFfn::new(4, 4).unwrap();
ffn.w_fc = WeightTensor::eye(4, 4);
ffn.w_proj = WeightTensor::eye(4, 4);
let x = vec![1.0_f32; 4];
let out = ffn.forward(&x, 1).unwrap();
let expected = gelu(1.0);
for &v in &out {
assert!((v - expected).abs() < 1e-5, "v={v} expected={expected}");
}
}
#[test]
fn mlp_ffn_batch_tokens() {
let ffn = MlpFfn::new(4, 8).unwrap();
let x = vec![0.0_f32; 2 * 4]; let out = ffn.forward(&x, 2).unwrap();
assert_eq!(out.len(), 2 * 4);
}
#[test]
fn mlp_ffn_zero_dim_error() {
assert!(MlpFfn::new(0, 8).is_err());
}
#[test]
fn swiglu_ffn_zero_weights_zero_output() {
let ffn = SwiGluFfn::new(4, 8).unwrap();
let x = vec![1.0_f32; 4];
let out = ffn.forward(&x, 1).unwrap();
assert!(out.iter().all(|&v| v.abs() < 1e-6));
}
#[test]
fn swiglu_ffn_gate_identity() {
let mut ffn = SwiGluFfn::new(4, 4).unwrap();
ffn.w_gate = WeightTensor::eye(4, 4);
ffn.w_up = WeightTensor::eye(4, 4);
ffn.w_down = WeightTensor::eye(4, 4);
let x = vec![2.0_f32; 4];
let out = ffn.forward(&x, 1).unwrap();
let expected = silu(2.0) * 2.0;
for &v in &out {
assert!((v - expected).abs() < 1e-5, "v={v} expected={expected}");
}
}
#[test]
fn swiglu_ffn_batch_tokens() {
let ffn = SwiGluFfn::new(4, 8).unwrap();
let x = vec![0.0_f32; 3 * 4];
let out = ffn.forward(&x, 3).unwrap();
assert_eq!(out.len(), 3 * 4);
}
#[test]
fn swiglu_ffn_zero_dim_error() {
assert!(SwiGluFfn::new(4, 0).is_err());
}
}