1fn ptx_header(sm: u32) -> String {
15 let (ptx_ver, target) = match sm {
16 v if v >= 100 => ("8.7", format!("sm_{v}")),
17 v if v >= 90 => ("8.4", format!("sm_{v}")),
18 v if v >= 80 => ("8.0", format!("sm_{v}")),
19 v => ("7.5", format!("sm_{v}")),
20 };
21 format!(".version {ptx_ver}\n.target {target}\n.address_size 64\n\n")
22}
23
24#[must_use]
30pub fn forward_pass_ptx(sm: u32) -> String {
31 let hdr = ptx_header(sm);
32 let body = ".visible .entry forward_pass_kernel(\n\
33 .param .u64 p_alpha_prev,\n\
34 .param .u64 p_alpha_next,\n\
35 .param .u64 p_log_a,\n\
36 .param .u64 p_log_b_o,\n\
37 .param .u32 p_n_states\n\
38 )\n\
39 {\n\
40 .reg .u64 %rd<10>;\n\
41 .reg .u32 %r<10>;\n\
42 .reg .f32 %f<10>;\n\
43 .reg .pred %p0;\n\
44 \n\
45 ld.param.u64 %rd0, [p_alpha_prev];\n\
46 ld.param.u64 %rd1, [p_alpha_next];\n\
47 ld.param.u64 %rd2, [p_log_a];\n\
48 ld.param.u64 %rd3, [p_log_b_o];\n\
49 ld.param.u32 %r0, [p_n_states];\n\
50 \n\
51 // j = global thread id\n\
52 mov.u32 %r1, %ntid.x;\n\
53 mov.u32 %r2, %ctaid.x;\n\
54 mov.u32 %r3, %tid.x;\n\
55 mad.lo.u32 %r4, %r1, %r2, %r3;\n\
56 setp.ge.u32 %p0, %r4, %r0;\n\
57 @%p0 bra $FP_DONE;\n\
58 \n\
59 // First pass: find max of (alpha_prev[i] + log_a[i*S + j])\n\
60 mov.f32 %f0, 0fFF800000; // -inf\n\
61 mov.u32 %r5, 0;\n\
62 $FP_MAX:\n\
63 setp.ge.u32 %p0, %r5, %r0;\n\
64 @%p0 bra $FP_SUM_INIT;\n\
65 // alpha_prev[i]\n\
66 mul.wide.u32 %rd4, %r5, 4;\n\
67 add.u64 %rd5, %rd0, %rd4;\n\
68 ld.global.f32 %f1, [%rd5];\n\
69 // log_a[i*S + j]\n\
70 mul.lo.u32 %r6, %r5, %r0;\n\
71 add.u32 %r6, %r6, %r4;\n\
72 mul.wide.u32 %rd6, %r6, 4;\n\
73 add.u64 %rd7, %rd2, %rd6;\n\
74 ld.global.f32 %f2, [%rd7];\n\
75 add.f32 %f3, %f1, %f2;\n\
76 max.f32 %f0, %f0, %f3;\n\
77 add.u32 %r5, %r5, 1;\n\
78 bra $FP_MAX;\n\
79 \n\
80 $FP_SUM_INIT:\n\
81 // Second pass: accumulate exp((alpha_prev[i] + log_a[i*S+j]) - max)\n\
82 mov.f32 %f4, 0f00000000;\n\
83 mov.u32 %r5, 0;\n\
84 $FP_SUM:\n\
85 setp.ge.u32 %p0, %r5, %r0;\n\
86 @%p0 bra $FP_WRITE;\n\
87 mul.wide.u32 %rd4, %r5, 4;\n\
88 add.u64 %rd5, %rd0, %rd4;\n\
89 ld.global.f32 %f1, [%rd5];\n\
90 mul.lo.u32 %r6, %r5, %r0;\n\
91 add.u32 %r6, %r6, %r4;\n\
92 mul.wide.u32 %rd6, %r6, 4;\n\
93 add.u64 %rd7, %rd2, %rd6;\n\
94 ld.global.f32 %f2, [%rd7];\n\
95 add.f32 %f3, %f1, %f2;\n\
96 sub.f32 %f3, %f3, %f0;\n\
97 mul.f32 %f3, %f3, 0f3FB8AA3B; // * log2(e): exp(x)=ex2(x*log2e)\n\
98 ex2.approx.f32 %f3, %f3;\n\
99 add.f32 %f4, %f4, %f3;\n\
100 add.u32 %r5, %r5, 1;\n\
101 bra $FP_SUM;\n\
102 \n\
103 $FP_WRITE:\n\
104 // result = max + log(sum) + log_b_o[j]\n\
105 lg2.approx.f32 %f4, %f4;\n\
106 mul.f32 %f4, %f4, 0f3F317218; // * ln(2): ln(x)=lg2(x)*ln2\n\
107 add.f32 %f4, %f4, %f0;\n\
108 mul.wide.u32 %rd4, %r4, 4;\n\
109 add.u64 %rd5, %rd3, %rd4;\n\
110 ld.global.f32 %f5, [%rd5];\n\
111 add.f32 %f4, %f4, %f5;\n\
112 add.u64 %rd6, %rd1, %rd4;\n\
113 st.global.f32 [%rd6], %f4;\n\
114 \n\
115 $FP_DONE:\n\
116 ret;\n\
117 }\n";
118 hdr + body
119}
120
121#[must_use]
127pub fn viterbi_step_ptx(sm: u32) -> String {
128 let hdr = ptx_header(sm);
129 let body = ".visible .entry viterbi_step_kernel(\n\
130 .param .u64 p_delta_prev,\n\
131 .param .u64 p_delta_next,\n\
132 .param .u64 p_log_a,\n\
133 .param .u64 p_log_b_o,\n\
134 .param .u64 p_psi,\n\
135 .param .u32 p_n_states\n\
136 )\n\
137 {\n\
138 .reg .u64 %rd<10>;\n\
139 .reg .u32 %r<10>;\n\
140 .reg .s32 %sr<4>;\n\
141 .reg .f32 %f<8>;\n\
142 .reg .pred %p0, %p1;\n\
143 \n\
144 ld.param.u64 %rd0, [p_delta_prev];\n\
145 ld.param.u64 %rd1, [p_delta_next];\n\
146 ld.param.u64 %rd2, [p_log_a];\n\
147 ld.param.u64 %rd3, [p_log_b_o];\n\
148 ld.param.u64 %rd4, [p_psi];\n\
149 ld.param.u32 %r0, [p_n_states];\n\
150 \n\
151 mov.u32 %r1, %ntid.x;\n\
152 mov.u32 %r2, %ctaid.x;\n\
153 mov.u32 %r3, %tid.x;\n\
154 mad.lo.u32 %r4, %r1, %r2, %r3;\n\
155 setp.ge.u32 %p0, %r4, %r0;\n\
156 @%p0 bra $VS_DONE;\n\
157 \n\
158 mov.f32 %f0, 0fFF800000; // best = -inf\n\
159 mov.s32 %sr0, -1; // argmax\n\
160 mov.u32 %r5, 0;\n\
161 $VS_LOOP:\n\
162 setp.ge.u32 %p0, %r5, %r0;\n\
163 @%p0 bra $VS_WRITE;\n\
164 // delta_prev[i]\n\
165 mul.wide.u32 %rd5, %r5, 4;\n\
166 add.u64 %rd6, %rd0, %rd5;\n\
167 ld.global.f32 %f1, [%rd6];\n\
168 // log_a[i*S + j]\n\
169 mul.lo.u32 %r6, %r5, %r0;\n\
170 add.u32 %r6, %r6, %r4;\n\
171 mul.wide.u32 %rd7, %r6, 4;\n\
172 add.u64 %rd8, %rd2, %rd7;\n\
173 ld.global.f32 %f2, [%rd8];\n\
174 add.f32 %f3, %f1, %f2;\n\
175 setp.gt.f32 %p1, %f3, %f0;\n\
176 @%p1 mov.f32 %f0, %f3;\n\
177 @%p1 cvt.s32.u32 %sr0, %r5;\n\
178 add.u32 %r5, %r5, 1;\n\
179 bra $VS_LOOP;\n\
180 \n\
181 $VS_WRITE:\n\
182 // delta_next[j] = best + log_b_o[j]\n\
183 mul.wide.u32 %rd5, %r4, 4;\n\
184 add.u64 %rd6, %rd3, %rd5;\n\
185 ld.global.f32 %f4, [%rd6];\n\
186 add.f32 %f0, %f0, %f4;\n\
187 add.u64 %rd7, %rd1, %rd5;\n\
188 st.global.f32 [%rd7], %f0;\n\
189 // psi[j] = argmax\n\
190 add.u64 %rd8, %rd4, %rd5;\n\
191 st.global.s32 [%rd8], %sr0;\n\
192 \n\
193 $VS_DONE:\n\
194 ret;\n\
195 }\n";
196 hdr + body
197}
198
199#[must_use]
204pub fn crf_features_ptx(sm: u32) -> String {
205 let hdr = ptx_header(sm);
206 let body = ".visible .entry crf_features_kernel(\n\
207 .param .u64 p_emit,\n\
208 .param .u64 p_trans,\n\
209 .param .u64 p_x_feat,\n\
210 .param .u64 p_score,\n\
211 .param .u32 p_t,\n\
212 .param .u32 p_n_labels,\n\
213 .param .u32 p_n_features\n\
214 )\n\
215 {\n\
216 .reg .u64 %rd<10>;\n\
217 .reg .u32 %r<14>;\n\
218 .reg .f32 %f<8>;\n\
219 .reg .pred %p0;\n\
220 \n\
221 ld.param.u64 %rd0, [p_emit];\n\
222 ld.param.u64 %rd1, [p_trans];\n\
223 ld.param.u64 %rd2, [p_x_feat];\n\
224 ld.param.u64 %rd3, [p_score];\n\
225 ld.param.u32 %r0, [p_t];\n\
226 ld.param.u32 %r1, [p_n_labels];\n\
227 ld.param.u32 %r2, [p_n_features];\n\
228 \n\
229 // y_prev = blockIdx.y * blockDim.y + threadIdx.y\n\
230 mov.u32 %r3, %ntid.y;\n\
231 mov.u32 %r4, %ctaid.y;\n\
232 mov.u32 %r5, %tid.y;\n\
233 mad.lo.u32 %r6, %r3, %r4, %r5;\n\
234 // y_cur = blockIdx.x * blockDim.x + threadIdx.x\n\
235 mov.u32 %r7, %ntid.x;\n\
236 mov.u32 %r8, %ctaid.x;\n\
237 mov.u32 %r9, %tid.x;\n\
238 mad.lo.u32 %r10, %r7, %r8, %r9;\n\
239 \n\
240 setp.ge.u32 %p0, %r6, %r1;\n\
241 @%p0 bra $CF_DONE;\n\
242 setp.ge.u32 %p0, %r10, %r1;\n\
243 @%p0 bra $CF_DONE;\n\
244 \n\
245 // Emission score: dot(emit[y_cur,:], x_feat[t,:])\n\
246 mov.f32 %f0, 0f00000000;\n\
247 mov.u32 %r11, 0;\n\
248 $CF_EMIT:\n\
249 setp.ge.u32 %p0, %r11, %r2;\n\
250 @%p0 bra $CF_TRANS;\n\
251 // emit[y_cur * n_features + k]\n\
252 mul.lo.u32 %r12, %r10, %r2;\n\
253 add.u32 %r12, %r12, %r11;\n\
254 mul.wide.u32 %rd4, %r12, 4;\n\
255 add.u64 %rd5, %rd0, %rd4;\n\
256 ld.global.f32 %f1, [%rd5];\n\
257 // x_feat[t * n_features + k]\n\
258 mul.lo.u32 %r13, %r0, %r2;\n\
259 add.u32 %r13, %r13, %r11;\n\
260 mul.wide.u32 %rd6, %r13, 4;\n\
261 add.u64 %rd7, %rd2, %rd6;\n\
262 ld.global.f32 %f2, [%rd7];\n\
263 fma.rn.f32 %f0, %f1, %f2, %f0;\n\
264 add.u32 %r11, %r11, 1;\n\
265 bra $CF_EMIT;\n\
266 \n\
267 $CF_TRANS:\n\
268 // Transition score: trans[y_prev * n_labels + y_cur]\n\
269 mul.lo.u32 %r12, %r6, %r1;\n\
270 add.u32 %r12, %r12, %r10;\n\
271 mul.wide.u32 %rd4, %r12, 4;\n\
272 add.u64 %rd5, %rd1, %rd4;\n\
273 ld.global.f32 %f3, [%rd5];\n\
274 add.f32 %f0, %f0, %f3;\n\
275 \n\
276 // score[y_prev * n_labels + y_cur] = f0\n\
277 add.u64 %rd6, %rd3, %rd4;\n\
278 st.global.f32 [%rd6], %f0;\n\
279 \n\
280 $CF_DONE:\n\
281 ret;\n\
282 }\n";
283 hdr + body
284}
285
286#[must_use]
292pub fn beam_topk_ptx(sm: u32) -> String {
293 let hdr = ptx_header(sm);
294 let body = ".visible .entry beam_topk_kernel(\n\
295 .param .u64 p_scores,\n\
296 .param .u64 p_rank,\n\
297 .param .u32 p_n,\n\
298 .param .u32 p_k\n\
299 )\n\
300 {\n\
301 .reg .u64 %rd<8>;\n\
302 .reg .u32 %r<10>;\n\
303 .reg .s32 %sr<4>;\n\
304 .reg .f32 %f<4>;\n\
305 .reg .pred %p0, %p1;\n\
306 \n\
307 ld.param.u64 %rd0, [p_scores];\n\
308 ld.param.u64 %rd1, [p_rank];\n\
309 ld.param.u32 %r0, [p_n];\n\
310 ld.param.u32 %r1, [p_k];\n\
311 \n\
312 mov.u32 %r2, %ntid.x;\n\
313 mov.u32 %r3, %ctaid.x;\n\
314 mov.u32 %r4, %tid.x;\n\
315 mad.lo.u32 %r5, %r2, %r3, %r4;\n\
316 setp.ge.u32 %p0, %r5, %r0;\n\
317 @%p0 bra $BK_DONE;\n\
318 \n\
319 // my score\n\
320 mul.wide.u32 %rd2, %r5, 4;\n\
321 add.u64 %rd3, %rd0, %rd2;\n\
322 ld.global.f32 %f0, [%rd3];\n\
323 \n\
324 mov.u32 %r6, 0; // rank counter\n\
325 mov.u32 %r7, 0; // loop index\n\
326 $BK_LOOP:\n\
327 setp.ge.u32 %p0, %r7, %r0;\n\
328 @%p0 bra $BK_WRITE;\n\
329 mul.wide.u32 %rd4, %r7, 4;\n\
330 add.u64 %rd5, %rd0, %rd4;\n\
331 ld.global.f32 %f1, [%rd5];\n\
332 setp.gt.f32 %p1, %f1, %f0;\n\
333 @%p1 add.u32 %r6, %r6, 1;\n\
334 add.u32 %r7, %r7, 1;\n\
335 bra $BK_LOOP;\n\
336 \n\
337 $BK_WRITE:\n\
338 setp.ge.u32 %p0, %r6, %r1;\n\
339 mov.s32 %sr0, -1;\n\
340 @!%p0 cvt.s32.u32 %sr0, %r6;\n\
341 add.u64 %rd6, %rd1, %rd2;\n\
342 st.global.s32 [%rd6], %sr0;\n\
343 \n\
344 $BK_DONE:\n\
345 ret;\n\
346 }\n";
347 hdr + body
348}
349
350#[must_use]
356pub fn edit_dist_ptx(sm: u32) -> String {
357 let hdr = ptx_header(sm);
358 let body = ".visible .entry edit_dist_kernel(\n\
359 .param .u64 p_dp,\n\
360 .param .u64 p_a,\n\
361 .param .u64 p_b,\n\
362 .param .u32 p_n_a,\n\
363 .param .u32 p_n_b,\n\
364 .param .u32 p_diag\n\
365 )\n\
366 {\n\
367 .reg .u64 %rd<10>;\n\
368 .reg .u32 %r<14>;\n\
369 .reg .s32 %sr<6>;\n\
370 .reg .pred %p0, %p1;\n\
371 \n\
372 ld.param.u64 %rd0, [p_dp];\n\
373 ld.param.u64 %rd1, [p_a];\n\
374 ld.param.u64 %rd2, [p_b];\n\
375 ld.param.u32 %r0, [p_n_a];\n\
376 ld.param.u32 %r1, [p_n_b];\n\
377 ld.param.u32 %r2, [p_diag];\n\
378 \n\
379 // i = tid + 1; j = diag - i\n\
380 mov.u32 %r3, %ntid.x;\n\
381 mov.u32 %r4, %ctaid.x;\n\
382 mov.u32 %r5, %tid.x;\n\
383 mad.lo.u32 %r6, %r3, %r4, %r5;\n\
384 add.u32 %r7, %r6, 1; // i\n\
385 sub.u32 %r8, %r2, %r7; // j\n\
386 setp.gt.u32 %p0, %r7, %r0;\n\
387 @%p0 bra $ED_DONE;\n\
388 setp.eq.u32 %p0, %r8, 0;\n\
389 @%p0 bra $ED_DONE;\n\
390 setp.gt.u32 %p0, %r8, %r1;\n\
391 @%p0 bra $ED_DONE;\n\
392 \n\
393 // load a[i-1], b[j-1]\n\
394 sub.u32 %r9, %r7, 1;\n\
395 mul.wide.u32 %rd3, %r9, 4;\n\
396 add.u64 %rd4, %rd1, %rd3;\n\
397 ld.global.s32 %sr0, [%rd4];\n\
398 sub.u32 %r10, %r8, 1;\n\
399 mul.wide.u32 %rd5, %r10, 4;\n\
400 add.u64 %rd6, %rd2, %rd5;\n\
401 ld.global.s32 %sr1, [%rd6];\n\
402 \n\
403 // cost: 0 if eq, else 1\n\
404 setp.eq.s32 %p1, %sr0, %sr1;\n\
405 mov.s32 %sr2, 1;\n\
406 @%p1 mov.s32 %sr2, 0;\n\
407 \n\
408 // Read 3 neighbours from dp\n\
409 // dp[(i-1)*(n_b+1) + j]\n\
410 add.u32 %r11, %r1, 1;\n\
411 mul.lo.u32 %r12, %r9, %r11;\n\
412 add.u32 %r12, %r12, %r8;\n\
413 mul.wide.u32 %rd7, %r12, 4;\n\
414 add.u64 %rd8, %rd0, %rd7;\n\
415 ld.global.s32 %sr3, [%rd8];\n\
416 // dp[i*(n_b+1) + (j-1)]\n\
417 mul.lo.u32 %r12, %r7, %r11;\n\
418 add.u32 %r12, %r12, %r10;\n\
419 mul.wide.u32 %rd7, %r12, 4;\n\
420 add.u64 %rd8, %rd0, %rd7;\n\
421 ld.global.s32 %sr4, [%rd8];\n\
422 // dp[(i-1)*(n_b+1) + (j-1)]\n\
423 mul.lo.u32 %r12, %r9, %r11;\n\
424 add.u32 %r12, %r12, %r10;\n\
425 mul.wide.u32 %rd7, %r12, 4;\n\
426 add.u64 %rd8, %rd0, %rd7;\n\
427 ld.global.s32 %sr5, [%rd8];\n\
428 \n\
429 // best = min(sr3+1, sr4+1, sr5+sr2)\n\
430 add.s32 %sr3, %sr3, 1;\n\
431 add.s32 %sr4, %sr4, 1;\n\
432 add.s32 %sr5, %sr5, %sr2;\n\
433 min.s32 %sr3, %sr3, %sr4;\n\
434 min.s32 %sr3, %sr3, %sr5;\n\
435 \n\
436 // write dp[i*(n_b+1)+j]\n\
437 mul.lo.u32 %r12, %r7, %r11;\n\
438 add.u32 %r12, %r12, %r8;\n\
439 mul.wide.u32 %rd7, %r12, 4;\n\
440 add.u64 %rd8, %rd0, %rd7;\n\
441 st.global.s32 [%rd8], %sr3;\n\
442 \n\
443 $ED_DONE:\n\
444 ret;\n\
445 }\n";
446 hdr + body
447}
448
449#[must_use]
454pub fn kalman_predict_ptx(sm: u32) -> String {
455 let hdr = ptx_header(sm);
456 let body = ".visible .entry kalman_predict_kernel(\n\
457 .param .u64 p_x,\n\
458 .param .u64 p_x_pred,\n\
459 .param .u64 p_a,\n\
460 .param .u64 p_p,\n\
461 .param .u64 p_p_pred,\n\
462 .param .u64 p_q,\n\
463 .param .u32 p_n\n\
464 )\n\
465 {\n\
466 .reg .u64 %rd<14>;\n\
467 .reg .u32 %r<14>;\n\
468 .reg .f32 %f<10>;\n\
469 .reg .pred %p0;\n\
470 \n\
471 ld.param.u64 %rd0, [p_x];\n\
472 ld.param.u64 %rd1, [p_x_pred];\n\
473 ld.param.u64 %rd2, [p_a];\n\
474 ld.param.u64 %rd3, [p_p];\n\
475 ld.param.u64 %rd4, [p_p_pred];\n\
476 ld.param.u64 %rd5, [p_q];\n\
477 ld.param.u32 %r0, [p_n];\n\
478 \n\
479 mov.u32 %r1, %ntid.x;\n\
480 mov.u32 %r2, %ctaid.x;\n\
481 mov.u32 %r3, %tid.x;\n\
482 mad.lo.u32 %r4, %r1, %r2, %r3;\n\
483 setp.ge.u32 %p0, %r4, %r0;\n\
484 @%p0 bra $KP_DONE;\n\
485 \n\
486 // x_pred[i] = sum_k A[i,k] * x[k]\n\
487 mov.f32 %f0, 0f00000000;\n\
488 mov.u32 %r5, 0;\n\
489 $KP_VEC:\n\
490 setp.ge.u32 %p0, %r5, %r0;\n\
491 @%p0 bra $KP_VEC_WR;\n\
492 // A[i*n + k]\n\
493 mul.lo.u32 %r6, %r4, %r0;\n\
494 add.u32 %r6, %r6, %r5;\n\
495 mul.wide.u32 %rd6, %r6, 4;\n\
496 add.u64 %rd7, %rd2, %rd6;\n\
497 ld.global.f32 %f1, [%rd7];\n\
498 // x[k]\n\
499 mul.wide.u32 %rd8, %r5, 4;\n\
500 add.u64 %rd9, %rd0, %rd8;\n\
501 ld.global.f32 %f2, [%rd9];\n\
502 fma.rn.f32 %f0, %f1, %f2, %f0;\n\
503 add.u32 %r5, %r5, 1;\n\
504 bra $KP_VEC;\n\
505 \n\
506 $KP_VEC_WR:\n\
507 mul.wide.u32 %rd6, %r4, 4;\n\
508 add.u64 %rd7, %rd1, %rd6;\n\
509 st.global.f32 [%rd7], %f0;\n\
510 \n\
511 // P_pred[i,j] = sum_{k,l} A[i,k] P[k,l] A[j,l] + Q[i,j]\n\
512 // One thread handles row i, all j.\n\
513 mov.u32 %r7, 0;\n\
514 $KP_J:\n\
515 setp.ge.u32 %p0, %r7, %r0;\n\
516 @%p0 bra $KP_DONE;\n\
517 mov.f32 %f3, 0f00000000;\n\
518 mov.u32 %r8, 0;\n\
519 $KP_K:\n\
520 setp.ge.u32 %p0, %r8, %r0;\n\
521 @%p0 bra $KP_K_DONE;\n\
522 mov.f32 %f4, 0f00000000;\n\
523 mov.u32 %r9, 0;\n\
524 $KP_L:\n\
525 setp.ge.u32 %p0, %r9, %r0;\n\
526 @%p0 bra $KP_L_DONE;\n\
527 // P[k*n + l]\n\
528 mul.lo.u32 %r10, %r8, %r0;\n\
529 add.u32 %r10, %r10, %r9;\n\
530 mul.wide.u32 %rd6, %r10, 4;\n\
531 add.u64 %rd7, %rd3, %rd6;\n\
532 ld.global.f32 %f5, [%rd7];\n\
533 // A[j*n + l]\n\
534 mul.lo.u32 %r11, %r7, %r0;\n\
535 add.u32 %r11, %r11, %r9;\n\
536 mul.wide.u32 %rd8, %r11, 4;\n\
537 add.u64 %rd9, %rd2, %rd8;\n\
538 ld.global.f32 %f6, [%rd9];\n\
539 fma.rn.f32 %f4, %f5, %f6, %f4;\n\
540 add.u32 %r9, %r9, 1;\n\
541 bra $KP_L;\n\
542 \n\
543 $KP_L_DONE:\n\
544 // A[i*n + k]\n\
545 mul.lo.u32 %r10, %r4, %r0;\n\
546 add.u32 %r10, %r10, %r8;\n\
547 mul.wide.u32 %rd6, %r10, 4;\n\
548 add.u64 %rd7, %rd2, %rd6;\n\
549 ld.global.f32 %f7, [%rd7];\n\
550 fma.rn.f32 %f3, %f7, %f4, %f3;\n\
551 add.u32 %r8, %r8, 1;\n\
552 bra $KP_K;\n\
553 \n\
554 $KP_K_DONE:\n\
555 // P_pred[i*n+j] = f3 + Q[i*n+j]\n\
556 mul.lo.u32 %r10, %r4, %r0;\n\
557 add.u32 %r10, %r10, %r7;\n\
558 mul.wide.u32 %rd6, %r10, 4;\n\
559 add.u64 %rd7, %rd5, %rd6;\n\
560 ld.global.f32 %f8, [%rd7];\n\
561 add.f32 %f3, %f3, %f8;\n\
562 add.u64 %rd9, %rd4, %rd6;\n\
563 st.global.f32 [%rd9], %f3;\n\
564 add.u32 %r7, %r7, 1;\n\
565 bra $KP_J;\n\
566 \n\
567 $KP_DONE:\n\
568 ret;\n\
569 }\n";
570 hdr + body
571}
572
573#[must_use]
579pub fn mrf_gibbs_ptx(sm: u32) -> String {
580 let hdr = ptx_header(sm);
581 let body = ".visible .entry mrf_gibbs_kernel(\n\
582 .param .u64 p_spins,\n\
583 .param .f32 p_h,\n\
584 .param .f32 p_j,\n\
585 .param .u32 p_n_rows,\n\
586 .param .u32 p_n_cols,\n\
587 .param .u64 p_seed\n\
588 )\n\
589 {\n\
590 .reg .u64 %rd<10>;\n\
591 .reg .u32 %r<16>;\n\
592 .reg .s32 %sr<8>;\n\
593 .reg .f32 %f<10>;\n\
594 .reg .pred %p0, %p1;\n\
595 \n\
596 ld.param.u64 %rd0, [p_spins];\n\
597 ld.param.f32 %f0, [p_h];\n\
598 ld.param.f32 %f1, [p_j];\n\
599 ld.param.u32 %r0, [p_n_rows];\n\
600 ld.param.u32 %r1, [p_n_cols];\n\
601 ld.param.u64 %rd1, [p_seed];\n\
602 \n\
603 // i = blockIdx.y * blockDim.y + threadIdx.y\n\
604 mov.u32 %r2, %ntid.y;\n\
605 mov.u32 %r3, %ctaid.y;\n\
606 mov.u32 %r4, %tid.y;\n\
607 mad.lo.u32 %r5, %r2, %r3, %r4;\n\
608 // j = blockIdx.x * blockDim.x + threadIdx.x\n\
609 mov.u32 %r6, %ntid.x;\n\
610 mov.u32 %r7, %ctaid.x;\n\
611 mov.u32 %r8, %tid.x;\n\
612 mad.lo.u32 %r9, %r6, %r7, %r8;\n\
613 \n\
614 setp.ge.u32 %p0, %r5, %r0;\n\
615 @%p0 bra $MG_DONE;\n\
616 setp.ge.u32 %p0, %r9, %r1;\n\
617 @%p0 bra $MG_DONE;\n\
618 \n\
619 // Sum of neighbour spins (with bounds checks)\n\
620 mov.s32 %sr0, 0;\n\
621 // up\n\
622 setp.eq.u32 %p1, %r5, 0;\n\
623 @%p1 bra $MG_NB1;\n\
624 sub.u32 %r10, %r5, 1;\n\
625 mul.lo.u32 %r11, %r10, %r1;\n\
626 add.u32 %r11, %r11, %r9;\n\
627 mul.wide.u32 %rd2, %r11, 4;\n\
628 add.u64 %rd3, %rd0, %rd2;\n\
629 ld.global.s32 %sr1, [%rd3];\n\
630 add.s32 %sr0, %sr0, %sr1;\n\
631 $MG_NB1:\n\
632 // down\n\
633 add.u32 %r10, %r5, 1;\n\
634 setp.ge.u32 %p1, %r10, %r0;\n\
635 @%p1 bra $MG_NB2;\n\
636 mul.lo.u32 %r11, %r10, %r1;\n\
637 add.u32 %r11, %r11, %r9;\n\
638 mul.wide.u32 %rd2, %r11, 4;\n\
639 add.u64 %rd3, %rd0, %rd2;\n\
640 ld.global.s32 %sr1, [%rd3];\n\
641 add.s32 %sr0, %sr0, %sr1;\n\
642 $MG_NB2:\n\
643 // left\n\
644 setp.eq.u32 %p1, %r9, 0;\n\
645 @%p1 bra $MG_NB3;\n\
646 sub.u32 %r10, %r9, 1;\n\
647 mul.lo.u32 %r11, %r5, %r1;\n\
648 add.u32 %r11, %r11, %r10;\n\
649 mul.wide.u32 %rd2, %r11, 4;\n\
650 add.u64 %rd3, %rd0, %rd2;\n\
651 ld.global.s32 %sr1, [%rd3];\n\
652 add.s32 %sr0, %sr0, %sr1;\n\
653 $MG_NB3:\n\
654 // right\n\
655 add.u32 %r10, %r9, 1;\n\
656 setp.ge.u32 %p1, %r10, %r1;\n\
657 @%p1 bra $MG_FIELD;\n\
658 mul.lo.u32 %r11, %r5, %r1;\n\
659 add.u32 %r11, %r11, %r10;\n\
660 mul.wide.u32 %rd2, %r11, 4;\n\
661 add.u64 %rd3, %rd0, %rd2;\n\
662 ld.global.s32 %sr1, [%rd3];\n\
663 add.s32 %sr0, %sr0, %sr1;\n\
664 \n\
665 $MG_FIELD:\n\
666 // field = j * sum + h\n\
667 cvt.rn.f32.s32 %f2, %sr0;\n\
668 mul.f32 %f3, %f1, %f2;\n\
669 add.f32 %f3, %f3, %f0;\n\
670 \n\
671 // p_up = 1 / (1 + exp(-2*field))\n\
672 mov.f32 %f4, 0fC0000000; // -2\n\
673 mul.f32 %f5, %f4, %f3;\n\
674 mul.f32 %f5, %f5, 0f3FB8AA3B; // * log2(e): exp(x)=ex2(x*log2e)\n\
675 ex2.approx.f32 %f5, %f5; // exp(-2*field)\n\
676 mov.f32 %f6, 0f3F800000; // 1.0\n\
677 add.f32 %f5, %f5, %f6;\n\
678 div.rn.f32 %f7, %f6, %f5;\n\
679 \n\
680 // Inline LCG: seed ^ (row * n_cols + col)\n\
681 mul.lo.u32 %r12, %r5, %r1;\n\
682 add.u32 %r12, %r12, %r9;\n\
683 cvt.u64.u32 %rd4, %r12;\n\
684 xor.b64 %rd5, %rd4, %rd1;\n\
685 mov.u64 %rd6, 6364136223846793005;\n\
686 mul.lo.u64 %rd5, %rd5, %rd6;\n\
687 mov.u64 %rd6, 1442695040888963407;\n\
688 add.u64 %rd5, %rd5, %rd6;\n\
689 shr.u64 %rd7, %rd5, 32;\n\
690 cvt.u32.u64 %r13, %rd7;\n\
691 // u = (high32 >> 8) / 2^24\n\
692 shr.u32 %r14, %r13, 8;\n\
693 cvt.rn.f32.u32 %f8, %r14;\n\
694 mov.f32 %f9, 0f33800000; // 1 / 2^24\n\
695 mul.f32 %f8, %f8, %f9;\n\
696 \n\
697 // s = (u < p_up) ? +1 : -1\n\
698 setp.lt.f32 %p1, %f8, %f7;\n\
699 mov.s32 %sr2, -1;\n\
700 @%p1 mov.s32 %sr2, 1;\n\
701 \n\
702 // store\n\
703 mul.lo.u32 %r15, %r5, %r1;\n\
704 add.u32 %r15, %r15, %r9;\n\
705 mul.wide.u32 %rd8, %r15, 4;\n\
706 add.u64 %rd9, %rd0, %rd8;\n\
707 st.global.s32 [%rd9], %sr2;\n\
708 \n\
709 $MG_DONE:\n\
710 ret;\n\
711 }\n";
712 hdr + body
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718
719 #[test]
720 fn ptx_header_versions() {
721 assert!(ptx_header(75).contains(".version 7.5"));
722 assert!(ptx_header(80).contains(".version 8.0"));
723 assert!(ptx_header(89).contains(".version 8.0"));
724 assert!(ptx_header(90).contains(".version 8.4"));
725 assert!(ptx_header(100).contains(".version 8.7"));
726 }
727
728 #[test]
729 fn all_kernels_non_empty() {
730 type KernelFn = fn(u32) -> String;
731 let kernels: &[(&str, KernelFn)] = &[
732 ("forward_pass", forward_pass_ptx),
733 ("viterbi_step", viterbi_step_ptx),
734 ("crf_features", crf_features_ptx),
735 ("beam_topk", beam_topk_ptx),
736 ("edit_dist", edit_dist_ptx),
737 ("kalman_predict", kalman_predict_ptx),
738 ("mrf_gibbs", mrf_gibbs_ptx),
739 ];
740 let sms = [75u32, 80, 86, 89, 90, 100];
741 for &sm in &sms {
742 for &(name, f) in kernels {
743 let s = f(sm);
744 assert!(!s.is_empty(), "{name} sm{sm} empty");
745 assert!(
746 s.contains(".visible .entry"),
747 "{name} sm{sm} missing .visible .entry"
748 );
749 }
750 }
751 }
752}