1fn ptx_header(sm: u32) -> String {
15 let (ptx_ver, target) = match sm {
16 v if v >= 100 => ("8.7", format!("sm_{v}")),
17 v if v >= 90 => ("8.4", format!("sm_{v}")),
18 v if v >= 80 => ("8.0", format!("sm_{v}")),
19 v => ("7.5", format!("sm_{v}")),
20 };
21 format!(".version {ptx_ver}\n.target {target}\n.address_size 64\n\n")
22}
23
24#[must_use]
30pub fn forward_pass_ptx(sm: u32) -> String {
31 let hdr = ptx_header(sm);
32 let body = ".visible .entry forward_pass_kernel(\n\
33 .param .u64 p_alpha_prev,\n\
34 .param .u64 p_alpha_next,\n\
35 .param .u64 p_log_a,\n\
36 .param .u64 p_log_b_o,\n\
37 .param .u32 p_n_states\n\
38 )\n\
39 {\n\
40 .reg .u64 %rd<10>;\n\
41 .reg .u32 %r<10>;\n\
42 .reg .f32 %f<10>;\n\
43 .reg .pred %p0;\n\
44 \n\
45 ld.param.u64 %rd0, [p_alpha_prev];\n\
46 ld.param.u64 %rd1, [p_alpha_next];\n\
47 ld.param.u64 %rd2, [p_log_a];\n\
48 ld.param.u64 %rd3, [p_log_b_o];\n\
49 ld.param.u32 %r0, [p_n_states];\n\
50 \n\
51 // j = global thread id\n\
52 mov.u32 %r1, %ntid.x;\n\
53 mov.u32 %r2, %ctaid.x;\n\
54 mov.u32 %r3, %tid.x;\n\
55 mad.lo.u32 %r4, %r1, %r2, %r3;\n\
56 setp.ge.u32 %p0, %r4, %r0;\n\
57 @%p0 bra $FP_DONE;\n\
58 \n\
59 // First pass: find max of (alpha_prev[i] + log_a[i*S + j])\n\
60 mov.f32 %f0, 0fFF800000; // -inf\n\
61 mov.u32 %r5, 0;\n\
62 $FP_MAX:\n\
63 setp.ge.u32 %p0, %r5, %r0;\n\
64 @%p0 bra $FP_SUM_INIT;\n\
65 // alpha_prev[i]\n\
66 mul.wide.u32 %rd4, %r5, 4;\n\
67 add.u64 %rd5, %rd0, %rd4;\n\
68 ld.global.f32 %f1, [%rd5];\n\
69 // log_a[i*S + j]\n\
70 mul.lo.u32 %r6, %r5, %r0;\n\
71 add.u32 %r6, %r6, %r4;\n\
72 mul.wide.u32 %rd6, %r6, 4;\n\
73 add.u64 %rd7, %rd2, %rd6;\n\
74 ld.global.f32 %f2, [%rd7];\n\
75 add.f32 %f3, %f1, %f2;\n\
76 max.f32 %f0, %f0, %f3;\n\
77 add.u32 %r5, %r5, 1;\n\
78 bra $FP_MAX;\n\
79 \n\
80 $FP_SUM_INIT:\n\
81 // Second pass: accumulate exp((alpha_prev[i] + log_a[i*S+j]) - max)\n\
82 mov.f32 %f4, 0f00000000;\n\
83 mov.u32 %r5, 0;\n\
84 $FP_SUM:\n\
85 setp.ge.u32 %p0, %r5, %r0;\n\
86 @%p0 bra $FP_WRITE;\n\
87 mul.wide.u32 %rd4, %r5, 4;\n\
88 add.u64 %rd5, %rd0, %rd4;\n\
89 ld.global.f32 %f1, [%rd5];\n\
90 mul.lo.u32 %r6, %r5, %r0;\n\
91 add.u32 %r6, %r6, %r4;\n\
92 mul.wide.u32 %rd6, %r6, 4;\n\
93 add.u64 %rd7, %rd2, %rd6;\n\
94 ld.global.f32 %f2, [%rd7];\n\
95 add.f32 %f3, %f1, %f2;\n\
96 sub.f32 %f3, %f3, %f0;\n\
97 ex2.approx.f32 %f3, %f3;\n\
98 add.f32 %f4, %f4, %f3;\n\
99 add.u32 %r5, %r5, 1;\n\
100 bra $FP_SUM;\n\
101 \n\
102 $FP_WRITE:\n\
103 // result = max + log(sum) + log_b_o[j]\n\
104 lg2.approx.f32 %f4, %f4;\n\
105 add.f32 %f4, %f4, %f0;\n\
106 mul.wide.u32 %rd4, %r4, 4;\n\
107 add.u64 %rd5, %rd3, %rd4;\n\
108 ld.global.f32 %f5, [%rd5];\n\
109 add.f32 %f4, %f4, %f5;\n\
110 add.u64 %rd6, %rd1, %rd4;\n\
111 st.global.f32 [%rd6], %f4;\n\
112 \n\
113 $FP_DONE:\n\
114 ret;\n\
115 }\n";
116 hdr + body
117}
118
119#[must_use]
125pub fn viterbi_step_ptx(sm: u32) -> String {
126 let hdr = ptx_header(sm);
127 let body = ".visible .entry viterbi_step_kernel(\n\
128 .param .u64 p_delta_prev,\n\
129 .param .u64 p_delta_next,\n\
130 .param .u64 p_log_a,\n\
131 .param .u64 p_log_b_o,\n\
132 .param .u64 p_psi,\n\
133 .param .u32 p_n_states\n\
134 )\n\
135 {\n\
136 .reg .u64 %rd<10>;\n\
137 .reg .u32 %r<10>;\n\
138 .reg .s32 %sr<4>;\n\
139 .reg .f32 %f<8>;\n\
140 .reg .pred %p0, %p1;\n\
141 \n\
142 ld.param.u64 %rd0, [p_delta_prev];\n\
143 ld.param.u64 %rd1, [p_delta_next];\n\
144 ld.param.u64 %rd2, [p_log_a];\n\
145 ld.param.u64 %rd3, [p_log_b_o];\n\
146 ld.param.u64 %rd4, [p_psi];\n\
147 ld.param.u32 %r0, [p_n_states];\n\
148 \n\
149 mov.u32 %r1, %ntid.x;\n\
150 mov.u32 %r2, %ctaid.x;\n\
151 mov.u32 %r3, %tid.x;\n\
152 mad.lo.u32 %r4, %r1, %r2, %r3;\n\
153 setp.ge.u32 %p0, %r4, %r0;\n\
154 @%p0 bra $VS_DONE;\n\
155 \n\
156 mov.f32 %f0, 0fFF800000; // best = -inf\n\
157 mov.s32 %sr0, -1; // argmax\n\
158 mov.u32 %r5, 0;\n\
159 $VS_LOOP:\n\
160 setp.ge.u32 %p0, %r5, %r0;\n\
161 @%p0 bra $VS_WRITE;\n\
162 // delta_prev[i]\n\
163 mul.wide.u32 %rd5, %r5, 4;\n\
164 add.u64 %rd6, %rd0, %rd5;\n\
165 ld.global.f32 %f1, [%rd6];\n\
166 // log_a[i*S + j]\n\
167 mul.lo.u32 %r6, %r5, %r0;\n\
168 add.u32 %r6, %r6, %r4;\n\
169 mul.wide.u32 %rd7, %r6, 4;\n\
170 add.u64 %rd8, %rd2, %rd7;\n\
171 ld.global.f32 %f2, [%rd8];\n\
172 add.f32 %f3, %f1, %f2;\n\
173 setp.gt.f32 %p1, %f3, %f0;\n\
174 @%p1 mov.f32 %f0, %f3;\n\
175 @%p1 cvt.s32.u32 %sr0, %r5;\n\
176 add.u32 %r5, %r5, 1;\n\
177 bra $VS_LOOP;\n\
178 \n\
179 $VS_WRITE:\n\
180 // delta_next[j] = best + log_b_o[j]\n\
181 mul.wide.u32 %rd5, %r4, 4;\n\
182 add.u64 %rd6, %rd3, %rd5;\n\
183 ld.global.f32 %f4, [%rd6];\n\
184 add.f32 %f0, %f0, %f4;\n\
185 add.u64 %rd7, %rd1, %rd5;\n\
186 st.global.f32 [%rd7], %f0;\n\
187 // psi[j] = argmax\n\
188 add.u64 %rd8, %rd4, %rd5;\n\
189 st.global.s32 [%rd8], %sr0;\n\
190 \n\
191 $VS_DONE:\n\
192 ret;\n\
193 }\n";
194 hdr + body
195}
196
197#[must_use]
202pub fn crf_features_ptx(sm: u32) -> String {
203 let hdr = ptx_header(sm);
204 let body = ".visible .entry crf_features_kernel(\n\
205 .param .u64 p_emit,\n\
206 .param .u64 p_trans,\n\
207 .param .u64 p_x_feat,\n\
208 .param .u64 p_score,\n\
209 .param .u32 p_t,\n\
210 .param .u32 p_n_labels,\n\
211 .param .u32 p_n_features\n\
212 )\n\
213 {\n\
214 .reg .u64 %rd<10>;\n\
215 .reg .u32 %r<14>;\n\
216 .reg .f32 %f<8>;\n\
217 .reg .pred %p0;\n\
218 \n\
219 ld.param.u64 %rd0, [p_emit];\n\
220 ld.param.u64 %rd1, [p_trans];\n\
221 ld.param.u64 %rd2, [p_x_feat];\n\
222 ld.param.u64 %rd3, [p_score];\n\
223 ld.param.u32 %r0, [p_t];\n\
224 ld.param.u32 %r1, [p_n_labels];\n\
225 ld.param.u32 %r2, [p_n_features];\n\
226 \n\
227 // y_prev = blockIdx.y * blockDim.y + threadIdx.y\n\
228 mov.u32 %r3, %ntid.y;\n\
229 mov.u32 %r4, %ctaid.y;\n\
230 mov.u32 %r5, %tid.y;\n\
231 mad.lo.u32 %r6, %r3, %r4, %r5;\n\
232 // y_cur = blockIdx.x * blockDim.x + threadIdx.x\n\
233 mov.u32 %r7, %ntid.x;\n\
234 mov.u32 %r8, %ctaid.x;\n\
235 mov.u32 %r9, %tid.x;\n\
236 mad.lo.u32 %r10, %r7, %r8, %r9;\n\
237 \n\
238 setp.ge.u32 %p0, %r6, %r1;\n\
239 @%p0 bra $CF_DONE;\n\
240 setp.ge.u32 %p0, %r10, %r1;\n\
241 @%p0 bra $CF_DONE;\n\
242 \n\
243 // Emission score: dot(emit[y_cur,:], x_feat[t,:])\n\
244 mov.f32 %f0, 0f00000000;\n\
245 mov.u32 %r11, 0;\n\
246 $CF_EMIT:\n\
247 setp.ge.u32 %p0, %r11, %r2;\n\
248 @%p0 bra $CF_TRANS;\n\
249 // emit[y_cur * n_features + k]\n\
250 mul.lo.u32 %r12, %r10, %r2;\n\
251 add.u32 %r12, %r12, %r11;\n\
252 mul.wide.u32 %rd4, %r12, 4;\n\
253 add.u64 %rd5, %rd0, %rd4;\n\
254 ld.global.f32 %f1, [%rd5];\n\
255 // x_feat[t * n_features + k]\n\
256 mul.lo.u32 %r13, %r0, %r2;\n\
257 add.u32 %r13, %r13, %r11;\n\
258 mul.wide.u32 %rd6, %r13, 4;\n\
259 add.u64 %rd7, %rd2, %rd6;\n\
260 ld.global.f32 %f2, [%rd7];\n\
261 fma.rn.f32 %f0, %f1, %f2, %f0;\n\
262 add.u32 %r11, %r11, 1;\n\
263 bra $CF_EMIT;\n\
264 \n\
265 $CF_TRANS:\n\
266 // Transition score: trans[y_prev * n_labels + y_cur]\n\
267 mul.lo.u32 %r12, %r6, %r1;\n\
268 add.u32 %r12, %r12, %r10;\n\
269 mul.wide.u32 %rd4, %r12, 4;\n\
270 add.u64 %rd5, %rd1, %rd4;\n\
271 ld.global.f32 %f3, [%rd5];\n\
272 add.f32 %f0, %f0, %f3;\n\
273 \n\
274 // score[y_prev * n_labels + y_cur] = f0\n\
275 add.u64 %rd6, %rd3, %rd4;\n\
276 st.global.f32 [%rd6], %f0;\n\
277 \n\
278 $CF_DONE:\n\
279 ret;\n\
280 }\n";
281 hdr + body
282}
283
284#[must_use]
290pub fn beam_topk_ptx(sm: u32) -> String {
291 let hdr = ptx_header(sm);
292 let body = ".visible .entry beam_topk_kernel(\n\
293 .param .u64 p_scores,\n\
294 .param .u64 p_rank,\n\
295 .param .u32 p_n,\n\
296 .param .u32 p_k\n\
297 )\n\
298 {\n\
299 .reg .u64 %rd<8>;\n\
300 .reg .u32 %r<10>;\n\
301 .reg .s32 %sr<4>;\n\
302 .reg .f32 %f<4>;\n\
303 .reg .pred %p0, %p1;\n\
304 \n\
305 ld.param.u64 %rd0, [p_scores];\n\
306 ld.param.u64 %rd1, [p_rank];\n\
307 ld.param.u32 %r0, [p_n];\n\
308 ld.param.u32 %r1, [p_k];\n\
309 \n\
310 mov.u32 %r2, %ntid.x;\n\
311 mov.u32 %r3, %ctaid.x;\n\
312 mov.u32 %r4, %tid.x;\n\
313 mad.lo.u32 %r5, %r2, %r3, %r4;\n\
314 setp.ge.u32 %p0, %r5, %r0;\n\
315 @%p0 bra $BK_DONE;\n\
316 \n\
317 // my score\n\
318 mul.wide.u32 %rd2, %r5, 4;\n\
319 add.u64 %rd3, %rd0, %rd2;\n\
320 ld.global.f32 %f0, [%rd3];\n\
321 \n\
322 mov.u32 %r6, 0; // rank counter\n\
323 mov.u32 %r7, 0; // loop index\n\
324 $BK_LOOP:\n\
325 setp.ge.u32 %p0, %r7, %r0;\n\
326 @%p0 bra $BK_WRITE;\n\
327 mul.wide.u32 %rd4, %r7, 4;\n\
328 add.u64 %rd5, %rd0, %rd4;\n\
329 ld.global.f32 %f1, [%rd5];\n\
330 setp.gt.f32 %p1, %f1, %f0;\n\
331 @%p1 add.u32 %r6, %r6, 1;\n\
332 add.u32 %r7, %r7, 1;\n\
333 bra $BK_LOOP;\n\
334 \n\
335 $BK_WRITE:\n\
336 setp.ge.u32 %p0, %r6, %r1;\n\
337 mov.s32 %sr0, -1;\n\
338 @!%p0 cvt.s32.u32 %sr0, %r6;\n\
339 add.u64 %rd6, %rd1, %rd2;\n\
340 st.global.s32 [%rd6], %sr0;\n\
341 \n\
342 $BK_DONE:\n\
343 ret;\n\
344 }\n";
345 hdr + body
346}
347
348#[must_use]
354pub fn edit_dist_ptx(sm: u32) -> String {
355 let hdr = ptx_header(sm);
356 let body = ".visible .entry edit_dist_kernel(\n\
357 .param .u64 p_dp,\n\
358 .param .u64 p_a,\n\
359 .param .u64 p_b,\n\
360 .param .u32 p_n_a,\n\
361 .param .u32 p_n_b,\n\
362 .param .u32 p_diag\n\
363 )\n\
364 {\n\
365 .reg .u64 %rd<10>;\n\
366 .reg .u32 %r<14>;\n\
367 .reg .s32 %sr<6>;\n\
368 .reg .pred %p0, %p1;\n\
369 \n\
370 ld.param.u64 %rd0, [p_dp];\n\
371 ld.param.u64 %rd1, [p_a];\n\
372 ld.param.u64 %rd2, [p_b];\n\
373 ld.param.u32 %r0, [p_n_a];\n\
374 ld.param.u32 %r1, [p_n_b];\n\
375 ld.param.u32 %r2, [p_diag];\n\
376 \n\
377 // i = tid + 1; j = diag - i\n\
378 mov.u32 %r3, %ntid.x;\n\
379 mov.u32 %r4, %ctaid.x;\n\
380 mov.u32 %r5, %tid.x;\n\
381 mad.lo.u32 %r6, %r3, %r4, %r5;\n\
382 add.u32 %r7, %r6, 1; // i\n\
383 sub.u32 %r8, %r2, %r7; // j\n\
384 setp.gt.u32 %p0, %r7, %r0;\n\
385 @%p0 bra $ED_DONE;\n\
386 setp.eq.u32 %p0, %r8, 0;\n\
387 @%p0 bra $ED_DONE;\n\
388 setp.gt.u32 %p0, %r8, %r1;\n\
389 @%p0 bra $ED_DONE;\n\
390 \n\
391 // load a[i-1], b[j-1]\n\
392 sub.u32 %r9, %r7, 1;\n\
393 mul.wide.u32 %rd3, %r9, 4;\n\
394 add.u64 %rd4, %rd1, %rd3;\n\
395 ld.global.s32 %sr0, [%rd4];\n\
396 sub.u32 %r10, %r8, 1;\n\
397 mul.wide.u32 %rd5, %r10, 4;\n\
398 add.u64 %rd6, %rd2, %rd5;\n\
399 ld.global.s32 %sr1, [%rd6];\n\
400 \n\
401 // cost: 0 if eq, else 1\n\
402 setp.eq.s32 %p1, %sr0, %sr1;\n\
403 mov.s32 %sr2, 1;\n\
404 @%p1 mov.s32 %sr2, 0;\n\
405 \n\
406 // Read 3 neighbours from dp\n\
407 // dp[(i-1)*(n_b+1) + j]\n\
408 add.u32 %r11, %r1, 1;\n\
409 mul.lo.u32 %r12, %r9, %r11;\n\
410 add.u32 %r12, %r12, %r8;\n\
411 mul.wide.u32 %rd7, %r12, 4;\n\
412 add.u64 %rd8, %rd0, %rd7;\n\
413 ld.global.s32 %sr3, [%rd8];\n\
414 // dp[i*(n_b+1) + (j-1)]\n\
415 mul.lo.u32 %r12, %r7, %r11;\n\
416 add.u32 %r12, %r12, %r10;\n\
417 mul.wide.u32 %rd7, %r12, 4;\n\
418 add.u64 %rd8, %rd0, %rd7;\n\
419 ld.global.s32 %sr4, [%rd8];\n\
420 // dp[(i-1)*(n_b+1) + (j-1)]\n\
421 mul.lo.u32 %r12, %r9, %r11;\n\
422 add.u32 %r12, %r12, %r10;\n\
423 mul.wide.u32 %rd7, %r12, 4;\n\
424 add.u64 %rd8, %rd0, %rd7;\n\
425 ld.global.s32 %sr5, [%rd8];\n\
426 \n\
427 // best = min(sr3+1, sr4+1, sr5+sr2)\n\
428 add.s32 %sr3, %sr3, 1;\n\
429 add.s32 %sr4, %sr4, 1;\n\
430 add.s32 %sr5, %sr5, %sr2;\n\
431 min.s32 %sr3, %sr3, %sr4;\n\
432 min.s32 %sr3, %sr3, %sr5;\n\
433 \n\
434 // write dp[i*(n_b+1)+j]\n\
435 mul.lo.u32 %r12, %r7, %r11;\n\
436 add.u32 %r12, %r12, %r8;\n\
437 mul.wide.u32 %rd7, %r12, 4;\n\
438 add.u64 %rd8, %rd0, %rd7;\n\
439 st.global.s32 [%rd8], %sr3;\n\
440 \n\
441 $ED_DONE:\n\
442 ret;\n\
443 }\n";
444 hdr + body
445}
446
447#[must_use]
452pub fn kalman_predict_ptx(sm: u32) -> String {
453 let hdr = ptx_header(sm);
454 let body = ".visible .entry kalman_predict_kernel(\n\
455 .param .u64 p_x,\n\
456 .param .u64 p_x_pred,\n\
457 .param .u64 p_a,\n\
458 .param .u64 p_p,\n\
459 .param .u64 p_p_pred,\n\
460 .param .u64 p_q,\n\
461 .param .u32 p_n\n\
462 )\n\
463 {\n\
464 .reg .u64 %rd<14>;\n\
465 .reg .u32 %r<14>;\n\
466 .reg .f32 %f<10>;\n\
467 .reg .pred %p0;\n\
468 \n\
469 ld.param.u64 %rd0, [p_x];\n\
470 ld.param.u64 %rd1, [p_x_pred];\n\
471 ld.param.u64 %rd2, [p_a];\n\
472 ld.param.u64 %rd3, [p_p];\n\
473 ld.param.u64 %rd4, [p_p_pred];\n\
474 ld.param.u64 %rd5, [p_q];\n\
475 ld.param.u32 %r0, [p_n];\n\
476 \n\
477 mov.u32 %r1, %ntid.x;\n\
478 mov.u32 %r2, %ctaid.x;\n\
479 mov.u32 %r3, %tid.x;\n\
480 mad.lo.u32 %r4, %r1, %r2, %r3;\n\
481 setp.ge.u32 %p0, %r4, %r0;\n\
482 @%p0 bra $KP_DONE;\n\
483 \n\
484 // x_pred[i] = sum_k A[i,k] * x[k]\n\
485 mov.f32 %f0, 0f00000000;\n\
486 mov.u32 %r5, 0;\n\
487 $KP_VEC:\n\
488 setp.ge.u32 %p0, %r5, %r0;\n\
489 @%p0 bra $KP_VEC_WR;\n\
490 // A[i*n + k]\n\
491 mul.lo.u32 %r6, %r4, %r0;\n\
492 add.u32 %r6, %r6, %r5;\n\
493 mul.wide.u32 %rd6, %r6, 4;\n\
494 add.u64 %rd7, %rd2, %rd6;\n\
495 ld.global.f32 %f1, [%rd7];\n\
496 // x[k]\n\
497 mul.wide.u32 %rd8, %r5, 4;\n\
498 add.u64 %rd9, %rd0, %rd8;\n\
499 ld.global.f32 %f2, [%rd9];\n\
500 fma.rn.f32 %f0, %f1, %f2, %f0;\n\
501 add.u32 %r5, %r5, 1;\n\
502 bra $KP_VEC;\n\
503 \n\
504 $KP_VEC_WR:\n\
505 mul.wide.u32 %rd6, %r4, 4;\n\
506 add.u64 %rd7, %rd1, %rd6;\n\
507 st.global.f32 [%rd7], %f0;\n\
508 \n\
509 // P_pred[i,j] = sum_{k,l} A[i,k] P[k,l] A[j,l] + Q[i,j]\n\
510 // One thread handles row i, all j.\n\
511 mov.u32 %r7, 0;\n\
512 $KP_J:\n\
513 setp.ge.u32 %p0, %r7, %r0;\n\
514 @%p0 bra $KP_DONE;\n\
515 mov.f32 %f3, 0f00000000;\n\
516 mov.u32 %r8, 0;\n\
517 $KP_K:\n\
518 setp.ge.u32 %p0, %r8, %r0;\n\
519 @%p0 bra $KP_K_DONE;\n\
520 mov.f32 %f4, 0f00000000;\n\
521 mov.u32 %r9, 0;\n\
522 $KP_L:\n\
523 setp.ge.u32 %p0, %r9, %r0;\n\
524 @%p0 bra $KP_L_DONE;\n\
525 // P[k*n + l]\n\
526 mul.lo.u32 %r10, %r8, %r0;\n\
527 add.u32 %r10, %r10, %r9;\n\
528 mul.wide.u32 %rd6, %r10, 4;\n\
529 add.u64 %rd7, %rd3, %rd6;\n\
530 ld.global.f32 %f5, [%rd7];\n\
531 // A[j*n + l]\n\
532 mul.lo.u32 %r11, %r7, %r0;\n\
533 add.u32 %r11, %r11, %r9;\n\
534 mul.wide.u32 %rd8, %r11, 4;\n\
535 add.u64 %rd9, %rd2, %rd8;\n\
536 ld.global.f32 %f6, [%rd9];\n\
537 fma.rn.f32 %f4, %f5, %f6, %f4;\n\
538 add.u32 %r9, %r9, 1;\n\
539 bra $KP_L;\n\
540 \n\
541 $KP_L_DONE:\n\
542 // A[i*n + k]\n\
543 mul.lo.u32 %r10, %r4, %r0;\n\
544 add.u32 %r10, %r10, %r8;\n\
545 mul.wide.u32 %rd6, %r10, 4;\n\
546 add.u64 %rd7, %rd2, %rd6;\n\
547 ld.global.f32 %f7, [%rd7];\n\
548 fma.rn.f32 %f3, %f7, %f4, %f3;\n\
549 add.u32 %r8, %r8, 1;\n\
550 bra $KP_K;\n\
551 \n\
552 $KP_K_DONE:\n\
553 // P_pred[i*n+j] = f3 + Q[i*n+j]\n\
554 mul.lo.u32 %r10, %r4, %r0;\n\
555 add.u32 %r10, %r10, %r7;\n\
556 mul.wide.u32 %rd6, %r10, 4;\n\
557 add.u64 %rd7, %rd5, %rd6;\n\
558 ld.global.f32 %f8, [%rd7];\n\
559 add.f32 %f3, %f3, %f8;\n\
560 add.u64 %rd9, %rd4, %rd6;\n\
561 st.global.f32 [%rd9], %f3;\n\
562 add.u32 %r7, %r7, 1;\n\
563 bra $KP_J;\n\
564 \n\
565 $KP_DONE:\n\
566 ret;\n\
567 }\n";
568 hdr + body
569}
570
571#[must_use]
577pub fn mrf_gibbs_ptx(sm: u32) -> String {
578 let hdr = ptx_header(sm);
579 let body = ".visible .entry mrf_gibbs_kernel(\n\
580 .param .u64 p_spins,\n\
581 .param .f32 p_h,\n\
582 .param .f32 p_j,\n\
583 .param .u32 p_n_rows,\n\
584 .param .u32 p_n_cols,\n\
585 .param .u64 p_seed\n\
586 )\n\
587 {\n\
588 .reg .u64 %rd<10>;\n\
589 .reg .u32 %r<16>;\n\
590 .reg .s32 %sr<8>;\n\
591 .reg .f32 %f<10>;\n\
592 .reg .pred %p0, %p1;\n\
593 \n\
594 ld.param.u64 %rd0, [p_spins];\n\
595 ld.param.f32 %f0, [p_h];\n\
596 ld.param.f32 %f1, [p_j];\n\
597 ld.param.u32 %r0, [p_n_rows];\n\
598 ld.param.u32 %r1, [p_n_cols];\n\
599 ld.param.u64 %rd1, [p_seed];\n\
600 \n\
601 // i = blockIdx.y * blockDim.y + threadIdx.y\n\
602 mov.u32 %r2, %ntid.y;\n\
603 mov.u32 %r3, %ctaid.y;\n\
604 mov.u32 %r4, %tid.y;\n\
605 mad.lo.u32 %r5, %r2, %r3, %r4;\n\
606 // j = blockIdx.x * blockDim.x + threadIdx.x\n\
607 mov.u32 %r6, %ntid.x;\n\
608 mov.u32 %r7, %ctaid.x;\n\
609 mov.u32 %r8, %tid.x;\n\
610 mad.lo.u32 %r9, %r6, %r7, %r8;\n\
611 \n\
612 setp.ge.u32 %p0, %r5, %r0;\n\
613 @%p0 bra $MG_DONE;\n\
614 setp.ge.u32 %p0, %r9, %r1;\n\
615 @%p0 bra $MG_DONE;\n\
616 \n\
617 // Sum of neighbour spins (with bounds checks)\n\
618 mov.s32 %sr0, 0;\n\
619 // up\n\
620 setp.eq.u32 %p1, %r5, 0;\n\
621 @%p1 bra $MG_NB1;\n\
622 sub.u32 %r10, %r5, 1;\n\
623 mul.lo.u32 %r11, %r10, %r1;\n\
624 add.u32 %r11, %r11, %r9;\n\
625 mul.wide.u32 %rd2, %r11, 4;\n\
626 add.u64 %rd3, %rd0, %rd2;\n\
627 ld.global.s32 %sr1, [%rd3];\n\
628 add.s32 %sr0, %sr0, %sr1;\n\
629 $MG_NB1:\n\
630 // down\n\
631 add.u32 %r10, %r5, 1;\n\
632 setp.ge.u32 %p1, %r10, %r0;\n\
633 @%p1 bra $MG_NB2;\n\
634 mul.lo.u32 %r11, %r10, %r1;\n\
635 add.u32 %r11, %r11, %r9;\n\
636 mul.wide.u32 %rd2, %r11, 4;\n\
637 add.u64 %rd3, %rd0, %rd2;\n\
638 ld.global.s32 %sr1, [%rd3];\n\
639 add.s32 %sr0, %sr0, %sr1;\n\
640 $MG_NB2:\n\
641 // left\n\
642 setp.eq.u32 %p1, %r9, 0;\n\
643 @%p1 bra $MG_NB3;\n\
644 sub.u32 %r10, %r9, 1;\n\
645 mul.lo.u32 %r11, %r5, %r1;\n\
646 add.u32 %r11, %r11, %r10;\n\
647 mul.wide.u32 %rd2, %r11, 4;\n\
648 add.u64 %rd3, %rd0, %rd2;\n\
649 ld.global.s32 %sr1, [%rd3];\n\
650 add.s32 %sr0, %sr0, %sr1;\n\
651 $MG_NB3:\n\
652 // right\n\
653 add.u32 %r10, %r9, 1;\n\
654 setp.ge.u32 %p1, %r10, %r1;\n\
655 @%p1 bra $MG_FIELD;\n\
656 mul.lo.u32 %r11, %r5, %r1;\n\
657 add.u32 %r11, %r11, %r10;\n\
658 mul.wide.u32 %rd2, %r11, 4;\n\
659 add.u64 %rd3, %rd0, %rd2;\n\
660 ld.global.s32 %sr1, [%rd3];\n\
661 add.s32 %sr0, %sr0, %sr1;\n\
662 \n\
663 $MG_FIELD:\n\
664 // field = j * sum + h\n\
665 cvt.rn.f32.s32 %f2, %sr0;\n\
666 mul.f32 %f3, %f1, %f2;\n\
667 add.f32 %f3, %f3, %f0;\n\
668 \n\
669 // p_up = 1 / (1 + exp(-2*field))\n\
670 mov.f32 %f4, 0fC0000000; // -2\n\
671 mul.f32 %f5, %f4, %f3;\n\
672 ex2.approx.f32 %f5, %f5; // exp2(...)\n\
673 mov.f32 %f6, 0f3F800000; // 1.0\n\
674 add.f32 %f5, %f5, %f6;\n\
675 div.rn.f32 %f7, %f6, %f5;\n\
676 \n\
677 // Inline LCG: seed ^ (row * n_cols + col)\n\
678 mul.lo.u32 %r12, %r5, %r1;\n\
679 add.u32 %r12, %r12, %r9;\n\
680 cvt.u64.u32 %rd4, %r12;\n\
681 xor.b64 %rd5, %rd4, %rd1;\n\
682 mov.u64 %rd6, 6364136223846793005;\n\
683 mul.lo.u64 %rd5, %rd5, %rd6;\n\
684 mov.u64 %rd6, 1442695040888963407;\n\
685 add.u64 %rd5, %rd5, %rd6;\n\
686 shr.u64 %rd7, %rd5, 32;\n\
687 cvt.u32.u64 %r13, %rd7;\n\
688 // u = (high32 >> 8) / 2^24\n\
689 shr.u32 %r14, %r13, 8;\n\
690 cvt.rn.f32.u32 %f8, %r14;\n\
691 mov.f32 %f9, 0f33800000; // 1 / 2^24\n\
692 mul.f32 %f8, %f8, %f9;\n\
693 \n\
694 // s = (u < p_up) ? +1 : -1\n\
695 setp.lt.f32 %p1, %f8, %f7;\n\
696 mov.s32 %sr2, -1;\n\
697 @%p1 mov.s32 %sr2, 1;\n\
698 \n\
699 // store\n\
700 mul.lo.u32 %r15, %r5, %r1;\n\
701 add.u32 %r15, %r15, %r9;\n\
702 mul.wide.u32 %rd8, %r15, 4;\n\
703 add.u64 %rd9, %rd0, %rd8;\n\
704 st.global.s32 [%rd9], %sr2;\n\
705 \n\
706 $MG_DONE:\n\
707 ret;\n\
708 }\n";
709 hdr + body
710}
711
712#[cfg(test)]
713mod tests {
714 use super::*;
715
716 #[test]
717 fn ptx_header_versions() {
718 assert!(ptx_header(75).contains(".version 7.5"));
719 assert!(ptx_header(80).contains(".version 8.0"));
720 assert!(ptx_header(89).contains(".version 8.0"));
721 assert!(ptx_header(90).contains(".version 8.4"));
722 assert!(ptx_header(100).contains(".version 8.7"));
723 }
724
725 #[test]
726 fn all_kernels_non_empty() {
727 type KernelFn = fn(u32) -> String;
728 let kernels: &[(&str, KernelFn)] = &[
729 ("forward_pass", forward_pass_ptx),
730 ("viterbi_step", viterbi_step_ptx),
731 ("crf_features", crf_features_ptx),
732 ("beam_topk", beam_topk_ptx),
733 ("edit_dist", edit_dist_ptx),
734 ("kalman_predict", kalman_predict_ptx),
735 ("mrf_gibbs", mrf_gibbs_ptx),
736 ];
737 let sms = [75u32, 80, 86, 89, 90, 100];
738 for &sm in &sms {
739 for &(name, f) in kernels {
740 let s = f(sm);
741 assert!(!s.is_empty(), "{name} sm{sm} empty");
742 assert!(
743 s.contains(".visible .entry"),
744 "{name} sm{sm} missing .visible .entry"
745 );
746 }
747 }
748 }
749}