use crate::runtime::{Graph, Runtime};
use crate::tensor::Tensor;
pub struct CapturedGraph<R: Runtime> {
graph: R::Graph,
inputs: Vec<Tensor<R>>,
outputs: Vec<Tensor<R>>,
arena: Option<Tensor<R>>,
}
impl<R: Runtime> CapturedGraph<R> {
pub fn new(graph: R::Graph, inputs: Vec<Tensor<R>>, outputs: Vec<Tensor<R>>) -> Self {
Self {
graph,
inputs,
outputs,
arena: None,
}
}
pub fn new_with_arena(
graph: R::Graph,
inputs: Vec<Tensor<R>>,
outputs: Vec<Tensor<R>>,
arena: Tensor<R>,
) -> Self {
Self {
graph,
inputs,
outputs,
arena: Some(arena),
}
}
pub fn launch(&self) -> crate::error::Result<()> {
self.graph.launch()
}
pub fn graph(&self) -> &R::Graph {
&self.graph
}
pub fn inputs(&self) -> &[Tensor<R>] {
&self.inputs
}
pub fn outputs(&self) -> &[Tensor<R>] {
&self.outputs
}
}
unsafe impl<R: Runtime> Send for CapturedGraph<R>
where
R::Graph: Send,
Tensor<R>: Send,
{
}
unsafe impl<R: Runtime> Sync for CapturedGraph<R>
where
R::Graph: Sync,
Tensor<R>: Sync,
{
}
impl<R: Runtime> std::fmt::Debug for CapturedGraph<R>
where
R::Graph: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CapturedGraph")
.field("graph", &self.graph)
.field("inputs_len", &self.inputs.len())
.field("outputs_len", &self.outputs.len())
.field("has_arena", &self.arena.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::NoOpGraph;
#[test]
fn test_captured_graph_noop_launch() {
let graph = NoOpGraph;
let captured: CapturedGraph<crate::runtime::cpu::CpuRuntime> =
CapturedGraph::new(graph, vec![], vec![]);
assert!(captured.launch().is_ok());
}
#[test]
fn test_captured_graph_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<CapturedGraph<crate::runtime::cpu::CpuRuntime>>();
}
#[cfg(feature = "cuda")]
#[test]
fn test_captured_graph_send_sync_cuda() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<CapturedGraph<crate::runtime::cuda::CudaRuntime>>();
}
#[test]
fn test_captured_graph_drop_ordering() {
use crate::runtime::cpu::CpuRuntime;
use crate::tensor::Tensor;
let device = CpuRuntime::default_device();
let a = Tensor::<CpuRuntime>::zeros(&[4], crate::dtype::DType::F32, &device);
let b = a.clone();
let c = a.clone();
let captured: CapturedGraph<CpuRuntime> =
CapturedGraph::new(NoOpGraph, vec![a, b], vec![c]);
drop(captured);
}
}