use {
crate::{engine::Engine, error::DnnlError, memory::Memory, stream::Stream},
config::PrimitiveConfig,
descriptor::PrimitiveDescriptor,
onednnl_sys::{
dnnl_exec_arg_t, dnnl_primitive_create, dnnl_primitive_destroy, dnnl_primitive_execute,
dnnl_primitive_t, dnnl_prop_kind_t, dnnl_status_t,
},
std::sync::Arc,
};
pub mod attributes;
pub mod config;
pub mod descriptor;
pub trait Direction {
const KIND: DirectionT;
}
pub enum DirectionT {
Forward,
Backward,
}
pub struct Forward;
pub struct Backward;
impl Direction for Forward {
const KIND: DirectionT = DirectionT::Forward;
}
impl Direction for Backward {
const KIND: DirectionT = DirectionT::Backward;
}
pub enum OperationType {
Augru,
BatchNormalization,
Binary,
Concat,
Convolution,
Deconvolution,
Eltwise,
GroupNormalization,
Gru,
InnerProduct,
LayerNormalization,
LbrAuGru,
Lrn,
Lstm,
MatMul,
PRelu,
Reduction,
Shuffle,
Softmax,
VanillaRnn,
}
#[derive(Debug, Copy, Clone)]
pub struct PropForwardTraining;
#[derive(Debug, Clone, Copy)]
pub struct PropForwardInference;
#[derive(Debug, Clone, Copy)]
pub struct PropBackward;
#[derive(Debug, Clone, Copy)]
pub struct PropBackwardBias;
#[derive(Debug, Clone, Copy)]
pub struct PropBackwardWeights;
#[derive(Debug, Clone, Copy)]
pub struct PropBackwardData;
#[derive(Debug, Clone, Copy)]
pub struct PropAny;
pub trait Operation<'a, D: Direction, P: PropType<D>> {
const TYPE: OperationType;
type OperationConfig: PrimitiveConfig<'a, D, P>;
}
pub trait PropType<D> {
const KIND: dnnl_prop_kind_t::Type;
}
impl PropType<Forward> for PropForwardInference {
const KIND: dnnl_prop_kind_t::Type = dnnl_prop_kind_t::dnnl_forward_inference;
}
impl PropType<Forward> for PropForwardTraining {
const KIND: dnnl_prop_kind_t::Type = dnnl_prop_kind_t::dnnl_forward_training;
}
impl PropType<Backward> for PropBackward {
const KIND: dnnl_prop_kind_t::Type = dnnl_prop_kind_t::dnnl_backward;
}
impl PropType<Backward> for PropBackwardWeights {
const KIND: dnnl_prop_kind_t::Type = dnnl_prop_kind_t::dnnl_backward_weights;
}
impl PropType<Backward> for PropBackwardData {
const KIND: dnnl_prop_kind_t::Type = dnnl_prop_kind_t::dnnl_backward_data;
}
pub struct Primitive<'a, D: Direction, P: PropType<D>, C: PrimitiveConfig<'a, D, P>> {
pub handle: dnnl_primitive_t,
pub desc: Option<PrimitiveDescriptor<'a, D, P, C>>,
pub engine: Arc<Engine>,
}
impl<'a, D: Direction, P: PropType<D>, C: PrimitiveConfig<'a, D, P>> Primitive<'a, D, P, C> {
pub fn new<O: Operation<'a, D, P, OperationConfig = C>>(
config: O::OperationConfig,
engine: Arc<Engine>,
) -> Result<Primitive<'a, D, P, C>, DnnlError> {
let desc = config.create_primitive_desc(engine.clone())?;
Self::from_descriptor(desc, engine)
}
pub fn from_descriptor(
desc: PrimitiveDescriptor<'a, D, P, C>,
engine: Arc<Engine>,
) -> Result<Primitive<'a, D, P, C>, DnnlError> {
let mut handle = std::ptr::null_mut();
let status = unsafe { dnnl_primitive_create(&mut handle, desc.handle) };
if status == dnnl_status_t::dnnl_success {
Ok(Primitive::<'a, D, P, C> {
handle,
desc: Some(desc),
engine,
})
} else {
Err(status.into())
}
}
pub fn execute<T>(
&mut self,
stream: &Stream,
args: Vec<ExecArg<'_, T>>,
) -> Result<Option<PrimitiveDescriptor<'a, D, P, C>>, DnnlError> {
let c_args: Vec<dnnl_exec_arg_t> = args
.iter()
.map(|arg| dnnl_exec_arg_t {
arg: arg.index,
memory: arg.mem.handle,
})
.collect();
let status = unsafe {
dnnl_primitive_execute(
self.handle,
stream.handle,
c_args.len() as i32,
c_args.as_ptr(),
)
};
if status == dnnl_status_t::dnnl_success {
Ok(self.desc.take())
} else {
Err(status.into())
}
}
}
impl<'a, D: Direction, P: PropType<D>, C: PrimitiveConfig<'a, D, P>> Drop
for Primitive<'a, D, P, C>
{
fn drop(&mut self) {
unsafe {
dnnl_primitive_destroy(self.handle);
}
}
}
pub struct ExecArg<'a, T> {
pub index: i32,
pub mem: &'a Memory<T>,
}