use {
crate::{
error::DnnlError,
graph::{
partition::OneDNNGraphPartition,
tensor::{logical::LogicalTensor, tensor::Tensor},
},
stream::Stream,
},
onednnl_sys::{
dnnl_graph_compiled_partition_create, dnnl_graph_compiled_partition_destroy,
dnnl_graph_compiled_partition_execute, dnnl_graph_compiled_partition_query_logical_tensor,
dnnl_graph_compiled_partition_t, dnnl_status_t,
},
};
pub struct CompiledPartition {
pub(crate) handle: dnnl_graph_compiled_partition_t,
pub(crate) partition: OneDNNGraphPartition,
}
impl CompiledPartition {
pub fn create(partition: OneDNNGraphPartition) -> Result<Self, DnnlError> {
let mut handle = std::ptr::null_mut();
let status = unsafe { dnnl_graph_compiled_partition_create(&mut handle, partition.handle) };
if status != dnnl_status_t::dnnl_success {
return Err(status.into());
}
Ok(CompiledPartition { handle, partition })
}
pub fn execute(
&self,
stream: &Stream,
inputs: &[Tensor],
outputs: &[&mut Tensor],
) -> Result<(), DnnlError> {
let mut input_handles: Vec<_> = inputs.iter().map(|t| t.handle as *const _).collect();
let mut output_handles: Vec<_> = outputs.iter().map(|t| t.handle as *const _).collect();
let num_inputs = input_handles.len();
let num_outputs = output_handles.len();
let status = unsafe {
dnnl_graph_compiled_partition_execute(
self.handle,
stream.handle,
num_inputs,
input_handles.as_mut_ptr(),
num_outputs,
output_handles.as_mut_ptr(), )
};
if status != dnnl_status_t::dnnl_success {
return Err(status.into());
}
Ok(())
}
pub fn query_logical_tensor(&self, index: usize) -> Result<LogicalTensor, DnnlError> {
let mut logical_tensor = std::mem::MaybeUninit::uninit();
let status = unsafe {
dnnl_graph_compiled_partition_query_logical_tensor(
self.handle,
index,
logical_tensor.as_mut_ptr(),
)
};
if status != dnnl_status_t::dnnl_success {
return Err(status.into());
}
let lt = unsafe {
LogicalTensor {
handle: logical_tensor.assume_init(),
}
};
Ok(lt)
}
}
impl Drop for CompiledPartition {
fn drop(&mut self) {
unsafe {
dnnl_graph_compiled_partition_destroy(self.handle);
}
}
}