numrs/backend/cpu/
simd_conv.rs1use crate::array::Array;
2pub fn conv1d_simd(
4 input: &Array,
5 weight: &Array,
6 bias: Option<&Array>,
7 stride: usize,
8 padding: usize,
9) -> anyhow::Result<Array> {
10 let batch_size = input.shape[0];
15 let in_channels = input.shape[1];
16 let in_length = input.shape[2];
17
18 let out_channels = weight.shape[0];
19 let kernel_size = weight.shape[2];
20
21 let out_length = (in_length + 2 * padding - kernel_size) / stride + 1;
23
24 let _output_shape = vec![batch_size, out_channels, out_length];
25
26 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
27 {
28 if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
29 unsafe {
30 return conv1d_avx2_fma(
31 input,
32 weight,
33 bias,
34 batch_size,
35 in_channels,
36 in_length,
37 out_channels,
38 kernel_size,
39 out_length,
40 stride,
41 padding,
42 );
43 }
44 }
45 }
46
47 crate::backend::cpu::conv::conv1d_naive(input, weight, bias, stride, padding)
49}
50
51#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
52#[target_feature(enable = "avx2,fma")]
53unsafe fn conv1d_avx2_fma(
54 input: &Array,
55 weight: &Array,
56 bias: Option<&Array>,
57 batch_size: usize,
58 in_channels: usize,
59 in_length: usize,
60 out_channels: usize,
61 kernel_size: usize,
62 out_length: usize,
63 stride: usize,
64 padding: usize,
65) -> anyhow::Result<Array> {
66 use rayon::prelude::*;
67 #[cfg(target_arch = "x86")]
68 use std::arch::x86::*;
69 #[cfg(target_arch = "x86_64")]
70 use std::arch::x86_64::*;
71
72 let mut output = Array::zeros(vec![batch_size, out_channels, out_length]);
73 let base_ptr_addr = output.data.as_mut_ptr() as usize;
76
77 let bias_data = if let Some(b) = bias {
79 b.data.clone()
80 } else {
81 vec![0.0; out_channels]
82 };
83
84 let bias_slice = bias_data.as_slice();
85
86 (0..batch_size).into_par_iter().for_each(move |b_idx| {
89 (0..out_channels).into_par_iter().for_each(move |oc| {
90 let out_ptr = base_ptr_addr as *mut f32;
91 let out_offset = b_idx * out_channels * out_length + oc * out_length;
94
95 let weight_offset_base = oc * in_channels * kernel_size;
97
98 let input_offset_base = b_idx * in_channels * in_length;
100
101 let bias_val = bias_slice[oc];
102 let v_bias = _mm256_set1_ps(bias_val);
103
104 let mut o_idx = 0;
106 while o_idx + 8 <= out_length {
107 let mut v_acc = v_bias;
108
109 for ic in 0..in_channels {
111 let w_start = weight_offset_base + ic * kernel_size;
112 let in_start = input_offset_base + ic * in_length;
113
114 for k in 0..kernel_size {
115 let w_val = *weight.data.get_unchecked(w_start + k);
116 let v_w = _mm256_set1_ps(w_val);
117
118 let mut loaded_inputs = [0.0f32; 8];
129 for lane in 0..8 {
130 let cur_out_idx = o_idx + lane;
131 let input_idx_signed =
132 (cur_out_idx * stride) as isize + k as isize - padding as isize;
133
134 if input_idx_signed >= 0 && input_idx_signed < in_length as isize {
135 loaded_inputs[lane] = *input
136 .data
137 .get_unchecked(in_start + input_idx_signed as usize);
138 } else {
139 loaded_inputs[lane] = 0.0;
140 }
141 }
142
143 let v_in = _mm256_loadu_ps(loaded_inputs.as_ptr());
144 v_acc = _mm256_fmadd_ps(v_in, v_w, v_acc);
145 }
146 }
147
148 let dst_ptr = out_ptr.add(out_offset + o_idx);
151 _mm256_storeu_ps(dst_ptr, v_acc);
152
153 o_idx += 8;
154 }
155
156 for i in o_idx..out_length {
158 let mut acc = bias_val;
159 for ic in 0..in_channels {
160 let w_start = weight_offset_base + ic * kernel_size;
161 let in_start = input_offset_base + ic * in_length;
162
163 for k in 0..kernel_size {
164 let input_idx_signed =
165 (i * stride) as isize + k as isize - padding as isize;
166 if input_idx_signed >= 0 && input_idx_signed < in_length as isize {
167 let val = *input
168 .data
169 .get_unchecked(in_start + input_idx_signed as usize);
170 let w = *weight.data.get_unchecked(w_start + k);
171 acc += val * w;
172 }
173 }
174 }
175 *out_ptr.add(out_offset + i) = acc;
176 }
177 });
178 });
179
180 Ok(output)
181}