Skip to main content

oxicuda_graphalg/
ptx_kernels.rs

1//! GPU PTX kernels for graph algorithm operations.
2//!
3//! Each kernel is emitted as a self-contained PTX module string, parameterised on SM version.
4//! PTX ISA is selected by SM:
5//!     SM>=100 -> 8.7 (Blackwell), SM>=90 -> 8.4 (Hopper),
6//!     SM>=80  -> 8.0 (Ampere),    else  -> 7.5 (Turing).
7//!
8//! IMPORTANT: PTX kernel bodies use **string concatenation** (NOT `format!()`) for
9//! sections containing `%rd`, `%r`, `%f`, `%fd` register names, which Rust's format macro
10//! would misinterpret as unused format arguments in edition 2024.
11
12/// Build a PTX file header string for the given SM version.
13fn ptx_header(sm: u32) -> String {
14    let (ptx_ver, target) = match sm {
15        v if v >= 100 => ("8.7", format!("sm_{v}")),
16        v if v >= 90 => ("8.4", format!("sm_{v}")),
17        v if v >= 80 => ("8.0", format!("sm_{v}")),
18        v => ("7.5", format!("sm_{v}")),
19    };
20    format!(".version {ptx_ver}\n.target {target}\n.address_size 64\n\n")
21}
22
23/// Frontier-based BFS single-step level update.
24///
25/// Signature: `bfs_level_kernel(row_ptr, col_idx, level, frontier_in, frontier_out, n, depth)`
26/// For each vertex `u` in `frontier_in`, set `level[v] = depth+1` for every neighbor `v`
27/// with `level[v] == -1`, and append `v` to `frontier_out`.
28#[must_use]
29pub fn bfs_level_ptx(sm: u32) -> String {
30    let hdr = ptx_header(sm);
31    let body = ".visible .entry bfs_level_kernel(\n\
32        .param .u64 p_row_ptr,\n\
33        .param .u64 p_col_idx,\n\
34        .param .u64 p_level,\n\
35        .param .u64 p_front_in,\n\
36        .param .u64 p_front_out,\n\
37        .param .u32 p_n,\n\
38        .param .u32 p_depth\n\
39    )\n\
40    {\n\
41        .reg .u64  %rd<16>;\n\
42        .reg .u32  %r<24>;\n\
43        .reg .pred %p0;\n\
44    \n\
45        ld.param.u64  %rd0, [p_row_ptr];\n\
46        ld.param.u64  %rd1, [p_col_idx];\n\
47        ld.param.u64  %rd2, [p_level];\n\
48        ld.param.u64  %rd3, [p_front_in];\n\
49        ld.param.u64  %rd4, [p_front_out];\n\
50        ld.param.u32  %r0,  [p_n];\n\
51        ld.param.u32  %r1,  [p_depth];\n\
52    \n\
53        mov.u32       %r2, %ntid.x;\n\
54        mov.u32       %r3, %ctaid.x;\n\
55        mov.u32       %r4, %tid.x;\n\
56        mad.lo.u32    %r5, %r2, %r3, %r4;\n\
57    \n\
58        setp.ge.u32   %p0, %r5, %r0;\n\
59        @%p0 bra $BFS_DONE;\n\
60    \n\
61        // load frontier_in[gid] -> u (1 = in frontier)\n\
62        mul.wide.u32  %rd5, %r5, 4;\n\
63        add.u64       %rd6, %rd3, %rd5;\n\
64        ld.global.u32 %r6, [%rd6];\n\
65        setp.eq.u32   %p0, %r6, 0;\n\
66        @%p0 bra $BFS_DONE;\n\
67    \n\
68        // u = gid; row_ptr[u], row_ptr[u+1]\n\
69        mul.wide.u32  %rd7, %r5, 4;\n\
70        add.u64       %rd8, %rd0, %rd7;\n\
71        ld.global.u32 %r7, [%rd8];\n\
72        add.u64       %rd9, %rd8, 4;\n\
73        ld.global.u32 %r8, [%rd9];\n\
74    \n\
75        // depth_next = depth + 1\n\
76        add.u32       %r9, %r1, 1;\n\
77    \n\
78        // for j = row_ptr[u]; j < row_ptr[u+1]; j++\n\
79        mov.u32       %r10, %r7;\n\
80    $BFS_LOOP:\n\
81        setp.ge.u32   %p0, %r10, %r8;\n\
82        @%p0 bra $BFS_DONE;\n\
83    \n\
84        // v = col_idx[j]\n\
85        mul.wide.u32  %rd10, %r10, 4;\n\
86        add.u64       %rd11, %rd1, %rd10;\n\
87        ld.global.u32 %r11, [%rd11];\n\
88    \n\
89        // if level[v] == -1\n\
90        mul.wide.u32  %rd12, %r11, 4;\n\
91        add.u64       %rd13, %rd2, %rd12;\n\
92        ld.global.u32 %r12, [%rd13];\n\
93        setp.ne.u32   %p0, %r12, 0xFFFFFFFF;\n\
94        @%p0 bra $BFS_SKIP;\n\
95    \n\
96        st.global.u32 [%rd13], %r9;\n\
97        add.u64       %rd14, %rd4, %rd12;\n\
98        mov.u32       %r13, 1;\n\
99        st.global.u32 [%rd14], %r13;\n\
100    \n\
101    $BFS_SKIP:\n\
102        add.u32       %r10, %r10, 1;\n\
103        bra $BFS_LOOP;\n\
104    \n\
105    $BFS_DONE:\n\
106        ret;\n\
107    }\n";
108    hdr + body
109}
110
111/// Dijkstra single edge relaxation kernel.
112///
113/// Signature: `dijkstra_relax_kernel(row_ptr, col_idx, weights, dist, frontier, n, u)`
114/// For the source vertex `u`, relax `dist[v] = min(dist[v], dist[u] + w(u,v))`.
115#[must_use]
116pub fn dijkstra_relax_ptx(sm: u32) -> String {
117    let hdr = ptx_header(sm);
118    let body = ".visible .entry dijkstra_relax_kernel(\n\
119        .param .u64 p_row_ptr,\n\
120        .param .u64 p_col_idx,\n\
121        .param .u64 p_weights,\n\
122        .param .u64 p_dist,\n\
123        .param .u64 p_frontier,\n\
124        .param .u32 p_n,\n\
125        .param .u32 p_u\n\
126    )\n\
127    {\n\
128        .reg .u64  %rd<16>;\n\
129        .reg .u32  %r<20>;\n\
130        .reg .f32  %f<8>;\n\
131        .reg .pred %p0;\n\
132    \n\
133        ld.param.u64  %rd0, [p_row_ptr];\n\
134        ld.param.u64  %rd1, [p_col_idx];\n\
135        ld.param.u64  %rd2, [p_weights];\n\
136        ld.param.u64  %rd3, [p_dist];\n\
137        ld.param.u64  %rd4, [p_frontier];\n\
138        ld.param.u32  %r0,  [p_n];\n\
139        ld.param.u32  %r1,  [p_u];\n\
140    \n\
141        // row_ptr[u], row_ptr[u+1]\n\
142        mul.wide.u32  %rd5, %r1, 4;\n\
143        add.u64       %rd6, %rd0, %rd5;\n\
144        ld.global.u32 %r2, [%rd6];\n\
145        add.u64       %rd7, %rd6, 4;\n\
146        ld.global.u32 %r3, [%rd7];\n\
147    \n\
148        // gid = tid within [row_ptr[u], row_ptr[u+1])\n\
149        mov.u32       %r4, %ntid.x;\n\
150        mov.u32       %r5, %ctaid.x;\n\
151        mov.u32       %r6, %tid.x;\n\
152        mad.lo.u32    %r7, %r4, %r5, %r6;\n\
153        add.u32       %r8, %r2, %r7;\n\
154        setp.ge.u32   %p0, %r8, %r3;\n\
155        @%p0 bra $DR_DONE;\n\
156    \n\
157        // dist[u]\n\
158        mul.wide.u32  %rd8, %r1, 4;\n\
159        add.u64       %rd9, %rd3, %rd8;\n\
160        ld.global.f32 %f0, [%rd9];\n\
161    \n\
162        // v = col_idx[j]; w = weights[j]\n\
163        mul.wide.u32  %rd10, %r8, 4;\n\
164        add.u64       %rd11, %rd1, %rd10;\n\
165        ld.global.u32 %r9, [%rd11];\n\
166        add.u64       %rd12, %rd2, %rd10;\n\
167        ld.global.f32 %f1, [%rd12];\n\
168    \n\
169        // candidate = dist[u] + w\n\
170        add.f32       %f2, %f0, %f1;\n\
171    \n\
172        // dist[v]\n\
173        mul.wide.u32  %rd13, %r9, 4;\n\
174        add.u64       %rd14, %rd3, %rd13;\n\
175        ld.global.f32 %f3, [%rd14];\n\
176    \n\
177        setp.ge.f32   %p0, %f2, %f3;\n\
178        @%p0 bra $DR_DONE;\n\
179        st.global.f32 [%rd14], %f2;\n\
180        add.u64       %rd15, %rd4, %rd13;\n\
181        mov.u32       %r10, 1;\n\
182        st.global.u32 [%rd15], %r10;\n\
183    \n\
184    $DR_DONE:\n\
185        ret;\n\
186    }\n";
187    hdr + body
188}
189
190/// PageRank single power-iteration step.
191///
192/// Signature: `pagerank_step_kernel(row_ptr_t, col_idx_t, out_degree, rank_in, rank_out, n, damping)`
193/// `rank_out[v] = (1-damping)/n + damping * sum_{u in in_neighbors(v)} rank_in[u] / out_degree[u]`.
194#[must_use]
195pub fn pagerank_step_ptx(sm: u32) -> String {
196    let hdr = ptx_header(sm);
197    let body = ".visible .entry pagerank_step_kernel(\n\
198        .param .u64 p_row_ptr_t,\n\
199        .param .u64 p_col_idx_t,\n\
200        .param .u64 p_out_degree,\n\
201        .param .u64 p_rank_in,\n\
202        .param .u64 p_rank_out,\n\
203        .param .u32 p_n,\n\
204        .param .f32 p_damping\n\
205    )\n\
206    {\n\
207        .reg .u64  %rd<16>;\n\
208        .reg .u32  %r<20>;\n\
209        .reg .f32  %f<10>;\n\
210        .reg .pred %p0;\n\
211    \n\
212        ld.param.u64  %rd0, [p_row_ptr_t];\n\
213        ld.param.u64  %rd1, [p_col_idx_t];\n\
214        ld.param.u64  %rd2, [p_out_degree];\n\
215        ld.param.u64  %rd3, [p_rank_in];\n\
216        ld.param.u64  %rd4, [p_rank_out];\n\
217        ld.param.u32  %r0,  [p_n];\n\
218        ld.param.f32  %f0,  [p_damping];\n\
219    \n\
220        mov.u32       %r1, %ntid.x;\n\
221        mov.u32       %r2, %ctaid.x;\n\
222        mov.u32       %r3, %tid.x;\n\
223        mad.lo.u32    %r4, %r1, %r2, %r3;\n\
224    \n\
225        setp.ge.u32   %p0, %r4, %r0;\n\
226        @%p0 bra $PR_DONE;\n\
227    \n\
228        // teleport = (1 - damping) / n\n\
229        mov.f32       %f1, 0f3F800000;\n\
230        sub.f32       %f2, %f1, %f0;\n\
231        cvt.rn.f32.u32 %f3, %r0;\n\
232        div.rn.f32    %f4, %f2, %f3;\n\
233    \n\
234        // start, end = row_ptr_t[v], row_ptr_t[v+1]\n\
235        mul.wide.u32  %rd5, %r4, 4;\n\
236        add.u64       %rd6, %rd0, %rd5;\n\
237        ld.global.u32 %r5, [%rd6];\n\
238        add.u64       %rd7, %rd6, 4;\n\
239        ld.global.u32 %r6, [%rd7];\n\
240    \n\
241        // sum = 0\n\
242        mov.f32       %f5, 0f00000000;\n\
243        mov.u32       %r7, %r5;\n\
244    $PR_LOOP:\n\
245        setp.ge.u32   %p0, %r7, %r6;\n\
246        @%p0 bra $PR_WRITE;\n\
247    \n\
248        // u = col_idx_t[j]\n\
249        mul.wide.u32  %rd8, %r7, 4;\n\
250        add.u64       %rd9, %rd1, %rd8;\n\
251        ld.global.u32 %r8, [%rd9];\n\
252    \n\
253        // r_in = rank_in[u]\n\
254        mul.wide.u32  %rd10, %r8, 4;\n\
255        add.u64       %rd11, %rd3, %rd10;\n\
256        ld.global.f32 %f6, [%rd11];\n\
257    \n\
258        // out_d = out_degree[u]\n\
259        add.u64       %rd12, %rd2, %rd10;\n\
260        ld.global.u32 %r9, [%rd12];\n\
261        cvt.rn.f32.u32 %f7, %r9;\n\
262    \n\
263        div.rn.f32    %f8, %f6, %f7;\n\
264        add.f32       %f5, %f5, %f8;\n\
265    \n\
266        add.u32       %r7, %r7, 1;\n\
267        bra $PR_LOOP;\n\
268    \n\
269    $PR_WRITE:\n\
270        mul.f32       %f9, %f0, %f5;\n\
271        add.f32       %f9, %f9, %f4;\n\
272        mul.wide.u32  %rd13, %r4, 4;\n\
273        add.u64       %rd14, %rd4, %rd13;\n\
274        st.global.f32 [%rd14], %f9;\n\
275    \n\
276    $PR_DONE:\n\
277        ret;\n\
278    }\n";
279    hdr + body
280}
281
282/// Floyd-Warshall inner DP update.
283///
284/// Signature: `fw_inner_kernel(dist, n, k)`
285/// `dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j])`.
286#[must_use]
287pub fn fw_inner_ptx(sm: u32) -> String {
288    let hdr = ptx_header(sm);
289    let body = ".visible .entry fw_inner_kernel(\n\
290        .param .u64 p_dist,\n\
291        .param .u32 p_n,\n\
292        .param .u32 p_k\n\
293    )\n\
294    {\n\
295        .reg .u64  %rd<12>;\n\
296        .reg .u32  %r<24>;\n\
297        .reg .f32  %f<8>;\n\
298        .reg .pred %p0;\n\
299    \n\
300        ld.param.u64  %rd0, [p_dist];\n\
301        ld.param.u32  %r0,  [p_n];\n\
302        ld.param.u32  %r1,  [p_k];\n\
303    \n\
304        // i = blockIdx.y * blockDim.y + threadIdx.y\n\
305        mov.u32       %r2, %ntid.y;\n\
306        mov.u32       %r3, %ctaid.y;\n\
307        mov.u32       %r4, %tid.y;\n\
308        mad.lo.u32    %r5, %r2, %r3, %r4;\n\
309    \n\
310        // j = blockIdx.x * blockDim.x + threadIdx.x\n\
311        mov.u32       %r6, %ntid.x;\n\
312        mov.u32       %r7, %ctaid.x;\n\
313        mov.u32       %r8, %tid.x;\n\
314        mad.lo.u32    %r9, %r6, %r7, %r8;\n\
315    \n\
316        setp.ge.u32   %p0, %r5, %r0;\n\
317        @%p0 bra $FW_DONE;\n\
318        setp.ge.u32   %p0, %r9, %r0;\n\
319        @%p0 bra $FW_DONE;\n\
320    \n\
321        // d_ij = dist[i*n + j]\n\
322        mul.lo.u32    %r10, %r5, %r0;\n\
323        add.u32       %r10, %r10, %r9;\n\
324        mul.wide.u32  %rd2, %r10, 4;\n\
325        add.u64       %rd3, %rd0, %rd2;\n\
326        ld.global.f32 %f0, [%rd3];\n\
327    \n\
328        // d_ik = dist[i*n + k]\n\
329        mul.lo.u32    %r11, %r5, %r0;\n\
330        add.u32       %r11, %r11, %r1;\n\
331        mul.wide.u32  %rd4, %r11, 4;\n\
332        add.u64       %rd5, %rd0, %rd4;\n\
333        ld.global.f32 %f1, [%rd5];\n\
334    \n\
335        // d_kj = dist[k*n + j]\n\
336        mul.lo.u32    %r12, %r1, %r0;\n\
337        add.u32       %r12, %r12, %r9;\n\
338        mul.wide.u32  %rd6, %r12, 4;\n\
339        add.u64       %rd7, %rd0, %rd6;\n\
340        ld.global.f32 %f2, [%rd7];\n\
341    \n\
342        // candidate = d_ik + d_kj\n\
343        add.f32       %f3, %f1, %f2;\n\
344        min.f32       %f4, %f0, %f3;\n\
345        st.global.f32 [%rd3], %f4;\n\
346    \n\
347    $FW_DONE:\n\
348        ret;\n\
349    }\n";
350    hdr + body
351}
352
353/// Triangle counting per row.
354///
355/// Signature: `triangle_count_kernel(row_ptr, col_idx, count, n)`
356/// For each vertex `u`, count triples `(u, v, w)` with `u<v<w` and all three edges present.
357#[must_use]
358pub fn triangle_count_ptx(sm: u32) -> String {
359    let hdr = ptx_header(sm);
360    let body = ".visible .entry triangle_count_kernel(\n\
361        .param .u64 p_row_ptr,\n\
362        .param .u64 p_col_idx,\n\
363        .param .u64 p_count,\n\
364        .param .u32 p_n\n\
365    )\n\
366    {\n\
367        .reg .u64  %rd<16>;\n\
368        .reg .u32  %r<32>;\n\
369        .reg .pred %p0;\n\
370    \n\
371        ld.param.u64  %rd0, [p_row_ptr];\n\
372        ld.param.u64  %rd1, [p_col_idx];\n\
373        ld.param.u64  %rd2, [p_count];\n\
374        ld.param.u32  %r0,  [p_n];\n\
375    \n\
376        mov.u32       %r1, %ntid.x;\n\
377        mov.u32       %r2, %ctaid.x;\n\
378        mov.u32       %r3, %tid.x;\n\
379        mad.lo.u32    %r4, %r1, %r2, %r3;\n\
380    \n\
381        setp.ge.u32   %p0, %r4, %r0;\n\
382        @%p0 bra $TC_DONE;\n\
383    \n\
384        // row_ptr[u], row_ptr[u+1]\n\
385        mul.wide.u32  %rd3, %r4, 4;\n\
386        add.u64       %rd4, %rd0, %rd3;\n\
387        ld.global.u32 %r5, [%rd4];\n\
388        add.u64       %rd5, %rd4, 4;\n\
389        ld.global.u32 %r6, [%rd5];\n\
390    \n\
391        mov.u32       %r7, 0;\n\
392        mov.u32       %r8, %r5;\n\
393    \n\
394    $TC_OUTER:\n\
395        setp.ge.u32   %p0, %r8, %r6;\n\
396        @%p0 bra $TC_WRITE;\n\
397    \n\
398        // v = col_idx[j]\n\
399        mul.wide.u32  %rd6, %r8, 4;\n\
400        add.u64       %rd7, %rd1, %rd6;\n\
401        ld.global.u32 %r9, [%rd7];\n\
402    \n\
403        // only u < v\n\
404        setp.le.u32   %p0, %r9, %r4;\n\
405        @%p0 bra $TC_OUTER_END;\n\
406    \n\
407        add.u32       %r10, %r8, 1;\n\
408    $TC_INNER:\n\
409        setp.ge.u32   %p0, %r10, %r6;\n\
410        @%p0 bra $TC_OUTER_END;\n\
411    \n\
412        // w = col_idx[k]\n\
413        mul.wide.u32  %rd8, %r10, 4;\n\
414        add.u64       %rd9, %rd1, %rd8;\n\
415        ld.global.u32 %r11, [%rd9];\n\
416    \n\
417        setp.le.u32   %p0, %r11, %r9;\n\
418        @%p0 bra $TC_INNER_END;\n\
419    \n\
420        // check v-w edge: scan row v for w\n\
421        mul.wide.u32  %rd10, %r9, 4;\n\
422        add.u64       %rd11, %rd0, %rd10;\n\
423        ld.global.u32 %r12, [%rd11];\n\
424        add.u64       %rd12, %rd11, 4;\n\
425        ld.global.u32 %r13, [%rd12];\n\
426    \n\
427        mov.u32       %r14, %r12;\n\
428    $TC_SCAN:\n\
429        setp.ge.u32   %p0, %r14, %r13;\n\
430        @%p0 bra $TC_INNER_END;\n\
431        mul.wide.u32  %rd13, %r14, 4;\n\
432        add.u64       %rd14, %rd1, %rd13;\n\
433        ld.global.u32 %r15, [%rd14];\n\
434        setp.eq.u32   %p0, %r15, %r11;\n\
435        @%p0 bra $TC_HIT;\n\
436        add.u32       %r14, %r14, 1;\n\
437        bra $TC_SCAN;\n\
438    \n\
439    $TC_HIT:\n\
440        add.u32       %r7, %r7, 1;\n\
441    \n\
442    $TC_INNER_END:\n\
443        add.u32       %r10, %r10, 1;\n\
444        bra $TC_INNER;\n\
445    \n\
446    $TC_OUTER_END:\n\
447        add.u32       %r8, %r8, 1;\n\
448        bra $TC_OUTER;\n\
449    \n\
450    $TC_WRITE:\n\
451        mul.wide.u32  %rd15, %r4, 4;\n\
452        add.u64       %rd5, %rd2, %rd15;\n\
453        st.global.u32 [%rd5], %r7;\n\
454    \n\
455    $TC_DONE:\n\
456        ret;\n\
457    }\n";
458    hdr + body
459}
460
461/// Boolean CSR sparse mat-vec (for matrix-form BFS).
462///
463/// Signature: `csr_spmv_bool_kernel(row_ptr, col_idx, x, y, n)`
464/// `y[i] = OR over j in row_ptr[i] .. row_ptr[i+1] of x[col_idx[j]]`.
465#[must_use]
466pub fn csr_spmv_bool_ptx(sm: u32) -> String {
467    let hdr = ptx_header(sm);
468    let body = ".visible .entry csr_spmv_bool_kernel(\n\
469        .param .u64 p_row_ptr,\n\
470        .param .u64 p_col_idx,\n\
471        .param .u64 p_x,\n\
472        .param .u64 p_y,\n\
473        .param .u32 p_n\n\
474    )\n\
475    {\n\
476        .reg .u64  %rd<12>;\n\
477        .reg .u32  %r<16>;\n\
478        .reg .pred %p0;\n\
479    \n\
480        ld.param.u64  %rd0, [p_row_ptr];\n\
481        ld.param.u64  %rd1, [p_col_idx];\n\
482        ld.param.u64  %rd2, [p_x];\n\
483        ld.param.u64  %rd3, [p_y];\n\
484        ld.param.u32  %r0,  [p_n];\n\
485    \n\
486        mov.u32       %r1, %ntid.x;\n\
487        mov.u32       %r2, %ctaid.x;\n\
488        mov.u32       %r3, %tid.x;\n\
489        mad.lo.u32    %r4, %r1, %r2, %r3;\n\
490    \n\
491        setp.ge.u32   %p0, %r4, %r0;\n\
492        @%p0 bra $SP_DONE;\n\
493    \n\
494        // start, end = row_ptr[i], row_ptr[i+1]\n\
495        mul.wide.u32  %rd4, %r4, 4;\n\
496        add.u64       %rd5, %rd0, %rd4;\n\
497        ld.global.u32 %r5, [%rd5];\n\
498        add.u64       %rd6, %rd5, 4;\n\
499        ld.global.u32 %r6, [%rd6];\n\
500    \n\
501        mov.u32       %r7, 0;\n\
502        mov.u32       %r8, %r5;\n\
503    \n\
504    $SP_LOOP:\n\
505        setp.ge.u32   %p0, %r8, %r6;\n\
506        @%p0 bra $SP_WRITE;\n\
507    \n\
508        mul.wide.u32  %rd7, %r8, 4;\n\
509        add.u64       %rd8, %rd1, %rd7;\n\
510        ld.global.u32 %r9, [%rd8];\n\
511        mul.wide.u32  %rd9, %r9, 4;\n\
512        add.u64       %rd10, %rd2, %rd9;\n\
513        ld.global.u32 %r10, [%rd10];\n\
514        or.b32        %r7, %r7, %r10;\n\
515    \n\
516        add.u32       %r8, %r8, 1;\n\
517        bra $SP_LOOP;\n\
518    \n\
519    $SP_WRITE:\n\
520        add.u64       %rd11, %rd3, %rd4;\n\
521        st.global.u32 [%rd11], %r7;\n\
522    \n\
523    $SP_DONE:\n\
524        ret;\n\
525    }\n";
526    hdr + body
527}
528
529/// Label propagation single step.
530///
531/// Signature: `community_label_kernel(row_ptr, col_idx, label_in, label_out, n)`
532/// Update `label[u]` = most-frequent label among neighbors of u (ties broken by min label).
533#[must_use]
534pub fn community_label_ptx(sm: u32) -> String {
535    let hdr = ptx_header(sm);
536    let body = ".visible .entry community_label_kernel(\n\
537        .param .u64 p_row_ptr,\n\
538        .param .u64 p_col_idx,\n\
539        .param .u64 p_label_in,\n\
540        .param .u64 p_label_out,\n\
541        .param .u32 p_n\n\
542    )\n\
543    {\n\
544        .reg .u64  %rd<16>;\n\
545        .reg .u32  %r<24>;\n\
546        .reg .pred %p0;\n\
547    \n\
548        ld.param.u64  %rd0, [p_row_ptr];\n\
549        ld.param.u64  %rd1, [p_col_idx];\n\
550        ld.param.u64  %rd2, [p_label_in];\n\
551        ld.param.u64  %rd3, [p_label_out];\n\
552        ld.param.u32  %r0,  [p_n];\n\
553    \n\
554        mov.u32       %r1, %ntid.x;\n\
555        mov.u32       %r2, %ctaid.x;\n\
556        mov.u32       %r3, %tid.x;\n\
557        mad.lo.u32    %r4, %r1, %r2, %r3;\n\
558    \n\
559        setp.ge.u32   %p0, %r4, %r0;\n\
560        @%p0 bra $LP_DONE;\n\
561    \n\
562        // start = row_ptr[u], end = row_ptr[u+1]\n\
563        mul.wide.u32  %rd4, %r4, 4;\n\
564        add.u64       %rd5, %rd0, %rd4;\n\
565        ld.global.u32 %r5, [%rd5];\n\
566        add.u64       %rd6, %rd5, 4;\n\
567        ld.global.u32 %r6, [%rd6];\n\
568    \n\
569        // default: keep own label\n\
570        add.u64       %rd7, %rd2, %rd4;\n\
571        ld.global.u32 %r7, [%rd7];\n\
572        mov.u32       %r8, %r7;\n\
573    \n\
574        // simple min-label rule for ties (deterministic; aggregation done CPU-side)\n\
575        mov.u32       %r9, %r5;\n\
576    $LP_LOOP:\n\
577        setp.ge.u32   %p0, %r9, %r6;\n\
578        @%p0 bra $LP_WRITE;\n\
579    \n\
580        mul.wide.u32  %rd8, %r9, 4;\n\
581        add.u64       %rd9, %rd1, %rd8;\n\
582        ld.global.u32 %r10, [%rd9];\n\
583        mul.wide.u32  %rd10, %r10, 4;\n\
584        add.u64       %rd11, %rd2, %rd10;\n\
585        ld.global.u32 %r11, [%rd11];\n\
586    \n\
587        setp.ge.u32   %p0, %r11, %r8;\n\
588        @%p0 bra $LP_NEXT;\n\
589        mov.u32       %r8, %r11;\n\
590    \n\
591    $LP_NEXT:\n\
592        add.u32       %r9, %r9, 1;\n\
593        bra $LP_LOOP;\n\
594    \n\
595    $LP_WRITE:\n\
596        add.u64       %rd12, %rd3, %rd4;\n\
597        st.global.u32 [%rd12], %r8;\n\
598    \n\
599    $LP_DONE:\n\
600        ret;\n\
601    }\n";
602    hdr + body
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608
609    type KernelFn = fn(u32) -> String;
610
611    fn all_kernels() -> Vec<(&'static str, KernelFn)> {
612        vec![
613            ("bfs_level", bfs_level_ptx),
614            ("dijkstra_relax", dijkstra_relax_ptx),
615            ("pagerank_step", pagerank_step_ptx),
616            ("fw_inner", fw_inner_ptx),
617            ("triangle_count", triangle_count_ptx),
618            ("csr_spmv_bool", csr_spmv_bool_ptx),
619            ("community_label", community_label_ptx),
620        ]
621    }
622
623    #[test]
624    fn ptx_header_versions() {
625        assert!(ptx_header(75).contains("7.5"));
626        assert!(ptx_header(80).contains("8.0"));
627        assert!(ptx_header(90).contains("8.4"));
628        assert!(ptx_header(100).contains("8.7"));
629    }
630
631    #[test]
632    fn ptx_all_kernels_non_empty_all_sm() {
633        for sm in [75u32, 80, 86, 89, 90, 100] {
634            for (name, f) in all_kernels() {
635                let s = f(sm);
636                assert!(!s.is_empty(), "kernel {name} sm={sm} produced empty string");
637                assert!(
638                    s.contains(".visible .entry"),
639                    "kernel {name} sm={sm} missing entry"
640                );
641                assert!(s.contains("ret"), "kernel {name} sm={sm} missing ret");
642            }
643        }
644    }
645
646    #[test]
647    fn ptx_target_matches_sm() {
648        for sm in [75u32, 80, 86, 89, 90, 100] {
649            let s = bfs_level_ptx(sm);
650            assert!(s.contains(&format!("sm_{sm}")));
651        }
652    }
653}