onednnl 0.0.1

high-level bindings to oneDNN Deep Learning library
Documentation
use {
    crate::{
        engine::Engine,
        error::DnnlError,
        memory::descriptor::MemoryDescriptor,
        primitive::{
            attributes::PrimitiveAttributes, config::PrimitiveConfig,
            descriptor::PrimitiveDescriptor, Backward, Forward, Operation, OperationType,
            PropForwardTraining, PropType,
        },
    },
    onednnl_sys::{
        dnnl_augru_backward_primitive_desc_create, dnnl_augru_forward_primitive_desc_create,
        dnnl_rnn_direction_t, dnnl_status_t,
    },
    std::{ffi::c_uint, marker::PhantomData, sync::Arc},
};

pub struct ForwardAuGruConfig {
    direction: dnnl_rnn_direction_t::Type,
    src_layer_desc: MemoryDescriptor,
    src_iter_desc: MemoryDescriptor,
    attention_desc: MemoryDescriptor,
    weights_layer_desc: MemoryDescriptor,
    weights_iter_desc: MemoryDescriptor,
    bias_desc: MemoryDescriptor,
    dst_layer_desc: MemoryDescriptor,
    dst_iter_desc: MemoryDescriptor,
    flags: c_uint,
    attr: PrimitiveAttributes,
}

impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardAuGruConfig {
    fn create_primitive_desc(
        self,
        engine: Arc<Engine>,
    ) -> Result<PrimitiveDescriptor<'a, Forward, P, ForwardAuGruConfig>, DnnlError> {
        let mut handle = std::ptr::null_mut();
        let status = unsafe {
            dnnl_augru_forward_primitive_desc_create(
                &mut handle,
                engine.handle,
                P::KIND,
                self.direction,
                self.src_layer_desc.handle,
                self.src_iter_desc.handle,
                self.attention_desc.handle,
                self.weights_layer_desc.handle,
                self.weights_iter_desc.handle,
                self.bias_desc.handle,
                self.dst_layer_desc.handle,
                self.dst_iter_desc.handle,
                self.flags,
                self.attr.handle,
            )
        };

        if status == dnnl_status_t::dnnl_success {
            Ok(PrimitiveDescriptor::<'a, Forward, P, ForwardAuGruConfig> {
                handle,
                config: self,

                _marker_a: PhantomData,
                _marker_d: PhantomData,
                _marker_p: PhantomData,
            })
        } else {
            Err(status.into())
        }
    }
}

pub struct BackwardAuGruConfig<'a> {
    direction: dnnl_rnn_direction_t::Type,
    src_layer_desc: MemoryDescriptor,
    src_iter_desc: MemoryDescriptor,
    attention_desc: MemoryDescriptor,
    weights_layer_desc: MemoryDescriptor,
    weights_iter_desc: MemoryDescriptor,
    bias_desc: MemoryDescriptor,
    dst_layer_desc: MemoryDescriptor,
    dst_iter_desc: MemoryDescriptor,
    diff_src_layer_desc: MemoryDescriptor,
    diff_src_iter_desc: MemoryDescriptor,
    diff_attention_desc: MemoryDescriptor,
    diff_weights_layer_desc: MemoryDescriptor,
    diff_weights_iter_desc: MemoryDescriptor,
    diff_bias_desc: MemoryDescriptor,
    diff_dst_layer_desc: MemoryDescriptor,
    diff_dst_iter_desc: MemoryDescriptor,
    flags: c_uint,
    hint_fwd_pd: &'a PrimitiveDescriptor<'a, Forward, PropForwardTraining, ForwardAuGruConfig>,
    attr: PrimitiveAttributes,
}

impl<'a, P: PropType<Backward>> PrimitiveConfig<'a, Backward, P> for BackwardAuGruConfig<'a> {
    fn create_primitive_desc(
        self,
        engine: Arc<Engine>,
    ) -> Result<PrimitiveDescriptor<'a, Backward, P, BackwardAuGruConfig<'a>>, DnnlError> {
        let mut handle = std::ptr::null_mut();
        let status = unsafe {
            dnnl_augru_backward_primitive_desc_create(
                &mut handle,
                engine.handle,
                P::KIND,
                self.direction,
                self.src_layer_desc.handle,
                self.src_iter_desc.handle,
                self.attention_desc.handle,
                self.weights_layer_desc.handle,
                self.weights_iter_desc.handle,
                self.bias_desc.handle,
                self.dst_layer_desc.handle,
                self.dst_iter_desc.handle,
                self.diff_src_layer_desc.handle,
                self.diff_src_iter_desc.handle,
                self.diff_attention_desc.handle,
                self.diff_weights_layer_desc.handle,
                self.diff_weights_iter_desc.handle,
                self.diff_bias_desc.handle,
                self.diff_dst_layer_desc.handle,
                self.diff_dst_iter_desc.handle,
                self.flags,
                self.hint_fwd_pd.handle,
                self.attr.handle,
            )
        };

        if status == dnnl_status_t::dnnl_success {
            Ok(
                PrimitiveDescriptor::<'a, Backward, P, BackwardAuGruConfig<'a>> {
                    handle,
                    config: self,
                    _marker_a: PhantomData,
                    _marker_d: PhantomData,
                    _marker_p: PhantomData,
                },
            )
        } else {
            Err(status.into())
        }
    }
}

pub struct ForwardAuGru<P: PropType<Forward>> {
    pub prop_type: P,
}

impl<P: PropType<Forward>> Operation<'_, Forward, P> for ForwardAuGru<P> {
    const TYPE: OperationType = OperationType::Augru;
    type OperationConfig = ForwardAuGruConfig;
}

pub struct BackwardAuGru<P: PropType<Backward>> {
    pub prop_type: P,
}

impl<'a, P: PropType<Backward>> Operation<'a, Backward, P> for BackwardAuGru<P> {
    const TYPE: OperationType = OperationType::Augru;
    type OperationConfig = BackwardAuGruConfig<'a>;
}