Skip to main content

provable_contracts/kernels/
conv1d.rs

1//! 1D Convolution kernel.
2//!
3//! Matches `conv1d-kernel-v1.yaml`.
4//!
5//! Standard 1D convolution with configurable stride and padding.
6//! Input layout: c_in x length (row-major).
7//! Weight layout: c_out x c_in x kernel_size (row-major).
8//! Output layout: c_out x out_length (row-major).
9
10/// Scalar reference implementation of 1D convolution.
11///
12/// Computes the convolution of `input` with `weight` and optional `bias`.
13///
14/// - `input`: flattened `c_in x length`
15/// - `weight`: flattened `c_out x c_in x kernel_size`
16/// - `bias`: optional, length `c_out`
17/// - `output`: flattened `c_out x out_length` where `out_length = (length + 2*padding - kernel_size) / stride + 1`
18///
19/// # Panics
20///
21/// Panics if dimensions are inconsistent or output buffer has wrong length.
22#[allow(clippy::too_many_arguments)]
23pub fn conv1d_scalar(
24    input: &[f32],
25    weight: &[f32],
26    bias: Option<&[f32]>,
27    c_in: usize,
28    c_out: usize,
29    length: usize,
30    kernel_size: usize,
31    stride: usize,
32    padding: usize,
33    output: &mut [f32],
34) {
35    assert_eq!(input.len(), c_in * length, "input length mismatch");
36    assert_eq!(
37        weight.len(),
38        c_out * c_in * kernel_size,
39        "weight length mismatch"
40    );
41    if let Some(b) = bias {
42        assert_eq!(b.len(), c_out, "bias length mismatch");
43    }
44    assert!(stride > 0, "stride must be > 0");
45    let out_length = (length + 2 * padding - kernel_size) / stride + 1;
46    assert_eq!(
47        output.len(),
48        c_out * out_length,
49        "output length mismatch: expected {}",
50        c_out * out_length
51    );
52
53    for oc in 0..c_out {
54        for ol in 0..out_length {
55            let sum = conv1d_output_element(
56                input,
57                weight,
58                c_in,
59                length,
60                kernel_size,
61                stride,
62                padding,
63                oc,
64                ol,
65            );
66            let bias_val = bias.map_or(0.0, |b| b[oc]);
67            output[oc * out_length + ol] = sum + bias_val;
68        }
69    }
70}
71
72/// Compute a single output element of the convolution.
73#[allow(clippy::too_many_arguments)]
74fn conv1d_output_element(
75    input: &[f32],
76    weight: &[f32],
77    c_in: usize,
78    length: usize,
79    kernel_size: usize,
80    stride: usize,
81    padding: usize,
82    oc: usize,
83    ol: usize,
84) -> f32 {
85    let mut sum = 0.0_f32;
86    for ic in 0..c_in {
87        for k in 0..kernel_size {
88            let in_pos_signed = (ol * stride + k) as isize - padding as isize;
89            if in_pos_signed >= 0 && (in_pos_signed as usize) < length {
90                let in_pos = in_pos_signed as usize;
91                let w_idx = oc * c_in * kernel_size + ic * kernel_size + k;
92                let i_idx = ic * length + in_pos;
93                sum += weight[w_idx] * input[i_idx];
94            }
95        }
96    }
97    sum
98}
99
100/// AVX2 implementation of 1D convolution.
101///
102/// Delegates to scalar due to irregular memory access patterns in convolution.
103///
104/// # Safety
105///
106/// Requires AVX2 support on the target CPU.
107///
108/// # Panics
109///
110/// Same as [`conv1d_scalar`].
111#[cfg(target_arch = "x86_64")]
112#[target_feature(enable = "avx2")]
113#[allow(clippy::too_many_arguments)]
114pub unsafe fn conv1d_avx2(
115    input: &[f32],
116    weight: &[f32],
117    bias: Option<&[f32]>,
118    c_in: usize,
119    c_out: usize,
120    length: usize,
121    kernel_size: usize,
122    stride: usize,
123    padding: usize,
124    output: &mut [f32],
125) {
126    conv1d_scalar(
127        input,
128        weight,
129        bias,
130        c_in,
131        c_out,
132        length,
133        kernel_size,
134        stride,
135        padding,
136        output,
137    );
138}
139
140/// PTX assembly for the 1D convolution kernel.
141///
142/// One block per output channel, threads along output length.
143/// Each thread computes one output position by summing over input
144/// channels and kernel positions.
145pub fn conv1d_ptx() -> &'static str {
146    r#".version 8.5
147.target sm_90
148.address_size 64
149
150// Conv1D kernel: 1 block per output channel, threads along output length.
151// Params: input_ptr, weight_ptr, bias_ptr, output_ptr,
152//         c_in, length, kernel_size, stride, padding, out_length
153.visible .entry conv1d_kernel(
154    .param .u64 input_ptr,
155    .param .u64 weight_ptr,
156    .param .u64 bias_ptr,
157    .param .u64 output_ptr,
158    .param .u32 c_in,
159    .param .u32 length,
160    .param .u32 kernel_size,
161    .param .u32 stride,
162    .param .u32 padding,
163    .param .u32 out_length
164)
165{
166    .reg .u32 %tid, %oc, %ol, %ic, %k, %c_in, %len, %ks, %str, %pad, %olen;
167    .reg .u32 %in_pos, %w_idx, %i_idx, %tmp, %w_base_oc;
168    .reg .s32 %in_pos_signed;
169    .reg .u64 %in_base, %w_base, %b_base, %out_base, %addr;
170    .reg .f32 %sum, %wval, %ival, %bval;
171    .reg .pred %p, %p_lo, %p_hi;
172
173    mov.u32 %oc, %ctaid.x;
174    mov.u32 %tid, %tid.x;
175    ld.param.u32 %olen, [out_length];
176    setp.ge.u32 %p, %tid, %olen;
177    @%p bra DONE;
178
179    mov.u32 %ol, %tid;
180    ld.param.u64 %in_base, [input_ptr];
181    ld.param.u64 %w_base, [weight_ptr];
182    ld.param.u64 %b_base, [bias_ptr];
183    ld.param.u64 %out_base, [output_ptr];
184    ld.param.u32 %c_in, [c_in];
185    ld.param.u32 %len, [length];
186    ld.param.u32 %ks, [kernel_size];
187    ld.param.u32 %str, [stride];
188    ld.param.u32 %pad, [padding];
189
190    mov.f32 %sum, 0f00000000;
191
192    // weight base for this output channel: oc * c_in * kernel_size
193    mul.lo.u32 %w_base_oc, %oc, %c_in;
194    mul.lo.u32 %w_base_oc, %w_base_oc, %ks;
195
196    mov.u32 %ic, 0;
197IC_LOOP:
198    setp.ge.u32 %p, %ic, %c_in;
199    @%p bra IC_DONE;
200
201    mov.u32 %k, 0;
202K_LOOP:
203    setp.ge.u32 %p, %k, %ks;
204    @%p bra K_DONE;
205
206    // in_pos_signed = ol * stride + k - padding
207    mul.lo.u32 %in_pos, %ol, %str;
208    add.u32 %in_pos, %in_pos, %k;
209    sub.s32 %in_pos_signed, %in_pos, %pad;
210    setp.lt.s32 %p_lo, %in_pos_signed, 0;
211    @%p_lo bra SKIP;
212    mov.u32 %in_pos, %in_pos_signed;
213    setp.ge.u32 %p_hi, %in_pos, %len;
214    @%p_hi bra SKIP;
215
216    // w_idx = w_base_oc + ic * ks + k
217    mul.lo.u32 %w_idx, %ic, %ks;
218    add.u32 %w_idx, %w_idx, %w_base_oc;
219    add.u32 %w_idx, %w_idx, %k;
220    // Load weight
221    cvt.u64.u32 %addr, %w_idx;
222    shl.b64 %addr, %addr, 2;
223    add.u64 %addr, %w_base, %addr;
224    ld.global.f32 %wval, [%addr];
225
226    // i_idx = ic * length + in_pos
227    mul.lo.u32 %i_idx, %ic, %len;
228    add.u32 %i_idx, %i_idx, %in_pos;
229    cvt.u64.u32 %addr, %i_idx;
230    shl.b64 %addr, %addr, 2;
231    add.u64 %addr, %in_base, %addr;
232    ld.global.f32 %ival, [%addr];
233
234    fma.rn.f32 %sum, %wval, %ival, %sum;
235
236SKIP:
237    add.u32 %k, %k, 1;
238    bra K_LOOP;
239K_DONE:
240    add.u32 %ic, %ic, 1;
241    bra IC_LOOP;
242IC_DONE:
243
244    // Add bias if present (bias_ptr != 0)
245    setp.eq.u64 %p, %b_base, 0;
246    @%p bra STORE;
247    cvt.u64.u32 %addr, %oc;
248    shl.b64 %addr, %addr, 2;
249    add.u64 %addr, %b_base, %addr;
250    ld.global.f32 %bval, [%addr];
251    add.f32 %sum, %sum, %bval;
252
253STORE:
254    // output[oc * out_length + ol]
255    mul.lo.u32 %tmp, %oc, %olen;
256    add.u32 %tmp, %tmp, %ol;
257    cvt.u64.u32 %addr, %tmp;
258    shl.b64 %addr, %addr, 2;
259    add.u64 %addr, %out_base, %addr;
260    st.global.f32 [%addr], %sum;
261
262DONE:
263    ret;
264}
265"#
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    // ---------------------------------------------------------------
273    // Scalar tests
274    // ---------------------------------------------------------------
275
276    #[test]
277    fn test_conv1d_identity() {
278        // kernel_size=1, weight=identity matrix, c_in=c_out=2, length=4
279        let c_in = 2;
280        let c_out = 2;
281        let length = 4;
282        let kernel_size = 1;
283        let stride = 1;
284        let padding = 0;
285        let out_length = (length + 2 * padding - kernel_size) / stride + 1;
286
287        // Input: [[1,2,3,4],[5,6,7,8]]
288        let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
289        // Weight: identity (2x2x1) = [[1,0],[0,1]]
290        let weight = [1.0_f32, 0.0, 0.0, 1.0];
291        let mut output = vec![0.0_f32; c_out * out_length];
292
293        conv1d_scalar(
294            &input,
295            &weight,
296            None,
297            c_in,
298            c_out,
299            length,
300            kernel_size,
301            stride,
302            padding,
303            &mut output,
304        );
305
306        // Output should equal input
307        assert_eq!(output, input.to_vec());
308    }
309
310    #[test]
311    fn test_conv1d_known_values() {
312        // Single input channel, single output channel
313        // input = [1, 2, 3, 4, 5], kernel = [1, 0, -1], stride=1, padding=0
314        let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0];
315        let weight = [1.0_f32, 0.0, -1.0];
316        let out_length = (5 - 3) + 1; // 3
317        let mut output = vec![0.0_f32; out_length];
318
319        conv1d_scalar(&input, &weight, None, 1, 1, 5, 3, 1, 0, &mut output);
320
321        // output[0] = 1*1 + 2*0 + 3*(-1) = -2
322        // output[1] = 2*1 + 3*0 + 4*(-1) = -2
323        // output[2] = 3*1 + 4*0 + 5*(-1) = -2
324        assert_eq!(output, vec![-2.0, -2.0, -2.0]);
325    }
326
327    #[test]
328    fn test_conv1d_with_bias() {
329        let input = [1.0_f32, 2.0, 3.0];
330        let weight = [1.0_f32, 1.0];
331        let bias = [10.0_f32];
332        let out_length = (3 - 2) + 1; // 2
333        let mut output = vec![0.0_f32; out_length];
334
335        conv1d_scalar(&input, &weight, Some(&bias), 1, 1, 3, 2, 1, 0, &mut output);
336
337        // output[0] = 1+2 + 10 = 13
338        // output[1] = 2+3 + 10 = 15
339        assert_eq!(output, vec![13.0, 15.0]);
340    }
341
342    #[test]
343    fn test_conv1d_with_padding() {
344        // input = [1, 2, 3], kernel = [1, 1, 1], padding=1, stride=1
345        let input = [1.0_f32, 2.0, 3.0];
346        let weight = [1.0_f32, 1.0, 1.0];
347        let out_length = (3 + 2 - 3) + 1; // 3
348        let mut output = vec![0.0_f32; out_length];
349
350        conv1d_scalar(&input, &weight, None, 1, 1, 3, 3, 1, 1, &mut output);
351
352        // output[0]: pos 0..3, with padding: [0, 1, 2] -> 0+1+2 = 3
353        // output[1]: pos 1..4, -> [1, 2, 3] -> 6
354        // output[2]: pos 2..5, with padding: [2, 3, 0] -> 5
355        assert_eq!(output, vec![3.0, 6.0, 5.0]);
356    }
357
358    #[test]
359    fn test_conv1d_with_stride() {
360        let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0];
361        let weight = [1.0_f32];
362        let out_length = (5 - 1) / 2 + 1; // 3
363        let mut output = vec![0.0_f32; out_length];
364
365        conv1d_scalar(&input, &weight, None, 1, 1, 5, 1, 2, 0, &mut output);
366
367        assert_eq!(output, vec![1.0, 3.0, 5.0]);
368    }
369
370    #[test]
371    #[should_panic(expected = "input length mismatch")]
372    fn test_conv1d_input_mismatch() {
373        let input = [1.0_f32; 5];
374        let weight = [1.0_f32; 3];
375        let mut output = [0.0_f32; 3];
376        conv1d_scalar(&input, &weight, None, 2, 1, 5, 3, 1, 0, &mut output);
377    }
378
379    // ---------------------------------------------------------------
380    // AVX2 tests
381    // ---------------------------------------------------------------
382
383    #[cfg(target_arch = "x86_64")]
384    #[test]
385    fn test_conv1d_avx2_parity() {
386        if !is_x86_feature_detected!("avx2") {
387            return;
388        }
389        let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0];
390        let weight = [1.0_f32, 0.0, -1.0];
391        let out_length = 3;
392        let mut scalar_out = vec![0.0_f32; out_length];
393        let mut avx2_out = vec![0.0_f32; out_length];
394
395        conv1d_scalar(&input, &weight, None, 1, 1, 5, 3, 1, 0, &mut scalar_out);
396        unsafe {
397            conv1d_avx2(&input, &weight, None, 1, 1, 5, 3, 1, 0, &mut avx2_out);
398        }
399        assert_eq!(scalar_out, avx2_out);
400    }
401
402    // ---------------------------------------------------------------
403    // PTX structural tests
404    // ---------------------------------------------------------------
405
406    #[test]
407    fn test_conv1d_ptx_version() {
408        let ptx = conv1d_ptx();
409        assert!(
410            ptx.contains(".version 8.5"),
411            "PTX must declare .version 8.5"
412        );
413    }
414
415    #[test]
416    fn test_conv1d_ptx_target() {
417        let ptx = conv1d_ptx();
418        assert!(ptx.contains(".target sm_90"), "PTX must target sm_90");
419    }
420
421    #[test]
422    fn test_conv1d_ptx_entry() {
423        let ptx = conv1d_ptx();
424        assert!(ptx.contains(".entry conv1d_kernel"), "PTX must have .entry");
425    }
426
427    #[test]
428    fn test_conv1d_ptx_ret() {
429        let ptx = conv1d_ptx();
430        assert!(ptx.contains("ret;"), "PTX must have ret;");
431    }
432
433    #[test]
434    fn test_conv1d_ptx_balanced_braces() {
435        let ptx = conv1d_ptx();
436        let opens = ptx.chars().filter(|&c| c == '{').count();
437        let closes = ptx.chars().filter(|&c| c == '}').count();
438        assert_eq!(
439            opens, closes,
440            "PTX must have balanced braces: {opens} opens vs {closes} closes"
441        );
442    }
443}