onednnl 0.0.1

high-level bindings to oneDNN Deep Learning library
Documentation
use {
    crate::{
        memory::descriptor::MemoryDescriptor,
        primitive::{
            attributes::PrimitiveAttributes, config::PrimitiveConfig,
            descriptor::PrimitiveDescriptor, Backward, Forward, Operation, OperationType,
            PropBackward, PropForwardTraining, PropType,
        },
    },
    onednnl_sys::{
        dnnl_alg_kind_t, dnnl_eltwise_backward_primitive_desc_create,
        dnnl_eltwise_forward_primitive_desc_create, dnnl_status_t,
    },
    std::marker::PhantomData,
};

pub struct ForwardEltwiseConfig {
    pub alg_kind: dnnl_alg_kind_t::Type,
    pub src_desc: MemoryDescriptor,
    pub dst_desc: MemoryDescriptor,
    pub alpha: f32,
    pub beta: f32,
    pub attr: PrimitiveAttributes,
}

impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardEltwiseConfig {
    fn create_primitive_desc(
        self,
        engine: std::sync::Arc<crate::engine::Engine>,
    ) -> Result<
        crate::primitive::descriptor::PrimitiveDescriptor<'a, Forward, P, ForwardEltwiseConfig>,
        crate::error::DnnlError,
    > {
        let mut handle = std::ptr::null_mut();
        let status = unsafe {
            dnnl_eltwise_forward_primitive_desc_create(
                &mut handle,
                engine.handle,
                P::KIND,
                self.alg_kind,
                self.src_desc.handle,
                self.dst_desc.handle,
                self.alpha,
                self.beta,
                self.attr.handle,
            )
        };

        if status == dnnl_status_t::dnnl_success {
            Ok(
                PrimitiveDescriptor::<'a, Forward, P, ForwardEltwiseConfig> {
                    handle,
                    config: self,

                    _marker_a: PhantomData,
                    _marker_d: PhantomData,
                    _marker_p: PhantomData,
                },
            )
        } else {
            Err(status.into())
        }
    }
}

pub struct Unary;

impl Unary {
    pub const TANH: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_tanh;
    pub const TANH_USE_DST_FOR_BWD: dnnl_alg_kind_t::Type =
        dnnl_alg_kind_t::dnnl_eltwise_tanh_use_dst_for_bwd;
    pub const ABS: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_abs;
    pub const CLIP: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_clip;
    pub const CLIP_V2: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_clip_v2;
    pub const CLIP_V2_USE_DST_FOR_BWD: dnnl_alg_kind_t::Type =
        dnnl_alg_kind_t::dnnl_eltwise_clip_v2_use_dst_for_bwd;
    pub const ELU: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_elu;
    pub const ELU_USE_DST_FOR_BWD: dnnl_alg_kind_t::Type =
        dnnl_alg_kind_t::dnnl_eltwise_elu_use_dst_for_bwd;
    pub const EXP: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_exp;
    pub const EXP_USE_DST_FOR_BWD: dnnl_alg_kind_t::Type =
        dnnl_alg_kind_t::dnnl_eltwise_exp_use_dst_for_bwd;
    pub const GELU_ERF: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_gelu_erf;
    pub const GELU_TANH: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_gelu_tanh;
    pub const HARDSIGMOID: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_hardsigmoid;
    pub const HARDSWISH: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_hardswish;
    pub const LINEAR: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_linear;
    pub const LOG: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_log;
    pub const LOGISTIC: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_logistic;
    pub const LOGISTIC_USE_DST_FOR_BWD: dnnl_alg_kind_t::Type =
        dnnl_alg_kind_t::dnnl_eltwise_logistic_use_dst_for_bwd;
    pub const MISH: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_mish;
    pub const POW: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_pow;
    pub const RELU: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_relu;
    pub const RELU_USE_DST_FOR_BWD: dnnl_alg_kind_t::Type =
        dnnl_alg_kind_t::dnnl_eltwise_relu_use_dst_for_bwd;
    pub const ROUND: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_round;
    pub const SOFT_RELU: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_soft_relu;
    pub const SQRT: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_sqrt;
    pub const SQRT_USE_DST_FOR_BWD: dnnl_alg_kind_t::Type =
        dnnl_alg_kind_t::dnnl_eltwise_sqrt_use_dst_for_bwd;
    pub const SQUARE: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_square;
    pub const SWISH: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_swish;
}

pub struct BackwardEltwiseConfig<'a> {
    pub alg_kind: dnnl_alg_kind_t::Type,
    pub diff_src_desc: MemoryDescriptor,
    pub diff_dest_desc: MemoryDescriptor,
    pub data_desc: MemoryDescriptor,
    pub alpha: f32,
    pub beta: f32,
    pub forward_hint_desc:
        &'a PrimitiveDescriptor<'a, Forward, PropForwardTraining, ForwardEltwiseConfig>,
    pub attr: PrimitiveAttributes,
}

impl<'a> PrimitiveConfig<'a, Backward, PropBackward> for BackwardEltwiseConfig<'a> {
    fn create_primitive_desc(
        self,
        engine: std::sync::Arc<crate::engine::Engine>,
    ) -> Result<
        crate::primitive::descriptor::PrimitiveDescriptor<
            'a,
            Backward,
            PropBackward,
            BackwardEltwiseConfig<'a>,
        >,
        crate::error::DnnlError,
    > {
        let mut handle = std::ptr::null_mut();
        let status = unsafe {
            dnnl_eltwise_backward_primitive_desc_create(
                &mut handle,
                engine.handle,
                self.alg_kind,
                self.diff_src_desc.handle,
                self.diff_dest_desc.handle,
                self.data_desc.handle,
                self.alpha,
                self.beta,
                self.forward_hint_desc.handle,
                self.attr.handle,
            )
        };

        if status == dnnl_status_t::dnnl_success {
            Ok(
                PrimitiveDescriptor::<'a, Backward, PropBackward, BackwardEltwiseConfig<'a>> {
                    handle,
                    config: self,

                    _marker_a: PhantomData,
                    _marker_d: PhantomData,
                    _marker_p: PhantomData,
                },
            )
        } else {
            Err(status.into())
        }
    }
}

pub struct ForwardEltwise<P: PropType<Forward>> {
    pub prop_type: P,
}

impl<P: PropType<Forward>> Operation<'_, Forward, P> for ForwardEltwise<P> {
    const TYPE: OperationType = OperationType::Eltwise;

    type OperationConfig = ForwardEltwiseConfig;
}

pub struct BackwardEltwise<T: PropType<Backward>> {
    pub prop_type: T,
}

impl<'a> Operation<'a, Backward, PropBackward> for BackwardEltwise<PropBackward> {
    const TYPE: OperationType = OperationType::Eltwise;
    type OperationConfig = BackwardEltwiseConfig<'a>;
}