ember-infer-core 0.1.1

Core KernelBackend trait for ember-rs embedded INT8 inference engine
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
#![no_std]
#![deny(missing_docs)]

//! # ember-infer-core
//!
//! Core trait definitions for the ember-rs embedded TinyML inference engine.
//!
//! ember-rs is a `no_std` INT8 inference engine designed to be the
//! "Burn of embedded inference" - providing a pluggable [`KernelBackend`] trait
//! so different hardware backends can be swapped without changing model code.
//!
//! ## Design
//!
//! - Param struct field names mirror TFLite Micro's C structs
//!   (`TfLiteConvParams`, `TfLiteDepthwiseConvParams`, etc.)
//! - The `invoke` phase is covered by this trait; the `prepare` phase
//!   (scratch size calculation, shape inference) is handled at compile time
//!   by `ember-infer-macros`
//! - `scratch_size_*` functions have default implementations returning `0`,
//!   so pure-Rust reference backends don't need to implement them
//!
//! ## Backend implementations
//!
//! - `ember-infer-ref`: pure Rust reference implementation (for testing / non-ESP platforms)
//! - `ember-esp`: official ESP32-S3 backend using Espressif's esp-nn SIMD kernels
//!   (maintained in a separate repository)

// ----------------------------------------------------------------------------
// Enums - mirror TFLite Micro's C enums
// ----------------------------------------------------------------------------

/// Padding strategy, mirrors `TfLitePadding`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Padding {
    /// Output size equals `ceil(input_size / stride)`.
    Same,
    /// No padding; output size equals `floor((input_size - filter_size) / stride) + 1`.
    Valid,
}

/// Fused activation function applied after an operator, mirrors `TfLiteFusedActivation`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FusedActivation {
    /// No activation.
    None,
    /// ReLU: `max(0, x)`.
    Relu,
    /// ReLU6: `min(max(0, x), 6)`.
    Relu6,
    /// ReLU with range `[-1, 1]`.
    ReluN1To1,
    /// Tanh activation.
    Tanh,
    /// Sign bit activation.
    SignBit,
    /// Sigmoid activation.
    Sigmoid,
}

// ----------------------------------------------------------------------------
// Quantization params - mirror TfLiteQuantizationParams
// ----------------------------------------------------------------------------

/// Per-tensor quantization parameters, mirrors `TfLiteQuantizationParams`.
#[derive(Clone, Copy, Debug)]
pub struct QuantParam {
    /// The scale factor: `real_value = scale * (quantized_value - zero_point)`.
    pub scale: f32,
    /// The zero point for asymmetric quantization.
    pub zero_point: i32,
}

/// Optional per-channel quantization parameters for weight tensors.
///
/// TFLite commonly uses per-channel scales for convolution-family weights,
/// with the quantized dimension matching the output-channel axis.
#[derive(Clone, Copy, Debug)]
pub struct PerChannelQuantParam<'a> {
    /// Scale values, one per quantized channel.
    pub scales: &'a [f32],
    /// Zero points, one per quantized channel.
    pub zero_points: &'a [i32],
    /// Tensor axis the per-channel values apply to.
    pub quantized_dimension: usize,
}

// ----------------------------------------------------------------------------
// Error / Status - mirror TfLiteStatus
// ----------------------------------------------------------------------------

/// Errors that a [`KernelBackend`] implementation may return.
/// Mirrors the error states of `TfLiteStatus`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum KernelError {
    /// Input or output tensor shapes are invalid for this operation.
    InvalidShape,
    /// The requested activation function is not supported by this backend.
    UnsupportedActivation,
    /// A buffer passed to the backend does not meet alignment requirements.
    ///
    /// ESP-NN kernels require 16-byte aligned buffers. If the `assume-aligned`
    /// feature is disabled (default), misaligned buffers are automatically
    /// copied to an aligned scratch region. This error is only returned when
    /// `assume-aligned` is enabled and the caller violates the contract.
    AlignmentError,
    /// An internal error occurred in the backend.
    InternalError,
}

/// Result type for all [`KernelBackend`] operations, mirrors `TfLiteStatus`.
pub type Status = Result<(), KernelError>;

