use crate::error::{Error, Result};
use crate::inference::memory::BlockId;
use crate::ops::cuda::kernels::{self, PREFIX_CACHE_LOOKUP_MODULE};
use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::LaunchConfig;
use numr::dtype::DType;
use numr::runtime::Device;
use numr::runtime::cuda::{CudaClient, CudaRuntime};
use numr::tensor::Tensor;
const BLOCK_SIZE: u32 = 256;
pub fn gpu_prefix_cache_lookup(
client: &CudaClient,
query_hashes: &Tensor<CudaRuntime>,
table_keys: &Tensor<CudaRuntime>,
table_values: &Tensor<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
let qshape = query_hashes.shape();
if qshape.len() != 1 {
return Err(Error::InvalidArgument {
arg: "query_hashes",
reason: format!("expected 1D tensor, got {}D", qshape.len()),
});
}
let num_queries = qshape[0];
let kshape = table_keys.shape();
let vshape = table_values.shape();
if kshape.len() != 1 || vshape.len() != 1 || kshape[0] != vshape[0] {
return Err(Error::InvalidArgument {
arg: "table_keys / table_values",
reason: format!(
"both must be 1D with the same length; got keys {:?} values {:?}",
kshape, vshape
),
});
}
let capacity = kshape[0];
if capacity == 0 || (capacity & (capacity - 1)) != 0 {
return Err(Error::InvalidArgument {
arg: "table_keys",
reason: format!("capacity {} is not a power of two", capacity),
});
}
let device = query_hashes.device();
let device_index = device.id();
let out_block_ids = Tensor::<CudaRuntime>::empty(&[num_queries], DType::I32, device);
let module =
kernels::get_or_load_module(client.context(), device_index, PREFIX_CACHE_LOOKUP_MODULE)?;
let func = kernels::get_kernel_function(&module, "prefix_cache_lookup")?;
let grid_size = (num_queries as u32).div_ceil(BLOCK_SIZE);
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
};
let q_ptr = query_hashes.ptr();
let k_ptr = table_keys.ptr();
let v_ptr = table_values.ptr();
let out_ptr = out_block_ids.ptr();
let cap_i32 = capacity as i32;
let nq_i32 = num_queries as i32;
unsafe {
let mut builder = client.stream().launch_builder(&func);
builder.arg(&q_ptr);
builder.arg(&k_ptr);
builder.arg(&v_ptr);
builder.arg(&out_ptr);
builder.arg(&cap_i32);
builder.arg(&nq_i32);
builder.launch(cfg).map_err(|e| Error::KernelError {
reason: format!("prefix_cache_lookup kernel launch failed: {:?}", e),
})?;
}
Ok(out_block_ids)
}
pub fn result_to_options(result: &Tensor<CudaRuntime>) -> Vec<Option<BlockId>> {
let host_vec: Vec<i32> = result.to_vec::<i32>();
host_vec
.into_iter()
.map(|v| if v < 0 { None } else { Some(v as BlockId) })
.collect()
}