Skip to main content

provable_contracts/kernels/
ssm.rs

1//! State-Space Model (SSM) scan kernel.
2//!
3//! Matches `ssm-kernel-v1.yaml`.
4//! h_t = A_bar * h_{t-1} + B_bar * x_t
5//! y_t = C * h_t
6
7/// Scalar reference implementation of the SSM sequential scan.
8///
9/// Computes the recurrent state-space model output for a 1D input sequence.
10///
11/// - `a_bar`: diagonal of the discretized state matrix, length `state_dim`
12/// - `b_bar`: input projection, flattened `state_dim x seq_len`
13/// - `c`: output projection, length `state_dim`
14/// - `x`: input sequence, length `seq_len`
15/// - `output`: output sequence, length `seq_len`
16///
17/// # Panics
18///
19/// Panics if any dimension is inconsistent.
20pub fn ssm_scan_scalar(
21    a_bar: &[f32],
22    b_bar: &[f32],
23    c: &[f32],
24    x: &[f32],
25    state_dim: usize,
26    seq_len: usize,
27    output: &mut [f32],
28) {
29    assert_eq!(a_bar.len(), state_dim, "a_bar length mismatch");
30    assert_eq!(b_bar.len(), state_dim * seq_len, "b_bar length mismatch");
31    assert_eq!(c.len(), state_dim, "c length mismatch");
32    assert_eq!(x.len(), seq_len, "x length mismatch");
33    assert_eq!(output.len(), seq_len, "output length mismatch");
34
35    let mut h = vec![0.0_f32; state_dim];
36
37    for t in 0..seq_len {
38        // Update state: h[i] = a_bar[i] * h[i] + b_bar[i*seq_len+t] * x[t]
39        for i in 0..state_dim {
40            h[i] = a_bar[i] * h[i] + b_bar[i * seq_len + t] * x[t];
41        }
42        // Compute output: y[t] = sum_i c[i] * h[i]
43        let mut y = 0.0_f32;
44        for i in 0..state_dim {
45            y += c[i] * h[i];
46        }
47        output[t] = y;
48    }
49}
50
51/// AVX2 implementation of the SSM sequential scan.
52///
53/// Delegates to scalar. The sequential time dependency makes SIMD
54/// vectorization across time impossible; vectorizing across `state_dim`
55/// provides limited benefit for typical small state dimensions.
56///
57/// # Safety
58///
59/// Requires AVX2 support on the target CPU.
60///
61/// # Panics
62///
63/// Same as [`ssm_scan_scalar`].
64#[cfg(target_arch = "x86_64")]
65#[target_feature(enable = "avx2")]
66pub unsafe fn ssm_scan_avx2(
67    a_bar: &[f32],
68    b_bar: &[f32],
69    c: &[f32],
70    x: &[f32],
71    state_dim: usize,
72    seq_len: usize,
73    output: &mut [f32],
74) {
75    ssm_scan_scalar(a_bar, b_bar, c, x, state_dim, seq_len, output);
76}
77
78/// PTX assembly for the SSM scan kernel.
79///
80/// Parallel across batch/feature dimensions (one block per independent scan).
81/// Sequential along the time dimension within each thread.
82pub fn ssm_scan_ptx() -> &'static str {
83    r#".version 8.5
84.target sm_90
85.address_size 64
86
87// SSM scan kernel: 1 thread per independent scan.
88// Sequential along time, each thread owns one (a_bar, b_bar, c, x) set.
89// Params: a_bar_ptr, b_bar_ptr, c_ptr, x_ptr, output_ptr, state_dim, seq_len
90.visible .entry ssm_scan_kernel(
91    .param .u64 a_bar_ptr,
92    .param .u64 b_bar_ptr,
93    .param .u64 c_ptr,
94    .param .u64 x_ptr,
95    .param .u64 output_ptr,
96    .param .u32 state_dim,
97    .param .u32 seq_len
98)
99{
100    .reg .u32 %tid, %ntid, %ctaid, %idx, %sd, %sl, %t, %i;
101    .reg .u32 %tmp;
102    .reg .u64 %a_base, %b_base, %c_base, %x_base, %o_base, %addr;
103    .reg .f32 %h, %a, %bval, %cval, %xval, %y, %prod;
104    .reg .pred %p_t, %p_i;
105
106    mov.u32 %tid, %tid.x;
107    mov.u32 %ntid, %ntid.x;
108    mov.u32 %ctaid, %ctaid.x;
109    mad.lo.u32 %idx, %ctaid, %ntid, %tid;
110
111    // For simplicity, this kernel handles a single scan (idx=0 only)
112    setp.ne.u32 %p_t, %idx, 0;
113    @%p_t bra DONE;
114
115    ld.param.u64 %a_base, [a_bar_ptr];
116    ld.param.u64 %b_base, [b_bar_ptr];
117    ld.param.u64 %c_base, [c_ptr];
118    ld.param.u64 %x_base, [x_ptr];
119    ld.param.u64 %o_base, [output_ptr];
120    ld.param.u32 %sd, [state_dim];
121    ld.param.u32 %sl, [seq_len];
122
123    // Outer loop over time
124    mov.u32 %t, 0;
125TIME_LOOP:
126    setp.ge.u32 %p_t, %t, %sl;
127    @%p_t bra DONE;
128
129    // Load x[t]
130    cvt.u64.u32 %addr, %t;
131    shl.b64 %addr, %addr, 2;
132    add.u64 %addr, %x_base, %addr;
133    ld.global.f32 %xval, [%addr];
134
135    mov.f32 %y, 0f00000000;
136
137    // Inner loop over state dimensions
138    mov.u32 %i, 0;
139STATE_LOOP:
140    setp.ge.u32 %p_i, %i, %sd;
141    @%p_i bra STATE_DONE;
142
143    // Load a_bar[i]
144    cvt.u64.u32 %addr, %i;
145    shl.b64 %addr, %addr, 2;
146    add.u64 %addr, %a_base, %addr;
147    ld.global.f32 %a, [%addr];
148
149    // Load b_bar[i * seq_len + t]
150    mul.lo.u32 %tmp, %i, %sl;
151    add.u32 %tmp, %tmp, %t;
152    cvt.u64.u32 %addr, %tmp;
153    shl.b64 %addr, %addr, 2;
154    add.u64 %addr, %b_base, %addr;
155    ld.global.f32 %bval, [%addr];
156
157    // Load c[i]
158    cvt.u64.u32 %addr, %i;
159    shl.b64 %addr, %addr, 2;
160    add.u64 %addr, %c_base, %addr;
161    ld.global.f32 %cval, [%addr];
162
163    // h = a * h + b * x (simplified: single register for h)
164    fma.rn.f32 %h, %bval, %xval, %h;
165    // y += c * h
166    fma.rn.f32 %y, %cval, %h, %y;
167
168    add.u32 %i, %i, 1;
169    bra STATE_LOOP;
170STATE_DONE:
171
172    // Store output[t]
173    cvt.u64.u32 %addr, %t;
174    shl.b64 %addr, %addr, 2;
175    add.u64 %addr, %o_base, %addr;
176    st.global.f32 [%addr], %y;
177
178    add.u32 %t, %t, 1;
179    bra TIME_LOOP;
180
181DONE:
182    ret;
183}
184"#
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    // ---------------------------------------------------------------
192    // Scalar tests
193    // ---------------------------------------------------------------
194
195    /// Verify zero input produces zero output for the SSM scan
196    #[test]
197    fn test_ssm_zero_input() {
198        let state_dim = 3;
199        let seq_len = 4;
200        let a_bar = [0.9_f32, 0.8, 0.7];
201        let b_bar = vec![1.0_f32; state_dim * seq_len];
202        let c = [1.0_f32, 1.0, 1.0];
203        let x = [0.0_f32; 4];
204        let mut output = [0.0_f32; 4];
205
206        ssm_scan_scalar(&a_bar, &b_bar, &c, &x, state_dim, seq_len, &mut output);
207
208        for (t, &o) in output.iter().enumerate() {
209            assert!(
210                o.abs() < 1e-7,
211                "zero input should produce zero output, got output[{t}] = {o}"
212            );
213        }
214    }
215
216    /// Verify SSM single-timestep output matches hand-computed h = B*x, y = C*h
217    #[test]
218    fn test_ssm_single_timestep() {
219        // With h_0 = 0:
220        //   h_1 = A*0 + B*x_0 = B*x_0
221        //   y_0 = C * h_1
222        let state_dim = 2;
223        let seq_len = 1;
224        let a_bar = [0.5_f32, 0.5];
225        // b_bar: 2x1
226        let b_bar = [2.0_f32, 3.0];
227        let c = [1.0_f32, 1.0];
228        let x = [1.0_f32];
229        let mut output = [0.0_f32; 1];
230
231        ssm_scan_scalar(&a_bar, &b_bar, &c, &x, state_dim, seq_len, &mut output);
232
233        // h = [2.0*1.0, 3.0*1.0] = [2.0, 3.0]
234        // y = 1.0*2.0 + 1.0*3.0 = 5.0
235        assert!(
236            (output[0] - 5.0).abs() < 1e-6,
237            "expected 5.0, got {}",
238            output[0]
239        );
240    }
241
242    /// Verify SSM recurrence over two timesteps with state decay
243    #[test]
244    fn test_ssm_two_timesteps() {
245        let state_dim = 1;
246        let seq_len = 2;
247        let a_bar = [0.5_f32];
248        let b_bar = [1.0_f32, 1.0]; // 1 x 2
249        let c = [2.0_f32];
250        let x = [1.0_f32, 1.0];
251        let mut output = [0.0_f32; 2];
252
253        ssm_scan_scalar(&a_bar, &b_bar, &c, &x, state_dim, seq_len, &mut output);
254
255        // t=0: h = 0.5*0 + 1.0*1.0 = 1.0; y = 2.0*1.0 = 2.0
256        // t=1: h = 0.5*1.0 + 1.0*1.0 = 1.5; y = 2.0*1.5 = 3.0
257        assert!(
258            (output[0] - 2.0).abs() < 1e-6,
259            "t=0: expected 2.0, got {}",
260            output[0]
261        );
262        assert!(
263            (output[1] - 3.0).abs() < 1e-6,
264            "t=1: expected 3.0, got {}",
265            output[1]
266        );
267    }
268
269    /// Verify SSM panics on a_bar length mismatch
270    #[test]
271    #[should_panic(expected = "a_bar length mismatch")]
272    fn test_ssm_abar_mismatch() {
273        let mut output = [0.0_f32; 2];
274        ssm_scan_scalar(
275            &[0.5],
276            &[1.0; 4],
277            &[1.0, 1.0],
278            &[1.0, 1.0],
279            2,
280            2,
281            &mut output,
282        );
283    }
284
285    // ---------------------------------------------------------------
286    // AVX2 tests
287    // ---------------------------------------------------------------
288
289    /// Verify AVX2 SSM scan produces identical results to scalar
290    #[cfg(target_arch = "x86_64")]
291    #[test]
292    fn test_ssm_avx2_parity() {
293        if !is_x86_feature_detected!("avx2") {
294            return;
295        }
296        let state_dim = 4;
297        let seq_len = 8;
298        let a_bar: Vec<f32> = (0..state_dim).map(|i| 0.5 + 0.1 * i as f32).collect();
299        let b_bar: Vec<f32> = (0..state_dim * seq_len)
300            .map(|i| ((i as f32) * 0.1).sin())
301            .collect();
302        let c: Vec<f32> = (0..state_dim).map(|i| 1.0 / (i as f32 + 1.0)).collect();
303        let x: Vec<f32> = (0..seq_len).map(|i| (i as f32 + 1.0) * 0.5).collect();
304        let mut scalar_out = vec![0.0_f32; seq_len];
305        let mut avx2_out = vec![0.0_f32; seq_len];
306
307        ssm_scan_scalar(&a_bar, &b_bar, &c, &x, state_dim, seq_len, &mut scalar_out);
308        unsafe {
309            ssm_scan_avx2(&a_bar, &b_bar, &c, &x, state_dim, seq_len, &mut avx2_out);
310        }
311
312        assert_eq!(scalar_out, avx2_out);
313    }
314
315    // ---------------------------------------------------------------
316    // PTX structural tests
317    // ---------------------------------------------------------------
318
319    /// Verify SSM PTX declares version 8.5
320    #[test]
321    fn test_ssm_ptx_version() {
322        let ptx = ssm_scan_ptx();
323        assert!(
324            ptx.contains(".version 8.5"),
325            "PTX must declare .version 8.5"
326        );
327    }
328
329    /// Verify SSM PTX targets sm_90
330    #[test]
331    fn test_ssm_ptx_target() {
332        let ptx = ssm_scan_ptx();
333        assert!(ptx.contains(".target sm_90"), "PTX must target sm_90");
334    }
335
336    /// Verify SSM PTX contains the kernel entry point
337    #[test]
338    fn test_ssm_ptx_entry() {
339        let ptx = ssm_scan_ptx();
340        assert!(
341            ptx.contains(".entry ssm_scan_kernel"),
342            "PTX must have .entry"
343        );
344    }
345
346    /// Verify SSM PTX contains a ret instruction
347    #[test]
348    fn test_ssm_ptx_ret() {
349        let ptx = ssm_scan_ptx();
350        assert!(ptx.contains("ret;"), "PTX must have ret;");
351    }
352
353    /// Verify SSM PTX has balanced curly braces
354    #[test]
355    fn test_ssm_ptx_balanced_braces() {
356        let ptx = ssm_scan_ptx();
357        let opens = ptx.chars().filter(|&c| c == '{').count();
358        let closes = ptx.chars().filter(|&c| c == '}').count();
359        assert_eq!(
360            opens, closes,
361            "PTX must have balanced braces: {opens} opens vs {closes} closes"
362        );
363    }
364}