// ----------------------------------------------------------------------------
// Operator parameter structs
// Field names mirror TFLite Micro's C structs for easy cross-reference.
// ----------------------------------------------------------------------------

/// Parameters for a 2D convolution operation.
///
/// Mirrors `TfLiteConvParams` from TFLite Micro.
/// Tensor layout: NHWC (batch, height, width, channels).
pub struct Conv2dParams<'a> {
    /// Input tensor data, quantized as `int8`.
    pub input: &'a [i8],
    /// Input tensor shape `[N, H, W, C_in]`.
    pub input_shape: [usize; 4],
    /// Input quantization parameters.
    pub input_quant: QuantParam,
    /// Weight tensor data, quantized as `int8`. Layout: `[C_out, KH, KW, C_in]`.
    pub weights: &'a [i8],
    /// Weight tensor shape `[C_out, KH, KW, C_in]`.
    pub weights_shape: [usize; 4],
    /// Weight quantization parameters (per-tensor).
    pub weights_quant: QuantParam,
    /// Optional per-channel weight quantization parameters.
    pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
    /// Optional bias tensor, stored as `int32`.
    ///
    /// Length must equal `C_out` when `Some`.
    pub bias: Option<&'a [i32]>,
    /// Output tensor buffer, written by the backend.
    pub output: &'a mut [i8],
    /// Output tensor shape `[N, H_out, W_out, C_out]`.
    pub output_shape: [usize; 4],
    /// Output quantization parameters.
    pub output_quant: QuantParam,
    /// Horizontal stride.
    pub stride_w: i32,
    /// Vertical stride.
    pub stride_h: i32,
    /// Horizontal dilation factor (1 = no dilation).
    pub dilation_w_factor: i32,
    /// Vertical dilation factor (1 = no dilation).
    pub dilation_h_factor: i32,
    /// Padding mode.
    pub padding: Padding,
    /// Fused activation function applied to the output.
    pub activation: FusedActivation,
    /// Scratch buffer for intermediate computation.
    ///
    /// Required size is reported by [`KernelBackend::conv2d_scratch_size`].
    /// Pass an empty slice `&mut []` if the backend does not require scratch memory.
    pub scratch: &'a mut [u8],
}

/// Parameters for a depthwise 2D convolution operation.
///
/// Mirrors `TfLiteDepthwiseConvParams` from TFLite Micro.
pub struct DepthwiseConv2dParams<'a> {
    /// Input tensor data, quantized as `int8`.
    pub input: &'a [i8],
    /// Input tensor shape `[N, H, W, C_in]`.
    pub input_shape: [usize; 4],
    /// Input quantization parameters.
    pub input_quant: QuantParam,
    /// Weight tensor data, quantized as `int8`.
    pub weights: &'a [i8],
    /// Weight tensor shape.
    pub weights_shape: [usize; 4],
    /// Weight quantization parameters (per-tensor).
    pub weights_quant: QuantParam,
    /// Optional per-channel weight quantization parameters.
    pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
    /// Optional bias tensor, stored as `int32`.
    pub bias: Option<&'a [i32]>,
    /// Output tensor buffer, written by the backend.
    pub output: &'a mut [i8],
    /// Output tensor shape `[N, H_out, W_out, C_out]`.
    pub output_shape: [usize; 4],
    /// Output quantization parameters.
    pub output_quant: QuantParam,
    /// Horizontal stride.
    pub stride_w: i32,
    /// Vertical stride.
    pub stride_h: i32,
    /// Horizontal dilation factor (1 = no dilation).
    pub dilation_w_factor: i32,
    /// Vertical dilation factor (1 = no dilation).
    pub dilation_h_factor: i32,
    /// Depth multiplier - the number of output channels per input channel.
    ///
    /// Specific to depthwise convolution; mirrors
    /// `TfLiteDepthwiseConvParams::depth_multiplier`.
    pub depth_multiplier: i32,
    /// Padding mode.
    pub padding: Padding,
    /// Fused activation function applied to the output.
    pub activation: FusedActivation,
    /// Scratch buffer for intermediate computation.
    pub scratch: &'a mut [u8],
}

