Skip to main content

entrenar/autograd/
wgpu_block.rs

1//! WgpuBlock — GPU-resident transformer block weights via wgpu (zero unsafe)
2//!
3//! Replaces `CudaNf4TransformerBlock` for the WgpuTrainingPipeline.
4//! Each block holds 7 projection weights (pre-dequantized F32 on GPU),
5//! 2 norm weights, and optional LoRA A/B adapters.
6//!
7//! # Architecture (§26 Step 0d.1)
8//!
9//! ```text
10//! WgpuBlock (per transformer layer)
11//! ├── norm weights: input_norm [hidden], post_attn_norm [hidden]
12//! ├── projections: q, k, v, o [hidden × hidden], gate, up [hidden × inter], down [inter × hidden]
13//! ├── LoRA A/B: trainable adapters per projection [hidden × rank] / [rank × out]
14//! └── all wgpu::Buffer, all F32, zero unsafe
15//! ```
16
17#[cfg(feature = "gpu")]
18use trueno::backends::gpu::wgpu;
19
20/// A single transformer layer's weights on GPU via wgpu.
21#[cfg(feature = "gpu")]
22pub struct WgpuBlock {
23    pub layer_idx: usize,
24
25    // Norm weights (F32, small: hidden_size floats each)
26    pub input_norm: wgpu::Buffer,
27    pub post_attn_norm: wgpu::Buffer,
28
29    // Projection weights — pre-dequantized F32 on GPU
30    // Stored row-major: [out_dim, in_dim]
31    pub w_q: wgpu::Buffer,    // [q_dim, hidden]
32    pub w_k: wgpu::Buffer,    // [kv_dim, hidden]
33    pub w_v: wgpu::Buffer,    // [kv_dim, hidden]
34    pub w_o: wgpu::Buffer,    // [hidden, q_dim]
35    pub w_gate: wgpu::Buffer, // [inter, hidden]
36    pub w_up: wgpu::Buffer,   // [inter, hidden]
37    pub w_down: wgpu::Buffer, // [hidden, inter]
38
39    // LoRA adapters (trainable, F32)
40    pub lora: Option<WgpuLoraAdapters>,
41}
42
43/// LoRA adapter weights for all 7 projections in a transformer layer.
44#[cfg(feature = "gpu")]
45pub struct WgpuLoraAdapters {
46    pub rank: u32,
47    pub scale: f32, // alpha / rank
48
49    // Each projection gets A [in_dim, rank] and B [rank, out_dim]
50    pub a_q: wgpu::Buffer,
51    pub b_q: wgpu::Buffer,
52    pub a_k: wgpu::Buffer,
53    pub b_k: wgpu::Buffer,
54    pub a_v: wgpu::Buffer,
55    pub b_v: wgpu::Buffer,
56    pub a_o: wgpu::Buffer,
57    pub b_o: wgpu::Buffer,
58    pub a_gate: wgpu::Buffer,
59    pub b_gate: wgpu::Buffer,
60    pub a_up: wgpu::Buffer,
61    pub b_up: wgpu::Buffer,
62    pub a_down: wgpu::Buffer,
63    pub b_down: wgpu::Buffer,
64
65    // AdamW optimizer states (m and v per trainable param)
66    pub m_states: Vec<wgpu::Buffer>, // 14 buffers (7 A + 7 B)
67    pub v_states: Vec<wgpu::Buffer>, // 14 buffers
68}
69
70/// Manages all transformer blocks on GPU + shared buffers.
71#[cfg(feature = "gpu")]
72pub struct WgpuBlockManager {
73    pub device: wgpu::Device,
74    pub queue: wgpu::Queue,
75    pub blocks: Vec<WgpuBlock>,
76
77    // Shared buffers (reused across layers)
78    pub hidden_buf: wgpu::Buffer, // [max_seq, hidden] — input/output per layer
79    pub hidden_buf2: wgpu::Buffer, // [max_seq, hidden] — residual stream
80    pub attn_out_buf: wgpu::Buffer, // [max_seq, hidden] — attention output
81    pub ffn_gate_buf: wgpu::Buffer, // [max_seq, inter] — FFN gate projection
82    pub ffn_up_buf: wgpu::Buffer, // [max_seq, inter] — FFN up projection
83    pub ffn_silu_buf: wgpu::Buffer, // [max_seq, inter] — SiLU(gate) × up
84    pub norm_buf: wgpu::Buffer,   // [max_seq, hidden] — RMSNorm output
85    pub q_buf: wgpu::Buffer,      // [max_seq, q_dim]
86    pub k_buf: wgpu::Buffer,      // [max_seq, kv_dim]
87    pub v_buf: wgpu::Buffer,      // [max_seq, kv_dim]
88
89    // Embedding + lm_head
90    pub embed_weight: wgpu::Buffer, // [vocab, hidden] — token embedding
91    pub lm_head_weight: wgpu::Buffer, // [vocab, hidden] — output projection (may be tied)
92    pub logits_buf: wgpu::Buffer,   // [max_seq, vocab] — logits output
93
94    // Gradient buffers (for backward pass)
95    pub grad_hidden_buf: wgpu::Buffer, // [max_seq, hidden]
96    pub grad_logits_buf: wgpu::Buffer, // [max_seq, vocab] — may reuse logits_buf
97
98    // Config
99    pub hidden_size: u32,
100    pub intermediate_size: u32,
101    pub num_heads: u32,
102    pub num_kv_heads: u32,
103    pub head_dim: u32,
104    pub max_seq_len: u32,
105    pub vocab_size: u32,
106    pub num_layers: u32,
107}
108
109#[cfg(feature = "gpu")]
110impl WgpuBlockManager {
111    /// Create a new block manager and upload all transformer weights to GPU.
112    ///
113    /// `weights_per_layer` is a closure that returns the F32 weights for each layer.
114    /// This avoids holding all 28 layers in CPU memory simultaneously.
115    pub fn new(
116        device: wgpu::Device,
117        queue: wgpu::Queue,
118        hidden_size: u32,
119        intermediate_size: u32,
120        num_heads: u32,
121        num_kv_heads: u32,
122        head_dim: u32,
123        num_layers: u32,
124        vocab_size: u32,
125        max_seq_len: u32,
126        _lora_rank: Option<u32>,
127        _lora_alpha: Option<f32>,
128    ) -> Self {
129        let q_dim = num_heads * head_dim;
130        let kv_dim = num_kv_heads * head_dim;
131        let max = max_seq_len;
132
133        // Shared buffers
134        let buf = |size: u32, label: &str| -> wgpu::Buffer {
135            device.create_buffer(&wgpu::BufferDescriptor {
136                label: Some(label),
137                size: u64::from(size) * 4,
138                usage: wgpu::BufferUsages::STORAGE
139                    | wgpu::BufferUsages::COPY_SRC
140                    | wgpu::BufferUsages::COPY_DST,
141                mapped_at_creation: false,
142            })
143        };
144
145        Self {
146            blocks: Vec::with_capacity(num_layers as usize),
147            hidden_buf: buf(max * hidden_size, "hidden"),
148            hidden_buf2: buf(max * hidden_size, "hidden2"),
149            attn_out_buf: buf(max * hidden_size, "attn_out"),
150            ffn_gate_buf: buf(max * intermediate_size, "ffn_gate"),
151            ffn_up_buf: buf(max * intermediate_size, "ffn_up"),
152            ffn_silu_buf: buf(max * intermediate_size, "ffn_silu"),
153            norm_buf: buf(max * hidden_size, "norm"),
154            q_buf: buf(max * q_dim, "q"),
155            k_buf: buf(max * kv_dim, "k"),
156            v_buf: buf(max * kv_dim, "v"),
157            embed_weight: buf(vocab_size * hidden_size, "embed"),
158            lm_head_weight: buf(vocab_size * hidden_size, "lm_head"),
159            logits_buf: buf(max * vocab_size, "logits"),
160            grad_hidden_buf: buf(max * hidden_size, "grad_hidden"),
161            grad_logits_buf: buf(max * vocab_size, "grad_logits"),
162            hidden_size,
163            intermediate_size,
164            num_heads,
165            num_kv_heads,
166            head_dim,
167            max_seq_len: max,
168            vocab_size,
169            num_layers,
170            device,
171            queue,
172        }
173    }
174
175    /// Upload a single transformer layer's weights to GPU.
176    pub fn upload_layer(
177        &mut self,
178        layer_idx: usize,
179        input_norm: &[f32],
180        post_attn_norm: &[f32],
181        w_q: &[f32],
182        w_k: &[f32],
183        w_v: &[f32],
184        w_o: &[f32],
185        w_gate: &[f32],
186        w_up: &[f32],
187        w_down: &[f32],
188        lora_rank: Option<u32>,
189        lora_scale: Option<f32>,
190    ) {
191        let upload = |data: &[f32], label: &str| -> wgpu::Buffer {
192            let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
193                label: Some(label),
194                size: (data.len() * 4) as u64,
195                usage: wgpu::BufferUsages::STORAGE
196                    | wgpu::BufferUsages::COPY_SRC
197                    | wgpu::BufferUsages::COPY_DST,
198                mapped_at_creation: false,
199            });
200            self.queue.write_buffer(&buffer, 0, bytemuck::cast_slice(data));
201            buffer
202        };
203
204        let prefix = format!("L{layer_idx}");
205
206        let lora = lora_rank.map(|rank| {
207            let scale = lora_scale.unwrap_or(1.0);
208            let h = self.hidden_size as usize;
209            let q = (self.num_heads * self.head_dim) as usize;
210            let kv = (self.num_kv_heads * self.head_dim) as usize;
211            let inter = self.intermediate_size as usize;
212            let r = rank as usize;
213
214            // Kaiming init for A, zero for B
215            let kaiming = |fan_in: usize, len: usize| -> Vec<f32> {
216                let std = (2.0 / fan_in as f32).sqrt();
217                (0..len).map(|i| ((i as f32 * 0.013 + layer_idx as f32).sin() * std)).collect()
218            };
219            let zeros = |len: usize| vec![0.0f32; len];
220
221            let pairs: Vec<(usize, usize, &str)> = vec![
222                (h, q, "q"),
223                (h, kv, "k"),
224                (h, kv, "v"),
225                (q, h, "o"),
226                (h, inter, "gate"),
227                (h, inter, "up"),
228                (inter, h, "down"),
229            ];
230
231            let mut m_states = Vec::with_capacity(14);
232            let mut v_states = Vec::with_capacity(14);
233            let mut a_bufs = Vec::with_capacity(7);
234            let mut b_bufs = Vec::with_capacity(7);
235
236            for (in_d, out_d, name) in &pairs {
237                let a = upload(&kaiming(*in_d, in_d * r), &format!("{prefix}.lora_a_{name}"));
238                let b = upload(&zeros(r * out_d), &format!("{prefix}.lora_b_{name}"));
239                m_states.push(upload(&zeros(in_d * r), &format!("{prefix}.m_a_{name}")));
240                m_states.push(upload(&zeros(r * out_d), &format!("{prefix}.m_b_{name}")));
241                v_states.push(upload(&zeros(in_d * r), &format!("{prefix}.v_a_{name}")));
242                v_states.push(upload(&zeros(r * out_d), &format!("{prefix}.v_b_{name}")));
243                a_bufs.push(a);
244                b_bufs.push(b);
245            }
246
247            WgpuLoraAdapters {
248                rank,
249                scale,
250                a_q: a_bufs.remove(0),
251                b_q: b_bufs.remove(0),
252                a_k: a_bufs.remove(0),
253                b_k: b_bufs.remove(0),
254                a_v: a_bufs.remove(0),
255                b_v: b_bufs.remove(0),
256                a_o: a_bufs.remove(0),
257                b_o: b_bufs.remove(0),
258                a_gate: a_bufs.remove(0),
259                b_gate: b_bufs.remove(0),
260                a_up: a_bufs.remove(0),
261                b_up: b_bufs.remove(0),
262                a_down: a_bufs.remove(0),
263                b_down: b_bufs.remove(0),
264                m_states,
265                v_states,
266            }
267        });
268
269        self.blocks.push(WgpuBlock {
270            layer_idx,
271            input_norm: upload(input_norm, &format!("{prefix}.input_norm")),
272            post_attn_norm: upload(post_attn_norm, &format!("{prefix}.post_attn_norm")),
273            w_q: upload(w_q, &format!("{prefix}.q_proj")),
274            w_k: upload(w_k, &format!("{prefix}.k_proj")),
275            w_v: upload(w_v, &format!("{prefix}.v_proj")),
276            w_o: upload(w_o, &format!("{prefix}.o_proj")),
277            w_gate: upload(w_gate, &format!("{prefix}.gate_proj")),
278            w_up: upload(w_up, &format!("{prefix}.up_proj")),
279            w_down: upload(w_down, &format!("{prefix}.down_proj")),
280            lora,
281        });
282
283        eprintln!(
284            "[wgpu] Uploaded layer {}/{} ({})",
285            layer_idx + 1,
286            self.num_layers,
287            if self.blocks.last().unwrap().lora.is_some() { "with LoRA" } else { "frozen" }
288        );
289    }
290
291    /// Upload embedding + lm_head weights.
292    pub fn upload_embeddings(&mut self, embed: &[f32], lm_head: &[f32]) {
293        self.queue.write_buffer(&self.embed_weight, 0, bytemuck::cast_slice(embed));
294        self.queue.write_buffer(&self.lm_head_weight, 0, bytemuck::cast_slice(lm_head));
295        eprintln!(
296            "[wgpu] Uploaded embeddings: embed=[{}×{}], lm_head=[{}×{}]",
297            self.vocab_size, self.hidden_size, self.vocab_size, self.hidden_size
298        );
299    }
300
301    /// Total GPU memory used (approximate, in bytes).
302    pub fn gpu_memory_bytes(&self) -> u64 {
303        let h = u64::from(self.hidden_size);
304        let inter = u64::from(self.intermediate_size);
305        let q = u64::from(self.num_heads * self.head_dim);
306        let kv = u64::from(self.num_kv_heads * self.head_dim);
307        let v = u64::from(self.vocab_size);
308        let s = u64::from(self.max_seq_len);
309        let l = u64::from(self.num_layers);
310
311        // Per layer: norms + 7 projections + optional LoRA
312        let per_layer_weights =
313            (2 * h + q * h + kv * h * 2 + h * q + inter * h * 2 + h * inter) * 4;
314        let shared_bufs =
315            (s * h * 4 + s * inter * 3 + s * q + s * kv * 2 + s * v * 2 + v * h * 2) * 4;
316
317        per_layer_weights * l + shared_bufs
318    }
319
320    /// Number of uploaded layers.
321    pub fn layer_count(&self) -> usize {
322        self.blocks.len()
323    }
324}
325
326#[cfg(test)]
327#[cfg(feature = "gpu")]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn test_wgpu_block_manager_creation() {
333        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
334        let adapter = match trueno::backends::gpu::runtime::block_on(
335            instance.request_adapter(&wgpu::RequestAdapterOptions::default()),
336        ) {
337            Ok(a) => a,
338            Err(_) => return, // No GPU
339        };
340        let (device, queue) = match trueno::backends::gpu::runtime::block_on(
341            adapter.request_device(&wgpu::DeviceDescriptor::default()),
342        ) {
343            Ok(dq) => dq,
344            Err(_) => return,
345        };
346
347        let mut mgr = WgpuBlockManager::new(
348            device,
349            queue,
350            64,        // hidden
351            128,       // inter
352            4,         // heads
353            4,         // kv_heads
354            16,        // head_dim
355            2,         // layers
356            100,       // vocab
357            32,        // max_seq
358            Some(8),   // rank
359            Some(2.0), // alpha
360        );
361
362        // Upload 2 layers
363        for i in 0..2 {
364            let h = 64;
365            let inter = 128;
366            let q_dim = 4 * 16;
367            let kv_dim = 4 * 16;
368            mgr.upload_layer(
369                i,
370                &vec![1.0; h],           // input_norm
371                &vec![1.0; h],           // post_attn_norm
372                &vec![0.01; q_dim * h],  // w_q
373                &vec![0.01; kv_dim * h], // w_k
374                &vec![0.01; kv_dim * h], // w_v
375                &vec![0.01; h * q_dim],  // w_o
376                &vec![0.01; inter * h],  // w_gate
377                &vec![0.01; inter * h],  // w_up
378                &vec![0.01; h * inter],  // w_down
379                Some(8),
380                Some(2.0 / 8.0),
381            );
382        }
383
384        assert_eq!(mgr.layer_count(), 2);
385        assert!(mgr.gpu_memory_bytes() > 0);
386    }
387}