use std::sync::Arc;
use cudarc::driver::sys as driver_sys;
use crate::error::GpuError;
use crate::graph::{GraphHandle, GraphOpRecord, GraphRecordCtx};
const LIB: &str = "graph";
pub struct ChildGraphOp {
pub child: GraphHandle,
}
impl GraphOpRecord for ChildGraphOp {
fn record(&self, ctx: &GraphRecordCtx<'_>) -> Result<(), GpuError> {
let parent = ctx.parent_graph();
let cu_child = self.child.cu_graph();
let mut node: driver_sys::CUgraphNode = std::ptr::null_mut();
let s = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
driver_sys::cuGraphAddChildGraphNode(
&mut node as *mut _,
parent,
std::ptr::null(),
0,
cu_child,
)
}));
match s {
Ok(s) => {
if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
Ok(())
} else {
Err(GpuError::LibraryError {
lib: LIB,
msg: format!("cuGraphAddChildGraphNode: {s:?}"),
})
}
}
Err(_) => Err(GpuError::Unrecoverable(
"ChildGraphOp::record: CUDA driver not loadable".into(),
)),
}
}
}
pub fn child_graph_op(child: GraphHandle) -> ChildGraphOp {
ChildGraphOp { child }
}
pub struct ChildGraphInsertion {
pub op: ChildGraphOp,
pub keep_alive: Arc<()>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::{GraphHandle, MockGraphRecordCtx};
use std::sync::Arc;
#[test]
fn child_graph_op_records_into_parent() {
let child = GraphHandle::synthetic_for_tests();
let op = child_graph_op(child);
let parent_graph: driver_sys::CUgraph = std::ptr::null_mut();
let mock = MockGraphRecordCtx::new(parent_graph);
let ctx: GraphRecordCtx<'_> = mock.as_ctx();
let r = op.record(&ctx);
match r {
Ok(()) => {}
Err(GpuError::Unrecoverable(_)) => {}
Err(GpuError::LibraryError { .. }) => {}
other => panic!("unexpected: {other:?}"),
}
let _ = Arc::new(()); }
}