1#![allow(dead_code, clippy::many_single_char_names)]
2use std::collections::HashMap;
13
14pub struct LayerActivations {
19 pub attn_norm_out: wgpu::Buffer,
21 pub attn_output: wgpu::Buffer,
23 pub ffn_norm_out: wgpu::Buffer,
25 pub silu_gate_output: wgpu::Buffer,
27 pub rstd_attn: wgpu::Buffer,
29 pub rstd_ffn: wgpu::Buffer,
31 pub softmax_logsumexp: wgpu::Buffer,
33}
34
35pub struct QkvLoRA<'a> {
37 pub q_a: &'a wgpu::Buffer,
38 pub q_b: &'a wgpu::Buffer,
39 pub k_a: &'a wgpu::Buffer,
40 pub k_b: &'a wgpu::Buffer,
41 pub v_a: &'a wgpu::Buffer,
42 pub v_b: &'a wgpu::Buffer,
43 pub rank: u32,
44 pub scale: f32,
45 pub in_dim: u32,
46 pub q_dim: u32,
47 pub kv_dim: u32,
48 pub lora_pipeline: &'a wgpu::ComputePipeline,
49 pub lora_bgl: &'a wgpu::BindGroupLayout,
50}
51
52pub struct WgslForwardPass {
55 device: wgpu::Device,
56 queue: wgpu::Queue,
57
58 matmul_pipeline: wgpu::ComputePipeline,
60 tiled_matmul_pipeline: wgpu::ComputePipeline,
62 gemv_pipeline: wgpu::ComputePipeline,
64 q4k_gemv_pipeline: wgpu::ComputePipeline,
66 attention_pipeline: wgpu::ComputePipeline,
68 attention_bgl: wgpu::BindGroupLayout,
69 rmsnorm_pipeline: wgpu::ComputePipeline,
70 silu_mul_pipeline: wgpu::ComputePipeline,
71 rope_pipeline: wgpu::ComputePipeline,
72 batch_rope_pipeline: wgpu::ComputePipeline,
73 batch_rope_bgl: wgpu::BindGroupLayout,
74 residual_pipeline: wgpu::ComputePipeline,
75
76 matmul_bgl: wgpu::BindGroupLayout,
78 elementwise_bgl: wgpu::BindGroupLayout,
79
80 weight_buffers: HashMap<String, wgpu::Buffer>,
82 q4k_weights: HashMap<String, wgpu::Buffer>,
84 cpu_biases: HashMap<String, Vec<f32>>,
86 kv_cache_k: Vec<wgpu::Buffer>,
88 kv_cache_v: Vec<wgpu::Buffer>,
90
91 hidden_buf: wgpu::Buffer, q_buf: wgpu::Buffer, k_buf: wgpu::Buffer, v_buf: wgpu::Buffer, attn_out_buf: wgpu::Buffer, ffn_gate_buf: wgpu::Buffer, ffn_up_buf: wgpu::Buffer, ffn_silu_buf: wgpu::Buffer, ffn_out_buf: wgpu::Buffer, norm_buf: wgpu::Buffer, staging_buf: wgpu::Buffer, hidden_dim: u32,
107 num_heads: u32,
108 num_kv_heads: u32,
109 head_dim: u32,
110 intermediate_dim: u32,
111}
112
113const RMSNORM_SHADER: &str = r#"
116@group(0) @binding(0) var<storage, read> input: array<f32>;
117@group(0) @binding(1) var<storage, read> weight: array<f32>;
118@group(0) @binding(2) var<storage, read_write> output: array<f32>;
119@group(0) @binding(3) var<uniform> params: vec4<u32>; // (dim, 0, 0, 0)
120
121var<workgroup> shared_sum: array<f32, 256>;
122
123@compute @workgroup_size(256)
124fn main(@builtin(local_invocation_id) lid: vec3<u32>,
125 @builtin(workgroup_id) wg_id: vec3<u32>) {
126 let dim = params.x;
127 let row = wg_id.y;
128 let base = row * dim;
129 let tid = lid.x;
130
131 // Compute sum of squares (reduction) for this row
132 var local_sum: f32 = 0.0;
133 var i = tid;
134 while (i < dim) {
135 let val = input[base + i];
136 local_sum += val * val;
137 i += 256u;
138 }
139 shared_sum[tid] = local_sum;
140 workgroupBarrier();
141
142 // Tree reduction
143 var stride = 128u;
144 while (stride > 0u) {
145 if (tid < stride) {
146 shared_sum[tid] += shared_sum[tid + stride];
147 }
148 workgroupBarrier();
149 stride >>= 1u;
150 }
151
152 let rms = sqrt(shared_sum[0] / f32(dim) + 1e-6);
153
154 // Normalize and scale
155 i = tid;
156 while (i < dim) {
157 output[base + i] = (input[base + i] / rms) * weight[i];
158 i += 256u;
159 }
160}
161"#;
162
163const SILU_MUL_SHADER: &str = r#"
165@group(0) @binding(0) var<storage, read> gate: array<f32>;
166@group(0) @binding(1) var<storage, read> up: array<f32>;
167@group(0) @binding(2) var<storage, read_write> output: array<f32>;
168@group(0) @binding(3) var<uniform> params: vec4<u32>; // (dim, 0, 0, 0)
169
170@compute @workgroup_size(256)
171fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
172 let idx = gid.x;
173 if (idx >= params.x) { return; }
174 let g = gate[idx];
175 let silu_g = g / (1.0 + exp(-g));
176 output[idx] = silu_g * up[idx];
177}
178"#;
179
180const RESIDUAL_SHADER: &str = r#"
182@group(0) @binding(0) var<storage, read> a: array<f32>;
183@group(0) @binding(1) var<storage, read> b: array<f32>;
184@group(0) @binding(2) var<storage, read_write> output: array<f32>;
185@group(0) @binding(3) var<uniform> params: vec4<u32>; // (dim, 0, 0, 0)
186
187@compute @workgroup_size(256)
188fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
189 let idx = gid.x;
190 if (idx >= params.x) { return; }
191 output[idx] = a[idx] + b[idx];
192}
193"#;
194
195const BATCH_ROPE_SHADER: &str = r#"
199@group(0) @binding(0) var<storage, read_write> qk: array<f32>;
200
201struct RopeParams {
202 seq_len: u32,
203 num_heads: u32,
204 head_dim: u32,
205 _pad: u32,
206}
207
208@group(0) @binding(1) var<uniform> params: RopeParams;
209
210@compute @workgroup_size(256)
211fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
212 let idx = gid.x;
213 let total = params.seq_len * params.num_heads * params.head_dim;
214 if (idx >= total) { return; }
215
216 let head_dim = params.head_dim;
217 let half_hd = head_dim / 2u;
218
219 // Decompose idx into (position, head, pos_in_head)
220 let elements_per_pos = params.num_heads * head_dim;
221 let position = idx / elements_per_pos;
222 let within_pos = idx % elements_per_pos;
223 let head_idx = within_pos / head_dim;
224 let pos_in_head = within_pos % head_dim;
225
226 // Only process the first half of each head (pairs with second half)
227 if (pos_in_head >= half_hd) { return; }
228
229 let theta = pow(1000000.0, -f32(pos_in_head * 2u) / f32(head_dim));
230 let angle = f32(position) * theta;
231 let cos_a = cos(angle);
232 let sin_a = sin(angle);
233
234 let base = position * elements_per_pos + head_idx * head_dim;
235 let i0 = base + pos_in_head;
236 let i1 = i0 + half_hd;
237
238 let x0 = qk[i0];
239 let x1 = qk[i1];
240 qk[i0] = x0 * cos_a - x1 * sin_a;
241 qk[i1] = x0 * sin_a + x1 * cos_a;
242}
243"#;
244
245const ROPE_SHADER: &str = r#"
247@group(0) @binding(0) var<storage, read_write> qk: array<f32>;
248@group(0) @binding(1) var<uniform> params: vec4<u32>; // (dim, position, num_heads, head_dim)
249
250@compute @workgroup_size(256)
251fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
252 let idx = gid.x;
253 let dim = params.x;
254 let position = params.y;
255 let head_dim = params.w;
256
257 if (idx >= dim) { return; }
258
259 let half_hd = head_dim / 2u;
260 let head_idx = idx / head_dim;
261 let pos_in_head = idx % head_dim;
262
263 if (pos_in_head >= half_hd) { return; }
264
265 let theta = pow(1000000.0, -f32(pos_in_head * 2u) / f32(head_dim));
266 let angle = f32(position) * theta;
267 let cos_a = cos(angle);
268 let sin_a = sin(angle);
269
270 let i0 = head_idx * head_dim + pos_in_head;
271 let i1 = i0 + half_hd;
272
273 let x0 = qk[i0];
274 let x1 = qk[i1];
275 qk[i0] = x0 * cos_a - x1 * sin_a;
276 qk[i1] = x0 * sin_a + x1 * cos_a;
277}
278"#;
279
280impl WgslForwardPass {
281 pub fn rmsnorm_shader() -> &'static str {
283 RMSNORM_SHADER
284 }
285 pub fn silu_mul_shader() -> &'static str {
286 SILU_MUL_SHADER
287 }
288 pub fn residual_shader() -> &'static str {
289 RESIDUAL_SHADER
290 }
291 pub fn rope_shader() -> &'static str {
292 ROPE_SHADER
293 }
294
295 pub fn new(
300 device: wgpu::Device,
301 queue: wgpu::Queue,
302 hidden_dim: usize,
303 num_heads: usize,
304 num_kv_heads: usize,
305 head_dim: usize,
306 intermediate_dim: usize,
307 ) -> Self {
308 let q_dim = num_heads * head_dim;
309 let kv_dim = num_kv_heads * head_dim;
310
311 let matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
313 label: Some("matmul"),
314 source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::MATMUL_SHADER.into()),
315 });
316 let rmsnorm_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
317 label: Some("rmsnorm"),
318 source: wgpu::ShaderSource::Wgsl(RMSNORM_SHADER.into()),
319 });
320 let silu_mul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
321 label: Some("silu_mul"),
322 source: wgpu::ShaderSource::Wgsl(SILU_MUL_SHADER.into()),
323 });
324 let rope_shader_mod = device.create_shader_module(wgpu::ShaderModuleDescriptor {
325 label: Some("rope"),
326 source: wgpu::ShaderSource::Wgsl(ROPE_SHADER.into()),
327 });
328 let residual_shader_mod = device.create_shader_module(wgpu::ShaderModuleDescriptor {
329 label: Some("residual"),
330 source: wgpu::ShaderSource::Wgsl(RESIDUAL_SHADER.into()),
331 });
332
333 let matmul_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
335 label: Some("matmul_bgl"),
336 entries: &[
337 bgl_storage(0, true),
338 bgl_storage(1, true),
339 bgl_storage(2, false),
340 bgl_uniform(3),
341 ],
342 });
343 let elementwise_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
344 label: Some("ew_bgl"),
345 entries: &[
346 bgl_storage(0, true),
347 bgl_storage(1, true),
348 bgl_storage(2, false),
349 bgl_uniform(3),
350 ],
351 });
352
353 let make_pipeline =
355 |shader: &wgpu::ShaderModule, bgl: &wgpu::BindGroupLayout, label: &str| {
356 let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
357 label: Some(label),
358 bind_group_layouts: &[bgl],
359 push_constant_ranges: &[],
360 });
361 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
362 label: Some(label),
363 layout: Some(&pl),
364 module: shader,
365 entry_point: Some("main"),
366 compilation_options: Default::default(),
367 cache: None,
368 })
369 };
370
371 let matmul_pipeline = make_pipeline(&matmul_shader, &matmul_bgl, "matmul_pipe");
372
373 let tiled_matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
375 label: Some("tiled_matmul"),
376 source: wgpu::ShaderSource::Wgsl(
377 crate::backends::gpu::shaders::TILED_GEMM_SHADER.into(),
378 ),
379 });
380 let tiled_matmul_pipeline =
381 make_pipeline(&tiled_matmul_shader, &matmul_bgl, "tiled_matmul_pipe");
382
383 let attention_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
385 label: Some("causal_attention"),
386 source: wgpu::ShaderSource::Wgsl(
387 crate::backends::gpu::shaders::CAUSAL_ATTENTION_SHADER.into(),
388 ),
389 });
390 let attention_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
391 label: Some("attn_bgl"),
392 entries: &[
393 bgl_storage(0, true), bgl_storage(1, true), bgl_storage(2, true), bgl_storage(3, false), bgl_uniform(4), ],
399 });
400 let attention_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
401 label: Some("attn_pl"),
402 bind_group_layouts: &[&attention_bgl],
403 push_constant_ranges: &[],
404 });
405 let attention_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
406 label: Some("attn_pipe"),
407 layout: Some(&attention_pl),
408 module: &attention_shader,
409 entry_point: Some("main"),
410 compilation_options: Default::default(),
411 cache: None,
412 });
413
414 let gemv_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
416 label: Some("gemv"),
417 source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::GEMV_SHADER.into()),
418 });
419 let gemv_pipeline = make_pipeline(&gemv_shader, &matmul_bgl, "gemv_pipe");
420
421 let q4k_gemv_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
423 label: Some("q4k_gemv"),
424 source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::Q4K_GEMV_SHADER.into()),
425 });
426 let q4k_gemv_pipeline = make_pipeline(&q4k_gemv_shader, &matmul_bgl, "q4k_gemv_pipe");
427
428 let rmsnorm_pipeline = make_pipeline(&rmsnorm_shader, &elementwise_bgl, "rmsnorm_pipe");
429 let silu_mul_pipeline = make_pipeline(&silu_mul_shader, &elementwise_bgl, "silu_pipe");
430 let residual_pipeline = make_pipeline(&residual_shader_mod, &elementwise_bgl, "res_pipe");
431
432 let rope_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
434 label: Some("rope_bgl"),
435 entries: &[bgl_storage(0, false), bgl_uniform(1)],
436 });
437 let rope_pipeline = {
438 let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
439 label: Some("rope_pl"),
440 bind_group_layouts: &[&rope_bgl],
441 push_constant_ranges: &[],
442 });
443 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
444 label: Some("rope_pipe"),
445 layout: Some(&pl),
446 module: &rope_shader_mod,
447 entry_point: Some("main"),
448 compilation_options: Default::default(),
449 cache: None,
450 })
451 };
452
453 let batch_rope_shader_mod = device.create_shader_module(wgpu::ShaderModuleDescriptor {
455 label: Some("batch_rope"),
456 source: wgpu::ShaderSource::Wgsl(BATCH_ROPE_SHADER.into()),
457 });
458 let batch_rope_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
459 label: Some("batch_rope_bgl"),
460 entries: &[bgl_storage(0, false), bgl_uniform(1)],
461 });
462 let batch_rope_pipeline = {
463 let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
464 label: Some("batch_rope_pl"),
465 bind_group_layouts: &[&batch_rope_bgl],
466 push_constant_ranges: &[],
467 });
468 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
469 label: Some("batch_rope_pipe"),
470 layout: Some(&pl),
471 module: &batch_rope_shader_mod,
472 entry_point: Some("main"),
473 compilation_options: Default::default(),
474 cache: None,
475 })
476 };
477
478 let buf = |size: usize, label: &str| -> wgpu::Buffer {
480 device.create_buffer(&wgpu::BufferDescriptor {
481 label: Some(label),
482 size: (size * 4) as u64,
483 usage: wgpu::BufferUsages::STORAGE
484 | wgpu::BufferUsages::COPY_SRC
485 | wgpu::BufferUsages::COPY_DST,
486 mapped_at_creation: false,
487 })
488 };
489
490 let max_seq = 2048;
494 let hidden_buf = buf(max_seq * hidden_dim, "hidden");
495 let q_buf = buf(max_seq * q_dim, "q");
496 let k_buf = buf(max_seq * kv_dim, "k");
497 let v_buf = buf(max_seq * kv_dim, "v");
498 let attn_out_buf = buf(max_seq * hidden_dim, "attn_out");
499 let ffn_gate_buf = buf(max_seq * intermediate_dim, "ffn_gate");
500 let ffn_up_buf = buf(max_seq * intermediate_dim, "ffn_up");
501 let ffn_silu_buf = buf(max_seq * intermediate_dim, "ffn_silu");
502 let ffn_out_buf = buf(max_seq * hidden_dim, "ffn_out");
503 let norm_buf = buf(max_seq * hidden_dim, "norm");
504
505 let max_out = max_seq * hidden_dim.max(intermediate_dim);
506 let staging_buf = device.create_buffer(&wgpu::BufferDescriptor {
507 label: Some("staging"),
508 size: (max_out * 4) as u64,
509 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
510 mapped_at_creation: false,
511 });
512
513 Self {
514 device,
515 queue,
516 matmul_pipeline,
517 tiled_matmul_pipeline,
518 attention_pipeline,
519 attention_bgl,
520 gemv_pipeline,
521 q4k_gemv_pipeline,
522 rmsnorm_pipeline,
523 silu_mul_pipeline,
524 rope_pipeline,
525 batch_rope_pipeline,
526 batch_rope_bgl,
527 residual_pipeline,
528 matmul_bgl,
529 elementwise_bgl,
530 weight_buffers: HashMap::new(),
531 q4k_weights: HashMap::new(),
532 kv_cache_k: Vec::new(),
533 kv_cache_v: Vec::new(),
534 cpu_biases: HashMap::new(),
535 hidden_buf,
536 q_buf,
537 k_buf,
538 v_buf,
539 attn_out_buf,
540 ffn_gate_buf,
541 ffn_up_buf,
542 ffn_silu_buf,
543 ffn_out_buf,
544 norm_buf,
545 staging_buf,
546 hidden_dim: hidden_dim as u32,
547 num_heads: num_heads as u32,
548 num_kv_heads: num_kv_heads as u32,
549 head_dim: head_dim as u32,
550 intermediate_dim: intermediate_dim as u32,
551 }
552 }
553
554 pub fn upload_weight(&mut self, name: &str, data: &[f32]) {
557 if name.contains("bias") {
558 self.cpu_biases.insert(name.to_string(), data.to_vec());
560 return;
561 }
562 let size_bytes = (data.len() * 4) as u64;
564 let max_binding = self.device.limits().max_storage_buffer_binding_size as u64;
565 if size_bytes > max_binding {
566 eprintln!(
567 "[wgpu] Skipping weight '{}' ({:.1} MB > {:.1} MB limit) — CPU fallback",
568 name,
569 size_bytes as f64 / 1e6,
570 max_binding as f64 / 1e6
571 );
572 return;
573 }
574 use wgpu::util::DeviceExt;
575 let buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
576 label: Some(name),
577 contents: bytemuck::cast_slice(data),
578 usage: wgpu::BufferUsages::STORAGE,
579 });
580 self.weight_buffers.insert(name.to_string(), buffer);
581 }
582
583 pub fn upload_q4k_weight(&mut self, name: &str, data: &[u8]) {
585 use wgpu::util::DeviceExt;
586 let buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
587 label: Some(name),
588 contents: data,
589 usage: wgpu::BufferUsages::STORAGE,
590 });
591 self.q4k_weights.insert(name.to_string(), buffer);
592 }
593
594 pub fn init_kv_cache(&mut self, num_layers: usize) {
596 let kv_dim = (self.num_kv_heads * self.head_dim) as u64;
597 let max_seq = 2048u64;
598 for _ in 0..num_layers {
599 let k = self.device.create_buffer(&wgpu::BufferDescriptor {
600 label: Some("kv_cache_k"),
601 size: max_seq * kv_dim * 4,
602 usage: wgpu::BufferUsages::STORAGE
603 | wgpu::BufferUsages::COPY_DST
604 | wgpu::BufferUsages::COPY_SRC,
605 mapped_at_creation: false,
606 });
607 let v = self.device.create_buffer(&wgpu::BufferDescriptor {
608 label: Some("kv_cache_v"),
609 size: max_seq * kv_dim * 4,
610 usage: wgpu::BufferUsages::STORAGE
611 | wgpu::BufferUsages::COPY_DST
612 | wgpu::BufferUsages::COPY_SRC,
613 mapped_at_creation: false,
614 });
615 self.kv_cache_k.push(k);
616 self.kv_cache_v.push(v);
617 }
618 }
619
620 pub fn weight_count(&self) -> usize {
622 self.weight_buffers.len()
623 }
624
625 pub fn weight_buffer(&self, name: &str) -> Option<&wgpu::Buffer> {
628 self.weight_buffers.get(name)
629 }
630
631 pub fn device_ref(&self) -> &wgpu::Device {
633 &self.device
634 }
635
636 pub fn queue_ref(&self) -> &wgpu::Queue {
638 &self.queue
639 }
640
641 pub fn hidden_buffer(&self) -> &wgpu::Buffer {
643 &self.hidden_buf
644 }
645
646 pub fn q_buffer(&self) -> &wgpu::Buffer {
648 &self.q_buf
649 }
650
651 pub fn k_buffer(&self) -> &wgpu::Buffer {
653 &self.k_buf
654 }
655
656 pub fn v_buffer(&self) -> &wgpu::Buffer {
658 &self.v_buf
659 }
660
661 pub fn gpu_residual_add(
663 &self,
664 a: &wgpu::Buffer,
665 b: &wgpu::Buffer,
666 output: &wgpu::Buffer,
667 len: u32,
668 ) {
669 let mut encoder = self.device.create_command_encoder(&Default::default());
670 self.encode_residual(&mut encoder, a, b, output, len);
671 self.queue.submit(Some(encoder.finish()));
672 }
673
674 pub fn gpu_rmsnorm(&self, weight: &wgpu::Buffer, output: &wgpu::Buffer, _seq_len: u32) {
677 let mut encoder = self.device.create_command_encoder(&Default::default());
678 self.encode_rmsnorm(&mut encoder, &self.hidden_buf, weight, output, self.hidden_dim);
679 self.queue.submit(Some(encoder.finish()));
680 }
681
682 pub fn download_hidden(&self, len: usize) -> Vec<f32> {
684 let size = (len * 4) as u64;
685 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
686 label: Some("hidden_download"),
687 size,
688 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
689 mapped_at_creation: false,
690 });
691 let mut encoder = self.device.create_command_encoder(&Default::default());
692 encoder.copy_buffer_to_buffer(&self.hidden_buf, 0, &staging, 0, size);
693 self.queue.submit(Some(encoder.finish()));
694
695 let slice = staging.slice(..size);
696 let (tx, rx) = std::sync::mpsc::channel();
697 slice.map_async(wgpu::MapMode::Read, move |r| {
698 tx.send(r).ok();
699 });
700 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
701 rx.recv()
702 .expect("GPU map_async callback channel disconnected")
703 .expect("GPU buffer mapping failed");
704
705 let data = slice.get_mapped_range();
706 let result: Vec<f32> = bytemuck::cast_slice(&data)[..len].to_vec();
707 drop(data);
708 staging.unmap();
709 result
710 }
711
712 pub fn total_vram_bytes(&self) -> usize {
714 let weight_bytes: usize = self.weight_buffers.values().map(|b| b.size() as usize).sum();
715 let intermediate_bytes = (self.hidden_dim as usize * 4) * 4 + (self.num_heads as usize * self.head_dim as usize * 4) + (self.num_kv_heads as usize * self.head_dim as usize * 4) * 2 + (self.intermediate_dim as usize * 4) * 2; weight_bytes + intermediate_bytes
720 }
721
722 #[provable_contracts_macros::contract("wgpu-forward-pass-v1", equation = "rmsnorm_correctness")]
728 pub fn forward_model(
729 &self,
730 token_id: u32,
731 position: usize,
732 num_layers: usize,
733 token_embedding: &[f32],
734 output_norm_weight: &[f32],
735 lm_head_weight: &[f32],
736 vocab_size: usize,
737 eps: f32,
738 kv_caches: &mut Vec<(Vec<f32>, Vec<f32>)>,
739 ) -> Result<Vec<f32>, String> {
740 let hd = self.hidden_dim as usize;
741
742 let embed_start = token_id as usize * hd;
744 if embed_start + hd > token_embedding.len() {
745 return Err(format!(
746 "Token {} out of range (embedding size {})",
747 token_id,
748 token_embedding.len() / hd
749 ));
750 }
751 let mut hidden: Vec<f32> = token_embedding[embed_start..embed_start + hd].to_vec();
752
753 while kv_caches.len() < num_layers {
756 kv_caches.push((Vec::new(), Vec::new()));
757 }
758 for layer_idx in 0..num_layers {
759 let prefix = format!("layer.{layer_idx}");
760 let (ref mut k_cache, ref mut v_cache) = kv_caches[layer_idx];
761 self.forward_layer(&mut hidden, &prefix, position, k_cache, v_cache)?;
762 }
763
764 let rms = (hidden.iter().map(|x| x * x).sum::<f32>() / hd as f32 + eps).sqrt();
766 for i in 0..hd {
767 hidden[i] = (hidden[i] / rms) * output_norm_weight[i];
768 }
769
770 let mut logits = vec![0.0f32; vocab_size];
776 for v in 0..vocab_size {
777 let mut sum = 0.0f32;
778 let row_start = v * hd;
779 for j in 0..hd {
780 sum += lm_head_weight[row_start + j] * hidden[j];
781 }
782 logits[v] = sum;
783 }
784 Ok(logits)
785 }
786
787 pub fn forward_layer(
794 &self,
795 hidden: &mut [f32],
796 layer_prefix: &str,
797 _position: usize,
798 kv_cache_k: &mut Vec<f32>, kv_cache_v: &mut Vec<f32>, ) -> Result<(), String> {
801 let hd = self.hidden_dim;
802
803 self.queue.write_buffer(&self.hidden_buf, 0, bytemuck::cast_slice(hidden));
805
806 let mut encoder = self.device.create_command_encoder(&Default::default());
807
808 let norm_w = self
810 .weight_buffers
811 .get(&format!("{layer_prefix}.attn_norm"))
812 .ok_or_else(|| format!("Missing {layer_prefix}.attn_norm"))?;
813 self.encode_rmsnorm(&mut encoder, &self.hidden_buf, norm_w, &self.norm_buf, hd);
814
815 let q_dim = self.num_heads * self.head_dim;
817 let kv_dim = self.num_kv_heads * self.head_dim;
818
819 self.encode_matmul(
820 &mut encoder,
821 &self.norm_buf,
822 layer_prefix,
823 "q_proj",
824 &self.q_buf,
825 1,
826 hd,
827 q_dim,
828 );
829 self.encode_matmul(
830 &mut encoder,
831 &self.norm_buf,
832 layer_prefix,
833 "k_proj",
834 &self.k_buf,
835 1,
836 hd,
837 kv_dim,
838 );
839 self.encode_matmul(
840 &mut encoder,
841 &self.norm_buf,
842 layer_prefix,
843 "v_proj",
844 &self.v_buf,
845 1,
846 hd,
847 kv_dim,
848 );
849
850 let q_bytes = (q_dim * 4) as u64;
853 let kv_bytes = (kv_dim * 4) as u64;
854
855 let q_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
857 label: Some("q_stg"),
858 size: q_bytes,
859 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
860 mapped_at_creation: false,
861 });
862 let k_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
863 label: Some("k_stg"),
864 size: kv_bytes,
865 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
866 mapped_at_creation: false,
867 });
868 let v_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
869 label: Some("v_stg"),
870 size: kv_bytes,
871 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
872 mapped_at_creation: false,
873 });
874 encoder.copy_buffer_to_buffer(&self.q_buf, 0, &q_staging, 0, q_bytes);
875 encoder.copy_buffer_to_buffer(&self.k_buf, 0, &k_staging, 0, kv_bytes);
876 encoder.copy_buffer_to_buffer(&self.v_buf, 0, &v_staging, 0, kv_bytes);
877 self.queue.submit(Some(encoder.finish()));
878
879 let mut q_data = vec![0.0f32; q_dim as usize];
881 {
882 let slice = q_staging.slice(..q_bytes);
883 let (tx, rx) = std::sync::mpsc::channel();
884 slice.map_async(wgpu::MapMode::Read, move |r| {
885 tx.send(r).ok();
886 });
887 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
888 rx.recv().map_err(|e| format!("q recv: {e}"))?.map_err(|e| format!("q map: {e:?}"))?;
889 let data = slice.get_mapped_range();
890 q_data.copy_from_slice(&bytemuck::cast_slice::<u8, f32>(&data)[..q_dim as usize]);
891 }
892 q_staging.unmap();
893
894 let mut k_data = vec![0.0f32; kv_dim as usize];
896 {
897 let slice = k_staging.slice(..kv_bytes);
898 let (tx, rx) = std::sync::mpsc::channel();
899 slice.map_async(wgpu::MapMode::Read, move |r| {
900 tx.send(r).ok();
901 });
902 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
903 rx.recv().map_err(|e| format!("k recv: {e}"))?.map_err(|e| format!("k map: {e:?}"))?;
904 let data = slice.get_mapped_range();
905 k_data.copy_from_slice(&bytemuck::cast_slice::<u8, f32>(&data)[..kv_dim as usize]);
906 }
907 k_staging.unmap();
908
909 let mut v_data = vec![0.0f32; kv_dim as usize];
911 {
912 let slice = v_staging.slice(..kv_bytes);
913 let (tx, rx) = std::sync::mpsc::channel();
914 slice.map_async(wgpu::MapMode::Read, move |r| {
915 tx.send(r).ok();
916 });
917 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
918 rx.recv().map_err(|e| format!("v recv: {e}"))?.map_err(|e| format!("v map: {e:?}"))?;
919 let data = slice.get_mapped_range();
920 v_data.copy_from_slice(&bytemuck::cast_slice::<u8, f32>(&data)[..kv_dim as usize]);
921 }
922 v_staging.unmap();
923
924 if let Some(q_bias) = self.cpu_biases.get(&format!("{layer_prefix}.q_bias")) {
926 for (q, b) in q_data.iter_mut().zip(q_bias.iter()) {
927 *q += *b;
928 }
929 }
930 if let Some(k_bias) = self.cpu_biases.get(&format!("{layer_prefix}.k_bias")) {
931 for (k, b) in k_data.iter_mut().zip(k_bias.iter()) {
932 *k += *b;
933 }
934 }
935 if let Some(v_bias) = self.cpu_biases.get(&format!("{layer_prefix}.v_bias")) {
936 for (v, b) in v_data.iter_mut().zip(v_bias.iter()) {
937 *v += *b;
938 }
939 }
940
941 let head_dim = self.head_dim as usize;
943 let position = _position; let rope_theta = 1_000_000.0f64; for h in 0..(self.num_heads as usize) {
948 let offset = h * head_dim;
949 let half = head_dim / 2;
950 for i in 0..half {
951 let theta = rope_theta.powf(-((2 * i) as f64) / head_dim as f64);
952 let angle = position as f64 * theta;
953 let cos_a = angle.cos() as f32;
954 let sin_a = angle.sin() as f32;
955 let x0 = q_data[offset + i];
956 let x1 = q_data[offset + i + half];
957 q_data[offset + i] = x0 * cos_a - x1 * sin_a;
958 q_data[offset + i + half] = x0 * sin_a + x1 * cos_a;
959 }
960 }
961
962 for h in 0..(self.num_kv_heads as usize) {
964 let offset = h * head_dim;
965 let half = head_dim / 2;
966 for i in 0..half {
967 let theta = rope_theta.powf(-((2 * i) as f64) / head_dim as f64);
968 let angle = position as f64 * theta;
969 let cos_a = angle.cos() as f32;
970 let sin_a = angle.sin() as f32;
971 let x0 = k_data[offset + i];
972 let x1 = k_data[offset + i + half];
973 k_data[offset + i] = x0 * cos_a - x1 * sin_a;
974 k_data[offset + i + half] = x0 * sin_a + x1 * cos_a;
975 }
976 }
977
978 let head_dim = self.head_dim as usize;
980 let num_heads = self.num_heads as usize;
981 let num_kv_heads = self.num_kv_heads as usize;
982 let kv_dim_usize = kv_dim as usize;
983
984 kv_cache_k.extend_from_slice(&k_data);
985 kv_cache_v.extend_from_slice(&v_data);
986 let seq_len = kv_cache_k.len() / kv_dim_usize;
987
988 let kv_group = num_heads / num_kv_heads;
990 let scale = 1.0 / (head_dim as f32).sqrt();
991 let mut attn_out = vec![0.0f32; q_dim as usize];
992
993 for h in 0..num_heads {
994 let kv_h = h / kv_group;
995 let q_offset = h * head_dim;
996
997 let mut scores = vec![0.0f32; seq_len];
999 for s in 0..seq_len {
1000 let k_offset = s * kv_dim_usize + kv_h * head_dim;
1001 let mut dot = 0.0f32;
1002 for d in 0..head_dim {
1003 dot += q_data[q_offset + d] * kv_cache_k[k_offset + d];
1004 }
1005 scores[s] = dot * scale;
1006 }
1007
1008 let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
1010 let mut sum = 0.0f32;
1011 for s in scores.iter_mut() {
1012 *s = (*s - max_score).exp();
1013 sum += *s;
1014 }
1015 if sum > 0.0 {
1016 for s in scores.iter_mut() {
1017 *s /= sum;
1018 }
1019 }
1020
1021 let out_offset = h * head_dim;
1023 for d in 0..head_dim {
1024 let mut val = 0.0f32;
1025 for s in 0..seq_len {
1026 let v_offset = s * kv_dim_usize + kv_h * head_dim;
1027 val += scores[s] * kv_cache_v[v_offset + d];
1028 }
1029 attn_out[out_offset + d] = val;
1030 }
1031 }
1032
1033 self.queue.write_buffer(&self.q_buf, 0, bytemuck::cast_slice(&attn_out));
1035
1036 let mut encoder = self.device.create_command_encoder(&Default::default());
1038
1039 self.encode_matmul(
1041 &mut encoder,
1042 &self.q_buf,
1043 layer_prefix,
1044 "o_proj",
1045 &self.attn_out_buf,
1046 1,
1047 q_dim,
1048 hd,
1049 );
1050
1051 self.encode_residual(
1053 &mut encoder,
1054 &self.hidden_buf,
1055 &self.attn_out_buf,
1056 &self.ffn_out_buf,
1057 hd,
1058 );
1059
1060 let ffn_norm_w = self
1062 .weight_buffers
1063 .get(&format!("{layer_prefix}.ffn_norm"))
1064 .ok_or_else(|| format!("Missing {layer_prefix}.ffn_norm"))?;
1065 self.encode_rmsnorm(&mut encoder, &self.ffn_out_buf, ffn_norm_w, &self.norm_buf, hd);
1066
1067 let inter = self.intermediate_dim;
1069 self.encode_matmul(
1070 &mut encoder,
1071 &self.norm_buf,
1072 layer_prefix,
1073 "gate_proj",
1074 &self.ffn_gate_buf,
1075 1,
1076 hd,
1077 inter,
1078 );
1079 self.encode_matmul(
1080 &mut encoder,
1081 &self.norm_buf,
1082 layer_prefix,
1083 "up_proj",
1084 &self.ffn_up_buf,
1085 1,
1086 hd,
1087 inter,
1088 );
1089
1090 self.encode_silu_mul(
1096 &mut encoder,
1097 &self.ffn_gate_buf,
1098 &self.ffn_up_buf,
1099 &self.ffn_silu_buf,
1100 inter,
1101 );
1102
1103 self.encode_matmul(
1105 &mut encoder,
1106 &self.ffn_silu_buf,
1107 layer_prefix,
1108 "down_proj",
1109 &self.norm_buf,
1110 1,
1111 inter,
1112 hd,
1113 );
1114
1115 self.encode_residual(&mut encoder, &self.ffn_out_buf, &self.norm_buf, &self.hidden_buf, hd);
1117
1118 encoder.copy_buffer_to_buffer(&self.hidden_buf, 0, &self.staging_buf, 0, (hd * 4) as u64);
1120 self.queue.submit(Some(encoder.finish()));
1121
1122 let slice = self.staging_buf.slice(..(hd as u64 * 4));
1124 let (tx, rx) = std::sync::mpsc::channel();
1125 slice.map_async(wgpu::MapMode::Read, move |r| {
1126 tx.send(r).ok();
1127 });
1128 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
1129 rx.recv().map_err(|e| format!("recv: {e}"))?.map_err(|e| format!("map: {e:?}"))?;
1130 {
1131 let data = slice.get_mapped_range();
1132 hidden.copy_from_slice(
1133 &bytemuck::cast_slice::<u8, f32>(&data)[..self.hidden_dim as usize],
1134 );
1135 }
1136 self.staging_buf.unmap();
1137
1138 Ok(())
1139 }
1140
1141 pub fn encode_forward_layer_training(
1157 &self,
1158 encoder: &mut wgpu::CommandEncoder,
1159 seq_len: u32,
1160 layer_prefix: &str,
1161 saved: &LayerActivations,
1162 lora: Option<&QkvLoRA<'_>>,
1163 ) -> Result<(), String> {
1164 let hd = self.hidden_dim;
1165 let q_dim = self.num_heads * self.head_dim;
1166 let kv_dim = self.num_kv_heads * self.head_dim;
1167 let inter = self.intermediate_dim;
1168 let s = seq_len as usize;
1169
1170 let norm_w = self
1172 .weight_buffers
1173 .get(&format!("{layer_prefix}.attn_norm"))
1174 .ok_or_else(|| format!("Missing {layer_prefix}.attn_norm"))?;
1175 self.encode_rmsnorm(encoder, &self.hidden_buf, norm_w, &self.norm_buf, hd);
1176
1177 encoder.copy_buffer_to_buffer(
1179 &self.norm_buf,
1180 0,
1181 &saved.attn_norm_out,
1182 0,
1183 (s * hd as usize * 4) as u64,
1184 );
1185
1186 self.encode_matmul(
1188 encoder,
1189 &self.norm_buf,
1190 layer_prefix,
1191 "q_proj",
1192 &self.q_buf,
1193 seq_len,
1194 hd,
1195 q_dim,
1196 );
1197 self.encode_matmul(
1198 encoder,
1199 &self.norm_buf,
1200 layer_prefix,
1201 "k_proj",
1202 &self.k_buf,
1203 seq_len,
1204 hd,
1205 kv_dim,
1206 );
1207 self.encode_matmul(
1208 encoder,
1209 &self.norm_buf,
1210 layer_prefix,
1211 "v_proj",
1212 &self.v_buf,
1213 seq_len,
1214 hd,
1215 kv_dim,
1216 );
1217
1218 if let Some(lora) = lora {
1221 self.encode_lora_addmm(
1222 encoder,
1223 &saved.attn_norm_out,
1224 lora.q_a,
1225 lora.q_b,
1226 &self.q_buf,
1227 seq_len,
1228 lora.in_dim,
1229 lora.rank,
1230 lora.q_dim,
1231 lora.scale,
1232 lora.lora_pipeline,
1233 lora.lora_bgl,
1234 );
1235 self.encode_lora_addmm(
1236 encoder,
1237 &saved.attn_norm_out,
1238 lora.k_a,
1239 lora.k_b,
1240 &self.k_buf,
1241 seq_len,
1242 lora.in_dim,
1243 lora.rank,
1244 lora.kv_dim,
1245 lora.scale,
1246 lora.lora_pipeline,
1247 lora.lora_bgl,
1248 );
1249 self.encode_lora_addmm(
1250 encoder,
1251 &saved.attn_norm_out,
1252 lora.v_a,
1253 lora.v_b,
1254 &self.v_buf,
1255 seq_len,
1256 lora.in_dim,
1257 lora.rank,
1258 lora.kv_dim,
1259 lora.scale,
1260 lora.lora_pipeline,
1261 lora.lora_bgl,
1262 );
1263 }
1264
1265 if let Some(q_bias) = self.cpu_biases.get(&format!("{layer_prefix}.q_bias")) {
1267 self.encode_broadcast_bias(encoder, &self.q_buf, q_bias, seq_len);
1268 }
1269 if let Some(k_bias) = self.cpu_biases.get(&format!("{layer_prefix}.k_bias")) {
1270 self.encode_broadcast_bias(encoder, &self.k_buf, k_bias, seq_len);
1271 }
1272 if let Some(v_bias) = self.cpu_biases.get(&format!("{layer_prefix}.v_bias")) {
1273 self.encode_broadcast_bias(encoder, &self.v_buf, v_bias, seq_len);
1274 }
1275
1276 self.encode_batch_rope(encoder, &self.q_buf, seq_len, self.num_heads, self.head_dim);
1278 self.encode_batch_rope(encoder, &self.k_buf, seq_len, self.num_kv_heads, self.head_dim);
1279
1280 self.encode_attention(encoder, seq_len);
1282
1283 encoder.copy_buffer_to_buffer(
1285 &self.attn_out_buf,
1286 0,
1287 &saved.attn_output,
1288 0,
1289 (s * q_dim as usize * 4) as u64,
1290 );
1291
1292 self.encode_matmul(
1294 encoder,
1295 &self.attn_out_buf,
1296 layer_prefix,
1297 "o_proj",
1298 &self.q_buf,
1299 seq_len,
1300 q_dim,
1301 hd,
1302 );
1303
1304 self.encode_residual(
1306 encoder,
1307 &self.hidden_buf,
1308 &self.q_buf,
1309 &self.ffn_out_buf,
1310 hd * seq_len,
1311 );
1312
1313 let ffn_norm_w = self
1315 .weight_buffers
1316 .get(&format!("{layer_prefix}.ffn_norm"))
1317 .ok_or_else(|| format!("Missing {layer_prefix}.ffn_norm"))?;
1318 self.encode_rmsnorm(encoder, &self.ffn_out_buf, ffn_norm_w, &self.norm_buf, hd);
1319
1320 encoder.copy_buffer_to_buffer(
1322 &self.norm_buf,
1323 0,
1324 &saved.ffn_norm_out,
1325 0,
1326 (s * hd as usize * 4) as u64,
1327 );
1328
1329 self.encode_matmul(
1331 encoder,
1332 &self.norm_buf,
1333 layer_prefix,
1334 "gate_proj",
1335 &self.ffn_gate_buf,
1336 seq_len,
1337 hd,
1338 inter,
1339 );
1340 self.encode_matmul(
1341 encoder,
1342 &self.norm_buf,
1343 layer_prefix,
1344 "up_proj",
1345 &self.ffn_up_buf,
1346 seq_len,
1347 hd,
1348 inter,
1349 );
1350
1351 self.encode_silu_mul(
1353 encoder,
1354 &self.ffn_gate_buf,
1355 &self.ffn_up_buf,
1356 &self.ffn_silu_buf,
1357 inter * seq_len,
1358 );
1359
1360 encoder.copy_buffer_to_buffer(
1362 &self.ffn_silu_buf,
1363 0,
1364 &saved.silu_gate_output,
1365 0,
1366 (s * inter as usize * 4) as u64,
1367 );
1368
1369 self.encode_matmul(
1371 encoder,
1372 &self.ffn_silu_buf,
1373 layer_prefix,
1374 "down_proj",
1375 &self.norm_buf,
1376 seq_len,
1377 inter,
1378 hd,
1379 );
1380
1381 self.encode_residual(
1383 encoder,
1384 &self.ffn_out_buf,
1385 &self.norm_buf,
1386 &self.hidden_buf,
1387 hd * seq_len,
1388 );
1389
1390 Ok(())
1391 }
1392
1393 pub fn forward_layer_traced(
1396 &self,
1397 seq_len: u32,
1398 layer_prefix: &str,
1399 saved: &LayerActivations,
1400 lora: Option<&QkvLoRA<'_>>,
1401 ) -> Result<(), String> {
1402 let hd = self.hidden_dim;
1403 let q_dim = self.num_heads * self.head_dim;
1404 let kv_dim = self.num_kv_heads * self.head_dim;
1405 let inter = self.intermediate_dim;
1406 let s = seq_len as usize;
1407
1408 let norm_w = self
1409 .weight_buffers
1410 .get(&format!("{layer_prefix}.attn_norm"))
1411 .ok_or_else(|| format!("Missing {layer_prefix}.attn_norm"))?;
1412
1413 let mut trace = Vec::new();
1414 let mut run = |name: &str, f: &dyn Fn(&mut wgpu::CommandEncoder)| {
1415 let mut enc = self.device.create_command_encoder(&Default::default());
1416 f(&mut enc);
1417 self.queue.submit(Some(enc.finish()));
1418 let t = std::time::Instant::now();
1419 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
1420 trace.push((name.to_string(), t.elapsed().as_millis() as u64));
1421 };
1422
1423 run("rmsnorm1", &|e| self.encode_rmsnorm(e, &self.hidden_buf, norm_w, &self.norm_buf, hd));
1424 {
1425 let mut e = self.device.create_command_encoder(&Default::default());
1426 e.copy_buffer_to_buffer(
1427 &self.norm_buf,
1428 0,
1429 &saved.attn_norm_out,
1430 0,
1431 (s * hd as usize * 4) as u64,
1432 );
1433 self.queue.submit(Some(e.finish()));
1434 }
1435 run("q_proj", &|e| {
1436 self.encode_matmul(
1437 e,
1438 &self.norm_buf,
1439 layer_prefix,
1440 "q_proj",
1441 &self.q_buf,
1442 seq_len,
1443 hd,
1444 q_dim,
1445 )
1446 });
1447 run("k_proj", &|e| {
1448 self.encode_matmul(
1449 e,
1450 &self.norm_buf,
1451 layer_prefix,
1452 "k_proj",
1453 &self.k_buf,
1454 seq_len,
1455 hd,
1456 kv_dim,
1457 )
1458 });
1459 run("v_proj", &|e| {
1460 self.encode_matmul(
1461 e,
1462 &self.norm_buf,
1463 layer_prefix,
1464 "v_proj",
1465 &self.v_buf,
1466 seq_len,
1467 hd,
1468 kv_dim,
1469 )
1470 });
1471 if let Some(lr) = lora {
1472 run("lora_qkv", &|e| {
1473 self.encode_lora_addmm(
1474 e,
1475 &saved.attn_norm_out,
1476 lr.q_a,
1477 lr.q_b,
1478 &self.q_buf,
1479 seq_len,
1480 lr.in_dim,
1481 lr.rank,
1482 lr.q_dim,
1483 lr.scale,
1484 lr.lora_pipeline,
1485 lr.lora_bgl,
1486 );
1487 self.encode_lora_addmm(
1488 e,
1489 &saved.attn_norm_out,
1490 lr.k_a,
1491 lr.k_b,
1492 &self.k_buf,
1493 seq_len,
1494 lr.in_dim,
1495 lr.rank,
1496 lr.kv_dim,
1497 lr.scale,
1498 lr.lora_pipeline,
1499 lr.lora_bgl,
1500 );
1501 self.encode_lora_addmm(
1502 e,
1503 &saved.attn_norm_out,
1504 lr.v_a,
1505 lr.v_b,
1506 &self.v_buf,
1507 seq_len,
1508 lr.in_dim,
1509 lr.rank,
1510 lr.kv_dim,
1511 lr.scale,
1512 lr.lora_pipeline,
1513 lr.lora_bgl,
1514 );
1515 });
1516 }
1517 if let Some(q_bias) = self.cpu_biases.get(&format!("{layer_prefix}.q_bias")) {
1519 run("q_bias", &|e| self.encode_broadcast_bias(e, &self.q_buf, q_bias, seq_len));
1520 }
1521 if let Some(k_bias) = self.cpu_biases.get(&format!("{layer_prefix}.k_bias")) {
1522 run("k_bias", &|e| self.encode_broadcast_bias(e, &self.k_buf, k_bias, seq_len));
1523 }
1524 if let Some(v_bias) = self.cpu_biases.get(&format!("{layer_prefix}.v_bias")) {
1525 run("v_bias", &|e| self.encode_broadcast_bias(e, &self.v_buf, v_bias, seq_len));
1526 }
1527 run("rope_q", &|e| {
1528 self.encode_batch_rope(e, &self.q_buf, seq_len, self.num_heads, self.head_dim)
1529 });
1530 run("rope_k", &|e| {
1531 self.encode_batch_rope(e, &self.k_buf, seq_len, self.num_kv_heads, self.head_dim)
1532 });
1533 run("attention", &|e| self.encode_attention(e, seq_len));
1534 {
1535 let mut e = self.device.create_command_encoder(&Default::default());
1536 e.copy_buffer_to_buffer(
1537 &self.attn_out_buf,
1538 0,
1539 &saved.attn_output,
1540 0,
1541 (s * q_dim as usize * 4) as u64,
1542 );
1543 self.queue.submit(Some(e.finish()));
1544 }
1545 run("o_proj", &|e| {
1546 self.encode_matmul(
1547 e,
1548 &self.attn_out_buf,
1549 layer_prefix,
1550 "o_proj",
1551 &self.q_buf,
1552 seq_len,
1553 q_dim,
1554 hd,
1555 )
1556 });
1557 run("residual1", &|e| {
1558 self.encode_residual(e, &self.hidden_buf, &self.q_buf, &self.ffn_out_buf, hd * seq_len)
1559 });
1560 let ffn_norm_w = self
1561 .weight_buffers
1562 .get(&format!("{layer_prefix}.ffn_norm"))
1563 .ok_or_else(|| format!("Missing {layer_prefix}.ffn_norm"))?;
1564 run("rmsnorm2", &|e| {
1565 self.encode_rmsnorm(e, &self.ffn_out_buf, ffn_norm_w, &self.norm_buf, hd)
1566 });
1567 {
1568 let mut e = self.device.create_command_encoder(&Default::default());
1569 e.copy_buffer_to_buffer(
1570 &self.norm_buf,
1571 0,
1572 &saved.ffn_norm_out,
1573 0,
1574 (s * hd as usize * 4) as u64,
1575 );
1576 self.queue.submit(Some(e.finish()));
1577 }
1578 run("gate_proj", &|e| {
1579 self.encode_matmul(
1580 e,
1581 &self.norm_buf,
1582 layer_prefix,
1583 "gate_proj",
1584 &self.ffn_gate_buf,
1585 seq_len,
1586 hd,
1587 inter,
1588 )
1589 });
1590 run("up_proj", &|e| {
1591 self.encode_matmul(
1592 e,
1593 &self.norm_buf,
1594 layer_prefix,
1595 "up_proj",
1596 &self.ffn_up_buf,
1597 seq_len,
1598 hd,
1599 inter,
1600 )
1601 });
1602 run("silu", &|e| {
1603 self.encode_silu_mul(
1604 e,
1605 &self.ffn_gate_buf,
1606 &self.ffn_up_buf,
1607 &self.ffn_silu_buf,
1608 inter * seq_len,
1609 )
1610 });
1611 {
1612 let mut e = self.device.create_command_encoder(&Default::default());
1613 e.copy_buffer_to_buffer(
1614 &self.ffn_silu_buf,
1615 0,
1616 &saved.silu_gate_output,
1617 0,
1618 (s * inter as usize * 4) as u64,
1619 );
1620 self.queue.submit(Some(e.finish()));
1621 }
1622 run("down_proj", &|e| {
1623 self.encode_matmul(
1624 e,
1625 &self.ffn_silu_buf,
1626 layer_prefix,
1627 "down_proj",
1628 &self.norm_buf,
1629 seq_len,
1630 inter,
1631 hd,
1632 )
1633 });
1634 run("residual2", &|e| {
1635 self.encode_residual(
1636 e,
1637 &self.ffn_out_buf,
1638 &self.norm_buf,
1639 &self.hidden_buf,
1640 hd * seq_len,
1641 )
1642 });
1643
1644 let total: u64 = trace.iter().map(|(_, ms)| ms).sum();
1645 let parts: Vec<String> = trace.iter().map(|(n, ms)| format!("{n}={ms}")).collect();
1646 eprintln!("[OP-TRACE] layer {} total={}ms: {}", layer_prefix, total, parts.join(" "));
1647 Ok(())
1648 }
1649
1650 pub fn alloc_layer_activations(&self, seq_len: u32) -> LayerActivations {
1652 let s = seq_len as usize;
1653 let buf = |size: usize, label: &str| -> wgpu::Buffer {
1654 self.device.create_buffer(&wgpu::BufferDescriptor {
1655 label: Some(label),
1656 size: (size * 4) as u64,
1657 usage: wgpu::BufferUsages::STORAGE
1658 | wgpu::BufferUsages::COPY_SRC
1659 | wgpu::BufferUsages::COPY_DST,
1660 mapped_at_creation: false,
1661 })
1662 };
1663 LayerActivations {
1664 attn_norm_out: buf(s * self.hidden_dim as usize, "saved_attn_norm"),
1665 attn_output: buf(s * (self.num_heads * self.head_dim) as usize, "saved_attn_out"),
1666 ffn_norm_out: buf(s * self.hidden_dim as usize, "saved_ffn_norm"),
1667 silu_gate_output: buf(s * self.intermediate_dim as usize, "saved_silu"),
1668 rstd_attn: buf(s, "saved_rstd_attn"),
1669 rstd_ffn: buf(s, "saved_rstd_ffn"),
1670 softmax_logsumexp: buf(self.num_heads as usize * s, "saved_logsumexp"),
1671 }
1672 }
1673
1674 pub fn forward_layer_training(
1676 &self,
1677 seq_len: u32,
1678 layer_prefix: &str,
1679 ) -> Result<LayerActivations, String> {
1680 let saved = self.alloc_layer_activations(seq_len);
1681 let mut encoder = self.device.create_command_encoder(&Default::default());
1682 self.encode_forward_layer_training(&mut encoder, seq_len, layer_prefix, &saved, None)?;
1683 self.queue.submit(Some(encoder.finish()));
1684 Ok(saved)
1685 }
1686
1687 pub fn forward_all_layers_training(
1689 &self,
1690 seq_len: u32,
1691 num_layers: usize,
1692 ) -> Result<Vec<LayerActivations>, String> {
1693 let mut encoder = self.device.create_command_encoder(&Default::default());
1694 let mut all_saved = Vec::with_capacity(num_layers);
1695
1696 for layer_idx in 0..num_layers {
1697 let prefix = format!("layer.{layer_idx}");
1698 let saved = self.alloc_layer_activations(seq_len);
1699 self.encode_forward_layer_training(&mut encoder, seq_len, &prefix, &saved, None)?;
1700 all_saved.push(saved);
1701 }
1702
1703 self.queue.submit(Some(encoder.finish()));
1705 Ok(all_saved)
1706 }
1707 pub fn encode_broadcast_bias(
1715 &self,
1716 encoder: &mut wgpu::CommandEncoder,
1717 buf: &wgpu::Buffer,
1718 bias: &[f32],
1719 seq_len: u32,
1720 ) {
1721 let dim = bias.len();
1722 let mut full_bias = Vec::with_capacity(seq_len as usize * dim);
1724 for _ in 0..seq_len {
1725 full_bias.extend_from_slice(bias);
1726 }
1727 let bias_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
1728 label: Some("broadcast_bias"),
1729 size: (full_bias.len() * 4) as u64,
1730 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1731 mapped_at_creation: false,
1732 });
1733 self.queue.write_buffer(&bias_buf, 0, bytemuck::cast_slice(&full_bias));
1734
1735 let total = seq_len * dim as u32;
1737 let tmp = self.device.create_buffer(&wgpu::BufferDescriptor {
1738 label: Some("bias_tmp"),
1739 size: (total as usize * 4) as u64,
1740 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1741 mapped_at_creation: false,
1742 });
1743 self.encode_residual(encoder, buf, &bias_buf, &tmp, total);
1744 encoder.copy_buffer_to_buffer(&tmp, 0, buf, 0, (total as u64) * 4);
1745 }
1746
1747 fn encode_batch_rope(
1750 &self,
1751 encoder: &mut wgpu::CommandEncoder,
1752 qk_buf: &wgpu::Buffer,
1753 seq_len: u32,
1754 num_heads: u32,
1755 head_dim: u32,
1756 ) {
1757 #[repr(C)]
1758 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1759 struct RopeParams {
1760 seq_len: u32,
1761 num_heads: u32,
1762 head_dim: u32,
1763 _pad: u32,
1764 }
1765 let params = RopeParams { seq_len, num_heads, head_dim, _pad: 0 };
1766 let params_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
1767 label: Some("batch_rope_params"),
1768 size: 16,
1769 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1770 mapped_at_creation: false,
1771 });
1772 self.queue.write_buffer(¶ms_buf, 0, bytemuck::bytes_of(¶ms));
1773
1774 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1775 label: Some("batch_rope_bg"),
1776 layout: &self.batch_rope_bgl,
1777 entries: &[
1778 wgpu::BindGroupEntry { binding: 0, resource: qk_buf.as_entire_binding() },
1779 wgpu::BindGroupEntry { binding: 1, resource: params_buf.as_entire_binding() },
1780 ],
1781 });
1782 let total = seq_len * num_heads * head_dim;
1783 let wg = total.div_ceil(256);
1784 let mut pass = encoder.begin_compute_pass(&Default::default());
1785 pass.set_pipeline(&self.batch_rope_pipeline);
1786 pass.set_bind_group(0, &bg, &[]);
1787 pass.dispatch_workgroups(wg, 1, 1);
1788 }
1789
1790 fn encode_attention(&self, encoder: &mut wgpu::CommandEncoder, seq_len: u32) {
1791 let params = [seq_len, self.num_heads, self.num_kv_heads, self.head_dim];
1792 let params_buf = self.make_uniform(¶ms);
1793 let _q_dim = self.num_heads * self.head_dim;
1794
1795 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1800 label: None,
1801 layout: &self.attention_bgl,
1802 entries: &[
1803 wgpu::BindGroupEntry { binding: 0, resource: self.q_buf.as_entire_binding() },
1804 wgpu::BindGroupEntry { binding: 1, resource: self.k_buf.as_entire_binding() },
1805 wgpu::BindGroupEntry { binding: 2, resource: self.v_buf.as_entire_binding() },
1806 wgpu::BindGroupEntry {
1807 binding: 3,
1808 resource: self.attn_out_buf.as_entire_binding(),
1809 },
1810 wgpu::BindGroupEntry { binding: 4, resource: params_buf.as_entire_binding() },
1811 ],
1812 });
1813 let mut pass = encoder.begin_compute_pass(&Default::default());
1814 pass.set_pipeline(&self.attention_pipeline);
1815 pass.set_bind_group(0, &bg, &[]);
1816 pass.dispatch_workgroups(self.num_heads, seq_len, 1);
1818 }
1819
1820 #[allow(clippy::too_many_arguments)]
1830 fn encode_lora_addmm(
1831 &self,
1832 encoder: &mut wgpu::CommandEncoder,
1833 input: &wgpu::Buffer,
1834 lora_a: &wgpu::Buffer,
1835 lora_b: &wgpu::Buffer,
1836 output: &wgpu::Buffer,
1837 seq_len: u32,
1838 in_dim: u32,
1839 rank: u32,
1840 out_dim: u32,
1841 scale: f32,
1842 _pipeline: &wgpu::ComputePipeline,
1843 _bgl: &wgpu::BindGroupLayout,
1844 ) {
1845 let temp_size = (seq_len * rank) as u64 * 4;
1847 let temp = self.device.create_buffer(&wgpu::BufferDescriptor {
1848 label: Some("lora_temp"),
1849 size: temp_size,
1850 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1851 mapped_at_creation: false,
1852 });
1853 self.encode_tiled_gemm(encoder, input, lora_a, &temp, seq_len, in_dim, rank, 1.0);
1854
1855 let delta_size = (seq_len * out_dim) as u64 * 4;
1857 let delta = self.device.create_buffer(&wgpu::BufferDescriptor {
1858 label: Some("lora_delta"),
1859 size: delta_size,
1860 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1861 mapped_at_creation: false,
1862 });
1863 self.encode_tiled_gemm(encoder, &temp, lora_b, &delta, seq_len, rank, out_dim, scale);
1864
1865 let sum_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
1867 label: Some("lora_sum"),
1868 size: delta_size,
1869 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1870 mapped_at_creation: false,
1871 });
1872 self.encode_residual(encoder, output, &delta, &sum_buf, seq_len * out_dim);
1873 encoder.copy_buffer_to_buffer(&sum_buf, 0, output, 0, delta_size);
1874 }
1875
1876 fn encode_tiled_gemm(
1878 &self,
1879 encoder: &mut wgpu::CommandEncoder,
1880 a: &wgpu::Buffer,
1881 b: &wgpu::Buffer,
1882 c: &wgpu::Buffer,
1883 m: u32,
1884 k: u32,
1885 n: u32,
1886 alpha: f32,
1887 ) {
1888 let params = [m, k, n, alpha.to_bits()];
1889 let params_buf = self.make_uniform(¶ms);
1890 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1891 label: None,
1892 layout: &self.matmul_bgl,
1893 entries: &[
1894 wgpu::BindGroupEntry { binding: 0, resource: a.as_entire_binding() },
1895 wgpu::BindGroupEntry { binding: 1, resource: b.as_entire_binding() },
1896 wgpu::BindGroupEntry { binding: 2, resource: c.as_entire_binding() },
1897 wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
1898 ],
1899 });
1900 let mut pass = encoder.begin_compute_pass(&Default::default());
1901 pass.set_pipeline(&self.tiled_matmul_pipeline);
1902 pass.set_bind_group(0, &bg, &[]);
1903 pass.dispatch_workgroups(n.div_ceil(64), m.div_ceil(64), 1);
1904 }
1905
1906 fn encode_rmsnorm(
1907 &self,
1908 encoder: &mut wgpu::CommandEncoder,
1909 input: &wgpu::Buffer,
1910 weight: &wgpu::Buffer,
1911 output: &wgpu::Buffer,
1912 dim: u32,
1913 ) {
1914 let params = [dim, 0u32, 0, 0];
1915 let params_buf = self.make_uniform(¶ms);
1916 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1917 label: None,
1918 layout: &self.elementwise_bgl,
1919 entries: &[
1920 wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
1921 wgpu::BindGroupEntry { binding: 1, resource: weight.as_entire_binding() },
1922 wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
1923 wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
1924 ],
1925 });
1926 let num_rows = (input.size() / (dim as u64 * 4)).max(1) as u32;
1929 let mut pass = encoder.begin_compute_pass(&Default::default());
1930 pass.set_pipeline(&self.rmsnorm_pipeline);
1931 pass.set_bind_group(0, &bg, &[]);
1932 pass.dispatch_workgroups(1, num_rows, 1);
1933 }
1934
1935 fn encode_matmul(
1936 &self,
1937 encoder: &mut wgpu::CommandEncoder,
1938 input: &wgpu::Buffer,
1939 layer_prefix: &str,
1940 proj_name: &str,
1941 output: &wgpu::Buffer,
1942 m: u32,
1943 k: u32,
1944 n: u32,
1945 ) {
1946 if m == 1 && self.encode_q4k_gemv(encoder, input, output, layer_prefix, proj_name, n, k) {
1948 return;
1949 }
1950 let weight_key = format!("{layer_prefix}.{proj_name}");
1951 let weight = match self.weight_buffers.get(&weight_key) {
1952 Some(w) => w,
1953 None => return, };
1955 let params = if m == 1 { [n, k, 0u32, 0u32] } else { [m, k, n, 1.0_f32.to_bits()] };
1960 let params_buf = self.make_uniform(¶ms);
1961 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
1962 label: None,
1963 layout: &self.matmul_bgl,
1964 entries: &[
1965 wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
1966 wgpu::BindGroupEntry { binding: 1, resource: weight.as_entire_binding() },
1967 wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
1968 wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
1969 ],
1970 });
1971 let mut pass = encoder.begin_compute_pass(&Default::default());
1972 if m == 1 {
1973 pass.set_pipeline(&self.gemv_pipeline);
1975 pass.set_bind_group(0, &bg, &[]);
1976 pass.dispatch_workgroups(n, 1, 1);
1977 } else if m >= 4 {
1978 pass.set_pipeline(&self.tiled_matmul_pipeline);
1981 pass.set_bind_group(0, &bg, &[]);
1982 pass.dispatch_workgroups(n.div_ceil(64), m.div_ceil(64), 1);
1983 } else {
1984 pass.set_pipeline(&self.matmul_pipeline);
1986 pass.set_bind_group(0, &bg, &[]);
1987 pass.dispatch_workgroups(m.div_ceil(16), n.div_ceil(16), 1);
1988 }
1989 }
1990
1991 fn encode_q4k_gemv(
1995 &self,
1996 encoder: &mut wgpu::CommandEncoder,
1997 input: &wgpu::Buffer,
1998 output: &wgpu::Buffer,
1999 layer_prefix: &str,
2000 proj_name: &str,
2001 n: u32,
2002 k: u32,
2003 ) -> bool {
2004 let weight_key = format!("{layer_prefix}.{proj_name}");
2005 let weight = match self.q4k_weights.get(&weight_key) {
2006 Some(w) => w,
2007 None => return false,
2008 };
2009 let num_superblocks = (k + 255) / 256;
2010 let params = [n, k, num_superblocks, 0u32];
2011 let params_buf = self.make_uniform(¶ms);
2012 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
2013 label: None,
2014 layout: &self.matmul_bgl,
2015 entries: &[
2016 wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
2017 wgpu::BindGroupEntry { binding: 1, resource: weight.as_entire_binding() },
2018 wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
2019 wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
2020 ],
2021 });
2022 let mut pass = encoder.begin_compute_pass(&Default::default());
2023 pass.set_pipeline(&self.q4k_gemv_pipeline);
2024 pass.set_bind_group(0, &bg, &[]);
2025 pass.dispatch_workgroups(n, 1, 1);
2026 true
2027 }
2028
2029 fn encode_silu_mul(
2030 &self,
2031 encoder: &mut wgpu::CommandEncoder,
2032 gate: &wgpu::Buffer,
2033 up: &wgpu::Buffer,
2034 output: &wgpu::Buffer,
2035 dim: u32,
2036 ) {
2037 let params = [dim, 0u32, 0, 0];
2038 let params_buf = self.make_uniform(¶ms);
2039 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
2040 label: None,
2041 layout: &self.elementwise_bgl,
2042 entries: &[
2043 wgpu::BindGroupEntry { binding: 0, resource: gate.as_entire_binding() },
2044 wgpu::BindGroupEntry { binding: 1, resource: up.as_entire_binding() },
2045 wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
2046 wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
2047 ],
2048 });
2049 let mut pass = encoder.begin_compute_pass(&Default::default());
2050 pass.set_pipeline(&self.silu_mul_pipeline);
2051 pass.set_bind_group(0, &bg, &[]);
2052 pass.dispatch_workgroups(dim.div_ceil(256), 1, 1);
2053 }
2054
2055 fn encode_residual(
2056 &self,
2057 encoder: &mut wgpu::CommandEncoder,
2058 a: &wgpu::Buffer,
2059 b: &wgpu::Buffer,
2060 output: &wgpu::Buffer,
2061 dim: u32,
2062 ) {
2063 let params = [dim, 0u32, 0, 0];
2064 let params_buf = self.make_uniform(¶ms);
2065 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
2066 label: None,
2067 layout: &self.elementwise_bgl,
2068 entries: &[
2069 wgpu::BindGroupEntry { binding: 0, resource: a.as_entire_binding() },
2070 wgpu::BindGroupEntry { binding: 1, resource: b.as_entire_binding() },
2071 wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
2072 wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
2073 ],
2074 });
2075 let mut pass = encoder.begin_compute_pass(&Default::default());
2076 pass.set_pipeline(&self.residual_pipeline);
2077 pass.set_bind_group(0, &bg, &[]);
2078 pass.dispatch_workgroups(dim.div_ceil(256), 1, 1);
2079 }
2080
2081 fn make_uniform(&self, data: &[u32; 4]) -> wgpu::Buffer {
2082 use wgpu::util::DeviceExt;
2083 self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
2084 label: None,
2085 contents: bytemuck::cast_slice(data),
2086 usage: wgpu::BufferUsages::UNIFORM,
2087 })
2088 }
2089}
2090
2091fn bgl_storage(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
2092 wgpu::BindGroupLayoutEntry {
2093 binding,
2094 visibility: wgpu::ShaderStages::COMPUTE,
2095 ty: wgpu::BindingType::Buffer {
2096 ty: wgpu::BufferBindingType::Storage { read_only },
2097 has_dynamic_offset: false,
2098 min_binding_size: None,
2099 },
2100 count: None,
2101 }
2102}
2103
2104fn bgl_uniform(binding: u32) -> wgpu::BindGroupLayoutEntry {
2105 wgpu::BindGroupLayoutEntry {
2106 binding,
2107 visibility: wgpu::ShaderStages::COMPUTE,
2108 ty: wgpu::BindingType::Buffer {
2109 ty: wgpu::BufferBindingType::Uniform,
2110 has_dynamic_offset: false,
2111 min_binding_size: None,
2112 },
2113 count: None,
2114 }
2115}