Skip to main content

ember_infer_ref/
lib.rs

1#![no_std]
2
3extern crate alloc;
4
5use alloc::vec;
6use alloc::vec::Vec;
7use ember_infer_core::{
8    Conv2dParams, DepthwiseConv2dParams, ElementwiseAddParams, FullyConnectedParams,
9    FusedActivation, KernelBackend, KernelError, Padding, PerChannelQuantParam, PoolParams,
10    QuantParam, SoftmaxParams, Status,
11};
12
13/// Pure Rust reference implementation of [`KernelBackend`].
14///
15/// Used for CI testing and platforms where no hardware-accelerated backend is
16/// available.
17pub struct RefBackend;
18
19impl KernelBackend for RefBackend {
20    fn conv2d(&mut self, params: Conv2dParams<'_>) -> Status {
21        validate_len(params.input, product(&params.input_shape))?;
22        validate_len(params.weights, product(&params.weights_shape))?;
23        validate_len(params.output, product(&params.output_shape))?;
24
25        let [batches, input_h, input_w, input_c] = params.input_shape;
26        let [output_c, filter_h, filter_w, filter_input_c] = params.weights_shape;
27        let [output_batches, output_h, output_w, output_shape_c] = params.output_shape;
28
29        if batches != output_batches || input_c != filter_input_c || output_c != output_shape_c {
30            return Err(KernelError::InvalidShape);
31        }
32        validate_bias(params.bias, output_c)?;
33
34        let stride_h = positive_i32_to_usize(params.stride_h)?;
35        let stride_w = positive_i32_to_usize(params.stride_w)?;
36        let dilation_h = positive_i32_to_usize(params.dilation_h_factor)?;
37        let dilation_w = positive_i32_to_usize(params.dilation_w_factor)?;
38        let effective_filter_h = effective_filter_size(filter_h, dilation_h);
39        let effective_filter_w = effective_filter_size(filter_w, dilation_w);
40        let pad_h = compute_padding(input_h, effective_filter_h, stride_h, params.padding);
41        let pad_w = compute_padding(input_w, effective_filter_w, stride_w, params.padding);
42        for batch in 0..batches {
43            for out_y in 0..output_h {
44                for out_x in 0..output_w {
45                    for out_channel in 0..output_c {
46                        let (multiplier, shift) = output_channel_multiplier_shift(
47                            params.input_quant,
48                            params.weights_quant,
49                            params.weights_per_channel_quant,
50                            params.output_quant,
51                            out_channel,
52                        );
53                        let mut acc = params
54                            .bias
55                            .map(|bias| bias[out_channel])
56                            .unwrap_or_default();
57
58                        for filter_y in 0..filter_h {
59                            let in_y = out_y * stride_h + filter_y * dilation_h;
60                            if in_y < pad_h || in_y >= input_h + pad_h {
61                                continue;
62                            }
63                            let in_y = in_y - pad_h;
64
65                            for filter_x in 0..filter_w {
66                                let in_x = out_x * stride_w + filter_x * dilation_w;
67                                if in_x < pad_w || in_x >= input_w + pad_w {
68                                    continue;
69                                }
70                                let in_x = in_x - pad_w;
71
72                                for in_channel in 0..input_c {
73                                    let input = params.input[nhwc_index(
74                                        batch, in_y, in_x, in_channel, input_h, input_w, input_c,
75                                    )] as i32
76                                        - params.input_quant.zero_point;
77                                    let weight = params.weights[conv_weight_index(
78                                        out_channel,
79                                        filter_y,
80                                        filter_x,
81                                        in_channel,
82                                        filter_h,
83                                        filter_w,
84                                        input_c,
85                                    )] as i32
86                                        - params.weights_quant.zero_point;
87                                    acc = acc.saturating_add(input.saturating_mul(weight));
88                                }
89                            }
90                        }
91
92                        let scaled = requantize(acc, multiplier, shift, params.output_quant);
93                        params.output[nhwc_index(
94                            batch,
95                            out_y,
96                            out_x,
97                            out_channel,
98                            output_h,
99                            output_w,
100                            output_c,
101                        )] = apply_activation(scaled, params.activation, params.output_quant);
102                    }
103                }
104            }
105        }
106
107        Ok(())
108    }
109
110    fn depthwise_conv2d(&mut self, params: DepthwiseConv2dParams<'_>) -> Status {
111        validate_len(params.input, product(&params.input_shape))?;
112        validate_len(params.weights, product(&params.weights_shape))?;
113        validate_len(params.output, product(&params.output_shape))?;
114
115        let [batches, input_h, input_w, input_c] = params.input_shape;
116        let depth_multiplier = positive_i32_to_usize(params.depth_multiplier)?;
117        let depthwise_dims =
118            depthwise_filter_dims(params.weights_shape, input_c, depth_multiplier)?;
119        let [output_batches, output_h, output_w, output_c] = params.output_shape;
120
121        if batches != output_batches
122            || input_c != depthwise_dims.input_channels
123            || depth_multiplier != depthwise_dims.depth_multiplier
124            || output_c != input_c * depth_multiplier
125        {
126            return Err(KernelError::InvalidShape);
127        }
128        validate_bias(params.bias, output_c)?;
129
130        let stride_h = positive_i32_to_usize(params.stride_h)?;
131        let stride_w = positive_i32_to_usize(params.stride_w)?;
132        let dilation_h = positive_i32_to_usize(params.dilation_h_factor)?;
133        let dilation_w = positive_i32_to_usize(params.dilation_w_factor)?;
134        let effective_filter_h = effective_filter_size(depthwise_dims.filter_h, dilation_h);
135        let effective_filter_w = effective_filter_size(depthwise_dims.filter_w, dilation_w);
136        let pad_h = compute_padding(input_h, effective_filter_h, stride_h, params.padding);
137        let pad_w = compute_padding(input_w, effective_filter_w, stride_w, params.padding);
138        for batch in 0..batches {
139            for out_y in 0..output_h {
140                for out_x in 0..output_w {
141                    for in_channel in 0..input_c {
142                        for channel_multiplier in 0..depth_multiplier {
143                            let out_channel = in_channel * depth_multiplier + channel_multiplier;
144                            let (multiplier, shift) = output_channel_multiplier_shift(
145                                params.input_quant,
146                                params.weights_quant,
147                                params.weights_per_channel_quant,
148                                params.output_quant,
149                                out_channel,
150                            );
151                            let mut acc = params
152                                .bias
153                                .map(|bias| bias[out_channel])
154                                .unwrap_or_default();
155
156                            for filter_y in 0..depthwise_dims.filter_h {
157                                let in_y = out_y * stride_h + filter_y * dilation_h;
158                                if in_y < pad_h || in_y >= input_h + pad_h {
159                                    continue;
160                                }
161                                let in_y = in_y - pad_h;
162
163                                for filter_x in 0..depthwise_dims.filter_w {
164                                    let in_x = out_x * stride_w + filter_x * dilation_w;
165                                    if in_x < pad_w || in_x >= input_w + pad_w {
166                                        continue;
167                                    }
168                                    let in_x = in_x - pad_w;
169
170                                    let input = params.input[nhwc_index(
171                                        batch, in_y, in_x, in_channel, input_h, input_w, input_c,
172                                    )] as i32
173                                        - params.input_quant.zero_point;
174                                    let weight = params.weights[depthwise_weight_index(
175                                        filter_y,
176                                        filter_x,
177                                        in_channel,
178                                        channel_multiplier,
179                                        depthwise_dims,
180                                    )] as i32
181                                        - params.weights_quant.zero_point;
182                                    acc = acc.saturating_add(input.saturating_mul(weight));
183                                }
184                            }
185
186                            let scaled = requantize(acc, multiplier, shift, params.output_quant);
187                            params.output[nhwc_index(
188                                batch,
189                                out_y,
190                                out_x,
191                                out_channel,
192                                output_h,
193                                output_w,
194                                output_c,
195                            )] = apply_activation(scaled, params.activation, params.output_quant);
196                        }
197                    }
198                }
199            }
200        }
201
202        Ok(())
203    }
204
205    fn fully_connected(&mut self, params: FullyConnectedParams<'_>) -> Status {
206        validate_len(params.output, params.output_depth)?;
207        let [output_depth, input_depth] = params.weights_shape;
208        if params.output_depth != output_depth
209            || params.weights.len() != output_depth * input_depth
210            || params.input.len() != input_depth
211        {
212            return Err(KernelError::InvalidShape);
213        }
214        validate_bias(params.bias, output_depth)?;
215
216        for out_channel in 0..output_depth {
217            let (multiplier, shift) = output_channel_multiplier_shift(
218                params.input_quant,
219                params.weights_quant,
220                params.weights_per_channel_quant,
221                params.output_quant,
222                out_channel,
223            );
224            let mut acc = params
225                .bias
226                .map(|bias| bias[out_channel])
227                .unwrap_or_default();
228            for in_channel in 0..input_depth {
229                let input = params.input[in_channel] as i32 - params.input_quant.zero_point;
230                let weight = params.weights[out_channel * input_depth + in_channel] as i32
231                    - params.weights_quant.zero_point;
232                acc = acc.saturating_add(input.saturating_mul(weight));
233            }
234
235            let scaled = requantize(acc, multiplier, shift, params.output_quant);
236            params.output[out_channel] =
237                apply_activation(scaled, params.activation, params.output_quant);
238        }
239
240        Ok(())
241    }
242
243    fn avg_pool(&mut self, params: PoolParams<'_>) -> Status {
244        pool(params, PoolKind::Average)
245    }
246
247    fn max_pool(&mut self, params: PoolParams<'_>) -> Status {
248        pool(params, PoolKind::Max)
249    }
250
251    fn softmax(&mut self, params: SoftmaxParams<'_>) -> Status {
252        let [batches, classes] = params.input_shape;
253        if params.input.len() != batches * classes || params.output.len() != batches * classes {
254            return Err(KernelError::InvalidShape);
255        }
256
257        let mut exps: Vec<f32> = vec![0.0; classes];
258        for batch in 0..batches {
259            let offset = batch * classes;
260            let mut max_input = i8::MIN;
261            for class in 0..classes {
262                max_input = max_input.max(params.input[offset + class]);
263            }
264
265            let mut sum = 0.0f32;
266            for (class, exp) in exps.iter_mut().enumerate() {
267                let centered = (params.input[offset + class] as i32 - max_input as i32) as f32;
268                let real = centered * params.input_quant.scale * params.beta;
269                *exp = libm::expf(real);
270                sum += *exp;
271            }
272
273            if sum == 0.0 {
274                return Err(KernelError::InternalError);
275            }
276
277            for (class, exp) in exps.iter().enumerate() {
278                let probability = *exp / sum;
279                let quantized = round_f32_to_i32(probability / params.output_quant.scale)
280                    + params.output_quant.zero_point;
281                params.output[offset + class] = clamp_i8(quantized);
282            }
283        }
284
285        Ok(())
286    }
287
288    fn add(&mut self, params: ElementwiseAddParams<'_>) -> Status {
289        if params.input1.len() != params.input2.len() || params.output.len() != params.input1.len()
290        {
291            return Err(KernelError::InvalidShape);
292        }
293
294        for index in 0..params.output.len() {
295            let lhs = (params.input1[index] as i32 - params.input1_quant.zero_point) as f32
296                * params.input1_quant.scale;
297            let rhs = (params.input2[index] as i32 - params.input2_quant.zero_point) as f32
298                * params.input2_quant.scale;
299            let quantized = round_f32_to_i32((lhs + rhs) / params.output_quant.scale)
300                + params.output_quant.zero_point;
301            params.output[index] =
302                apply_activation(quantized, params.activation, params.output_quant);
303        }
304
305        Ok(())
306    }
307}
308
309#[derive(Clone, Copy)]
310enum PoolKind {
311    Average,
312    Max,
313}
314
315fn pool(params: PoolParams<'_>, kind: PoolKind) -> Status {
316    validate_len(params.input, product(&params.input_shape))?;
317    validate_len(params.output, product(&params.output_shape))?;
318
319    let [batches, input_h, input_w, channels] = params.input_shape;
320    let [output_batches, output_h, output_w, output_channels] = params.output_shape;
321    if batches != output_batches || channels != output_channels {
322        return Err(KernelError::InvalidShape);
323    }
324
325    let stride_h = positive_i32_to_usize(params.stride_h)?;
326    let stride_w = positive_i32_to_usize(params.stride_w)?;
327    let filter_h = positive_i32_to_usize(params.filter_h)?;
328    let filter_w = positive_i32_to_usize(params.filter_w)?;
329    let pad_h = compute_padding(input_h, filter_h, stride_h, params.padding);
330    let pad_w = compute_padding(input_w, filter_w, stride_w, params.padding);
331    let (multiplier, shift) =
332        quantize_multiplier((params.input_quant.scale / params.output_quant.scale) as f64);
333
334    for batch in 0..batches {
335        for out_y in 0..output_h {
336            for out_x in 0..output_w {
337                for channel in 0..channels {
338                    let mut acc = 0i32;
339                    let mut count = 0i32;
340                    let mut max_value = i8::MIN;
341
342                    for filter_y in 0..filter_h {
343                        let in_y = out_y * stride_h + filter_y;
344                        if in_y < pad_h || in_y >= input_h + pad_h {
345                            continue;
346                        }
347                        let in_y = in_y - pad_h;
348
349                        for filter_x in 0..filter_w {
350                            let in_x = out_x * stride_w + filter_x;
351                            if in_x < pad_w || in_x >= input_w + pad_w {
352                                continue;
353                            }
354                            let in_x = in_x - pad_w;
355                            let input = params.input[nhwc_index(
356                                batch, in_y, in_x, channel, input_h, input_w, channels,
357                            )];
358                            acc += input as i32 - params.input_quant.zero_point;
359                            count += 1;
360                            max_value = max_value.max(input);
361                        }
362                    }
363
364                    if count == 0 {
365                        return Err(KernelError::InvalidShape);
366                    }
367
368                    let quantized = match kind {
369                        PoolKind::Average => {
370                            let average = round_divide(acc, count);
371                            requantize(average, multiplier, shift, params.output_quant)
372                        }
373                        PoolKind::Max => {
374                            let centered = max_value as i32 - params.input_quant.zero_point;
375                            requantize(centered, multiplier, shift, params.output_quant)
376                        }
377                    };
378                    params.output
379                        [nhwc_index(batch, out_y, out_x, channel, output_h, output_w, channels)] =
380                        apply_activation(quantized, params.activation, params.output_quant);
381                }
382            }
383        }
384    }
385
386    Ok(())
387}
388
389fn validate_len<T>(slice: &[T], expected: usize) -> Status {
390    if slice.len() == expected {
391        Ok(())
392    } else {
393        Err(KernelError::InvalidShape)
394    }
395}
396
397fn validate_bias(bias: Option<&[i32]>, expected: usize) -> Status {
398    match bias {
399        Some(bias) => validate_len(bias, expected),
400        None => Ok(()),
401    }
402}
403
404fn product<const N: usize>(shape: &[usize; N]) -> usize {
405    shape.iter().product()
406}
407
408fn positive_i32_to_usize(value: i32) -> Result<usize, KernelError> {
409    if value > 0 {
410        Ok(value as usize)
411    } else {
412        Err(KernelError::InvalidShape)
413    }
414}
415
416fn effective_filter_size(filter_size: usize, dilation: usize) -> usize {
417    (filter_size - 1) * dilation + 1
418}
419
420fn nhwc_index(
421    batch: usize,
422    y: usize,
423    x: usize,
424    channel: usize,
425    height: usize,
426    width: usize,
427    channels: usize,
428) -> usize {
429    ((batch * height + y) * width + x) * channels + channel
430}
431
432fn conv_weight_index(
433    output_channel: usize,
434    filter_y: usize,
435    filter_x: usize,
436    input_channel: usize,
437    filter_h: usize,
438    filter_w: usize,
439    input_channels: usize,
440) -> usize {
441    ((output_channel * filter_h + filter_y) * filter_w + filter_x) * input_channels + input_channel
442}
443
444fn depthwise_weight_index(
445    filter_y: usize,
446    filter_x: usize,
447    input_channel: usize,
448    channel_multiplier: usize,
449    dims: DepthwiseDims,
450) -> usize {
451    let output_channel = input_channel * dims.depth_multiplier + channel_multiplier;
452    if dims.tflite_layout {
453        (filter_y * dims.filter_w + filter_x) * (dims.input_channels * dims.depth_multiplier)
454            + output_channel
455    } else {
456        ((filter_y * dims.filter_w + filter_x) * dims.input_channels + input_channel)
457            * dims.depth_multiplier
458            + channel_multiplier
459    }
460}
461
462#[derive(Clone, Copy)]
463struct DepthwiseDims {
464    filter_h: usize,
465    filter_w: usize,
466    input_channels: usize,
467    depth_multiplier: usize,
468    tflite_layout: bool,
469}
470
471fn depthwise_filter_dims(
472    weights_shape: [usize; 4],
473    input_channels: usize,
474    depth_multiplier: usize,
475) -> Result<DepthwiseDims, KernelError> {
476    if weights_shape[0] == 1 {
477        if input_channels == 0 {
478            return Err(KernelError::InvalidShape);
479        }
480        Ok(DepthwiseDims {
481            filter_h: weights_shape[1],
482            filter_w: weights_shape[2],
483            input_channels,
484            depth_multiplier: weights_shape[3] / input_channels,
485            tflite_layout: true,
486        })
487    } else {
488        Ok(DepthwiseDims {
489            filter_h: weights_shape[0],
490            filter_w: weights_shape[1],
491            input_channels: weights_shape[2],
492            depth_multiplier: weights_shape[3],
493            tflite_layout: false,
494        })
495    }
496    .and_then(|dims| {
497        if dims.input_channels == input_channels && dims.depth_multiplier == depth_multiplier {
498            Ok(dims)
499        } else {
500            Err(KernelError::InvalidShape)
501        }
502    })
503}
504
505fn multiply_by_quantized_multiplier(x: i32, multiplier: i32, shift: i32) -> i32 {
506    let total_shift = 31 - shift;
507    if total_shift <= 0 {
508        return saturating_left_shift(x.saturating_mul(multiplier), (-total_shift) as u32);
509    }
510    let round = 1i64 << (total_shift - 1);
511    (((x as i64 * multiplier as i64) + round) >> total_shift) as i32
512}
513
514fn saturating_left_shift(value: i32, shift: u32) -> i32 {
515    if value == 0 {
516        return 0;
517    }
518
519    if shift >= 31 {
520        if value >= 0 {
521            i32::MAX
522        } else {
523            i32::MIN
524        }
525    } else {
526        ((value as i64) << shift).clamp(i32::MIN as i64, i32::MAX as i64) as i32
527    }
528}
529
530fn quantize_multiplier(scale: f64) -> (i32, i32) {
531    if scale <= 0.0 {
532        return (0, 0);
533    }
534
535    let mut significand = scale;
536    let mut shift = 0i32;
537
538    while significand < 0.5 {
539        significand *= 2.0;
540        shift -= 1;
541    }
542    while significand >= 1.0 {
543        significand /= 2.0;
544        shift += 1;
545    }
546
547    let mut q = libm::round(significand * (1i64 << 31) as f64) as i64;
548    if q == 1i64 << 31 {
549        q /= 2;
550        shift += 1;
551    }
552
553    (q as i32, shift)
554}
555
556fn output_channel_multiplier_shift(
557    input_quant: QuantParam,
558    weights_quant: QuantParam,
559    weights_per_channel_quant: Option<PerChannelQuantParam<'_>>,
560    output_quant: QuantParam,
561    output_channel: usize,
562) -> (i32, i32) {
563    let weight_scale = weights_per_channel_quant
564        .and_then(|per_channel| per_channel.scales.get(output_channel).copied())
565        .unwrap_or(weights_quant.scale);
566    quantize_multiplier((input_quant.scale * weight_scale / output_quant.scale) as f64)
567}
568
569fn requantize(acc: i32, multiplier: i32, shift: i32, output_quant: QuantParam) -> i32 {
570    multiply_by_quantized_multiplier(acc, multiplier, shift) + output_quant.zero_point
571}
572
573fn apply_activation(val: i32, activation: FusedActivation, output_quant: QuantParam) -> i8 {
574    let min = match activation {
575        FusedActivation::None | FusedActivation::Sigmoid | FusedActivation::SignBit => {
576            i8::MIN as i32
577        }
578        FusedActivation::Relu | FusedActivation::Relu6 => {
579            (i8::MIN as i32).max(output_quant.zero_point)
580        }
581        FusedActivation::ReluN1To1 | FusedActivation::Tanh => (i8::MIN as i32)
582            .max(output_quant.zero_point + round_f32_to_i32(-1.0 / output_quant.scale)),
583    };
584    let max = match activation {
585        FusedActivation::Relu6 => (i8::MAX as i32)
586            .min(output_quant.zero_point + round_f32_to_i32(6.0 / output_quant.scale)),
587        FusedActivation::ReluN1To1 | FusedActivation::Tanh | FusedActivation::Sigmoid => (i8::MAX
588            as i32)
589            .min(output_quant.zero_point + round_f32_to_i32(1.0 / output_quant.scale)),
590        FusedActivation::None | FusedActivation::Relu | FusedActivation::SignBit => i8::MAX as i32,
591    };
592
593    clamp_i8(val.clamp(min, max))
594}
595
596fn clamp_i8(value: i32) -> i8 {
597    value.clamp(i8::MIN as i32, i8::MAX as i32) as i8
598}
599
600fn compute_padding(
601    input_size: usize,
602    filter_size: usize,
603    stride: usize,
604    padding: Padding,
605) -> usize {
606    match padding {
607        Padding::Valid => 0,
608        Padding::Same => {
609            let out_size = input_size.div_ceil(stride);
610            let pad = ((out_size - 1) * stride + filter_size).saturating_sub(input_size);
611            pad / 2
612        }
613    }
614}
615
616fn round_f32_to_i32(value: f32) -> i32 {
617    libm::roundf(value) as i32
618}
619
620fn round_divide(numerator: i32, denominator: i32) -> i32 {
621    if numerator >= 0 {
622        (numerator + denominator / 2) / denominator
623    } else {
624        (numerator - denominator / 2) / denominator
625    }
626}
627
628#[cfg(test)]
629mod tests {
630    use super::*;
631
632    const UNIT_QUANT: QuantParam = QuantParam {
633        scale: 1.0,
634        zero_point: 0,
635    };
636
637    #[test]
638    fn fully_connected_identity_scale() {
639        let mut backend = RefBackend;
640        let input = [2, -3];
641        let weights = [4, 5, -1, 6];
642        let mut output = [0; 2];
643
644        backend
645            .fully_connected(FullyConnectedParams {
646                input: &input,
647                input_quant: UNIT_QUANT,
648                weights: &weights,
649                weights_shape: [2, 2],
650                weights_quant: UNIT_QUANT,
651                weights_per_channel_quant: None,
652                bias: Some(&[1, -2]),
653                output: &mut output,
654                output_depth: 2,
655                output_quant: UNIT_QUANT,
656                activation: FusedActivation::None,
657            })
658            .unwrap();
659
660        assert_eq!(output, [-6, -22]);
661    }
662
663    #[test]
664    fn add_identity_scale() {
665        let mut backend = RefBackend;
666        let input1 = [1, -2, 3];
667        let input2 = [4, 5, -6];
668        let mut output = [0; 3];
669
670        backend
671            .add(ElementwiseAddParams {
672                input1: &input1,
673                input1_quant: UNIT_QUANT,
674                input2: &input2,
675                input2_quant: UNIT_QUANT,
676                output: &mut output,
677                output_quant: UNIT_QUANT,
678                activation: FusedActivation::None,
679            })
680            .unwrap();
681
682        assert_eq!(output, [5, 3, -3]);
683    }
684
685    #[test]
686    fn avg_pool_valid() {
687        let mut backend = RefBackend;
688        let input = [1, 3, 5, 7];
689        let mut output = [0; 1];
690
691        backend
692            .avg_pool(PoolParams {
693                input: &input,
694                input_shape: [1, 2, 2, 1],
695                input_quant: UNIT_QUANT,
696                output: &mut output,
697                output_shape: [1, 1, 1, 1],
698                output_quant: UNIT_QUANT,
699                stride_w: 1,
700                stride_h: 1,
701                filter_w: 2,
702                filter_h: 2,
703                padding: Padding::Valid,
704                activation: FusedActivation::None,
705            })
706            .unwrap();
707
708        assert_eq!(output, [4]);
709    }
710
711    #[test]
712    fn conv2d_single_filter_valid() {
713        let mut backend = RefBackend;
714        let input = [1, 2, 3, 4];
715        let weights = [1, 0, 0, 1];
716        let mut output = [0; 1];
717
718        backend
719            .conv2d(Conv2dParams {
720                input: &input,
721                input_shape: [1, 2, 2, 1],
722                input_quant: UNIT_QUANT,
723                weights: &weights,
724                weights_shape: [1, 2, 2, 1],
725                weights_quant: UNIT_QUANT,
726                weights_per_channel_quant: None,
727                bias: None,
728                output: &mut output,
729                output_shape: [1, 1, 1, 1],
730                output_quant: UNIT_QUANT,
731                stride_w: 1,
732                stride_h: 1,
733                dilation_w_factor: 1,
734                dilation_h_factor: 1,
735                padding: Padding::Valid,
736                activation: FusedActivation::None,
737                scratch: &mut [],
738            })
739            .unwrap();
740
741        assert_eq!(output, [5]);
742    }
743
744    #[test]
745    fn depthwise_accepts_tflite_filter_layout() {
746        let mut backend = RefBackend;
747        let input = [1, 2, 3, 4];
748        let weights = [1, 0, 0, 1];
749        let mut output = [0; 1];
750
751        backend
752            .depthwise_conv2d(DepthwiseConv2dParams {
753                input: &input,
754                input_shape: [1, 2, 2, 1],
755                input_quant: UNIT_QUANT,
756                weights: &weights,
757                weights_shape: [1, 2, 2, 1],
758                weights_quant: UNIT_QUANT,
759                weights_per_channel_quant: None,
760                bias: None,
761                output: &mut output,
762                output_shape: [1, 1, 1, 1],
763                output_quant: UNIT_QUANT,
764                stride_w: 1,
765                stride_h: 1,
766                dilation_w_factor: 1,
767                dilation_h_factor: 1,
768                depth_multiplier: 1,
769                padding: Padding::Valid,
770                activation: FusedActivation::None,
771                scratch: &mut [],
772            })
773            .unwrap();
774
775        assert_eq!(output, [5]);
776    }
777
778    #[test]
779    fn softmax_outputs_probability_distribution() {
780        let mut backend = RefBackend;
781        let input = [0, 0];
782        let mut output = [0; 2];
783
784        backend
785            .softmax(SoftmaxParams {
786                input: &input,
787                input_shape: [1, 2],
788                input_quant: QuantParam {
789                    scale: 1.0,
790                    zero_point: 0,
791                },
792                output: &mut output,
793                output_quant: QuantParam {
794                    scale: 1.0 / 256.0,
795                    zero_point: -128,
796                },
797                beta: 1.0,
798                scratch: &mut [],
799            })
800            .unwrap();
801
802        assert_eq!(output, [0, 0]);
803    }
804}