#![no_std]
#![deny(missing_docs)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Padding {
Same,
Valid,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FusedActivation {
None,
Relu,
Relu6,
ReluN1To1,
Tanh,
SignBit,
Sigmoid,
}
#[derive(Clone, Copy, Debug)]
pub struct QuantParam {
pub scale: f32,
pub zero_point: i32,
}
#[derive(Clone, Copy, Debug)]
pub struct PerChannelQuantParam<'a> {
pub scales: &'a [f32],
pub zero_points: &'a [i32],
pub quantized_dimension: usize,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum KernelError {
InvalidShape,
UnsupportedActivation,
AlignmentError,
InternalError,
}
pub type Status = Result<(), KernelError>;
pub struct Conv2dParams<'a> {
pub input: &'a [i8],
pub input_shape: [usize; 4],
pub input_quant: QuantParam,
pub weights: &'a [i8],
pub weights_shape: [usize; 4],
pub weights_quant: QuantParam,
pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
pub bias: Option<&'a [i32]>,
pub output: &'a mut [i8],
pub output_shape: [usize; 4],
pub output_quant: QuantParam,
pub stride_w: i32,
pub stride_h: i32,
pub dilation_w_factor: i32,
pub dilation_h_factor: i32,
pub padding: Padding,
pub activation: FusedActivation,
pub scratch: &'a mut [u8],
}
pub struct DepthwiseConv2dParams<'a> {
pub input: &'a [i8],
pub input_shape: [usize; 4],
pub input_quant: QuantParam,
pub weights: &'a [i8],
pub weights_shape: [usize; 4],
pub weights_quant: QuantParam,
pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
pub bias: Option<&'a [i32]>,
pub output: &'a mut [i8],
pub output_shape: [usize; 4],
pub output_quant: QuantParam,
pub stride_w: i32,
pub stride_h: i32,
pub dilation_w_factor: i32,
pub dilation_h_factor: i32,
pub depth_multiplier: i32,
pub padding: Padding,
pub activation: FusedActivation,
pub scratch: &'a mut [u8],
}
pub struct FullyConnectedParams<'a> {
pub input: &'a [i8],
pub input_quant: QuantParam,
pub weights: &'a [i8],
pub weights_shape: [usize; 2],
pub weights_quant: QuantParam,
pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
pub bias: Option<&'a [i32]>,
pub output: &'a mut [i8],
pub output_depth: usize,
pub output_quant: QuantParam,
pub activation: FusedActivation,
}
pub struct PoolParams<'a> {
pub input: &'a [i8],
pub input_shape: [usize; 4],
pub input_quant: QuantParam,
pub output: &'a mut [i8],
pub output_shape: [usize; 4],
pub output_quant: QuantParam,
pub stride_w: i32,
pub stride_h: i32,
pub filter_w: i32,
pub filter_h: i32,
pub padding: Padding,
pub activation: FusedActivation,
}
pub struct SoftmaxParams<'a> {
pub input: &'a [i8],
pub input_shape: [usize; 2],
pub input_quant: QuantParam,
pub output: &'a mut [i8],
pub output_quant: QuantParam,
pub beta: f32,
pub scratch: &'a mut [u8],
}
pub struct ElementwiseAddParams<'a> {
pub input1: &'a [i8],
pub input1_quant: QuantParam,
pub input2: &'a [i8],
pub input2_quant: QuantParam,
pub output: &'a mut [i8],
pub output_quant: QuantParam,
pub activation: FusedActivation,
}
pub trait KernelBackend {
fn conv2d(&mut self, params: Conv2dParams<'_>) -> Status;
fn depthwise_conv2d(&mut self, params: DepthwiseConv2dParams<'_>) -> Status;
fn fully_connected(&mut self, params: FullyConnectedParams<'_>) -> Status;
fn avg_pool(&mut self, params: PoolParams<'_>) -> Status;
fn max_pool(&mut self, params: PoolParams<'_>) -> Status;
fn softmax(&mut self, params: SoftmaxParams<'_>) -> Status;
fn add(&mut self, params: ElementwiseAddParams<'_>) -> Status;
fn conv2d_scratch_size(
input_shape: [usize; 4],
weights_shape: [usize; 4],
output_shape: [usize; 4],
) -> usize
where
Self: Sized,
{
let _ = (input_shape, weights_shape, output_shape);
0
}
fn depthwise_conv2d_scratch_size(
input_shape: [usize; 4],
weights_shape: [usize; 4],
output_shape: [usize; 4],
) -> usize
where
Self: Sized,
{
let _ = (input_shape, weights_shape, output_shape);
0
}
fn softmax_scratch_size(num_classes: usize) -> usize
where
Self: Sized,
{
let _ = num_classes;
0
}
}