use crate::error::{Error, Result};
use crate::nn::{Linear, MaybeQuantLinear, RmsNorm, RoPE, VarBuilder};
use crate::ops::RoPEOps;
use crate::ops::impl_generic::attention::mla::scaled_dot_product_attention_impl;
use crate::ops::impl_generic::attention::rope::apply_rope_impl;
use crate::quant::traits::QuantMatmulOps;
use numr::autograd::{Var, var_broadcast_to, var_cat, var_narrow, var_permute, var_reshape};
use numr::dtype::DType;
use numr::ops::{
BinaryOps, NormalizationOps, ReduceOps, ScalarOps, ShapeOps, TensorOps, TypeConversionOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct MlaConfig {
pub hidden_size: usize,
pub num_heads: usize,
pub head_dim: usize,
pub head_dim_v: usize,
pub kv_lora_rank: usize,
pub q_lora_rank: usize,
pub rope_head_dim: usize,
pub max_seq_len: usize,
pub rope_theta: f32,
pub use_norm: bool,
pub norm_eps: f32,
}
impl MlaConfig {
pub fn deepseek_v2(
hidden_size: usize,
num_heads: usize,
kv_lora_rank: usize,
q_lora_rank: usize,
rope_head_dim: usize,
max_seq_len: usize,
) -> Self {
let head_dim = hidden_size / num_heads;
Self {
hidden_size,
num_heads,
head_dim,
head_dim_v: head_dim,
kv_lora_rank,
q_lora_rank,
rope_head_dim,
max_seq_len,
rope_theta: 10000.0,
use_norm: true,
norm_eps: 1e-6,
}
}
pub fn validate(&self) -> Result<()> {
if self.hidden_size == 0 || self.num_heads == 0 {
return Err(Error::ModelError {
reason: "hidden_size and num_heads must be > 0".into(),
});
}
if self.kv_lora_rank == 0 {
return Err(Error::ModelError {
reason: "kv_lora_rank must be > 0 for MLA".into(),
});
}
if self.rope_head_dim > self.head_dim {
return Err(Error::ModelError {
reason: format!(
"rope_head_dim ({}) > head_dim ({})",
self.rope_head_dim, self.head_dim
),
});
}
Ok(())
}
pub fn qk_head_dim(&self) -> usize {
self.head_dim + self.rope_head_dim
}
pub fn q_uses_lora(&self) -> bool {
self.q_lora_rank > 0
}
}
pub struct Mla<R: Runtime> {
q_down: Option<MaybeQuantLinear<R>>,
q_up: MaybeQuantLinear<R>,
q_norm: Option<RmsNorm<R>>,
kv_compress: MaybeQuantLinear<R>,
kv_norm: Option<RmsNorm<R>>,
kv_decompress: MaybeQuantLinear<R>,
o_proj: MaybeQuantLinear<R>,
rope: RoPE<R>,
num_heads: usize,
head_dim: usize,
head_dim_v: usize,
rope_head_dim: usize,
kv_lora_rank: usize,
scale: f64,
}
impl<R: Runtime<DType = DType>> Mla<R> {
pub fn from_config(config: &MlaConfig, device: &R::Device) -> Result<Self> {
config.validate()?;
let h = config.hidden_size;
let nh = config.num_heads;
let qk_dim = config.qk_head_dim();
let dt = DType::F32;
let (q_down, q_up, q_norm) = if config.q_uses_lora() {
let q_down = MaybeQuantLinear::Standard(Linear::new(
Tensor::<R>::zeros(&[config.q_lora_rank, h], dt, device),
None,
true,
));
let q_up = MaybeQuantLinear::Standard(Linear::new(
Tensor::<R>::zeros(&[nh * qk_dim, config.q_lora_rank], dt, device),
None,
true,
));
let q_norm = if config.use_norm {
Some(RmsNorm::new(
Tensor::<R>::ones(&[config.q_lora_rank], dt, device),
config.norm_eps,
true,
))
} else {
None
};
(Some(q_down), q_up, q_norm)
} else {
let q_up = MaybeQuantLinear::Standard(Linear::new(
Tensor::<R>::zeros(&[nh * qk_dim, h], dt, device),
None,
true,
));
(None, q_up, None)
};
let kv_compress = MaybeQuantLinear::Standard(Linear::new(
Tensor::<R>::zeros(&[config.kv_lora_rank + config.rope_head_dim, h], dt, device),
None,
true,
));
let kv_norm = if config.use_norm {
Some(RmsNorm::new(
Tensor::<R>::ones(&[config.kv_lora_rank], dt, device),
config.norm_eps,
true,
))
} else {
None
};
let kv_decompress = MaybeQuantLinear::Standard(Linear::new(
Tensor::<R>::zeros(
&[
nh * (config.head_dim + config.head_dim_v),
config.kv_lora_rank,
],
dt,
device,
),
None,
true,
));
let o_proj = MaybeQuantLinear::Standard(Linear::new(
Tensor::<R>::zeros(&[h, nh * config.head_dim_v], dt, device),
None,
true,
));
let rope = RoPE::<R>::precompute_freqs(
config.max_seq_len,
config.rope_head_dim,
config.rope_theta,
None,
device,
);
let scale = 1.0 / (qk_dim as f64).sqrt();
Ok(Self {
q_down,
q_up,
q_norm,
kv_compress,
kv_norm,
kv_decompress,
o_proj,
rope,
num_heads: nh,
head_dim: config.head_dim,
head_dim_v: config.head_dim_v,
rope_head_dim: config.rope_head_dim,
kv_lora_rank: config.kv_lora_rank,
scale,
})
}
pub fn from_varbuilder(vb: &mut VarBuilder<R>, config: &MlaConfig) -> Result<Self> {
config.validate()?;
let nh = config.num_heads;
let qk_dim = config.qk_head_dim();
let (q_down, q_up, q_norm) = if config.q_uses_lora() {
let q_down = vb.pp("q_a_proj").take_maybe_quant_linear("weight", None)?;
let q_up = vb.pp("q_b_proj").take_maybe_quant_linear("weight", None)?;
let q_norm = if config.use_norm {
let mut qn_vb = vb.pp("q_a_layernorm");
Some(RmsNorm::new(
qn_vb.take_tensor("weight")?,
config.norm_eps,
false,
))
} else {
None
};
(Some(q_down), q_up, q_norm)
} else {
let q_up = vb.pp("q_proj").take_maybe_quant_linear("weight", None)?;
(None, q_up, None)
};
let kv_compress = vb
.pp("kv_a_proj_with_mqa")
.take_maybe_quant_linear("weight", None)?;
let kv_norm = if config.use_norm {
let mut kvn_vb = vb.pp("kv_a_layernorm");
Some(RmsNorm::new(
kvn_vb.take_tensor("weight")?,
config.norm_eps,
false,
))
} else {
None
};
let kv_decompress = vb.pp("kv_b_proj").take_maybe_quant_linear("weight", None)?;
let o_proj = vb.pp("o_proj").take_maybe_quant_linear("weight", None)?;
let rope = RoPE::<R>::precompute_freqs(
config.max_seq_len,
config.rope_head_dim,
config.rope_theta,
None,
vb.device(),
);
let scale = 1.0 / (qk_dim as f64).sqrt();
Ok(Self {
q_down,
q_up,
q_norm,
kv_compress,
kv_norm,
kv_decompress,
o_proj,
rope,
num_heads: nh,
head_dim: config.head_dim,
head_dim_v: config.head_dim_v,
rope_head_dim: config.rope_head_dim,
kv_lora_rank: config.kv_lora_rank,
scale,
})
}
pub fn forward<C>(&self, client: &C, hidden: &Var<R>) -> Result<Var<R>>
where
C: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ NormalizationOps<R>
+ ShapeOps<R>
+ BinaryOps<R>
+ TypeConversionOps<R>
+ QuantMatmulOps<R>
+ RoPEOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let shape = hidden.shape().to_vec();
let batch = shape[0];
let seq_len = shape[1];
let qk_dim = self.head_dim + self.rope_head_dim;
let q = if let Some(q_down) = &self.q_down {
let q_latent = q_down.forward(client, hidden)?;
let q_latent = if let Some(norm) = &self.q_norm {
norm.forward(client, &q_latent)?
} else {
q_latent
};
self.q_up.forward(client, &q_latent)?
} else {
self.q_up.forward(client, hidden)?
};
let q = var_reshape(&q, &[batch, seq_len, self.num_heads, qk_dim]).map_err(Error::Numr)?;
let q = var_permute(&q, &[0, 2, 1, 3]).map_err(Error::Numr)?;
let q = var_contiguous(&q);
let q_nope = var_narrow(&q, 3, 0, self.head_dim).map_err(Error::Numr)?;
let q_nope = var_contiguous(&q_nope);
let q_pe = var_narrow(&q, 3, self.head_dim, self.rope_head_dim).map_err(Error::Numr)?;
let q_pe = var_contiguous(&q_pe);
let kv_compressed = self.kv_compress.forward(client, hidden)?;
let c_kv = var_narrow(&kv_compressed, 2, 0, self.kv_lora_rank).map_err(Error::Numr)?;
let c_kv = var_contiguous(&c_kv);
let k_pe_raw = var_narrow(&kv_compressed, 2, self.kv_lora_rank, self.rope_head_dim)
.map_err(Error::Numr)?;
let k_pe_raw = var_contiguous(&k_pe_raw);
let c_kv = if let Some(norm) = &self.kv_norm {
norm.forward(client, &c_kv)?
} else {
c_kv
};
let kv = self.kv_decompress.forward(client, &c_kv)?;
let kv = var_reshape(
&kv,
&[
batch,
seq_len,
self.num_heads,
self.head_dim + self.head_dim_v,
],
)
.map_err(Error::Numr)?;
let kv = var_permute(&kv, &[0, 2, 1, 3]).map_err(Error::Numr)?;
let kv = var_contiguous(&kv);
let k_nope = var_narrow(&kv, 3, 0, self.head_dim).map_err(Error::Numr)?;
let k_nope = var_contiguous(&k_nope);
let v = var_narrow(&kv, 3, self.head_dim, self.head_dim_v).map_err(Error::Numr)?;
let v = var_contiguous(&v);
let k_pe = var_reshape(&k_pe_raw, &[batch, 1, seq_len, self.rope_head_dim])
.map_err(Error::Numr)?;
let k_pe = var_broadcast_to(&k_pe, &[batch, self.num_heads, seq_len, self.rope_head_dim])
.map_err(Error::Numr)?;
let k_pe = var_contiguous(&k_pe);
let q_pe = apply_rope_impl(client, &q_pe, self.rope.cos_cache(), self.rope.sin_cache())?;
let k_pe = apply_rope_impl(client, &k_pe, self.rope.cos_cache(), self.rope.sin_cache())?;
let q = var_cat(&[&q_nope, &q_pe], 3, client).map_err(Error::Numr)?;
let k = var_cat(&[&k_nope, &k_pe], 3, client).map_err(Error::Numr)?;
let attn_out = scaled_dot_product_attention_impl(client, &q, &k, &v, self.scale, true)?;
let attn_out = var_permute(&attn_out, &[0, 2, 1, 3]).map_err(Error::Numr)?;
let attn_out = var_contiguous(&attn_out);
let attn_out = var_reshape(
&attn_out,
&[batch, seq_len, self.num_heads * self.head_dim_v],
)
.map_err(Error::Numr)?;
self.o_proj.forward(client, &attn_out)
}
}
fn var_contiguous<R: Runtime>(v: &Var<R>) -> Var<R> {
Var::new(v.tensor().contiguous(), v.requires_grad())
}