onednnl 0.0.1

high-level bindings to oneDNN Deep Learning library
Documentation
use {
    crate::{
        memory::descriptor::MemoryDescriptor,
        primitive::{
            attributes::PrimitiveAttributes, config::PrimitiveConfig,
            descriptor::PrimitiveDescriptor, Forward, Operation, OperationType, PropType,
        },
    },
    onednnl_sys::{dnnl_batch_normalization_forward_primitive_desc_create, dnnl_status_t},
    std::{ffi::c_uint, marker::PhantomData},
};

pub struct ForwardBatchNormConfig {
    src_desc: MemoryDescriptor,
    dst_desc: MemoryDescriptor,
    epsilon: f32,
    flags: c_uint,
    attr: PrimitiveAttributes,
}

impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardBatchNormConfig {
    fn create_primitive_desc(
        self,
        engine: std::sync::Arc<crate::engine::Engine>,
    ) -> Result<
        crate::primitive::descriptor::PrimitiveDescriptor<'a, Forward, P, ForwardBatchNormConfig>,
        crate::error::DnnlError,
    > {
        let mut handle = std::ptr::null_mut();

        let status = unsafe {
            dnnl_batch_normalization_forward_primitive_desc_create(
                &mut handle,
                engine.handle,
                P::KIND,
                self.src_desc.handle,
                self.dst_desc.handle,
                self.epsilon,
                self.flags,
                self.attr.handle,
            )
        };

        if status == dnnl_status_t::dnnl_success {
            Ok(
                PrimitiveDescriptor::<'a, Forward, P, ForwardBatchNormConfig> {
                    handle,
                    config: self,
                    _marker_a: PhantomData,
                    _marker_d: PhantomData,
                    _marker_p: PhantomData,
                },
            )
        } else {
            Err(status.into())
        }
    }
}

pub struct ForwardBatchNorm<P: PropType<Forward>> {
    pub prop_type: P,
}

impl<P: PropType<Forward>> Operation<'_, Forward, P> for ForwardBatchNorm<P> {
    const TYPE: OperationType = OperationType::BatchNormalization;

    type OperationConfig = ForwardBatchNormConfig;
}