#![allow(clippy::similar_names)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::many_single_char_names)]
#![allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
use candle_core::{D, DType, Module, Result, Tensor};
use candle_nn::{Linear, VarBuilder, linear, linear_no_bias};
use crate::rope::apply_rope;
pub fn layer_norm_no_bias(dim: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
let gamma = vb.get(dim, "gamma")?;
Ok(candle_nn::LayerNorm::new_no_bias(gamma, eps))
}
pub fn layer_norm_affine(dim: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
let gamma = vb.get(dim, "gamma")?;
let beta = vb.get(dim, "beta")?;
Ok(candle_nn::LayerNorm::new(gamma, beta, eps))
}
#[derive(Debug)]
pub struct Attention {
qkv: Option<Linear>,
q_proj: Option<Linear>,
k_proj: Option<Linear>,
v_proj: Option<Linear>,
kv_proj: Option<Linear>,
out_proj: Linear,
num_heads: usize,
head_dim: usize,
is_cross: bool,
qk_norm: bool,
scale: f64,
}
impl Attention {
pub fn new_self(
embed_dim: usize,
num_heads: usize,
head_dim: usize,
qk_norm: bool,
vb: VarBuilder,
) -> Result<Self> {
assert_eq!(embed_dim, num_heads * head_dim);
let qkv = linear_no_bias(embed_dim, embed_dim * 3, vb.pp("to_qkv"))?;
let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("to_out"))?;
#[allow(clippy::cast_precision_loss)]
let scale = 1.0 / (head_dim as f64).sqrt();
Ok(Self {
qkv: Some(qkv),
q_proj: None,
k_proj: None,
v_proj: None,
kv_proj: None,
out_proj,
num_heads,
head_dim,
is_cross: false,
qk_norm,
scale,
})
}
pub fn new_self_split(
embed_dim: usize,
num_heads: usize,
head_dim: usize,
qk_norm: bool,
vb: VarBuilder,
) -> Result<Self> {
assert_eq!(embed_dim, num_heads * head_dim);
let q_proj = linear(embed_dim, embed_dim, vb.pp("to_q"))?;
let k_proj = linear(embed_dim, embed_dim, vb.pp("to_k"))?;
let v_proj = linear(embed_dim, embed_dim, vb.pp("to_v"))?;
let out_proj = linear(embed_dim, embed_dim, vb.pp("to_out").pp("0"))?;
#[allow(clippy::cast_precision_loss)]
let scale = 1.0 / (head_dim as f64).sqrt();
Ok(Self {
qkv: None,
q_proj: Some(q_proj),
k_proj: Some(k_proj),
v_proj: Some(v_proj),
kv_proj: None,
out_proj,
num_heads,
head_dim,
is_cross: false,
qk_norm,
scale,
})
}
pub fn new_cross(
embed_dim: usize,
kv_dim: usize,
num_heads: usize,
head_dim: usize,
qk_norm: bool,
vb: VarBuilder,
) -> Result<Self> {
assert_eq!(embed_dim, num_heads * head_dim);
let q_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("to_q"))?;
let kv_proj = linear_no_bias(kv_dim, kv_dim * 2, vb.pp("to_kv"))?;
let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("to_out"))?;
#[allow(clippy::cast_precision_loss)]
let scale = 1.0 / (head_dim as f64).sqrt();
Ok(Self {
qkv: None,
q_proj: Some(q_proj),
k_proj: None,
v_proj: None,
kv_proj: Some(kv_proj),
out_proj,
num_heads,
head_dim,
is_cross: true,
qk_norm,
scale,
})
}
fn split_heads(&self, x: &Tensor) -> Result<Tensor> {
let (b, t, _) = x.dims3()?;
x.reshape((b, t, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()
}
fn merge_heads(x: &Tensor) -> Result<Tensor> {
let (b, h, t, d) = x.dims4()?;
x.transpose(1, 2)?.contiguous()?.reshape((b, t, h * d))
}
fn l2_normalize(t: &Tensor) -> Result<Tensor> {
let sq = t.sqr()?;
let sum = sq.sum_keepdim(D::Minus1)?;
let norm = (sum + 1e-12)?.sqrt()?;
t.broadcast_div(&norm)
}
pub fn forward(
&self,
x: &Tensor,
kv_input: Option<&Tensor>,
rope: Option<(&Tensor, &Tensor)>,
) -> Result<Tensor> {
let (b, t_q, _) = x.dims3()?;
let (q, k, v) = if self.is_cross {
let kv = kv_input.expect("cross-attention requires a kv_input");
let q = self
.q_proj
.as_ref()
.expect("cross-attn has q_proj")
.forward(x)?;
let kv = self
.kv_proj
.as_ref()
.expect("cross-attn has kv_proj")
.forward(kv)?;
let kv_dim = kv.dim(D::Minus1)? / 2;
let k = kv.narrow(D::Minus1, 0, kv_dim)?;
let v = kv.narrow(D::Minus1, kv_dim, kv_dim)?;
(q, k, v)
} else if let Some(qkv) = self.qkv.as_ref() {
let qkv = qkv.forward(x)?;
let d = qkv.dim(D::Minus1)? / 3;
let q = qkv.narrow(D::Minus1, 0, d)?;
let k = qkv.narrow(D::Minus1, d, d)?;
let v = qkv.narrow(D::Minus1, 2 * d, d)?;
(q, k, v)
} else {
let q = self
.q_proj
.as_ref()
.expect("split self-attn has q_proj")
.forward(x)?;
let k = self
.k_proj
.as_ref()
.expect("split self-attn has k_proj")
.forward(x)?;
let v = self
.v_proj
.as_ref()
.expect("split self-attn has v_proj")
.forward(x)?;
(q, k, v)
};
let mut q = self.split_heads(&q)?;
let mut k = self.split_heads(&k)?;
let v = self.split_heads(&v)?;
if self.qk_norm {
q = Self::l2_normalize(&q)?;
k = Self::l2_normalize(&k)?;
}
if !self.is_cross
&& let Some((cos, sin)) = rope
{
let orig_dtype = q.dtype();
let q_f = q.to_dtype(DType::F32)?;
let k_f = k.to_dtype(DType::F32)?;
let q_rot = apply_rope(&q_f, cos, sin)?.to_dtype(orig_dtype)?;
let k_rot = apply_rope(&k_f, cos, sin)?.to_dtype(orig_dtype)?;
q = q_rot;
k = k_rot;
}
let k_t = k.transpose(D::Minus2, D::Minus1)?.contiguous()?;
let attn_scores = q.matmul(&k_t)?;
let attn_scores = (attn_scores * self.scale)?;
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_scores)?;
let out = attn_probs.matmul(&v.contiguous()?)?;
let out = Self::merge_heads(&out)?;
debug_assert_eq!(out.dim(1)?, t_q);
debug_assert_eq!(out.dim(0)?, b);
self.out_proj.forward(&out)
}
}
#[derive(Debug)]
pub struct FeedForward {
glu_proj: Linear,
out_proj: Linear,
}
impl FeedForward {
pub fn new(embed_dim: usize, inner_dim: usize, vb: VarBuilder) -> Result<Self> {
let glu_proj = linear_no_bias(embed_dim, inner_dim * 2, vb.pp("ff.0.proj"))?;
let out_proj = linear_no_bias(inner_dim, embed_dim, vb.pp("ff.2"))?;
Ok(Self { glu_proj, out_proj })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let projected = self.glu_proj.forward(x)?;
let half = projected.dim(D::Minus1)? / 2;
let value = projected.narrow(D::Minus1, 0, half)?;
let gate = projected.narrow(D::Minus1, half, half)?;
let activated = (value * candle_nn::ops::silu(&gate)?)?;
self.out_proj.forward(&activated)
}
}