1#[cfg(feature = "gpu")]
14use trueno::backends::gpu::shaders::backward::{
15 CROSS_ENTROPY_BACKWARD_SHADER, CROSS_ENTROPY_FORWARD_SHADER,
16};
17#[cfg(feature = "gpu")]
18use trueno::backends::gpu::wgpu;
19
20#[cfg(feature = "gpu")]
22pub struct WgslCrossEntropy {
23 device: wgpu::Device,
24 queue: wgpu::Queue,
25 forward_pipeline: wgpu::ComputePipeline,
26 backward_pipeline: wgpu::ComputePipeline,
27 forward_bgl: wgpu::BindGroupLayout,
28 backward_bgl: wgpu::BindGroupLayout,
29}
30
31#[repr(C)]
32#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
33struct CEForwardParams {
34 seq_len: u32,
35 vocab_size: u32,
36 loss_start: u32,
37 loss_end: u32,
38}
39
40#[repr(C)]
41#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
42struct CEBackwardParams {
43 seq_len: u32,
44 vocab_size: u32,
45 loss_start: u32,
46 loss_end: u32,
47 scale: f32,
48 _pad0: u32,
49 _pad1: u32,
50 _pad2: u32,
51}
52
53#[cfg(feature = "gpu")]
54impl WgslCrossEntropy {
55 pub fn new(device: wgpu::Device, queue: wgpu::Queue) -> Self {
56 let storage_ro = |binding: u32| wgpu::BindGroupLayoutEntry {
57 binding,
58 visibility: wgpu::ShaderStages::COMPUTE,
59 ty: wgpu::BindingType::Buffer {
60 ty: wgpu::BufferBindingType::Storage { read_only: true },
61 has_dynamic_offset: false,
62 min_binding_size: None,
63 },
64 count: None,
65 };
66 let storage_rw = |binding: u32| wgpu::BindGroupLayoutEntry {
67 binding,
68 visibility: wgpu::ShaderStages::COMPUTE,
69 ty: wgpu::BindingType::Buffer {
70 ty: wgpu::BufferBindingType::Storage { read_only: false },
71 has_dynamic_offset: false,
72 min_binding_size: None,
73 },
74 count: None,
75 };
76 let uniform = |binding: u32| wgpu::BindGroupLayoutEntry {
77 binding,
78 visibility: wgpu::ShaderStages::COMPUTE,
79 ty: wgpu::BindingType::Buffer {
80 ty: wgpu::BufferBindingType::Uniform,
81 has_dynamic_offset: false,
82 min_binding_size: None,
83 },
84 count: None,
85 };
86
87 let forward_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
89 label: Some("ce_fwd_bgl"),
90 entries: &[storage_ro(0), storage_ro(1), storage_rw(2), storage_rw(3), uniform(4)],
91 });
92 let fwd_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
93 label: Some("ce_forward"),
94 source: wgpu::ShaderSource::Wgsl(CROSS_ENTROPY_FORWARD_SHADER.into()),
95 });
96 let fwd_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
97 label: Some("ce_fwd_pl"),
98 bind_group_layouts: &[&forward_bgl],
99 push_constant_ranges: &[],
100 });
101 let forward_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
102 label: Some("ce_fwd_pipe"),
103 layout: Some(&fwd_pl),
104 module: &fwd_shader,
105 entry_point: Some("main"),
106 compilation_options: Default::default(),
107 cache: None,
108 });
109
110 let backward_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
112 label: Some("ce_bwd_bgl"),
113 entries: &[storage_rw(0), storage_ro(1), storage_ro(2), uniform(3)],
114 });
115 let bwd_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
116 label: Some("ce_backward"),
117 source: wgpu::ShaderSource::Wgsl(CROSS_ENTROPY_BACKWARD_SHADER.into()),
118 });
119 let bwd_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
120 label: Some("ce_bwd_pl"),
121 bind_group_layouts: &[&backward_bgl],
122 push_constant_ranges: &[],
123 });
124 let backward_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
125 label: Some("ce_bwd_pipe"),
126 layout: Some(&bwd_pl),
127 module: &bwd_shader,
128 entry_point: Some("main"),
129 compilation_options: Default::default(),
130 cache: None,
131 });
132
133 Self { device, queue, forward_pipeline, backward_pipeline, forward_bgl, backward_bgl }
134 }
135
136 pub fn forward_async(
144 &self,
145 logits: &wgpu::Buffer,
146 labels: &wgpu::Buffer,
147 losses: &wgpu::Buffer,
148 logsumexp: &wgpu::Buffer,
149 seq_len: u32,
150 vocab_size: u32,
151 loss_start: u32,
152 loss_end: u32,
153 ) {
154 let params = CEForwardParams { seq_len, vocab_size, loss_start, loss_end };
155 let params_buf = self.make_uniform(¶ms);
156 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
157 label: None,
158 layout: &self.forward_bgl,
159 entries: &[
160 wgpu::BindGroupEntry { binding: 0, resource: logits.as_entire_binding() },
161 wgpu::BindGroupEntry { binding: 1, resource: labels.as_entire_binding() },
162 wgpu::BindGroupEntry { binding: 2, resource: losses.as_entire_binding() },
163 wgpu::BindGroupEntry { binding: 3, resource: logsumexp.as_entire_binding() },
164 wgpu::BindGroupEntry { binding: 4, resource: params_buf.as_entire_binding() },
165 ],
166 });
167 let mut encoder = self.device.create_command_encoder(&Default::default());
168 {
169 let mut pass = encoder.begin_compute_pass(&Default::default());
170 pass.set_pipeline(&self.forward_pipeline);
171 pass.set_bind_group(0, &bg, &[]);
172 pass.dispatch_workgroups(seq_len, 1, 1);
173 }
174 self.queue.submit(Some(encoder.finish()));
175 }
176
177 pub fn read_loss(
180 &self,
181 losses: &wgpu::Buffer,
182 seq_len: u32,
183 loss_start: u32,
184 loss_end: u32,
185 ) -> f32 {
186 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
189 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
190 label: Some("ce_loss_readback"),
191 size: u64::from(seq_len) * 4,
192 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
193 mapped_at_creation: false,
194 });
195 let mut encoder = self.device.create_command_encoder(&Default::default());
196 encoder.copy_buffer_to_buffer(losses, 0, &staging, 0, u64::from(seq_len) * 4);
197 self.queue.submit(Some(encoder.finish()));
198
199 let slice = staging.slice(..);
200 let (tx, rx) = std::sync::mpsc::channel();
201 slice.map_async(wgpu::MapMode::Read, move |r| {
202 tx.send(r).ok();
203 });
204 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
205 rx.recv().unwrap().unwrap();
206
207 let data = slice.get_mapped_range();
208 let loss_data: &[f32] = bytemuck::cast_slice(&data);
209 let num_tokens = (loss_end - loss_start) as f32;
210 let avg = if num_tokens > 0.0 { loss_data.iter().sum::<f32>() / num_tokens } else { 0.0 };
211 drop(data);
212 staging.unmap();
213 avg
214 }
215
216 pub fn forward(
218 &self,
219 logits: &wgpu::Buffer,
220 labels: &wgpu::Buffer,
221 losses: &wgpu::Buffer,
222 logsumexp: &wgpu::Buffer,
223 seq_len: u32,
224 vocab_size: u32,
225 loss_start: u32,
226 loss_end: u32,
227 ) -> f32 {
228 self.forward_async(
229 logits, labels, losses, logsumexp, seq_len, vocab_size, loss_start, loss_end,
230 );
231 self.read_loss(losses, seq_len, loss_start, loss_end)
232 }
233
234 pub fn backward(
238 &self,
239 logits: &wgpu::Buffer, labels: &wgpu::Buffer, logsumexp: &wgpu::Buffer, seq_len: u32,
243 vocab_size: u32,
244 loss_start: u32,
245 loss_end: u32,
246 ) {
247 let num_tokens = (loss_end - loss_start).max(1);
248 let scale = 1.0 / num_tokens as f32;
249
250 let params = CEBackwardParams {
251 seq_len,
252 vocab_size,
253 loss_start,
254 loss_end,
255 scale,
256 _pad0: 0,
257 _pad1: 0,
258 _pad2: 0,
259 };
260 let params_buf = self.make_uniform(¶ms);
261
262 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
263 label: None,
264 layout: &self.backward_bgl,
265 entries: &[
266 wgpu::BindGroupEntry { binding: 0, resource: logits.as_entire_binding() },
267 wgpu::BindGroupEntry { binding: 1, resource: labels.as_entire_binding() },
268 wgpu::BindGroupEntry { binding: 2, resource: logsumexp.as_entire_binding() },
269 wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
270 ],
271 });
272
273 let mut encoder = self.device.create_command_encoder(&Default::default());
274 {
275 let mut pass = encoder.begin_compute_pass(&Default::default());
276 pass.set_pipeline(&self.backward_pipeline);
277 pass.set_bind_group(0, &bg, &[]);
278 let total = seq_len * vocab_size;
279 let workgroups = total.div_ceil(256);
280 if workgroups <= 65535 {
281 pass.dispatch_workgroups(workgroups, 1, 1);
282 } else {
283 let x = 65535u32;
285 let y = workgroups.div_ceil(x);
286 pass.dispatch_workgroups(x, y, 1);
287 }
288 }
289 self.queue.submit(Some(encoder.finish()));
290 }
291
292 fn make_uniform<T: bytemuck::Pod>(&self, data: &T) -> wgpu::Buffer {
293 let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
294 label: None,
295 size: std::mem::size_of::<T>() as u64,
296 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
297 mapped_at_creation: false,
298 });
299 self.queue.write_buffer(&buf, 0, bytemuck::bytes_of(data));
300 buf
301 }
302}
303
304#[cfg(test)]
305#[cfg(feature = "gpu")]
306mod tests {
307 use super::*;
308
309 #[test]
311 fn test_fused_ce_matches_naive() {
312 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
313 let adapter = match trueno::backends::gpu::runtime::block_on(
314 instance.request_adapter(&wgpu::RequestAdapterOptions::default()),
315 ) {
316 Ok(a) => a,
317 Err(_) => return,
318 };
319 let (device, queue) = match trueno::backends::gpu::runtime::block_on(
320 adapter.request_device(&wgpu::DeviceDescriptor::default()),
321 ) {
322 Ok(dq) => dq,
323 Err(_) => return,
324 };
325
326 let ce = WgslCrossEntropy::new(device.clone(), queue.clone());
327
328 let seq_len = 4u32;
329 let vocab = 8u32;
330
331 let logits_data: Vec<f32> =
333 (0..seq_len * vocab).map(|i| ((i as f32) * 0.3).sin()).collect();
334 let labels_data: Vec<u32> = vec![2, 5, 1, 7]; let buf = |data: &[u8], rw: bool| -> wgpu::Buffer {
337 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
338 label: None,
339 size: data.len() as u64,
340 usage: wgpu::BufferUsages::STORAGE
341 | wgpu::BufferUsages::COPY_SRC
342 | wgpu::BufferUsages::COPY_DST
343 | if rw { wgpu::BufferUsages::empty() } else { wgpu::BufferUsages::empty() },
344 mapped_at_creation: false,
345 });
346 queue.write_buffer(&buffer, 0, data);
347 buffer
348 };
349
350 let logits = buf(bytemuck::cast_slice(&logits_data), true);
351 let labels = buf(bytemuck::cast_slice(&labels_data), false);
352 let losses = buf(&vec![0u8; seq_len as usize * 4], true);
353 let logsumexp_buf = buf(&vec![0u8; seq_len as usize * 4], true);
354
355 let gpu_loss =
357 ce.forward(&logits, &labels, &losses, &logsumexp_buf, seq_len, vocab, 0, seq_len);
358
359 let mut cpu_loss = 0.0f32;
361 for pos in 0..seq_len as usize {
362 let offset = pos * vocab as usize;
363 let label = labels_data[pos] as usize;
364 let max_val: f32 = logits_data[offset..offset + vocab as usize]
365 .iter()
366 .copied()
367 .fold(f32::NEG_INFINITY, f32::max);
368 let sum_exp: f32 = logits_data[offset..offset + vocab as usize]
369 .iter()
370 .map(|x| (x - max_val).exp())
371 .sum();
372 let lse = max_val + sum_exp.ln();
373 cpu_loss += -logits_data[offset + label] + lse;
374 }
375 cpu_loss /= seq_len as f32;
376
377 let err = (gpu_loss - cpu_loss).abs();
378 eprintln!("[PARITY] Fused CE: gpu={gpu_loss:.6}, cpu={cpu_loss:.6}, err={err:.6}");
379 assert!(err < 1e-4, "Fused CE parity failed: err={err}");
380 }
381}