ember_infer_core/lib.rs
1#![no_std]
2#![deny(missing_docs)]
3
4//! # ember-infer-core
5//!
6//! Core trait definitions for the ember-rs embedded TinyML inference engine.
7//!
8//! ember-rs is a `no_std` INT8 inference engine designed to be the
9//! "Burn of embedded inference" - providing a pluggable [`KernelBackend`] trait
10//! so different hardware backends can be swapped without changing model code.
11//!
12//! ## Design
13//!
14//! - Param struct field names mirror TFLite Micro's C structs
15//! (`TfLiteConvParams`, `TfLiteDepthwiseConvParams`, etc.)
16//! - The `invoke` phase is covered by this trait; the `prepare` phase
17//! (scratch size calculation, shape inference) is handled at compile time
18//! by `ember-infer-macros`
19//! - `scratch_size_*` functions have default implementations returning `0`,
20//! so pure-Rust reference backends don't need to implement them
21//!
22//! ## Backend implementations
23//!
24//! - `ember-infer-ref`: pure Rust reference implementation (for testing / non-ESP platforms)
25//! - `ember-esp`: official ESP32-S3 backend using Espressif's esp-nn SIMD kernels
26//! (maintained in a separate repository)
27
28// ----------------------------------------------------------------------------
29// Enums - mirror TFLite Micro's C enums
30// ----------------------------------------------------------------------------
31
32/// Padding strategy, mirrors `TfLitePadding`.
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
34pub enum Padding {
35 /// Output size equals `ceil(input_size / stride)`.
36 Same,
37 /// No padding; output size equals `floor((input_size - filter_size) / stride) + 1`.
38 Valid,
39}
40
41/// Fused activation function applied after an operator, mirrors `TfLiteFusedActivation`.
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43pub enum FusedActivation {
44 /// No activation.
45 None,
46 /// ReLU: `max(0, x)`.
47 Relu,
48 /// ReLU6: `min(max(0, x), 6)`.
49 Relu6,
50 /// ReLU with range `[-1, 1]`.
51 ReluN1To1,
52 /// Tanh activation.
53 Tanh,
54 /// Sign bit activation.
55 SignBit,
56 /// Sigmoid activation.
57 Sigmoid,
58}
59
60// ----------------------------------------------------------------------------
61// Quantization params - mirror TfLiteQuantizationParams
62// ----------------------------------------------------------------------------
63
64/// Per-tensor quantization parameters, mirrors `TfLiteQuantizationParams`.
65#[derive(Clone, Copy, Debug)]
66pub struct QuantParam {
67 /// The scale factor: `real_value = scale * (quantized_value - zero_point)`.
68 pub scale: f32,
69 /// The zero point for asymmetric quantization.
70 pub zero_point: i32,
71}
72
73/// Optional per-channel quantization parameters for weight tensors.
74///
75/// TFLite commonly uses per-channel scales for convolution-family weights,
76/// with the quantized dimension matching the output-channel axis.
77#[derive(Clone, Copy, Debug)]
78pub struct PerChannelQuantParam<'a> {
79 /// Scale values, one per quantized channel.
80 pub scales: &'a [f32],
81 /// Zero points, one per quantized channel.
82 pub zero_points: &'a [i32],
83 /// Tensor axis the per-channel values apply to.
84 pub quantized_dimension: usize,
85}
86
87// ----------------------------------------------------------------------------
88// Error / Status - mirror TfLiteStatus
89// ----------------------------------------------------------------------------
90
91/// Errors that a [`KernelBackend`] implementation may return.
92/// Mirrors the error states of `TfLiteStatus`.
93#[derive(Clone, Copy, Debug, PartialEq, Eq)]
94pub enum KernelError {
95 /// Input or output tensor shapes are invalid for this operation.
96 InvalidShape,
97 /// The requested activation function is not supported by this backend.
98 UnsupportedActivation,
99 /// A buffer passed to the backend does not meet alignment requirements.
100 ///
101 /// ESP-NN kernels require 16-byte aligned buffers. If the `assume-aligned`
102 /// feature is disabled (default), misaligned buffers are automatically
103 /// copied to an aligned scratch region. This error is only returned when
104 /// `assume-aligned` is enabled and the caller violates the contract.
105 AlignmentError,
106 /// An internal error occurred in the backend.
107 InternalError,
108}
109
110/// Result type for all [`KernelBackend`] operations, mirrors `TfLiteStatus`.
111pub type Status = Result<(), KernelError>;
112
113// ----------------------------------------------------------------------------
114// Operator parameter structs
115// Field names mirror TFLite Micro's C structs for easy cross-reference.
116// ----------------------------------------------------------------------------
117
118/// Parameters for a 2D convolution operation.
119///
120/// Mirrors `TfLiteConvParams` from TFLite Micro.
121/// Tensor layout: NHWC (batch, height, width, channels).
122pub struct Conv2dParams<'a> {
123 /// Input tensor data, quantized as `int8`.
124 pub input: &'a [i8],
125 /// Input tensor shape `[N, H, W, C_in]`.
126 pub input_shape: [usize; 4],
127 /// Input quantization parameters.
128 pub input_quant: QuantParam,
129 /// Weight tensor data, quantized as `int8`. Layout: `[C_out, KH, KW, C_in]`.
130 pub weights: &'a [i8],
131 /// Weight tensor shape `[C_out, KH, KW, C_in]`.
132 pub weights_shape: [usize; 4],
133 /// Weight quantization parameters (per-tensor).
134 pub weights_quant: QuantParam,
135 /// Optional per-channel weight quantization parameters.
136 pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
137 /// Optional bias tensor, stored as `int32`.
138 ///
139 /// Length must equal `C_out` when `Some`.
140 pub bias: Option<&'a [i32]>,
141 /// Output tensor buffer, written by the backend.
142 pub output: &'a mut [i8],
143 /// Output tensor shape `[N, H_out, W_out, C_out]`.
144 pub output_shape: [usize; 4],
145 /// Output quantization parameters.
146 pub output_quant: QuantParam,
147 /// Horizontal stride.
148 pub stride_w: i32,
149 /// Vertical stride.
150 pub stride_h: i32,
151 /// Horizontal dilation factor (1 = no dilation).
152 pub dilation_w_factor: i32,
153 /// Vertical dilation factor (1 = no dilation).
154 pub dilation_h_factor: i32,
155 /// Padding mode.
156 pub padding: Padding,
157 /// Fused activation function applied to the output.
158 pub activation: FusedActivation,
159 /// Scratch buffer for intermediate computation.
160 ///
161 /// Required size is reported by [`KernelBackend::conv2d_scratch_size`].
162 /// Pass an empty slice `&mut []` if the backend does not require scratch memory.
163 pub scratch: &'a mut [u8],
164}
165
166/// Parameters for a depthwise 2D convolution operation.
167///
168/// Mirrors `TfLiteDepthwiseConvParams` from TFLite Micro.
169pub struct DepthwiseConv2dParams<'a> {
170 /// Input tensor data, quantized as `int8`.
171 pub input: &'a [i8],
172 /// Input tensor shape `[N, H, W, C_in]`.
173 pub input_shape: [usize; 4],
174 /// Input quantization parameters.
175 pub input_quant: QuantParam,
176 /// Weight tensor data, quantized as `int8`.
177 pub weights: &'a [i8],
178 /// Weight tensor shape.
179 pub weights_shape: [usize; 4],
180 /// Weight quantization parameters (per-tensor).
181 pub weights_quant: QuantParam,
182 /// Optional per-channel weight quantization parameters.
183 pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
184 /// Optional bias tensor, stored as `int32`.
185 pub bias: Option<&'a [i32]>,
186 /// Output tensor buffer, written by the backend.
187 pub output: &'a mut [i8],
188 /// Output tensor shape `[N, H_out, W_out, C_out]`.
189 pub output_shape: [usize; 4],
190 /// Output quantization parameters.
191 pub output_quant: QuantParam,
192 /// Horizontal stride.
193 pub stride_w: i32,
194 /// Vertical stride.
195 pub stride_h: i32,
196 /// Horizontal dilation factor (1 = no dilation).
197 pub dilation_w_factor: i32,
198 /// Vertical dilation factor (1 = no dilation).
199 pub dilation_h_factor: i32,
200 /// Depth multiplier - the number of output channels per input channel.
201 ///
202 /// Specific to depthwise convolution; mirrors
203 /// `TfLiteDepthwiseConvParams::depth_multiplier`.
204 pub depth_multiplier: i32,
205 /// Padding mode.
206 pub padding: Padding,
207 /// Fused activation function applied to the output.
208 pub activation: FusedActivation,
209 /// Scratch buffer for intermediate computation.
210 pub scratch: &'a mut [u8],
211}
212
213/// Parameters for a fully-connected (dense) layer.
214///
215/// Mirrors `TfLiteFullyConnectedParams` from TFLite Micro.
216pub struct FullyConnectedParams<'a> {
217 /// Input tensor data, quantized as `int8`.
218 pub input: &'a [i8],
219 /// Input quantization parameters.
220 pub input_quant: QuantParam,
221 /// Weight tensor data. Layout: `[output_depth, input_depth]`.
222 pub weights: &'a [i8],
223 /// Weight tensor shape `[output_depth, input_depth]`.
224 pub weights_shape: [usize; 2],
225 /// Weight quantization parameters (per-tensor).
226 pub weights_quant: QuantParam,
227 /// Optional per-channel weight quantization parameters.
228 pub weights_per_channel_quant: Option<PerChannelQuantParam<'a>>,
229 /// Optional bias tensor, stored as `int32`.
230 pub bias: Option<&'a [i32]>,
231 /// Output tensor buffer, written by the backend.
232 pub output: &'a mut [i8],
233 /// Number of output neurons.
234 pub output_depth: usize,
235 /// Output quantization parameters.
236 pub output_quant: QuantParam,
237 /// Fused activation function applied to the output.
238 pub activation: FusedActivation,
239}
240
241/// Parameters for a pooling operation (average or max).
242///
243/// Mirrors `TfLitePoolParams` from TFLite Micro.
244pub struct PoolParams<'a> {
245 /// Input tensor data, quantized as `int8`.
246 pub input: &'a [i8],
247 /// Input tensor shape `[N, H, W, C]`.
248 pub input_shape: [usize; 4],
249 /// Input quantization parameters.
250 pub input_quant: QuantParam,
251 /// Output tensor buffer, written by the backend.
252 pub output: &'a mut [i8],
253 /// Output tensor shape `[N, H_out, W_out, C]`.
254 pub output_shape: [usize; 4],
255 /// Output quantization parameters.
256 pub output_quant: QuantParam,
257 /// Horizontal stride.
258 pub stride_w: i32,
259 /// Vertical stride.
260 pub stride_h: i32,
261 /// Pooling filter width.
262 pub filter_w: i32,
263 /// Pooling filter height.
264 pub filter_h: i32,
265 /// Padding mode.
266 pub padding: Padding,
267 /// Fused activation function applied to the output.
268 pub activation: FusedActivation,
269}
270
271/// Parameters for the softmax operation.
272///
273/// Mirrors `TfLiteSoftmaxParams` from TFLite Micro.
274pub struct SoftmaxParams<'a> {
275 /// Input tensor data, quantized as `int8`.
276 pub input: &'a [i8],
277 /// Input shape `[batch, num_classes]`.
278 pub input_shape: [usize; 2],
279 /// Input quantization parameters.
280 pub input_quant: QuantParam,
281 /// Output tensor buffer, written by the backend.
282 pub output: &'a mut [i8],
283 /// Output quantization parameters.
284 pub output_quant: QuantParam,
285 /// Softmax beta parameter (typically `1.0`).
286 ///
287 /// Mirrors `TfLiteSoftmaxParams::beta`.
288 pub beta: f32,
289 /// Scratch buffer for intermediate computation.
290 pub scratch: &'a mut [u8],
291}
292
293/// Parameters for element-wise addition.
294///
295/// Mirrors `TfLiteAddParams` from TFLite Micro.
296pub struct ElementwiseAddParams<'a> {
297 /// First input tensor data, quantized as `int8`.
298 pub input1: &'a [i8],
299 /// First input quantization parameters.
300 pub input1_quant: QuantParam,
301 /// Second input tensor data, quantized as `int8`.
302 pub input2: &'a [i8],
303 /// Second input quantization parameters.
304 pub input2_quant: QuantParam,
305 /// Output tensor buffer, written by the backend.
306 pub output: &'a mut [i8],
307 /// Output quantization parameters.
308 pub output_quant: QuantParam,
309 /// Fused activation function applied to the output.
310 pub activation: FusedActivation,
311}
312
313// ----------------------------------------------------------------------------
314// KernelBackend - the central trait
315// ----------------------------------------------------------------------------
316
317/// The core abstraction for ember-rs: a hardware-specific INT8 inference backend.
318///
319/// # Design
320///
321/// This trait covers the **invoke phase** only. The **prepare phase**
322/// (scratch buffer sizing, shape inference) is performed at compile time by
323/// `ember-infer-macros` via the `conv2d_scratch_size` / `softmax_scratch_size`
324/// associated functions, which have default implementations returning `0`.
325///
326/// Implementations map directly onto TFLite Micro kernel `invoke` functions,
327/// which means porting an existing TFLite Micro optimized kernel (e.g., CMSIS-NN,
328/// esp-nn) to ember-rs requires minimal glue code.
329///
330/// # Implementing a backend
331///
332/// ```rust,ignore
333/// use ember_infer_core::{KernelBackend, Conv2dParams, Status};
334///
335/// pub struct MyBackend;
336///
337/// impl KernelBackend for MyBackend {
338/// fn conv2d(&mut self, params: Conv2dParams<'_>) -> Status {
339/// // call your hardware-accelerated kernel here
340/// todo!()
341/// }
342/// // ... implement remaining required methods
343/// }
344/// ```
345///
346/// # Scratch buffers
347///
348/// Backends that require scratch memory (e.g., esp-nn, CMSIS-NN) must override
349/// the `*_scratch_size` associated functions. The `ember-infer-macros` proc macro calls
350/// these at compile time to allocate correctly-sized scratch arrays in the
351/// generated inference function.
352pub trait KernelBackend {
353 /// Execute a 2D convolution.
354 ///
355 /// Corresponds to the `invoke` function of the `CONV_2D` kernel in TFLite Micro.
356 fn conv2d(&mut self, params: Conv2dParams<'_>) -> Status;
357
358 /// Execute a depthwise 2D convolution.
359 ///
360 /// Corresponds to the `invoke` function of the `DEPTHWISE_CONV_2D` kernel
361 /// in TFLite Micro.
362 fn depthwise_conv2d(&mut self, params: DepthwiseConv2dParams<'_>) -> Status;
363
364 /// Execute a fully-connected layer.
365 ///
366 /// Corresponds to the `invoke` function of the `FULLY_CONNECTED` kernel
367 /// in TFLite Micro.
368 fn fully_connected(&mut self, params: FullyConnectedParams<'_>) -> Status;
369
370 /// Execute average pooling.
371 ///
372 /// Corresponds to the `invoke` function of the `AVERAGE_POOL_2D` kernel
373 /// in TFLite Micro.
374 fn avg_pool(&mut self, params: PoolParams<'_>) -> Status;
375
376 /// Execute max pooling.
377 ///
378 /// Corresponds to the `invoke` function of the `MAX_POOL_2D` kernel
379 /// in TFLite Micro.
380 fn max_pool(&mut self, params: PoolParams<'_>) -> Status;
381
382 /// Execute softmax.
383 ///
384 /// Corresponds to the `invoke` function of the `SOFTMAX` kernel
385 /// in TFLite Micro.
386 fn softmax(&mut self, params: SoftmaxParams<'_>) -> Status;
387
388 /// Execute element-wise addition.
389 ///
390 /// Corresponds to the `invoke` function of the `ADD` kernel in TFLite Micro.
391 fn add(&mut self, params: ElementwiseAddParams<'_>) -> Status;
392
393 /// Returns the scratch buffer size in bytes required by [`Self::conv2d`].
394 ///
395 /// Called by `ember-infer-macros` at **compile time** to allocate scratch arrays
396 /// in the generated inference function. Corresponds to
397 /// `esp_nn_get_conv_scratch_size` / CMSIS-NN equivalents.
398 ///
399 /// The default implementation returns `0` (no scratch required), which is
400 /// correct for pure-Rust reference backends.
401 fn conv2d_scratch_size(
402 input_shape: [usize; 4],
403 weights_shape: [usize; 4],
404 output_shape: [usize; 4],
405 ) -> usize
406 where
407 Self: Sized,
408 {
409 let _ = (input_shape, weights_shape, output_shape);
410 0
411 }
412
413 /// Returns the scratch buffer size in bytes required by [`Self::depthwise_conv2d`].
414 fn depthwise_conv2d_scratch_size(
415 input_shape: [usize; 4],
416 weights_shape: [usize; 4],
417 output_shape: [usize; 4],
418 ) -> usize
419 where
420 Self: Sized,
421 {
422 let _ = (input_shape, weights_shape, output_shape);
423 0
424 }
425
426 /// Returns the scratch buffer size in bytes required by [`Self::softmax`].
427 fn softmax_scratch_size(num_classes: usize) -> usize
428 where
429 Self: Sized,
430 {
431 let _ = num_classes;
432 0
433 }
434}