use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use cudarc::driver::CudaStream;
use ferrum_bench_core::{global_profile, profile_fields_from_json};
use ferrum_types::{FerrumError, Result};
use super::{decode_state_slot_for_ordinal, CudaBackend};
use crate::backend::{Backend, BackendGraph};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct CudaGraphRuntimeConfig {
graph_prof: bool,
skip_upload: bool,
skip_sync: bool,
}
impl CudaGraphRuntimeConfig {
fn from_env() -> Self {
Self::from_env_vars(std::env::vars())
}
fn from_env_vars<I, K, V>(vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let mut config = Self {
graph_prof: false,
skip_upload: false,
skip_sync: false,
};
for (name, value) in vars {
let value = value.as_ref();
match name.as_ref() {
"FERRUM_GRAPH_PROF" => config.graph_prof = true,
"FERRUM_GRAPH_SKIP_UPLOAD" => config.skip_upload = value == "1",
"FERRUM_GRAPH_SKIP_SYNC" => config.skip_sync = value == "1",
_ => {}
}
}
config
}
}
fn cuda_graph_runtime_config() -> &'static CudaGraphRuntimeConfig {
static CONFIG: OnceLock<CudaGraphRuntimeConfig> = OnceLock::new();
CONFIG.get_or_init(CudaGraphRuntimeConfig::from_env)
}
impl BackendGraph for CudaBackend {
fn set_decode_state(ctx: &mut Self::Context, token: u32, step: u32) {
let valid_kv = (step as i32) + 1;
let step_i = step as i32;
let stream = ctx.stream.clone();
let mut w = decode_state_slot_for_ordinal(ctx.ordinal)
.write()
.expect("DECODE_STATE poisoned");
let bufs = w.as_mut().expect("DecodeStateBufs not initialised");
stream
.memcpy_htod(&[token], &mut bufs.token)
.expect("token_buf memcpy");
stream
.memcpy_htod(&[step_i], &mut bufs.pos)
.expect("pos_buf memcpy");
stream
.memcpy_htod(&[valid_kv], &mut bufs.kv)
.expect("kv_buf memcpy");
}
fn set_dev_state_mode(ctx: &mut Self::Context, enable: bool) {
ctx.use_dev_state = enable;
}
fn begin_graph_capture(ctx: &mut Self::Context) -> Result<()> {
use cudarc::driver::sys::CUstreamCaptureMode;
ctx.stream
.begin_capture(CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED)
.map_err(|e| FerrumError::unsupported(format!("begin_capture: {e}")))?;
ctx.capture_in_flight = true;
Ok(())
}
fn end_graph_capture(ctx: &mut Self::Context, key: u64) -> Result<()> {
use cudarc::driver::sys;
if !ctx.capture_in_flight {
return Err(FerrumError::unsupported("end_capture without begin"));
}
ctx.capture_in_flight = false;
ctx.ctx
.bind_to_thread()
.map_err(|e| FerrumError::unsupported(format!("bind pre-end: {e}")))?;
let cu_stream = ctx.stream.cu_stream();
let mut cu_graph: sys::CUgraph = std::ptr::null_mut();
let st1 = unsafe { sys::cuStreamEndCapture(cu_stream, &mut cu_graph) };
if st1 != sys::CUresult::CUDA_SUCCESS {
return Err(FerrumError::unsupported(format!(
"cuStreamEndCapture failed: {st1:?}"
)));
}
if cu_graph.is_null() {
return Err(FerrumError::unsupported(
"cuStreamEndCapture returned null graph",
));
}
let flags = 0u64;
let mut cu_graph_exec: sys::CUgraphExec = std::ptr::null_mut();
let st2 = unsafe { sys::cuGraphInstantiateWithFlags(&mut cu_graph_exec, cu_graph, flags) };
if st2 != sys::CUresult::CUDA_SUCCESS {
unsafe {
sys::cuGraphDestroy(cu_graph);
}
return Err(FerrumError::unsupported(format!(
"cuGraphInstantiate failed: {st2:?}"
)));
}
let st3 = unsafe { sys::cuGraphUpload(cu_graph_exec, cu_stream) };
if st3 != sys::CUresult::CUDA_SUCCESS {
unsafe {
sys::cuGraphExecDestroy(cu_graph_exec);
sys::cuGraphDestroy(cu_graph);
}
return Err(FerrumError::unsupported(format!(
"cuGraphUpload failed: {st3:?}"
)));
}
install_decode_graph_raw(key, cu_graph, cu_graph_exec, ctx.stream.clone());
Ok(())
}
fn reset_graph(_ctx: &mut Self::Context, key: u64) {
invalidate_decode_graph(key);
}
fn reset_all_graphs(_ctx: &mut Self::Context) {
invalidate_all_decode_graphs();
}
fn replay_graph(ctx: &mut Self::Context, key: u64) -> Result<bool> {
use cudarc::driver::sys;
let cu_stream = ctx.stream.cu_stream();
ctx.ctx
.bind_to_thread()
.map_err(|e| FerrumError::unsupported(format!("bind pre-replay: {e}")))?;
with_decode_graph(key, |g_opt| {
if let Some(g) = g_opt {
let runtime_config = cuda_graph_runtime_config();
let prof = runtime_config.graph_prof;
let t_pre = if prof {
Some(std::time::Instant::now())
} else {
None
};
if !runtime_config.skip_upload {
let st_up = unsafe { sys::cuGraphUpload(g.cu_graph_exec, cu_stream) };
if st_up != sys::CUresult::CUDA_SUCCESS {
return Err(FerrumError::unsupported(format!(
"cuGraphUpload: {st_up:?}"
)));
}
}
let t_after_upload = if prof {
Some(std::time::Instant::now())
} else {
None
};
let st = unsafe { sys::cuGraphLaunch(g.cu_graph_exec, cu_stream) };
if st != sys::CUresult::CUDA_SUCCESS {
return Err(FerrumError::unsupported(format!("cuGraphLaunch: {st:?}")));
}
let t_after_launch = if prof {
Some(std::time::Instant::now())
} else {
None
};
if !runtime_config.skip_sync {
let st_sync = unsafe { sys::cuStreamSynchronize(cu_stream) };
if st_sync != sys::CUresult::CUDA_SUCCESS {
return Err(FerrumError::unsupported(format!(
"post-launch sync: {st_sync:?}"
)));
}
}
if let (Some(t0), Some(t1), Some(t2)) = (t_pre, t_after_upload, t_after_launch) {
static N: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let n = N.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n.is_multiple_of(64) {
let upload = t1.duration_since(t0).as_micros();
let launch = t2.duration_since(t1).as_micros();
let sync = t2.elapsed().as_micros();
eprintln!(
"[graph-prof] call#{n} upload={upload}us launch={launch}us sync={sync}us total={}us",
t0.elapsed().as_micros()
);
let profile = global_profile();
if profile.is_enabled() {
let _ = profile.push_event(
"graph_prof",
profile_fields_from_json(serde_json::json!({
"call": n,
})),
profile_fields_from_json(serde_json::json!({
"upload": upload,
"launch": launch,
"sync": sync,
"total": t0.elapsed().as_micros(),
})),
true,
);
}
}
}
Ok(true)
} else {
Ok(false)
}
})
}
}
struct GraphSlotRaw {
cu_graph: cudarc::driver::sys::CUgraph,
cu_graph_exec: cudarc::driver::sys::CUgraphExec,
_stream: std::sync::Arc<cudarc::driver::CudaStream>,
}
impl Drop for GraphSlotRaw {
fn drop(&mut self) {
use cudarc::driver::sys;
unsafe {
sys::cuCtxSynchronize();
if !self.cu_graph_exec.is_null() {
sys::cuGraphExecDestroy(self.cu_graph_exec);
}
if !self.cu_graph.is_null() {
sys::cuGraphDestroy(self.cu_graph);
}
sys::cuCtxSynchronize();
}
}
}
unsafe impl Send for GraphSlotRaw {}
unsafe impl Sync for GraphSlotRaw {}
static DECODE_GRAPHS: std::sync::OnceLock<std::sync::RwLock<HashMap<u64, GraphSlotRaw>>> =
std::sync::OnceLock::new();
fn graph_slots() -> &'static std::sync::RwLock<HashMap<u64, GraphSlotRaw>> {
DECODE_GRAPHS.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
}
fn install_decode_graph_raw(
key: u64,
cu_graph: cudarc::driver::sys::CUgraph,
cu_graph_exec: cudarc::driver::sys::CUgraphExec,
stream: std::sync::Arc<cudarc::driver::CudaStream>,
) {
let mut g = graph_slots().write().expect("DECODE_GRAPHS poisoned");
g.insert(
key,
GraphSlotRaw {
cu_graph,
cu_graph_exec,
_stream: stream,
},
);
}
fn with_decode_graph<R>(key: u64, f: impl FnOnce(Option<&GraphSlotRaw>) -> Result<R>) -> Result<R> {
let guard = graph_slots().read().expect("DECODE_GRAPHS poisoned");
f(guard.get(&key))
}
pub fn invalidate_decode_graph(key: u64) {
graph_slots()
.write()
.expect("DECODE_GRAPHS poisoned")
.remove(&key);
}
pub fn invalidate_all_decode_graphs() {
graph_slots()
.write()
.expect("DECODE_GRAPHS poisoned")
.clear();
}
#[cfg(test)]
mod tests {
use super::CudaGraphRuntimeConfig;
#[test]
fn cuda_graph_runtime_config_parses_graph_knobs() {
let config = CudaGraphRuntimeConfig::from_env_vars([
("FERRUM_GRAPH_PROF", "0"),
("FERRUM_GRAPH_SKIP_UPLOAD", "1"),
("FERRUM_GRAPH_SKIP_SYNC", "1"),
]);
assert!(config.graph_prof);
assert!(config.skip_upload);
assert!(config.skip_sync);
}
#[test]
fn cuda_graph_runtime_config_keeps_default_replay_path() {
let config = CudaGraphRuntimeConfig::from_env_vars([
("FERRUM_GRAPH_SKIP_UPLOAD", "true"),
("FERRUM_GRAPH_SKIP_SYNC", "0"),
]);
assert!(!config.graph_prof);
assert!(!config.skip_upload);
assert!(!config.skip_sync);
}
}