Skip to main content

entrenar/autograd/
wgpu_cross_entropy.rs

1//! WgslCrossEntropy — fused cross-entropy loss on GPU via wgpu (§26 Step 0d.4)
2//!
3//! Computes causal LM loss without materializing the full softmax tensor.
4//! Forward: loss = -logits[label] + logsumexp(logits) per token
5//! Backward: grad = softmax(logits) - one_hot(label), written IN-PLACE into logits
6//!
7//! Memory savings: only [seq_len] logsumexp scalars saved, not [seq_len × vocab] softmax.
8//! For vocab=32000, seq=512: saves ~62 MB per forward pass.
9//!
10//! Contract: fused-cross-entropy-v1
11//! Zero unsafe, zero FFI.
12
13#[cfg(feature = "gpu")]
14use trueno::backends::gpu::shaders::backward::{
15    CROSS_ENTROPY_BACKWARD_SHADER, CROSS_ENTROPY_FORWARD_SHADER,
16};
17#[cfg(feature = "gpu")]
18use trueno::backends::gpu::wgpu;
19
20/// Fused cross-entropy loss computation on GPU.
21#[cfg(feature = "gpu")]
22pub struct WgslCrossEntropy {
23    device: wgpu::Device,
24    queue: wgpu::Queue,
25    forward_pipeline: wgpu::ComputePipeline,
26    backward_pipeline: wgpu::ComputePipeline,
27    forward_bgl: wgpu::BindGroupLayout,
28    backward_bgl: wgpu::BindGroupLayout,
29}
30
31#[repr(C)]
32#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
33struct CEForwardParams {
34    seq_len: u32,
35    vocab_size: u32,
36    loss_start: u32,
37    loss_end: u32,
38}
39
40#[repr(C)]
41#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
42struct CEBackwardParams {
43    seq_len: u32,
44    vocab_size: u32,
45    loss_start: u32,
46    loss_end: u32,
47    scale: f32,
48    _pad0: u32,
49    _pad1: u32,
50    _pad2: u32,
51}
52
53#[cfg(feature = "gpu")]
54impl WgslCrossEntropy {
55    pub fn new(device: wgpu::Device, queue: wgpu::Queue) -> Self {
56        let storage_ro = |binding: u32| wgpu::BindGroupLayoutEntry {
57            binding,
58            visibility: wgpu::ShaderStages::COMPUTE,
59            ty: wgpu::BindingType::Buffer {
60                ty: wgpu::BufferBindingType::Storage { read_only: true },
61                has_dynamic_offset: false,
62                min_binding_size: None,
63            },
64            count: None,
65        };
66        let storage_rw = |binding: u32| wgpu::BindGroupLayoutEntry {
67            binding,
68            visibility: wgpu::ShaderStages::COMPUTE,
69            ty: wgpu::BindingType::Buffer {
70                ty: wgpu::BufferBindingType::Storage { read_only: false },
71                has_dynamic_offset: false,
72                min_binding_size: None,
73            },
74            count: None,
75        };
76        let uniform = |binding: u32| wgpu::BindGroupLayoutEntry {
77            binding,
78            visibility: wgpu::ShaderStages::COMPUTE,
79            ty: wgpu::BindingType::Buffer {
80                ty: wgpu::BufferBindingType::Uniform,
81                has_dynamic_offset: false,
82                min_binding_size: None,
83            },
84            count: None,
85        };
86
87        // Forward: logits(ro), labels(ro), losses(rw), logsumexp(rw), params(uniform)
88        let forward_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
89            label: Some("ce_fwd_bgl"),
90            entries: &[storage_ro(0), storage_ro(1), storage_rw(2), storage_rw(3), uniform(4)],
91        });
92        let fwd_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
93            label: Some("ce_forward"),
94            source: wgpu::ShaderSource::Wgsl(CROSS_ENTROPY_FORWARD_SHADER.into()),
95        });
96        let fwd_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
97            label: Some("ce_fwd_pl"),
98            bind_group_layouts: &[&forward_bgl],
99            push_constant_ranges: &[],
100        });
101        let forward_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
102            label: Some("ce_fwd_pipe"),
103            layout: Some(&fwd_pl),
104            module: &fwd_shader,
105            entry_point: Some("main"),
106            compilation_options: Default::default(),
107            cache: None,
108        });
109
110        // Backward: logits(rw), labels(ro), logsumexp(ro), params(uniform)
111        let backward_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
112            label: Some("ce_bwd_bgl"),
113            entries: &[storage_rw(0), storage_ro(1), storage_ro(2), uniform(3)],
114        });
115        let bwd_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
116            label: Some("ce_backward"),
117            source: wgpu::ShaderSource::Wgsl(CROSS_ENTROPY_BACKWARD_SHADER.into()),
118        });
119        let bwd_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
120            label: Some("ce_bwd_pl"),
121            bind_group_layouts: &[&backward_bgl],
122            push_constant_ranges: &[],
123        });
124        let backward_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
125            label: Some("ce_bwd_pipe"),
126            layout: Some(&bwd_pl),
127            module: &bwd_shader,
128            entry_point: Some("main"),
129            compilation_options: Default::default(),
130            cache: None,
131        });
132
133        Self { device, queue, forward_pipeline, backward_pipeline, forward_bgl, backward_bgl }
134    }
135
136    /// Compute forward cross-entropy loss on GPU.
137    ///
138    /// Returns average loss over response tokens.
139    /// Saves logsumexp for backward pass.
140    /// Dispatch CE forward compute — no GPU sync, no loss download.
141    /// KAIZEN: the old `forward()` blocked 10s waiting for ALL prior GPU work.
142    /// Call `read_loss()` later (after backward) to get the actual loss value.
143    pub fn forward_async(
144        &self,
145        logits: &wgpu::Buffer,
146        labels: &wgpu::Buffer,
147        losses: &wgpu::Buffer,
148        logsumexp: &wgpu::Buffer,
149        seq_len: u32,
150        vocab_size: u32,
151        loss_start: u32,
152        loss_end: u32,
153    ) {
154        let params = CEForwardParams { seq_len, vocab_size, loss_start, loss_end };
155        let params_buf = self.make_uniform(&params);
156        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
157            label: None,
158            layout: &self.forward_bgl,
159            entries: &[
160                wgpu::BindGroupEntry { binding: 0, resource: logits.as_entire_binding() },
161                wgpu::BindGroupEntry { binding: 1, resource: labels.as_entire_binding() },
162                wgpu::BindGroupEntry { binding: 2, resource: losses.as_entire_binding() },
163                wgpu::BindGroupEntry { binding: 3, resource: logsumexp.as_entire_binding() },
164                wgpu::BindGroupEntry { binding: 4, resource: params_buf.as_entire_binding() },
165            ],
166        });
167        let mut encoder = self.device.create_command_encoder(&Default::default());
168        {
169            let mut pass = encoder.begin_compute_pass(&Default::default());
170            pass.set_pipeline(&self.forward_pipeline);
171            pass.set_bind_group(0, &bg, &[]);
172            pass.dispatch_workgroups(seq_len, 1, 1);
173        }
174        self.queue.submit(Some(encoder.finish()));
175    }
176
177    /// Read back loss from GPU (blocks until all prior GPU work completes).
178    /// Call this AFTER backward + LoRA updates to avoid blocking the pipeline.
179    pub fn read_loss(
180        &self,
181        losses: &wgpu::Buffer,
182        seq_len: u32,
183        loss_start: u32,
184        loss_end: u32,
185    ) -> f32 {
186        // PMAT-498 fix: (1) add buffer label (wgpu rejects unlabeled buffers on map_async),
187        // (2) poll before copy to ensure CE forward kernel completed (loss buffer has data).
188        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
189        let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
190            label: Some("ce_loss_readback"),
191            size: u64::from(seq_len) * 4,
192            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
193            mapped_at_creation: false,
194        });
195        let mut encoder = self.device.create_command_encoder(&Default::default());
196        encoder.copy_buffer_to_buffer(losses, 0, &staging, 0, u64::from(seq_len) * 4);
197        self.queue.submit(Some(encoder.finish()));
198
199        let slice = staging.slice(..);
200        let (tx, rx) = std::sync::mpsc::channel();
201        slice.map_async(wgpu::MapMode::Read, move |r| {
202            tx.send(r).ok();
203        });
204        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
205        rx.recv().unwrap().unwrap();
206
207        let data = slice.get_mapped_range();
208        let loss_data: &[f32] = bytemuck::cast_slice(&data);
209        let num_tokens = (loss_end - loss_start) as f32;
210        let avg = if num_tokens > 0.0 { loss_data.iter().sum::<f32>() / num_tokens } else { 0.0 };
211        drop(data);
212        staging.unmap();
213        avg
214    }
215
216    /// Synchronous forward (legacy — blocks on GPU). Prefer forward_async + read_loss.
217    pub fn forward(
218        &self,
219        logits: &wgpu::Buffer,
220        labels: &wgpu::Buffer,
221        losses: &wgpu::Buffer,
222        logsumexp: &wgpu::Buffer,
223        seq_len: u32,
224        vocab_size: u32,
225        loss_start: u32,
226        loss_end: u32,
227    ) -> f32 {
228        self.forward_async(
229            logits, labels, losses, logsumexp, seq_len, vocab_size, loss_start, loss_end,
230        );
231        self.read_loss(losses, seq_len, loss_start, loss_end)
232    }
233
234    /// Compute backward cross-entropy gradient IN-PLACE into logits buffer.
235    ///
236    /// After this call, logits[i] = (softmax(logits)[i] - one_hot(label)[i]) * scale
237    pub fn backward(
238        &self,
239        logits: &wgpu::Buffer, // [seq_len, vocab_size] — overwritten with gradient
240        labels: &wgpu::Buffer, // [seq_len] u32
241        logsumexp: &wgpu::Buffer, // [seq_len] from forward
242        seq_len: u32,
243        vocab_size: u32,
244        loss_start: u32,
245        loss_end: u32,
246    ) {
247        let num_tokens = (loss_end - loss_start).max(1);
248        let scale = 1.0 / num_tokens as f32;
249
250        let params = CEBackwardParams {
251            seq_len,
252            vocab_size,
253            loss_start,
254            loss_end,
255            scale,
256            _pad0: 0,
257            _pad1: 0,
258            _pad2: 0,
259        };
260        let params_buf = self.make_uniform(&params);
261
262        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
263            label: None,
264            layout: &self.backward_bgl,
265            entries: &[
266                wgpu::BindGroupEntry { binding: 0, resource: logits.as_entire_binding() },
267                wgpu::BindGroupEntry { binding: 1, resource: labels.as_entire_binding() },
268                wgpu::BindGroupEntry { binding: 2, resource: logsumexp.as_entire_binding() },
269                wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
270            ],
271        });
272
273        let mut encoder = self.device.create_command_encoder(&Default::default());
274        {
275            let mut pass = encoder.begin_compute_pass(&Default::default());
276            pass.set_pipeline(&self.backward_pipeline);
277            pass.set_bind_group(0, &bg, &[]);
278            let total = seq_len * vocab_size;
279            let workgroups = total.div_ceil(256);
280            if workgroups <= 65535 {
281                pass.dispatch_workgroups(workgroups, 1, 1);
282            } else {
283                // 2D dispatch for large tensors (>65535 workgroups)
284                let x = 65535u32;
285                let y = workgroups.div_ceil(x);
286                pass.dispatch_workgroups(x, y, 1);
287            }
288        }
289        self.queue.submit(Some(encoder.finish()));
290    }
291
292    fn make_uniform<T: bytemuck::Pod>(&self, data: &T) -> wgpu::Buffer {
293        let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
294            label: None,
295            size: std::mem::size_of::<T>() as u64,
296            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
297            mapped_at_creation: false,
298        });
299        self.queue.write_buffer(&buf, 0, bytemuck::bytes_of(data));
300        buf
301    }
302}
303
304#[cfg(test)]
305#[cfg(feature = "gpu")]
306mod tests {
307    use super::*;
308
309    /// FALSIFY-FCE-001: Fused CE matches naive cross-entropy
310    #[test]
311    fn test_fused_ce_matches_naive() {
312        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
313        let adapter = match trueno::backends::gpu::runtime::block_on(
314            instance.request_adapter(&wgpu::RequestAdapterOptions::default()),
315        ) {
316            Ok(a) => a,
317            Err(_) => return,
318        };
319        let (device, queue) = match trueno::backends::gpu::runtime::block_on(
320            adapter.request_device(&wgpu::DeviceDescriptor::default()),
321        ) {
322            Ok(dq) => dq,
323            Err(_) => return,
324        };
325
326        let ce = WgslCrossEntropy::new(device.clone(), queue.clone());
327
328        let seq_len = 4u32;
329        let vocab = 8u32;
330
331        // Random logits
332        let logits_data: Vec<f32> =
333            (0..seq_len * vocab).map(|i| ((i as f32) * 0.3).sin()).collect();
334        let labels_data: Vec<u32> = vec![2, 5, 1, 7]; // target tokens
335
336        let buf = |data: &[u8], rw: bool| -> wgpu::Buffer {
337            let buffer = device.create_buffer(&wgpu::BufferDescriptor {
338                label: None,
339                size: data.len() as u64,
340                usage: wgpu::BufferUsages::STORAGE
341                    | wgpu::BufferUsages::COPY_SRC
342                    | wgpu::BufferUsages::COPY_DST
343                    | if rw { wgpu::BufferUsages::empty() } else { wgpu::BufferUsages::empty() },
344                mapped_at_creation: false,
345            });
346            queue.write_buffer(&buffer, 0, data);
347            buffer
348        };
349
350        let logits = buf(bytemuck::cast_slice(&logits_data), true);
351        let labels = buf(bytemuck::cast_slice(&labels_data), false);
352        let losses = buf(&vec![0u8; seq_len as usize * 4], true);
353        let logsumexp_buf = buf(&vec![0u8; seq_len as usize * 4], true);
354
355        // All tokens are response (loss_start=0, loss_end=4)
356        let gpu_loss =
357            ce.forward(&logits, &labels, &losses, &logsumexp_buf, seq_len, vocab, 0, seq_len);
358
359        // CPU reference
360        let mut cpu_loss = 0.0f32;
361        for pos in 0..seq_len as usize {
362            let offset = pos * vocab as usize;
363            let label = labels_data[pos] as usize;
364            let max_val: f32 = logits_data[offset..offset + vocab as usize]
365                .iter()
366                .copied()
367                .fold(f32::NEG_INFINITY, f32::max);
368            let sum_exp: f32 = logits_data[offset..offset + vocab as usize]
369                .iter()
370                .map(|x| (x - max_val).exp())
371                .sum();
372            let lse = max_val + sum_exp.ln();
373            cpu_loss += -logits_data[offset + label] + lse;
374        }
375        cpu_loss /= seq_len as f32;
376
377        let err = (gpu_loss - cpu_loss).abs();
378        eprintln!("[PARITY] Fused CE: gpu={gpu_loss:.6}, cpu={cpu_loss:.6}, err={err:.6}");
379        assert!(err < 1e-4, "Fused CE parity failed: err={err}");
380    }
381}