Skip to main content

oxicuda_seq/
ptx_kernels.rs

1//! GPU PTX kernels for Sequence Models & Structured Prediction.
2//!
3//! Each kernel is emitted as a self-contained PTX module string, parameterised
4//! on the SM version.  PTX ISA selection by SM:
5//!     SM≥100 → 8.7 (Blackwell), SM≥90 → 8.4 (Hopper),
6//!     SM≥80  → 8.0 (Ampere),    else → 7.5 (Turing).
7//!
8//! IMPORTANT: PTX kernel bodies use **string concatenation** (NOT `format!()`)
9//! for sections containing `%rd`, `%r`, `%f` register names, which Rust's
10//! `format!` macro would reject as malformed positional specifiers in
11//! edition 2024.
12
13/// Build a PTX file header string for the given SM version.
14fn ptx_header(sm: u32) -> String {
15    let (ptx_ver, target) = match sm {
16        v if v >= 100 => ("8.7", format!("sm_{v}")),
17        v if v >= 90 => ("8.4", format!("sm_{v}")),
18        v if v >= 80 => ("8.0", format!("sm_{v}")),
19        v => ("7.5", format!("sm_{v}")),
20    };
21    format!(".version {ptx_ver}\n.target {target}\n.address_size 64\n\n")
22}
23
24/// HMM forward-pass kernel (log-space).
25///
26/// Signature: `forward_pass_kernel(alpha_prev, alpha_next, log_a, log_b_o, n_states)`
27/// One thread per destination state `j`; reads `alpha_prev[i] + log_a[i*S+j]`
28/// over all `i`, applies max+log-sum-exp, then adds `log_b_o[j]`.
29#[must_use]
30pub fn forward_pass_ptx(sm: u32) -> String {
31    let hdr = ptx_header(sm);
32    let body = ".visible .entry forward_pass_kernel(\n\
33        .param .u64 p_alpha_prev,\n\
34        .param .u64 p_alpha_next,\n\
35        .param .u64 p_log_a,\n\
36        .param .u64 p_log_b_o,\n\
37        .param .u32 p_n_states\n\
38    )\n\
39    {\n\
40        .reg .u64  %rd<10>;\n\
41        .reg .u32  %r<10>;\n\
42        .reg .f32  %f<10>;\n\
43        .reg .pred %p0;\n\
44    \n\
45        ld.param.u64  %rd0, [p_alpha_prev];\n\
46        ld.param.u64  %rd1, [p_alpha_next];\n\
47        ld.param.u64  %rd2, [p_log_a];\n\
48        ld.param.u64  %rd3, [p_log_b_o];\n\
49        ld.param.u32  %r0,  [p_n_states];\n\
50    \n\
51        // j = global thread id\n\
52        mov.u32       %r1, %ntid.x;\n\
53        mov.u32       %r2, %ctaid.x;\n\
54        mov.u32       %r3, %tid.x;\n\
55        mad.lo.u32    %r4, %r1, %r2, %r3;\n\
56        setp.ge.u32   %p0, %r4, %r0;\n\
57        @%p0 bra $FP_DONE;\n\
58    \n\
59        // First pass: find max of (alpha_prev[i] + log_a[i*S + j])\n\
60        mov.f32       %f0, 0fFF800000;   // -inf\n\
61        mov.u32       %r5, 0;\n\
62    $FP_MAX:\n\
63        setp.ge.u32   %p0, %r5, %r0;\n\
64        @%p0 bra $FP_SUM_INIT;\n\
65        // alpha_prev[i]\n\
66        mul.wide.u32  %rd4, %r5, 4;\n\
67        add.u64       %rd5, %rd0, %rd4;\n\
68        ld.global.f32 %f1, [%rd5];\n\
69        // log_a[i*S + j]\n\
70        mul.lo.u32    %r6, %r5, %r0;\n\
71        add.u32       %r6, %r6, %r4;\n\
72        mul.wide.u32  %rd6, %r6, 4;\n\
73        add.u64       %rd7, %rd2, %rd6;\n\
74        ld.global.f32 %f2, [%rd7];\n\
75        add.f32       %f3, %f1, %f2;\n\
76        max.f32       %f0, %f0, %f3;\n\
77        add.u32       %r5, %r5, 1;\n\
78        bra $FP_MAX;\n\
79    \n\
80    $FP_SUM_INIT:\n\
81        // Second pass: accumulate exp((alpha_prev[i] + log_a[i*S+j]) - max)\n\
82        mov.f32       %f4, 0f00000000;\n\
83        mov.u32       %r5, 0;\n\
84    $FP_SUM:\n\
85        setp.ge.u32   %p0, %r5, %r0;\n\
86        @%p0 bra $FP_WRITE;\n\
87        mul.wide.u32  %rd4, %r5, 4;\n\
88        add.u64       %rd5, %rd0, %rd4;\n\
89        ld.global.f32 %f1, [%rd5];\n\
90        mul.lo.u32    %r6, %r5, %r0;\n\
91        add.u32       %r6, %r6, %r4;\n\
92        mul.wide.u32  %rd6, %r6, 4;\n\
93        add.u64       %rd7, %rd2, %rd6;\n\
94        ld.global.f32 %f2, [%rd7];\n\
95        add.f32       %f3, %f1, %f2;\n\
96        sub.f32       %f3, %f3, %f0;\n\
97        mul.f32       %f3, %f3, 0f3FB8AA3B;   // * log2(e): exp(x)=ex2(x*log2e)\n\
98        ex2.approx.f32 %f3, %f3;\n\
99        add.f32       %f4, %f4, %f3;\n\
100        add.u32       %r5, %r5, 1;\n\
101        bra $FP_SUM;\n\
102    \n\
103    $FP_WRITE:\n\
104        // result = max + log(sum) + log_b_o[j]\n\
105        lg2.approx.f32 %f4, %f4;\n\
106        mul.f32       %f4, %f4, 0f3F317218;   // * ln(2): ln(x)=lg2(x)*ln2\n\
107        add.f32       %f4, %f4, %f0;\n\
108        mul.wide.u32  %rd4, %r4, 4;\n\
109        add.u64       %rd5, %rd3, %rd4;\n\
110        ld.global.f32 %f5, [%rd5];\n\
111        add.f32       %f4, %f4, %f5;\n\
112        add.u64       %rd6, %rd1, %rd4;\n\
113        st.global.f32 [%rd6], %f4;\n\
114    \n\
115    $FP_DONE:\n\
116        ret;\n\
117    }\n";
118    hdr + body
119}
120
121/// Viterbi step kernel (log-space, with argmax storage).
122///
123/// Signature: `viterbi_step_kernel(delta_prev, delta_next, log_a, log_b_o, psi, n_states)`
124/// Each thread computes one destination `j`: δ_t(j) = max_i(δ_{t-1}(i) + log A_ij) + log B_j(o_t).
125/// Argmax index is stored in `psi[j]` as `s32`.
126#[must_use]
127pub fn viterbi_step_ptx(sm: u32) -> String {
128    let hdr = ptx_header(sm);
129    let body = ".visible .entry viterbi_step_kernel(\n\
130        .param .u64 p_delta_prev,\n\
131        .param .u64 p_delta_next,\n\
132        .param .u64 p_log_a,\n\
133        .param .u64 p_log_b_o,\n\
134        .param .u64 p_psi,\n\
135        .param .u32 p_n_states\n\
136    )\n\
137    {\n\
138        .reg .u64  %rd<10>;\n\
139        .reg .u32  %r<10>;\n\
140        .reg .s32  %sr<4>;\n\
141        .reg .f32  %f<8>;\n\
142        .reg .pred %p0, %p1;\n\
143    \n\
144        ld.param.u64  %rd0, [p_delta_prev];\n\
145        ld.param.u64  %rd1, [p_delta_next];\n\
146        ld.param.u64  %rd2, [p_log_a];\n\
147        ld.param.u64  %rd3, [p_log_b_o];\n\
148        ld.param.u64  %rd4, [p_psi];\n\
149        ld.param.u32  %r0,  [p_n_states];\n\
150    \n\
151        mov.u32       %r1, %ntid.x;\n\
152        mov.u32       %r2, %ctaid.x;\n\
153        mov.u32       %r3, %tid.x;\n\
154        mad.lo.u32    %r4, %r1, %r2, %r3;\n\
155        setp.ge.u32   %p0, %r4, %r0;\n\
156        @%p0 bra $VS_DONE;\n\
157    \n\
158        mov.f32       %f0, 0fFF800000;   // best = -inf\n\
159        mov.s32       %sr0, -1;          // argmax\n\
160        mov.u32       %r5, 0;\n\
161    $VS_LOOP:\n\
162        setp.ge.u32   %p0, %r5, %r0;\n\
163        @%p0 bra $VS_WRITE;\n\
164        // delta_prev[i]\n\
165        mul.wide.u32  %rd5, %r5, 4;\n\
166        add.u64       %rd6, %rd0, %rd5;\n\
167        ld.global.f32 %f1, [%rd6];\n\
168        // log_a[i*S + j]\n\
169        mul.lo.u32    %r6, %r5, %r0;\n\
170        add.u32       %r6, %r6, %r4;\n\
171        mul.wide.u32  %rd7, %r6, 4;\n\
172        add.u64       %rd8, %rd2, %rd7;\n\
173        ld.global.f32 %f2, [%rd8];\n\
174        add.f32       %f3, %f1, %f2;\n\
175        setp.gt.f32   %p1, %f3, %f0;\n\
176        @%p1 mov.f32  %f0, %f3;\n\
177        @%p1 cvt.s32.u32 %sr0, %r5;\n\
178        add.u32       %r5, %r5, 1;\n\
179        bra $VS_LOOP;\n\
180    \n\
181    $VS_WRITE:\n\
182        // delta_next[j] = best + log_b_o[j]\n\
183        mul.wide.u32  %rd5, %r4, 4;\n\
184        add.u64       %rd6, %rd3, %rd5;\n\
185        ld.global.f32 %f4, [%rd6];\n\
186        add.f32       %f0, %f0, %f4;\n\
187        add.u64       %rd7, %rd1, %rd5;\n\
188        st.global.f32 [%rd7], %f0;\n\
189        // psi[j] = argmax\n\
190        add.u64       %rd8, %rd4, %rd5;\n\
191        st.global.s32 [%rd8], %sr0;\n\
192    \n\
193    $VS_DONE:\n\
194        ret;\n\
195    }\n";
196    hdr + body
197}
198
199/// CRF feature-score kernel.
200///
201/// Signature: `crf_features_kernel(emit, trans, x_feat, score, t, n_labels, n_features)`
202/// Computes per-(label,prev_label) score at time `t`: `emit[y]·x_feat[t] + trans[prev,y]`.
203#[must_use]
204pub fn crf_features_ptx(sm: u32) -> String {
205    let hdr = ptx_header(sm);
206    let body = ".visible .entry crf_features_kernel(\n\
207        .param .u64 p_emit,\n\
208        .param .u64 p_trans,\n\
209        .param .u64 p_x_feat,\n\
210        .param .u64 p_score,\n\
211        .param .u32 p_t,\n\
212        .param .u32 p_n_labels,\n\
213        .param .u32 p_n_features\n\
214    )\n\
215    {\n\
216        .reg .u64  %rd<10>;\n\
217        .reg .u32  %r<14>;\n\
218        .reg .f32  %f<8>;\n\
219        .reg .pred %p0;\n\
220    \n\
221        ld.param.u64  %rd0, [p_emit];\n\
222        ld.param.u64  %rd1, [p_trans];\n\
223        ld.param.u64  %rd2, [p_x_feat];\n\
224        ld.param.u64  %rd3, [p_score];\n\
225        ld.param.u32  %r0,  [p_t];\n\
226        ld.param.u32  %r1,  [p_n_labels];\n\
227        ld.param.u32  %r2,  [p_n_features];\n\
228    \n\
229        // y_prev = blockIdx.y * blockDim.y + threadIdx.y\n\
230        mov.u32       %r3, %ntid.y;\n\
231        mov.u32       %r4, %ctaid.y;\n\
232        mov.u32       %r5, %tid.y;\n\
233        mad.lo.u32    %r6, %r3, %r4, %r5;\n\
234        // y_cur = blockIdx.x * blockDim.x + threadIdx.x\n\
235        mov.u32       %r7, %ntid.x;\n\
236        mov.u32       %r8, %ctaid.x;\n\
237        mov.u32       %r9, %tid.x;\n\
238        mad.lo.u32    %r10, %r7, %r8, %r9;\n\
239    \n\
240        setp.ge.u32   %p0, %r6, %r1;\n\
241        @%p0 bra $CF_DONE;\n\
242        setp.ge.u32   %p0, %r10, %r1;\n\
243        @%p0 bra $CF_DONE;\n\
244    \n\
245        // Emission score: dot(emit[y_cur,:], x_feat[t,:])\n\
246        mov.f32       %f0, 0f00000000;\n\
247        mov.u32       %r11, 0;\n\
248    $CF_EMIT:\n\
249        setp.ge.u32   %p0, %r11, %r2;\n\
250        @%p0 bra $CF_TRANS;\n\
251        // emit[y_cur * n_features + k]\n\
252        mul.lo.u32    %r12, %r10, %r2;\n\
253        add.u32       %r12, %r12, %r11;\n\
254        mul.wide.u32  %rd4, %r12, 4;\n\
255        add.u64       %rd5, %rd0, %rd4;\n\
256        ld.global.f32 %f1, [%rd5];\n\
257        // x_feat[t * n_features + k]\n\
258        mul.lo.u32    %r13, %r0, %r2;\n\
259        add.u32       %r13, %r13, %r11;\n\
260        mul.wide.u32  %rd6, %r13, 4;\n\
261        add.u64       %rd7, %rd2, %rd6;\n\
262        ld.global.f32 %f2, [%rd7];\n\
263        fma.rn.f32    %f0, %f1, %f2, %f0;\n\
264        add.u32       %r11, %r11, 1;\n\
265        bra $CF_EMIT;\n\
266    \n\
267    $CF_TRANS:\n\
268        // Transition score: trans[y_prev * n_labels + y_cur]\n\
269        mul.lo.u32    %r12, %r6, %r1;\n\
270        add.u32       %r12, %r12, %r10;\n\
271        mul.wide.u32  %rd4, %r12, 4;\n\
272        add.u64       %rd5, %rd1, %rd4;\n\
273        ld.global.f32 %f3, [%rd5];\n\
274        add.f32       %f0, %f0, %f3;\n\
275    \n\
276        // score[y_prev * n_labels + y_cur] = f0\n\
277        add.u64       %rd6, %rd3, %rd4;\n\
278        st.global.f32 [%rd6], %f0;\n\
279    \n\
280    $CF_DONE:\n\
281        ret;\n\
282    }\n";
283    hdr + body
284}
285
286/// Beam top-k partial-sort kernel (one-pass rank approximation).
287///
288/// Signature: `beam_topk_kernel(scores, rank, n, k)`
289/// Each thread computes how many other scores are strictly greater (its rank);
290/// threads whose rank < k are marked surviving (`rank[tid] = rank`); others get `-1`.
291#[must_use]
292pub fn beam_topk_ptx(sm: u32) -> String {
293    let hdr = ptx_header(sm);
294    let body = ".visible .entry beam_topk_kernel(\n\
295        .param .u64 p_scores,\n\
296        .param .u64 p_rank,\n\
297        .param .u32 p_n,\n\
298        .param .u32 p_k\n\
299    )\n\
300    {\n\
301        .reg .u64  %rd<8>;\n\
302        .reg .u32  %r<10>;\n\
303        .reg .s32  %sr<4>;\n\
304        .reg .f32  %f<4>;\n\
305        .reg .pred %p0, %p1;\n\
306    \n\
307        ld.param.u64  %rd0, [p_scores];\n\
308        ld.param.u64  %rd1, [p_rank];\n\
309        ld.param.u32  %r0,  [p_n];\n\
310        ld.param.u32  %r1,  [p_k];\n\
311    \n\
312        mov.u32       %r2, %ntid.x;\n\
313        mov.u32       %r3, %ctaid.x;\n\
314        mov.u32       %r4, %tid.x;\n\
315        mad.lo.u32    %r5, %r2, %r3, %r4;\n\
316        setp.ge.u32   %p0, %r5, %r0;\n\
317        @%p0 bra $BK_DONE;\n\
318    \n\
319        // my score\n\
320        mul.wide.u32  %rd2, %r5, 4;\n\
321        add.u64       %rd3, %rd0, %rd2;\n\
322        ld.global.f32 %f0, [%rd3];\n\
323    \n\
324        mov.u32       %r6, 0;     // rank counter\n\
325        mov.u32       %r7, 0;     // loop index\n\
326    $BK_LOOP:\n\
327        setp.ge.u32   %p0, %r7, %r0;\n\
328        @%p0 bra $BK_WRITE;\n\
329        mul.wide.u32  %rd4, %r7, 4;\n\
330        add.u64       %rd5, %rd0, %rd4;\n\
331        ld.global.f32 %f1, [%rd5];\n\
332        setp.gt.f32   %p1, %f1, %f0;\n\
333        @%p1 add.u32  %r6, %r6, 1;\n\
334        add.u32       %r7, %r7, 1;\n\
335        bra $BK_LOOP;\n\
336    \n\
337    $BK_WRITE:\n\
338        setp.ge.u32   %p0, %r6, %r1;\n\
339        mov.s32       %sr0, -1;\n\
340        @!%p0 cvt.s32.u32 %sr0, %r6;\n\
341        add.u64       %rd6, %rd1, %rd2;\n\
342        st.global.s32 [%rd6], %sr0;\n\
343    \n\
344    $BK_DONE:\n\
345        ret;\n\
346    }\n";
347    hdr + body
348}
349
350/// Edit-distance anti-diagonal cell update kernel.
351///
352/// Signature: `edit_dist_kernel(dp, a_chars, b_chars, n_a, n_b, diag)`
353/// On anti-diagonal `diag`, each thread updates one cell `dp[i*(n_b+1)+j]`
354/// using the standard Levenshtein recurrence.
355#[must_use]
356pub fn edit_dist_ptx(sm: u32) -> String {
357    let hdr = ptx_header(sm);
358    let body = ".visible .entry edit_dist_kernel(\n\
359        .param .u64 p_dp,\n\
360        .param .u64 p_a,\n\
361        .param .u64 p_b,\n\
362        .param .u32 p_n_a,\n\
363        .param .u32 p_n_b,\n\
364        .param .u32 p_diag\n\
365    )\n\
366    {\n\
367        .reg .u64  %rd<10>;\n\
368        .reg .u32  %r<14>;\n\
369        .reg .s32  %sr<6>;\n\
370        .reg .pred %p0, %p1;\n\
371    \n\
372        ld.param.u64  %rd0, [p_dp];\n\
373        ld.param.u64  %rd1, [p_a];\n\
374        ld.param.u64  %rd2, [p_b];\n\
375        ld.param.u32  %r0,  [p_n_a];\n\
376        ld.param.u32  %r1,  [p_n_b];\n\
377        ld.param.u32  %r2,  [p_diag];\n\
378    \n\
379        // i = tid + 1; j = diag - i\n\
380        mov.u32       %r3, %ntid.x;\n\
381        mov.u32       %r4, %ctaid.x;\n\
382        mov.u32       %r5, %tid.x;\n\
383        mad.lo.u32    %r6, %r3, %r4, %r5;\n\
384        add.u32       %r7, %r6, 1;            // i\n\
385        sub.u32       %r8, %r2, %r7;          // j\n\
386        setp.gt.u32   %p0, %r7, %r0;\n\
387        @%p0 bra $ED_DONE;\n\
388        setp.eq.u32   %p0, %r8, 0;\n\
389        @%p0 bra $ED_DONE;\n\
390        setp.gt.u32   %p0, %r8, %r1;\n\
391        @%p0 bra $ED_DONE;\n\
392    \n\
393        // load a[i-1], b[j-1]\n\
394        sub.u32       %r9, %r7, 1;\n\
395        mul.wide.u32  %rd3, %r9, 4;\n\
396        add.u64       %rd4, %rd1, %rd3;\n\
397        ld.global.s32 %sr0, [%rd4];\n\
398        sub.u32       %r10, %r8, 1;\n\
399        mul.wide.u32  %rd5, %r10, 4;\n\
400        add.u64       %rd6, %rd2, %rd5;\n\
401        ld.global.s32 %sr1, [%rd6];\n\
402    \n\
403        // cost: 0 if eq, else 1\n\
404        setp.eq.s32   %p1, %sr0, %sr1;\n\
405        mov.s32       %sr2, 1;\n\
406        @%p1 mov.s32  %sr2, 0;\n\
407    \n\
408        // Read 3 neighbours from dp\n\
409        // dp[(i-1)*(n_b+1) + j]\n\
410        add.u32       %r11, %r1, 1;\n\
411        mul.lo.u32    %r12, %r9, %r11;\n\
412        add.u32       %r12, %r12, %r8;\n\
413        mul.wide.u32  %rd7, %r12, 4;\n\
414        add.u64       %rd8, %rd0, %rd7;\n\
415        ld.global.s32 %sr3, [%rd8];\n\
416        // dp[i*(n_b+1) + (j-1)]\n\
417        mul.lo.u32    %r12, %r7, %r11;\n\
418        add.u32       %r12, %r12, %r10;\n\
419        mul.wide.u32  %rd7, %r12, 4;\n\
420        add.u64       %rd8, %rd0, %rd7;\n\
421        ld.global.s32 %sr4, [%rd8];\n\
422        // dp[(i-1)*(n_b+1) + (j-1)]\n\
423        mul.lo.u32    %r12, %r9, %r11;\n\
424        add.u32       %r12, %r12, %r10;\n\
425        mul.wide.u32  %rd7, %r12, 4;\n\
426        add.u64       %rd8, %rd0, %rd7;\n\
427        ld.global.s32 %sr5, [%rd8];\n\
428    \n\
429        // best = min(sr3+1, sr4+1, sr5+sr2)\n\
430        add.s32       %sr3, %sr3, 1;\n\
431        add.s32       %sr4, %sr4, 1;\n\
432        add.s32       %sr5, %sr5, %sr2;\n\
433        min.s32       %sr3, %sr3, %sr4;\n\
434        min.s32       %sr3, %sr3, %sr5;\n\
435    \n\
436        // write dp[i*(n_b+1)+j]\n\
437        mul.lo.u32    %r12, %r7, %r11;\n\
438        add.u32       %r12, %r12, %r8;\n\
439        mul.wide.u32  %rd7, %r12, 4;\n\
440        add.u64       %rd8, %rd0, %rd7;\n\
441        st.global.s32 [%rd8], %sr3;\n\
442    \n\
443    $ED_DONE:\n\
444        ret;\n\
445    }\n";
446    hdr + body
447}
448
449/// Kalman predict step kernel (matrix-vector + covariance update).
450///
451/// Signature: `kalman_predict_kernel(x, x_pred, A, P, P_pred, Q, n)`
452/// Computes x_pred = A·x and P_pred = A·P·Aᵀ + Q.  Each thread handles one row.
453#[must_use]
454pub fn kalman_predict_ptx(sm: u32) -> String {
455    let hdr = ptx_header(sm);
456    let body = ".visible .entry kalman_predict_kernel(\n\
457        .param .u64 p_x,\n\
458        .param .u64 p_x_pred,\n\
459        .param .u64 p_a,\n\
460        .param .u64 p_p,\n\
461        .param .u64 p_p_pred,\n\
462        .param .u64 p_q,\n\
463        .param .u32 p_n\n\
464    )\n\
465    {\n\
466        .reg .u64  %rd<14>;\n\
467        .reg .u32  %r<14>;\n\
468        .reg .f32  %f<10>;\n\
469        .reg .pred %p0;\n\
470    \n\
471        ld.param.u64  %rd0, [p_x];\n\
472        ld.param.u64  %rd1, [p_x_pred];\n\
473        ld.param.u64  %rd2, [p_a];\n\
474        ld.param.u64  %rd3, [p_p];\n\
475        ld.param.u64  %rd4, [p_p_pred];\n\
476        ld.param.u64  %rd5, [p_q];\n\
477        ld.param.u32  %r0,  [p_n];\n\
478    \n\
479        mov.u32       %r1, %ntid.x;\n\
480        mov.u32       %r2, %ctaid.x;\n\
481        mov.u32       %r3, %tid.x;\n\
482        mad.lo.u32    %r4, %r1, %r2, %r3;\n\
483        setp.ge.u32   %p0, %r4, %r0;\n\
484        @%p0 bra $KP_DONE;\n\
485    \n\
486        // x_pred[i] = sum_k A[i,k] * x[k]\n\
487        mov.f32       %f0, 0f00000000;\n\
488        mov.u32       %r5, 0;\n\
489    $KP_VEC:\n\
490        setp.ge.u32   %p0, %r5, %r0;\n\
491        @%p0 bra $KP_VEC_WR;\n\
492        // A[i*n + k]\n\
493        mul.lo.u32    %r6, %r4, %r0;\n\
494        add.u32       %r6, %r6, %r5;\n\
495        mul.wide.u32  %rd6, %r6, 4;\n\
496        add.u64       %rd7, %rd2, %rd6;\n\
497        ld.global.f32 %f1, [%rd7];\n\
498        // x[k]\n\
499        mul.wide.u32  %rd8, %r5, 4;\n\
500        add.u64       %rd9, %rd0, %rd8;\n\
501        ld.global.f32 %f2, [%rd9];\n\
502        fma.rn.f32    %f0, %f1, %f2, %f0;\n\
503        add.u32       %r5, %r5, 1;\n\
504        bra $KP_VEC;\n\
505    \n\
506    $KP_VEC_WR:\n\
507        mul.wide.u32  %rd6, %r4, 4;\n\
508        add.u64       %rd7, %rd1, %rd6;\n\
509        st.global.f32 [%rd7], %f0;\n\
510    \n\
511        // P_pred[i,j] = sum_{k,l} A[i,k] P[k,l] A[j,l] + Q[i,j]\n\
512        // One thread handles row i, all j.\n\
513        mov.u32       %r7, 0;\n\
514    $KP_J:\n\
515        setp.ge.u32   %p0, %r7, %r0;\n\
516        @%p0 bra $KP_DONE;\n\
517        mov.f32       %f3, 0f00000000;\n\
518        mov.u32       %r8, 0;\n\
519    $KP_K:\n\
520        setp.ge.u32   %p0, %r8, %r0;\n\
521        @%p0 bra $KP_K_DONE;\n\
522        mov.f32       %f4, 0f00000000;\n\
523        mov.u32       %r9, 0;\n\
524    $KP_L:\n\
525        setp.ge.u32   %p0, %r9, %r0;\n\
526        @%p0 bra $KP_L_DONE;\n\
527        // P[k*n + l]\n\
528        mul.lo.u32    %r10, %r8, %r0;\n\
529        add.u32       %r10, %r10, %r9;\n\
530        mul.wide.u32  %rd6, %r10, 4;\n\
531        add.u64       %rd7, %rd3, %rd6;\n\
532        ld.global.f32 %f5, [%rd7];\n\
533        // A[j*n + l]\n\
534        mul.lo.u32    %r11, %r7, %r0;\n\
535        add.u32       %r11, %r11, %r9;\n\
536        mul.wide.u32  %rd8, %r11, 4;\n\
537        add.u64       %rd9, %rd2, %rd8;\n\
538        ld.global.f32 %f6, [%rd9];\n\
539        fma.rn.f32    %f4, %f5, %f6, %f4;\n\
540        add.u32       %r9, %r9, 1;\n\
541        bra $KP_L;\n\
542    \n\
543    $KP_L_DONE:\n\
544        // A[i*n + k]\n\
545        mul.lo.u32    %r10, %r4, %r0;\n\
546        add.u32       %r10, %r10, %r8;\n\
547        mul.wide.u32  %rd6, %r10, 4;\n\
548        add.u64       %rd7, %rd2, %rd6;\n\
549        ld.global.f32 %f7, [%rd7];\n\
550        fma.rn.f32    %f3, %f7, %f4, %f3;\n\
551        add.u32       %r8, %r8, 1;\n\
552        bra $KP_K;\n\
553    \n\
554    $KP_K_DONE:\n\
555        // P_pred[i*n+j] = f3 + Q[i*n+j]\n\
556        mul.lo.u32    %r10, %r4, %r0;\n\
557        add.u32       %r10, %r10, %r7;\n\
558        mul.wide.u32  %rd6, %r10, 4;\n\
559        add.u64       %rd7, %rd5, %rd6;\n\
560        ld.global.f32 %f8, [%rd7];\n\
561        add.f32       %f3, %f3, %f8;\n\
562        add.u64       %rd9, %rd4, %rd6;\n\
563        st.global.f32 [%rd9], %f3;\n\
564        add.u32       %r7, %r7, 1;\n\
565        bra $KP_J;\n\
566    \n\
567    $KP_DONE:\n\
568        ret;\n\
569    }\n";
570    hdr + body
571}
572
573/// Gibbs-sampling Ising MRF kernel: per-site conditional resample given neighbours.
574///
575/// Signature: `mrf_gibbs_kernel(spins, h, j, n_rows, n_cols, seed)`
576/// Each thread samples its spin given a uniform draw from the inline LCG seeded
577/// by `seed ^ tid`.
578#[must_use]
579pub fn mrf_gibbs_ptx(sm: u32) -> String {
580    let hdr = ptx_header(sm);
581    let body = ".visible .entry mrf_gibbs_kernel(\n\
582        .param .u64 p_spins,\n\
583        .param .f32 p_h,\n\
584        .param .f32 p_j,\n\
585        .param .u32 p_n_rows,\n\
586        .param .u32 p_n_cols,\n\
587        .param .u64 p_seed\n\
588    )\n\
589    {\n\
590        .reg .u64  %rd<10>;\n\
591        .reg .u32  %r<16>;\n\
592        .reg .s32  %sr<8>;\n\
593        .reg .f32  %f<10>;\n\
594        .reg .pred %p0, %p1;\n\
595    \n\
596        ld.param.u64  %rd0, [p_spins];\n\
597        ld.param.f32  %f0,  [p_h];\n\
598        ld.param.f32  %f1,  [p_j];\n\
599        ld.param.u32  %r0,  [p_n_rows];\n\
600        ld.param.u32  %r1,  [p_n_cols];\n\
601        ld.param.u64  %rd1, [p_seed];\n\
602    \n\
603        // i = blockIdx.y * blockDim.y + threadIdx.y\n\
604        mov.u32       %r2, %ntid.y;\n\
605        mov.u32       %r3, %ctaid.y;\n\
606        mov.u32       %r4, %tid.y;\n\
607        mad.lo.u32    %r5, %r2, %r3, %r4;\n\
608        // j = blockIdx.x * blockDim.x + threadIdx.x\n\
609        mov.u32       %r6, %ntid.x;\n\
610        mov.u32       %r7, %ctaid.x;\n\
611        mov.u32       %r8, %tid.x;\n\
612        mad.lo.u32    %r9, %r6, %r7, %r8;\n\
613    \n\
614        setp.ge.u32   %p0, %r5, %r0;\n\
615        @%p0 bra $MG_DONE;\n\
616        setp.ge.u32   %p0, %r9, %r1;\n\
617        @%p0 bra $MG_DONE;\n\
618    \n\
619        // Sum of neighbour spins (with bounds checks)\n\
620        mov.s32       %sr0, 0;\n\
621        // up\n\
622        setp.eq.u32   %p1, %r5, 0;\n\
623        @%p1 bra $MG_NB1;\n\
624        sub.u32       %r10, %r5, 1;\n\
625        mul.lo.u32    %r11, %r10, %r1;\n\
626        add.u32       %r11, %r11, %r9;\n\
627        mul.wide.u32  %rd2, %r11, 4;\n\
628        add.u64       %rd3, %rd0, %rd2;\n\
629        ld.global.s32 %sr1, [%rd3];\n\
630        add.s32       %sr0, %sr0, %sr1;\n\
631    $MG_NB1:\n\
632        // down\n\
633        add.u32       %r10, %r5, 1;\n\
634        setp.ge.u32   %p1, %r10, %r0;\n\
635        @%p1 bra $MG_NB2;\n\
636        mul.lo.u32    %r11, %r10, %r1;\n\
637        add.u32       %r11, %r11, %r9;\n\
638        mul.wide.u32  %rd2, %r11, 4;\n\
639        add.u64       %rd3, %rd0, %rd2;\n\
640        ld.global.s32 %sr1, [%rd3];\n\
641        add.s32       %sr0, %sr0, %sr1;\n\
642    $MG_NB2:\n\
643        // left\n\
644        setp.eq.u32   %p1, %r9, 0;\n\
645        @%p1 bra $MG_NB3;\n\
646        sub.u32       %r10, %r9, 1;\n\
647        mul.lo.u32    %r11, %r5, %r1;\n\
648        add.u32       %r11, %r11, %r10;\n\
649        mul.wide.u32  %rd2, %r11, 4;\n\
650        add.u64       %rd3, %rd0, %rd2;\n\
651        ld.global.s32 %sr1, [%rd3];\n\
652        add.s32       %sr0, %sr0, %sr1;\n\
653    $MG_NB3:\n\
654        // right\n\
655        add.u32       %r10, %r9, 1;\n\
656        setp.ge.u32   %p1, %r10, %r1;\n\
657        @%p1 bra $MG_FIELD;\n\
658        mul.lo.u32    %r11, %r5, %r1;\n\
659        add.u32       %r11, %r11, %r10;\n\
660        mul.wide.u32  %rd2, %r11, 4;\n\
661        add.u64       %rd3, %rd0, %rd2;\n\
662        ld.global.s32 %sr1, [%rd3];\n\
663        add.s32       %sr0, %sr0, %sr1;\n\
664    \n\
665    $MG_FIELD:\n\
666        // field = j * sum + h\n\
667        cvt.rn.f32.s32 %f2, %sr0;\n\
668        mul.f32       %f3, %f1, %f2;\n\
669        add.f32       %f3, %f3, %f0;\n\
670    \n\
671        // p_up = 1 / (1 + exp(-2*field))\n\
672        mov.f32       %f4, 0fC0000000;       // -2\n\
673        mul.f32       %f5, %f4, %f3;\n\
674        mul.f32       %f5, %f5, 0f3FB8AA3B;  // * log2(e): exp(x)=ex2(x*log2e)\n\
675        ex2.approx.f32 %f5, %f5;             // exp(-2*field)\n\
676        mov.f32       %f6, 0f3F800000;       // 1.0\n\
677        add.f32       %f5, %f5, %f6;\n\
678        div.rn.f32    %f7, %f6, %f5;\n\
679    \n\
680        // Inline LCG: seed ^ (row * n_cols + col)\n\
681        mul.lo.u32    %r12, %r5, %r1;\n\
682        add.u32       %r12, %r12, %r9;\n\
683        cvt.u64.u32   %rd4, %r12;\n\
684        xor.b64       %rd5, %rd4, %rd1;\n\
685        mov.u64       %rd6, 6364136223846793005;\n\
686        mul.lo.u64    %rd5, %rd5, %rd6;\n\
687        mov.u64       %rd6, 1442695040888963407;\n\
688        add.u64       %rd5, %rd5, %rd6;\n\
689        shr.u64       %rd7, %rd5, 32;\n\
690        cvt.u32.u64   %r13, %rd7;\n\
691        // u = (high32 >> 8) / 2^24\n\
692        shr.u32       %r14, %r13, 8;\n\
693        cvt.rn.f32.u32 %f8, %r14;\n\
694        mov.f32       %f9, 0f33800000;       // 1 / 2^24\n\
695        mul.f32       %f8, %f8, %f9;\n\
696    \n\
697        // s = (u < p_up) ? +1 : -1\n\
698        setp.lt.f32   %p1, %f8, %f7;\n\
699        mov.s32       %sr2, -1;\n\
700        @%p1 mov.s32  %sr2, 1;\n\
701    \n\
702        // store\n\
703        mul.lo.u32    %r15, %r5, %r1;\n\
704        add.u32       %r15, %r15, %r9;\n\
705        mul.wide.u32  %rd8, %r15, 4;\n\
706        add.u64       %rd9, %rd0, %rd8;\n\
707        st.global.s32 [%rd9], %sr2;\n\
708    \n\
709    $MG_DONE:\n\
710        ret;\n\
711    }\n";
712    hdr + body
713}
714
715#[cfg(test)]
716mod tests {
717    use super::*;
718
719    #[test]
720    fn ptx_header_versions() {
721        assert!(ptx_header(75).contains(".version 7.5"));
722        assert!(ptx_header(80).contains(".version 8.0"));
723        assert!(ptx_header(89).contains(".version 8.0"));
724        assert!(ptx_header(90).contains(".version 8.4"));
725        assert!(ptx_header(100).contains(".version 8.7"));
726    }
727
728    #[test]
729    fn all_kernels_non_empty() {
730        type KernelFn = fn(u32) -> String;
731        let kernels: &[(&str, KernelFn)] = &[
732            ("forward_pass", forward_pass_ptx),
733            ("viterbi_step", viterbi_step_ptx),
734            ("crf_features", crf_features_ptx),
735            ("beam_topk", beam_topk_ptx),
736            ("edit_dist", edit_dist_ptx),
737            ("kalman_predict", kalman_predict_ptx),
738            ("mrf_gibbs", mrf_gibbs_ptx),
739        ];
740        let sms = [75u32, 80, 86, 89, 90, 100];
741        for &sm in &sms {
742            for &(name, f) in kernels {
743                let s = f(sm);
744                assert!(!s.is_empty(), "{name} sm{sm} empty");
745                assert!(
746                    s.contains(".visible .entry"),
747                    "{name} sm{sm} missing .visible .entry"
748                );
749            }
750        }
751    }
752}