use ferrum_types::{FerrumError, Result};
use super::{default_stream, CudaBackend};
impl crate::backend::BackendKvDtype<crate::backend::KvInt8> for CudaBackend {
type KvBuffer = OptionalCudaInt8;
type KvScales = OptionalCudaScalesF16;
}
#[derive(Default)]
pub struct OptionalCudaInt8(pub Option<cudarc::driver::CudaSlice<i8>>);
impl OptionalCudaInt8 {
pub fn alloc(len: usize) -> Self {
let stream = default_stream();
let buf = stream.alloc_zeros::<i8>(len).expect("alloc int8 KV buffer");
Self(Some(buf))
}
pub fn buffer(&self) -> &cudarc::driver::CudaSlice<i8> {
self.0.as_ref().expect("OptionalCudaInt8 not allocated")
}
pub fn buffer_mut(&mut self) -> &mut cudarc::driver::CudaSlice<i8> {
self.0.as_mut().expect("OptionalCudaInt8 not allocated")
}
}
#[derive(Default)]
pub struct OptionalCudaScalesF16(pub Option<cudarc::driver::CudaSlice<half::f16>>);
impl OptionalCudaScalesF16 {
pub fn alloc(len: usize) -> Self {
let stream = default_stream();
let buf = stream
.alloc_zeros::<half::f16>(len)
.expect("alloc int8 KV scales");
Self(Some(buf))
}
pub fn buffer(&self) -> &cudarc::driver::CudaSlice<half::f16> {
self.0
.as_ref()
.expect("OptionalCudaScalesF16 not allocated")
}
pub fn buffer_mut(&mut self) -> &mut cudarc::driver::CudaSlice<half::f16> {
self.0
.as_mut()
.expect("OptionalCudaScalesF16 not allocated")
}
}
impl crate::backend::BackendInt8KvOps for CudaBackend {
fn alloc_paged_int8_layer(
max_blocks_per_seq: usize,
block_size: usize,
num_kv_heads: usize,
head_dim: usize,
) -> crate::backend::KvCacheQuant<Self, crate::backend::KvInt8> {
crate::backend::KvCacheQuant::<CudaBackend, crate::backend::KvInt8>::new_paged_cuda(
max_blocks_per_seq,
block_size,
num_kv_heads,
head_dim,
)
}
fn int8_kv_append_paged(
ctx: &mut Self::Context,
k_in: &Self::Buffer,
v_in: &Self::Buffer,
layer_k: &mut <Self as crate::backend::BackendKvDtype<crate::backend::KvInt8>>::KvBuffer,
layer_v: &mut <Self as crate::backend::BackendKvDtype<crate::backend::KvInt8>>::KvBuffer,
layer_k_scales: &mut <Self as crate::backend::BackendKvDtype<crate::backend::KvInt8>>::KvScales,
layer_v_scales: &mut <Self as crate::backend::BackendKvDtype<crate::backend::KvInt8>>::KvScales,
paged_block_indices: &[u32],
cache_len_before: usize,
tokens: usize,
block_size: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Result<()> {
if tokens == 0 {
return Ok(());
}
let stream = ctx.stream.clone();
let mut slot_mapping_host = vec![0i32; tokens];
for t in 0..tokens {
let global_pos = cache_len_before + t;
let block_logical = global_pos / block_size;
let slot_in_block = global_pos % block_size;
let block_physical = paged_block_indices[block_logical] as usize;
slot_mapping_host[t] = (block_physical * block_size + slot_in_block) as i32;
}
let slot_mapping = stream
.clone_htod(&slot_mapping_host)
.map_err(|e| FerrumError::model(format!("htod slot_mapping: {e}")))?;
if layer_k.0.is_none() {
return Err(FerrumError::model(
"int8_kv_append_paged: layer_k not allocated",
));
}
if layer_v.0.is_none() || layer_k_scales.0.is_none() || layer_v_scales.0.is_none() {
return Err(FerrumError::model(
"int8_kv_append_paged: layer_v / scales not allocated",
));
}
crate::int8_kv::launch_int8_kv_cache_append(
&ctx.ctx,
k_in.as_f16(),
v_in.as_f16(),
layer_k.buffer_mut(),
layer_v.buffer_mut(),
layer_k_scales.buffer_mut(),
layer_v_scales.buffer_mut(),
&slot_mapping,
tokens,
num_kv_heads,
head_dim,
)
.map_err(|e| FerrumError::model(format!("launch_int8_kv_cache_append: {e}")))?;
Ok(())
}
fn int8_paged_decode_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
layer_k: &<Self as crate::backend::BackendKvDtype<crate::backend::KvInt8>>::KvBuffer,
layer_v: &<Self as crate::backend::BackendKvDtype<crate::backend::KvInt8>>::KvBuffer,
layer_k_scales: &<Self as crate::backend::BackendKvDtype<crate::backend::KvInt8>>::KvScales,
layer_v_scales: &<Self as crate::backend::BackendKvDtype<crate::backend::KvInt8>>::KvScales,
block_table: &Self::Buffer,
output: &mut Self::Buffer,
num_q_heads: usize,
num_kv_heads: usize,
head_dim: usize,
valid_kv_len: usize,
block_size: usize,
scale: f32,
) -> Result<()> {
let n_blocks = valid_kv_len.div_ceil(block_size).max(1);
let bt_i32_view = unsafe {
block_table
.transmute::<i32>(n_blocks)
.ok_or_else(|| FerrumError::model("block_table transmute<i32> failed"))?
};
crate::int8_kv::launch_int8_paged_decode_attention(
&ctx.ctx,
q.as_f16(),
layer_k.buffer(),
layer_v.buffer(),
layer_k_scales.buffer(),
layer_v_scales.buffer(),
&bt_i32_view,
output.as_f16_mut(),
num_q_heads,
num_kv_heads,
head_dim,
valid_kv_len,
block_size,
scale,
)
.map_err(|e| FerrumError::model(format!("launch_int8_paged_decode_attention: {e}")))?;
Ok(())
}
}
impl crate::backend::KvCacheQuant<CudaBackend, crate::backend::KvInt8> {
pub fn new_paged_cuda(
max_blocks_per_seq: usize,
block_size: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Self {
use crate::backend::Backend;
let pool_tokens = max_blocks_per_seq * block_size;
let elem_count = pool_tokens * num_kv_heads * head_dim;
let scale_count = pool_tokens * num_kv_heads;
let block_table =
<CudaBackend as Backend>::alloc_typed(crate::backend::Dtype::U32, max_blocks_per_seq);
let mut context_lens = <CudaBackend as Backend>::alloc_typed(crate::backend::Dtype::U32, 1);
let mut bt_ctx = <CudaBackend as Backend>::new_context();
<CudaBackend as Backend>::write_typed::<u32>(&mut bt_ctx, &mut context_lens, &[0u32]);
<CudaBackend as Backend>::sync(&mut bt_ctx);
let bt_buf = block_table;
let cl_buf = context_lens;
crate::backend::KvCacheQuant {
k: OptionalCudaInt8::alloc(elem_count),
v: OptionalCudaInt8::alloc(elem_count),
k_scales: OptionalCudaScalesF16::alloc(scale_count),
v_scales: OptionalCudaScalesF16::alloc(scale_count),
len: 0,
capacity: pool_tokens,
num_kv_heads,
head_dim,
block_size,
block_table: Some(bt_buf),
context_lens: Some(cl_buf),
paged_block_indices: Vec::new(),
_kv_dtype: std::marker::PhantomData,
}
}
}