1#[allow(clippy::too_many_arguments)]
23pub fn conv1d_scalar(
24 input: &[f32],
25 weight: &[f32],
26 bias: Option<&[f32]>,
27 c_in: usize,
28 c_out: usize,
29 length: usize,
30 kernel_size: usize,
31 stride: usize,
32 padding: usize,
33 output: &mut [f32],
34) {
35 assert_eq!(input.len(), c_in * length, "input length mismatch");
36 assert_eq!(
37 weight.len(),
38 c_out * c_in * kernel_size,
39 "weight length mismatch"
40 );
41 if let Some(b) = bias {
42 assert_eq!(b.len(), c_out, "bias length mismatch");
43 }
44 assert!(stride > 0, "stride must be > 0");
45 let out_length = (length + 2 * padding - kernel_size) / stride + 1;
46 assert_eq!(
47 output.len(),
48 c_out * out_length,
49 "output length mismatch: expected {}",
50 c_out * out_length
51 );
52
53 for oc in 0..c_out {
54 for ol in 0..out_length {
55 let sum = conv1d_output_element(
56 input,
57 weight,
58 c_in,
59 length,
60 kernel_size,
61 stride,
62 padding,
63 oc,
64 ol,
65 );
66 let bias_val = bias.map_or(0.0, |b| b[oc]);
67 output[oc * out_length + ol] = sum + bias_val;
68 }
69 }
70}
71
72#[allow(clippy::too_many_arguments)]
74fn conv1d_output_element(
75 input: &[f32],
76 weight: &[f32],
77 c_in: usize,
78 length: usize,
79 kernel_size: usize,
80 stride: usize,
81 padding: usize,
82 oc: usize,
83 ol: usize,
84) -> f32 {
85 let mut sum = 0.0_f32;
86 for ic in 0..c_in {
87 for k in 0..kernel_size {
88 let in_pos_signed = (ol * stride + k) as isize - padding as isize;
89 if in_pos_signed >= 0 && (in_pos_signed as usize) < length {
90 let in_pos = in_pos_signed as usize;
91 let w_idx = oc * c_in * kernel_size + ic * kernel_size + k;
92 let i_idx = ic * length + in_pos;
93 sum += weight[w_idx] * input[i_idx];
94 }
95 }
96 }
97 sum
98}
99
100#[cfg(target_arch = "x86_64")]
112#[target_feature(enable = "avx2")]
113#[allow(clippy::too_many_arguments)]
114pub unsafe fn conv1d_avx2(
115 input: &[f32],
116 weight: &[f32],
117 bias: Option<&[f32]>,
118 c_in: usize,
119 c_out: usize,
120 length: usize,
121 kernel_size: usize,
122 stride: usize,
123 padding: usize,
124 output: &mut [f32],
125) {
126 conv1d_scalar(
127 input,
128 weight,
129 bias,
130 c_in,
131 c_out,
132 length,
133 kernel_size,
134 stride,
135 padding,
136 output,
137 );
138}
139
140pub fn conv1d_ptx() -> &'static str {
146 r#".version 8.5
147.target sm_90
148.address_size 64
149
150// Conv1D kernel: 1 block per output channel, threads along output length.
151// Params: input_ptr, weight_ptr, bias_ptr, output_ptr,
152// c_in, length, kernel_size, stride, padding, out_length
153.visible .entry conv1d_kernel(
154 .param .u64 input_ptr,
155 .param .u64 weight_ptr,
156 .param .u64 bias_ptr,
157 .param .u64 output_ptr,
158 .param .u32 c_in,
159 .param .u32 length,
160 .param .u32 kernel_size,
161 .param .u32 stride,
162 .param .u32 padding,
163 .param .u32 out_length
164)
165{
166 .reg .u32 %tid, %oc, %ol, %ic, %k, %c_in, %len, %ks, %str, %pad, %olen;
167 .reg .u32 %in_pos, %w_idx, %i_idx, %tmp, %w_base_oc;
168 .reg .s32 %in_pos_signed;
169 .reg .u64 %in_base, %w_base, %b_base, %out_base, %addr;
170 .reg .f32 %sum, %wval, %ival, %bval;
171 .reg .pred %p, %p_lo, %p_hi;
172
173 mov.u32 %oc, %ctaid.x;
174 mov.u32 %tid, %tid.x;
175 ld.param.u32 %olen, [out_length];
176 setp.ge.u32 %p, %tid, %olen;
177 @%p bra DONE;
178
179 mov.u32 %ol, %tid;
180 ld.param.u64 %in_base, [input_ptr];
181 ld.param.u64 %w_base, [weight_ptr];
182 ld.param.u64 %b_base, [bias_ptr];
183 ld.param.u64 %out_base, [output_ptr];
184 ld.param.u32 %c_in, [c_in];
185 ld.param.u32 %len, [length];
186 ld.param.u32 %ks, [kernel_size];
187 ld.param.u32 %str, [stride];
188 ld.param.u32 %pad, [padding];
189
190 mov.f32 %sum, 0f00000000;
191
192 // weight base for this output channel: oc * c_in * kernel_size
193 mul.lo.u32 %w_base_oc, %oc, %c_in;
194 mul.lo.u32 %w_base_oc, %w_base_oc, %ks;
195
196 mov.u32 %ic, 0;
197IC_LOOP:
198 setp.ge.u32 %p, %ic, %c_in;
199 @%p bra IC_DONE;
200
201 mov.u32 %k, 0;
202K_LOOP:
203 setp.ge.u32 %p, %k, %ks;
204 @%p bra K_DONE;
205
206 // in_pos_signed = ol * stride + k - padding
207 mul.lo.u32 %in_pos, %ol, %str;
208 add.u32 %in_pos, %in_pos, %k;
209 sub.s32 %in_pos_signed, %in_pos, %pad;
210 setp.lt.s32 %p_lo, %in_pos_signed, 0;
211 @%p_lo bra SKIP;
212 mov.u32 %in_pos, %in_pos_signed;
213 setp.ge.u32 %p_hi, %in_pos, %len;
214 @%p_hi bra SKIP;
215
216 // w_idx = w_base_oc + ic * ks + k
217 mul.lo.u32 %w_idx, %ic, %ks;
218 add.u32 %w_idx, %w_idx, %w_base_oc;
219 add.u32 %w_idx, %w_idx, %k;
220 // Load weight
221 cvt.u64.u32 %addr, %w_idx;
222 shl.b64 %addr, %addr, 2;
223 add.u64 %addr, %w_base, %addr;
224 ld.global.f32 %wval, [%addr];
225
226 // i_idx = ic * length + in_pos
227 mul.lo.u32 %i_idx, %ic, %len;
228 add.u32 %i_idx, %i_idx, %in_pos;
229 cvt.u64.u32 %addr, %i_idx;
230 shl.b64 %addr, %addr, 2;
231 add.u64 %addr, %in_base, %addr;
232 ld.global.f32 %ival, [%addr];
233
234 fma.rn.f32 %sum, %wval, %ival, %sum;
235
236SKIP:
237 add.u32 %k, %k, 1;
238 bra K_LOOP;
239K_DONE:
240 add.u32 %ic, %ic, 1;
241 bra IC_LOOP;
242IC_DONE:
243
244 // Add bias if present (bias_ptr != 0)
245 setp.eq.u64 %p, %b_base, 0;
246 @%p bra STORE;
247 cvt.u64.u32 %addr, %oc;
248 shl.b64 %addr, %addr, 2;
249 add.u64 %addr, %b_base, %addr;
250 ld.global.f32 %bval, [%addr];
251 add.f32 %sum, %sum, %bval;
252
253STORE:
254 // output[oc * out_length + ol]
255 mul.lo.u32 %tmp, %oc, %olen;
256 add.u32 %tmp, %tmp, %ol;
257 cvt.u64.u32 %addr, %tmp;
258 shl.b64 %addr, %addr, 2;
259 add.u64 %addr, %out_base, %addr;
260 st.global.f32 [%addr], %sum;
261
262DONE:
263 ret;
264}
265"#
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
277 fn test_conv1d_identity() {
278 let c_in = 2;
280 let c_out = 2;
281 let length = 4;
282 let kernel_size = 1;
283 let stride = 1;
284 let padding = 0;
285 let out_length = (length + 2 * padding - kernel_size) / stride + 1;
286
287 let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
289 let weight = [1.0_f32, 0.0, 0.0, 1.0];
291 let mut output = vec![0.0_f32; c_out * out_length];
292
293 conv1d_scalar(
294 &input,
295 &weight,
296 None,
297 c_in,
298 c_out,
299 length,
300 kernel_size,
301 stride,
302 padding,
303 &mut output,
304 );
305
306 assert_eq!(output, input.to_vec());
308 }
309
310 #[test]
311 fn test_conv1d_known_values() {
312 let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0];
315 let weight = [1.0_f32, 0.0, -1.0];
316 let out_length = (5 - 3) + 1; let mut output = vec![0.0_f32; out_length];
318
319 conv1d_scalar(&input, &weight, None, 1, 1, 5, 3, 1, 0, &mut output);
320
321 assert_eq!(output, vec![-2.0, -2.0, -2.0]);
325 }
326
327 #[test]
328 fn test_conv1d_with_bias() {
329 let input = [1.0_f32, 2.0, 3.0];
330 let weight = [1.0_f32, 1.0];
331 let bias = [10.0_f32];
332 let out_length = (3 - 2) + 1; let mut output = vec![0.0_f32; out_length];
334
335 conv1d_scalar(&input, &weight, Some(&bias), 1, 1, 3, 2, 1, 0, &mut output);
336
337 assert_eq!(output, vec![13.0, 15.0]);
340 }
341
342 #[test]
343 fn test_conv1d_with_padding() {
344 let input = [1.0_f32, 2.0, 3.0];
346 let weight = [1.0_f32, 1.0, 1.0];
347 let out_length = (3 + 2 - 3) + 1; let mut output = vec![0.0_f32; out_length];
349
350 conv1d_scalar(&input, &weight, None, 1, 1, 3, 3, 1, 1, &mut output);
351
352 assert_eq!(output, vec![3.0, 6.0, 5.0]);
356 }
357
358 #[test]
359 fn test_conv1d_with_stride() {
360 let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0];
361 let weight = [1.0_f32];
362 let out_length = (5 - 1) / 2 + 1; let mut output = vec![0.0_f32; out_length];
364
365 conv1d_scalar(&input, &weight, None, 1, 1, 5, 1, 2, 0, &mut output);
366
367 assert_eq!(output, vec![1.0, 3.0, 5.0]);
368 }
369
370 #[test]
371 #[should_panic(expected = "input length mismatch")]
372 fn test_conv1d_input_mismatch() {
373 let input = [1.0_f32; 5];
374 let weight = [1.0_f32; 3];
375 let mut output = [0.0_f32; 3];
376 conv1d_scalar(&input, &weight, None, 2, 1, 5, 3, 1, 0, &mut output);
377 }
378
379 #[cfg(target_arch = "x86_64")]
384 #[test]
385 fn test_conv1d_avx2_parity() {
386 if !is_x86_feature_detected!("avx2") {
387 return;
388 }
389 let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0];
390 let weight = [1.0_f32, 0.0, -1.0];
391 let out_length = 3;
392 let mut scalar_out = vec![0.0_f32; out_length];
393 let mut avx2_out = vec![0.0_f32; out_length];
394
395 conv1d_scalar(&input, &weight, None, 1, 1, 5, 3, 1, 0, &mut scalar_out);
396 unsafe {
397 conv1d_avx2(&input, &weight, None, 1, 1, 5, 3, 1, 0, &mut avx2_out);
398 }
399 assert_eq!(scalar_out, avx2_out);
400 }
401
402 #[test]
407 fn test_conv1d_ptx_version() {
408 let ptx = conv1d_ptx();
409 assert!(
410 ptx.contains(".version 8.5"),
411 "PTX must declare .version 8.5"
412 );
413 }
414
415 #[test]
416 fn test_conv1d_ptx_target() {
417 let ptx = conv1d_ptx();
418 assert!(ptx.contains(".target sm_90"), "PTX must target sm_90");
419 }
420
421 #[test]
422 fn test_conv1d_ptx_entry() {
423 let ptx = conv1d_ptx();
424 assert!(ptx.contains(".entry conv1d_kernel"), "PTX must have .entry");
425 }
426
427 #[test]
428 fn test_conv1d_ptx_ret() {
429 let ptx = conv1d_ptx();
430 assert!(ptx.contains("ret;"), "PTX must have ret;");
431 }
432
433 #[test]
434 fn test_conv1d_ptx_balanced_braces() {
435 let ptx = conv1d_ptx();
436 let opens = ptx.chars().filter(|&c| c == '{').count();
437 let closes = ptx.chars().filter(|&c| c == '}').count();
438 assert_eq!(
439 opens, closes,
440 "PTX must have balanced braces: {opens} opens vs {closes} closes"
441 );
442 }
443}