1#![no_std]
2
3extern crate alloc;
4
5use alloc::vec;
6use alloc::vec::Vec;
7use ember_infer_core::{
8 Conv2dParams, DepthwiseConv2dParams, ElementwiseAddParams, FullyConnectedParams,
9 FusedActivation, KernelBackend, KernelError, Padding, PerChannelQuantParam, PoolParams,
10 QuantParam, SoftmaxParams, Status,
11};
12
13pub struct RefBackend;
18
19impl KernelBackend for RefBackend {
20 fn conv2d(&mut self, params: Conv2dParams<'_>) -> Status {
21 validate_len(params.input, product(¶ms.input_shape))?;
22 validate_len(params.weights, product(¶ms.weights_shape))?;
23 validate_len(params.output, product(¶ms.output_shape))?;
24
25 let [batches, input_h, input_w, input_c] = params.input_shape;
26 let [output_c, filter_h, filter_w, filter_input_c] = params.weights_shape;
27 let [output_batches, output_h, output_w, output_shape_c] = params.output_shape;
28
29 if batches != output_batches || input_c != filter_input_c || output_c != output_shape_c {
30 return Err(KernelError::InvalidShape);
31 }
32 validate_bias(params.bias, output_c)?;
33
34 let stride_h = positive_i32_to_usize(params.stride_h)?;
35 let stride_w = positive_i32_to_usize(params.stride_w)?;
36 let dilation_h = positive_i32_to_usize(params.dilation_h_factor)?;
37 let dilation_w = positive_i32_to_usize(params.dilation_w_factor)?;
38 let effective_filter_h = effective_filter_size(filter_h, dilation_h);
39 let effective_filter_w = effective_filter_size(filter_w, dilation_w);
40 let pad_h = compute_padding(input_h, effective_filter_h, stride_h, params.padding);
41 let pad_w = compute_padding(input_w, effective_filter_w, stride_w, params.padding);
42 for batch in 0..batches {
43 for out_y in 0..output_h {
44 for out_x in 0..output_w {
45 for out_channel in 0..output_c {
46 let (multiplier, shift) = output_channel_multiplier_shift(
47 params.input_quant,
48 params.weights_quant,
49 params.weights_per_channel_quant,
50 params.output_quant,
51 out_channel,
52 );
53 let mut acc = params
54 .bias
55 .map(|bias| bias[out_channel])
56 .unwrap_or_default();
57
58 for filter_y in 0..filter_h {
59 let in_y = out_y * stride_h + filter_y * dilation_h;
60 if in_y < pad_h || in_y >= input_h + pad_h {
61 continue;
62 }
63 let in_y = in_y - pad_h;
64
65 for filter_x in 0..filter_w {
66 let in_x = out_x * stride_w + filter_x * dilation_w;
67 if in_x < pad_w || in_x >= input_w + pad_w {
68 continue;
69 }
70 let in_x = in_x - pad_w;
71
72 for in_channel in 0..input_c {
73 let input = params.input[nhwc_index(
74 batch, in_y, in_x, in_channel, input_h, input_w, input_c,
75 )] as i32
76 - params.input_quant.zero_point;
77 let weight = params.weights[conv_weight_index(
78 out_channel,
79 filter_y,
80 filter_x,
81 in_channel,
82 filter_h,
83 filter_w,
84 input_c,
85 )] as i32
86 - params.weights_quant.zero_point;
87 acc = acc.saturating_add(input.saturating_mul(weight));
88 }
89 }
90 }
91
92 let scaled = requantize(acc, multiplier, shift, params.output_quant);
93 params.output[nhwc_index(
94 batch,
95 out_y,
96 out_x,
97 out_channel,
98 output_h,
99 output_w,
100 output_c,
101 )] = apply_activation(scaled, params.activation, params.output_quant);
102 }
103 }
104 }
105 }
106
107 Ok(())
108 }
109
110 fn depthwise_conv2d(&mut self, params: DepthwiseConv2dParams<'_>) -> Status {
111 validate_len(params.input, product(¶ms.input_shape))?;
112 validate_len(params.weights, product(¶ms.weights_shape))?;
113 validate_len(params.output, product(¶ms.output_shape))?;
114
115 let [batches, input_h, input_w, input_c] = params.input_shape;
116 let depth_multiplier = positive_i32_to_usize(params.depth_multiplier)?;
117 let depthwise_dims =
118 depthwise_filter_dims(params.weights_shape, input_c, depth_multiplier)?;
119 let [output_batches, output_h, output_w, output_c] = params.output_shape;
120
121 if batches != output_batches
122 || input_c != depthwise_dims.input_channels
123 || depth_multiplier != depthwise_dims.depth_multiplier
124 || output_c != input_c * depth_multiplier
125 {
126 return Err(KernelError::InvalidShape);
127 }
128 validate_bias(params.bias, output_c)?;
129
130 let stride_h = positive_i32_to_usize(params.stride_h)?;
131 let stride_w = positive_i32_to_usize(params.stride_w)?;
132 let dilation_h = positive_i32_to_usize(params.dilation_h_factor)?;
133 let dilation_w = positive_i32_to_usize(params.dilation_w_factor)?;
134 let effective_filter_h = effective_filter_size(depthwise_dims.filter_h, dilation_h);
135 let effective_filter_w = effective_filter_size(depthwise_dims.filter_w, dilation_w);
136 let pad_h = compute_padding(input_h, effective_filter_h, stride_h, params.padding);
137 let pad_w = compute_padding(input_w, effective_filter_w, stride_w, params.padding);
138 for batch in 0..batches {
139 for out_y in 0..output_h {
140 for out_x in 0..output_w {
141 for in_channel in 0..input_c {
142 for channel_multiplier in 0..depth_multiplier {
143 let out_channel = in_channel * depth_multiplier + channel_multiplier;
144 let (multiplier, shift) = output_channel_multiplier_shift(
145 params.input_quant,
146 params.weights_quant,
147 params.weights_per_channel_quant,
148 params.output_quant,
149 out_channel,
150 );
151 let mut acc = params
152 .bias
153 .map(|bias| bias[out_channel])
154 .unwrap_or_default();
155
156 for filter_y in 0..depthwise_dims.filter_h {
157 let in_y = out_y * stride_h + filter_y * dilation_h;
158 if in_y < pad_h || in_y >= input_h + pad_h {
159 continue;
160 }
161 let in_y = in_y - pad_h;
162
163 for filter_x in 0..depthwise_dims.filter_w {
164 let in_x = out_x * stride_w + filter_x * dilation_w;
165 if in_x < pad_w || in_x >= input_w + pad_w {
166 continue;
167 }
168 let in_x = in_x - pad_w;
169
170 let input = params.input[nhwc_index(
171 batch, in_y, in_x, in_channel, input_h, input_w, input_c,
172 )] as i32
173 - params.input_quant.zero_point;
174 let weight = params.weights[depthwise_weight_index(
175 filter_y,
176 filter_x,
177 in_channel,
178 channel_multiplier,
179 depthwise_dims,
180 )] as i32
181 - params.weights_quant.zero_point;
182 acc = acc.saturating_add(input.saturating_mul(weight));
183 }
184 }
185
186 let scaled = requantize(acc, multiplier, shift, params.output_quant);
187 params.output[nhwc_index(
188 batch,
189 out_y,
190 out_x,
191 out_channel,
192 output_h,
193 output_w,
194 output_c,
195 )] = apply_activation(scaled, params.activation, params.output_quant);
196 }
197 }
198 }
199 }
200 }
201
202 Ok(())
203 }
204
205 fn fully_connected(&mut self, params: FullyConnectedParams<'_>) -> Status {
206 validate_len(params.output, params.output_depth)?;
207 let [output_depth, input_depth] = params.weights_shape;
208 if params.output_depth != output_depth
209 || params.weights.len() != output_depth * input_depth
210 || params.input.len() != input_depth
211 {
212 return Err(KernelError::InvalidShape);
213 }
214 validate_bias(params.bias, output_depth)?;
215
216 for out_channel in 0..output_depth {
217 let (multiplier, shift) = output_channel_multiplier_shift(
218 params.input_quant,
219 params.weights_quant,
220 params.weights_per_channel_quant,
221 params.output_quant,
222 out_channel,
223 );
224 let mut acc = params
225 .bias
226 .map(|bias| bias[out_channel])
227 .unwrap_or_default();
228 for in_channel in 0..input_depth {
229 let input = params.input[in_channel] as i32 - params.input_quant.zero_point;
230 let weight = params.weights[out_channel * input_depth + in_channel] as i32
231 - params.weights_quant.zero_point;
232 acc = acc.saturating_add(input.saturating_mul(weight));
233 }
234
235 let scaled = requantize(acc, multiplier, shift, params.output_quant);
236 params.output[out_channel] =
237 apply_activation(scaled, params.activation, params.output_quant);
238 }
239
240 Ok(())
241 }
242
243 fn avg_pool(&mut self, params: PoolParams<'_>) -> Status {
244 pool(params, PoolKind::Average)
245 }
246
247 fn max_pool(&mut self, params: PoolParams<'_>) -> Status {
248 pool(params, PoolKind::Max)
249 }
250
251 fn softmax(&mut self, params: SoftmaxParams<'_>) -> Status {
252 let [batches, classes] = params.input_shape;
253 if params.input.len() != batches * classes || params.output.len() != batches * classes {
254 return Err(KernelError::InvalidShape);
255 }
256
257 let mut exps: Vec<f32> = vec![0.0; classes];
258 for batch in 0..batches {
259 let offset = batch * classes;
260 let mut max_input = i8::MIN;
261 for class in 0..classes {
262 max_input = max_input.max(params.input[offset + class]);
263 }
264
265 let mut sum = 0.0f32;
266 for (class, exp) in exps.iter_mut().enumerate() {
267 let centered = (params.input[offset + class] as i32 - max_input as i32) as f32;
268 let real = centered * params.input_quant.scale * params.beta;
269 *exp = libm::expf(real);
270 sum += *exp;
271 }
272
273 if sum == 0.0 {
274 return Err(KernelError::InternalError);
275 }
276
277 for (class, exp) in exps.iter().enumerate() {
278 let probability = *exp / sum;
279 let quantized = round_f32_to_i32(probability / params.output_quant.scale)
280 + params.output_quant.zero_point;
281 params.output[offset + class] = clamp_i8(quantized);
282 }
283 }
284
285 Ok(())
286 }
287
288 fn add(&mut self, params: ElementwiseAddParams<'_>) -> Status {
289 if params.input1.len() != params.input2.len() || params.output.len() != params.input1.len()
290 {
291 return Err(KernelError::InvalidShape);
292 }
293
294 for index in 0..params.output.len() {
295 let lhs = (params.input1[index] as i32 - params.input1_quant.zero_point) as f32
296 * params.input1_quant.scale;
297 let rhs = (params.input2[index] as i32 - params.input2_quant.zero_point) as f32
298 * params.input2_quant.scale;
299 let quantized = round_f32_to_i32((lhs + rhs) / params.output_quant.scale)
300 + params.output_quant.zero_point;
301 params.output[index] =
302 apply_activation(quantized, params.activation, params.output_quant);
303 }
304
305 Ok(())
306 }
307}
308
309#[derive(Clone, Copy)]
310enum PoolKind {
311 Average,
312 Max,
313}
314
315fn pool(params: PoolParams<'_>, kind: PoolKind) -> Status {
316 validate_len(params.input, product(¶ms.input_shape))?;
317 validate_len(params.output, product(¶ms.output_shape))?;
318
319 let [batches, input_h, input_w, channels] = params.input_shape;
320 let [output_batches, output_h, output_w, output_channels] = params.output_shape;
321 if batches != output_batches || channels != output_channels {
322 return Err(KernelError::InvalidShape);
323 }
324
325 let stride_h = positive_i32_to_usize(params.stride_h)?;
326 let stride_w = positive_i32_to_usize(params.stride_w)?;
327 let filter_h = positive_i32_to_usize(params.filter_h)?;
328 let filter_w = positive_i32_to_usize(params.filter_w)?;
329 let pad_h = compute_padding(input_h, filter_h, stride_h, params.padding);
330 let pad_w = compute_padding(input_w, filter_w, stride_w, params.padding);
331 let (multiplier, shift) =
332 quantize_multiplier((params.input_quant.scale / params.output_quant.scale) as f64);
333
334 for batch in 0..batches {
335 for out_y in 0..output_h {
336 for out_x in 0..output_w {
337 for channel in 0..channels {
338 let mut acc = 0i32;
339 let mut count = 0i32;
340 let mut max_value = i8::MIN;
341
342 for filter_y in 0..filter_h {
343 let in_y = out_y * stride_h + filter_y;
344 if in_y < pad_h || in_y >= input_h + pad_h {
345 continue;
346 }
347 let in_y = in_y - pad_h;
348
349 for filter_x in 0..filter_w {
350 let in_x = out_x * stride_w + filter_x;
351 if in_x < pad_w || in_x >= input_w + pad_w {
352 continue;
353 }
354 let in_x = in_x - pad_w;
355 let input = params.input[nhwc_index(
356 batch, in_y, in_x, channel, input_h, input_w, channels,
357 )];
358 acc += input as i32 - params.input_quant.zero_point;
359 count += 1;
360 max_value = max_value.max(input);
361 }
362 }
363
364 if count == 0 {
365 return Err(KernelError::InvalidShape);
366 }
367
368 let quantized = match kind {
369 PoolKind::Average => {
370 let average = round_divide(acc, count);
371 requantize(average, multiplier, shift, params.output_quant)
372 }
373 PoolKind::Max => {
374 let centered = max_value as i32 - params.input_quant.zero_point;
375 requantize(centered, multiplier, shift, params.output_quant)
376 }
377 };
378 params.output
379 [nhwc_index(batch, out_y, out_x, channel, output_h, output_w, channels)] =
380 apply_activation(quantized, params.activation, params.output_quant);
381 }
382 }
383 }
384 }
385
386 Ok(())
387}
388
389fn validate_len<T>(slice: &[T], expected: usize) -> Status {
390 if slice.len() == expected {
391 Ok(())
392 } else {
393 Err(KernelError::InvalidShape)
394 }
395}
396
397fn validate_bias(bias: Option<&[i32]>, expected: usize) -> Status {
398 match bias {
399 Some(bias) => validate_len(bias, expected),
400 None => Ok(()),
401 }
402}
403
404fn product<const N: usize>(shape: &[usize; N]) -> usize {
405 shape.iter().product()
406}
407
408fn positive_i32_to_usize(value: i32) -> Result<usize, KernelError> {
409 if value > 0 {
410 Ok(value as usize)
411 } else {
412 Err(KernelError::InvalidShape)
413 }
414}
415
416fn effective_filter_size(filter_size: usize, dilation: usize) -> usize {
417 (filter_size - 1) * dilation + 1
418}
419
420fn nhwc_index(
421 batch: usize,
422 y: usize,
423 x: usize,
424 channel: usize,
425 height: usize,
426 width: usize,
427 channels: usize,
428) -> usize {
429 ((batch * height + y) * width + x) * channels + channel
430}
431
432fn conv_weight_index(
433 output_channel: usize,
434 filter_y: usize,
435 filter_x: usize,
436 input_channel: usize,
437 filter_h: usize,
438 filter_w: usize,
439 input_channels: usize,
440) -> usize {
441 ((output_channel * filter_h + filter_y) * filter_w + filter_x) * input_channels + input_channel
442}
443
444fn depthwise_weight_index(
445 filter_y: usize,
446 filter_x: usize,
447 input_channel: usize,
448 channel_multiplier: usize,
449 dims: DepthwiseDims,
450) -> usize {
451 let output_channel = input_channel * dims.depth_multiplier + channel_multiplier;
452 if dims.tflite_layout {
453 (filter_y * dims.filter_w + filter_x) * (dims.input_channels * dims.depth_multiplier)
454 + output_channel
455 } else {
456 ((filter_y * dims.filter_w + filter_x) * dims.input_channels + input_channel)
457 * dims.depth_multiplier
458 + channel_multiplier
459 }
460}
461
462#[derive(Clone, Copy)]
463struct DepthwiseDims {
464 filter_h: usize,
465 filter_w: usize,
466 input_channels: usize,
467 depth_multiplier: usize,
468 tflite_layout: bool,
469}
470
471fn depthwise_filter_dims(
472 weights_shape: [usize; 4],
473 input_channels: usize,
474 depth_multiplier: usize,
475) -> Result<DepthwiseDims, KernelError> {
476 if weights_shape[0] == 1 {
477 if input_channels == 0 {
478 return Err(KernelError::InvalidShape);
479 }
480 Ok(DepthwiseDims {
481 filter_h: weights_shape[1],
482 filter_w: weights_shape[2],
483 input_channels,
484 depth_multiplier: weights_shape[3] / input_channels,
485 tflite_layout: true,
486 })
487 } else {
488 Ok(DepthwiseDims {
489 filter_h: weights_shape[0],
490 filter_w: weights_shape[1],
491 input_channels: weights_shape[2],
492 depth_multiplier: weights_shape[3],
493 tflite_layout: false,
494 })
495 }
496 .and_then(|dims| {
497 if dims.input_channels == input_channels && dims.depth_multiplier == depth_multiplier {
498 Ok(dims)
499 } else {
500 Err(KernelError::InvalidShape)
501 }
502 })
503}
504
505fn multiply_by_quantized_multiplier(x: i32, multiplier: i32, shift: i32) -> i32 {
506 let total_shift = 31 - shift;
507 if total_shift <= 0 {
508 return saturating_left_shift(x.saturating_mul(multiplier), (-total_shift) as u32);
509 }
510 let round = 1i64 << (total_shift - 1);
511 (((x as i64 * multiplier as i64) + round) >> total_shift) as i32
512}
513
514fn saturating_left_shift(value: i32, shift: u32) -> i32 {
515 if value == 0 {
516 return 0;
517 }
518
519 if shift >= 31 {
520 if value >= 0 {
521 i32::MAX
522 } else {
523 i32::MIN
524 }
525 } else {
526 ((value as i64) << shift).clamp(i32::MIN as i64, i32::MAX as i64) as i32
527 }
528}
529
530fn quantize_multiplier(scale: f64) -> (i32, i32) {
531 if scale <= 0.0 {
532 return (0, 0);
533 }
534
535 let mut significand = scale;
536 let mut shift = 0i32;
537
538 while significand < 0.5 {
539 significand *= 2.0;
540 shift -= 1;
541 }
542 while significand >= 1.0 {
543 significand /= 2.0;
544 shift += 1;
545 }
546
547 let mut q = libm::round(significand * (1i64 << 31) as f64) as i64;
548 if q == 1i64 << 31 {
549 q /= 2;
550 shift += 1;
551 }
552
553 (q as i32, shift)
554}
555
556fn output_channel_multiplier_shift(
557 input_quant: QuantParam,
558 weights_quant: QuantParam,
559 weights_per_channel_quant: Option<PerChannelQuantParam<'_>>,
560 output_quant: QuantParam,
561 output_channel: usize,
562) -> (i32, i32) {
563 let weight_scale = weights_per_channel_quant
564 .and_then(|per_channel| per_channel.scales.get(output_channel).copied())
565 .unwrap_or(weights_quant.scale);
566 quantize_multiplier((input_quant.scale * weight_scale / output_quant.scale) as f64)
567}
568
569fn requantize(acc: i32, multiplier: i32, shift: i32, output_quant: QuantParam) -> i32 {
570 multiply_by_quantized_multiplier(acc, multiplier, shift) + output_quant.zero_point
571}
572
573fn apply_activation(val: i32, activation: FusedActivation, output_quant: QuantParam) -> i8 {
574 let min = match activation {
575 FusedActivation::None | FusedActivation::Sigmoid | FusedActivation::SignBit => {
576 i8::MIN as i32
577 }
578 FusedActivation::Relu | FusedActivation::Relu6 => {
579 (i8::MIN as i32).max(output_quant.zero_point)
580 }
581 FusedActivation::ReluN1To1 | FusedActivation::Tanh => (i8::MIN as i32)
582 .max(output_quant.zero_point + round_f32_to_i32(-1.0 / output_quant.scale)),
583 };
584 let max = match activation {
585 FusedActivation::Relu6 => (i8::MAX as i32)
586 .min(output_quant.zero_point + round_f32_to_i32(6.0 / output_quant.scale)),
587 FusedActivation::ReluN1To1 | FusedActivation::Tanh | FusedActivation::Sigmoid => (i8::MAX
588 as i32)
589 .min(output_quant.zero_point + round_f32_to_i32(1.0 / output_quant.scale)),
590 FusedActivation::None | FusedActivation::Relu | FusedActivation::SignBit => i8::MAX as i32,
591 };
592
593 clamp_i8(val.clamp(min, max))
594}
595
596fn clamp_i8(value: i32) -> i8 {
597 value.clamp(i8::MIN as i32, i8::MAX as i32) as i8
598}
599
600fn compute_padding(
601 input_size: usize,
602 filter_size: usize,
603 stride: usize,
604 padding: Padding,
605) -> usize {
606 match padding {
607 Padding::Valid => 0,
608 Padding::Same => {
609 let out_size = input_size.div_ceil(stride);
610 let pad = ((out_size - 1) * stride + filter_size).saturating_sub(input_size);
611 pad / 2
612 }
613 }
614}
615
616fn round_f32_to_i32(value: f32) -> i32 {
617 libm::roundf(value) as i32
618}
619
620fn round_divide(numerator: i32, denominator: i32) -> i32 {
621 if numerator >= 0 {
622 (numerator + denominator / 2) / denominator
623 } else {
624 (numerator - denominator / 2) / denominator
625 }
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631
632 const UNIT_QUANT: QuantParam = QuantParam {
633 scale: 1.0,
634 zero_point: 0,
635 };
636
637 #[test]
638 fn fully_connected_identity_scale() {
639 let mut backend = RefBackend;
640 let input = [2, -3];
641 let weights = [4, 5, -1, 6];
642 let mut output = [0; 2];
643
644 backend
645 .fully_connected(FullyConnectedParams {
646 input: &input,
647 input_quant: UNIT_QUANT,
648 weights: &weights,
649 weights_shape: [2, 2],
650 weights_quant: UNIT_QUANT,
651 weights_per_channel_quant: None,
652 bias: Some(&[1, -2]),
653 output: &mut output,
654 output_depth: 2,
655 output_quant: UNIT_QUANT,
656 activation: FusedActivation::None,
657 })
658 .unwrap();
659
660 assert_eq!(output, [-6, -22]);
661 }
662
663 #[test]
664 fn add_identity_scale() {
665 let mut backend = RefBackend;
666 let input1 = [1, -2, 3];
667 let input2 = [4, 5, -6];
668 let mut output = [0; 3];
669
670 backend
671 .add(ElementwiseAddParams {
672 input1: &input1,
673 input1_quant: UNIT_QUANT,
674 input2: &input2,
675 input2_quant: UNIT_QUANT,
676 output: &mut output,
677 output_quant: UNIT_QUANT,
678 activation: FusedActivation::None,
679 })
680 .unwrap();
681
682 assert_eq!(output, [5, 3, -3]);
683 }
684
685 #[test]
686 fn avg_pool_valid() {
687 let mut backend = RefBackend;
688 let input = [1, 3, 5, 7];
689 let mut output = [0; 1];
690
691 backend
692 .avg_pool(PoolParams {
693 input: &input,
694 input_shape: [1, 2, 2, 1],
695 input_quant: UNIT_QUANT,
696 output: &mut output,
697 output_shape: [1, 1, 1, 1],
698 output_quant: UNIT_QUANT,
699 stride_w: 1,
700 stride_h: 1,
701 filter_w: 2,
702 filter_h: 2,
703 padding: Padding::Valid,
704 activation: FusedActivation::None,
705 })
706 .unwrap();
707
708 assert_eq!(output, [4]);
709 }
710
711 #[test]
712 fn conv2d_single_filter_valid() {
713 let mut backend = RefBackend;
714 let input = [1, 2, 3, 4];
715 let weights = [1, 0, 0, 1];
716 let mut output = [0; 1];
717
718 backend
719 .conv2d(Conv2dParams {
720 input: &input,
721 input_shape: [1, 2, 2, 1],
722 input_quant: UNIT_QUANT,
723 weights: &weights,
724 weights_shape: [1, 2, 2, 1],
725 weights_quant: UNIT_QUANT,
726 weights_per_channel_quant: None,
727 bias: None,
728 output: &mut output,
729 output_shape: [1, 1, 1, 1],
730 output_quant: UNIT_QUANT,
731 stride_w: 1,
732 stride_h: 1,
733 dilation_w_factor: 1,
734 dilation_h_factor: 1,
735 padding: Padding::Valid,
736 activation: FusedActivation::None,
737 scratch: &mut [],
738 })
739 .unwrap();
740
741 assert_eq!(output, [5]);
742 }
743
744 #[test]
745 fn depthwise_accepts_tflite_filter_layout() {
746 let mut backend = RefBackend;
747 let input = [1, 2, 3, 4];
748 let weights = [1, 0, 0, 1];
749 let mut output = [0; 1];
750
751 backend
752 .depthwise_conv2d(DepthwiseConv2dParams {
753 input: &input,
754 input_shape: [1, 2, 2, 1],
755 input_quant: UNIT_QUANT,
756 weights: &weights,
757 weights_shape: [1, 2, 2, 1],
758 weights_quant: UNIT_QUANT,
759 weights_per_channel_quant: None,
760 bias: None,
761 output: &mut output,
762 output_shape: [1, 1, 1, 1],
763 output_quant: UNIT_QUANT,
764 stride_w: 1,
765 stride_h: 1,
766 dilation_w_factor: 1,
767 dilation_h_factor: 1,
768 depth_multiplier: 1,
769 padding: Padding::Valid,
770 activation: FusedActivation::None,
771 scratch: &mut [],
772 })
773 .unwrap();
774
775 assert_eq!(output, [5]);
776 }
777
778 #[test]
779 fn softmax_outputs_probability_distribution() {
780 let mut backend = RefBackend;
781 let input = [0, 0];
782 let mut output = [0; 2];
783
784 backend
785 .softmax(SoftmaxParams {
786 input: &input,
787 input_shape: [1, 2],
788 input_quant: QuantParam {
789 scale: 1.0,
790 zero_point: 0,
791 },
792 output: &mut output,
793 output_quant: QuantParam {
794 scale: 1.0 / 256.0,
795 zero_point: -128,
796 },
797 beta: 1.0,
798 scratch: &mut [],
799 })
800 .unwrap();
801
802 assert_eq!(output, [0, 0]);
803 }
804}