1pub fn ssm_scan_scalar(
21 a_bar: &[f32],
22 b_bar: &[f32],
23 c: &[f32],
24 x: &[f32],
25 state_dim: usize,
26 seq_len: usize,
27 output: &mut [f32],
28) {
29 assert_eq!(a_bar.len(), state_dim, "a_bar length mismatch");
30 assert_eq!(b_bar.len(), state_dim * seq_len, "b_bar length mismatch");
31 assert_eq!(c.len(), state_dim, "c length mismatch");
32 assert_eq!(x.len(), seq_len, "x length mismatch");
33 assert_eq!(output.len(), seq_len, "output length mismatch");
34
35 let mut h = vec![0.0_f32; state_dim];
36
37 for t in 0..seq_len {
38 for i in 0..state_dim {
40 h[i] = a_bar[i] * h[i] + b_bar[i * seq_len + t] * x[t];
41 }
42 let mut y = 0.0_f32;
44 for i in 0..state_dim {
45 y += c[i] * h[i];
46 }
47 output[t] = y;
48 }
49}
50
51#[cfg(target_arch = "x86_64")]
65#[target_feature(enable = "avx2")]
66pub unsafe fn ssm_scan_avx2(
67 a_bar: &[f32],
68 b_bar: &[f32],
69 c: &[f32],
70 x: &[f32],
71 state_dim: usize,
72 seq_len: usize,
73 output: &mut [f32],
74) {
75 ssm_scan_scalar(a_bar, b_bar, c, x, state_dim, seq_len, output);
76}
77
78pub fn ssm_scan_ptx() -> &'static str {
83 r#".version 8.5
84.target sm_90
85.address_size 64
86
87// SSM scan kernel: 1 thread per independent scan.
88// Sequential along time, each thread owns one (a_bar, b_bar, c, x) set.
89// Params: a_bar_ptr, b_bar_ptr, c_ptr, x_ptr, output_ptr, state_dim, seq_len
90.visible .entry ssm_scan_kernel(
91 .param .u64 a_bar_ptr,
92 .param .u64 b_bar_ptr,
93 .param .u64 c_ptr,
94 .param .u64 x_ptr,
95 .param .u64 output_ptr,
96 .param .u32 state_dim,
97 .param .u32 seq_len
98)
99{
100 .reg .u32 %tid, %ntid, %ctaid, %idx, %sd, %sl, %t, %i;
101 .reg .u32 %tmp;
102 .reg .u64 %a_base, %b_base, %c_base, %x_base, %o_base, %addr;
103 .reg .f32 %h, %a, %bval, %cval, %xval, %y, %prod;
104 .reg .pred %p_t, %p_i;
105
106 mov.u32 %tid, %tid.x;
107 mov.u32 %ntid, %ntid.x;
108 mov.u32 %ctaid, %ctaid.x;
109 mad.lo.u32 %idx, %ctaid, %ntid, %tid;
110
111 // For simplicity, this kernel handles a single scan (idx=0 only)
112 setp.ne.u32 %p_t, %idx, 0;
113 @%p_t bra DONE;
114
115 ld.param.u64 %a_base, [a_bar_ptr];
116 ld.param.u64 %b_base, [b_bar_ptr];
117 ld.param.u64 %c_base, [c_ptr];
118 ld.param.u64 %x_base, [x_ptr];
119 ld.param.u64 %o_base, [output_ptr];
120 ld.param.u32 %sd, [state_dim];
121 ld.param.u32 %sl, [seq_len];
122
123 // Outer loop over time
124 mov.u32 %t, 0;
125TIME_LOOP:
126 setp.ge.u32 %p_t, %t, %sl;
127 @%p_t bra DONE;
128
129 // Load x[t]
130 cvt.u64.u32 %addr, %t;
131 shl.b64 %addr, %addr, 2;
132 add.u64 %addr, %x_base, %addr;
133 ld.global.f32 %xval, [%addr];
134
135 mov.f32 %y, 0f00000000;
136
137 // Inner loop over state dimensions
138 mov.u32 %i, 0;
139STATE_LOOP:
140 setp.ge.u32 %p_i, %i, %sd;
141 @%p_i bra STATE_DONE;
142
143 // Load a_bar[i]
144 cvt.u64.u32 %addr, %i;
145 shl.b64 %addr, %addr, 2;
146 add.u64 %addr, %a_base, %addr;
147 ld.global.f32 %a, [%addr];
148
149 // Load b_bar[i * seq_len + t]
150 mul.lo.u32 %tmp, %i, %sl;
151 add.u32 %tmp, %tmp, %t;
152 cvt.u64.u32 %addr, %tmp;
153 shl.b64 %addr, %addr, 2;
154 add.u64 %addr, %b_base, %addr;
155 ld.global.f32 %bval, [%addr];
156
157 // Load c[i]
158 cvt.u64.u32 %addr, %i;
159 shl.b64 %addr, %addr, 2;
160 add.u64 %addr, %c_base, %addr;
161 ld.global.f32 %cval, [%addr];
162
163 // h = a * h + b * x (simplified: single register for h)
164 fma.rn.f32 %h, %bval, %xval, %h;
165 // y += c * h
166 fma.rn.f32 %y, %cval, %h, %y;
167
168 add.u32 %i, %i, 1;
169 bra STATE_LOOP;
170STATE_DONE:
171
172 // Store output[t]
173 cvt.u64.u32 %addr, %t;
174 shl.b64 %addr, %addr, 2;
175 add.u64 %addr, %o_base, %addr;
176 st.global.f32 [%addr], %y;
177
178 add.u32 %t, %t, 1;
179 bra TIME_LOOP;
180
181DONE:
182 ret;
183}
184"#
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
197 fn test_ssm_zero_input() {
198 let state_dim = 3;
199 let seq_len = 4;
200 let a_bar = [0.9_f32, 0.8, 0.7];
201 let b_bar = vec![1.0_f32; state_dim * seq_len];
202 let c = [1.0_f32, 1.0, 1.0];
203 let x = [0.0_f32; 4];
204 let mut output = [0.0_f32; 4];
205
206 ssm_scan_scalar(&a_bar, &b_bar, &c, &x, state_dim, seq_len, &mut output);
207
208 for (t, &o) in output.iter().enumerate() {
209 assert!(
210 o.abs() < 1e-7,
211 "zero input should produce zero output, got output[{t}] = {o}"
212 );
213 }
214 }
215
216 #[test]
218 fn test_ssm_single_timestep() {
219 let state_dim = 2;
223 let seq_len = 1;
224 let a_bar = [0.5_f32, 0.5];
225 let b_bar = [2.0_f32, 3.0];
227 let c = [1.0_f32, 1.0];
228 let x = [1.0_f32];
229 let mut output = [0.0_f32; 1];
230
231 ssm_scan_scalar(&a_bar, &b_bar, &c, &x, state_dim, seq_len, &mut output);
232
233 assert!(
236 (output[0] - 5.0).abs() < 1e-6,
237 "expected 5.0, got {}",
238 output[0]
239 );
240 }
241
242 #[test]
244 fn test_ssm_two_timesteps() {
245 let state_dim = 1;
246 let seq_len = 2;
247 let a_bar = [0.5_f32];
248 let b_bar = [1.0_f32, 1.0]; let c = [2.0_f32];
250 let x = [1.0_f32, 1.0];
251 let mut output = [0.0_f32; 2];
252
253 ssm_scan_scalar(&a_bar, &b_bar, &c, &x, state_dim, seq_len, &mut output);
254
255 assert!(
258 (output[0] - 2.0).abs() < 1e-6,
259 "t=0: expected 2.0, got {}",
260 output[0]
261 );
262 assert!(
263 (output[1] - 3.0).abs() < 1e-6,
264 "t=1: expected 3.0, got {}",
265 output[1]
266 );
267 }
268
269 #[test]
271 #[should_panic(expected = "a_bar length mismatch")]
272 fn test_ssm_abar_mismatch() {
273 let mut output = [0.0_f32; 2];
274 ssm_scan_scalar(
275 &[0.5],
276 &[1.0; 4],
277 &[1.0, 1.0],
278 &[1.0, 1.0],
279 2,
280 2,
281 &mut output,
282 );
283 }
284
285 #[cfg(target_arch = "x86_64")]
291 #[test]
292 fn test_ssm_avx2_parity() {
293 if !is_x86_feature_detected!("avx2") {
294 return;
295 }
296 let state_dim = 4;
297 let seq_len = 8;
298 let a_bar: Vec<f32> = (0..state_dim).map(|i| 0.5 + 0.1 * i as f32).collect();
299 let b_bar: Vec<f32> = (0..state_dim * seq_len)
300 .map(|i| ((i as f32) * 0.1).sin())
301 .collect();
302 let c: Vec<f32> = (0..state_dim).map(|i| 1.0 / (i as f32 + 1.0)).collect();
303 let x: Vec<f32> = (0..seq_len).map(|i| (i as f32 + 1.0) * 0.5).collect();
304 let mut scalar_out = vec![0.0_f32; seq_len];
305 let mut avx2_out = vec![0.0_f32; seq_len];
306
307 ssm_scan_scalar(&a_bar, &b_bar, &c, &x, state_dim, seq_len, &mut scalar_out);
308 unsafe {
309 ssm_scan_avx2(&a_bar, &b_bar, &c, &x, state_dim, seq_len, &mut avx2_out);
310 }
311
312 assert_eq!(scalar_out, avx2_out);
313 }
314
315 #[test]
321 fn test_ssm_ptx_version() {
322 let ptx = ssm_scan_ptx();
323 assert!(
324 ptx.contains(".version 8.5"),
325 "PTX must declare .version 8.5"
326 );
327 }
328
329 #[test]
331 fn test_ssm_ptx_target() {
332 let ptx = ssm_scan_ptx();
333 assert!(ptx.contains(".target sm_90"), "PTX must target sm_90");
334 }
335
336 #[test]
338 fn test_ssm_ptx_entry() {
339 let ptx = ssm_scan_ptx();
340 assert!(
341 ptx.contains(".entry ssm_scan_kernel"),
342 "PTX must have .entry"
343 );
344 }
345
346 #[test]
348 fn test_ssm_ptx_ret() {
349 let ptx = ssm_scan_ptx();
350 assert!(ptx.contains("ret;"), "PTX must have ret;");
351 }
352
353 #[test]
355 fn test_ssm_ptx_balanced_braces() {
356 let ptx = ssm_scan_ptx();
357 let opens = ptx.chars().filter(|&c| c == '{').count();
358 let closes = ptx.chars().filter(|&c| c == '}').count();
359 assert_eq!(
360 opens, closes,
361 "PTX must have balanced braces: {opens} opens vs {closes} closes"
362 );
363 }
364}