/// Parameters for a fully-connected (dense) layer.
///
/// Mirrors `TfLiteFullyConnectedParams` from TFLite Micro.
pub struct FullyConnectedParams<'a> {
    /// Input tensor data, quantized as `int8`.
    pub input: &'a [i8],
    /// Input quantization parameters.
    pub input_quant: QuantParam,
    /// Weight tensor data. Layout: `[output_depth, input_depth]`.
    pub weights: &'a [i8],
    /// Weight tensor shape `[output_depth, input_depth]`.
    pub weights_shape: [usize; 2],
    /// Weight quantization parameters (per-tensor).
    pub weights_quant: QuantParam,
    /// Optional per-channel weight quantization parameters.
    pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
    /// Optional bias tensor, stored as `int32`.
    pub bias: Option<&'a [i32]>,
    /// Output tensor buffer, written by the backend.
    pub output: &'a mut [i8],
    /// Number of output neurons.
    pub output_depth: usize,
    /// Output quantization parameters.
    pub output_quant: QuantParam,
    /// Fused activation function applied to the output.
    pub activation: FusedActivation,
}

/// Parameters for a pooling operation (average or max).
///
/// Mirrors `TfLitePoolParams` from TFLite Micro.
pub struct PoolParams<'a> {
    /// Input tensor data, quantized as `int8`.
    pub input: &'a [i8],
    /// Input tensor shape `[N, H, W, C]`.
    pub input_shape: [usize; 4],
    /// Input quantization parameters.
    pub input_quant: QuantParam,
    /// Output tensor buffer, written by the backend.
    pub output: &'a mut [i8],
    /// Output tensor shape `[N, H_out, W_out, C]`.
    pub output_shape: [usize; 4],
    /// Output quantization parameters.
    pub output_quant: QuantParam,
    /// Horizontal stride.
    pub stride_w: i32,
    /// Vertical stride.
    pub stride_h: i32,
    /// Pooling filter width.
    pub filter_w: i32,
    /// Pooling filter height.
    pub filter_h: i32,
    /// Padding mode.
    pub padding: Padding,
    /// Fused activation function applied to the output.
    pub activation: FusedActivation,
}

/// Parameters for the softmax operation.
///
/// Mirrors `TfLiteSoftmaxParams` from TFLite Micro.
pub struct SoftmaxParams<'a> {
    /// Input tensor data, quantized as `int8`.
    pub input: &'a [i8],
    /// Input shape `[batch, num_classes]`.
    pub input_shape: [usize; 2],
    /// Input quantization parameters.
    pub input_quant: QuantParam,
    /// Output tensor buffer, written by the backend.
    pub output: &'a mut [i8],
    /// Output quantization parameters.
    pub output_quant: QuantParam,
    /// Softmax beta parameter (typically `1.0`).
    ///
    /// Mirrors `TfLiteSoftmaxParams::beta`.
    pub beta: f32,
    /// Scratch buffer for intermediate computation.
    pub scratch: &'a mut [u8],
}

/// Parameters for element-wise addition.
///
/// Mirrors `TfLiteAddParams` from TFLite Micro.
pub struct ElementwiseAddParams<'a> {
    /// First input tensor data, quantized as `int8`.
    pub input1: &'a [i8],
    /// First input quantization parameters.
    pub input1_quant: QuantParam,
    /// Second input tensor data, quantized as `int8`.
    pub input2: &'a [i8],
    /// Second input quantization parameters.
    pub input2_quant: QuantParam,
    /// Output tensor buffer, written by the backend.
    pub output: &'a mut [i8],
    /// Output quantization parameters.
    pub output_quant: QuantParam,
    /// Fused activation function applied to the output.
    pub activation: FusedActivation,
}

// ----------------------------------------------------------------------------
// KernelBackend - the central trait
// ----------------------------------------------------------------------------

