use crate::backend::{Backend, BackendResult};
use crate::tensor::{DType, Tensor};
use super::error::{ModelError, ModelResult};
#[derive(Debug)]
pub struct Linear {
pub weight: Tensor,
pub bias: Option<Tensor>,
pub in_features: usize,
pub out_features: usize,
}
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> ModelResult<Self> {
if weight.ndim() != 2 {
return Err(ModelError::ConfigError("Linear weight must be 2D".into()));
}
let in_features = weight.shape()[0];
let out_features = weight.shape()[1];
if let Some(ref b) = bias
&& b.shape() != [out_features]
{
return Err(ModelError::TensorShapeMismatch {
name: "bias".into(),
expected: vec![out_features],
got: b.shape().to_vec(),
});
}
Ok(Self {
weight,
bias,
in_features,
out_features,
})
}
pub fn forward(
&self,
x: &Tensor,
out: &mut Tensor,
backend: &dyn Backend,
) -> BackendResult<()> {
if self.weight.dtype().is_quantized() {
backend.vec_mat_q(x, &self.weight, out)?;
} else {
backend.vec_mat(x, &self.weight, out)?;
}
if let Some(ref bias) = self.bias {
let out_data = out.as_f32_mut()?;
let bias_data = bias.as_f32()?;
for (o, &b) in out_data.iter_mut().zip(bias_data.iter()) {
*o += b;
}
}
Ok(())
}
pub fn forward_no_bias(
&self,
x: &Tensor,
out: &mut Tensor,
backend: &dyn Backend,
) -> BackendResult<()> {
if self.weight.dtype().is_quantized() {
backend.vec_mat_q(x, &self.weight, out)?;
} else {
backend.vec_mat(x, &self.weight, out)?;
}
Ok(())
}
pub fn apply_bias(&self, out: &mut Tensor, _backend: &dyn Backend) -> BackendResult<()> {
if let Some(ref bias) = self.bias {
let out_data = out.as_f32_mut()?;
let bias_data = bias.as_f32()?;
for (o, &b) in out_data.iter_mut().zip(bias_data.iter()) {
*o += b;
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct RMSNorm {
pub weight: Tensor,
pub eps: f32,
pub hidden_size: usize,
}
impl RMSNorm {
pub fn new(weight: Tensor, eps: f32) -> ModelResult<Self> {
if weight.ndim() != 1 {
return Err(ModelError::ConfigError("RMSNorm weight must be 1D".into()));
}
let hidden_size = weight.shape()[0];
Ok(Self {
weight,
eps,
hidden_size,
})
}
pub fn forward(
&self,
x: &Tensor,
out: &mut Tensor,
backend: &dyn Backend,
) -> BackendResult<()> {
backend.rms_norm(x, &self.weight, self.eps, out)
}
}
#[derive(Debug)]
pub struct LayerNorm {
pub weight: Tensor,
pub bias: Tensor,
pub eps: f32,
pub hidden_size: usize,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f32) -> ModelResult<Self> {
if weight.ndim() != 1 {
return Err(ModelError::ConfigError("LayerNorm weight must be 1D".into()));
}
let hidden_size = weight.shape()[0];
Ok(Self { weight, bias, eps, hidden_size })
}
pub fn forward(
&self,
x: &Tensor,
out: &mut Tensor,
_backend: &dyn Backend,
) -> BackendResult<()> {
let x_data = x.as_f32()?;
let out_data = out.as_f32_mut()?;
let w_data = self.weight.as_f32()?;
let b_data = self.bias.as_f32()?;
let n = self.hidden_size;
let mean: f32 = x_data[..n].iter().sum::<f32>() / n as f32;
let var: f32 = x_data[..n].iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / n as f32;
let std_inv = 1.0 / (var + self.eps).sqrt();
for i in 0..n {
out_data[i] = (x_data[i] - mean) * std_inv * w_data[i] + b_data[i];
}
Ok(())
}
}
#[derive(Debug)]
pub enum NormLayer {
RMS(RMSNorm),
Layer(LayerNorm),
}
impl NormLayer {
pub fn forward(
&self,
x: &Tensor,
out: &mut Tensor,
backend: &dyn Backend,
) -> BackendResult<()> {
match self {
Self::RMS(norm) => norm.forward(x, out, backend),
Self::Layer(norm) => norm.forward(x, out, backend),
}
}
pub fn hidden_size(&self) -> usize {
match self {
Self::RMS(n) => n.hidden_size,
Self::Layer(n) => n.hidden_size,
}
}
pub fn weight(&self) -> &Tensor {
match self {
Self::RMS(n) => &n.weight,
Self::Layer(n) => &n.weight,
}
}
pub fn eps(&self) -> f32 {
match self {
Self::RMS(n) => n.eps,
Self::Layer(n) => n.eps,
}
}
pub fn as_rms(&self) -> Option<&RMSNorm> {
match self {
Self::RMS(n) => Some(n),
Self::Layer(_) => None,
}
}
pub fn as_layer_norm(&self) -> Option<&LayerNorm> {
match self {
Self::Layer(n) => Some(n),
Self::RMS(_) => None,
}
}
pub fn is_layer_norm(&self) -> bool {
matches!(self, Self::Layer(_))
}
pub fn bias(&self) -> Option<&Tensor> {
match self {
Self::Layer(n) => Some(&n.bias),
Self::RMS(_) => None,
}
}
}
#[derive(Debug)]
pub struct Attention {
pub wq: Linear,
pub wk: Linear,
pub wv: Linear,
pub wo: Linear,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub key_length: usize,
pub value_length: usize,
pub rope_dims: usize,
pub scale: f32,
pub use_neox_rope: bool,
pub has_attention_gate: bool,
pub q_norm: Option<RMSNorm>,
pub k_norm: Option<RMSNorm>,
pub attn_logit_softcap: f32,
pub rope_partial_at_end: bool,
}
impl Attention {
pub fn new(
wq: Linear,
wk: Linear,
wv: Linear,
wo: Linear,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Self {
Self::with_rope_type(wq, wk, wv, wo, num_heads, num_kv_heads, head_dim, false)
}
#[allow(clippy::too_many_arguments)]
pub fn with_rope_type(
wq: Linear,
wk: Linear,
wv: Linear,
wo: Linear,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
use_neox_rope: bool,
) -> Self {
Self {
wq,
wk,
wv,
wo,
num_heads,
num_kv_heads,
head_dim,
key_length: head_dim,
value_length: head_dim,
rope_dims: head_dim,
scale: 1.0 / (head_dim as f32).sqrt(),
use_neox_rope,
has_attention_gate: false,
q_norm: None,
k_norm: None,
attn_logit_softcap: 0.0,
rope_partial_at_end: false,
}
}
#[allow(clippy::too_many_arguments)]
pub fn with_kv_dims(
wq: Linear,
wk: Linear,
wv: Linear,
wo: Linear,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
key_length: usize,
value_length: usize,
rope_dims: usize,
use_neox_rope: bool,
has_attention_gate: bool,
) -> Self {
Self {
wq,
wk,
wv,
wo,
num_heads,
num_kv_heads,
head_dim,
key_length,
value_length,
rope_dims,
scale: 1.0 / (key_length as f32).sqrt(),
use_neox_rope,
has_attention_gate,
q_norm: None,
k_norm: None,
attn_logit_softcap: 0.0,
rope_partial_at_end: false,
}
}
pub fn set_rope_partial_at_end(&mut self, at_end: bool) {
self.rope_partial_at_end = at_end;
}
pub fn set_qk_norms(&mut self, q_norm: RMSNorm, k_norm: RMSNorm) {
self.q_norm = Some(q_norm);
self.k_norm = Some(k_norm);
}
pub fn set_attn_logit_softcap(&mut self, cap: f32) {
self.attn_logit_softcap = cap;
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
x: &Tensor,
k_cache: &mut Tensor,
v_cache: &mut Tensor,
pos: usize,
freq_base: f32,
freq_scale: f32,
backend: &dyn Backend,
) -> ModelResult<Tensor> {
let hidden_size = x.shape().last().copied().unwrap_or(0);
let seq_len = if x.ndim() == 1 { 1 } else { x.shape()[0] };
let kl = self.key_length;
let vl = self.value_length;
let x_vec = if x.ndim() == 2 {
let x_data = x.as_f32()?;
let start = (seq_len - 1) * hidden_size;
Tensor::from_f32(&x_data[start..start + hidden_size], vec![hidden_size])?
} else {
x.clone()
};
let q_out_size = self.wq.out_features;
let mut q_raw = Tensor::zeros(vec![q_out_size], DType::F32);
let mut k = Tensor::zeros(vec![self.num_kv_heads * kl], DType::F32);
let mut v = Tensor::zeros(vec![self.num_kv_heads * vl], DType::F32);
self.wq.forward(&x_vec, &mut q_raw, backend)?;
self.wk.forward(&x_vec, &mut k, backend)?;
self.wv.forward(&x_vec, &mut v, backend)?;
let (q_proper_data, gate_data) = if self.has_attention_gate {
let raw = q_raw.as_f32()?;
let q_proper_len = self.num_heads * kl;
let gate_len = self.num_heads * vl;
if raw.len() >= q_proper_len + gate_len {
let per_head_total = kl + vl;
let mut q_buf = vec![0.0f32; q_proper_len];
let mut g_buf = vec![0.0f32; gate_len];
for h in 0..self.num_heads {
let src = h * per_head_total;
let q_dst = h * kl;
let g_dst = h * vl;
q_buf[q_dst..q_dst + kl].copy_from_slice(&raw[src..src + kl]);
g_buf[g_dst..g_dst + vl].copy_from_slice(&raw[src + kl..src + kl + vl]);
}
(q_buf, Some(g_buf))
} else {
(raw.to_vec(), None)
}
} else {
(q_raw.as_f32()?.to_vec(), None)
};
let q_flat = Tensor::from_f32(&q_proper_data, vec![self.num_heads * kl])?;
let mut q_reshaped = q_flat.reshape(vec![self.num_heads, 1, kl])?;
let mut k_reshaped = k.reshape(vec![self.num_kv_heads, 1, kl])?;
let v_reshaped = v.reshape(vec![self.num_kv_heads, 1, vl])?;
if let Some(ref q_norm) = self.q_norm {
let q_data = q_reshaped.as_f32()?.to_vec();
let q_out = q_reshaped.as_f32_mut()?;
let norm_w = q_norm.weight.as_f32()?;
let norm_dim = norm_w.len();
for h in 0..self.num_heads {
let offset = h * kl;
let head_slice = &q_data[offset..offset + kl];
let ss: f32 = head_slice[..norm_dim].iter().map(|x| x * x).sum::<f32>()
/ norm_dim as f32;
let rms = (ss + q_norm.eps).sqrt();
for d in 0..norm_dim.min(kl) {
q_out[offset + d] = head_slice[d] / rms * norm_w[d];
}
}
}
if let Some(ref k_norm) = self.k_norm {
let k_data = k_reshaped.as_f32()?.to_vec();
let k_out = k_reshaped.as_f32_mut()?;
let norm_w = k_norm.weight.as_f32()?;
let norm_dim = norm_w.len();
for h in 0..self.num_kv_heads {
let offset = h * kl;
let head_slice = &k_data[offset..offset + kl];
let ss: f32 = head_slice[..norm_dim].iter().map(|x| x * x).sum::<f32>()
/ norm_dim as f32;
let rms = (ss + k_norm.eps).sqrt();
for d in 0..norm_dim.min(kl) {
k_out[offset + d] = head_slice[d] / rms * norm_w[d];
}
}
}
if self.rope_dims > 0 && self.rope_dims < kl {
let rope_offset = if self.rope_partial_at_end { kl - self.rope_dims } else { 0 };
let q_data = q_reshaped.as_f32()?.to_vec();
let k_data = k_reshaped.as_f32()?.to_vec();
let mut q_rope = vec![0.0f32; self.num_heads * self.rope_dims];
let mut k_rope = vec![0.0f32; self.num_kv_heads * self.rope_dims];
for h in 0..self.num_heads {
let src = h * kl + rope_offset;
let dst = h * self.rope_dims;
q_rope[dst..dst + self.rope_dims]
.copy_from_slice(&q_data[src..src + self.rope_dims]);
}
for h in 0..self.num_kv_heads {
let src = h * kl + rope_offset;
let dst = h * self.rope_dims;
k_rope[dst..dst + self.rope_dims]
.copy_from_slice(&k_data[src..src + self.rope_dims]);
}
let mut q_rope_t =
Tensor::from_f32(&q_rope, vec![self.num_heads, 1, self.rope_dims])?;
let mut k_rope_t =
Tensor::from_f32(&k_rope, vec![self.num_kv_heads, 1, self.rope_dims])?;
backend.rope(
&mut q_rope_t,
&mut k_rope_t,
pos,
freq_base,
freq_scale,
self.use_neox_rope,
)?;
let q_rope_out = q_rope_t.as_f32()?;
let k_rope_out = k_rope_t.as_f32()?;
let q_out = q_reshaped.as_f32_mut()?;
let k_out = k_reshaped.as_f32_mut()?;
for h in 0..self.num_heads {
let dst = h * kl + rope_offset;
let src = h * self.rope_dims;
q_out[dst..dst + self.rope_dims]
.copy_from_slice(&q_rope_out[src..src + self.rope_dims]);
}
for h in 0..self.num_kv_heads {
let dst = h * kl + rope_offset;
let src = h * self.rope_dims;
k_out[dst..dst + self.rope_dims]
.copy_from_slice(&k_rope_out[src..src + self.rope_dims]);
}
} else {
backend.rope(
&mut q_reshaped,
&mut k_reshaped,
pos,
freq_base,
freq_scale,
self.use_neox_rope,
)?;
}
let max_seq_len = k_cache.shape()[1];
let num_kv_heads = self.num_kv_heads;
{
let k_cache_data = k_cache.as_f32_mut()?;
let k_new_data = k_reshaped.as_f32()?;
for h in 0..num_kv_heads {
let cache_offset = h * max_seq_len * kl + pos * kl;
let new_offset = h * kl;
k_cache_data[cache_offset..cache_offset + kl]
.copy_from_slice(&k_new_data[new_offset..new_offset + kl]);
}
}
{
let v_cache_data = v_cache.as_f32_mut()?;
let v_new_data = v_reshaped.as_f32()?;
for h in 0..num_kv_heads {
let cache_offset = h * max_seq_len * vl + pos * vl;
let new_offset = h * vl;
v_cache_data[cache_offset..cache_offset + vl]
.copy_from_slice(&v_new_data[new_offset..new_offset + vl]);
}
}
let kv_len = pos + 1;
let mut attn_out = Tensor::zeros(vec![self.num_heads, 1, vl], DType::F32);
if self.attn_logit_softcap > 0.0 {
let cap = self.attn_logit_softcap;
let num_queries_per_kv = self.num_heads / self.num_kv_heads;
let max_seq_len = k_cache.shape()[1];
let q_data = q_reshaped.as_f32()?;
let k_data = k_cache.as_f32()?;
let v_data = v_cache.as_f32()?;
let out_data = attn_out.as_f32_mut()?;
let k_head_stride = max_seq_len * kl;
let v_head_stride = max_seq_len * vl;
for head in 0..self.num_heads {
let kv_head = head / num_queries_per_kv;
let q_offset = head * kl;
let q_vec = &q_data[q_offset..q_offset + kl];
let mut scores = vec![0.0f32; kv_len];
let k_base = kv_head * k_head_stride;
let v_base = kv_head * v_head_stride;
for (kv_pos, score) in scores.iter_mut().enumerate() {
if kv_pos > pos {
*score = f32::NEG_INFINITY;
continue;
}
let k_offset = k_base + kv_pos * kl;
let k_vec = &k_data[k_offset..k_offset + kl];
let mut dot = 0.0f32;
for d in 0..kl {
dot += q_vec[d] * k_vec[d];
}
*score = dot * self.scale;
if cap > 0.0 {
*score = cap * (*score / cap).tanh();
}
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for s in &mut scores {
*s = (*s - max_score).exp();
sum += *s;
}
let inv_sum = 1.0 / sum;
for s in &mut scores {
*s *= inv_sum;
}
let out_offset = head * vl;
let out_vec = &mut out_data[out_offset..out_offset + vl];
out_vec.fill(0.0);
for (kv_pos, &score_val) in scores.iter().enumerate() {
if score_val > 1e-8 {
let v_offset = v_base + kv_pos * vl;
let v_vec = &v_data[v_offset..v_offset + vl];
for d in 0..vl {
out_vec[d] += score_val * v_vec[d];
}
}
}
}
} else {
backend.attention_cached(
&q_reshaped,
k_cache,
v_cache,
&mut attn_out,
self.scale,
kv_len,
)?;
}
let attn_flat = if let Some(ref gate) = gate_data {
let attn_data = attn_out.as_f32()?;
let total = self.num_heads * vl;
let mut gated = vec![0.0f32; total];
for i in 0..total {
let sigmoid_g = 1.0 / (1.0 + (-gate[i]).exp());
gated[i] = sigmoid_g * attn_data[i];
}
Tensor::from_f32(&gated, vec![total])?
} else {
attn_out.reshape(vec![self.num_heads * vl])?
};
let mut out = Tensor::zeros(vec![hidden_size], DType::F32);
self.wo.forward(&attn_flat, &mut out, backend)?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub fn forward_turboquant(
&self,
x: &Tensor,
tq_cache: &mut super::kv_turboquant::TurboQuantKVCache,
layer_idx: usize,
pos: usize,
freq_base: f32,
freq_scale: f32,
backend: &dyn Backend,
) -> ModelResult<Tensor> {
let hidden_size = x.shape().last().copied().unwrap_or(0);
let seq_len = if x.ndim() == 1 { 1 } else { x.shape()[0] };
let kl = self.key_length;
let vl = self.value_length;
let x_vec = if x.ndim() == 2 {
let x_data = x.as_f32()?;
let start = (seq_len - 1) * hidden_size;
Tensor::from_f32(&x_data[start..start + hidden_size], vec![hidden_size])?
} else {
x.clone()
};
let q_out_size = self.wq.out_features;
let mut q_raw = Tensor::zeros(vec![q_out_size], DType::F32);
let mut k = Tensor::zeros(vec![self.num_kv_heads * kl], DType::F32);
let mut v = Tensor::zeros(vec![self.num_kv_heads * vl], DType::F32);
self.wq.forward(&x_vec, &mut q_raw, backend)?;
self.wk.forward(&x_vec, &mut k, backend)?;
self.wv.forward(&x_vec, &mut v, backend)?;
let (q_proper_data, gate_data) = if self.has_attention_gate {
let raw = q_raw.as_f32()?;
let q_proper_len = self.num_heads * kl;
let gate_len = self.num_heads * vl;
if raw.len() >= q_proper_len + gate_len {
let per_head_total = kl + vl;
let mut q_buf = vec![0.0f32; q_proper_len];
let mut g_buf = vec![0.0f32; gate_len];
for h in 0..self.num_heads {
let src = h * per_head_total;
let q_dst = h * kl;
let g_dst = h * vl;
q_buf[q_dst..q_dst + kl].copy_from_slice(&raw[src..src + kl]);
g_buf[g_dst..g_dst + vl].copy_from_slice(&raw[src + kl..src + kl + vl]);
}
(q_buf, Some(g_buf))
} else {
(raw.to_vec(), None)
}
} else {
(q_raw.as_f32()?.to_vec(), None)
};
let q_flat = Tensor::from_f32(&q_proper_data, vec![self.num_heads * kl])?;
let mut q_reshaped = q_flat.reshape(vec![self.num_heads, 1, kl])?;
let mut k_reshaped = k.reshape(vec![self.num_kv_heads, 1, kl])?;
let v_reshaped = v.reshape(vec![self.num_kv_heads, 1, vl])?;
if let Some(ref q_norm) = self.q_norm {
let q_data = q_reshaped.as_f32()?.to_vec();
let q_out = q_reshaped.as_f32_mut()?;
let norm_w = q_norm.weight.as_f32()?;
let norm_dim = norm_w.len();
for h in 0..self.num_heads {
let offset = h * kl;
let head_slice = &q_data[offset..offset + kl];
let ss: f32 = head_slice[..norm_dim].iter().map(|x| x * x).sum::<f32>()
/ norm_dim as f32;
let rms = (ss + q_norm.eps).sqrt();
for d in 0..norm_dim.min(kl) {
q_out[offset + d] = head_slice[d] / rms * norm_w[d];
}
}
}
if let Some(ref k_norm) = self.k_norm {
let k_data = k_reshaped.as_f32()?.to_vec();
let k_out = k_reshaped.as_f32_mut()?;
let norm_w = k_norm.weight.as_f32()?;
let norm_dim = norm_w.len();
for h in 0..self.num_kv_heads {
let offset = h * kl;
let head_slice = &k_data[offset..offset + kl];
let ss: f32 = head_slice[..norm_dim].iter().map(|x| x * x).sum::<f32>()
/ norm_dim as f32;
let rms = (ss + k_norm.eps).sqrt();
for d in 0..norm_dim.min(kl) {
k_out[offset + d] = head_slice[d] / rms * norm_w[d];
}
}
}
if self.rope_dims > 0 && self.rope_dims < kl {
let rope_offset = if self.rope_partial_at_end { kl - self.rope_dims } else { 0 };
let q_data = q_reshaped.as_f32()?.to_vec();
let k_data = k_reshaped.as_f32()?.to_vec();
let mut q_rope = vec![0.0f32; self.num_heads * self.rope_dims];
let mut k_rope = vec![0.0f32; self.num_kv_heads * self.rope_dims];
for h in 0..self.num_heads {
let src = h * kl + rope_offset;
let dst = h * self.rope_dims;
q_rope[dst..dst + self.rope_dims].copy_from_slice(&q_data[src..src + self.rope_dims]);
}
for h in 0..self.num_kv_heads {
let src = h * kl + rope_offset;
let dst = h * self.rope_dims;
k_rope[dst..dst + self.rope_dims].copy_from_slice(&k_data[src..src + self.rope_dims]);
}
let mut q_rope_t = Tensor::from_f32(&q_rope, vec![self.num_heads, 1, self.rope_dims])?;
let mut k_rope_t = Tensor::from_f32(&k_rope, vec![self.num_kv_heads, 1, self.rope_dims])?;
backend.rope(&mut q_rope_t, &mut k_rope_t, pos, freq_base, freq_scale, self.use_neox_rope)?;
let q_rope_out = q_rope_t.as_f32()?;
let k_rope_out = k_rope_t.as_f32()?;
let q_out = q_reshaped.as_f32_mut()?;
let k_out = k_reshaped.as_f32_mut()?;
for h in 0..self.num_heads {
let dst = h * kl + rope_offset;
let src = h * self.rope_dims;
q_out[dst..dst + self.rope_dims].copy_from_slice(&q_rope_out[src..src + self.rope_dims]);
}
for h in 0..self.num_kv_heads {
let dst = h * kl + rope_offset;
let src = h * self.rope_dims;
k_out[dst..dst + self.rope_dims].copy_from_slice(&k_rope_out[src..src + self.rope_dims]);
}
} else {
backend.rope(&mut q_reshaped, &mut k_reshaped, pos, freq_base, freq_scale, self.use_neox_rope)?;
}
let k_data = k_reshaped.as_f32()?;
let v_data = v_reshaped.as_f32()?;
tq_cache.write_kv(layer_idx, k_data, v_data);
let q_data = q_reshaped.as_f32()?;
let attn_flat = backend.attention_turboquant(
q_data, tq_cache, layer_idx, self.num_heads, self.scale,
)?;
let attn_vec = if let Some(ref gate) = gate_data {
let total = self.num_heads * vl;
let mut gated = vec![0.0f32; total];
for i in 0..total {
let sigmoid_g = 1.0 / (1.0 + (-gate[i]).exp());
gated[i] = sigmoid_g * attn_flat[i];
}
gated
} else {
attn_flat
};
let attn_tensor = Tensor::from_f32(&attn_vec, vec![self.num_heads * vl])?;
let mut out = Tensor::zeros(vec![hidden_size], DType::F32);
self.wo.forward(&attn_tensor, &mut out, backend)?;
Ok(out)
}
}
#[derive(Debug)]
pub struct FeedForward {
pub w_gate: Linear,
pub w_up: Linear,
pub w_down: Linear,
pub hidden_size: usize,
pub intermediate_size: usize,
}
impl FeedForward {
pub fn new(w_gate: Linear, w_up: Linear, w_down: Linear) -> Self {
let hidden_size = w_down.out_features;
let intermediate_size = w_gate.out_features;
Self {
w_gate,
w_up,
w_down,
hidden_size,
intermediate_size,
}
}
pub fn forward(
&self,
x: &Tensor,
out: &mut Tensor,
backend: &dyn Backend,
) -> BackendResult<()> {
let mut gate = Tensor::zeros(vec![self.intermediate_size], DType::F32);
let mut up = Tensor::zeros(vec![self.intermediate_size], DType::F32);
self.w_gate.forward(x, &mut gate, backend)?;
self.w_up.forward(x, &mut up, backend)?;
{
let gate_data = gate.as_f32_mut()?;
let up_data = up.as_f32()?;
crate::backend::cpu::simd::silu_mul_inplace(gate_data, up_data);
}
self.w_down.forward(&gate, out, backend)?;
Ok(())
}
}
#[derive(Debug)]
pub struct NoGateFeedForward {
pub w_up: Linear,
pub w_down: Linear,
pub hidden_size: usize,
pub intermediate_size: usize,
pub use_gelu: bool,
}
impl NoGateFeedForward {
pub fn new(w_up: Linear, w_down: Linear, use_gelu: bool) -> Self {
let hidden_size = w_down.out_features;
let intermediate_size = w_up.out_features;
Self { w_up, w_down, hidden_size, intermediate_size, use_gelu }
}
pub fn forward(
&self,
x: &Tensor,
out: &mut Tensor,
backend: &dyn Backend,
) -> BackendResult<()> {
let mut up = Tensor::zeros(vec![self.intermediate_size], DType::F32);
self.w_up.forward(x, &mut up, backend)?;
{
let data = up.as_f32_mut()?;
if self.use_gelu {
for v in data.iter_mut() {
let x = *v;
*v = 0.5 * x
* (1.0 + (0.797_884_6 * (x + 0.044715 * x * x * x)).tanh());
}
} else {
for v in data.iter_mut() {
*v = *v / (1.0 + (-*v).exp());
}
}
}
self.w_down.forward(&up, out, backend)?;
Ok(())
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum FfnLayer {
Dense(FeedForward),
NoGate(NoGateFeedForward),
Moe(super::moe::MoeLayer),
Identity,
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum AttentionLayer {
FullAttention(Attention),
DeltaNet(Box<super::deltanet::DeltaNetLayer>),
Mamba(Box<super::mamba::MambaLayer>),
}
#[derive(Debug)]
pub struct TransformerLayer {
pub attn_norm: NormLayer,
pub attn_layer: AttentionLayer,
pub post_attn_norm: Option<NormLayer>,
pub ffn_norm: NormLayer,
pub ffn_layer: FfnLayer,
pub post_ffn_norm: Option<NormLayer>,
pub layer_idx: usize,
pub use_parallel_residual: bool,
}
impl TransformerLayer {
pub fn attention(&self) -> Option<&Attention> {
match &self.attn_layer {
AttentionLayer::FullAttention(attn) => Some(attn),
AttentionLayer::DeltaNet(_) | AttentionLayer::Mamba(_) => None,
}
}
pub fn ffn(&self) -> Option<&FeedForward> {
match &self.ffn_layer {
FfnLayer::Dense(ffn) => Some(ffn),
FfnLayer::NoGate(_) | FfnLayer::Moe(_) | FfnLayer::Identity => None,
}
}
pub fn moe(&self) -> Option<&super::moe::MoeLayer> {
match &self.ffn_layer {
FfnLayer::Dense(_) | FfnLayer::NoGate(_) | FfnLayer::Identity => None,
FfnLayer::Moe(moe) => Some(moe),
}
}
pub fn no_gate_ffn(&self) -> Option<&NoGateFeedForward> {
match &self.ffn_layer {
FfnLayer::Dense(_) | FfnLayer::Moe(_) | FfnLayer::Identity => None,
FfnLayer::NoGate(ffn) => Some(ffn),
}
}
pub fn is_recurrent(&self) -> bool {
matches!(
&self.attn_layer,
AttentionLayer::DeltaNet(_) | AttentionLayer::Mamba(_)
)
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
x: &Tensor,
k_cache: &mut Tensor,
v_cache: &mut Tensor,
pos: usize,
freq_base: f32,
freq_scale: f32,
backend: &dyn Backend,
recurrent_state: Option<&mut super::deltanet::RecurrentLayerState>,
) -> ModelResult<Tensor> {
let hidden_size = x.shape().last().copied().unwrap_or(0);
let mut norm_out = Tensor::zeros(x.shape().to_vec(), DType::F32);
self.attn_norm.forward(x, &mut norm_out, backend)?;
let attn_out = match &self.attn_layer {
AttentionLayer::FullAttention(attn) => {
attn.forward(&norm_out, k_cache, v_cache, pos, freq_base, freq_scale, backend)?
}
AttentionLayer::DeltaNet(dn) => {
let state = recurrent_state.ok_or_else(|| {
ModelError::ConfigError(
"DeltaNet layer requires recurrent state".into(),
)
})?;
match state {
super::deltanet::RecurrentLayerState::DeltaNet(ds) => {
dn.forward(&norm_out, ds, backend)?
}
_ => {
return Err(ModelError::ConfigError(
"Expected DeltaNet state for DeltaNet layer".into(),
))
}
}
}
AttentionLayer::Mamba(mb) => {
let state = recurrent_state.ok_or_else(|| {
ModelError::ConfigError(
"Mamba layer requires recurrent state".into(),
)
})?;
match state {
super::deltanet::RecurrentLayerState::Mamba(ms) => {
mb.forward(&norm_out, ms, backend)?
}
_ => {
return Err(ModelError::ConfigError(
"Expected Mamba state for Mamba layer".into(),
))
}
}
}
};
let x_data = x.as_f32()?;
let x_start = if x.ndim() == 2 {
(x.shape()[0] - 1) * hidden_size
} else {
0
};
if matches!(self.ffn_layer, FfnLayer::Identity) {
let mut out = attn_out;
{
let out_data = out.as_f32_mut()?;
let x_slice = &x_data[x_start..x_start + hidden_size];
for (o, &xv) in out_data.iter_mut().zip(x_slice.iter()) {
*o += xv;
}
}
return Ok(out);
}
if self.use_parallel_residual {
let mut ffn_out = match &self.ffn_layer {
FfnLayer::Dense(ffn) => {
let mut out = Tensor::zeros(vec![hidden_size], DType::F32);
ffn.forward(&norm_out, &mut out, backend)?;
out
}
FfnLayer::NoGate(ffn) => {
let mut out = Tensor::zeros(vec![hidden_size], DType::F32);
ffn.forward(&norm_out, &mut out, backend)?;
out
}
FfnLayer::Moe(moe) => moe.forward(&norm_out, backend)?,
FfnLayer::Identity => unreachable!(),
};
{
let out_data = ffn_out.as_f32_mut()?;
let attn_data = attn_out.as_f32()?;
let x_slice = &x_data[x_start..x_start + hidden_size];
for i in 0..hidden_size {
out_data[i] += attn_data[i] + x_slice[i];
}
}
Ok(ffn_out)
} else {
let mut h = attn_out;
if let Some(ref pan) = self.post_attn_norm {
let mut normed = Tensor::zeros(vec![hidden_size], DType::F32);
pan.forward(&h, &mut normed, backend)?;
h = normed;
}
{
let h_data = h.as_f32_mut()?;
let x_slice = &x_data[x_start..x_start + hidden_size];
for (a, &xv) in h_data.iter_mut().zip(x_slice.iter()) {
*a += xv;
}
}
let mut ffn_norm_out = Tensor::zeros(vec![hidden_size], DType::F32);
self.ffn_norm.forward(&h, &mut ffn_norm_out, backend)?;
let mut ffn_out = match &self.ffn_layer {
FfnLayer::Dense(ffn) => {
let mut out = Tensor::zeros(vec![hidden_size], DType::F32);
ffn.forward(&ffn_norm_out, &mut out, backend)?;
out
}
FfnLayer::NoGate(ffn) => {
let mut out = Tensor::zeros(vec![hidden_size], DType::F32);
ffn.forward(&ffn_norm_out, &mut out, backend)?;
out
}
FfnLayer::Moe(moe) => moe.forward(&ffn_norm_out, backend)?,
FfnLayer::Identity => unreachable!(),
};
if let Some(ref norm) = self.post_ffn_norm {
let mut normed = Tensor::zeros(vec![hidden_size], DType::F32);
norm.forward(&ffn_out, &mut normed, backend)?;
ffn_out = normed;
}
{
let ffn_data = ffn_out.as_f32_mut()?;
let h_data = h.as_f32()?;
for (f, &hv) in ffn_data.iter_mut().zip(h_data.iter()) {
*f += hv;
}
}
Ok(ffn_out)
}
}
#[allow(clippy::too_many_arguments)]
pub fn forward_tq(
&self,
x: &Tensor,
tq_cache: &mut super::kv_turboquant::TurboQuantKVCache,
_k_cache: &mut Tensor,
_v_cache: &mut Tensor,
pos: usize,
freq_base: f32,
freq_scale: f32,
backend: &dyn Backend,
recurrent_state: Option<&mut super::deltanet::RecurrentLayerState>,
) -> ModelResult<Tensor> {
let hidden_size = x.shape().last().copied().unwrap_or(0);
let mut norm_out = Tensor::zeros(x.shape().to_vec(), DType::F32);
self.attn_norm.forward(x, &mut norm_out, backend)?;
let attn_out = match &self.attn_layer {
AttentionLayer::FullAttention(attn) => {
attn.forward_turboquant(
&norm_out, tq_cache, self.layer_idx,
pos, freq_base, freq_scale, backend,
)?
}
AttentionLayer::DeltaNet(dn) => {
let state = recurrent_state.ok_or_else(|| {
ModelError::ConfigError("DeltaNet layer requires recurrent state".into())
})?;
match state {
super::deltanet::RecurrentLayerState::DeltaNet(ds) => {
dn.forward(&norm_out, ds, backend)?
}
_ => return Err(ModelError::ConfigError(
"Expected DeltaNet state for DeltaNet layer".into(),
)),
}
}
AttentionLayer::Mamba(mb) => {
let state = recurrent_state.ok_or_else(|| {
ModelError::ConfigError("Mamba layer requires recurrent state".into())
})?;
match state {
super::deltanet::RecurrentLayerState::Mamba(ms) => {
mb.forward(&norm_out, ms, backend)?
}
_ => return Err(ModelError::ConfigError(
"Expected Mamba state for Mamba layer".into(),
)),
}
}
};
let x_data = x.as_f32()?;
let x_start = if x.ndim() == 2 {
(x.shape()[0] - 1) * hidden_size
} else {
0
};
if matches!(self.ffn_layer, FfnLayer::Identity) {
let mut out = attn_out;
{
let out_data = out.as_f32_mut()?;
let x_slice = &x_data[x_start..x_start + hidden_size];
for (o, &xv) in out_data.iter_mut().zip(x_slice.iter()) {
*o += xv;
}
}
return Ok(out);
}
if self.use_parallel_residual {
let mut ffn_out = match &self.ffn_layer {
FfnLayer::Dense(ffn) => {
let mut out = Tensor::zeros(vec![hidden_size], DType::F32);
ffn.forward(&norm_out, &mut out, backend)?;
out
}
FfnLayer::NoGate(ffn) => {
let mut out = Tensor::zeros(vec![hidden_size], DType::F32);
ffn.forward(&norm_out, &mut out, backend)?;
out
}
FfnLayer::Moe(moe) => moe.forward(&norm_out, backend)?,
FfnLayer::Identity => unreachable!(),
};
{
let out_data = ffn_out.as_f32_mut()?;
let attn_data = attn_out.as_f32()?;
let x_slice = &x_data[x_start..x_start + hidden_size];
for i in 0..hidden_size {
out_data[i] += attn_data[i] + x_slice[i];
}
}
Ok(ffn_out)
} else {
let mut h = attn_out;
if let Some(ref pan) = self.post_attn_norm {
let mut normed = Tensor::zeros(vec![hidden_size], DType::F32);
pan.forward(&h, &mut normed, backend)?;
h = normed;
}
{
let h_data = h.as_f32_mut()?;
let x_slice = &x_data[x_start..x_start + hidden_size];
for (a, &xv) in h_data.iter_mut().zip(x_slice.iter()) {
*a += xv;
}
}
let mut ffn_norm_out = Tensor::zeros(vec![hidden_size], DType::F32);
self.ffn_norm.forward(&h, &mut ffn_norm_out, backend)?;
let mut ffn_out = match &self.ffn_layer {
FfnLayer::Dense(ffn) => {
let mut out = Tensor::zeros(vec![hidden_size], DType::F32);
ffn.forward(&ffn_norm_out, &mut out, backend)?;
out
}
FfnLayer::NoGate(ffn) => {
let mut out = Tensor::zeros(vec![hidden_size], DType::F32);
ffn.forward(&ffn_norm_out, &mut out, backend)?;
out
}
FfnLayer::Moe(moe) => moe.forward(&ffn_norm_out, backend)?,
FfnLayer::Identity => unreachable!(),
};
if let Some(ref norm) = self.post_ffn_norm {
let mut normed = Tensor::zeros(vec![hidden_size], DType::F32);
norm.forward(&ffn_out, &mut normed, backend)?;
ffn_out = normed;
}
{
let ffn_data = ffn_out.as_f32_mut()?;
let h_data = h.as_f32()?;
for (f, &hv) in ffn_data.iter_mut().zip(h_data.iter()) {
*f += hv;
}
}
Ok(ffn_out)
}
}
}