Skip to main content

provable_contracts/kernels/
linear.rs

1//! Linear projection kernel.
2//!
3//! Matches `linear-projection-v1.yaml`.
4//! `y = xW^T + b` — matrix multiply with optional bias.
5//!
6//! Each function provides one of three backends:
7//! - `fn linear_scalar(...)` -- Pure Rust scalar reference (ground truth)
8//! - `unsafe fn linear_avx2(...)` -- AVX2 SIMD implementation
9//! - `fn linear_ptx() -> &'static str` -- PTX assembly source string
10
11use super::ops;
12
13// ────────────────────────────────────────────────────────────────────────────
14// Scalar implementation
15// ────────────────────────────────────────────────────────────────────────────
16
17/// Linear projection (scalar reference).
18///
19/// Computes `y = x @ W^T + bias` where:
20/// - `x` is `batch x in_features` (row-major)
21/// - `weight` is `out_features x in_features` (row-major, transposed during multiply)
22/// - `bias` is `out_features` (optional, pass empty slice for no bias)
23/// - `output` is `batch x out_features`
24///
25/// # Panics
26/// Panics if dimensions are inconsistent.
27pub fn linear_scalar(
28    x: &[f32],
29    weight: &[f32],
30    bias: &[f32],
31    batch: usize,
32    in_features: usize,
33    out_features: usize,
34    output: &mut [f32],
35) {
36    assert_eq!(x.len(), batch * in_features, "x dimension mismatch");
37    assert_eq!(
38        weight.len(),
39        out_features * in_features,
40        "weight dimension mismatch"
41    );
42    assert_eq!(
43        output.len(),
44        batch * out_features,
45        "output dimension mismatch"
46    );
47    assert!(
48        bias.is_empty() || bias.len() == out_features,
49        "bias must be empty or out_features={out_features}, got {}",
50        bias.len()
51    );
52
53    // y = x @ W^T: for each row of x, dot with each row of W
54    for b in 0..batch {
55        let x_row = &x[b * in_features..(b + 1) * in_features];
56        for o in 0..out_features {
57            let w_row = &weight[o * in_features..(o + 1) * in_features];
58            let mut val = ops::dot(x_row, w_row);
59            if !bias.is_empty() {
60                val += bias[o];
61            }
62            output[b * out_features + o] = val;
63        }
64    }
65}
66
67// ────────────────────────────────────────────────────────────────────────────
68// AVX2 implementation
69// ────────────────────────────────────────────────────────────────────────────
70
71/// AVX2 linear projection -- delegates to scalar.
72///
73/// # Safety
74/// Requires AVX2 support.
75#[cfg(target_arch = "x86_64")]
76#[target_feature(enable = "avx2")]
77pub unsafe fn linear_avx2(
78    x: &[f32],
79    weight: &[f32],
80    bias: &[f32],
81    batch: usize,
82    in_features: usize,
83    out_features: usize,
84    output: &mut [f32],
85) {
86    linear_scalar(x, weight, bias, batch, in_features, out_features, output);
87}
88
89// ────────────────────────────────────────────────────────────────────────────
90// PTX implementation
91// ────────────────────────────────────────────────────────────────────────────
92
93/// PTX assembly for linear projection.
94///
95/// One thread per output element (batch_idx, out_feature). Each thread
96/// computes one dot product of x_row and w_row, then adds bias.
97pub fn linear_ptx() -> &'static str {
98    r#".version 8.5
99.target sm_90
100.address_size 64
101.visible .entry linear_kernel(
102    .param .u64 X,
103    .param .u64 W,
104    .param .u64 BIAS,
105    .param .u64 OUT,
106    .param .u32 BATCH,
107    .param .u32 IN_FEAT,
108    .param .u32 OUT_FEAT,
109    .param .u32 HAS_BIAS
110) {
111    .reg .u32 %tid, %bid, %batch, %in_feat, %out_feat, %has_bias;
112    .reg .u32 %b_idx, %o_idx, %k, %tmp32;
113    .reg .u64 %x_ptr, %w_ptr, %bias_ptr, %out_ptr, %addr, %off64;
114    .reg .f32 %acc, %x_val, %w_val, %bias_val;
115    .reg .pred %p_k, %p_bias, %p_bound;
116
117    mov.u32 %tid, %tid.x;
118    mov.u32 %bid, %ctaid.x;
119
120    ld.param.u32 %batch, [BATCH];
121    ld.param.u32 %in_feat, [IN_FEAT];
122    ld.param.u32 %out_feat, [OUT_FEAT];
123    ld.param.u32 %has_bias, [HAS_BIAS];
124    ld.param.u64 %x_ptr, [X];
125    ld.param.u64 %w_ptr, [W];
126    ld.param.u64 %bias_ptr, [BIAS];
127    ld.param.u64 %out_ptr, [OUT];
128
129    // bid = batch index, tid = output feature index
130    mov.u32 %b_idx, %bid;
131    mov.u32 %o_idx, %tid;
132
133    setp.ge.u32 %p_bound, %o_idx, %out_feat;
134    @%p_bound bra EXIT;
135
136    // acc = dot(x[b_idx], w[o_idx])
137    mov.f32 %acc, 0f00000000;
138    mov.u32 %k, 0;
139DOT_LOOP:
140    setp.ge.u32 %p_k, %k, %in_feat;
141    @%p_k bra DOT_DONE;
142
143    // x[b_idx * in_feat + k]
144    mad.lo.u32 %tmp32, %b_idx, %in_feat, %k;
145    mul.wide.u32 %off64, %tmp32, 4;
146    add.u64 %addr, %x_ptr, %off64;
147    ld.global.f32 %x_val, [%addr];
148
149    // w[o_idx * in_feat + k]
150    mad.lo.u32 %tmp32, %o_idx, %in_feat, %k;
151    mul.wide.u32 %off64, %tmp32, 4;
152    add.u64 %addr, %w_ptr, %off64;
153    ld.global.f32 %w_val, [%addr];
154
155    fma.rn.f32 %acc, %x_val, %w_val, %acc;
156    add.u32 %k, %k, 1;
157    bra DOT_LOOP;
158DOT_DONE:
159
160    // Add bias if present
161    setp.eq.u32 %p_bias, %has_bias, 0;
162    @%p_bias bra STORE;
163    mul.wide.u32 %off64, %o_idx, 4;
164    add.u64 %addr, %bias_ptr, %off64;
165    ld.global.f32 %bias_val, [%addr];
166    add.f32 %acc, %acc, %bias_val;
167
168STORE:
169    mad.lo.u32 %tmp32, %b_idx, %out_feat, %o_idx;
170    mul.wide.u32 %off64, %tmp32, 4;
171    add.u64 %addr, %out_ptr, %off64;
172    st.global.f32 [%addr], %acc;
173
174EXIT:
175    ret;
176}
177"#
178}
179
180// ────────────────────────────────────────────────────────────────────────────
181// Tests
182// ────────────────────────────────────────────────────────────────────────────
183
184#[cfg(test)]
185mod tests {
186    use super::super::ulp::assert_ulp_eq;
187    use super::*;
188    use proptest::prelude::*;
189
190    /// Verify linear projection with bias produces correct known-answer result
191    #[test]
192    fn test_linear_basic_with_bias() {
193        // x = [[1, 2]], W = [[3, 4], [5, 6]], b = [10, 20]
194        // y = x @ W^T + b = [[1*3+2*4+10, 1*5+2*6+20]] = [[21, 37]]
195        let x = [1.0, 2.0];
196        let w = [3.0, 4.0, 5.0, 6.0]; // 2x2
197        let b = [10.0, 20.0];
198        let mut output = [0.0f32; 2];
199
200        linear_scalar(&x, &w, &b, 1, 2, 2, &mut output);
201        assert!((output[0] - 21.0).abs() < 1e-5);
202        assert!((output[1] - 37.0).abs() < 1e-5);
203    }
204
205    /// Verify linear projection works correctly with an empty bias slice
206    #[test]
207    fn test_linear_no_bias() {
208        let x = [1.0, 0.0];
209        let w = [1.0, 0.0, 0.0, 1.0]; // identity-ish
210        let mut output = [0.0f32; 2];
211
212        linear_scalar(&x, &w, &[], 1, 2, 2, &mut output);
213        assert!((output[0] - 1.0).abs() < 1e-5);
214        assert!((output[1] - 0.0).abs() < 1e-5);
215    }
216
217    /// Verify zero input produces output equal to bias
218    #[test]
219    fn test_linear_zero_input_returns_bias() {
220        let x = [0.0, 0.0, 0.0];
221        let w = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2x3
222        let b = [7.0, 8.0];
223        let mut output = [0.0f32; 2];
224
225        linear_scalar(&x, &w, &b, 1, 3, 2, &mut output);
226        assert!((output[0] - 7.0).abs() < 1e-5);
227        assert!((output[1] - 8.0).abs() < 1e-5);
228    }
229
230    /// Verify linear projection handles batched inputs correctly
231    #[test]
232    fn test_linear_batch() {
233        // batch=2, in=2, out=1, W=[[1,1]], no bias
234        let x = [1.0, 2.0, 3.0, 4.0]; // 2x2
235        let w = [1.0, 1.0]; // 1x2
236        let mut output = [0.0f32; 2]; // 2x1
237
238        linear_scalar(&x, &w, &[], 2, 2, 1, &mut output);
239        assert!((output[0] - 3.0).abs() < 1e-5); // 1+2
240        assert!((output[1] - 7.0).abs() < 1e-5); // 3+4
241    }
242
243    /// Verify linear homogeneity: f(2x) = 2*f(x) when bias is absent
244    #[test]
245    fn test_linear_linearity() {
246        // f(2x) = 2*f(x) when no bias
247        let x1 = [1.0, 2.0, 3.0];
248        let x2: Vec<f32> = x1.iter().map(|v| v * 2.0).collect();
249        let w = [0.5, 0.3, 0.1, 0.2, 0.4, 0.6]; // 2x3
250        let mut out1 = [0.0f32; 2];
251        let mut out2 = [0.0f32; 2];
252
253        linear_scalar(&x1, &w, &[], 1, 3, 2, &mut out1);
254        linear_scalar(&x2, &w, &[], 1, 3, 2, &mut out2);
255
256        for i in 0..2 {
257            assert!(
258                (out2[i] - 2.0 * out1[i]).abs() < 1e-5,
259                "linearity violated at {i}: f(2x)={} vs 2*f(x)={}",
260                out2[i],
261                2.0 * out1[i]
262            );
263        }
264    }
265
266    proptest! {
267        #[test]
268        fn prop_linear_output_finite(
269            batch in 1usize..3,
270            in_f in 1usize..5,
271            out_f in 1usize..5,
272        ) {
273            let x: Vec<f32> = (0..batch * in_f).map(|i| (i as f32) * 0.1).collect();
274            let w: Vec<f32> = (0..out_f * in_f).map(|i| (i as f32) * 0.1).collect();
275            let b: Vec<f32> = (0..out_f).map(|i| (i as f32) * 0.01).collect();
276            let mut output = vec![0.0f32; batch * out_f];
277
278            linear_scalar(&x, &w, &b, batch, in_f, out_f, &mut output);
279
280            for (idx, &val) in output.iter().enumerate() {
281                prop_assert!(val.is_finite(), "output[{idx}] = {val} not finite");
282            }
283        }
284    }
285
286    /// Verify linear PTX contains required entry point and instructions
287    #[test]
288    fn test_linear_ptx_structure() {
289        let ptx = linear_ptx();
290        assert!(ptx.contains(".entry linear_kernel"));
291        assert!(ptx.contains("fma.rn.f32"));
292        assert!(ptx.contains("ret;"));
293    }
294
295    /// Verify AVX2 linear projection produces identical results to scalar
296    #[cfg(target_arch = "x86_64")]
297    #[test]
298    fn test_linear_avx2_parity() {
299        if !is_x86_feature_detected!("avx2") {
300            return;
301        }
302        let x = [1.0, 2.0, 3.0, 4.0]; // 1x4
303        let w = [0.5, 0.3, 0.1, 0.2, 0.4, 0.6, 0.7, 0.8]; // 2x4
304        let b = [1.0, 2.0];
305        let mut scalar_out = [0.0f32; 2]; // 1x2
306        let mut avx2_out = [0.0f32; 2];
307        linear_scalar(&x, &w, &b, 1, 4, 2, &mut scalar_out);
308        unsafe { linear_avx2(&x, &w, &b, 1, 4, 2, &mut avx2_out) };
309        assert_ulp_eq(&scalar_out, &avx2_out, 0);
310    }
311}