numrs/backend/cpu/
simd_conv.rs

1use crate::array::Array;
2/// SIMD implementation of Conv1D
3pub fn conv1d_simd(
4    input: &Array,
5    weight: &Array,
6    bias: Option<&Array>,
7    stride: usize,
8    padding: usize,
9) -> anyhow::Result<Array> {
10    // Input: [Batch, InChannels, InLength]
11    // Weight: [OutChannels, InChannels, KernelSize]
12    // Output: [Batch, OutChannels, OutLength]
13
14    let batch_size = input.shape[0];
15    let in_channels = input.shape[1];
16    let in_length = input.shape[2];
17
18    let out_channels = weight.shape[0];
19    let kernel_size = weight.shape[2];
20
21    // Output length: (InLength + 2*Padding - KernelSize) / Stride + 1
22    let out_length = (in_length + 2 * padding - kernel_size) / stride + 1;
23
24    let _output_shape = vec![batch_size, out_channels, out_length];
25
26    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
27    {
28        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
29            unsafe {
30                return conv1d_avx2_fma(
31                    input,
32                    weight,
33                    bias,
34                    batch_size,
35                    in_channels,
36                    in_length,
37                    out_channels,
38                    kernel_size,
39                    out_length,
40                    stride,
41                    padding,
42                );
43            }
44        }
45    }
46
47    // Fallback to naive if SIMD not available
48    crate::backend::cpu::conv::conv1d_naive(input, weight, bias, stride, padding)
49}
50
51#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
52#[target_feature(enable = "avx2,fma")]
53unsafe fn conv1d_avx2_fma(
54    input: &Array,
55    weight: &Array,
56    bias: Option<&Array>,
57    batch_size: usize,
58    in_channels: usize,
59    in_length: usize,
60    out_channels: usize,
61    kernel_size: usize,
62    out_length: usize,
63    stride: usize,
64    padding: usize,
65) -> anyhow::Result<Array> {
66    use rayon::prelude::*;
67    #[cfg(target_arch = "x86")]
68    use std::arch::x86::*;
69    #[cfg(target_arch = "x86_64")]
70    use std::arch::x86_64::*;
71
72    let mut output = Array::zeros(vec![batch_size, out_channels, out_length]);
73    // Safety strategy: Cast pointer to usize (integer) to pass through Rayon boundaries.
74    // Integers are always Send + Sync. We cast back to pointer inside the thread.
75    let base_ptr_addr = output.data.as_mut_ptr() as usize;
76
77    // Pre-process bias if present
78    let bias_data = if let Some(b) = bias {
79        b.data.clone()
80    } else {
81        vec![0.0; out_channels]
82    };
83
84    let bias_slice = bias_data.as_slice();
85
86    // Parallelize over batch and output channels
87    // Each thread handles one (batch, out_channel) slice outputting [out_length]
88    (0..batch_size).into_par_iter().for_each(move |b_idx| {
89        (0..out_channels).into_par_iter().for_each(move |oc| {
90            let out_ptr = base_ptr_addr as *mut f32;
91            // Get pointers relative to this task
92            // Output start: b_idx * out_channels * out_length + oc * out_length
93            let out_offset = b_idx * out_channels * out_length + oc * out_length;
94
95            // Weight start: oc * in_channels * kernel_size
96            let weight_offset_base = oc * in_channels * kernel_size;
97
98            // Input start: b_idx * in_channels * in_length
99            let input_offset_base = b_idx * in_channels * in_length;
100
101            let bias_val = bias_slice[oc];
102            let v_bias = _mm256_set1_ps(bias_val);
103
104            // Iterate over output sequence vectorized (8 elements at a time)
105            let mut o_idx = 0;
106            while o_idx + 8 <= out_length {
107                let mut v_acc = v_bias;
108
109                // Convolve over Kernel and InChannels
110                for ic in 0..in_channels {
111                    let w_start = weight_offset_base + ic * kernel_size;
112                    let in_start = input_offset_base + ic * in_length;
113
114                    for k in 0..kernel_size {
115                        let w_val = *weight.data.get_unchecked(w_start + k);
116                        let v_w = _mm256_set1_ps(w_val);
117
118                        // Input indices for the 8 output positions
119                        // input_idx = (o_idx + shift) * stride + k - padding
120                        // We need to handle padding checks potentially.
121                        // If padding is 0 and stride is 1 (common case), it simplifies.
122
123                        // Vectorized load of input is tricky with stride/padding.
124                        // For generic case, let's gather or scalar load into vector.
125
126                        // Construct the 8 input values manually for FMA
127                        // This is "vectorized intrisics" but partly scalar loads due to arbitrary stride
128                        let mut loaded_inputs = [0.0f32; 8];
129                        for lane in 0..8 {
130                            let cur_out_idx = o_idx + lane;
131                            let input_idx_signed =
132                                (cur_out_idx * stride) as isize + k as isize - padding as isize;
133
134                            if input_idx_signed >= 0 && input_idx_signed < in_length as isize {
135                                loaded_inputs[lane] = *input
136                                    .data
137                                    .get_unchecked(in_start + input_idx_signed as usize);
138                            } else {
139                                loaded_inputs[lane] = 0.0;
140                            }
141                        }
142
143                        let v_in = _mm256_loadu_ps(loaded_inputs.as_ptr());
144                        v_acc = _mm256_fmadd_ps(v_in, v_w, v_acc);
145                    }
146                }
147
148                // Store result
149                // Safe because we are the only ones writing to this slice of output
150                let dst_ptr = out_ptr.add(out_offset + o_idx);
151                _mm256_storeu_ps(dst_ptr, v_acc);
152
153                o_idx += 8;
154            }
155
156            // Handle remaining output elements
157            for i in o_idx..out_length {
158                let mut acc = bias_val;
159                for ic in 0..in_channels {
160                    let w_start = weight_offset_base + ic * kernel_size;
161                    let in_start = input_offset_base + ic * in_length;
162
163                    for k in 0..kernel_size {
164                        let input_idx_signed =
165                            (i * stride) as isize + k as isize - padding as isize;
166                        if input_idx_signed >= 0 && input_idx_signed < in_length as isize {
167                            let val = *input
168                                .data
169                                .get_unchecked(in_start + input_idx_signed as usize);
170                            let w = *weight.data.get_unchecked(w_start + k);
171                            acc += val * w;
172                        }
173                    }
174                }
175                *out_ptr.add(out_offset + i) = acc;
176            }
177        });
178    });
179
180    Ok(output)
181}