#[inline]
pub fn requantize_scalar(
input: &[i32],
output: &mut [i8],
scales: &[f32],
zero_point: i8,
out_channels: usize,
) {
debug_assert_eq!(input.len(), output.len());
debug_assert!(!scales.is_empty());
let zp = zero_point as f32;
for (i, (&acc, out)) in input.iter().zip(output.iter_mut()).enumerate() {
let channel = i % out_channels;
let scale = scales[channel];
let scaled = acc as f32 * scale + zp;
*out = scaled.round().clamp(-128.0, 127.0) as i8;
}
}
#[inline]
pub fn conv2d_int8_scalar(
input: &[i8],
weights: &[i8],
bias: &[i32],
output: &mut [i32],
batch: usize,
in_h: usize,
in_w: usize,
in_c: usize,
out_c: usize,
kh: usize,
kw: usize,
stride: usize,
padding: usize,
dilation: usize,
) {
let out_h = (in_h + 2 * padding - dilation * (kh - 1) - 1) / stride + 1;
let out_w = (in_w + 2 * padding - dilation * (kw - 1) - 1) / stride + 1;
debug_assert_eq!(input.len(), batch * in_h * in_w * in_c);
debug_assert_eq!(weights.len(), out_c * kh * kw * in_c);
debug_assert_eq!(bias.len(), out_c);
debug_assert_eq!(output.len(), batch * out_h * out_w * out_c);
for b in 0..batch {
for oh in 0..out_h {
for ow in 0..out_w {
for oc in 0..out_c {
let mut acc = bias[oc];
for kh_idx in 0..kh {
for kw_idx in 0..kw {
let ih = (oh * stride + kh_idx * dilation) as isize - padding as isize;
let iw = (ow * stride + kw_idx * dilation) as isize - padding as isize;
if ih < 0 || ih >= in_h as isize || iw < 0 || iw >= in_w as isize {
continue;
}
let ih = ih as usize;
let iw = iw as usize;
for ic in 0..in_c {
let input_idx = ((b * in_h + ih) * in_w + iw) * in_c + ic;
let weight_idx = ((oc * kh + kh_idx) * kw + kw_idx) * in_c + ic;
let input_val = input[input_idx] as i32;
let weight_val = weights[weight_idx] as i32;
acc += input_val * weight_val;
}
}
}
let output_idx = ((b * out_h + oh) * out_w + ow) * out_c + oc;
output[output_idx] = acc;
}
}
}
}
}
#[inline]
pub fn depthwise_conv2d_int8_scalar(
input: &[i8],
weights: &[i8],
bias: &[i32],
output: &mut [i32],
batch: usize,
in_h: usize,
in_w: usize,
channels: usize,
kh: usize,
kw: usize,
stride: usize,
padding: usize,
dilation: usize,
) {
let out_h = (in_h + 2 * padding - dilation * (kh - 1) - 1) / stride + 1;
let out_w = (in_w + 2 * padding - dilation * (kw - 1) - 1) / stride + 1;
debug_assert_eq!(input.len(), batch * in_h * in_w * channels);
debug_assert_eq!(weights.len(), channels * kh * kw);
debug_assert_eq!(bias.len(), channels);
debug_assert_eq!(output.len(), batch * out_h * out_w * channels);
for b in 0..batch {
for oh in 0..out_h {
for ow in 0..out_w {
for c in 0..channels {
let mut acc = bias[c];
for kh_idx in 0..kh {
for kw_idx in 0..kw {
let ih = (oh * stride + kh_idx * dilation) as isize - padding as isize;
let iw = (ow * stride + kw_idx * dilation) as isize - padding as isize;
if ih < 0 || ih >= in_h as isize || iw < 0 || iw >= in_w as isize {
continue;
}
let ih = ih as usize;
let iw = iw as usize;
let input_idx = ((b * in_h + ih) * in_w + iw) * channels + c;
let weight_idx = (c * kh + kh_idx) * kw + kw_idx;
acc += (input[input_idx] as i32) * (weights[weight_idx] as i32);
}
}
let output_idx = ((b * out_h + oh) * out_w + ow) * channels + c;
output[output_idx] = acc;
}
}
}
}
}
#[inline]
pub fn matmul_int8_scalar(
a: &[i8],
b: &[i8],
bias: &[i32],
output: &mut [i32],
m: usize,
k: usize,
n: usize,
) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
debug_assert_eq!(bias.len(), n);
debug_assert_eq!(output.len(), m * n);
for i in 0..m {
for j in 0..n {
let mut acc = bias[j];
for k_idx in 0..k {
let a_val = a[i * k + k_idx] as i32;
let b_val = b[k_idx * n + j] as i32; acc += a_val * b_val;
}
output[i * n + j] = acc;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_requantize_scalar() {
let input = vec![1000, 2000, 3000, 4000];
let mut output = vec![0i8; 4];
let scales = vec![0.01, 0.02]; let zero_point = 0;
requantize_scalar(&input, &mut output, &scales, zero_point, 2);
assert_eq!(output[0], 10);
assert_eq!(output[2], 30);
assert_eq!(output[1], 40);
assert_eq!(output[3], 80);
}
#[test]
fn test_requantize_clamping() {
let input = vec![20000, -20000]; let mut output = vec![0i8; 2];
let scales = vec![1.0];
let zero_point = 0;
requantize_scalar(&input, &mut output, &scales, zero_point, 1);
assert_eq!(output[0], 127);
assert_eq!(output[1], -128);
}
#[test]
fn test_conv2d_int8_scalar_3x3_identity() {
let input = vec![1i8, 2, 3, 4, 5, 6, 7, 8, 9];
let weights = vec![
0i8, 0, 0, 0, 1, 0, 0, 0, 0, ];
let bias = vec![0i32];
let mut output = vec![0i32; 9];
conv2d_int8_scalar(
&input, &weights, &bias, &mut output, 1, 3, 3, 1, 1, 3, 3, 1, 1, 1,
);
assert_eq!(output[4], 5); }
#[test]
fn test_conv2d_int8_scalar_no_overflow() {
let input = vec![127i8; 64]; let weights = vec![127i8; 64]; let bias = vec![0i32];
let mut output = vec![0i32; 64];
conv2d_int8_scalar(
&input, &weights, &bias, &mut output, 1, 8, 8, 1, 1, 3, 3, 1, 1, 1,
);
for &val in &output {
assert!(val > 0); assert!(val <= 145161); }
}
#[test]
fn test_depthwise_conv2d_int8_scalar() {
let input = vec![1i8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18];
let weights = vec![
1i8, 0, 0, 0, 1, 0, 0, 0, 1, 0i8, 0, 1, 0, 1, 0, 1, 0, 0, ];
let bias = vec![0i32, 0i32];
let mut output = vec![0i32; 18];
depthwise_conv2d_int8_scalar(
&input, &weights, &bias, &mut output, 1, 3, 3, 2, 3, 3, 1, 1, 1,
);
assert_eq!(output[8], 15);
assert_eq!(output[9], 15);
}
#[test]
fn test_matmul_int8_scalar_2x2() {
let a = vec![1i8, 2, 3, 4, 5, 6]; let b = vec![1i8, 2, 3, 4, 5, 6]; let bias = vec![0i32, 0i32];
let mut output = vec![0i32; 4];
matmul_int8_scalar(&a, &b, &bias, &mut output, 2, 3, 2);
assert_eq!(output[0], 22);
assert_eq!(output[1], 28);
assert_eq!(output[2], 49);
assert_eq!(output[3], 64);
}
#[test]
fn test_matmul_int8_scalar_with_bias() {
let a = vec![1i8, 2]; let b = vec![3i8, 4]; let bias = vec![10i32];
let mut output = vec![0i32];
matmul_int8_scalar(&a, &b, &bias, &mut output, 1, 2, 1);
assert_eq!(output[0], 21);
}
#[test]
fn test_conv2d_stride_2() {
let input = vec![1i8; 16]; let weights = vec![1i8; 9]; let bias = vec![0i32];
let mut output = vec![0i32; 4];
conv2d_int8_scalar(
&input, &weights, &bias, &mut output, 1, 4, 4, 1, 1, 3, 3, 2, 1, 1,
);
for &val in &output {
assert!(val >= 0);
}
}
#[test]
fn test_conv2d_dilation_2() {
let input = vec![1i8; 25]; let weights = vec![1i8; 9]; let bias = vec![0i32];
let mut output = vec![0i32; 9];
conv2d_int8_scalar(
&input, &weights, &bias, &mut output, 1, 5, 5, 1, 1, 3, 3, 1, 2, 2,
);
for &val in &output {
assert!(val >= 0);
}
}
#[test]
fn test_requantize_preserves_range() {
use std::i32::{MAX, MIN};
let input = vec![MAX, MIN, 0, 1000, -1000];
let mut output = vec![0i8; 5];
let scales = vec![0.001];
requantize_scalar(&input, &mut output, &scales, 0, 1);
for &val in &output {
assert!(val >= -128 && val <= 127);
}
}
#[test]
fn test_conv2d_commutative_channels() {
let input = vec![1i8, 2, 3, 4]; let weights = vec![1i8, 1, 1, 1]; let bias = vec![0i32];
let mut output1 = vec![0i32; 4];
let mut output2 = vec![0i32; 4];
conv2d_int8_scalar(
&input, &weights, &bias, &mut output1, 1, 2, 2, 1, 1, 2, 2, 1, 0, 1,
);
conv2d_int8_scalar(
&input, &weights, &bias, &mut output2, 1, 2, 2, 1, 1, 2, 2, 1, 0, 1,
);
assert_eq!(output1, output2);
}
#[test]
fn test_depthwise_per_channel_independence() {
let input = vec![1i8, 10, 2, 20, 3, 30, 4, 40]; let weights = vec![1i8, 1, 1, 1, 2, 2, 2, 2]; let bias = vec![0i32, 0i32];
let mut output = vec![0i32; 8];
depthwise_conv2d_int8_scalar(
&input, &weights, &bias, &mut output, 1, 2, 2, 2, 2, 2, 1, 0, 1,
);
assert!(output[1] > output[0]);
}
#[test]
fn test_matmul_associative_bias() {
let a = vec![1i8, 2];
let b = vec![3i8, 4];
let bias1 = vec![5i32];
let bias2 = vec![10i32];
let mut output1 = vec![0i32];
let mut output2 = vec![0i32];
matmul_int8_scalar(&a, &b, &bias1, &mut output1, 1, 2, 1);
matmul_int8_scalar(&a, &b, &bias2, &mut output2, 1, 2, 1);
assert_eq!(output2[0] - output1[0], 5); }
}