Skip to main content

oxicuda_seq/
ptx_kernels.rs

1//! GPU PTX kernels for Sequence Models & Structured Prediction.
2//!
3//! Each kernel is emitted as a self-contained PTX module string, parameterised
4//! on the SM version.  PTX ISA selection by SM:
5//!     SM≥100 → 8.7 (Blackwell), SM≥90 → 8.4 (Hopper),
6//!     SM≥80  → 8.0 (Ampere),    else → 7.5 (Turing).
7//!
8//! IMPORTANT: PTX kernel bodies use **string concatenation** (NOT `format!()`)
9//! for sections containing `%rd`, `%r`, `%f` register names, which Rust's
10//! `format!` macro would reject as malformed positional specifiers in
11//! edition 2024.
12
13/// Build a PTX file header string for the given SM version.
14fn ptx_header(sm: u32) -> String {
15    let (ptx_ver, target) = match sm {
16        v if v >= 100 => ("8.7", format!("sm_{v}")),
17        v if v >= 90 => ("8.4", format!("sm_{v}")),
18        v if v >= 80 => ("8.0", format!("sm_{v}")),
19        v => ("7.5", format!("sm_{v}")),
20    };
21    format!(".version {ptx_ver}\n.target {target}\n.address_size 64\n\n")
22}
23
24/// HMM forward-pass kernel (log-space).
25///
26/// Signature: `forward_pass_kernel(alpha_prev, alpha_next, log_a, log_b_o, n_states)`
27/// One thread per destination state `j`; reads `alpha_prev[i] + log_a[i*S+j]`
28/// over all `i`, applies max+log-sum-exp, then adds `log_b_o[j]`.
29#[must_use]
30pub fn forward_pass_ptx(sm: u32) -> String {
31    let hdr = ptx_header(sm);
32    let body = ".visible .entry forward_pass_kernel(\n\
33        .param .u64 p_alpha_prev,\n\
34        .param .u64 p_alpha_next,\n\
35        .param .u64 p_log_a,\n\
36        .param .u64 p_log_b_o,\n\
37        .param .u32 p_n_states\n\
38    )\n\
39    {\n\
40        .reg .u64  %rd<10>;\n\
41        .reg .u32  %r<10>;\n\
42        .reg .f32  %f<10>;\n\
43        .reg .pred %p0;\n\
44    \n\
45        ld.param.u64  %rd0, [p_alpha_prev];\n\
46        ld.param.u64  %rd1, [p_alpha_next];\n\
47        ld.param.u64  %rd2, [p_log_a];\n\
48        ld.param.u64  %rd3, [p_log_b_o];\n\
49        ld.param.u32  %r0,  [p_n_states];\n\
50    \n\
51        // j = global thread id\n\
52        mov.u32       %r1, %ntid.x;\n\
53        mov.u32       %r2, %ctaid.x;\n\
54        mov.u32       %r3, %tid.x;\n\
55        mad.lo.u32    %r4, %r1, %r2, %r3;\n\
56        setp.ge.u32   %p0, %r4, %r0;\n\
57        @%p0 bra $FP_DONE;\n\
58    \n\
59        // First pass: find max of (alpha_prev[i] + log_a[i*S + j])\n\
60        mov.f32       %f0, 0fFF800000;   // -inf\n\
61        mov.u32       %r5, 0;\n\
62    $FP_MAX:\n\
63        setp.ge.u32   %p0, %r5, %r0;\n\
64        @%p0 bra $FP_SUM_INIT;\n\
65        // alpha_prev[i]\n\
66        mul.wide.u32  %rd4, %r5, 4;\n\
67        add.u64       %rd5, %rd0, %rd4;\n\
68        ld.global.f32 %f1, [%rd5];\n\
69        // log_a[i*S + j]\n\
70        mul.lo.u32    %r6, %r5, %r0;\n\
71        add.u32       %r6, %r6, %r4;\n\
72        mul.wide.u32  %rd6, %r6, 4;\n\
73        add.u64       %rd7, %rd2, %rd6;\n\
74        ld.global.f32 %f2, [%rd7];\n\
75        add.f32       %f3, %f1, %f2;\n\
76        max.f32       %f0, %f0, %f3;\n\
77        add.u32       %r5, %r5, 1;\n\
78        bra $FP_MAX;\n\
79    \n\
80    $FP_SUM_INIT:\n\
81        // Second pass: accumulate exp((alpha_prev[i] + log_a[i*S+j]) - max)\n\
82        mov.f32       %f4, 0f00000000;\n\
83        mov.u32       %r5, 0;\n\
84    $FP_SUM:\n\
85        setp.ge.u32   %p0, %r5, %r0;\n\
86        @%p0 bra $FP_WRITE;\n\
87        mul.wide.u32  %rd4, %r5, 4;\n\
88        add.u64       %rd5, %rd0, %rd4;\n\
89        ld.global.f32 %f1, [%rd5];\n\
90        mul.lo.u32    %r6, %r5, %r0;\n\
91        add.u32       %r6, %r6, %r4;\n\
92        mul.wide.u32  %rd6, %r6, 4;\n\
93        add.u64       %rd7, %rd2, %rd6;\n\
94        ld.global.f32 %f2, [%rd7];\n\
95        add.f32       %f3, %f1, %f2;\n\
96        sub.f32       %f3, %f3, %f0;\n\
97        ex2.approx.f32 %f3, %f3;\n\
98        add.f32       %f4, %f4, %f3;\n\
99        add.u32       %r5, %r5, 1;\n\
100        bra $FP_SUM;\n\
101    \n\
102    $FP_WRITE:\n\
103        // result = max + log(sum) + log_b_o[j]\n\
104        lg2.approx.f32 %f4, %f4;\n\
105        add.f32       %f4, %f4, %f0;\n\
106        mul.wide.u32  %rd4, %r4, 4;\n\
107        add.u64       %rd5, %rd3, %rd4;\n\
108        ld.global.f32 %f5, [%rd5];\n\
109        add.f32       %f4, %f4, %f5;\n\
110        add.u64       %rd6, %rd1, %rd4;\n\
111        st.global.f32 [%rd6], %f4;\n\
112    \n\
113    $FP_DONE:\n\
114        ret;\n\
115    }\n";
116    hdr + body
117}
118
119/// Viterbi step kernel (log-space, with argmax storage).
120///
121/// Signature: `viterbi_step_kernel(delta_prev, delta_next, log_a, log_b_o, psi, n_states)`
122/// Each thread computes one destination `j`: δ_t(j) = max_i(δ_{t-1}(i) + log A_ij) + log B_j(o_t).
123/// Argmax index is stored in `psi[j]` as `s32`.
124#[must_use]
125pub fn viterbi_step_ptx(sm: u32) -> String {
126    let hdr = ptx_header(sm);
127    let body = ".visible .entry viterbi_step_kernel(\n\
128        .param .u64 p_delta_prev,\n\
129        .param .u64 p_delta_next,\n\
130        .param .u64 p_log_a,\n\
131        .param .u64 p_log_b_o,\n\
132        .param .u64 p_psi,\n\
133        .param .u32 p_n_states\n\
134    )\n\
135    {\n\
136        .reg .u64  %rd<10>;\n\
137        .reg .u32  %r<10>;\n\
138        .reg .s32  %sr<4>;\n\
139        .reg .f32  %f<8>;\n\
140        .reg .pred %p0, %p1;\n\
141    \n\
142        ld.param.u64  %rd0, [p_delta_prev];\n\
143        ld.param.u64  %rd1, [p_delta_next];\n\
144        ld.param.u64  %rd2, [p_log_a];\n\
145        ld.param.u64  %rd3, [p_log_b_o];\n\
146        ld.param.u64  %rd4, [p_psi];\n\
147        ld.param.u32  %r0,  [p_n_states];\n\
148    \n\
149        mov.u32       %r1, %ntid.x;\n\
150        mov.u32       %r2, %ctaid.x;\n\
151        mov.u32       %r3, %tid.x;\n\
152        mad.lo.u32    %r4, %r1, %r2, %r3;\n\
153        setp.ge.u32   %p0, %r4, %r0;\n\
154        @%p0 bra $VS_DONE;\n\
155    \n\
156        mov.f32       %f0, 0fFF800000;   // best = -inf\n\
157        mov.s32       %sr0, -1;          // argmax\n\
158        mov.u32       %r5, 0;\n\
159    $VS_LOOP:\n\
160        setp.ge.u32   %p0, %r5, %r0;\n\
161        @%p0 bra $VS_WRITE;\n\
162        // delta_prev[i]\n\
163        mul.wide.u32  %rd5, %r5, 4;\n\
164        add.u64       %rd6, %rd0, %rd5;\n\
165        ld.global.f32 %f1, [%rd6];\n\
166        // log_a[i*S + j]\n\
167        mul.lo.u32    %r6, %r5, %r0;\n\
168        add.u32       %r6, %r6, %r4;\n\
169        mul.wide.u32  %rd7, %r6, 4;\n\
170        add.u64       %rd8, %rd2, %rd7;\n\
171        ld.global.f32 %f2, [%rd8];\n\
172        add.f32       %f3, %f1, %f2;\n\
173        setp.gt.f32   %p1, %f3, %f0;\n\
174        @%p1 mov.f32  %f0, %f3;\n\
175        @%p1 cvt.s32.u32 %sr0, %r5;\n\
176        add.u32       %r5, %r5, 1;\n\
177        bra $VS_LOOP;\n\
178    \n\
179    $VS_WRITE:\n\
180        // delta_next[j] = best + log_b_o[j]\n\
181        mul.wide.u32  %rd5, %r4, 4;\n\
182        add.u64       %rd6, %rd3, %rd5;\n\
183        ld.global.f32 %f4, [%rd6];\n\
184        add.f32       %f0, %f0, %f4;\n\
185        add.u64       %rd7, %rd1, %rd5;\n\
186        st.global.f32 [%rd7], %f0;\n\
187        // psi[j] = argmax\n\
188        add.u64       %rd8, %rd4, %rd5;\n\
189        st.global.s32 [%rd8], %sr0;\n\
190    \n\
191    $VS_DONE:\n\
192        ret;\n\
193    }\n";
194    hdr + body
195}
196
197/// CRF feature-score kernel.
198///
199/// Signature: `crf_features_kernel(emit, trans, x_feat, score, t, n_labels, n_features)`
200/// Computes per-(label,prev_label) score at time `t`: `emit[y]·x_feat[t] + trans[prev,y]`.
201#[must_use]
202pub fn crf_features_ptx(sm: u32) -> String {
203    let hdr = ptx_header(sm);
204    let body = ".visible .entry crf_features_kernel(\n\
205        .param .u64 p_emit,\n\
206        .param .u64 p_trans,\n\
207        .param .u64 p_x_feat,\n\
208        .param .u64 p_score,\n\
209        .param .u32 p_t,\n\
210        .param .u32 p_n_labels,\n\
211        .param .u32 p_n_features\n\
212    )\n\
213    {\n\
214        .reg .u64  %rd<10>;\n\
215        .reg .u32  %r<14>;\n\
216        .reg .f32  %f<8>;\n\
217        .reg .pred %p0;\n\
218    \n\
219        ld.param.u64  %rd0, [p_emit];\n\
220        ld.param.u64  %rd1, [p_trans];\n\
221        ld.param.u64  %rd2, [p_x_feat];\n\
222        ld.param.u64  %rd3, [p_score];\n\
223        ld.param.u32  %r0,  [p_t];\n\
224        ld.param.u32  %r1,  [p_n_labels];\n\
225        ld.param.u32  %r2,  [p_n_features];\n\
226    \n\
227        // y_prev = blockIdx.y * blockDim.y + threadIdx.y\n\
228        mov.u32       %r3, %ntid.y;\n\
229        mov.u32       %r4, %ctaid.y;\n\
230        mov.u32       %r5, %tid.y;\n\
231        mad.lo.u32    %r6, %r3, %r4, %r5;\n\
232        // y_cur = blockIdx.x * blockDim.x + threadIdx.x\n\
233        mov.u32       %r7, %ntid.x;\n\
234        mov.u32       %r8, %ctaid.x;\n\
235        mov.u32       %r9, %tid.x;\n\
236        mad.lo.u32    %r10, %r7, %r8, %r9;\n\
237    \n\
238        setp.ge.u32   %p0, %r6, %r1;\n\
239        @%p0 bra $CF_DONE;\n\
240        setp.ge.u32   %p0, %r10, %r1;\n\
241        @%p0 bra $CF_DONE;\n\
242    \n\
243        // Emission score: dot(emit[y_cur,:], x_feat[t,:])\n\
244        mov.f32       %f0, 0f00000000;\n\
245        mov.u32       %r11, 0;\n\
246    $CF_EMIT:\n\
247        setp.ge.u32   %p0, %r11, %r2;\n\
248        @%p0 bra $CF_TRANS;\n\
249        // emit[y_cur * n_features + k]\n\
250        mul.lo.u32    %r12, %r10, %r2;\n\
251        add.u32       %r12, %r12, %r11;\n\
252        mul.wide.u32  %rd4, %r12, 4;\n\
253        add.u64       %rd5, %rd0, %rd4;\n\
254        ld.global.f32 %f1, [%rd5];\n\
255        // x_feat[t * n_features + k]\n\
256        mul.lo.u32    %r13, %r0, %r2;\n\
257        add.u32       %r13, %r13, %r11;\n\
258        mul.wide.u32  %rd6, %r13, 4;\n\
259        add.u64       %rd7, %rd2, %rd6;\n\
260        ld.global.f32 %f2, [%rd7];\n\
261        fma.rn.f32    %f0, %f1, %f2, %f0;\n\
262        add.u32       %r11, %r11, 1;\n\
263        bra $CF_EMIT;\n\
264    \n\
265    $CF_TRANS:\n\
266        // Transition score: trans[y_prev * n_labels + y_cur]\n\
267        mul.lo.u32    %r12, %r6, %r1;\n\
268        add.u32       %r12, %r12, %r10;\n\
269        mul.wide.u32  %rd4, %r12, 4;\n\
270        add.u64       %rd5, %rd1, %rd4;\n\
271        ld.global.f32 %f3, [%rd5];\n\
272        add.f32       %f0, %f0, %f3;\n\
273    \n\
274        // score[y_prev * n_labels + y_cur] = f0\n\
275        add.u64       %rd6, %rd3, %rd4;\n\
276        st.global.f32 [%rd6], %f0;\n\
277    \n\
278    $CF_DONE:\n\
279        ret;\n\
280    }\n";
281    hdr + body
282}
283
284/// Beam top-k partial-sort kernel (one-pass rank approximation).
285///
286/// Signature: `beam_topk_kernel(scores, rank, n, k)`
287/// Each thread computes how many other scores are strictly greater (its rank);
288/// threads whose rank < k are marked surviving (`rank[tid] = rank`); others get `-1`.
289#[must_use]
290pub fn beam_topk_ptx(sm: u32) -> String {
291    let hdr = ptx_header(sm);
292    let body = ".visible .entry beam_topk_kernel(\n\
293        .param .u64 p_scores,\n\
294        .param .u64 p_rank,\n\
295        .param .u32 p_n,\n\
296        .param .u32 p_k\n\
297    )\n\
298    {\n\
299        .reg .u64  %rd<8>;\n\
300        .reg .u32  %r<10>;\n\
301        .reg .s32  %sr<4>;\n\
302        .reg .f32  %f<4>;\n\
303        .reg .pred %p0, %p1;\n\
304    \n\
305        ld.param.u64  %rd0, [p_scores];\n\
306        ld.param.u64  %rd1, [p_rank];\n\
307        ld.param.u32  %r0,  [p_n];\n\
308        ld.param.u32  %r1,  [p_k];\n\
309    \n\
310        mov.u32       %r2, %ntid.x;\n\
311        mov.u32       %r3, %ctaid.x;\n\
312        mov.u32       %r4, %tid.x;\n\
313        mad.lo.u32    %r5, %r2, %r3, %r4;\n\
314        setp.ge.u32   %p0, %r5, %r0;\n\
315        @%p0 bra $BK_DONE;\n\
316    \n\
317        // my score\n\
318        mul.wide.u32  %rd2, %r5, 4;\n\
319        add.u64       %rd3, %rd0, %rd2;\n\
320        ld.global.f32 %f0, [%rd3];\n\
321    \n\
322        mov.u32       %r6, 0;     // rank counter\n\
323        mov.u32       %r7, 0;     // loop index\n\
324    $BK_LOOP:\n\
325        setp.ge.u32   %p0, %r7, %r0;\n\
326        @%p0 bra $BK_WRITE;\n\
327        mul.wide.u32  %rd4, %r7, 4;\n\
328        add.u64       %rd5, %rd0, %rd4;\n\
329        ld.global.f32 %f1, [%rd5];\n\
330        setp.gt.f32   %p1, %f1, %f0;\n\
331        @%p1 add.u32  %r6, %r6, 1;\n\
332        add.u32       %r7, %r7, 1;\n\
333        bra $BK_LOOP;\n\
334    \n\
335    $BK_WRITE:\n\
336        setp.ge.u32   %p0, %r6, %r1;\n\
337        mov.s32       %sr0, -1;\n\
338        @!%p0 cvt.s32.u32 %sr0, %r6;\n\
339        add.u64       %rd6, %rd1, %rd2;\n\
340        st.global.s32 [%rd6], %sr0;\n\
341    \n\
342    $BK_DONE:\n\
343        ret;\n\
344    }\n";
345    hdr + body
346}
347
348/// Edit-distance anti-diagonal cell update kernel.
349///
350/// Signature: `edit_dist_kernel(dp, a_chars, b_chars, n_a, n_b, diag)`
351/// On anti-diagonal `diag`, each thread updates one cell `dp[i*(n_b+1)+j]`
352/// using the standard Levenshtein recurrence.
353#[must_use]
354pub fn edit_dist_ptx(sm: u32) -> String {
355    let hdr = ptx_header(sm);
356    let body = ".visible .entry edit_dist_kernel(\n\
357        .param .u64 p_dp,\n\
358        .param .u64 p_a,\n\
359        .param .u64 p_b,\n\
360        .param .u32 p_n_a,\n\
361        .param .u32 p_n_b,\n\
362        .param .u32 p_diag\n\
363    )\n\
364    {\n\
365        .reg .u64  %rd<10>;\n\
366        .reg .u32  %r<14>;\n\
367        .reg .s32  %sr<6>;\n\
368        .reg .pred %p0, %p1;\n\
369    \n\
370        ld.param.u64  %rd0, [p_dp];\n\
371        ld.param.u64  %rd1, [p_a];\n\
372        ld.param.u64  %rd2, [p_b];\n\
373        ld.param.u32  %r0,  [p_n_a];\n\
374        ld.param.u32  %r1,  [p_n_b];\n\
375        ld.param.u32  %r2,  [p_diag];\n\
376    \n\
377        // i = tid + 1; j = diag - i\n\
378        mov.u32       %r3, %ntid.x;\n\
379        mov.u32       %r4, %ctaid.x;\n\
380        mov.u32       %r5, %tid.x;\n\
381        mad.lo.u32    %r6, %r3, %r4, %r5;\n\
382        add.u32       %r7, %r6, 1;            // i\n\
383        sub.u32       %r8, %r2, %r7;          // j\n\
384        setp.gt.u32   %p0, %r7, %r0;\n\
385        @%p0 bra $ED_DONE;\n\
386        setp.eq.u32   %p0, %r8, 0;\n\
387        @%p0 bra $ED_DONE;\n\
388        setp.gt.u32   %p0, %r8, %r1;\n\
389        @%p0 bra $ED_DONE;\n\
390    \n\
391        // load a[i-1], b[j-1]\n\
392        sub.u32       %r9, %r7, 1;\n\
393        mul.wide.u32  %rd3, %r9, 4;\n\
394        add.u64       %rd4, %rd1, %rd3;\n\
395        ld.global.s32 %sr0, [%rd4];\n\
396        sub.u32       %r10, %r8, 1;\n\
397        mul.wide.u32  %rd5, %r10, 4;\n\
398        add.u64       %rd6, %rd2, %rd5;\n\
399        ld.global.s32 %sr1, [%rd6];\n\
400    \n\
401        // cost: 0 if eq, else 1\n\
402        setp.eq.s32   %p1, %sr0, %sr1;\n\
403        mov.s32       %sr2, 1;\n\
404        @%p1 mov.s32  %sr2, 0;\n\
405    \n\
406        // Read 3 neighbours from dp\n\
407        // dp[(i-1)*(n_b+1) + j]\n\
408        add.u32       %r11, %r1, 1;\n\
409        mul.lo.u32    %r12, %r9, %r11;\n\
410        add.u32       %r12, %r12, %r8;\n\
411        mul.wide.u32  %rd7, %r12, 4;\n\
412        add.u64       %rd8, %rd0, %rd7;\n\
413        ld.global.s32 %sr3, [%rd8];\n\
414        // dp[i*(n_b+1) + (j-1)]\n\
415        mul.lo.u32    %r12, %r7, %r11;\n\
416        add.u32       %r12, %r12, %r10;\n\
417        mul.wide.u32  %rd7, %r12, 4;\n\
418        add.u64       %rd8, %rd0, %rd7;\n\
419        ld.global.s32 %sr4, [%rd8];\n\
420        // dp[(i-1)*(n_b+1) + (j-1)]\n\
421        mul.lo.u32    %r12, %r9, %r11;\n\
422        add.u32       %r12, %r12, %r10;\n\
423        mul.wide.u32  %rd7, %r12, 4;\n\
424        add.u64       %rd8, %rd0, %rd7;\n\
425        ld.global.s32 %sr5, [%rd8];\n\
426    \n\
427        // best = min(sr3+1, sr4+1, sr5+sr2)\n\
428        add.s32       %sr3, %sr3, 1;\n\
429        add.s32       %sr4, %sr4, 1;\n\
430        add.s32       %sr5, %sr5, %sr2;\n\
431        min.s32       %sr3, %sr3, %sr4;\n\
432        min.s32       %sr3, %sr3, %sr5;\n\
433    \n\
434        // write dp[i*(n_b+1)+j]\n\
435        mul.lo.u32    %r12, %r7, %r11;\n\
436        add.u32       %r12, %r12, %r8;\n\
437        mul.wide.u32  %rd7, %r12, 4;\n\
438        add.u64       %rd8, %rd0, %rd7;\n\
439        st.global.s32 [%rd8], %sr3;\n\
440    \n\
441    $ED_DONE:\n\
442        ret;\n\
443    }\n";
444    hdr + body
445}
446
447/// Kalman predict step kernel (matrix-vector + covariance update).
448///
449/// Signature: `kalman_predict_kernel(x, x_pred, A, P, P_pred, Q, n)`
450/// Computes x_pred = A·x and P_pred = A·P·Aᵀ + Q.  Each thread handles one row.
451#[must_use]
452pub fn kalman_predict_ptx(sm: u32) -> String {
453    let hdr = ptx_header(sm);
454    let body = ".visible .entry kalman_predict_kernel(\n\
455        .param .u64 p_x,\n\
456        .param .u64 p_x_pred,\n\
457        .param .u64 p_a,\n\
458        .param .u64 p_p,\n\
459        .param .u64 p_p_pred,\n\
460        .param .u64 p_q,\n\
461        .param .u32 p_n\n\
462    )\n\
463    {\n\
464        .reg .u64  %rd<14>;\n\
465        .reg .u32  %r<14>;\n\
466        .reg .f32  %f<10>;\n\
467        .reg .pred %p0;\n\
468    \n\
469        ld.param.u64  %rd0, [p_x];\n\
470        ld.param.u64  %rd1, [p_x_pred];\n\
471        ld.param.u64  %rd2, [p_a];\n\
472        ld.param.u64  %rd3, [p_p];\n\
473        ld.param.u64  %rd4, [p_p_pred];\n\
474        ld.param.u64  %rd5, [p_q];\n\
475        ld.param.u32  %r0,  [p_n];\n\
476    \n\
477        mov.u32       %r1, %ntid.x;\n\
478        mov.u32       %r2, %ctaid.x;\n\
479        mov.u32       %r3, %tid.x;\n\
480        mad.lo.u32    %r4, %r1, %r2, %r3;\n\
481        setp.ge.u32   %p0, %r4, %r0;\n\
482        @%p0 bra $KP_DONE;\n\
483    \n\
484        // x_pred[i] = sum_k A[i,k] * x[k]\n\
485        mov.f32       %f0, 0f00000000;\n\
486        mov.u32       %r5, 0;\n\
487    $KP_VEC:\n\
488        setp.ge.u32   %p0, %r5, %r0;\n\
489        @%p0 bra $KP_VEC_WR;\n\
490        // A[i*n + k]\n\
491        mul.lo.u32    %r6, %r4, %r0;\n\
492        add.u32       %r6, %r6, %r5;\n\
493        mul.wide.u32  %rd6, %r6, 4;\n\
494        add.u64       %rd7, %rd2, %rd6;\n\
495        ld.global.f32 %f1, [%rd7];\n\
496        // x[k]\n\
497        mul.wide.u32  %rd8, %r5, 4;\n\
498        add.u64       %rd9, %rd0, %rd8;\n\
499        ld.global.f32 %f2, [%rd9];\n\
500        fma.rn.f32    %f0, %f1, %f2, %f0;\n\
501        add.u32       %r5, %r5, 1;\n\
502        bra $KP_VEC;\n\
503    \n\
504    $KP_VEC_WR:\n\
505        mul.wide.u32  %rd6, %r4, 4;\n\
506        add.u64       %rd7, %rd1, %rd6;\n\
507        st.global.f32 [%rd7], %f0;\n\
508    \n\
509        // P_pred[i,j] = sum_{k,l} A[i,k] P[k,l] A[j,l] + Q[i,j]\n\
510        // One thread handles row i, all j.\n\
511        mov.u32       %r7, 0;\n\
512    $KP_J:\n\
513        setp.ge.u32   %p0, %r7, %r0;\n\
514        @%p0 bra $KP_DONE;\n\
515        mov.f32       %f3, 0f00000000;\n\
516        mov.u32       %r8, 0;\n\
517    $KP_K:\n\
518        setp.ge.u32   %p0, %r8, %r0;\n\
519        @%p0 bra $KP_K_DONE;\n\
520        mov.f32       %f4, 0f00000000;\n\
521        mov.u32       %r9, 0;\n\
522    $KP_L:\n\
523        setp.ge.u32   %p0, %r9, %r0;\n\
524        @%p0 bra $KP_L_DONE;\n\
525        // P[k*n + l]\n\
526        mul.lo.u32    %r10, %r8, %r0;\n\
527        add.u32       %r10, %r10, %r9;\n\
528        mul.wide.u32  %rd6, %r10, 4;\n\
529        add.u64       %rd7, %rd3, %rd6;\n\
530        ld.global.f32 %f5, [%rd7];\n\
531        // A[j*n + l]\n\
532        mul.lo.u32    %r11, %r7, %r0;\n\
533        add.u32       %r11, %r11, %r9;\n\
534        mul.wide.u32  %rd8, %r11, 4;\n\
535        add.u64       %rd9, %rd2, %rd8;\n\
536        ld.global.f32 %f6, [%rd9];\n\
537        fma.rn.f32    %f4, %f5, %f6, %f4;\n\
538        add.u32       %r9, %r9, 1;\n\
539        bra $KP_L;\n\
540    \n\
541    $KP_L_DONE:\n\
542        // A[i*n + k]\n\
543        mul.lo.u32    %r10, %r4, %r0;\n\
544        add.u32       %r10, %r10, %r8;\n\
545        mul.wide.u32  %rd6, %r10, 4;\n\
546        add.u64       %rd7, %rd2, %rd6;\n\
547        ld.global.f32 %f7, [%rd7];\n\
548        fma.rn.f32    %f3, %f7, %f4, %f3;\n\
549        add.u32       %r8, %r8, 1;\n\
550        bra $KP_K;\n\
551    \n\
552    $KP_K_DONE:\n\
553        // P_pred[i*n+j] = f3 + Q[i*n+j]\n\
554        mul.lo.u32    %r10, %r4, %r0;\n\
555        add.u32       %r10, %r10, %r7;\n\
556        mul.wide.u32  %rd6, %r10, 4;\n\
557        add.u64       %rd7, %rd5, %rd6;\n\
558        ld.global.f32 %f8, [%rd7];\n\
559        add.f32       %f3, %f3, %f8;\n\
560        add.u64       %rd9, %rd4, %rd6;\n\
561        st.global.f32 [%rd9], %f3;\n\
562        add.u32       %r7, %r7, 1;\n\
563        bra $KP_J;\n\
564    \n\
565    $KP_DONE:\n\
566        ret;\n\
567    }\n";
568    hdr + body
569}
570
571/// Gibbs-sampling Ising MRF kernel: per-site conditional resample given neighbours.
572///
573/// Signature: `mrf_gibbs_kernel(spins, h, j, n_rows, n_cols, seed)`
574/// Each thread samples its spin given a uniform draw from the inline LCG seeded
575/// by `seed ^ tid`.
576#[must_use]
577pub fn mrf_gibbs_ptx(sm: u32) -> String {
578    let hdr = ptx_header(sm);
579    let body = ".visible .entry mrf_gibbs_kernel(\n\
580        .param .u64 p_spins,\n\
581        .param .f32 p_h,\n\
582        .param .f32 p_j,\n\
583        .param .u32 p_n_rows,\n\
584        .param .u32 p_n_cols,\n\
585        .param .u64 p_seed\n\
586    )\n\
587    {\n\
588        .reg .u64  %rd<10>;\n\
589        .reg .u32  %r<16>;\n\
590        .reg .s32  %sr<8>;\n\
591        .reg .f32  %f<10>;\n\
592        .reg .pred %p0, %p1;\n\
593    \n\
594        ld.param.u64  %rd0, [p_spins];\n\
595        ld.param.f32  %f0,  [p_h];\n\
596        ld.param.f32  %f1,  [p_j];\n\
597        ld.param.u32  %r0,  [p_n_rows];\n\
598        ld.param.u32  %r1,  [p_n_cols];\n\
599        ld.param.u64  %rd1, [p_seed];\n\
600    \n\
601        // i = blockIdx.y * blockDim.y + threadIdx.y\n\
602        mov.u32       %r2, %ntid.y;\n\
603        mov.u32       %r3, %ctaid.y;\n\
604        mov.u32       %r4, %tid.y;\n\
605        mad.lo.u32    %r5, %r2, %r3, %r4;\n\
606        // j = blockIdx.x * blockDim.x + threadIdx.x\n\
607        mov.u32       %r6, %ntid.x;\n\
608        mov.u32       %r7, %ctaid.x;\n\
609        mov.u32       %r8, %tid.x;\n\
610        mad.lo.u32    %r9, %r6, %r7, %r8;\n\
611    \n\
612        setp.ge.u32   %p0, %r5, %r0;\n\
613        @%p0 bra $MG_DONE;\n\
614        setp.ge.u32   %p0, %r9, %r1;\n\
615        @%p0 bra $MG_DONE;\n\
616    \n\
617        // Sum of neighbour spins (with bounds checks)\n\
618        mov.s32       %sr0, 0;\n\
619        // up\n\
620        setp.eq.u32   %p1, %r5, 0;\n\
621        @%p1 bra $MG_NB1;\n\
622        sub.u32       %r10, %r5, 1;\n\
623        mul.lo.u32    %r11, %r10, %r1;\n\
624        add.u32       %r11, %r11, %r9;\n\
625        mul.wide.u32  %rd2, %r11, 4;\n\
626        add.u64       %rd3, %rd0, %rd2;\n\
627        ld.global.s32 %sr1, [%rd3];\n\
628        add.s32       %sr0, %sr0, %sr1;\n\
629    $MG_NB1:\n\
630        // down\n\
631        add.u32       %r10, %r5, 1;\n\
632        setp.ge.u32   %p1, %r10, %r0;\n\
633        @%p1 bra $MG_NB2;\n\
634        mul.lo.u32    %r11, %r10, %r1;\n\
635        add.u32       %r11, %r11, %r9;\n\
636        mul.wide.u32  %rd2, %r11, 4;\n\
637        add.u64       %rd3, %rd0, %rd2;\n\
638        ld.global.s32 %sr1, [%rd3];\n\
639        add.s32       %sr0, %sr0, %sr1;\n\
640    $MG_NB2:\n\
641        // left\n\
642        setp.eq.u32   %p1, %r9, 0;\n\
643        @%p1 bra $MG_NB3;\n\
644        sub.u32       %r10, %r9, 1;\n\
645        mul.lo.u32    %r11, %r5, %r1;\n\
646        add.u32       %r11, %r11, %r10;\n\
647        mul.wide.u32  %rd2, %r11, 4;\n\
648        add.u64       %rd3, %rd0, %rd2;\n\
649        ld.global.s32 %sr1, [%rd3];\n\
650        add.s32       %sr0, %sr0, %sr1;\n\
651    $MG_NB3:\n\
652        // right\n\
653        add.u32       %r10, %r9, 1;\n\
654        setp.ge.u32   %p1, %r10, %r1;\n\
655        @%p1 bra $MG_FIELD;\n\
656        mul.lo.u32    %r11, %r5, %r1;\n\
657        add.u32       %r11, %r11, %r10;\n\
658        mul.wide.u32  %rd2, %r11, 4;\n\
659        add.u64       %rd3, %rd0, %rd2;\n\
660        ld.global.s32 %sr1, [%rd3];\n\
661        add.s32       %sr0, %sr0, %sr1;\n\
662    \n\
663    $MG_FIELD:\n\
664        // field = j * sum + h\n\
665        cvt.rn.f32.s32 %f2, %sr0;\n\
666        mul.f32       %f3, %f1, %f2;\n\
667        add.f32       %f3, %f3, %f0;\n\
668    \n\
669        // p_up = 1 / (1 + exp(-2*field))\n\
670        mov.f32       %f4, 0fC0000000;       // -2\n\
671        mul.f32       %f5, %f4, %f3;\n\
672        ex2.approx.f32 %f5, %f5;             // exp2(...)\n\
673        mov.f32       %f6, 0f3F800000;       // 1.0\n\
674        add.f32       %f5, %f5, %f6;\n\
675        div.rn.f32    %f7, %f6, %f5;\n\
676    \n\
677        // Inline LCG: seed ^ (row * n_cols + col)\n\
678        mul.lo.u32    %r12, %r5, %r1;\n\
679        add.u32       %r12, %r12, %r9;\n\
680        cvt.u64.u32   %rd4, %r12;\n\
681        xor.b64       %rd5, %rd4, %rd1;\n\
682        mov.u64       %rd6, 6364136223846793005;\n\
683        mul.lo.u64    %rd5, %rd5, %rd6;\n\
684        mov.u64       %rd6, 1442695040888963407;\n\
685        add.u64       %rd5, %rd5, %rd6;\n\
686        shr.u64       %rd7, %rd5, 32;\n\
687        cvt.u32.u64   %r13, %rd7;\n\
688        // u = (high32 >> 8) / 2^24\n\
689        shr.u32       %r14, %r13, 8;\n\
690        cvt.rn.f32.u32 %f8, %r14;\n\
691        mov.f32       %f9, 0f33800000;       // 1 / 2^24\n\
692        mul.f32       %f8, %f8, %f9;\n\
693    \n\
694        // s = (u < p_up) ? +1 : -1\n\
695        setp.lt.f32   %p1, %f8, %f7;\n\
696        mov.s32       %sr2, -1;\n\
697        @%p1 mov.s32  %sr2, 1;\n\
698    \n\
699        // store\n\
700        mul.lo.u32    %r15, %r5, %r1;\n\
701        add.u32       %r15, %r15, %r9;\n\
702        mul.wide.u32  %rd8, %r15, 4;\n\
703        add.u64       %rd9, %rd0, %rd8;\n\
704        st.global.s32 [%rd9], %sr2;\n\
705    \n\
706    $MG_DONE:\n\
707        ret;\n\
708    }\n";
709    hdr + body
710}
711
712#[cfg(test)]
713mod tests {
714    use super::*;
715
716    #[test]
717    fn ptx_header_versions() {
718        assert!(ptx_header(75).contains(".version 7.5"));
719        assert!(ptx_header(80).contains(".version 8.0"));
720        assert!(ptx_header(89).contains(".version 8.0"));
721        assert!(ptx_header(90).contains(".version 8.4"));
722        assert!(ptx_header(100).contains(".version 8.7"));
723    }
724
725    #[test]
726    fn all_kernels_non_empty() {
727        type KernelFn = fn(u32) -> String;
728        let kernels: &[(&str, KernelFn)] = &[
729            ("forward_pass", forward_pass_ptx),
730            ("viterbi_step", viterbi_step_ptx),
731            ("crf_features", crf_features_ptx),
732            ("beam_topk", beam_topk_ptx),
733            ("edit_dist", edit_dist_ptx),
734            ("kalman_predict", kalman_predict_ptx),
735            ("mrf_gibbs", mrf_gibbs_ptx),
736        ];
737        let sms = [75u32, 80, 86, 89, 90, 100];
738        for &sm in &sms {
739            for &(name, f) in kernels {
740                let s = f(sm);
741                assert!(!s.is_empty(), "{name} sm{sm} empty");
742                assert!(
743                    s.contains(".visible .entry"),
744                    "{name} sm{sm} missing .visible .entry"
745                );
746            }
747        }
748    }
749}