use crate::distributed::tensor_parallel::{ColumnParallelLinear, RowParallelLinear};
use crate::error::{Error, Result};
use crate::inference::kv_cache::KvCache;
use crate::model::traits::ModelClient;
use crate::nn::{RmsNorm, RoPE};
use crate::ops::impl_generic::attention::multi_head_attention_impl;
use crate::ops::impl_generic::attention::rope::apply_rope_interleaved_impl;
use numr::autograd::{Var, var_add, var_mul, var_narrow, var_reshape, var_silu};
use numr::dtype::DType;
use numr::ops::{
ActivationOps, BinaryOps, CompareOps, ConditionalOps, IndexingOps, ReduceOps, ScalarOps,
ShapeOps, TensorOps, UnaryOps,
};
use numr::runtime::Runtime;
pub(super) struct LlamaBlockTp<R: Runtime> {
pub(super) input_layernorm: RmsNorm<R>,
pub(super) self_attn: LlamaAttentionTp<R>,
pub(super) post_attention_layernorm: RmsNorm<R>,
pub(super) mlp: LlamaMlpTp<R>,
}
pub(super) struct LlamaAttentionTp<R: Runtime> {
pub(super) q_proj: ColumnParallelLinear<R>,
pub(super) k_proj: ColumnParallelLinear<R>,
pub(super) v_proj: ColumnParallelLinear<R>,
pub(super) o_proj: RowParallelLinear<R>,
pub(super) num_heads: usize, pub(super) num_kv_heads: usize, pub(super) head_dim: usize,
pub(super) q_norm: Option<RmsNorm<R>>,
pub(super) k_norm: Option<RmsNorm<R>>,
}
pub(super) struct LlamaMlpTp<R: Runtime> {
pub(super) gate_proj: ColumnParallelLinear<R>,
pub(super) up_proj: ColumnParallelLinear<R>,
pub(super) down_proj: RowParallelLinear<R>,
}
impl<R: Runtime<DType = DType>> LlamaBlockTp<R> {
pub(super) fn forward<C>(&self, client: &C, x: &Var<R>, rope: &RoPE<R>) -> Result<Var<R>>
where
C: ModelClient<R>,
R::Client: TensorOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ CompareOps<R>
+ ConditionalOps<R>,
{
let normed = self.input_layernorm.forward(client, x)?;
let attn_out = self.self_attn.forward(client, &normed, rope)?;
let h = var_add(x, &attn_out, client).map_err(Error::Numr)?;
let normed = self.post_attention_layernorm.forward(client, &h)?;
let mlp_out = self.mlp.forward(client, &normed)?;
var_add(&h, &mlp_out, client).map_err(Error::Numr)
}
pub(super) fn forward_with_kv_cache<C>(
&self,
client: &C,
x: &Var<R>,
rope: &RoPE<R>,
kv_cache: &mut KvCache<R>,
position: usize,
) -> Result<Var<R>>
where
C: ModelClient<R>,
R::Client: TensorOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ CompareOps<R>
+ ConditionalOps<R>,
{
let normed = self.input_layernorm.forward(client, x)?;
let attn_out = self
.self_attn
.forward_with_kv_cache(client, &normed, rope, kv_cache, position)?;
let h = var_add(x, &attn_out, client).map_err(Error::Numr)?;
let normed = self.post_attention_layernorm.forward(client, &h)?;
let mlp_out = self.mlp.forward(client, &normed)?;
var_add(&h, &mlp_out, client).map_err(Error::Numr)
}
}
impl<R: Runtime<DType = DType>> LlamaAttentionTp<R> {
pub(super) fn forward<C>(&self, client: &C, x: &Var<R>, rope: &RoPE<R>) -> Result<Var<R>>
where
C: ModelClient<R>,
R::Client: TensorOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ CompareOps<R>
+ ConditionalOps<R>,
{
let shape = x.shape().to_vec();
let batch = shape[0];
let seq_len = shape[1];
let q = self.q_proj.forward(client, x)?;
let k = self.k_proj.forward(client, x)?;
let v = self.v_proj.forward(client, x)?;
let q = numr::autograd::var_reshape(&q, &[batch, seq_len, self.num_heads, self.head_dim])
.map_err(Error::Numr)?;
let k =
numr::autograd::var_reshape(&k, &[batch, seq_len, self.num_kv_heads, self.head_dim])
.map_err(Error::Numr)?;
let v =
numr::autograd::var_reshape(&v, &[batch, seq_len, self.num_kv_heads, self.head_dim])
.map_err(Error::Numr)?;
let q = numr::autograd::var_permute(&q, &[0, 2, 1, 3]).map_err(Error::Numr)?;
let q = var_contiguous(&q);
let k = numr::autograd::var_permute(&k, &[0, 2, 1, 3]).map_err(Error::Numr)?;
let k = var_contiguous(&k);
let v = numr::autograd::var_permute(&v, &[0, 2, 1, 3]).map_err(Error::Numr)?;
let v = var_contiguous(&v);
let q = match &self.q_norm {
Some(norm) => norm.forward(client, &q)?,
None => q,
};
let k = match &self.k_norm {
Some(norm) => norm.forward(client, &k)?,
None => k,
};
let q = apply_rope_interleaved_impl(client, &q, rope.cos_cache(), rope.sin_cache())?;
let k = apply_rope_interleaved_impl(client, &k, rope.cos_cache(), rope.sin_cache())?;
let (k, v) = if self.num_kv_heads < self.num_heads {
let repeat = self.num_heads / self.num_kv_heads;
let k = repeat_kv(&k, repeat).map_err(Error::Numr)?;
let v = repeat_kv(&v, repeat).map_err(Error::Numr)?;
(k, v)
} else {
(k, v)
};
let attn_out = multi_head_attention_impl(client, &q, &k, &v, None, self.num_heads)?;
let attn_out =
numr::autograd::var_permute(&attn_out, &[0, 2, 1, 3]).map_err(Error::Numr)?;
let attn_out = var_contiguous(&attn_out);
let attn_out = numr::autograd::var_reshape(
&attn_out,
&[batch, seq_len, self.num_heads * self.head_dim],
)
.map_err(Error::Numr)?;
self.o_proj.forward(client, &attn_out)
}
pub(super) fn forward_with_kv_cache<C>(
&self,
client: &C,
x: &Var<R>,
rope: &RoPE<R>,
kv_cache: &mut KvCache<R>,
position: usize,
) -> Result<Var<R>>
where
C: ModelClient<R>,
R::Client: TensorOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ CompareOps<R>
+ ConditionalOps<R>,
{
let shape = x.shape().to_vec();
let batch = shape[0];
let seq_len = shape[1];
let q = self.q_proj.forward(client, x)?;
let k = self.k_proj.forward(client, x)?;
let v = self.v_proj.forward(client, x)?;
let q = var_reshape(&q, &[batch, seq_len, self.num_heads, self.head_dim])
.map_err(Error::Numr)?;
let k = var_reshape(&k, &[batch, seq_len, self.num_kv_heads, self.head_dim])
.map_err(Error::Numr)?;
let v = var_reshape(&v, &[batch, seq_len, self.num_kv_heads, self.head_dim])
.map_err(Error::Numr)?;
let q = numr::autograd::var_permute(&q, &[0, 2, 1, 3]).map_err(Error::Numr)?;
let q = var_contiguous(&q);
let k = numr::autograd::var_permute(&k, &[0, 2, 1, 3]).map_err(Error::Numr)?;
let k = var_contiguous(&k);
let v = numr::autograd::var_permute(&v, &[0, 2, 1, 3]).map_err(Error::Numr)?;
let v = var_contiguous(&v);
let q = match &self.q_norm {
Some(norm) => norm.forward(client, &q)?,
None => q,
};
let k = match &self.k_norm {
Some(norm) => norm.forward(client, &k)?,
None => k,
};
let cos_offset = var_narrow(rope.cos_cache(), 0, position, seq_len).map_err(Error::Numr)?;
let sin_offset = var_narrow(rope.sin_cache(), 0, position, seq_len).map_err(Error::Numr)?;
let q = apply_rope_interleaved_impl(client, &q, &cos_offset, &sin_offset)?;
let k = apply_rope_interleaved_impl(client, &k, &cos_offset, &sin_offset)?;
kv_cache.update_fused(k.tensor(), v.tensor(), client)?;
let kv_seq_len = kv_cache.seq_len();
let is_prefill = seq_len > 1;
let (attn_out, _lse) = client.flash_attention_fwd(
q.tensor(),
kv_cache.k_cache_raw(),
kv_cache.v_cache_raw(),
self.num_heads,
self.num_kv_heads,
self.head_dim,
is_prefill,
0,
Some(kv_seq_len),
)?;
let attn_out = Var::new(attn_out, false);
let attn_out =
numr::autograd::var_permute(&attn_out, &[0, 2, 1, 3]).map_err(Error::Numr)?;
let attn_out = var_contiguous(&attn_out);
let attn_out = var_reshape(&attn_out, &[batch, seq_len, self.num_heads * self.head_dim])
.map_err(Error::Numr)?;
self.o_proj.forward(client, &attn_out)
}
}
fn var_contiguous<R: Runtime>(v: &Var<R>) -> Var<R> {
Var::new(v.tensor().contiguous(), v.requires_grad())
}
fn repeat_kv<R: Runtime>(x: &Var<R>, repeat: usize) -> numr::error::Result<Var<R>> {
if repeat == 1 {
return Ok(x.clone());
}
let shape = x.shape();
let [b, h_kv, s, d] = [shape[0], shape[1], shape[2], shape[3]];
let expanded = x.tensor().reshape(&[b, h_kv, 1, s, d])?;
let expanded = expanded.broadcast_to(&[b, h_kv, repeat, s, d])?;
let result = expanded.contiguous().reshape(&[b, h_kv * repeat, s, d])?;
Ok(Var::new(result, x.requires_grad()))
}
impl<R: Runtime<DType = DType>> LlamaMlpTp<R> {
pub(super) fn forward<C>(&self, client: &C, x: &Var<R>) -> Result<Var<R>>
where
C: ModelClient<R>,
R::Client: TensorOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ CompareOps<R>
+ ConditionalOps<R>,
{
let gate = self.gate_proj.forward(client, x)?;
let up = self.up_proj.forward(client, x)?;
let gate_silu = var_silu(&gate, client).map_err(Error::Numr)?;
let hidden = var_mul(&gate_silu, &up, client).map_err(Error::Numr)?;
self.down_proj.forward(client, &hidden)
}
}