1use crate::autograd::Tensor;
20use crate::lora::LoRALayer;
21use crate::transformer::config::TransformerConfig;
22use crate::transformer::model::Transformer;
23use std::cell::RefCell;
24use std::sync::Arc;
25use trueno::backends::gpu::{wgpu, GpuCommandBatch, GpuDevice, PipelineCache};
26
27struct GpuResidentFfnWeights {
33 w_gate: Arc<wgpu::Buffer>,
35 w_up: Arc<wgpu::Buffer>,
37 w_down: Arc<wgpu::Buffer>,
39 gate_up_elements: usize,
41 down_elements: usize,
43}
44
45pub struct WgpuForwardPass {
54 device: GpuDevice,
55 config: TransformerConfig,
56 num_layers: usize,
58 ffn_weights: Vec<GpuResidentFfnWeights>,
62 pipeline_cache: RefCell<PipelineCache>,
67}
68
69impl WgpuForwardPass {
70 pub fn new(config: &TransformerConfig, adapter_index: u32) -> Result<Self, String> {
79 let device = GpuDevice::new_with_adapter_index(adapter_index)?;
80
81 Ok(Self {
82 device,
83 config: config.clone(),
84 num_layers: config.num_hidden_layers,
85 ffn_weights: Vec::new(),
86 pipeline_cache: RefCell::new(PipelineCache::new()),
87 })
88 }
89
90 pub fn new_default(config: &TransformerConfig) -> Result<Self, String> {
92 let device = GpuDevice::new()?;
93
94 Ok(Self {
95 device,
96 config: config.clone(),
97 num_layers: config.num_hidden_layers,
98 ffn_weights: Vec::new(),
99 pipeline_cache: RefCell::new(PipelineCache::new()),
100 })
101 }
102
103 pub fn with_resident_weights(model: &Transformer) -> Result<Self, String> {
121 contract_pre_with_resident_weights!();
122 let device = GpuDevice::new()?;
123 let config = model.config.clone();
124 let num_layers = config.num_hidden_layers;
125 let hidden_size = config.hidden_size;
126 let intermediate_size = config.intermediate_size;
127 let gate_up_elements = hidden_size * intermediate_size;
128 let down_elements = intermediate_size * hidden_size;
129
130 let mut ffn_weights = Vec::with_capacity(num_layers);
131 let mut total_bytes: usize = 0;
132
133 for (i, layer) in model.layers.iter().enumerate() {
134 let gate_data = layer.ffn.w_gate.data();
135 let gate_slice = gate_data
136 .as_slice()
137 .ok_or_else(|| format!("Layer {i}: gate weight not contiguous"))?;
138 let up_data = layer.ffn.w_up.data();
139 let up_slice =
140 up_data.as_slice().ok_or_else(|| format!("Layer {i}: up weight not contiguous"))?;
141 let down_data = layer.ffn.w_down.data();
142 let down_slice = down_data
143 .as_slice()
144 .ok_or_else(|| format!("Layer {i}: down weight not contiguous"))?;
145
146 let w_gate = Arc::new(device.device.create_buffer(&wgpu::BufferDescriptor {
147 label: Some(&format!("ffn_gate_L{i}")),
148 size: (gate_slice.len() * 4) as u64,
149 usage: wgpu::BufferUsages::STORAGE
150 | wgpu::BufferUsages::COPY_SRC
151 | wgpu::BufferUsages::COPY_DST,
152 mapped_at_creation: false,
153 }));
154 device.queue.write_buffer(&w_gate, 0, bytemuck::cast_slice(gate_slice));
155
156 let w_up = Arc::new(device.device.create_buffer(&wgpu::BufferDescriptor {
157 label: Some(&format!("ffn_up_L{i}")),
158 size: (up_slice.len() * 4) as u64,
159 usage: wgpu::BufferUsages::STORAGE
160 | wgpu::BufferUsages::COPY_SRC
161 | wgpu::BufferUsages::COPY_DST,
162 mapped_at_creation: false,
163 }));
164 device.queue.write_buffer(&w_up, 0, bytemuck::cast_slice(up_slice));
165
166 let w_down = Arc::new(device.device.create_buffer(&wgpu::BufferDescriptor {
167 label: Some(&format!("ffn_down_L{i}")),
168 size: (down_slice.len() * 4) as u64,
169 usage: wgpu::BufferUsages::STORAGE
170 | wgpu::BufferUsages::COPY_SRC
171 | wgpu::BufferUsages::COPY_DST,
172 mapped_at_creation: false,
173 }));
174 device.queue.write_buffer(&w_down, 0, bytemuck::cast_slice(down_slice));
175
176 let layer_bytes = (gate_slice.len() + up_slice.len() + down_slice.len()) * 4;
177 total_bytes += layer_bytes;
178
179 ffn_weights.push(GpuResidentFfnWeights {
180 w_gate,
181 w_up,
182 w_down,
183 gate_up_elements,
184 down_elements,
185 });
186 }
187
188 let total_mb = total_bytes as f64 / (1024.0 * 1024.0);
189 eprintln!("[wgpu] GPU-resident FFN weights: {num_layers} layers, {total_mb:.1} MB");
190
191 Ok(Self {
192 device,
193 config,
194 num_layers,
195 ffn_weights,
196 pipeline_cache: RefCell::new(PipelineCache::new()),
197 })
198 }
199
200 pub fn forward_hidden(&self, model: &Transformer, token_ids: &[u32]) -> Result<Tensor, String> {
212 let seq_len = token_ids.len();
213 let hidden_size = self.config.hidden_size;
214 let intermediate_size = self.config.intermediate_size;
215
216 let mut hidden = model.embed_tokens.forward(token_ids);
218
219 crate::autograd::suppress_per_op_wgpu();
227 for (layer_idx, layer) in model.layers.iter().enumerate() {
228 let norm1 = layer.input_norm.forward_batched(&hidden, seq_len, hidden_size);
230 let attn_out = layer.self_attn.forward(&norm1, seq_len);
231 let residual1 = crate::autograd::add(&hidden, &attn_out);
232
233 let norm2 = layer.post_attn_norm.forward_batched(&residual1, seq_len, hidden_size);
235
236 let resident = self.ffn_weights.get(layer_idx);
238
239 let ffn_out = self.forward_ffn_gpu(
240 &norm2,
241 &layer.ffn.w_gate,
242 &layer.ffn.w_up,
243 &layer.ffn.w_down,
244 seq_len,
245 hidden_size,
246 intermediate_size,
247 resident,
248 )?;
249
250 hidden = crate::autograd::add(&residual1, &ffn_out);
252 }
253 crate::autograd::unsuppress_per_op_wgpu();
255
256 let normalized = model.norm.forward_batched(&hidden, seq_len, hidden_size);
258
259 Ok(normalized)
260 }
261
262 fn forward_ffn_gpu(
274 &self,
275 input: &Tensor,
276 w_gate: &Tensor,
277 w_up: &Tensor,
278 w_down: &Tensor,
279 seq_len: usize,
280 hidden_size: usize,
281 intermediate_size: usize,
282 resident_weights: Option<&GpuResidentFfnWeights>,
283 ) -> Result<Tensor, String> {
284 use trueno::backends::gpu::runtime;
285
286 runtime::block_on(async {
287 let mut batch = GpuCommandBatch::new(self.device.clone());
288
289 let input_data = input.data();
291 let input_slice = input_data.as_slice().ok_or("Input tensor not contiguous")?;
292 let buf_input = batch.upload(input_slice);
293
294 let (buf_gate, buf_up, buf_down) = if let Some(rw) = resident_weights {
296 let g = batch.import_buffer(Arc::clone(&rw.w_gate), rw.gate_up_elements);
298 let u = batch.import_buffer(Arc::clone(&rw.w_up), rw.gate_up_elements);
299 let d = batch.import_buffer(Arc::clone(&rw.w_down), rw.down_elements);
300 (g, u, d)
301 } else {
302 let gate_data = w_gate.data();
304 let gate_slice = gate_data.as_slice().ok_or("Gate weight not contiguous")?;
305 let up_data = w_up.data();
306 let up_slice = up_data.as_slice().ok_or("Up weight not contiguous")?;
307 let down_data = w_down.data();
308 let down_slice = down_data.as_slice().ok_or("Down weight not contiguous")?;
309 let g = batch.upload(gate_slice);
310 let u = batch.upload(up_slice);
311 let d = batch.upload(down_slice);
312 (g, u, d)
313 };
314
315 let gate_out = batch.matmul(
317 buf_input,
318 buf_gate,
319 seq_len as u32,
320 hidden_size as u32,
321 intermediate_size as u32,
322 );
323
324 let up_out = batch.matmul(
326 buf_input,
327 buf_up,
328 seq_len as u32,
329 hidden_size as u32,
330 intermediate_size as u32,
331 );
332
333 let gate_activated = batch.swish(gate_out);
335 let swiglu_out = batch.mul(gate_activated, up_out);
336
337 let ffn_out = batch.matmul(
339 swiglu_out,
340 buf_down,
341 seq_len as u32,
342 intermediate_size as u32,
343 hidden_size as u32,
344 );
345
346 batch.execute_with_cache(&mut self.pipeline_cache.borrow_mut()).await?;
348
349 let result_data = batch.read(ffn_out).await?;
351
352 Ok(Tensor::from_vec(result_data, false))
353 })
354 }
355
356 pub fn forward_hidden_batch(
372 &self,
373 model: &Transformer,
374 batch_token_ids: &[Vec<u32>],
375 lora_layers: Option<&[LoRALayer]>,
376 ) -> Result<Vec<Tensor>, String> {
377 let hidden_size = self.config.hidden_size;
378 let intermediate_size = self.config.intermediate_size;
379 let n = batch_token_ids.len();
380
381 let mut hiddens: Vec<Tensor> =
383 batch_token_ids.iter().map(|ids| model.embed_tokens.forward(ids)).collect();
384
385 let total_tokens: usize = batch_token_ids.iter().map(std::vec::Vec::len).sum();
390
391 crate::autograd::suppress_per_op_wgpu();
392 for (layer_idx, layer) in model.layers.iter().enumerate() {
393 let mut ffn_input_tensors: Vec<Tensor> = Vec::with_capacity(n);
397 let mut residuals: Vec<Tensor> = Vec::with_capacity(n);
398 for (i, hidden) in hiddens.iter().enumerate() {
399 let seq_len = batch_token_ids[i].len();
400 let norm1 = layer.input_norm.forward_batched(hidden, seq_len, hidden_size);
401
402 let attn_out = match lora_layers {
404 Some(loras) => {
405 let q_idx = layer_idx * 2;
406 let v_idx = layer_idx * 2 + 1;
407 if v_idx < loras.len() {
408 layer.self_attn.forward_with_lora(
409 &norm1,
410 seq_len,
411 loras[q_idx].lora_a(),
412 loras[q_idx].lora_b(),
413 loras[v_idx].lora_a(),
414 loras[v_idx].lora_b(),
415 loras[q_idx].rank(),
416 loras[q_idx].scale(),
417 )
418 } else {
419 layer.self_attn.forward(&norm1, seq_len)
420 }
421 }
422 None => layer.self_attn.forward(&norm1, seq_len),
423 };
424
425 let residual1 = crate::autograd::add(hidden, &attn_out);
426 let norm2 = layer.post_attn_norm.forward_batched(&residual1, seq_len, hidden_size);
427 ffn_input_tensors.push(norm2);
428 residuals.push(residual1);
429 }
430
431 let mut concat_input = Vec::with_capacity(total_tokens * hidden_size);
434 for norm2 in &ffn_input_tensors {
435 let data = norm2.data();
436 concat_input.extend_from_slice(data.as_slice().expect("norm2 contiguous"));
437 }
438 let concat_tensor = Tensor::from_vec(concat_input, false);
439
440 let resident = self.ffn_weights.get(layer_idx);
442
443 let ffn_out = self.forward_ffn_gpu(
444 &concat_tensor,
445 &layer.ffn.w_gate,
446 &layer.ffn.w_up,
447 &layer.ffn.w_down,
448 total_tokens,
449 hidden_size,
450 intermediate_size,
451 resident,
452 )?;
453
454 let ffn_data = ffn_out.data();
456 let ffn_slice = ffn_data.as_slice().expect("ffn contiguous");
457 let mut offset = 0;
458 hiddens = residuals
459 .into_iter()
460 .enumerate()
461 .map(|(i, r)| {
462 let len = batch_token_ids[i].len() * hidden_size;
463 let sample_ffn =
464 Tensor::from_vec(ffn_slice[offset..offset + len].to_vec(), false);
465 offset += len;
466 crate::autograd::add(&r, &sample_ffn)
467 })
468 .collect();
469 }
470 crate::autograd::unsuppress_per_op_wgpu();
471
472 let results: Vec<Tensor> = hiddens
474 .into_iter()
475 .enumerate()
476 .map(|(i, h)| {
477 let seq_len = batch_token_ids[i].len();
478 model.norm.forward_batched(&h, seq_len, hidden_size)
479 })
480 .collect();
481
482 Ok(results)
483 }
484
485 pub fn adapter_info(&self) -> String {
487 format!(
488 "wgpu device ({}x{} model, {} layers)",
489 self.config.hidden_size, self.config.intermediate_size, self.num_layers
490 )
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_wgpu_forward_pass_creation() {
500 if !GpuDevice::is_available() {
501 eprintln!("GPU not available, skipping test");
502 return;
503 }
504
505 let mut config = TransformerConfig::llama2_7b();
506 config.hidden_size = 64;
507 config.num_hidden_layers = 2;
508 config.num_attention_heads = 4;
509 config.num_kv_heads = 4;
510 config.intermediate_size = 128;
511 config.vocab_size = 100;
512
513 let pass = WgpuForwardPass::new_default(&config);
514 assert!(pass.is_ok(), "WgpuForwardPass creation failed: {:?}", pass.err());
515 }
516
517 #[test]
518 fn test_wgpu_ffn_numerical_correctness() {
519 if !GpuDevice::is_available() {
520 eprintln!("GPU not available, skipping test");
521 return;
522 }
523
524 let mut config = TransformerConfig::llama2_7b();
525 config.hidden_size = 8;
526 config.num_hidden_layers = 1;
527 config.num_attention_heads = 2;
528 config.num_kv_heads = 2;
529 config.intermediate_size = 16;
530 config.vocab_size = 32;
531
532 let pass =
533 WgpuForwardPass::new_default(&config).expect("GPU available but creation failed");
534
535 let input = Tensor::from_vec(vec![1.0; 8], false); let w_gate = Tensor::from_vec(vec![0.1; 8 * 16], false);
538 let w_up = Tensor::from_vec(vec![0.1; 8 * 16], false);
539 let w_down = Tensor::from_vec(vec![0.1; 16 * 8], false);
540
541 let gpu_result = pass.forward_ffn_gpu(
542 &input, &w_gate, &w_up, &w_down, 1, 8, 16,
543 None, );
545
546 assert!(gpu_result.is_ok(), "GPU FFN failed: {:?}", gpu_result.err());
547
548 let gpu_data = gpu_result.expect("checked above");
549 assert_eq!(gpu_data.len(), 8, "Output should be 1 × 8");
550
551 for (i, &val) in gpu_data.data().iter().enumerate() {
553 assert!(val.is_finite(), "NaN/Inf at index {i}: {val}");
554 }
555 }
556}