use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
use cudarc::cublas::CudaBlas;
use cudarc::driver::sys as driver_sys;
use cudarc::driver::CudaGraph;
use parking_lot::Mutex;
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::device::DeviceState;
use crate::error::GpuError;
pub mod record;
#[cfg(feature = "cufft")]
pub use record::fft_r2c::FftR2COp;
pub use record::memcpy::MemcpyOp;
#[cfg(feature = "curand")]
pub use record::rng_fill_uniform::RngFillUniformOp;
pub use record::sgemm::SgemmOp;
pub mod child;
#[cfg(feature = "graphs-conditional")]
pub mod conditional;
pub mod dot;
pub mod exec_update;
pub use child::ChildGraphOp;
pub use dot::{export_dot, DotFlags};
pub use exec_update::{exec_update, GraphExecUpdateOutcome};
const LIB: &str = "graph";
#[doc(hidden)]
pub struct MockGraphRecordCtx {
parent_graph: driver_sys::CUgraph,
stream: Option<Arc<cudarc::driver::CudaStream>>,
}
impl MockGraphRecordCtx {
pub fn new(parent_graph: driver_sys::CUgraph) -> Self {
Self {
parent_graph,
stream: None,
}
}
pub fn with_stream(mut self, stream: Arc<cudarc::driver::CudaStream>) -> Self {
self.stream = Some(stream);
self
}
pub fn parent_graph(&self) -> driver_sys::CUgraph {
self.parent_graph
}
pub fn stream(&self) -> Option<&Arc<cudarc::driver::CudaStream>> {
self.stream.as_ref()
}
pub fn as_ctx(&self) -> GraphRecordCtx<'_> {
GraphRecordCtx {
stream: self.stream.as_ref(),
blas: None,
#[cfg(feature = "curand")]
rng: None,
#[cfg(feature = "cufft")]
fft: None,
parent_graph: Some(self.parent_graph),
}
}
}
pub trait GraphOpRecord {
fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError>;
}
pub struct SendGraph(Arc<CudaGraph>);
unsafe impl Send for SendGraph {}
unsafe impl Sync for SendGraph {}
impl Clone for SendGraph {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
#[derive(Clone)]
pub struct GraphHandle {
graph: Option<SendGraph>,
generation: u64,
#[doc(hidden)]
synthetic_cu_graph: driver_sys::CUgraph,
#[doc(hidden)]
synthetic_cu_graph_exec: driver_sys::CUgraphExec,
}
unsafe impl Send for GraphHandle {}
unsafe impl Sync for GraphHandle {}
impl GraphHandle {
pub fn from_graph(graph: Arc<CudaGraph>, state: &Arc<DeviceState>) -> Self {
Self {
graph: Some(SendGraph(graph)),
generation: state.generation(),
synthetic_cu_graph: std::ptr::null_mut(),
synthetic_cu_graph_exec: std::ptr::null_mut(),
}
}
pub fn generation(&self) -> u64 {
self.generation
}
pub fn cu_graph(&self) -> driver_sys::CUgraph {
if let Some(g) = self.graph.as_ref() {
g.0.cu_graph()
} else {
self.synthetic_cu_graph
}
}
pub fn cu_graph_exec(&self) -> driver_sys::CUgraphExec {
if let Some(g) = self.graph.as_ref() {
g.0.cu_graph_exec()
} else {
self.synthetic_cu_graph_exec
}
}
#[doc(hidden)]
pub fn synthetic_for_tests() -> Self {
Self {
graph: None,
generation: 0,
synthetic_cu_graph: std::ptr::null_mut(),
synthetic_cu_graph_exec: std::ptr::null_mut(),
}
}
}
pub struct GraphRecordCtx<'a> {
pub stream: Option<&'a Arc<cudarc::driver::CudaStream>>,
pub blas: Option<&'a CudaBlas>,
#[cfg(feature = "curand")]
pub rng: Option<&'a cudarc::curand::CudaRng>,
#[cfg(feature = "cufft")]
pub fft: Option<&'a cudarc::cufft::CudaFft>,
pub parent_graph: Option<driver_sys::CUgraph>,
}
impl<'a> GraphRecordCtx<'a> {
pub fn require_stream(&self) -> Result<&'a Arc<cudarc::driver::CudaStream>, GpuError> {
self.stream.ok_or_else(|| {
GpuError::Unrecoverable("GraphRecordCtx: no captured stream available".into())
})
}
pub fn parent_graph(&self) -> driver_sys::CUgraph {
self.parent_graph.unwrap_or(std::ptr::null_mut())
}
}
pub trait GraphOp: Send + 'static {
fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError>;
fn op_name(&self) -> &'static str {
"graph_op"
}
}
impl GraphOp for Box<dyn GraphOp> {
fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
(**self).record(ctx)
}
fn op_name(&self) -> &'static str {
(**self).op_name()
}
}
#[deprecated(
since = "0.1.0",
note = "construct individual `impl GraphOp` types (e.g. `SgemmOp`, `MemcpyOp`) and \
push them as `Box<dyn GraphOp>` instead of using the closed enum"
)]
#[allow(deprecated)]
pub enum GraphOpLegacy {
Sgemm(Box<SgemmOp>),
Memcpy(Box<MemcpyOp>),
#[cfg(feature = "curand")]
RngFillUniform(Box<RngFillUniformOp>),
#[cfg(feature = "cufft")]
FftR2C(Box<FftR2COp>),
}
#[allow(deprecated)]
impl GraphOp for GraphOpLegacy {
fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
match self {
GraphOpLegacy::Sgemm(b) => b.record(ctx),
GraphOpLegacy::Memcpy(m) => m.record(ctx),
#[cfg(feature = "curand")]
GraphOpLegacy::RngFillUniform(r) => r.record(ctx),
#[cfg(feature = "cufft")]
GraphOpLegacy::FftR2C(r) => r.record(ctx),
}
}
fn op_name(&self) -> &'static str {
match self {
GraphOpLegacy::Sgemm(b) => b.op_name(),
GraphOpLegacy::Memcpy(m) => m.op_name(),
#[cfg(feature = "curand")]
GraphOpLegacy::RngFillUniform(r) => r.op_name(),
#[cfg(feature = "cufft")]
GraphOpLegacy::FftR2C(r) => r.op_name(),
}
}
}
pub enum GraphMsg {
Record {
script: Vec<Box<dyn GraphOp>>,
reply: oneshot::Sender<Result<GraphHandle, GpuError>>,
},
Launch {
handle: GraphHandle,
reply: oneshot::Sender<Result<(), GpuError>>,
},
#[cfg(feature = "cufft")]
SetFftPlan {
plan: cudarc::cufft::CudaFft,
reply: oneshot::Sender<()>,
},
}
struct SendBlas(CudaBlas);
unsafe impl Send for SendBlas {}
unsafe impl Sync for SendBlas {}
#[cfg(feature = "curand")]
struct SendRng(cudarc::curand::CudaRng);
#[cfg(feature = "curand")]
unsafe impl Send for SendRng {}
#[cfg(feature = "curand")]
unsafe impl Sync for SendRng {}
#[cfg(feature = "cufft")]
struct SendFft(cudarc::cufft::CudaFft);
#[cfg(feature = "cufft")]
unsafe impl Send for SendFft {}
#[cfg(feature = "cufft")]
unsafe impl Sync for SendFft {}
pub struct GraphActor {
inner: GraphInner,
}
#[allow(dead_code)]
enum GraphInner {
Real {
stream: Arc<cudarc::driver::CudaStream>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
blas: Option<Mutex<SendBlas>>,
#[cfg(feature = "curand")]
rng: Option<Mutex<SendRng>>,
#[cfg(feature = "cufft")]
fft: Mutex<Option<SendFft>>,
},
Mock,
}
impl GraphActor {
pub fn props(
stream: Arc<cudarc::driver::CudaStream>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
) -> Props<Self> {
Props::create(move || {
let blas = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
CudaBlas::new(stream.clone())
})) {
Ok(Ok(b)) => Some(Mutex::new(SendBlas(b))),
_ => None,
};
#[cfg(feature = "curand")]
let rng = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
cudarc::curand::CudaRng::new(0, stream.clone())
})) {
Ok(Ok(r)) => Some(Mutex::new(SendRng(r))),
_ => None,
};
GraphActor {
inner: GraphInner::Real {
stream: stream.clone(),
completion: completion.clone(),
state: state.clone(),
blas,
#[cfg(feature = "curand")]
rng,
#[cfg(feature = "cufft")]
fft: Mutex::new(None),
},
}
})
}
pub fn mock_props() -> Props<Self> {
Props::create(|| GraphActor {
inner: GraphInner::Mock,
})
}
}
fn run_record(
stream: &Arc<cudarc::driver::CudaStream>,
state: &Arc<DeviceState>,
blas: &Option<Mutex<SendBlas>>,
#[cfg(feature = "curand")] rng: &Option<Mutex<SendRng>>,
#[cfg(feature = "cufft")] fft: &Mutex<Option<SendFft>>,
mut script: Vec<Box<dyn GraphOp>>,
) -> Result<GraphHandle, GpuError> {
let begin_res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
stream.begin_capture(driver_sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_GLOBAL)
}));
match begin_res {
Ok(Ok(())) => {}
Ok(Err(e)) => {
return Err(GpuError::LibraryError {
lib: LIB,
msg: format!("begin_capture: {e}"),
});
}
Err(_) => {
return Err(GpuError::Unrecoverable(
"GraphActor::Record: CUDA driver not loadable".into(),
));
}
}
let bail = |e: GpuError, stream: &Arc<cudarc::driver::CudaStream>| -> GpuError {
let _ = stream.end_capture(
driver_sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH,
);
e
};
let blas_guard = blas.as_ref().map(|m| m.lock());
#[cfg(feature = "curand")]
let rng_guard = rng.as_ref().map(|m| m.lock());
#[cfg(feature = "cufft")]
let fft_guard = fft.lock();
let mut ctx = GraphRecordCtx {
stream: Some(stream),
blas: blas_guard.as_ref().map(|g| &g.0),
#[cfg(feature = "curand")]
rng: rng_guard.as_ref().map(|g| &g.0),
#[cfg(feature = "cufft")]
fft: fft_guard.as_ref().map(|g| &g.0),
parent_graph: None,
};
for op in script.iter_mut() {
if let Err(e) = op.record(&mut ctx) {
drop(ctx);
#[cfg(feature = "cufft")]
drop(fft_guard);
#[cfg(feature = "curand")]
drop(rng_guard);
drop(blas_guard);
return Err(bail(e, stream));
}
}
drop(ctx);
#[cfg(feature = "cufft")]
drop(fft_guard);
#[cfg(feature = "curand")]
drop(rng_guard);
drop(blas_guard);
let end_res = stream.end_capture(
driver_sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH,
);
let cuda_graph = match end_res {
Ok(Some(g)) => g,
Ok(None) => {
return Err(GpuError::LibraryError {
lib: LIB,
msg: "end_capture returned None".into(),
});
}
Err(e) => {
return Err(GpuError::LibraryError {
lib: LIB,
msg: format!("end_capture: {e}"),
});
}
};
Ok(GraphHandle::from_graph(Arc::new(cuda_graph), state))
}
#[async_trait]
impl Actor for GraphActor {
type Msg = GraphMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: GraphMsg) {
match &self.inner {
GraphInner::Mock => match msg {
GraphMsg::Record { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"GraphActor in mock mode".into(),
)));
}
GraphMsg::Launch { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"GraphActor in mock mode".into(),
)));
}
#[cfg(feature = "cufft")]
GraphMsg::SetFftPlan { reply, .. } => {
let _ = reply.send(());
}
},
GraphInner::Real {
stream,
completion,
state,
blas,
#[cfg(feature = "curand")]
rng,
#[cfg(feature = "cufft")]
fft,
} => match msg {
GraphMsg::Record { script, reply } => {
let res = run_record(
stream,
state,
blas,
#[cfg(feature = "curand")]
rng,
#[cfg(feature = "cufft")]
fft,
script,
);
let _ = reply.send(res);
}
#[cfg(feature = "cufft")]
GraphMsg::SetFftPlan { plan, reply } => {
*fft.lock() = Some(SendFft(plan));
let _ = reply.send(());
}
GraphMsg::Launch { handle, reply } => {
if handle.generation != state.generation() {
let _ = reply.send(Err(GpuError::GpuRefStale(
"graph captured against rebuilt context",
)));
return;
}
let Some(graph) = handle.graph.as_ref() else {
let _ = reply.send(Err(GpuError::Unrecoverable(
"GraphActor::Launch: synthetic GraphHandle has no captured graph"
.into(),
)));
return;
};
let res = graph.0.launch().map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("launch: {e}"),
});
if let Err(e) = res {
let _ = reply.send(Err(e));
return;
}
let stream = stream.clone();
let completion = completion.clone();
tokio::spawn(async move {
let r = completion.await_completion(&stream).await;
let _ = reply.send(r);
});
}
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex as StdMutex;
struct MockOp {
name: &'static str,
trace: Arc<StdMutex<Vec<&'static str>>>,
}
impl GraphOp for MockOp {
fn record(&mut self, _ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
self.trace.lock().unwrap().push(self.name);
Ok(())
}
fn op_name(&self) -> &'static str {
self.name
}
}
struct CounterOp {
count: Arc<StdMutex<u32>>,
}
impl GraphOp for CounterOp {
fn record(&mut self, _ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
*self.count.lock().unwrap() += 1;
Ok(())
}
fn op_name(&self) -> &'static str {
"counter_op"
}
}
fn no_gpu_ctx<'a>() -> GraphRecordCtx<'a> {
GraphRecordCtx {
stream: None,
blas: None,
#[cfg(feature = "curand")]
rng: None,
#[cfg(feature = "cufft")]
fft: None,
parent_graph: None,
}
}
#[test]
fn external_graph_op_impls_can_be_appended_and_recorded() {
let trace: Arc<StdMutex<Vec<&'static str>>> = Arc::new(StdMutex::new(Vec::new()));
let count = Arc::new(StdMutex::new(0u32));
let mut script: Vec<Box<dyn GraphOp>> = Vec::new();
script.push(Box::new(MockOp {
name: "first_mock",
trace: trace.clone(),
}));
script.push(Box::new(CounterOp {
count: count.clone(),
}));
script.push(Box::new(MockOp {
name: "second_mock",
trace: trace.clone(),
}));
script.push(Box::new(CounterOp {
count: count.clone(),
}));
assert_eq!(script[0].op_name(), "first_mock");
assert_eq!(script[1].op_name(), "counter_op");
assert_eq!(script[2].op_name(), "second_mock");
assert_eq!(script[3].op_name(), "counter_op");
let mut ctx = no_gpu_ctx();
for op in script.iter_mut() {
op.record(&mut ctx).expect("mock op must record");
}
assert_eq!(
*trace.lock().unwrap(),
vec!["first_mock", "second_mock"],
"MockOp::record should append its name in script order"
);
assert_eq!(*count.lock().unwrap(), 2, "CounterOp ran twice");
}
#[test]
fn require_stream_returns_clean_error_in_no_gpu_ctx() {
let ctx = no_gpu_ctx();
let err = ctx.require_stream().unwrap_err();
assert!(matches!(err, GpuError::Unrecoverable(_)));
}
#[test]
fn graph_op_legacy_dispatches_to_inner_op() {
let trace: Arc<StdMutex<Vec<&'static str>>> = Arc::new(StdMutex::new(Vec::new()));
let mut boxed: Box<dyn GraphOp> = Box::new(MockOp {
name: "via_box_dyn",
trace: trace.clone(),
});
let mut ctx = no_gpu_ctx();
boxed.record(&mut ctx).unwrap();
assert_eq!(*trace.lock().unwrap(), vec!["via_box_dyn"]);
assert_eq!(boxed.op_name(), "via_box_dyn");
}
}