#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use crate::ops::conv_common::{Conv1dParams, Conv2dParams};
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn conv1d_f32(
input: *const f32,
weight: *const f32,
bias: Option<*const f32>,
output: *mut f32,
params: Conv1dParams,
) {
let lanes = 4;
let c_in_per_group = params.c_in / params.groups;
let c_out_per_group = params.c_out / params.groups;
let chunks = c_in_per_group / lanes;
for b in 0..params.batch {
for g in 0..params.groups {
let c_in_start = g * c_in_per_group;
let c_out_start = g * c_out_per_group;
for oc in 0..c_out_per_group {
let out_c = c_out_start + oc;
let bias_val = if let Some(b) = bias {
*b.add(out_c)
} else {
0.0
};
for ol in 0..params.output_length {
let mut acc = vdupq_n_f32(bias_val);
let mut scalar_acc = 0.0f32;
for k in 0..params.kernel_size {
let il_signed = (ol * params.stride) as isize
+ (k * params.dilation) as isize
- params.pad_left as isize;
if il_signed < 0 || (il_signed as usize) >= params.length {
continue;
}
let il = il_signed as usize;
for chunk in 0..chunks {
let ic_base = chunk * lanes;
let in_idx = b * params.c_in * params.length
+ (c_in_start + ic_base) * params.length
+ il;
let in_arr = [
*input.add(in_idx),
*input.add(in_idx + params.length),
*input.add(in_idx + 2 * params.length),
*input.add(in_idx + 3 * params.length),
];
let v_in = vld1q_f32(in_arr.as_ptr());
let w_idx = out_c * c_in_per_group * params.kernel_size
+ ic_base * params.kernel_size
+ k;
let w_arr = [
*weight.add(w_idx),
*weight.add(w_idx + params.kernel_size),
*weight.add(w_idx + 2 * params.kernel_size),
*weight.add(w_idx + 3 * params.kernel_size),
];
let v_w = vld1q_f32(w_arr.as_ptr());
acc = vfmaq_f32(acc, v_in, v_w);
}
for ic in (chunks * lanes)..c_in_per_group {
let in_idx = b * params.c_in * params.length
+ (c_in_start + ic) * params.length
+ il;
let w_idx = out_c * c_in_per_group * params.kernel_size
+ ic * params.kernel_size
+ k;
scalar_acc += *input.add(in_idx) * *weight.add(w_idx);
}
}
let sum = vaddvq_f32(acc) + scalar_acc;
let out_idx =
b * params.c_out * params.output_length + out_c * params.output_length + ol;
*output.add(out_idx) = sum;
}
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn conv1d_f64(
input: *const f64,
weight: *const f64,
bias: Option<*const f64>,
output: *mut f64,
params: Conv1dParams,
) {
let lanes = 2;
let c_in_per_group = params.c_in / params.groups;
let c_out_per_group = params.c_out / params.groups;
let chunks = c_in_per_group / lanes;
for b in 0..params.batch {
for g in 0..params.groups {
let c_in_start = g * c_in_per_group;
let c_out_start = g * c_out_per_group;
for oc in 0..c_out_per_group {
let out_c = c_out_start + oc;
let bias_val = if let Some(b) = bias {
*b.add(out_c)
} else {
0.0
};
for ol in 0..params.output_length {
let mut acc = vdupq_n_f64(bias_val);
let mut scalar_acc = 0.0f64;
for k in 0..params.kernel_size {
let il_signed = (ol * params.stride) as isize
+ (k * params.dilation) as isize
- params.pad_left as isize;
if il_signed < 0 || (il_signed as usize) >= params.length {
continue;
}
let il = il_signed as usize;
for chunk in 0..chunks {
let ic_base = chunk * lanes;
let in_idx = b * params.c_in * params.length
+ (c_in_start + ic_base) * params.length
+ il;
let in_arr = [*input.add(in_idx), *input.add(in_idx + params.length)];
let v_in = vld1q_f64(in_arr.as_ptr());
let w_idx = out_c * c_in_per_group * params.kernel_size
+ ic_base * params.kernel_size
+ k;
let w_arr =
[*weight.add(w_idx), *weight.add(w_idx + params.kernel_size)];
let v_w = vld1q_f64(w_arr.as_ptr());
acc = vfmaq_f64(acc, v_in, v_w);
}
for ic in (chunks * lanes)..c_in_per_group {
let in_idx = b * params.c_in * params.length
+ (c_in_start + ic) * params.length
+ il;
let w_idx = out_c * c_in_per_group * params.kernel_size
+ ic * params.kernel_size
+ k;
scalar_acc += *input.add(in_idx) * *weight.add(w_idx);
}
}
let sum = vaddvq_f64(acc) + scalar_acc;
let out_idx =
b * params.c_out * params.output_length + out_c * params.output_length + ol;
*output.add(out_idx) = sum;
}
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn conv2d_f32(
input: *const f32,
weight: *const f32,
bias: Option<*const f32>,
output: *mut f32,
params: Conv2dParams,
) {
let lanes = 4;
let c_in_per_group = params.c_in / params.groups;
let c_out_per_group = params.c_out / params.groups;
let chunks = c_in_per_group / lanes;
let in_spatial = params.height * params.width;
let out_spatial = params.output_h * params.output_w;
let kernel_spatial = params.kernel_h * params.kernel_w;
for b in 0..params.batch {
for g in 0..params.groups {
let c_in_start = g * c_in_per_group;
let c_out_start = g * c_out_per_group;
for oc in 0..c_out_per_group {
let out_c = c_out_start + oc;
let bias_val = if let Some(bias_ptr) = bias {
*bias_ptr.add(out_c)
} else {
0.0
};
for oh in 0..params.output_h {
for ow in 0..params.output_w {
let mut acc = vdupq_n_f32(bias_val);
let mut scalar_acc = 0.0f32;
for kh in 0..params.kernel_h {
for kw in 0..params.kernel_w {
let ih_signed = (oh * params.stride_h) as isize
+ (kh * params.dilation_h) as isize
- params.pad_top as isize;
let iw_signed = (ow * params.stride_w) as isize
+ (kw * params.dilation_w) as isize
- params.pad_left as isize;
if ih_signed < 0
|| (ih_signed as usize) >= params.height
|| iw_signed < 0
|| (iw_signed as usize) >= params.width
{
continue;
}
let ih = ih_signed as usize;
let iw = iw_signed as usize;
for chunk in 0..chunks {
let ic_base = chunk * lanes;
let mut in_arr = [0.0f32; 4];
for lane in 0..lanes {
let ic = ic_base + lane;
let in_idx = b * params.c_in * in_spatial
+ (c_in_start + ic) * in_spatial
+ ih * params.width
+ iw;
in_arr[lane] = *input.add(in_idx);
}
let v_in = vld1q_f32(in_arr.as_ptr());
let mut w_arr = [0.0f32; 4];
for lane in 0..lanes {
let ic = ic_base + lane;
let w_idx = out_c * c_in_per_group * kernel_spatial
+ ic * kernel_spatial
+ kh * params.kernel_w
+ kw;
w_arr[lane] = *weight.add(w_idx);
}
let v_w = vld1q_f32(w_arr.as_ptr());
acc = vfmaq_f32(acc, v_in, v_w);
}
for ic in (chunks * lanes)..c_in_per_group {
let in_idx = b * params.c_in * in_spatial
+ (c_in_start + ic) * in_spatial
+ ih * params.width
+ iw;
let w_idx = out_c * c_in_per_group * kernel_spatial
+ ic * kernel_spatial
+ kh * params.kernel_w
+ kw;
scalar_acc += *input.add(in_idx) * *weight.add(w_idx);
}
}
}
let sum = vaddvq_f32(acc) + scalar_acc;
let out_idx = b * params.c_out * out_spatial
+ out_c * out_spatial
+ oh * params.output_w
+ ow;
*output.add(out_idx) = sum;
}
}
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn conv2d_f64(
input: *const f64,
weight: *const f64,
bias: Option<*const f64>,
output: *mut f64,
params: Conv2dParams,
) {
let lanes = 2;
let c_in_per_group = params.c_in / params.groups;
let c_out_per_group = params.c_out / params.groups;
let chunks = c_in_per_group / lanes;
let in_spatial = params.height * params.width;
let out_spatial = params.output_h * params.output_w;
let kernel_spatial = params.kernel_h * params.kernel_w;
for b in 0..params.batch {
for g in 0..params.groups {
let c_in_start = g * c_in_per_group;
let c_out_start = g * c_out_per_group;
for oc in 0..c_out_per_group {
let out_c = c_out_start + oc;
let bias_val = if let Some(bias_ptr) = bias {
*bias_ptr.add(out_c)
} else {
0.0
};
for oh in 0..params.output_h {
for ow in 0..params.output_w {
let mut acc = vdupq_n_f64(bias_val);
let mut scalar_acc = 0.0f64;
for kh in 0..params.kernel_h {
for kw in 0..params.kernel_w {
let ih_signed = (oh * params.stride_h) as isize
+ (kh * params.dilation_h) as isize
- params.pad_top as isize;
let iw_signed = (ow * params.stride_w) as isize
+ (kw * params.dilation_w) as isize
- params.pad_left as isize;
if ih_signed < 0
|| (ih_signed as usize) >= params.height
|| iw_signed < 0
|| (iw_signed as usize) >= params.width
{
continue;
}
let ih = ih_signed as usize;
let iw = iw_signed as usize;
for chunk in 0..chunks {
let ic_base = chunk * lanes;
let mut in_arr = [0.0f64; 2];
for lane in 0..lanes {
let ic = ic_base + lane;
let in_idx = b * params.c_in * in_spatial
+ (c_in_start + ic) * in_spatial
+ ih * params.width
+ iw;
in_arr[lane] = *input.add(in_idx);
}
let v_in = vld1q_f64(in_arr.as_ptr());
let mut w_arr = [0.0f64; 2];
for lane in 0..lanes {
let ic = ic_base + lane;
let w_idx = out_c * c_in_per_group * kernel_spatial
+ ic * kernel_spatial
+ kh * params.kernel_w
+ kw;
w_arr[lane] = *weight.add(w_idx);
}
let v_w = vld1q_f64(w_arr.as_ptr());
acc = vfmaq_f64(acc, v_in, v_w);
}
for ic in (chunks * lanes)..c_in_per_group {
let in_idx = b * params.c_in * in_spatial
+ (c_in_start + ic) * in_spatial
+ ih * params.width
+ iw;
let w_idx = out_c * c_in_per_group * kernel_spatial
+ ic * kernel_spatial
+ kh * params.kernel_w
+ kw;
scalar_acc += *input.add(in_idx) * *weight.add(w_idx);
}
}
}
let sum = vaddvq_f64(acc) + scalar_acc;
let out_idx = b * params.c_out * out_spatial
+ out_c * out_spatial
+ oh * params.output_w
+ ow;
*output.add(out_idx) = sum;
}
}
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn depthwise_conv2d_f32(
input: *const f32,
weight: *const f32,
bias: Option<*const f32>,
output: *mut f32,
params: Conv2dParams,
) {
let lanes = 4;
let out_w_chunks = params.output_w / lanes;
let in_spatial = params.height * params.width;
let out_spatial = params.output_h * params.output_w;
let kernel_spatial = params.kernel_h * params.kernel_w;
for b in 0..params.batch {
for c in 0..params.c_out {
let bias_val = if let Some(bias_ptr) = bias {
*bias_ptr.add(c)
} else {
0.0
};
let v_bias = vdupq_n_f32(bias_val);
for oh in 0..params.output_h {
for ow_chunk in 0..out_w_chunks {
let ow_base = ow_chunk * lanes;
let mut acc = v_bias;
for kh in 0..params.kernel_h {
for kw in 0..params.kernel_w {
let w_val = *weight.add(c * kernel_spatial + kh * params.kernel_w + kw);
let v_w = vdupq_n_f32(w_val);
let ih_signed = (oh * params.stride_h) as isize
+ (kh * params.dilation_h) as isize
- params.pad_top as isize;
if ih_signed < 0 || (ih_signed as usize) >= params.height {
continue;
}
let ih = ih_signed as usize;
let mut in_arr = [0.0f32; 4];
for lane in 0..lanes {
let ow = ow_base + lane;
let iw_signed = (ow * params.stride_w) as isize
+ (kw * params.dilation_w) as isize
- params.pad_left as isize;
if iw_signed >= 0 && (iw_signed as usize) < params.width {
let iw = iw_signed as usize;
let in_idx = b * params.c_in * in_spatial
+ c * in_spatial
+ ih * params.width
+ iw;
in_arr[lane] = *input.add(in_idx);
}
}
let v_in = vld1q_f32(in_arr.as_ptr());
acc = vfmaq_f32(acc, v_in, v_w);
}
}
let out_idx = b * params.c_out * out_spatial
+ c * out_spatial
+ oh * params.output_w
+ ow_base;
vst1q_f32(output.add(out_idx), acc);
}
for ow in (out_w_chunks * lanes)..params.output_w {
let mut sum = bias_val;
for kh in 0..params.kernel_h {
for kw in 0..params.kernel_w {
let ih_signed = (oh * params.stride_h) as isize
+ (kh * params.dilation_h) as isize
- params.pad_top as isize;
let iw_signed = (ow * params.stride_w) as isize
+ (kw * params.dilation_w) as isize
- params.pad_left as isize;
if ih_signed < 0
|| (ih_signed as usize) >= params.height
|| iw_signed < 0
|| (iw_signed as usize) >= params.width
{
continue;
}
let ih = ih_signed as usize;
let iw = iw_signed as usize;
let in_idx = b * params.c_in * in_spatial
+ c * in_spatial
+ ih * params.width
+ iw;
let w_idx = c * kernel_spatial + kh * params.kernel_w + kw;
sum += *input.add(in_idx) * *weight.add(w_idx);
}
}
let out_idx = b * params.c_out * out_spatial
+ c * out_spatial
+ oh * params.output_w
+ ow;
*output.add(out_idx) = sum;
}
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn depthwise_conv2d_f64(
input: *const f64,
weight: *const f64,
bias: Option<*const f64>,
output: *mut f64,
params: Conv2dParams,
) {
let lanes = 2;
let out_w_chunks = params.output_w / lanes;
let in_spatial = params.height * params.width;
let out_spatial = params.output_h * params.output_w;
let kernel_spatial = params.kernel_h * params.kernel_w;
for b in 0..params.batch {
for c in 0..params.c_out {
let bias_val = if let Some(bias_ptr) = bias {
*bias_ptr.add(c)
} else {
0.0
};
let v_bias = vdupq_n_f64(bias_val);
for oh in 0..params.output_h {
for ow_chunk in 0..out_w_chunks {
let ow_base = ow_chunk * lanes;
let mut acc = v_bias;
for kh in 0..params.kernel_h {
for kw in 0..params.kernel_w {
let w_val = *weight.add(c * kernel_spatial + kh * params.kernel_w + kw);
let v_w = vdupq_n_f64(w_val);
let ih_signed = (oh * params.stride_h) as isize
+ (kh * params.dilation_h) as isize
- params.pad_top as isize;
if ih_signed < 0 || (ih_signed as usize) >= params.height {
continue;
}
let ih = ih_signed as usize;
let mut in_arr = [0.0f64; 2];
for lane in 0..lanes {
let ow = ow_base + lane;
let iw_signed = (ow * params.stride_w) as isize
+ (kw * params.dilation_w) as isize
- params.pad_left as isize;
if iw_signed >= 0 && (iw_signed as usize) < params.width {
let iw = iw_signed as usize;
let in_idx = b * params.c_in * in_spatial
+ c * in_spatial
+ ih * params.width
+ iw;
in_arr[lane] = *input.add(in_idx);
}
}
let v_in = vld1q_f64(in_arr.as_ptr());
acc = vfmaq_f64(acc, v_in, v_w);
}
}
let out_idx = b * params.c_out * out_spatial
+ c * out_spatial
+ oh * params.output_w
+ ow_base;
vst1q_f64(output.add(out_idx), acc);
}
for ow in (out_w_chunks * lanes)..params.output_w {
let mut sum = bias_val;
for kh in 0..params.kernel_h {
for kw in 0..params.kernel_w {
let ih_signed = (oh * params.stride_h) as isize
+ (kh * params.dilation_h) as isize
- params.pad_top as isize;
let iw_signed = (ow * params.stride_w) as isize
+ (kw * params.dilation_w) as isize
- params.pad_left as isize;
if ih_signed < 0
|| (ih_signed as usize) >= params.height
|| iw_signed < 0
|| (iw_signed as usize) >= params.width
{
continue;
}
let ih = ih_signed as usize;
let iw = iw_signed as usize;
let in_idx = b * params.c_in * in_spatial
+ c * in_spatial
+ ih * params.width
+ iw;
let w_idx = c * kernel_spatial + kh * params.kernel_w + kw;
sum += *input.add(in_idx) * *weight.add(w_idx);
}
}
let out_idx = b * params.c_out * out_spatial
+ c * out_spatial
+ oh * params.output_w
+ ow;
*output.add(out_idx) = sum;
}
}
}
}
}