#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "aarch64")]
pub unsafe fn dot_product_int8_neon(a: &[i8], b: &[i8]) -> i32 {
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 16;
let mut acc = vdupq_n_s32(0);
for i in 0..chunks {
let offset = i * 16;
let va_low = vld1_s8(a.as_ptr().add(offset));
let va_high = vld1_s8(a.as_ptr().add(offset + 8));
let vb_low = vld1_s8(b.as_ptr().add(offset));
let vb_high = vld1_s8(b.as_ptr().add(offset + 8));
let prod16_low = vmull_s8(va_low, vb_low);
let prod16_high = vmull_s8(va_high, vb_high);
acc = vpadalq_s16(acc, prod16_low);
acc = vpadalq_s16(acc, prod16_high);
}
let sum_pair = vpadd_s32(vget_low_s32(acc), vget_high_s32(acc));
let sum_final = vpadd_s32(sum_pair, sum_pair);
let mut result = vget_lane_s32(sum_final, 0);
for i in (chunks * 16)..len {
result += (a[i] as i32) * (b[i] as i32);
}
result
}
#[cfg(target_arch = "aarch64")]
pub unsafe fn conv2d_int8_neon(
input: &[u8],
input_zero_point: i32,
kernel: &[i8],
bias_i32: &[i32],
output: &mut [i32],
in_h: usize,
in_w: usize,
in_c: usize,
out_c: usize,
stride: usize,
padding: usize,
) {
let out_h = (in_h + 2 * padding - 3) / stride + 1;
let out_w = (in_w + 2 * padding - 3) / stride + 1;
let out_c_chunks = out_c / 4; let kernel_size = 3;
let mut weight_sums = vec![0i32; out_c];
for oc in 0..out_c {
let mut sum = 0i32;
for ic in 0..in_c {
for kh in 0..3 {
for kw in 0..3 {
let idx = (oc * in_c + ic) * 9 + kh * 3 + kw;
sum += kernel[idx] as i32;
}
}
}
weight_sums[oc] = sum;
}
for oh in 0..out_h {
for ow in 0..out_w {
let out_spatial_idx = oh * out_w + ow;
for oc_chunk in 0..out_c_chunks {
let oc_base = oc_chunk * 4;
let mut acc = [0i32; 4];
for i in 0..4 {
let oc = oc_base + i;
acc[i] = bias_i32[oc] - input_zero_point * weight_sums[oc];
}
for kh in 0..kernel_size {
for kw in 0..kernel_size {
let ih = (oh * stride + kh) as isize - padding as isize;
let iw = (ow * stride + kw) as isize - padding as isize;
if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
let ih = ih as usize;
let iw = iw as usize;
let ic_chunks = in_c / 16;
for ic_chunk in 0..ic_chunks {
let ic_base = ic_chunk * 16;
let input_base = (ih * in_w + iw) * in_c + ic_base;
let input_u8 = vld1q_u8(input.as_ptr().add(input_base));
let offset = vdupq_n_u8(128);
let input_shifted = vsubq_u8(input_u8, offset);
let input_i8_low = vreinterpret_s8_u8(vget_low_u8(input_shifted));
let input_i8_high = vreinterpret_s8_u8(vget_high_u8(input_shifted));
for i in 0..4 {
let oc = oc_base + i;
let mut w_buf = [0i8; 16];
for j in 0..16 {
let k_idx = (oc * in_c + ic_base + j) * 9 + kh * 3 + kw;
w_buf[j] = kernel[k_idx];
}
let kernel_i8_low = vld1_s8(w_buf.as_ptr());
let kernel_i8_high = vld1_s8(w_buf.as_ptr().add(8));
let prod16_low = vmull_s8(input_i8_low, kernel_i8_low);
let prod16_high = vmull_s8(input_i8_high, kernel_i8_high);
let sum_low = vpaddlq_s16(prod16_low);
let sum_high = vpaddlq_s16(prod16_high);
let total = vaddq_s32(sum_low, sum_high);
let sum_pair = vpadd_s32(vget_low_s32(total), vget_high_s32(total));
let sum_final = vpadd_s32(sum_pair, sum_pair);
acc[i] += vget_lane_s32(sum_final, 0);
}
}
for ic in (ic_chunks * 16)..in_c {
let input_idx = (ih * in_w + iw) * in_c + ic;
let input_val = input[input_idx] as i32;
for i in 0..4 {
let oc = oc_base + i;
let k_idx = (oc * in_c + ic) * 9 + kh * 3 + kw;
let w_val = kernel[k_idx] as i32;
acc[i] += input_val * w_val;
}
}
}
}
}
for i in 0..4 {
output[out_spatial_idx * out_c + oc_base + i] = acc[i];
}
}
for oc in (out_c_chunks * 4)..out_c {
let mut acc = bias_i32[oc] - input_zero_point * weight_sums[oc];
for kh in 0..kernel_size {
for kw in 0..kernel_size {
let ih = (oh * stride + kh) as isize - padding as isize;
let iw = (ow * stride + kw) as isize - padding as isize;
if ih >= 0 && ih < in_h as isize && iw >= 0 && iw < in_w as isize {
let ih = ih as usize;
let iw = iw as usize;
for ic in 0..in_c {
let input_idx = (ih * in_w + iw) * in_c + ic;
let k_idx = (oc * in_c + ic) * 9 + kh * 3 + kw;
acc += (input[input_idx] as i32) * (kernel[k_idx] as i32);
}
}
}
}
output[out_spatial_idx * out_c + oc] = acc;
}
}
}
}
#[cfg(target_arch = "aarch64")]
pub unsafe fn depthwise_conv2d_int8_neon(
input: &[u8],
input_zero_point: i32,
kernel: &[i8],
bias_i32: &[i32],
output: &mut [i32],
h: usize,
w: usize,
c: usize,
stride: usize,
padding: usize,
) {
let out_h = (h + 2 * padding - 3) / stride + 1;
let out_w = (w + 2 * padding - 3) / stride + 1;
let c_chunks = c / 4;
let mut weight_sums = vec![0i32; c];
for ch in 0..c {
let mut sum = 0i32;
for k in 0..9 {
sum += kernel[ch * 9 + k] as i32;
}
weight_sums[ch] = sum;
}
for oh in 0..out_h {
for ow in 0..out_w {
for c_chunk in 0..c_chunks {
let c_base = c_chunk * 4;
let mut acc = [0i32; 4];
for i in 0..4 {
let ch = c_base + i;
acc[i] = bias_i32[ch] - input_zero_point * weight_sums[ch];
}
for kh in 0..3 {
for kw in 0..3 {
let ih = (oh * stride + kh) as isize - padding as isize;
let iw = (ow * stride + kw) as isize - padding as isize;
if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
let ih = ih as usize;
let iw = iw as usize;
let input_base = (ih * w + iw) * c + c_base;
let kernel_base = c_base * 9 + kh * 3 + kw;
for i in 0..4 {
let input_val = input[input_base + i] as i32;
let kernel_val = kernel[kernel_base + i * 9] as i32;
acc[i] += input_val * kernel_val;
}
}
}
}
for i in 0..4 {
output[(oh * out_w + ow) * c + c_base + i] = acc[i];
}
}
for ch in (c_chunks * 4)..c {
let mut acc = bias_i32[ch] - input_zero_point * weight_sums[ch];
for kh in 0..3 {
for kw in 0..3 {
let ih = (oh * stride + kh) as isize - padding as isize;
let iw = (ow * stride + kw) as isize - padding as isize;
if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
let ih = ih as usize;
let iw = iw as usize;
let input_idx = (ih * w + iw) * c + ch;
let kernel_idx = ch * 9 + kh * 3 + kw;
acc += (input[input_idx] as i32) * (kernel[kernel_idx] as i32);
}
}
}
output[(oh * out_w + ow) * c + ch] = acc;
}
}
}
}
#[cfg(target_arch = "aarch64")]
pub unsafe fn matmul_int8_neon(
a: &[i8],
b: &[i8],
output: &mut [i32],
m: usize,
k: usize,
n: usize,
) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
debug_assert_eq!(output.len(), m * n);
let n_chunks = n / 4;
for row in 0..m {
let a_row = &a[row * k..(row + 1) * k];
for n_chunk in 0..n_chunks {
let n_base = n_chunk * 4;
let mut acc = [0i32; 4];
let k_chunks = k / 16;
for k_chunk in 0..k_chunks {
let k_base = k_chunk * 16;
let va_low = vld1_s8(a_row.as_ptr().add(k_base));
let va_high = vld1_s8(a_row.as_ptr().add(k_base + 8));
for i in 0..4 {
let col = n_base + i;
let mut w_buf = [0i8; 16];
for j in 0..16 {
w_buf[j] = b[(k_base + j) * n + col];
}
let vb_low = vld1_s8(w_buf.as_ptr());
let vb_high = vld1_s8(w_buf.as_ptr().add(8));
let prod16_low = vmull_s8(va_low, vb_low);
let prod16_high = vmull_s8(va_high, vb_high);
let sum_low = vpaddlq_s16(prod16_low);
let sum_high = vpaddlq_s16(prod16_high);
let total = vaddq_s32(sum_low, sum_high);
let sum_pair = vpadd_s32(vget_low_s32(total), vget_high_s32(total));
let sum_final = vpadd_s32(sum_pair, sum_pair);
acc[i] += vget_lane_s32(sum_final, 0);
}
}
for k_idx in (k_chunks * 16)..k {
let a_val = a_row[k_idx] as i32;
for i in 0..4 {
let col = n_base + i;
let b_val = b[k_idx * n + col] as i32;
acc[i] += a_val * b_val;
}
}
for i in 0..4 {
output[row * n + n_base + i] = acc[i];
}
}
for col in (n_chunks * 4)..n {
let mut sum = 0i32;
for k_idx in 0..k {
sum += (a_row[k_idx] as i32) * (b[k_idx * n + col] as i32);
}
output[row * n + col] = sum;
}
}
}
#[cfg(not(target_arch = "aarch64"))]
pub unsafe fn dot_product_int8_neon(_a: &[i8], _b: &[i8]) -> i32 {
0
}
#[cfg(not(target_arch = "aarch64"))]
pub unsafe fn conv2d_int8_neon(
_input: &[u8],
_input_zero_point: i32,
_kernel: &[i8],
_bias_i32: &[i32],
_output: &mut [i32],
_in_h: usize,
_in_w: usize,
_in_c: usize,
_out_c: usize,
_stride: usize,
_padding: usize,
) {
}
#[cfg(not(target_arch = "aarch64"))]
pub unsafe fn depthwise_conv2d_int8_neon(
_input: &[u8],
_input_zero_point: i32,
_kernel: &[i8],
_bias_i32: &[i32],
_output: &mut [i32],
_h: usize,
_w: usize,
_c: usize,
_stride: usize,
_padding: usize,
) {
}
#[cfg(not(target_arch = "aarch64"))]
pub unsafe fn matmul_int8_neon(
_a: &[i8],
_b: &[i8],
_output: &mut [i32],
_m: usize,
_k: usize,
_n: usize,
) {
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dot_product_int8_neon() {
let a: Vec<i8> = (0..32).map(|i| (i % 64) as i8).collect();
let b: Vec<i8> = (0..32).map(|i| ((i % 64) as i8)).collect();
#[cfg(target_arch = "aarch64")]
{
unsafe {
let result = dot_product_int8_neon(&a, &b);
let mut expected = 0i32;
for i in 0..32 {
expected += (a[i] as i32) * (b[i] as i32);
}
assert_eq!(result, expected);
}
}
}
}