#[cfg(feature = "cuda")]
pub use cuda_impl::*;
#[cfg(feature = "cuda")]
mod cuda_impl {
use cudarc::driver::sys;
use numr::runtime::Graph;
use numr::runtime::cuda::{CudaClient, CudaRuntime};
use numr::tensor::Tensor;
use crate::error::{Error, Result};
pub struct DeviceScalars {
pub seq_len_k: Tensor<CudaRuntime>,
pub write_pos: Tensor<CudaRuntime>,
}
impl DeviceScalars {
pub fn new(initial_seq_len: usize, device: &numr::runtime::cuda::CudaDevice) -> Self {
let val = initial_seq_len as i32;
let seq_len_k = Tensor::<CudaRuntime>::from_slice(&[val], &[1], device);
let write_pos = Tensor::<CudaRuntime>::from_slice(&[val], &[1], device);
Self {
seq_len_k,
write_pos,
}
}
pub fn seq_len_k_ptr(&self) -> u64 {
self.seq_len_k.ptr()
}
pub fn write_pos_ptr(&self) -> u64 {
self.write_pos.ptr()
}
#[allow(clippy::too_many_arguments)]
pub fn update_rope_slices(
&self,
client: &CudaClient,
rope_cos_cache: &Tensor<CudaRuntime>,
rope_sin_cache: &Tensor<CudaRuntime>,
cos_slice: &numr::autograd::Var<CudaRuntime>,
sin_slice: &numr::autograd::Var<CudaRuntime>,
position: usize,
half_dim: usize,
) -> Result<()> {
let stream = client.stream().cu_stream();
copy_rope_slice_async(
rope_cos_cache,
position * half_dim,
cos_slice.tensor(),
half_dim,
stream,
)?;
copy_rope_slice_async(
rope_sin_cache,
position * half_dim,
sin_slice.tensor(),
half_dim,
stream,
)?;
Ok(())
}
pub fn update(&self, client: &CudaClient, seq_len: usize) -> Result<()> {
let write_pos_val = seq_len as u32;
let seq_len_k_val = (seq_len + 1) as u32;
let stream = client.stream().cu_stream();
unsafe {
let result = sys::cuMemsetD32Async(self.seq_len_k.ptr(), seq_len_k_val, 1, stream);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(Error::InferenceError {
reason: format!("cuMemsetD32Async for seq_len_k failed: {:?}", result),
});
}
let result = sys::cuMemsetD32Async(self.write_pos.ptr(), write_pos_val, 1, stream);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(Error::InferenceError {
reason: format!("cuMemsetD32Async for write_pos failed: {:?}", result),
});
}
}
Ok(())
}
}
fn copy_rope_slice_async(
src: &Tensor<CudaRuntime>,
src_elem_off: usize,
dst: &Tensor<CudaRuntime>,
head_dim: usize,
stream: sys::CUstream,
) -> Result<()> {
let bytes = head_dim * std::mem::size_of::<f32>();
let src_ptr = src.ptr() + (src_elem_off * std::mem::size_of::<f32>()) as u64;
let dst_ptr = dst.ptr();
unsafe {
let result = sys::cuMemcpyDtoDAsync_v2(dst_ptr, src_ptr, bytes, stream);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(Error::InferenceError {
reason: format!("cuMemcpyDtoDAsync_v2 for RoPE slice failed: {:?}", result),
});
}
}
Ok(())
}
pub fn argmax_to_buf(
client: &CudaClient,
logits: &Tensor<CudaRuntime>,
out: &Tensor<CudaRuntime>,
) -> numr::error::Result<()> {
use numr::ops::traits::IndexingOps;
let last_dim = logits.shape().len() - 1;
let token_ids = client.argmax(logits, last_dim, false)?;
let bytes = std::mem::size_of::<i64>();
unsafe {
let result = sys::cuMemcpyAsync(
out.ptr(),
token_ids.ptr(),
bytes,
client.stream().cu_stream(),
);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(numr::error::Error::Backend(format!(
"argmax_to_buf cuMemcpyAsync failed: {:?}",
result
)));
}
}
Ok(())
}
pub fn batch_argmax_to_buf(
client: &CudaClient,
logits: &Tensor<CudaRuntime>,
out: &Tensor<CudaRuntime>,
batch_size: usize,
) -> numr::error::Result<()> {
use numr::ops::traits::IndexingOps;
let last_dim = logits.shape().len() - 1;
let token_ids = client.argmax(logits, last_dim, false)?;
let bytes = batch_size * std::mem::size_of::<i64>();
unsafe {
let result = sys::cuMemcpyAsync(
out.ptr(),
token_ids.ptr(),
bytes,
client.stream().cu_stream(),
);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(numr::error::Error::Backend(format!(
"batch_argmax_to_buf cuMemcpyAsync failed: {:?}",
result
)));
}
}
Ok(())
}
pub struct DecodeGraph {
pub graph: numr::runtime::cuda::CudaGraph,
pub device_scalars: DeviceScalars,
pub token_buf: Tensor<CudaRuntime>,
pub cos_slice: Tensor<CudaRuntime>,
pub sin_slice: Tensor<CudaRuntime>,
pub rope_cos_cache: Tensor<CudaRuntime>,
pub rope_sin_cache: Tensor<CudaRuntime>,
pub next_token_buf: Tensor<CudaRuntime>,
pub head_dim: usize,
pub seq_len: usize,
}
impl DecodeGraph {
pub fn seed_next_token(&self, client: &CudaClient, token: i64) -> Result<()> {
let lo = (token as u64 & 0xFFFF_FFFF) as u32; let hi = ((token as u64) >> 32) as u32; let stream = client.stream().cu_stream();
unsafe {
let result = sys::cuMemsetD32Async(self.next_token_buf.ptr(), lo, 1, stream);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(Error::InferenceError {
reason: format!("seed_next_token cuMemsetD32Async lo failed: {:?}", result),
});
}
let result = sys::cuMemsetD32Async(self.next_token_buf.ptr() + 4, hi, 1, stream);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(Error::InferenceError {
reason: format!("seed_next_token cuMemsetD32Async hi failed: {:?}", result),
});
}
}
Ok(())
}
pub fn pre_replay_and_launch(&mut self, client: &CudaClient) -> Result<()> {
let stream = client.stream().cu_stream();
unsafe {
let result = sys::cuMemcpyDtoDAsync_v2(
self.token_buf.ptr(),
self.next_token_buf.ptr(),
std::mem::size_of::<i64>(),
stream,
);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(Error::InferenceError {
reason: format!("cuMemcpyDtoDAsync_v2 for token_buf failed: {:?}", result),
});
}
}
self.device_scalars.update(client, self.seq_len)?;
let stream_handle = client.stream().cu_stream();
copy_rope_slice_async(
&self.rope_cos_cache,
self.seq_len * self.head_dim,
&self.cos_slice,
self.head_dim,
stream_handle,
)?;
copy_rope_slice_async(
&self.rope_sin_cache,
self.seq_len * self.head_dim,
&self.sin_slice,
self.head_dim,
stream_handle,
)?;
self.graph.launch()?;
self.seq_len += 1;
Ok(())
}
}
pub struct PagedDecodeGraph {
pub graph: numr::runtime::cuda::CudaGraph,
pub device_scalars: DeviceScalars,
pub token_buf: Tensor<CudaRuntime>,
pub cos_slice: Tensor<CudaRuntime>,
pub sin_slice: Tensor<CudaRuntime>,
pub rope_cos_cache: Tensor<CudaRuntime>,
pub rope_sin_cache: Tensor<CudaRuntime>,
pub next_token_buf: Tensor<CudaRuntime>,
pub slot_mapping: Tensor<CudaRuntime>,
pub head_dim: usize,
pub seq_len: usize,
}
impl PagedDecodeGraph {
pub fn seed_next_token(&self, client: &CudaClient, token: i64) -> Result<()> {
let lo = (token as u64 & 0xFFFF_FFFF) as u32;
let hi = ((token as u64) >> 32) as u32;
let stream = client.stream().cu_stream();
unsafe {
let result = sys::cuMemsetD32Async(self.next_token_buf.ptr(), lo, 1, stream);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(Error::InferenceError {
reason: format!("PagedDecodeGraph seed lo failed: {:?}", result),
});
}
let result = sys::cuMemsetD32Async(self.next_token_buf.ptr() + 4, hi, 1, stream);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(Error::InferenceError {
reason: format!("PagedDecodeGraph seed hi failed: {:?}", result),
});
}
}
Ok(())
}
pub fn pre_replay_and_launch(&mut self, client: &CudaClient) -> Result<()> {
let stream = client.stream().cu_stream();
unsafe {
let result = sys::cuMemcpyDtoDAsync_v2(
self.token_buf.ptr(),
self.next_token_buf.ptr(),
std::mem::size_of::<i64>(),
stream,
);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(Error::InferenceError {
reason: format!("PagedDecodeGraph token_buf copy failed: {:?}", result),
});
}
}
self.device_scalars.update(client, self.seq_len)?;
unsafe {
let result =
sys::cuMemsetD32Async(self.slot_mapping.ptr(), self.seq_len as u32, 1, stream);
if result != sys::CUresult::CUDA_SUCCESS {
return Err(Error::InferenceError {
reason: format!(
"PagedDecodeGraph slot_mapping update failed: {:?}",
result
),
});
}
}
let stream_handle = client.stream().cu_stream();
copy_rope_slice_async(
&self.rope_cos_cache,
self.seq_len * self.head_dim,
&self.cos_slice,
self.head_dim,
stream_handle,
)?;
copy_rope_slice_async(
&self.rope_sin_cache,
self.seq_len * self.head_dim,
&self.sin_slice,
self.head_dim,
stream_handle,
)?;
self.graph.launch()?;
self.seq_len += 1;
Ok(())
}
}
}