/// The core abstraction for ember-rs: a hardware-specific INT8 inference backend.
///
/// # Design
///
/// This trait covers the **invoke phase** only. The **prepare phase**
/// (scratch buffer sizing, shape inference) is performed at compile time by
/// `ember-infer-macros` via the `conv2d_scratch_size` / `softmax_scratch_size`
/// associated functions, which have default implementations returning `0`.
///
/// Implementations map directly onto TFLite Micro kernel `invoke` functions,
/// which means porting an existing TFLite Micro optimized kernel (e.g., CMSIS-NN,
/// esp-nn) to ember-rs requires minimal glue code.
///
/// # Implementing a backend
///
/// ```rust,ignore
/// use ember_infer_core::{KernelBackend, Conv2dParams, Status};
///
/// pub struct MyBackend;
///
/// impl KernelBackend for MyBackend {
///     fn conv2d(&mut self, params: Conv2dParams<'_>) -> Status {
///         // call your hardware-accelerated kernel here
///         todo!()
///     }
///     // ... implement remaining required methods
/// }
/// ```
///
/// # Scratch buffers
///
/// Backends that require scratch memory (e.g., esp-nn, CMSIS-NN) must override
/// the `*_scratch_size` associated functions. The `ember-infer-macros` proc macro calls
/// these at compile time to allocate correctly-sized scratch arrays in the
/// generated inference function.
pub trait KernelBackend {
    /// Execute a 2D convolution.
    ///
    /// Corresponds to the `invoke` function of the `CONV_2D` kernel in TFLite Micro.
    fn conv2d(&mut self, params: Conv2dParams<'_>) -> Status;

    /// Execute a depthwise 2D convolution.
    ///
    /// Corresponds to the `invoke` function of the `DEPTHWISE_CONV_2D` kernel
    /// in TFLite Micro.
    fn depthwise_conv2d(&mut self, params: DepthwiseConv2dParams<'_>) -> Status;

    /// Execute a fully-connected layer.
    ///
    /// Corresponds to the `invoke` function of the `FULLY_CONNECTED` kernel
    /// in TFLite Micro.
    fn fully_connected(&mut self, params: FullyConnectedParams<'_>) -> Status;

    /// Execute average pooling.
    ///
    /// Corresponds to the `invoke` function of the `AVERAGE_POOL_2D` kernel
    /// in TFLite Micro.
    fn avg_pool(&mut self, params: PoolParams<'_>) -> Status;

    /// Execute max pooling.
    ///
    /// Corresponds to the `invoke` function of the `MAX_POOL_2D` kernel
    /// in TFLite Micro.
    fn max_pool(&mut self, params: PoolParams<'_>) -> Status;

    /// Execute softmax.
    ///
    /// Corresponds to the `invoke` function of the `SOFTMAX` kernel
    /// in TFLite Micro.
    fn softmax(&mut self, params: SoftmaxParams<'_>) -> Status;

    /// Execute element-wise addition.
    ///
    /// Corresponds to the `invoke` function of the `ADD` kernel in TFLite Micro.
    fn add(&mut self, params: ElementwiseAddParams<'_>) -> Status;

    /// Returns the scratch buffer size in bytes required by [`Self::conv2d`].
    ///
    /// Called by `ember-infer-macros` at **compile time** to allocate scratch arrays
    /// in the generated inference function. Corresponds to
    /// `esp_nn_get_conv_scratch_size` / CMSIS-NN equivalents.
    ///
    /// The default implementation returns `0` (no scratch required), which is
    /// correct for pure-Rust reference backends.
    fn conv2d_scratch_size(
        input_shape: [usize; 4],
        weights_shape: [usize; 4],
        output_shape: [usize; 4],
    ) -> usize
    where
        Self: Sized,
    {
        let _ = (input_shape, weights_shape, output_shape);
        0
    }

    /// Returns the scratch buffer size in bytes required by [`Self::depthwise_conv2d`].
    fn depthwise_conv2d_scratch_size(
        input_shape: [usize; 4],
        weights_shape: [usize; 4],
        output_shape: [usize; 4],
    ) -> usize
    where
        Self: Sized,
    {
        let _ = (input_shape, weights_shape, output_shape);
        0
    }

    /// Returns the scratch buffer size in bytes required by [`Self::softmax`].
    fn softmax_scratch_size(num_classes: usize) -> usize
    where
        Self: Sized,
    {
        let _ = num_classes;
        0
    }
}