Skip to main content

provable_contracts/kernels/
batchnorm_ptx.rs

1/// PTX assembly for BatchNorm kernel (training mode).
2///
3/// One block per channel. Each block reduces across the batch dimension
4/// to compute per-channel mean and variance, then normalizes.
5pub fn batchnorm_ptx() -> &'static str {
6    r#".version 8.5
7.target sm_90
8.address_size 64
9
10// BatchNorm kernel (training): 1 block per channel, 256 threads per block.
11// Each block computes mean/var for its channel, normalizes, and updates running stats.
12// Input layout: [N, C] row-major, element (n, ch) = input[n * C + ch].
13.visible .entry batchnorm_kernel(
14    .param .u64 input_ptr,
15    .param .u64 gamma_ptr,
16    .param .u64 beta_ptr,
17    .param .u64 output_ptr,
18    .param .u64 running_mean_ptr,
19    .param .u64 running_var_ptr,
20    .param .u32 batch_size,
21    .param .u32 channels,
22    .param .f32 eps,
23    .param .f32 momentum
24)
25{
26    .reg .u32 %tid, %ch, %n_batch, %n_ch, %i, %idx, %stride;
27    .reg .u32 %lane, %warp_id, %mask;
28    .reg .u64 %in_base, %g_base, %b_base, %out_base;
29    .reg .u64 %rm_base, %rv_base, %addr;
30    .reg .f32 %val, %diff, %sq;
31    .reg .f32 %sum_local, %sum_warp, %batch_mean;
32    .reg .f32 %var_local, %var_warp, %batch_var;
33    .reg .f32 %inv_std, %eps, %momentum, %nf;
34    .reg .f32 %gamma_val, %beta_val, %normed, %result;
35    .reg .f32 %old_rm, %old_rv, %new_rm, %new_rv, %one_minus_m;
36    .reg .pred %p;
37    .shared .f32 smem[32];
38
39    ld.param.u64 %in_base, [input_ptr];
40    ld.param.u64 %g_base, [gamma_ptr];
41    ld.param.u64 %b_base, [beta_ptr];
42    ld.param.u64 %out_base, [output_ptr];
43    ld.param.u64 %rm_base, [running_mean_ptr];
44    ld.param.u64 %rv_base, [running_var_ptr];
45    ld.param.u32 %n_batch, [batch_size];
46    ld.param.u32 %n_ch, [channels];
47    ld.param.f32 %eps, [eps];
48    ld.param.f32 %momentum, [momentum];
49
50    mov.u32 %tid, %tid.x;
51    mov.u32 %ch, %ctaid.x;  // 1 block per channel
52    mov.u32 %mask, 0xFFFFFFFF;
53
54    // --- Pass 1: compute sum for mean ---
55    mov.f32 %sum_local, 0f00000000;
56    mov.u32 %i, %tid;
57mean_loop:
58    setp.ge.u32 %p, %i, %n_batch;
59    @%p bra mean_done;
60    // idx = i * channels + ch
61    mad.lo.u32 %idx, %i, %n_ch, %ch;
62    cvt.u64.u32 %addr, %idx;
63    shl.b64 %addr, %addr, 2;
64    add.u64 %addr, %in_base, %addr;
65    ld.global.f32 %val, [%addr];
66    add.f32 %sum_local, %sum_local, %val;
67    add.u32 %i, %i, 256;
68    bra mean_loop;
69mean_done:
70
71    // Warp-level sum reduction
72    shfl.sync.down.b32 %sum_warp, %sum_local, 16, 31, %mask;
73    add.f32 %sum_local, %sum_local, %sum_warp;
74    shfl.sync.down.b32 %sum_warp, %sum_local, 8, 31, %mask;
75    add.f32 %sum_local, %sum_local, %sum_warp;
76    shfl.sync.down.b32 %sum_warp, %sum_local, 4, 31, %mask;
77    add.f32 %sum_local, %sum_local, %sum_warp;
78    shfl.sync.down.b32 %sum_warp, %sum_local, 2, 31, %mask;
79    add.f32 %sum_local, %sum_local, %sum_warp;
80    shfl.sync.down.b32 %sum_warp, %sum_local, 1, 31, %mask;
81    add.f32 %sum_local, %sum_local, %sum_warp;
82
83    and.b32 %lane, %tid, 31;
84    shr.b32 %warp_id, %tid, 5;
85    setp.eq.u32 %p, %lane, 0;
86    @%p st.shared.f32 [smem + %warp_id * 4], %sum_local;
87    bar.sync 0;
88
89    setp.lt.u32 %p, %tid, 8;
90    @%p ld.shared.f32 %sum_local, [smem + %tid * 4];
91    @!%p mov.f32 %sum_local, 0f00000000;
92    shfl.sync.down.b32 %sum_warp, %sum_local, 4, 31, %mask;
93    add.f32 %sum_local, %sum_local, %sum_warp;
94    shfl.sync.down.b32 %sum_warp, %sum_local, 2, 31, %mask;
95    add.f32 %sum_local, %sum_local, %sum_warp;
96    shfl.sync.down.b32 %sum_warp, %sum_local, 1, 31, %mask;
97    add.f32 %sum_local, %sum_local, %sum_warp;
98
99    // mean = sum / N
100    setp.eq.u32 %p, %tid, 0;
101    cvt.rn.f32.u32 %nf, %n_batch;
102    div.approx.f32 %batch_mean, %sum_local, %nf;
103    @%p st.shared.f32 [smem], %batch_mean;
104    bar.sync 0;
105    ld.shared.f32 %batch_mean, [smem];
106
107    // --- Pass 2: compute variance ---
108    mov.f32 %var_local, 0f00000000;
109    mov.u32 %i, %tid;
110var_loop:
111    setp.ge.u32 %p, %i, %n_batch;
112    @%p bra var_done;
113    mad.lo.u32 %idx, %i, %n_ch, %ch;
114    cvt.u64.u32 %addr, %idx;
115    shl.b64 %addr, %addr, 2;
116    add.u64 %addr, %in_base, %addr;
117    ld.global.f32 %val, [%addr];
118    sub.f32 %diff, %val, %batch_mean;
119    mul.f32 %sq, %diff, %diff;
120    add.f32 %var_local, %var_local, %sq;
121    add.u32 %i, %i, 256;
122    bra var_loop;
123var_done:
124
125    // Warp-level variance reduction
126    shfl.sync.down.b32 %var_warp, %var_local, 16, 31, %mask;
127    add.f32 %var_local, %var_local, %var_warp;
128    shfl.sync.down.b32 %var_warp, %var_local, 8, 31, %mask;
129    add.f32 %var_local, %var_local, %var_warp;
130    shfl.sync.down.b32 %var_warp, %var_local, 4, 31, %mask;
131    add.f32 %var_local, %var_local, %var_warp;
132    shfl.sync.down.b32 %var_warp, %var_local, 2, 31, %mask;
133    add.f32 %var_local, %var_local, %var_warp;
134    shfl.sync.down.b32 %var_warp, %var_local, 1, 31, %mask;
135    add.f32 %var_local, %var_local, %var_warp;
136
137    and.b32 %lane, %tid, 31;
138    shr.b32 %warp_id, %tid, 5;
139    setp.eq.u32 %p, %lane, 0;
140    @%p st.shared.f32 [smem + %warp_id * 4], %var_local;
141    bar.sync 0;
142
143    setp.lt.u32 %p, %tid, 8;
144    @%p ld.shared.f32 %var_local, [smem + %tid * 4];
145    @!%p mov.f32 %var_local, 0f00000000;
146    shfl.sync.down.b32 %var_warp, %var_local, 4, 31, %mask;
147    add.f32 %var_local, %var_local, %var_warp;
148    shfl.sync.down.b32 %var_warp, %var_local, 2, 31, %mask;
149    add.f32 %var_local, %var_local, %var_warp;
150    shfl.sync.down.b32 %var_warp, %var_local, 1, 31, %mask;
151    add.f32 %var_local, %var_local, %var_warp;
152
153    // var = var_sum / N, inv_std = rsqrt(var + eps)
154    setp.eq.u32 %p, %tid, 0;
155    div.approx.f32 %batch_var, %var_local, %nf;
156    add.f32 %batch_var, %batch_var, %eps;
157    rsqrt.approx.f32 %inv_std, %batch_var;
158    @%p st.shared.f32 [smem], %inv_std;
159
160    // Also update running stats (thread 0 only)
161    @%p {
162        // running_mean = (1-m)*running_mean + m*batch_mean
163        cvt.u64.u32 %addr, %ch;
164        shl.b64 %addr, %addr, 2;
165        add.u64 %addr, %rm_base, %addr;
166        ld.global.f32 %old_rm, [%addr];
167        mov.f32 %one_minus_m, 0f3F800000;
168        sub.f32 %one_minus_m, %one_minus_m, %momentum;
169        mul.f32 %new_rm, %one_minus_m, %old_rm;
170        fma.rn.f32 %new_rm, %momentum, %batch_mean, %new_rm;
171        st.global.f32 [%addr], %new_rm;
172
173        // running_var = (1-m)*running_var + m*batch_var (before eps was added)
174        // Recompute batch_var without eps
175        sub.f32 %batch_var, %batch_var, %eps;
176        cvt.u64.u32 %addr, %ch;
177        shl.b64 %addr, %addr, 2;
178        add.u64 %addr, %rv_base, %addr;
179        ld.global.f32 %old_rv, [%addr];
180        mul.f32 %new_rv, %one_minus_m, %old_rv;
181        fma.rn.f32 %new_rv, %momentum, %batch_var, %new_rv;
182        st.global.f32 [%addr], %new_rv;
183    }
184
185    bar.sync 0;
186    ld.shared.f32 %inv_std, [smem];
187
188    // Load gamma and beta for this channel
189    cvt.u64.u32 %addr, %ch;
190    shl.b64 %addr, %addr, 2;
191    add.u64 %addr, %g_base, %addr;
192    ld.global.f32 %gamma_val, [%addr];
193    cvt.u64.u32 %addr, %ch;
194    shl.b64 %addr, %addr, 2;
195    add.u64 %addr, %b_base, %addr;
196    ld.global.f32 %beta_val, [%addr];
197
198    // --- Pass 3: normalize + affine ---
199    mov.u32 %i, %tid;
200norm_loop:
201    setp.ge.u32 %p, %i, %n_batch;
202    @%p bra norm_done;
203    mad.lo.u32 %idx, %i, %n_ch, %ch;
204    cvt.u64.u32 %addr, %idx;
205    shl.b64 %addr, %addr, 2;
206    add.u64 %addr, %in_base, %addr;
207    ld.global.f32 %val, [%addr];
208    sub.f32 %diff, %val, %batch_mean;
209    mul.f32 %normed, %diff, %inv_std;
210    fma.rn.f32 %result, %gamma_val, %normed, %beta_val;
211    cvt.u64.u32 %addr, %idx;
212    shl.b64 %addr, %addr, 2;
213    add.u64 %addr, %out_base, %addr;
214    st.global.f32 [%addr], %result;
215    add.u32 %i, %i, 256;
216    bra norm_loop;
217norm_done:
218
219    bar.sync 0;
220    ret;
221}
222"#
223}