#![allow(unused_variables, dead_code, unused_imports, unused_mut)]
pub mod collective;
pub mod fa2_ffi;
#[cfg(feature = "fa2-source")]
pub mod fa2_source;
pub mod graph;
pub mod int8_kv;
pub mod moe;
pub mod paged;
pub mod quant;
pub mod cublas;
pub mod cuda_decode;
pub mod cuda_graph;
pub mod decode_attention;
pub mod decode_buffers;
pub mod fused_add_rms_norm;
pub mod fused_silu_mul;
pub mod gpu_paged_kv;
pub mod marlin;
pub mod nccl_comm;
pub mod residual_add;
pub mod rms_norm;
pub mod rope;
pub mod tp_decode;
pub mod weight_store;
#[cfg(feature = "triton-kernels")]
pub mod triton_add_bias;
#[cfg(feature = "triton-kernels")]
pub mod triton_fused_add_rms_norm;
#[cfg(feature = "triton-kernels")]
pub mod triton_fused_moe;
#[cfg(feature = "triton-kernels")]
pub mod triton_fused_silu_mul;
#[cfg(feature = "triton-kernels")]
pub mod triton_gelu;
#[cfg(feature = "triton-kernels")]
pub mod triton_layer_norm;
#[cfg(feature = "triton-kernels")]
pub mod triton_meta;
#[cfg(feature = "triton-kernels")]
pub mod triton_ptx;
#[cfg(feature = "triton-kernels")]
pub mod triton_residual_add;
#[cfg(feature = "triton-kernels")]
pub mod triton_residual_add_inplace;
#[cfg(feature = "triton-kernels")]
pub mod triton_rms_norm;
#[cfg(feature = "triton-kernels")]
pub mod triton_softmax;
#[cfg(feature = "triton-kernels")]
pub mod triton_w4a16;
#[cfg(feature = "vllm-marlin")]
pub mod vllm_marlin;
#[cfg(feature = "vllm-paged-attn-v2")]
pub mod vllm_paged_attn;
pub(super) use super::MAX_LAYERS_FOR_GRAPH;
pub use int8_kv::{OptionalCudaInt8, OptionalCudaScalesF16};
#[cfg(feature = "marlin")]
pub use quant::pregrow_marlin_gather_scratch;
pub use quant::{marlin_gemm_with_perm, GptqStoreCuda};
use super::{
AttnConfig, Backend, BackendCollective, BackendGraph, BackendMoeFused, BackendPagedKv,
BackendQuantGguf, BackendQuantMarlin, QuantKind, QuantWeights, ReduceOp,
};
use crate::ptx;
use cudarc::cublas::CudaBlas;
use cudarc::driver::{
CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream, DeviceRepr, LaunchConfig,
PushKernelArg,
};
use cudarc::nvrtc::Ptx;
use ferrum_types::{FerrumError, Result};
use half::f16;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq)]
struct CudaBackendRuntimeEnv {
moe_streams: usize,
cuda_max_kv: Option<usize>,
cuda_device: usize,
}
impl CudaBackendRuntimeEnv {
fn from_env() -> Self {
Self::from_env_vars(std::env::vars())
}
fn from_env_vars<I, K, V>(vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: Into<String>,
{
let mut moe_streams = None;
let mut cuda_max_kv = None;
let mut cuda_device = None;
for (key, value) in vars {
let value = value.into();
match key.as_ref() {
"FERRUM_MOE_STREAMS" => moe_streams = value.parse::<usize>().ok(),
"FERRUM_CUDA_MAX_KV" => cuda_max_kv = value.parse::<usize>().ok(),
"FERRUM_CUDA_DEVICE" => cuda_device = value.parse::<usize>().ok(),
_ => {}
}
}
Self {
moe_streams: moe_streams.unwrap_or(4).max(1),
cuda_max_kv,
cuda_device: cuda_device.unwrap_or(0),
}
}
}
fn cuda_backend_runtime_env() -> &'static CudaBackendRuntimeEnv {
static CONFIG: std::sync::OnceLock<CudaBackendRuntimeEnv> = std::sync::OnceLock::new();
CONFIG.get_or_init(CudaBackendRuntimeEnv::from_env)
}
thread_local! {
static CUDA_DEVICE_SCOPE: std::cell::Cell<Option<usize>> = const { std::cell::Cell::new(None) };
}
struct CudaDeviceScopeGuard {
previous: Option<usize>,
}
impl CudaDeviceScopeGuard {
fn enter(ordinal: usize) -> Self {
let previous = CUDA_DEVICE_SCOPE.with(|scope| {
let previous = scope.get();
scope.set(Some(ordinal));
previous
});
Self { previous }
}
}
impl Drop for CudaDeviceScopeGuard {
fn drop(&mut self) {
CUDA_DEVICE_SCOPE.with(|scope| scope.set(self.previous));
}
}
pub(super) fn current_device_ordinal() -> usize {
CUDA_DEVICE_SCOPE
.with(|scope| scope.get())
.unwrap_or(cuda_backend_runtime_env().cuda_device)
}
fn with_cuda_device_ordinal<R>(device_ordinal: Option<usize>, body: impl FnOnce() -> R) -> R {
if let Some(ordinal) = device_ordinal {
let _guard = CudaDeviceScopeGuard::enter(ordinal);
body()
} else {
body()
}
}
pub struct CudaState {
pub ordinal: usize,
pub ctx: Arc<CudaContext>,
pub stream: Arc<CudaStream>,
pub blas: Arc<CudaBlas>,
modules: HashMap<&'static str, Arc<CudaModule>>,
pub use_dev_state: bool,
pub capture_in_flight: bool,
batched_scratch_u64_k: Option<CudaSlice<u64>>,
batched_scratch_u64_v: Option<CudaSlice<u64>>,
batched_scratch_u64_cache: Option<CudaSlice<u64>>,
batched_scratch_i32_kv_lens: Option<CudaSlice<i32>>,
batched_scratch_i32_cache_lens: Option<CudaSlice<i32>>,
batched_host_k_ptrs: Box<[u64; HOST_STAGING_TOTAL]>,
batched_host_v_ptrs: Box<[u64; HOST_STAGING_TOTAL]>,
batched_host_cache_ptrs: Box<[u64; HOST_STAGING_TOTAL]>,
batched_host_kv_lens: Box<[i32; HOST_STAGING_TOTAL]>,
batched_host_cache_lens: Box<[i32; HOST_STAGING_TOTAL]>,
moe_streams: Option<Vec<Arc<CudaStream>>>,
moe_entry_event: Option<usize>,
moe_exit_events: Option<Vec<usize>>,
moe_route_ids: Option<CudaSlice<f16>>,
moe_route_weights: Option<CudaSlice<f16>>,
moe_route_capacity: usize,
paged_attn_out_tm: Option<crate::backend::CudaBuf>,
paged_attn_out_tm_capacity: usize,
}
#[cfg(feature = "vllm-moe-marlin")]
static VLLM_MOE_C_TMP: std::sync::OnceLock<std::sync::RwLock<HashMap<usize, CudaSlice<f32>>>> =
std::sync::OnceLock::new();
static ARGMAX_OUT: std::sync::OnceLock<std::sync::RwLock<HashMap<usize, CudaSlice<i32>>>> =
std::sync::OnceLock::new();
fn argmax_out_slots() -> &'static std::sync::RwLock<HashMap<usize, CudaSlice<i32>>> {
ARGMAX_OUT.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
}
fn with_argmax_out<R>(
stream: &Arc<CudaStream>,
ordinal: usize,
m: usize,
body: impl FnOnce(&mut CudaSlice<i32>) -> R,
) -> R {
let slots = argmax_out_slots();
{
let g = slots.read().expect("ARGMAX_OUT poisoned");
if let Some(buf) = g.get(&ordinal) {
if buf.len() >= m {
drop(g);
let mut w = slots.write().expect("ARGMAX_OUT poisoned");
return body(w.get_mut(&ordinal).expect("just observed Some"));
}
}
}
let capacity = m.max(64).next_power_of_two();
let mut w = slots.write().expect("ARGMAX_OUT poisoned");
let need_alloc = w.get(&ordinal).map(|b| b.len() < m).unwrap_or(true);
if need_alloc {
let new = unsafe { stream.alloc::<i32>(capacity) }.expect("argmax_out alloc");
w.insert(ordinal, new);
}
body(w.get_mut(&ordinal).expect("alloc above"))
}
#[cfg(feature = "vllm-moe-marlin")]
fn vllm_moe_c_tmp_slots() -> &'static std::sync::RwLock<HashMap<usize, CudaSlice<f32>>> {
VLLM_MOE_C_TMP.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
}
#[cfg(feature = "vllm-moe-marlin")]
pub fn with_vllm_moe_c_tmp<R>(
stream: &Arc<CudaStream>,
ordinal: usize,
body: impl FnOnce(&mut CudaSlice<f32>) -> R,
) -> R {
let slots = vllm_moe_c_tmp_slots();
{
let g = slots.read().expect("VLLM_MOE_C_TMP poisoned");
if g.contains_key(&ordinal) {
drop(g);
let mut w = slots.write().expect("VLLM_MOE_C_TMP poisoned");
let s = w.get_mut(&ordinal).expect("just observed Some");
return body(s);
}
}
let mut w = slots.write().expect("VLLM_MOE_C_TMP poisoned");
if !w.contains_key(&ordinal) {
const C_TMP_SIZE_F32: usize = 4 * 1024 * 1024;
let buf = stream
.alloc_zeros::<f32>(C_TMP_SIZE_F32)
.expect("alloc_zeros vllm_moe_c_tmp_f32 (per-device)");
tracing::info!(
"vLLM moe c_tmp scratch allocated (device {ordinal}): {} fp32 ({:.1} MB)",
C_TMP_SIZE_F32,
(C_TMP_SIZE_F32 * 4) as f32 / 1e6
);
w.insert(ordinal, buf);
}
body(w.get_mut(&ordinal).unwrap())
}
pub(super) const BATCHED_SCRATCH_CAP: usize = 64;
const MAX_GRAPH_SLOTS: usize = 2 * super::MAX_LAYERS_FOR_GRAPH;
pub(super) const HOST_STAGING_TOTAL: usize = MAX_GRAPH_SLOTS * BATCHED_SCRATCH_CAP;
impl CudaState {
pub fn moe_stream_pool(&mut self) -> &[Arc<CudaStream>] {
if self.moe_streams.is_none() {
let n = cuda_backend_runtime_env().moe_streams;
let mut pool = Vec::with_capacity(n);
for _ in 0..n {
let s = self
.ctx
.new_stream()
.expect("CudaState::moe_stream_pool: new_stream failed");
pool.push(s);
}
tracing::info!("MoE stream pool initialized: {} streams", n);
self.moe_streams = Some(pool);
}
self.moe_streams.as_ref().unwrap()
}
pub fn moe_sync_events(
&mut self,
) -> (
cudarc::driver::sys::CUevent,
Vec<cudarc::driver::sys::CUevent>,
) {
use cudarc::driver::sys as cu;
if self.moe_entry_event.is_none() {
let n = self.moe_stream_pool().len();
let mut entry: cu::CUevent = std::ptr::null_mut();
unsafe {
cu::cuEventCreate(&mut entry, 2);
}
let mut exits: Vec<usize> = Vec::with_capacity(n);
for _ in 0..n {
let mut e: cu::CUevent = std::ptr::null_mut();
unsafe {
cu::cuEventCreate(&mut e, 2);
}
exits.push(e as usize);
}
self.moe_entry_event = Some(entry as usize);
self.moe_exit_events = Some(exits);
tracing::info!("MoE sync events initialized: 1 entry + {} exits", n);
}
let entry = self.moe_entry_event.unwrap() as cu::CUevent;
let exits: Vec<cu::CUevent> = self
.moe_exit_events
.as_ref()
.unwrap()
.iter()
.map(|&p| p as cu::CUevent)
.collect();
(entry, exits)
}
fn module(&mut self, key: &'static str, ptx_src: &str) -> Arc<CudaModule> {
if let Some(m) = self.modules.get(key) {
return m.clone();
}
let m = ensure_module(self.ordinal, &self.ctx, key, ptx_src);
self.modules.insert(key, m.clone());
m
}
pub(crate) fn func(
&mut self,
module_key: &'static str,
ptx_src: &str,
fn_name: &'static str,
) -> CudaFunction {
let m = self.module(module_key, ptx_src);
m.load_function(fn_name)
.unwrap_or_else(|e| panic!("CudaBackend: load_function({fn_name}): {e}"))
}
}
#[repr(C)]
#[derive(Clone, Copy)]
struct FlashAttnParams {
batch: i32,
num_heads: i32,
num_kv_heads: i32,
q_len: i32,
kv_len: i32,
head_dim: i32,
causal: i32,
pos_offset: i32,
kv_seq_stride: i32,
}
unsafe impl DeviceRepr for FlashAttnParams {}
pub struct CudaBackend;
impl Backend for CudaBackend {
type Buffer = crate::backend::CudaBuf;
type Context = CudaState;
type Timer = crate::backend::timer::CudaTimer;
fn make_timer() -> Self::Timer {
crate::backend::timer::CudaTimer::new()
}
fn new_context() -> Self::Context {
let ordinal = current_device_ordinal();
let stream = default_stream();
let ctx = stream.context().clone();
let blas = ensure_blas_handle(&stream);
ensure_decode_state_bufs(&stream);
ensure_batched_scratch(&stream);
unsafe {
ctx.disable_event_tracking();
}
Self::Context {
ordinal,
ctx,
stream,
blas,
modules: HashMap::new(),
use_dev_state: false,
capture_in_flight: false,
batched_scratch_u64_k: None,
batched_scratch_u64_v: None,
batched_scratch_u64_cache: None,
batched_scratch_i32_kv_lens: None,
batched_scratch_i32_cache_lens: None,
batched_host_k_ptrs: Box::new([0u64; HOST_STAGING_TOTAL]),
batched_host_v_ptrs: Box::new([0u64; HOST_STAGING_TOTAL]),
batched_host_cache_ptrs: Box::new([0u64; HOST_STAGING_TOTAL]),
batched_host_kv_lens: Box::new([0i32; HOST_STAGING_TOTAL]),
batched_host_cache_lens: Box::new([0i32; HOST_STAGING_TOTAL]),
moe_streams: None,
moe_entry_event: None,
moe_exit_events: None,
moe_route_ids: None,
moe_route_weights: None,
moe_route_capacity: 0,
paged_attn_out_tm: None,
paged_attn_out_tm_capacity: 0,
}
}
fn with_device_ordinal<R>(device_ordinal: Option<usize>, body: impl FnOnce() -> R) -> R {
with_cuda_device_ordinal(device_ordinal, body)
}
fn supports_device_ordinal_scope() -> bool {
true
}
fn alloc_typed(dtype: crate::backend::Dtype, n: usize) -> Self::Buffer {
use crate::backend::{CudaBuf, Dtype};
let n = n.max(1);
with_stream(|stream| match dtype {
Dtype::F32 => CudaBuf::from_f32(
stream
.alloc_zeros::<f32>(n)
.expect("CudaBackend::alloc_typed: alloc_zeros<f32>"),
),
Dtype::F16 => CudaBuf::from_f16(
stream
.alloc_zeros::<f16>(n)
.expect("CudaBackend::alloc_typed: alloc_zeros<f16>"),
),
Dtype::U32 => CudaBuf::from_u32(
stream
.alloc_zeros::<u32>(n)
.expect("CudaBackend::alloc_typed: alloc_zeros<u32>"),
),
Dtype::I32 => CudaBuf::from_i32(
stream
.alloc_zeros::<i32>(n)
.expect("CudaBackend::alloc_typed: alloc_zeros<i32>"),
),
Dtype::I8 => CudaBuf::from_i8(
stream
.alloc_zeros::<i8>(n)
.expect("CudaBackend::alloc_typed: alloc_zeros<i8>"),
),
})
}
fn from_slice_typed<T: crate::backend::HostDtype>(data: &[T]) -> Self::Buffer {
use crate::backend::{CudaBuf, Dtype};
with_stream(|stream| match T::DTYPE {
Dtype::F32 => {
let host: &[f32] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, data.len()) };
CudaBuf::from_f32(stream.clone_htod(host).expect("cuda htod f32"))
}
Dtype::F16 => {
let host: &[f16] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f16, data.len()) };
CudaBuf::from_f16(stream.clone_htod(host).expect("cuda htod f16"))
}
Dtype::U32 => {
let host: &[u32] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u32, data.len()) };
CudaBuf::from_u32(stream.clone_htod(host).expect("cuda htod u32"))
}
Dtype::I32 => {
let host: &[i32] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i32, data.len()) };
CudaBuf::from_i32(stream.clone_htod(host).expect("cuda htod i32"))
}
Dtype::I8 => {
let host: &[i8] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i8, data.len()) };
CudaBuf::from_i8(stream.clone_htod(host).expect("cuda htod i8"))
}
})
}
fn write_typed<T: crate::backend::HostDtype>(
ctx: &mut Self::Context,
dst: &mut Self::Buffer,
data: &[T],
) {
use crate::backend::Dtype;
if data.is_empty() {
return;
}
let stream = ctx.stream.clone();
match T::DTYPE {
Dtype::U32 => {
let host: &[u32] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u32, data.len()) };
let d = dst.as_u32_mut();
stream.memcpy_htod(host, d).expect("cuda write_typed u32");
}
Dtype::I32 => {
let host: &[i32] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i32, data.len()) };
let d = dst.as_i32_mut();
stream.memcpy_htod(host, d).expect("cuda write_typed i32");
}
Dtype::F32 => {
let host: &[f32] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, data.len()) };
let d = dst.as_f32_mut();
stream.memcpy_htod(host, d).expect("cuda write_typed f32");
}
Dtype::F16 => {
let host: &[f16] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f16, data.len()) };
let d = dst.as_f16_mut();
stream.memcpy_htod(host, d).expect("cuda write_typed f16");
}
Dtype::I8 => {
let host: &[i8] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i8, data.len()) };
let d = dst.as_i8_mut();
stream.memcpy_htod(host, d).expect("cuda write_typed i8");
}
}
}
fn sync(ctx: &mut Self::Context) {
ctx.stream.synchronize().expect("CudaBackend: stream sync");
}
fn alloc(len: usize) -> Self::Buffer {
with_stream(|stream| {
crate::backend::CudaBuf::from_f16(
unsafe { stream.alloc::<f16>(len) }.expect("cuda alloc"),
)
})
}
fn from_slice(data: &[f32]) -> Self::Buffer {
let host: Vec<f16> = data.iter().map(|&x| f16::from_f32(x)).collect();
with_stream(|stream| {
crate::backend::CudaBuf::from_f16(stream.clone_htod(&host).expect("cuda htod"))
})
}
fn write_f32_to_activation(ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[f32]) {
if data.is_empty() {
return;
}
let host: Vec<f16> = data.iter().map(|&x| f16::from_f32(x)).collect();
let mut dst_view = dst.as_f16_mut().slice_mut(0..data.len());
ctx.stream
.memcpy_htod(&host, &mut dst_view)
.expect("cuda write_f32_to_activation");
}
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
with_stream(|stream| {
let mut host = vec![f16::ZERO; len];
let view = buf.as_f16().slice(0..len);
stream.memcpy_dtoh(&view, &mut host).expect("cuda dtoh");
stream.synchronize().expect("cuda dtoh sync");
host.into_iter().map(|x| x.to_f32()).collect()
})
}
fn argmax_rows_f16(
ctx: &mut Self::Context,
logits: &Self::Buffer,
m: usize,
n: usize,
) -> Result<Vec<u32>> {
let func = ctx.func("argmax_rows", ptx::ARGMAX_ROWS, "argmax_rows_f16");
let stream = ctx.stream.clone();
let host = with_argmax_out(&stream, ctx.ordinal, m, |out_dev| -> Result<Vec<i32>> {
let n_i32 = n as i32;
let mut b = stream.launch_builder(&func);
b.arg(logits);
b.arg(&n_i32);
b.arg(&mut *out_dev);
unsafe {
b.launch(LaunchConfig {
grid_dim: (m as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| FerrumError::internal(format!("argmax_rows launch: {e}")))?;
let mut host = vec![0i32; m];
let view = out_dev.slice(0..m);
stream
.memcpy_dtoh(&view, &mut host)
.map_err(|e| FerrumError::internal(format!("argmax_rows dtoh: {e}")))?;
stream
.synchronize()
.map_err(|e| FerrumError::internal(format!("argmax_rows sync: {e}")))?;
Ok(host)
})?;
Ok(host.into_iter().map(|x| x as u32).collect())
}
fn rms_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
) {
let func = ctx.func("rms_norm", ptx::RMS_NORM, "rms_norm_f16");
let dim_i32 = dim as i32;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(x);
b.arg(w);
b.arg(out);
b.arg(&dim_i32);
b.arg(&eps);
unsafe {
b.launch(LaunchConfig {
grid_dim: (tokens as u32, 1, 1),
block_dim: (dim.min(1024) as u32, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("rms_norm launch");
}
fn fused_add_rms_norm(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
) {
let func = ctx.func(
"fused_add_rms_norm",
ptx::FUSED_ADD_RMS_NORM,
"fused_add_rms_norm_inplace_f16",
);
let dim_i32 = dim as i32;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(x);
b.arg(residual);
b.arg(w);
b.arg(out);
b.arg(&dim_i32);
b.arg(&eps);
unsafe {
b.launch(LaunchConfig {
grid_dim: (tokens as u32, 1, 1),
block_dim: (dim.min(1024) as u32, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("fused_add_rms_norm launch");
}
fn gemm(
ctx: &mut Self::Context,
a: &Self::Buffer,
b: &Self::Buffer,
out: &mut Self::Buffer,
m: usize,
n: usize,
k: usize,
) {
use cudarc::cublas::result::gemm_ex;
use cudarc::cublas::sys::{
cublasComputeType_t, cublasGemmAlgo_t, cublasOperation_t, cudaDataType_t,
};
use cudarc::driver::{DevicePtr, DevicePtrMut};
let (a_ptr, _rec_a) = b.as_f16().device_ptr(&ctx.stream); let (b_ptr, _rec_b) = a.as_f16().device_ptr(&ctx.stream); let (c_ptr, _rec_c) = out.as_f16_mut().device_ptr_mut(&ctx.stream);
with_blas_scalars(ctx.ordinal, |alpha_f32, beta_f32| {
let (alpha_ptr, _ga) = alpha_f32.device_ptr(&ctx.stream);
let (beta_ptr, _gb) = beta_f32.device_ptr(&ctx.stream);
unsafe {
gemm_ex(
*ctx.blas.handle(),
cublasOperation_t::CUBLAS_OP_T,
cublasOperation_t::CUBLAS_OP_N,
n as i32,
m as i32,
k as i32,
alpha_ptr as *const _,
a_ptr as *const _,
cudaDataType_t::CUDA_R_16F,
k as i32,
b_ptr as *const _,
cudaDataType_t::CUDA_R_16F,
k as i32,
beta_ptr as *const _,
c_ptr as *mut _,
cudaDataType_t::CUDA_R_16F,
n as i32,
cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16F,
cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
}
.expect("gemm (cublasGemmEx, compute=32F_FAST_16F, algo=TENSOR_OP)");
});
}
fn flash_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k: &Self::Buffer,
v: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
q_len: usize,
kv_len: usize,
pos_offset: usize,
cfg: &AttnConfig,
) {
if q_len == 1 {
let use_dyn = ctx.use_dev_state;
let func_name = if use_dyn {
"decode_attention_head_major_f16_dyn"
} else {
"decode_attention_head_major_f16"
};
let func = ctx.func("decode_attention_hm", ptx::DECODE_ATTENTION_HM, func_name);
let num_q = cfg.num_heads as i32;
let num_kv = cfg.num_kv_heads as i32;
let hd = cfg.head_dim as i32;
let capacity = if cfg.kv_seq_stride > 0 {
cfg.kv_seq_stride as i32
} else {
kv_len as i32
};
let valid_kv_scalar = kv_len as i32;
let scale = cfg.scale;
const DECODE_MAX_KV_POS_DEFAULT: usize = 8192; let env_cap = cuda_backend_runtime_env()
.cuda_max_kv
.unwrap_or(DECODE_MAX_KV_POS_DEFAULT);
let max_kv_pos = capacity.min(env_cap as i32) as u32;
let shared_mem = max_kv_pos * 4;
if shared_mem > 48 * 1024 {
let _ = func.set_attribute(
cudarc::driver::sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_mem as i32,
);
}
let stream = ctx.stream.clone();
let dec_guard = if use_dyn {
Some(
decode_state_slot_for_ordinal(ctx.ordinal)
.read()
.expect("DECODE_STATE poisoned"),
)
} else {
None
};
let mut bld = stream.launch_builder(&func);
bld.arg(q);
bld.arg(k);
bld.arg(v);
bld.arg(out);
bld.arg(&num_q);
bld.arg(&num_kv);
bld.arg(&hd);
bld.arg(&capacity);
if use_dyn {
let bufs = dec_guard.as_ref().unwrap().as_ref().unwrap();
bld.arg(&bufs.kv);
} else {
bld.arg(&valid_kv_scalar);
}
bld.arg(&scale);
unsafe {
bld.launch(LaunchConfig {
grid_dim: (cfg.num_heads as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: shared_mem,
})
}
.expect("decode_attention_head_major launch");
drop(dec_guard);
return;
}
let func = ctx.func(
"flash_attn_full",
ptx::FLASH_ATTN_FULL,
"flash_attn_full_f16",
);
let params = FlashAttnParams {
batch: batch as i32,
num_heads: cfg.num_heads as i32,
num_kv_heads: cfg.num_kv_heads as i32,
q_len: q_len as i32,
kv_len: kv_len as i32,
head_dim: cfg.head_dim as i32,
causal: if cfg.causal { 1 } else { 0 },
pos_offset: pos_offset as i32,
kv_seq_stride: if cfg.kv_seq_stride > 0 {
cfg.kv_seq_stride as i32
} else {
kv_len as i32
},
};
const TILE_Q: usize = 16;
let num_q_tiles = (q_len + TILE_Q - 1) / TILE_Q;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(q);
b.arg(k);
b.arg(v);
b.arg(out);
b.arg(¶ms);
unsafe {
b.launch(LaunchConfig {
grid_dim: (num_q_tiles as u32, cfg.num_heads as u32, batch as u32),
block_dim: (TILE_Q as u32, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("flash_attn_full launch");
}
fn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
) {
let src_view = src.as_f16().slice(src_offset..src_offset + len);
let mut dst_view = dst.as_f16_mut().slice_mut(dst_offset..dst_offset + len);
ctx.stream
.memcpy_dtod(&src_view, &mut dst_view)
.expect("copy_slice dtod");
}
fn embedding_lookup_dev(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
dim: usize,
) {
let dim_i32 = dim as i32;
let batch_i32 = batch as i32;
let block = 256u32;
let grid_x = ((dim as u32) + block - 1) / block;
let func = ctx.func(
"embedding_lookup",
ptx::EMBEDDING_LOOKUP,
"embedding_lookup_f16",
);
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(table);
b.arg(ids);
b.arg(out);
b.arg(&batch_i32);
b.arg(&dim_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid_x, batch as u32, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("embedding_lookup_dev launch");
}
fn embedding_lookup(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &[u32],
out: &mut Self::Buffer,
dim: usize,
) {
let dim_i32 = dim as i32;
let block = 256u32;
let grid_x = ((dim as u32) + block - 1) / block;
if ctx.use_dev_state {
debug_assert!(ids.len() == 1, "dev_state embedding requires batch=1");
let func = ctx.func(
"embedding_lookup",
ptx::EMBEDDING_LOOKUP,
"embedding_lookup_f16_dyn",
);
let stream = ctx.stream.clone();
let dec_guard = decode_state_slot_for_ordinal(ctx.ordinal)
.read()
.expect("DECODE_STATE poisoned");
let bufs = dec_guard.as_ref().expect("DecodeStateBufs");
let mut b = stream.launch_builder(&func);
b.arg(table);
b.arg(&bufs.token);
b.arg(out);
b.arg(&dim_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid_x, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("embedding_lookup_dyn launch");
drop(dec_guard);
return;
}
let batch = ids.len();
let stream = ctx.stream.clone();
let ids_dev = stream.clone_htod(ids).expect("embedding_lookup ids htod");
let func = ctx.func(
"embedding_lookup",
ptx::EMBEDDING_LOOKUP,
"embedding_lookup_f16",
);
let batch_i32 = batch as i32;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(table);
b.arg(&ids_dev);
b.arg(out);
b.arg(&batch_i32);
b.arg(&dim_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid_x, batch as u32, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("embedding_lookup launch");
}
fn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
) {
let func = ctx.func("split_qkv", ptx::SPLIT_QKV, "split_qkv_f16");
let tokens_i32 = tokens as i32;
let q_dim_i32 = q_dim as i32;
let kv_dim_i32 = kv_dim as i32;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(qkv);
b.arg(q);
b.arg(k);
b.arg(v);
b.arg(&tokens_i32);
b.arg(&q_dim_i32);
b.arg(&kv_dim_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (tokens as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("split_qkv launch");
}
fn fused_silu_mul_split(
ctx: &mut Self::Context,
gate_up: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
) {
let func = ctx.func(
"fused_silu_mul",
ptx::FUSED_SILU_MUL,
"fused_silu_mul_interleaved_f16",
);
let im_i32 = im as i32;
let total = tokens * im;
let total_i32 = total as i32;
let block = 256u32;
let grid = ((total as u32) + block - 1) / block;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(gate_up);
b.arg(out);
b.arg(&im_i32);
b.arg(&total_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("fused_silu_mul_split launch");
}
fn kv_cache_append_batched_per_cache(
ctx: &mut Self::Context,
caches: &[&Self::Buffer],
new_data: &Self::Buffer,
cache_lens: &Self::Buffer,
capacity: usize,
m: usize,
nkv: usize,
hd: usize,
slot: usize,
) -> Result<()> {
use cudarc::driver::DevicePtr;
if m == 0 {
return Ok(());
}
if caches.len() != m {
return Err(FerrumError::model(
"kv_cache_append_batched_per_cache: caches length != m",
));
}
let stream = ctx.stream.clone();
if m > BATCHED_SCRATCH_CAP {
return Err(FerrumError::model(format!(
"kv_cache_append_batched_per_cache: m={m} exceeds BATCHED_SCRATCH_CAP={BATCHED_SCRATCH_CAP}",
)));
}
if slot >= MAX_GRAPH_SLOTS {
return Err(FerrumError::model(format!(
"kv_cache_append_batched_per_cache: slot={slot} exceeds MAX_GRAPH_SLOTS={MAX_GRAPH_SLOTS}",
)));
}
let host_start = slot * BATCHED_SCRATCH_CAP;
let func = ctx.func(
"kv_cache_append_batched",
ptx::KV_CACHE_APPEND,
"kv_cache_append_batched_per_cache_f16",
);
let m_i32 = m as i32;
let nkv_i32 = nkv as i32;
let hd_i32 = hd as i32;
let capacity_i32 = capacity as i32;
let per_item = nkv * hd;
let block_dim = 256u32;
let grid_x = (per_item as u32 + block_dim - 1) / block_dim;
with_batched_scratch_mut(ctx.ordinal, |slot_g| {
for i in 0..m {
let (cp, _) = caches[i].as_f16().device_ptr(&stream);
slot_g.host_cache_ptrs[host_start + i] = cp;
}
{
let host_slice: &[u64] = &slot_g.host_cache_ptrs[host_start..host_start + m];
let mut view = slot_g
.scratch_u64_cache
.slice_mut(host_start..host_start + m);
stream
.memcpy_htod(host_slice, &mut view)
.map_err(|e| FerrumError::model(format!("memcpy cache_ptrs: {e}")))?;
}
let cache_ptrs_view = slot_g.scratch_u64_cache.slice(host_start..host_start + m);
let cache_lens_dev = cache_lens;
let mut b = stream.launch_builder(&func);
b.arg(&cache_ptrs_view);
b.arg(new_data);
b.arg(cache_lens_dev);
b.arg(&m_i32);
b.arg(&nkv_i32);
b.arg(&hd_i32);
b.arg(&capacity_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid_x, m as u32, 1),
block_dim: (block_dim, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| FerrumError::model(format!("kv_cache_append_batched: {e}")))?;
Ok::<(), FerrumError>(())
})?;
Ok(())
}
fn flash_attention_batched_per_cache(
ctx: &mut Self::Context,
q: &Self::Buffer,
k_caches: &[&Self::Buffer],
v_caches: &[&Self::Buffer],
kv_lens: &Self::Buffer,
out: &mut Self::Buffer,
nq: usize,
nkv: usize,
hd: usize,
scale: f32,
max_valid_kv: usize,
capacity: usize,
slot: usize,
) -> Result<()> {
use cudarc::driver::DevicePtr;
let m = k_caches.len();
if m == 0 {
return Ok(());
}
if v_caches.len() != m {
return Err(FerrumError::model(
"flash_attention_batched_per_cache: k/v length mismatch",
));
}
let stream = ctx.stream.clone();
if m > BATCHED_SCRATCH_CAP {
return Err(FerrumError::model(format!(
"flash_attention_batched_per_cache: m={m} exceeds BATCHED_SCRATCH_CAP={BATCHED_SCRATCH_CAP}",
)));
}
if slot >= MAX_GRAPH_SLOTS {
return Err(FerrumError::model(format!(
"flash_attention_batched_per_cache: slot={slot} exceeds MAX_GRAPH_SLOTS={MAX_GRAPH_SLOTS}",
)));
}
let host_start = slot * BATCHED_SCRATCH_CAP;
let func = ctx.func(
"batched_decode_attn",
ptx::BATCHED_DECODE_ATTENTION,
"batched_decode_attention_f16",
);
let nq_i32 = nq as i32;
let nkv_i32 = nkv as i32;
let hd_i32 = hd as i32;
let capacity_i32 = capacity as i32;
let shared_bytes = (max_valid_kv.min(capacity).max(1) as u32) * 4;
with_batched_scratch_mut(ctx.ordinal, |slot_g| {
for i in 0..m {
let (kp, _) = k_caches[i].as_f16().device_ptr(&stream);
let (vp, _) = v_caches[i].as_f16().device_ptr(&stream);
slot_g.host_k_ptrs[host_start + i] = kp;
slot_g.host_v_ptrs[host_start + i] = vp;
}
{
let k_host_slice: &[u64] = &slot_g.host_k_ptrs[host_start..host_start + m];
let mut view = slot_g.scratch_u64_k.slice_mut(host_start..host_start + m);
stream
.memcpy_htod(k_host_slice, &mut view)
.map_err(|e| FerrumError::model(format!("memcpy k_ptrs: {e}")))?;
}
{
let v_host_slice: &[u64] = &slot_g.host_v_ptrs[host_start..host_start + m];
let mut view = slot_g.scratch_u64_v.slice_mut(host_start..host_start + m);
stream
.memcpy_htod(v_host_slice, &mut view)
.map_err(|e| FerrumError::model(format!("memcpy v_ptrs: {e}")))?;
}
let k_ptrs_view = slot_g.scratch_u64_k.slice(host_start..host_start + m);
let v_ptrs_view = slot_g.scratch_u64_v.slice(host_start..host_start + m);
let kv_lens_dev = kv_lens;
let mut b = stream.launch_builder(&func);
b.arg(q);
b.arg(&k_ptrs_view);
b.arg(&v_ptrs_view);
b.arg(out);
b.arg(kv_lens_dev);
b.arg(&nq_i32);
b.arg(&nkv_i32);
b.arg(&hd_i32);
b.arg(&capacity_i32);
b.arg(&scale);
unsafe {
b.launch(LaunchConfig {
grid_dim: (nq as u32, m as u32, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: shared_bytes,
})
}
.map_err(|e| FerrumError::model(format!("flash_attn_batched: {e}")))?;
Ok::<(), FerrumError>(())
})?;
Ok(())
}
fn qk_norm_rope_batched_per_item(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
positions: &Self::Buffer,
m: usize,
heads: usize,
head_dim: usize,
eps: f32,
mode: i32,
) -> Result<()> {
let func = ctx.func(
"qk_norm_rope_batched",
ptx::QK_NORM_ROPE,
"qk_norm_rope_batched_decode_f16",
);
let m_i32 = m as i32;
let heads_i32 = heads as i32;
let head_dim_i32 = head_dim as i32;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(input);
b.arg(norm_w);
b.arg(cos);
b.arg(sin);
b.arg(output);
b.arg(&m_i32);
b.arg(&heads_i32);
b.arg(&head_dim_i32);
b.arg(positions);
b.arg(&eps);
b.arg(&mode);
unsafe {
b.launch(LaunchConfig {
grid_dim: (m as u32, heads as u32, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| FerrumError::model(format!("qk_norm_rope_batched_per_item: {e}")))?;
Ok(())
}
fn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
) {
let use_dyn = ctx.use_dev_state && tokens == 1;
let fn_name = if use_dyn {
"qk_norm_rope_transpose_f16_dyn"
} else {
"qk_norm_rope_transpose_f16"
};
let func = ctx.func("qk_norm_rope", ptx::QK_NORM_ROPE, fn_name);
let tokens_i32 = tokens as i32;
let heads_i32 = heads as i32;
let head_dim_i32 = head_dim as i32;
let pos_offset_i32 = pos_offset as i32;
let stream = ctx.stream.clone();
let dec_guard = if use_dyn {
Some(
decode_state_slot_for_ordinal(ctx.ordinal)
.read()
.expect("DECODE_STATE poisoned"),
)
} else {
None
};
let mut b = stream.launch_builder(&func);
b.arg(input);
b.arg(norm_w);
b.arg(cos);
b.arg(sin);
b.arg(output);
b.arg(&tokens_i32);
b.arg(&heads_i32);
b.arg(&head_dim_i32);
if use_dyn {
let bufs = dec_guard.as_ref().unwrap().as_ref().unwrap();
b.arg(&bufs.pos);
} else {
b.arg(&pos_offset_i32);
}
b.arg(&eps);
b.arg(&mode);
unsafe {
b.launch(LaunchConfig {
grid_dim: (tokens as u32, heads as u32, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("qk_norm_rope launch");
drop(dec_guard);
}
fn split_qkv_norm_rope(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q_norm_w: &Self::Buffer,
k_norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
q_out: &mut Self::Buffer,
k_out: &mut Self::Buffer,
v_out: &mut Self::Buffer,
tokens: usize,
q_heads: usize,
kv_heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
qk_mode: i32,
) -> Result<()> {
let q_dim = q_heads * head_dim;
let kv_dim = kv_heads * head_dim;
let q_buf_size = tokens * q_dim;
let kv_buf_size = tokens * kv_dim;
let mut q_buf = <Self as Backend>::alloc(q_buf_size);
let mut k_buf = <Self as Backend>::alloc(kv_buf_size);
let mut v_buf = <Self as Backend>::alloc(kv_buf_size);
Self::split_qkv(
ctx, qkv, &mut q_buf, &mut k_buf, &mut v_buf, tokens, q_dim, kv_dim,
);
Self::qk_norm_rope(
ctx, &q_buf, q_norm_w, cos, sin, q_out, tokens, q_heads, head_dim, pos_offset, eps,
qk_mode,
);
Self::qk_norm_rope(
ctx, &k_buf, k_norm_w, cos, sin, k_out, tokens, kv_heads, head_dim, pos_offset, eps,
qk_mode,
);
Self::qk_norm_rope(
ctx, &v_buf, q_norm_w, cos, sin, v_out, tokens, kv_heads, head_dim, pos_offset, eps, 0,
);
Ok(())
}
fn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
) {
debug_assert!(cache_len + new_tokens <= cache_capacity);
let use_dyn = ctx.use_dev_state && new_tokens == 1;
let fn_name = if use_dyn {
"kv_cache_append_head_major_f16_dyn"
} else {
"kv_cache_append_head_major_f16"
};
let func = ctx.func("kv_cache_append", ptx::KV_CACHE_APPEND, fn_name);
let nkv_i32 = nkv as i32;
let hd_i32 = hd as i32;
let cache_len_i32 = cache_len as i32;
let new_tokens_i32 = new_tokens as i32;
let cap_i32 = cache_capacity as i32;
let total = nkv * new_tokens * hd;
let block = 256u32;
let grid = ((total as u32) + block - 1) / block;
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
let stream = ctx.stream.clone();
let dec_guard = if use_dyn {
Some(
decode_state_slot_for_ordinal(ctx.ordinal)
.read()
.expect("DECODE_STATE poisoned"),
)
} else {
None
};
{
let mut b = stream.launch_builder(&func);
b.arg(cache_k);
b.arg(new_k_head_major);
b.arg(&nkv_i32);
b.arg(&hd_i32);
if use_dyn {
let bufs = dec_guard.as_ref().unwrap().as_ref().unwrap();
b.arg(&bufs.pos);
} else {
b.arg(&cache_len_i32);
}
b.arg(&new_tokens_i32);
b.arg(&cap_i32);
unsafe { b.launch(cfg) }.expect("kv_cache_append K launch");
}
{
let mut b = stream.launch_builder(&func);
b.arg(cache_v);
b.arg(new_v_head_major);
b.arg(&nkv_i32);
b.arg(&hd_i32);
if use_dyn {
let bufs = dec_guard.as_ref().unwrap().as_ref().unwrap();
b.arg(&bufs.pos);
} else {
b.arg(&cache_len_i32);
}
b.arg(&new_tokens_i32);
b.arg(&cap_i32);
unsafe { b.launch(cfg) }.expect("kv_cache_append V launch");
}
drop(dec_guard);
}
fn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
) {
let func = ctx.func("transpose", ptx::TRANSPOSE, "transpose_head_to_token_f16");
let tokens_i32 = tokens as i32;
let heads_i32 = heads as i32;
let dim_i32 = dim as i32;
let total = tokens * heads * dim;
let block = 256u32;
let grid = ((total as u32) + block - 1) / block;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(src);
b.arg(dst);
b.arg(&tokens_i32);
b.arg(&heads_i32);
b.arg(&dim_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("transpose_head_to_token launch");
}
fn transpose_token_to_head(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
) {
let func = ctx.func("transpose", ptx::TRANSPOSE, "transpose_token_to_head_f16");
let tokens_i32 = tokens as i32;
let heads_i32 = heads as i32;
let dim_i32 = dim as i32;
let total = tokens * heads * dim;
let block = 256u32;
let grid = ((total as u32) + block - 1) / block;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(src);
b.arg(dst);
b.arg(&tokens_i32);
b.arg(&heads_i32);
b.arg(&dim_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("transpose_token_to_head launch");
}
fn add_inplace(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
len: usize,
) {
let func = ctx.func(
"residual_add",
ptx::RESIDUAL_ADD,
"residual_add_inplace_f16",
);
let n_i32 = len as i32;
let block = 256u32;
let grid = ((len as u32) + block - 1) / block;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(residual);
b.arg(x);
b.arg(&n_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("add_inplace (residual_add_inplace) launch");
}
fn scaled_add_inplace(
ctx: &mut Self::Context,
dst: &mut Self::Buffer,
src: &Self::Buffer,
scale: f32,
len: usize,
) {
if len == 0 {
return;
}
let dst_dtype = dst.dtype();
let src_dtype = src.dtype();
assert_eq!(
dst_dtype,
src_dtype,
"CudaBackend::scaled_add_inplace dtype mismatch: dst={} src={}",
dst_dtype.name(),
src_dtype.name()
);
assert!(
len <= dst.len() && len <= src.len(),
"CudaBackend::scaled_add_inplace len={len} exceeds dst_len={} src_len={}",
dst.len(),
src.len()
);
let fn_name = match dst_dtype {
crate::backend::Dtype::F16 => "scaled_add_inplace_f16",
crate::backend::Dtype::F32 => "scaled_add_inplace_f32",
other => panic!(
"CudaBackend::scaled_add_inplace unsupported dtype {}",
other.name()
),
};
let func = ctx.func("scaled_add_inplace", ptx::SCALED_ADD_INPLACE, fn_name);
let n_i32 = len as i32;
let block = 256u32;
let grid = ((len as u32) + block - 1) / block;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(dst);
b.arg(src);
b.arg(&scale);
b.arg(&n_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("scaled_add_inplace launch");
}
fn fused_silu_mul_split_strided(
ctx: &mut Self::Context,
gate_up: &Self::Buffer,
in_row_offset: usize,
out: &mut Self::Buffer,
out_row_offset: usize,
tokens: usize,
intermediate: usize,
) {
use cudarc::driver::{DevicePtr, DevicePtrMut};
let func = ctx.func(
"fused_silu_mul",
ptx::FUSED_SILU_MUL,
"fused_silu_mul_interleaved_f16",
);
let im_i32 = intermediate as i32;
let total = tokens * intermediate;
let total_i32 = total as i32;
let block = 256u32;
let grid = ((total as u32) + block - 1) / block;
let stream = ctx.stream.clone();
let in_byte_off = in_row_offset * 2 * intermediate * std::mem::size_of::<half::f16>();
let out_byte_off = out_row_offset * intermediate * std::mem::size_of::<half::f16>();
let (gu_base, _g) = gate_up.as_f16().device_ptr(&stream);
let (out_base, _g2) = out.as_f16_mut().device_ptr_mut(&stream);
let gu_ptr = gu_base + in_byte_off as u64;
let out_ptr = out_base + out_byte_off as u64;
let mut b = stream.launch_builder(&func);
b.arg(&gu_ptr);
b.arg(&out_ptr);
b.arg(&im_i32);
b.arg(&total_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("fused_silu_mul_split_strided launch");
}
fn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
) {
let func = ctx.func("add_bias", ptx::ADD_BIAS, "add_bias_f16");
let rows_i32 = rows as i32;
let cols_i32 = cols as i32;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(data);
b.arg(bias);
b.arg(&rows_i32);
b.arg(&cols_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (rows as u32, 1, 1),
block_dim: (cols.min(1024) as u32, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("add_bias launch");
}
fn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
) {
let func = ctx.func("layer_norm", ptx::LAYER_NORM, "layer_norm_f16");
let dim_i32 = dim as i32;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(x);
b.arg(gamma);
b.arg(beta);
b.arg(out);
b.arg(&dim_i32);
b.arg(&eps);
unsafe {
b.launch(LaunchConfig {
grid_dim: (tokens as u32, 1, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("layer_norm launch");
}
fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
let func = ctx.func("gelu", ptx::GELU, "gelu_f16");
let n_i32 = len as i32;
let block = 256u32;
let grid = ((len as u32) + block - 1) / block;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(x);
b.arg(out);
b.arg(&n_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("gelu launch");
}
fn zero_buffer(ctx: &mut Self::Context, buf: &mut Self::Buffer, len: usize) -> Result<()> {
use cudarc::driver::DevicePtr;
let stream = ctx.stream.clone();
let (ptr, _g) = buf.as_f16().device_ptr(&stream);
unsafe {
cudarc::driver::sys::cuMemsetD16Async(
ptr as cudarc::driver::sys::CUdeviceptr,
0,
len,
stream.cu_stream(),
)
}
.result()
.map_err(|e| FerrumError::model(format!("cuMemsetD16Async: {e}")))?;
Ok(())
}
}
static GLOBAL_STREAMS: std::sync::OnceLock<std::sync::RwLock<HashMap<usize, Arc<CudaStream>>>> =
std::sync::OnceLock::new();
fn stream_slots() -> &'static std::sync::RwLock<HashMap<usize, Arc<CudaStream>>> {
GLOBAL_STREAMS.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
}
pub(super) fn default_stream() -> Arc<CudaStream> {
let ordinal = current_device_ordinal();
if let Some(s) = stream_slots()
.read()
.expect("GLOBAL_STREAMS poisoned")
.get(&ordinal)
{
return s.clone();
}
let mut w = stream_slots().write().expect("GLOBAL_STREAMS poisoned");
if !w.contains_key(&ordinal) {
let ctx = CudaContext::new(ordinal).unwrap_or_else(|e| {
panic!("CudaBackend: failed to init default context {ordinal}: {e}")
});
unsafe {
ctx.disable_event_tracking();
}
let stream = ctx
.new_stream()
.unwrap_or_else(|e| panic!("CudaBackend: failed to create default stream: {e}"));
w.insert(ordinal, stream);
}
w.get(&ordinal).cloned().expect("just inserted")
}
fn with_stream<R>(f: impl FnOnce(&Arc<CudaStream>) -> R) -> R {
let stream = default_stream();
f(&stream)
}
pub fn install_thread_stream(stream: Arc<CudaStream>) {
stream_slots()
.write()
.expect("GLOBAL_STREAMS poisoned")
.insert(current_device_ordinal(), stream);
}
pub struct DecodeStateBufs {
pub token: CudaSlice<u32>,
pub pos: CudaSlice<i32>,
pub kv: CudaSlice<i32>,
}
unsafe impl Send for DecodeStateBufs {}
unsafe impl Sync for DecodeStateBufs {}
static DECODE_STATES: std::sync::OnceLock<
std::sync::RwLock<HashMap<usize, &'static std::sync::RwLock<Option<DecodeStateBufs>>>>,
> = std::sync::OnceLock::new();
fn decode_state_slots(
) -> &'static std::sync::RwLock<HashMap<usize, &'static std::sync::RwLock<Option<DecodeStateBufs>>>>
{
DECODE_STATES.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
}
pub(super) fn decode_state_slot_for_ordinal(
ordinal: usize,
) -> &'static std::sync::RwLock<Option<DecodeStateBufs>> {
{
let g = decode_state_slots().read().expect("DECODE_STATES poisoned");
if let Some(slot) = g.get(&ordinal) {
return *slot;
}
}
let mut w = decode_state_slots()
.write()
.expect("DECODE_STATES poisoned");
*w.entry(ordinal)
.or_insert_with(|| Box::leak(Box::new(std::sync::RwLock::new(None))))
}
fn ensure_decode_state_bufs(stream: &Arc<CudaStream>) {
let slot = decode_state_slot_for_ordinal(current_device_ordinal());
let guard = slot.read().expect("DECODE_STATE poisoned");
if guard.is_some() {
return;
}
drop(guard);
let mut w = slot.write().expect("DECODE_STATE poisoned");
if w.is_none() {
let token = unsafe { stream.alloc::<u32>(1) }.expect("token_buf alloc");
let pos = unsafe { stream.alloc::<i32>(1) }.expect("pos_buf alloc");
let kv = unsafe { stream.alloc::<i32>(1) }.expect("kv_buf alloc");
*w = Some(DecodeStateBufs { token, pos, kv });
}
}
struct BlasSlot {
blas: Arc<CudaBlas>,
_workspace: CudaSlice<u8>,
pub alpha_f32: CudaSlice<f32>, pub beta_f32: CudaSlice<f32>, }
unsafe impl Send for BlasSlot {}
unsafe impl Sync for BlasSlot {}
static BLAS_HANDLES: std::sync::OnceLock<std::sync::RwLock<HashMap<usize, BlasSlot>>> =
std::sync::OnceLock::new();
fn blas_slots() -> &'static std::sync::RwLock<HashMap<usize, BlasSlot>> {
BLAS_HANDLES.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
}
fn ensure_blas_handle(stream: &Arc<CudaStream>) -> Arc<CudaBlas> {
let ordinal = current_device_ordinal();
if let Some(slot) = blas_slots().read().expect("BLAS poisoned").get(&ordinal) {
return slot.blas.clone();
}
let mut w = blas_slots().write().expect("BLAS poisoned");
if !w.contains_key(&ordinal) {
const WS_BYTES: usize = 32 * 1024 * 1024;
let blas = Arc::new(CudaBlas::new(stream.clone()).expect("CudaBlas::new"));
let workspace = unsafe { stream.alloc::<u8>(WS_BYTES) }.expect("blas ws alloc");
let alpha_f32 = stream.clone_htod(&[1.0f32]).expect("alpha htod");
let beta_f32 = stream.clone_htod(&[0.0f32]).expect("beta htod");
unsafe {
use cudarc::cublas::sys;
use cudarc::driver::DevicePtr;
let (ws_ptr, _g) = workspace.device_ptr(stream);
let st = sys::cublasSetWorkspace_v2(*blas.handle(), ws_ptr as *mut _, WS_BYTES);
assert_eq!(
st,
sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS,
"set workspace"
);
let st = sys::cublasSetPointerMode_v2(
*blas.handle(),
sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
);
assert_eq!(
st,
sys::cublasStatus_t::CUBLAS_STATUS_SUCCESS,
"set pointer mode"
);
}
w.insert(
ordinal,
BlasSlot {
blas,
_workspace: workspace,
alpha_f32,
beta_f32,
},
);
}
w.get(&ordinal).unwrap().blas.clone()
}
fn with_blas_scalars<R>(
ordinal: usize,
f: impl FnOnce(&CudaSlice<f32>, &CudaSlice<f32>) -> R,
) -> R {
let g = blas_slots().read().expect("BLAS poisoned");
let s = g.get(&ordinal).expect("BLAS not init");
f(&s.alpha_f32, &s.beta_f32)
}
struct BatchedScratchSlot {
pub scratch_u64_k: CudaSlice<u64>,
pub scratch_u64_v: CudaSlice<u64>,
pub scratch_u64_cache: CudaSlice<u64>,
pub host_k_ptrs: Box<[u64; HOST_STAGING_TOTAL]>,
pub host_v_ptrs: Box<[u64; HOST_STAGING_TOTAL]>,
pub host_cache_ptrs: Box<[u64; HOST_STAGING_TOTAL]>,
}
unsafe impl Send for BatchedScratchSlot {}
unsafe impl Sync for BatchedScratchSlot {}
static BATCHED_SCRATCH: std::sync::OnceLock<
std::sync::RwLock<HashMap<usize, &'static std::sync::RwLock<Option<BatchedScratchSlot>>>>,
> = std::sync::OnceLock::new();
fn batched_scratch_slots() -> &'static std::sync::RwLock<
HashMap<usize, &'static std::sync::RwLock<Option<BatchedScratchSlot>>>,
> {
BATCHED_SCRATCH.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
}
fn batched_scratch_slot_for_ordinal(
ordinal: usize,
) -> &'static std::sync::RwLock<Option<BatchedScratchSlot>> {
{
let g = batched_scratch_slots()
.read()
.expect("BATCHED_SCRATCH poisoned");
if let Some(slot) = g.get(&ordinal) {
return *slot;
}
}
let mut w = batched_scratch_slots()
.write()
.expect("BATCHED_SCRATCH poisoned");
*w.entry(ordinal)
.or_insert_with(|| Box::leak(Box::new(std::sync::RwLock::new(None))))
}
fn ensure_batched_scratch(stream: &Arc<CudaStream>) {
let slot = batched_scratch_slot_for_ordinal(current_device_ordinal());
{
let g = slot.read().expect("BATCHED_SCRATCH poisoned");
if g.is_some() {
return;
}
}
let mut w = slot.write().expect("BATCHED_SCRATCH poisoned");
if w.is_none() {
let scratch_u64_k = unsafe { stream.alloc::<u64>(HOST_STAGING_TOTAL) }
.expect("batched scratch_u64_k alloc");
let scratch_u64_v = unsafe { stream.alloc::<u64>(HOST_STAGING_TOTAL) }
.expect("batched scratch_u64_v alloc");
let scratch_u64_cache = unsafe { stream.alloc::<u64>(HOST_STAGING_TOTAL) }
.expect("batched scratch_u64_cache alloc");
*w = Some(BatchedScratchSlot {
scratch_u64_k,
scratch_u64_v,
scratch_u64_cache,
host_k_ptrs: Box::new([0u64; HOST_STAGING_TOTAL]),
host_v_ptrs: Box::new([0u64; HOST_STAGING_TOTAL]),
host_cache_ptrs: Box::new([0u64; HOST_STAGING_TOTAL]),
});
}
}
fn with_batched_scratch_mut<R>(ordinal: usize, f: impl FnOnce(&mut BatchedScratchSlot) -> R) -> R {
let mut g = batched_scratch_slot_for_ordinal(ordinal)
.write()
.expect("BATCHED_SCRATCH poisoned");
f(g.as_mut().expect("BatchedScratchSlot not initialised"))
}
static MODULES: std::sync::OnceLock<
std::sync::Mutex<HashMap<(usize, &'static str), Arc<CudaModule>>>,
> = std::sync::OnceLock::new();
fn modules_cache() -> &'static std::sync::Mutex<HashMap<(usize, &'static str), Arc<CudaModule>>> {
MODULES.get_or_init(|| std::sync::Mutex::new(HashMap::new()))
}
pub(super) fn ensure_module(
ordinal: usize,
ctx: &Arc<CudaContext>,
key: &'static str,
ptx_src: &str,
) -> Arc<CudaModule> {
let cache_key = (ordinal, key);
{
let g = modules_cache().lock().expect("MODULES poisoned");
if let Some(m) = g.get(&cache_key) {
return m.clone();
}
}
let mut g = modules_cache().lock().expect("MODULES poisoned");
if let Some(m) = g.get(&cache_key) {
return m.clone();
}
let m = ctx
.load_module(Ptx::from_src(ptx_src.to_string()))
.unwrap_or_else(|e| panic!("CudaBackend: load_module({key}): {e}"));
g.insert(cache_key, m.clone());
m
}
impl crate::backend::BackendKvDtype<crate::backend::KvFp16> for CudaBackend {
type KvBuffer = <Self as crate::backend::Backend>::Buffer;
type KvScales = ();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cuda_backend_runtime_env_parses_values() {
let env = CudaBackendRuntimeEnv::from_env_vars([
("FERRUM_MOE_STREAMS", "8"),
("FERRUM_CUDA_MAX_KV", "16384"),
("FERRUM_CUDA_DEVICE", "2"),
]);
assert_eq!(env.moe_streams, 8);
assert_eq!(env.cuda_max_kv, Some(16384));
assert_eq!(env.cuda_device, 2);
}
#[test]
fn cuda_backend_runtime_env_defaults_invalid_values() {
let env = CudaBackendRuntimeEnv::from_env_vars([
("FERRUM_MOE_STREAMS", "0"),
("FERRUM_CUDA_MAX_KV", "invalid"),
("FERRUM_CUDA_DEVICE", "invalid"),
]);
assert_eq!(env.moe_streams, 1);
assert_eq!(env.cuda_max_kv, None);
assert_eq!(env.cuda_device, 0);
}
#[test]
fn cuda_device_scope_nests_and_restores() {
let default = current_device_ordinal();
with_cuda_device_ordinal(Some(1), || {
assert_eq!(current_device_ordinal(), 1);
with_cuda_device_ordinal(Some(2), || {
assert_eq!(current_device_ordinal(), 2);
});
assert_eq!(current_device_ordinal(), 1);
});
assert_eq!(current_device_ordinal(), default);
}
}