1pub fn batchnorm_ptx() -> &'static str {
6 r#".version 8.5
7.target sm_90
8.address_size 64
9
10// BatchNorm kernel (training): 1 block per channel, 256 threads per block.
11// Each block computes mean/var for its channel, normalizes, and updates running stats.
12// Input layout: [N, C] row-major, element (n, ch) = input[n * C + ch].
13.visible .entry batchnorm_kernel(
14 .param .u64 input_ptr,
15 .param .u64 gamma_ptr,
16 .param .u64 beta_ptr,
17 .param .u64 output_ptr,
18 .param .u64 running_mean_ptr,
19 .param .u64 running_var_ptr,
20 .param .u32 batch_size,
21 .param .u32 channels,
22 .param .f32 eps,
23 .param .f32 momentum
24)
25{
26 .reg .u32 %tid, %ch, %n_batch, %n_ch, %i, %idx, %stride;
27 .reg .u32 %lane, %warp_id, %mask;
28 .reg .u64 %in_base, %g_base, %b_base, %out_base;
29 .reg .u64 %rm_base, %rv_base, %addr;
30 .reg .f32 %val, %diff, %sq;
31 .reg .f32 %sum_local, %sum_warp, %batch_mean;
32 .reg .f32 %var_local, %var_warp, %batch_var;
33 .reg .f32 %inv_std, %eps, %momentum, %nf;
34 .reg .f32 %gamma_val, %beta_val, %normed, %result;
35 .reg .f32 %old_rm, %old_rv, %new_rm, %new_rv, %one_minus_m;
36 .reg .pred %p;
37 .shared .f32 smem[32];
38
39 ld.param.u64 %in_base, [input_ptr];
40 ld.param.u64 %g_base, [gamma_ptr];
41 ld.param.u64 %b_base, [beta_ptr];
42 ld.param.u64 %out_base, [output_ptr];
43 ld.param.u64 %rm_base, [running_mean_ptr];
44 ld.param.u64 %rv_base, [running_var_ptr];
45 ld.param.u32 %n_batch, [batch_size];
46 ld.param.u32 %n_ch, [channels];
47 ld.param.f32 %eps, [eps];
48 ld.param.f32 %momentum, [momentum];
49
50 mov.u32 %tid, %tid.x;
51 mov.u32 %ch, %ctaid.x; // 1 block per channel
52 mov.u32 %mask, 0xFFFFFFFF;
53
54 // --- Pass 1: compute sum for mean ---
55 mov.f32 %sum_local, 0f00000000;
56 mov.u32 %i, %tid;
57mean_loop:
58 setp.ge.u32 %p, %i, %n_batch;
59 @%p bra mean_done;
60 // idx = i * channels + ch
61 mad.lo.u32 %idx, %i, %n_ch, %ch;
62 cvt.u64.u32 %addr, %idx;
63 shl.b64 %addr, %addr, 2;
64 add.u64 %addr, %in_base, %addr;
65 ld.global.f32 %val, [%addr];
66 add.f32 %sum_local, %sum_local, %val;
67 add.u32 %i, %i, 256;
68 bra mean_loop;
69mean_done:
70
71 // Warp-level sum reduction
72 shfl.sync.down.b32 %sum_warp, %sum_local, 16, 31, %mask;
73 add.f32 %sum_local, %sum_local, %sum_warp;
74 shfl.sync.down.b32 %sum_warp, %sum_local, 8, 31, %mask;
75 add.f32 %sum_local, %sum_local, %sum_warp;
76 shfl.sync.down.b32 %sum_warp, %sum_local, 4, 31, %mask;
77 add.f32 %sum_local, %sum_local, %sum_warp;
78 shfl.sync.down.b32 %sum_warp, %sum_local, 2, 31, %mask;
79 add.f32 %sum_local, %sum_local, %sum_warp;
80 shfl.sync.down.b32 %sum_warp, %sum_local, 1, 31, %mask;
81 add.f32 %sum_local, %sum_local, %sum_warp;
82
83 and.b32 %lane, %tid, 31;
84 shr.b32 %warp_id, %tid, 5;
85 setp.eq.u32 %p, %lane, 0;
86 @%p st.shared.f32 [smem + %warp_id * 4], %sum_local;
87 bar.sync 0;
88
89 setp.lt.u32 %p, %tid, 8;
90 @%p ld.shared.f32 %sum_local, [smem + %tid * 4];
91 @!%p mov.f32 %sum_local, 0f00000000;
92 shfl.sync.down.b32 %sum_warp, %sum_local, 4, 31, %mask;
93 add.f32 %sum_local, %sum_local, %sum_warp;
94 shfl.sync.down.b32 %sum_warp, %sum_local, 2, 31, %mask;
95 add.f32 %sum_local, %sum_local, %sum_warp;
96 shfl.sync.down.b32 %sum_warp, %sum_local, 1, 31, %mask;
97 add.f32 %sum_local, %sum_local, %sum_warp;
98
99 // mean = sum / N
100 setp.eq.u32 %p, %tid, 0;
101 cvt.rn.f32.u32 %nf, %n_batch;
102 div.approx.f32 %batch_mean, %sum_local, %nf;
103 @%p st.shared.f32 [smem], %batch_mean;
104 bar.sync 0;
105 ld.shared.f32 %batch_mean, [smem];
106
107 // --- Pass 2: compute variance ---
108 mov.f32 %var_local, 0f00000000;
109 mov.u32 %i, %tid;
110var_loop:
111 setp.ge.u32 %p, %i, %n_batch;
112 @%p bra var_done;
113 mad.lo.u32 %idx, %i, %n_ch, %ch;
114 cvt.u64.u32 %addr, %idx;
115 shl.b64 %addr, %addr, 2;
116 add.u64 %addr, %in_base, %addr;
117 ld.global.f32 %val, [%addr];
118 sub.f32 %diff, %val, %batch_mean;
119 mul.f32 %sq, %diff, %diff;
120 add.f32 %var_local, %var_local, %sq;
121 add.u32 %i, %i, 256;
122 bra var_loop;
123var_done:
124
125 // Warp-level variance reduction
126 shfl.sync.down.b32 %var_warp, %var_local, 16, 31, %mask;
127 add.f32 %var_local, %var_local, %var_warp;
128 shfl.sync.down.b32 %var_warp, %var_local, 8, 31, %mask;
129 add.f32 %var_local, %var_local, %var_warp;
130 shfl.sync.down.b32 %var_warp, %var_local, 4, 31, %mask;
131 add.f32 %var_local, %var_local, %var_warp;
132 shfl.sync.down.b32 %var_warp, %var_local, 2, 31, %mask;
133 add.f32 %var_local, %var_local, %var_warp;
134 shfl.sync.down.b32 %var_warp, %var_local, 1, 31, %mask;
135 add.f32 %var_local, %var_local, %var_warp;
136
137 and.b32 %lane, %tid, 31;
138 shr.b32 %warp_id, %tid, 5;
139 setp.eq.u32 %p, %lane, 0;
140 @%p st.shared.f32 [smem + %warp_id * 4], %var_local;
141 bar.sync 0;
142
143 setp.lt.u32 %p, %tid, 8;
144 @%p ld.shared.f32 %var_local, [smem + %tid * 4];
145 @!%p mov.f32 %var_local, 0f00000000;
146 shfl.sync.down.b32 %var_warp, %var_local, 4, 31, %mask;
147 add.f32 %var_local, %var_local, %var_warp;
148 shfl.sync.down.b32 %var_warp, %var_local, 2, 31, %mask;
149 add.f32 %var_local, %var_local, %var_warp;
150 shfl.sync.down.b32 %var_warp, %var_local, 1, 31, %mask;
151 add.f32 %var_local, %var_local, %var_warp;
152
153 // var = var_sum / N, inv_std = rsqrt(var + eps)
154 setp.eq.u32 %p, %tid, 0;
155 div.approx.f32 %batch_var, %var_local, %nf;
156 add.f32 %batch_var, %batch_var, %eps;
157 rsqrt.approx.f32 %inv_std, %batch_var;
158 @%p st.shared.f32 [smem], %inv_std;
159
160 // Also update running stats (thread 0 only)
161 @%p {
162 // running_mean = (1-m)*running_mean + m*batch_mean
163 cvt.u64.u32 %addr, %ch;
164 shl.b64 %addr, %addr, 2;
165 add.u64 %addr, %rm_base, %addr;
166 ld.global.f32 %old_rm, [%addr];
167 mov.f32 %one_minus_m, 0f3F800000;
168 sub.f32 %one_minus_m, %one_minus_m, %momentum;
169 mul.f32 %new_rm, %one_minus_m, %old_rm;
170 fma.rn.f32 %new_rm, %momentum, %batch_mean, %new_rm;
171 st.global.f32 [%addr], %new_rm;
172
173 // running_var = (1-m)*running_var + m*batch_var (before eps was added)
174 // Recompute batch_var without eps
175 sub.f32 %batch_var, %batch_var, %eps;
176 cvt.u64.u32 %addr, %ch;
177 shl.b64 %addr, %addr, 2;
178 add.u64 %addr, %rv_base, %addr;
179 ld.global.f32 %old_rv, [%addr];
180 mul.f32 %new_rv, %one_minus_m, %old_rv;
181 fma.rn.f32 %new_rv, %momentum, %batch_var, %new_rv;
182 st.global.f32 [%addr], %new_rv;
183 }
184
185 bar.sync 0;
186 ld.shared.f32 %inv_std, [smem];
187
188 // Load gamma and beta for this channel
189 cvt.u64.u32 %addr, %ch;
190 shl.b64 %addr, %addr, 2;
191 add.u64 %addr, %g_base, %addr;
192 ld.global.f32 %gamma_val, [%addr];
193 cvt.u64.u32 %addr, %ch;
194 shl.b64 %addr, %addr, 2;
195 add.u64 %addr, %b_base, %addr;
196 ld.global.f32 %beta_val, [%addr];
197
198 // --- Pass 3: normalize + affine ---
199 mov.u32 %i, %tid;
200norm_loop:
201 setp.ge.u32 %p, %i, %n_batch;
202 @%p bra norm_done;
203 mad.lo.u32 %idx, %i, %n_ch, %ch;
204 cvt.u64.u32 %addr, %idx;
205 shl.b64 %addr, %addr, 2;
206 add.u64 %addr, %in_base, %addr;
207 ld.global.f32 %val, [%addr];
208 sub.f32 %diff, %val, %batch_mean;
209 mul.f32 %normed, %diff, %inv_std;
210 fma.rn.f32 %result, %gamma_val, %normed, %beta_val;
211 cvt.u64.u32 %addr, %idx;
212 shl.b64 %addr, %addr, 2;
213 add.u64 %addr, %out_base, %addr;
214 st.global.f32 [%addr], %result;
215 add.u32 %i, %i, 256;
216 bra norm_loop;
217norm_done:
218
219 bar.sync 0;
220 ret;
221}
222"#
223}