Skip to main content

ember_infer_core/
lib.rs

1#![no_std]
2#![deny(missing_docs)]
3
4//! # ember-infer-core
5//!
6//! Core trait definitions for the ember-rs embedded TinyML inference engine.
7//!
8//! ember-rs is a `no_std` INT8 inference engine designed to be the
9//! "Burn of embedded inference" - providing a pluggable [`KernelBackend`] trait
10//! so different hardware backends can be swapped without changing model code.
11//!
12//! ## Design
13//!
14//! - Param struct field names mirror TFLite Micro's C structs
15//!   (`TfLiteConvParams`, `TfLiteDepthwiseConvParams`, etc.)
16//! - The `invoke` phase is covered by this trait; the `prepare` phase
17//!   (scratch size calculation, shape inference) is handled at compile time
18//!   by `ember-infer-macros`
19//! - `scratch_size_*` functions have default implementations returning `0`,
20//!   so pure-Rust reference backends don't need to implement them
21//!
22//! ## Backend implementations
23//!
24//! - `ember-infer-ref`: pure Rust reference implementation (for testing / non-ESP platforms)
25//! - `ember-esp`: official ESP32-S3 backend using Espressif's esp-nn SIMD kernels
26//!   (maintained in a separate repository)
27
28// ----------------------------------------------------------------------------
29// Enums - mirror TFLite Micro's C enums
30// ----------------------------------------------------------------------------
31
32/// Padding strategy, mirrors `TfLitePadding`.
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum Padding {
35    /// Output size equals `ceil(input_size / stride)`.
36    Same,
37    /// No padding; output size equals `floor((input_size - filter_size) / stride) + 1`.
38    Valid,
39}
40
41/// Fused activation function applied after an operator, mirrors `TfLiteFusedActivation`.
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43pub enum FusedActivation {
44    /// No activation.
45    None,
46    /// ReLU: `max(0, x)`.
47    Relu,
48    /// ReLU6: `min(max(0, x), 6)`.
49    Relu6,
50    /// ReLU with range `[-1, 1]`.
51    ReluN1To1,
52    /// Tanh activation.
53    Tanh,
54    /// Sign bit activation.
55    SignBit,
56    /// Sigmoid activation.
57    Sigmoid,
58}
59
60// ----------------------------------------------------------------------------
61// Quantization params - mirror TfLiteQuantizationParams
62// ----------------------------------------------------------------------------
63
64/// Per-tensor quantization parameters, mirrors `TfLiteQuantizationParams`.
65#[derive(Clone, Copy, Debug)]
66pub struct QuantParam {
67    /// The scale factor: `real_value = scale * (quantized_value - zero_point)`.
68    pub scale: f32,
69    /// The zero point for asymmetric quantization.
70    pub zero_point: i32,
71}
72
73/// Optional per-channel quantization parameters for weight tensors.
74///
75/// TFLite commonly uses per-channel scales for convolution-family weights,
76/// with the quantized dimension matching the output-channel axis.
77#[derive(Clone, Copy, Debug)]
78pub struct PerChannelQuantParam<'a> {
79    /// Scale values, one per quantized channel.
80    pub scales: &'a [f32],
81    /// Zero points, one per quantized channel.
82    pub zero_points: &'a [i32],
83    /// Tensor axis the per-channel values apply to.
84    pub quantized_dimension: usize,
85}
86
87// ----------------------------------------------------------------------------
88// Error / Status - mirror TfLiteStatus
89// ----------------------------------------------------------------------------
90
91/// Errors that a [`KernelBackend`] implementation may return.
92/// Mirrors the error states of `TfLiteStatus`.
93#[derive(Clone, Copy, Debug, PartialEq, Eq)]
94pub enum KernelError {
95    /// Input or output tensor shapes are invalid for this operation.
96    InvalidShape,
97    /// The requested activation function is not supported by this backend.
98    UnsupportedActivation,
99    /// A buffer passed to the backend does not meet alignment requirements.
100    ///
101    /// ESP-NN kernels require 16-byte aligned buffers. If the `assume-aligned`
102    /// feature is disabled (default), misaligned buffers are automatically
103    /// copied to an aligned scratch region. This error is only returned when
104    /// `assume-aligned` is enabled and the caller violates the contract.
105    AlignmentError,
106    /// An internal error occurred in the backend.
107    InternalError,
108}
109
110/// Result type for all [`KernelBackend`] operations, mirrors `TfLiteStatus`.
111pub type Status = Result<(), KernelError>;
112
113// ----------------------------------------------------------------------------
114// Operator parameter structs
115// Field names mirror TFLite Micro's C structs for easy cross-reference.
116// ----------------------------------------------------------------------------
117
118/// Parameters for a 2D convolution operation.
119///
120/// Mirrors `TfLiteConvParams` from TFLite Micro.
121/// Tensor layout: NHWC (batch, height, width, channels).
122pub struct Conv2dParams<'a> {
123    /// Input tensor data, quantized as `int8`.
124    pub input: &'a [i8],
125    /// Input tensor shape `[N, H, W, C_in]`.
126    pub input_shape: [usize; 4],
127    /// Input quantization parameters.
128    pub input_quant: QuantParam,
129    /// Weight tensor data, quantized as `int8`. Layout: `[C_out, KH, KW, C_in]`.
130    pub weights: &'a [i8],
131    /// Weight tensor shape `[C_out, KH, KW, C_in]`.
132    pub weights_shape: [usize; 4],
133    /// Weight quantization parameters (per-tensor).
134    pub weights_quant: QuantParam,
135    /// Optional per-channel weight quantization parameters.
136    pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
137    /// Optional bias tensor, stored as `int32`.
138    ///
139    /// Length must equal `C_out` when `Some`.
140    pub bias: Option<&'a [i32]>,
141    /// Output tensor buffer, written by the backend.
142    pub output: &'a mut [i8],
143    /// Output tensor shape `[N, H_out, W_out, C_out]`.
144    pub output_shape: [usize; 4],
145    /// Output quantization parameters.
146    pub output_quant: QuantParam,
147    /// Horizontal stride.
148    pub stride_w: i32,
149    /// Vertical stride.
150    pub stride_h: i32,
151    /// Horizontal dilation factor (1 = no dilation).
152    pub dilation_w_factor: i32,
153    /// Vertical dilation factor (1 = no dilation).
154    pub dilation_h_factor: i32,
155    /// Padding mode.
156    pub padding: Padding,
157    /// Fused activation function applied to the output.
158    pub activation: FusedActivation,
159    /// Scratch buffer for intermediate computation.
160    ///
161    /// Required size is reported by [`KernelBackend::conv2d_scratch_size`].
162    /// Pass an empty slice `&mut []` if the backend does not require scratch memory.
163    pub scratch: &'a mut [u8],
164}
165
166/// Parameters for a depthwise 2D convolution operation.
167///
168/// Mirrors `TfLiteDepthwiseConvParams` from TFLite Micro.
169pub struct DepthwiseConv2dParams<'a> {
170    /// Input tensor data, quantized as `int8`.
171    pub input: &'a [i8],
172    /// Input tensor shape `[N, H, W, C_in]`.
173    pub input_shape: [usize; 4],
174    /// Input quantization parameters.
175    pub input_quant: QuantParam,
176    /// Weight tensor data, quantized as `int8`.
177    pub weights: &'a [i8],
178    /// Weight tensor shape.
179    pub weights_shape: [usize; 4],
180    /// Weight quantization parameters (per-tensor).
181    pub weights_quant: QuantParam,
182    /// Optional per-channel weight quantization parameters.
183    pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
184    /// Optional bias tensor, stored as `int32`.
185    pub bias: Option<&'a [i32]>,
186    /// Output tensor buffer, written by the backend.
187    pub output: &'a mut [i8],
188    /// Output tensor shape `[N, H_out, W_out, C_out]`.
189    pub output_shape: [usize; 4],
190    /// Output quantization parameters.
191    pub output_quant: QuantParam,
192    /// Horizontal stride.
193    pub stride_w: i32,
194    /// Vertical stride.
195    pub stride_h: i32,
196    /// Horizontal dilation factor (1 = no dilation).
197    pub dilation_w_factor: i32,
198    /// Vertical dilation factor (1 = no dilation).
199    pub dilation_h_factor: i32,
200    /// Depth multiplier - the number of output channels per input channel.
201    ///
202    /// Specific to depthwise convolution; mirrors
203    /// `TfLiteDepthwiseConvParams::depth_multiplier`.
204    pub depth_multiplier: i32,
205    /// Padding mode.
206    pub padding: Padding,
207    /// Fused activation function applied to the output.
208    pub activation: FusedActivation,
209    /// Scratch buffer for intermediate computation.
210    pub scratch: &'a mut [u8],
211}
212
213/// Parameters for a fully-connected (dense) layer.
214///
215/// Mirrors `TfLiteFullyConnectedParams` from TFLite Micro.
216pub struct FullyConnectedParams<'a> {
217    /// Input tensor data, quantized as `int8`.
218    pub input: &'a [i8],
219    /// Input quantization parameters.
220    pub input_quant: QuantParam,
221    /// Weight tensor data. Layout: `[output_depth, input_depth]`.
222    pub weights: &'a [i8],
223    /// Weight tensor shape `[output_depth, input_depth]`.
224    pub weights_shape: [usize; 2],
225    /// Weight quantization parameters (per-tensor).
226    pub weights_quant: QuantParam,
227    /// Optional per-channel weight quantization parameters.
228    pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
229    /// Optional bias tensor, stored as `int32`.
230    pub bias: Option<&'a [i32]>,
231    /// Output tensor buffer, written by the backend.
232    pub output: &'a mut [i8],
233    /// Number of output neurons.
234    pub output_depth: usize,
235    /// Output quantization parameters.
236    pub output_quant: QuantParam,
237    /// Fused activation function applied to the output.
238    pub activation: FusedActivation,
239}
240
241/// Parameters for a pooling operation (average or max).
242///
243/// Mirrors `TfLitePoolParams` from TFLite Micro.
244pub struct PoolParams<'a> {
245    /// Input tensor data, quantized as `int8`.
246    pub input: &'a [i8],
247    /// Input tensor shape `[N, H, W, C]`.
248    pub input_shape: [usize; 4],
249    /// Input quantization parameters.
250    pub input_quant: QuantParam,
251    /// Output tensor buffer, written by the backend.
252    pub output: &'a mut [i8],
253    /// Output tensor shape `[N, H_out, W_out, C]`.
254    pub output_shape: [usize; 4],
255    /// Output quantization parameters.
256    pub output_quant: QuantParam,
257    /// Horizontal stride.
258    pub stride_w: i32,
259    /// Vertical stride.
260    pub stride_h: i32,
261    /// Pooling filter width.
262    pub filter_w: i32,
263    /// Pooling filter height.
264    pub filter_h: i32,
265    /// Padding mode.
266    pub padding: Padding,
267    /// Fused activation function applied to the output.
268    pub activation: FusedActivation,
269}
270
271/// Parameters for the softmax operation.
272///
273/// Mirrors `TfLiteSoftmaxParams` from TFLite Micro.
274pub struct SoftmaxParams<'a> {
275    /// Input tensor data, quantized as `int8`.
276    pub input: &'a [i8],
277    /// Input shape `[batch, num_classes]`.
278    pub input_shape: [usize; 2],
279    /// Input quantization parameters.
280    pub input_quant: QuantParam,
281    /// Output tensor buffer, written by the backend.
282    pub output: &'a mut [i8],
283    /// Output quantization parameters.
284    pub output_quant: QuantParam,
285    /// Softmax beta parameter (typically `1.0`).
286    ///
287    /// Mirrors `TfLiteSoftmaxParams::beta`.
288    pub beta: f32,
289    /// Scratch buffer for intermediate computation.
290    pub scratch: &'a mut [u8],
291}
292
293/// Parameters for element-wise addition.
294///
295/// Mirrors `TfLiteAddParams` from TFLite Micro.
296pub struct ElementwiseAddParams<'a> {
297    /// First input tensor data, quantized as `int8`.
298    pub input1: &'a [i8],
299    /// First input quantization parameters.
300    pub input1_quant: QuantParam,
301    /// Second input tensor data, quantized as `int8`.
302    pub input2: &'a [i8],
303    /// Second input quantization parameters.
304    pub input2_quant: QuantParam,
305    /// Output tensor buffer, written by the backend.
306    pub output: &'a mut [i8],
307    /// Output quantization parameters.
308    pub output_quant: QuantParam,
309    /// Fused activation function applied to the output.
310    pub activation: FusedActivation,
311}
312
313// ----------------------------------------------------------------------------
314// KernelBackend - the central trait
315// ----------------------------------------------------------------------------
316
317/// The core abstraction for ember-rs: a hardware-specific INT8 inference backend.
318///
319/// # Design
320///
321/// This trait covers the **invoke phase** only. The **prepare phase**
322/// (scratch buffer sizing, shape inference) is performed at compile time by
323/// `ember-infer-macros` via the `conv2d_scratch_size` / `softmax_scratch_size`
324/// associated functions, which have default implementations returning `0`.
325///
326/// Implementations map directly onto TFLite Micro kernel `invoke` functions,
327/// which means porting an existing TFLite Micro optimized kernel (e.g., CMSIS-NN,
328/// esp-nn) to ember-rs requires minimal glue code.
329///
330/// # Implementing a backend
331///
332/// ```rust,ignore
333/// use ember_infer_core::{KernelBackend, Conv2dParams, Status};
334///
335/// pub struct MyBackend;
336///
337/// impl KernelBackend for MyBackend {
338///     fn conv2d(&mut self, params: Conv2dParams<'_>) -> Status {
339///         // call your hardware-accelerated kernel here
340///         todo!()
341///     }
342///     // ... implement remaining required methods
343/// }
344/// ```
345///
346/// # Scratch buffers
347///
348/// Backends that require scratch memory (e.g., esp-nn, CMSIS-NN) must override
349/// the `*_scratch_size` associated functions. The `ember-infer-macros` proc macro calls
350/// these at compile time to allocate correctly-sized scratch arrays in the
351/// generated inference function.
352pub trait KernelBackend {
353    /// Execute a 2D convolution.
354    ///
355    /// Corresponds to the `invoke` function of the `CONV_2D` kernel in TFLite Micro.
356    fn conv2d(&mut self, params: Conv2dParams<'_>) -> Status;
357
358    /// Execute a depthwise 2D convolution.
359    ///
360    /// Corresponds to the `invoke` function of the `DEPTHWISE_CONV_2D` kernel
361    /// in TFLite Micro.
362    fn depthwise_conv2d(&mut self, params: DepthwiseConv2dParams<'_>) -> Status;
363
364    /// Execute a fully-connected layer.
365    ///
366    /// Corresponds to the `invoke` function of the `FULLY_CONNECTED` kernel
367    /// in TFLite Micro.
368    fn fully_connected(&mut self, params: FullyConnectedParams<'_>) -> Status;
369
370    /// Execute average pooling.
371    ///
372    /// Corresponds to the `invoke` function of the `AVERAGE_POOL_2D` kernel
373    /// in TFLite Micro.
374    fn avg_pool(&mut self, params: PoolParams<'_>) -> Status;
375
376    /// Execute max pooling.
377    ///
378    /// Corresponds to the `invoke` function of the `MAX_POOL_2D` kernel
379    /// in TFLite Micro.
380    fn max_pool(&mut self, params: PoolParams<'_>) -> Status;
381
382    /// Execute softmax.
383    ///
384    /// Corresponds to the `invoke` function of the `SOFTMAX` kernel
385    /// in TFLite Micro.
386    fn softmax(&mut self, params: SoftmaxParams<'_>) -> Status;
387
388    /// Execute element-wise addition.
389    ///
390    /// Corresponds to the `invoke` function of the `ADD` kernel in TFLite Micro.
391    fn add(&mut self, params: ElementwiseAddParams<'_>) -> Status;
392
393    /// Returns the scratch buffer size in bytes required by [`Self::conv2d`].
394    ///
395    /// Called by `ember-infer-macros` at **compile time** to allocate scratch arrays
396    /// in the generated inference function. Corresponds to
397    /// `esp_nn_get_conv_scratch_size` / CMSIS-NN equivalents.
398    ///
399    /// The default implementation returns `0` (no scratch required), which is
400    /// correct for pure-Rust reference backends.
401    fn conv2d_scratch_size(
402        input_shape: [usize; 4],
403        weights_shape: [usize; 4],
404        output_shape: [usize; 4],
405    ) -> usize
406    where
407        Self: Sized,
408    {
409        let _ = (input_shape, weights_shape, output_shape);
410        0
411    }
412
413    /// Returns the scratch buffer size in bytes required by [`Self::depthwise_conv2d`].
414    fn depthwise_conv2d_scratch_size(
415        input_shape: [usize; 4],
416        weights_shape: [usize; 4],
417        output_shape: [usize; 4],
418    ) -> usize
419    where
420        Self: Sized,
421    {
422        let _ = (input_shape, weights_shape, output_shape);
423        0
424    }
425
426    /// Returns the scratch buffer size in bytes required by [`Self::softmax`].
427    fn softmax_scratch_size(num_classes: usize) -> usize
428    where
429        Self: Sized,
430    {
431        let _ = num_classes;
432        0
433    }
434}