use cudarc::nvrtc::compile_ptx;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use tracing::debug;
use super::cudagraph_type::CudaGraph;
use super::types::CudaGraphError;
fn fnv1a_64(data: &[u8]) -> u64 {
const BASIS: u64 = 0xcbf29ce484222325;
const PRIME: u64 = 0x100000001b3;
let mut h = BASIS;
for &b in data {
h ^= b as u64;
h = h.wrapping_mul(PRIME);
}
h
}
fn ptx_cache_path(src_hash: u64, tag: &str) -> std::path::PathBuf {
std::env::temp_dir().join(format!("oxibonsai_ptx_{src_hash:016x}_{tag}.ptx"))
}
fn load_ptx_cache(src_hash: u64, tag: &str) -> Option<cudarc::nvrtc::Ptx> {
let path = ptx_cache_path(src_hash, tag);
let ptx_src = std::fs::read_to_string(&path).ok()?;
Some(cudarc::nvrtc::Ptx::from_src(ptx_src))
}
fn save_ptx_cache(ptx: &cudarc::nvrtc::Ptx, src_hash: u64, tag: &str) {
let path = ptx_cache_path(src_hash, tag);
let _ = std::fs::write(&path, ptx.to_src());
}
pub(crate) fn compile_or_load_ptx(
src: &str,
tag: &str,
) -> Result<cudarc::nvrtc::Ptx, CudaGraphError> {
let hash = fnv1a_64(src.as_bytes());
if let Some(cached) = load_ptx_cache(hash, tag) {
debug!("PTX cache hit for tag={tag} hash={hash:016x}");
return Ok(cached);
}
debug!("PTX cache miss for tag={tag}, compiling...");
let ptx = compile_ptx(src).map_err(|e| CudaGraphError::CompilationFailed(format!("{e}")))?;
save_ptx_cache(&ptx, hash, tag);
debug!("PTX compiled and cached: tag={tag}");
Ok(ptx)
}
static NEXT_HANDLE_ID: AtomicU64 = AtomicU64::new(1);
pub(crate) fn alloc_handle_id() -> u64 {
NEXT_HANDLE_ID.fetch_add(1, Ordering::Relaxed)
}
pub(super) static GLOBAL_CUDA_GRAPH: OnceLock<Mutex<Option<Arc<CudaGraph>>>> = OnceLock::new();
#[allow(clippy::too_many_arguments)]
pub fn try_cuda_ffn(
hidden: &mut [f32],
attn_out: &[f32],
norm_weight: &[f32],
eps: f32,
attn_proj_handle_id: u64,
attn_proj_bytes: &[u8],
gate_up_handle_id: u64,
gate_bytes: &[u8],
up_bytes: &[u8],
down_handle_id: u64,
down_bytes: &[u8],
hidden_size: usize,
intermediate_size: usize,
) -> Result<(), CudaGraphError> {
let graph = CudaGraph::global()?;
let attn_proj_w = graph.get_or_upload_weight_soa(attn_proj_handle_id, attn_proj_bytes)?;
let gate_up_w = graph.get_or_upload_weight_soa_lazy(gate_up_handle_id, || {
let mut fused = Vec::with_capacity(gate_bytes.len() + up_bytes.len());
fused.extend_from_slice(gate_bytes);
fused.extend_from_slice(up_bytes);
fused
})?;
let down_w = graph.get_or_upload_weight_soa(down_handle_id, down_bytes)?;
graph.encode_ffn_phase(
hidden,
attn_out,
norm_weight,
eps,
&attn_proj_w,
&gate_up_w,
&down_w,
hidden_size,
intermediate_size,
)
}
#[allow(clippy::too_many_arguments)]
pub fn try_cuda_qkv(
input: &[f32],
output: &mut [f32],
weight_handle_id: u64,
q_bytes: &[u8],
k_bytes: &[u8],
v_bytes: &[u8],
n_rows: usize,
k: usize,
) -> Result<(), CudaGraphError> {
let graph = CudaGraph::global()?;
let weight_w = graph.get_or_upload_weight_soa_lazy(weight_handle_id, || {
let mut fused = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
fused.extend_from_slice(q_bytes);
fused.extend_from_slice(k_bytes);
fused.extend_from_slice(v_bytes);
fused
})?;
graph.encode_qkv_phase(input, output, &weight_w, n_rows, k)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cuda_graph_global_init() {
match CudaGraph::global() {
Ok(_) => {}
Err(e) => {
eprintln!("CudaGraph::global() not available (expected in CPU-only CI): {e}");
}
}
}
#[test]
fn test_reformat_aos_to_soa_round_trip() {
const N: usize = 10;
let mut aos = vec![0u8; N * 18];
for i in 0..N {
let base = i * 18;
let v = i as u16;
aos[base] = (v & 0xff) as u8;
aos[base + 1] = (v >> 8) as u8;
for j in 2..18 {
aos[base + j] = 0xABu8;
}
}
let soa = CudaGraph::reformat_q1_aos_to_soa(&aos).expect("reformat failed");
assert_eq!(soa.len(), aos.len());
for i in 0..N {
let v = i as u16;
assert_eq!(
soa[i * 2],
(v & 0xff) as u8,
"scale byte 0 wrong at block {i}"
);
assert_eq!(
soa[i * 2 + 1],
(v >> 8) as u8,
"scale byte 1 wrong at block {i}"
);
}
for i in 0..N {
let data_start = N * 2 + i * 16;
for j in 0..16 {
assert_eq!(
soa[data_start + j],
0xABu8,
"data wrong at block {i} byte {j}"
);
}
}
}
#[test]
fn test_handle_id_uniqueness() {
let ids: Vec<u64> = (0..64).map(|_| alloc_handle_id()).collect();
for w in ids.windows(2) {
assert!(w[1] > w[0], "handle IDs not strictly increasing");
}
}
#[test]
fn test_fused_gate_up_swiglu_source_has_entry_point() {
assert!(
crate::gpu_backend::cuda_kernels::CUDA_V7_KERNELS_SRC
.contains("fused_gate_up_swiglu_q1"),
"CUDA_V7_KERNELS_SRC must contain the fused_gate_up_swiglu_q1 kernel entry point"
);
}
#[test]
fn test_fused_gate_up_swiglu_source_has_silu_epilogue() {
let src = crate::gpu_backend::cuda_kernels::CUDA_V7_KERNELS_SRC;
assert!(
src.contains("silu(gate_partial) * up_partial"),
"fused kernel epilogue 'silu(gate_partial) * up_partial' not found in kernel source"
);
}
#[test]
fn test_fused_gate_up_swiglu_source_has_dual_accumulators() {
let src = crate::gpu_backend::cuda_kernels::CUDA_V7_KERNELS_SRC;
assert!(
src.contains("gate_partial"),
"fused kernel must have 'gate_partial' accumulator"
);
assert!(
src.contains("up_partial"),
"fused kernel must have 'up_partial' accumulator"
);
}
#[test]
fn test_fused_gate_up_swiglu_runtime_compile() {
match CudaGraph::global() {
Ok(_) => {}
Err(e) => {
eprintln!(
"test_fused_gate_up_swiglu_runtime_compile: no CUDA device (expected in CPU-only CI): {e}"
);
}
}
}
}