use super::flash_standard::{StandardAttentionClient, StandardAttnConfig, standard_attention_bwd};
use crate::error::{Error, Result};
use numr::dtype::DType;
use numr::ops::{IndexingOps, ScatterReduceOp};
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub struct PagedKv<'a, R: Runtime> {
pub k_blocks: &'a Tensor<R>,
pub v_blocks: &'a Tensor<R>,
pub block_table: &'a Tensor<R>,
}
#[derive(Debug, Clone, Copy)]
pub struct PagedAttnConfig {
pub num_heads: usize,
pub num_kv_heads: usize,
pub seq_len_k: usize,
pub head_dim: usize,
pub block_size: usize,
pub causal: bool,
}
fn paged_kv_row_indices(
block_table: &[i32],
batch_size: usize,
num_kv_heads: usize,
seq_len_k: usize,
block_size: usize,
max_num_blocks: usize,
) -> Vec<i64> {
let mut rows = Vec::with_capacity(batch_size * num_kv_heads * seq_len_k);
for b in 0..batch_size {
for h_kv in 0..num_kv_heads {
for t in 0..seq_len_k {
let logical_block = t / block_size;
let offset = t % block_size;
let physical_block = block_table[b * max_num_blocks + logical_block] as usize;
let row = (physical_block * block_size + offset) * num_kv_heads + h_kv;
rows.push(row as i64);
}
}
}
rows
}
pub fn paged_attention_bwd_impl<R, C>(
client: &C,
dout: &Tensor<R>,
q: &Tensor<R>,
kv: &PagedKv<R>,
output: &Tensor<R>,
cfg: PagedAttnConfig,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: StandardAttentionClient<R> + IndexingOps<R>,
{
let batch_size = q.shape()[0];
let max_num_blocks = kv.block_table.shape()[1];
let d = cfg.head_dim;
let nkv = cfg.num_kv_heads;
let sk = cfg.seq_len_k;
let device = q.device();
let pool_rows: usize = kv.k_blocks.shape().iter().product::<usize>() / d;
let n = batch_size * nkv * sk;
let bt = kv.block_table.to_vec::<i32>();
let rows = paged_kv_row_indices(&bt, batch_size, nkv, sk, cfg.block_size, max_num_blocks);
let idx = Tensor::<R>::from_slice(&rows, &[n], device);
let k_pool = kv.k_blocks.reshape(&[pool_rows, d]).map_err(Error::Numr)?;
let v_pool = kv.v_blocks.reshape(&[pool_rows, d]).map_err(Error::Numr)?;
let k_dense = client
.index_select(&k_pool, 0, &idx)
.map_err(Error::Numr)?
.reshape(&[batch_size, nkv, sk, d])
.map_err(Error::Numr)?;
let v_dense = client
.index_select(&v_pool, 0, &idx)
.map_err(Error::Numr)?
.reshape(&[batch_size, nkv, sk, d])
.map_err(Error::Numr)?;
let (dq, dk_dense, dv_dense) = standard_attention_bwd(
client,
dout,
q,
&k_dense,
&v_dense,
output,
StandardAttnConfig {
num_heads: cfg.num_heads,
num_kv_heads: nkv,
causal: cfg.causal,
window_size: 0,
},
)?;
let dk_blocks = scatter_kv_grad(
client,
&dk_dense,
&idx,
pool_rows,
n,
d,
kv.k_blocks.shape(),
)?;
let dv_blocks = scatter_kv_grad(
client,
&dv_dense,
&idx,
pool_rows,
n,
d,
kv.v_blocks.shape(),
)?;
Ok((dq, dk_blocks, dv_blocks))
}
fn scatter_kv_grad<R, C>(
client: &C,
grad_dense: &Tensor<R>,
idx: &Tensor<R>,
pool_rows: usize,
n: usize,
d: usize,
pool_shape: &[usize],
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: StandardAttentionClient<R> + IndexingOps<R>,
{
let grad_rows = grad_dense.reshape(&[n, d]).map_err(Error::Numr)?;
let index2d = idx
.reshape(&[n, 1])
.map_err(Error::Numr)?
.broadcast_to(&[n, d])
.map_err(Error::Numr)?
.contiguous()?;
let dst = Tensor::<R>::zeros(&[pool_rows, d], DType::F32, grad_dense.device());
let scattered = client
.scatter_reduce(&dst, 0, &index2d, &grad_rows, ScatterReduceOp::Sum, true)
.map_err(Error::Numr)?;
scattered.reshape(pool_shape).map_err(Error::Numr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ops::impl_generic::attention::flash_standard::standard_attention_fwd;
use crate::test_utils::cpu_setup;
use numr::ops::IndexingOps;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn row_indices_match_layout() {
let bt = [0, 1, 2, 3];
let rows = paged_kv_row_indices(&bt, 2, 2, 3, 2, 2);
assert_eq!(&rows[0..3], &[0, 2, 4]);
assert_eq!(&rows[3..6], &[1, 3, 5]);
assert_eq!(&rows[6..9], &[8, 10, 12]);
}
#[test]
fn paged_bwd_matches_dense_reference() {
let (client, device) = cpu_setup();
let (b, h, nkv, d) = (2usize, 2usize, 1usize, 4usize);
let block_size = 2usize;
let seq_len_k = 4usize;
let seq_len_q = 4usize;
let max_num_blocks = 2usize;
let num_blocks = 4usize;
let nrows = num_blocks * block_size * nkv;
let pool: Vec<f32> = (0..nrows * d).map(|i| (i as f32 * 0.017).sin()).collect();
let k_blocks =
Tensor::<CpuRuntime>::from_slice(&pool, &[num_blocks, block_size, nkv, d], &device);
let vpool: Vec<f32> = (0..nrows * d).map(|i| (i as f32 * 0.023).cos()).collect();
let v_blocks =
Tensor::<CpuRuntime>::from_slice(&vpool, &[num_blocks, block_size, nkv, d], &device);
let bt = Tensor::<CpuRuntime>::from_slice(&[0i32, 1, 2, 3], &[b, max_num_blocks], &device);
let qd: Vec<f32> = (0..b * h * seq_len_q * d)
.map(|i| (i as f32 * 0.01).cos())
.collect();
let q = Tensor::<CpuRuntime>::from_slice(&qd, &[b, h, seq_len_q, d], &device);
let dod: Vec<f32> = (0..b * h * seq_len_q * d)
.map(|i| (i as f32 * 0.03).sin())
.collect();
let dout = Tensor::<CpuRuntime>::from_slice(&dod, &[b, h, seq_len_q, d], &device);
let cfg = PagedAttnConfig {
num_heads: h,
num_kv_heads: nkv,
seq_len_k,
head_dim: d,
block_size,
causal: true,
};
let rows =
paged_kv_row_indices(&[0, 1, 2, 3], b, nkv, seq_len_k, block_size, max_num_blocks);
let n = rows.len();
let idx = Tensor::<CpuRuntime>::from_slice(&rows, &[n], &device);
let k_pool = k_blocks.reshape(&[nrows, d]).unwrap();
let v_pool = v_blocks.reshape(&[nrows, d]).unwrap();
let k_dense = client
.index_select(&k_pool, 0, &idx)
.unwrap()
.reshape(&[b, nkv, seq_len_k, d])
.unwrap();
let v_dense = client
.index_select(&v_pool, 0, &idx)
.unwrap()
.reshape(&[b, nkv, seq_len_k, d])
.unwrap();
let scfg = StandardAttnConfig {
num_heads: h,
num_kv_heads: nkv,
causal: true,
window_size: 0,
};
let (output, _lse) = standard_attention_fwd(&client, &q, &k_dense, &v_dense, scfg).unwrap();
let (dq_ref, dk_ref_dense, _dv_ref_dense) =
standard_attention_bwd(&client, &dout, &q, &k_dense, &v_dense, &output, scfg).unwrap();
let kv = PagedKv {
k_blocks: &k_blocks,
v_blocks: &v_blocks,
block_table: &bt,
};
let (dq, dk_blocks, _dv_blocks) =
paged_attention_bwd_impl(&client, &dout, &q, &kv, &output, cfg).unwrap();
let dq_a: Vec<f32> = dq.to_vec();
let dq_b: Vec<f32> = dq_ref.to_vec();
for (x, y) in dq_a.iter().zip(dq_b.iter()) {
assert!((x - y).abs() < 1e-5, "dq mismatch: {x} vs {y}");
}
let dk_pool = dk_blocks.reshape(&[nrows, d]).unwrap();
let dk_regathered = client.index_select(&dk_pool, 0, &idx).unwrap();
let dk_re: Vec<f32> = dk_regathered.to_vec();
let dk_ref: Vec<f32> = dk_ref_dense.reshape(&[n, d]).unwrap().to_vec();
for (x, y) in dk_re.iter().zip(dk_ref.iter()) {
assert!((x - y).abs() < 1e-5, "dk mismatch: {x} vs {y}");
}
}
}