1use super::cuda_tensor::{CudaTensorError, Result};
25
26#[cfg(not(feature = "gpu"))]
28pub struct WgpuTrainer;
29
30#[cfg(not(feature = "gpu"))]
31impl WgpuTrainer {
32 pub fn new() -> Result<Self> {
33 Err(CudaTensorError::CudaNotAvailable("Compiled without GPU support".into()))
34 }
35}
36
37#[cfg(feature = "gpu")]
38use trueno::backends::gpu::wgpu;
39
40#[cfg(feature = "gpu")]
43const GEMM_SHADER: &str = trueno::backends::gpu::shaders::TILED_GEMM_SHADER;
44
45#[cfg(feature = "gpu")]
47const ADAMW_SHADER: &str = r"
48@group(0) @binding(0) var<storage, read_write> params: array<f32>;
49@group(0) @binding(1) var<storage, read> grads: array<f32>;
50@group(0) @binding(2) var<storage, read_write> m_state: array<f32>;
51@group(0) @binding(3) var<storage, read_write> v_state: array<f32>;
52
53struct AdamWParams {
54 lr: f32,
55 beta1: f32,
56 beta2: f32,
57 eps: f32,
58 weight_decay: f32,
59 bias_correction1: f32, // 1 - beta1^t
60 bias_correction2: f32, // 1 - beta2^t
61 n: u32,
62}
63
64@group(0) @binding(4) var<uniform> cfg: AdamWParams;
65
66@compute @workgroup_size(256)
67fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
68 let i = gid.x;
69 if (i >= cfg.n) { return; }
70
71 let g = grads[i];
72 var m = cfg.beta1 * m_state[i] + (1.0 - cfg.beta1) * g;
73 var v = cfg.beta2 * v_state[i] + (1.0 - cfg.beta2) * g * g;
74 m_state[i] = m;
75 v_state[i] = v;
76
77 let m_hat = m / cfg.bias_correction1;
78 let v_hat = v / cfg.bias_correction2;
79
80 // Decoupled weight decay (AdamW, not Adam with L2)
81 params[i] = params[i] - cfg.lr * (m_hat / (sqrt(v_hat) + cfg.eps) + cfg.weight_decay * params[i]);
82}
83";
84
85const GRAD_CLIP_SHADER: &str = r"
87@group(0) @binding(0) var<storage, read_write> grads: array<f32>;
88
89struct ClipParams {
90 scale: f32,
91 n: u32,
92 _pad0: u32,
93 _pad1: u32,
94}
95
96@group(0) @binding(1) var<uniform> cfg: ClipParams;
97
98@compute @workgroup_size(256)
99fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
100 let i = gid.x;
101 if (i >= cfg.n) { return; }
102 grads[i] = grads[i] * cfg.scale;
103}
104";
105
106pub struct WgpuTrainer {
108 device: wgpu::Device,
109 queue: wgpu::Queue,
110 matmul_pipeline: wgpu::ComputePipeline,
111 matmul_bgl: wgpu::BindGroupLayout,
112 adamw_pipeline: wgpu::ComputePipeline,
113 adamw_bgl: wgpu::BindGroupLayout,
114 clip_pipeline: wgpu::ComputePipeline,
115 clip_bgl: wgpu::BindGroupLayout,
116 step: u32,
117}
118
119#[repr(C)]
120#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
121struct GemmDims {
122 m: u32,
123 k: u32,
124 n: u32,
125 alpha_bits: u32,
126}
127
128#[repr(C)]
129#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
130struct AdamWConfig {
131 lr: f32,
132 beta1: f32,
133 beta2: f32,
134 eps: f32,
135 weight_decay: f32,
136 bias_correction1: f32,
137 bias_correction2: f32,
138 n: u32,
139}
140
141#[repr(C)]
142#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
143struct ClipConfig {
144 scale: f32,
145 n: u32,
146 _pad0: u32,
147 _pad1: u32,
148}
149
150impl WgpuTrainer {
151 pub fn new() -> Result<Self> {
153 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
154 backends: wgpu::Backends::VULKAN | wgpu::Backends::METAL,
155 ..Default::default()
156 });
157
158 let adapter = trueno::backends::gpu::runtime::block_on(instance.request_adapter(
159 &wgpu::RequestAdapterOptions {
160 power_preference: wgpu::PowerPreference::HighPerformance,
161 ..Default::default()
162 },
163 ))
164 .map_err(|e| CudaTensorError::CudaNotAvailable(format!("No wgpu adapter: {e}")))?;
165
166 let (device, queue) = trueno::backends::gpu::runtime::block_on(adapter.request_device(
167 &wgpu::DeviceDescriptor {
168 label: Some("WgpuTrainer"),
169 required_features: wgpu::Features::empty(),
170 required_limits: wgpu::Limits {
171 max_storage_buffer_binding_size:
172 adapter.limits().max_storage_buffer_binding_size,
173 max_buffer_size: adapter.limits().max_buffer_size,
174 ..Default::default()
175 },
176 memory_hints: wgpu::MemoryHints::Performance,
177 experimental_features: Default::default(),
178 trace: Default::default(),
179 },
180 ))
181 .map_err(|e| CudaTensorError::CudaNotAvailable(format!("wgpu device: {e}")))?;
182
183 let matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
185 label: Some("tiled_gemm"),
186 source: wgpu::ShaderSource::Wgsl(GEMM_SHADER.into()),
187 });
188 let matmul_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
189 label: Some("gemm_bgl"),
190 entries: &[
191 storage_entry(0, true),
192 storage_entry(1, true),
193 storage_entry(2, false),
194 uniform_entry(3),
195 ],
196 });
197 let matmul_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
198 label: Some("gemm_pl"),
199 bind_group_layouts: &[&matmul_bgl],
200 push_constant_ranges: &[],
201 });
202 let matmul_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
203 label: Some("tiled_gemm_pipe"),
204 layout: Some(&matmul_pl),
205 module: &matmul_shader,
206 entry_point: Some("main"),
207 compilation_options: Default::default(),
208 cache: None,
209 });
210
211 let adamw_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
213 label: Some("adamw"),
214 source: wgpu::ShaderSource::Wgsl(ADAMW_SHADER.into()),
215 });
216 let adamw_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
217 label: Some("adamw_bgl"),
218 entries: &[
219 storage_entry(0, false), storage_entry(1, true), storage_entry(2, false), storage_entry(3, false), uniform_entry(4), ],
225 });
226 let adamw_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
227 label: Some("adamw_pl"),
228 bind_group_layouts: &[&adamw_bgl],
229 push_constant_ranges: &[],
230 });
231 let adamw_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
232 label: Some("adamw_pipe"),
233 layout: Some(&adamw_pl),
234 module: &adamw_shader,
235 entry_point: Some("main"),
236 compilation_options: Default::default(),
237 cache: None,
238 });
239
240 let clip_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
242 label: Some("grad_clip"),
243 source: wgpu::ShaderSource::Wgsl(GRAD_CLIP_SHADER.into()),
244 });
245 let clip_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
246 label: Some("clip_bgl"),
247 entries: &[storage_entry(0, false), uniform_entry(1)],
248 });
249 let clip_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
250 label: Some("clip_pl"),
251 bind_group_layouts: &[&clip_bgl],
252 push_constant_ranges: &[],
253 });
254 let clip_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
255 label: Some("clip_pipe"),
256 layout: Some(&clip_pl),
257 module: &clip_shader,
258 entry_point: Some("main"),
259 compilation_options: Default::default(),
260 cache: None,
261 });
262
263 Ok(Self {
264 device,
265 queue,
266 matmul_pipeline,
267 matmul_bgl,
268 adamw_pipeline,
269 adamw_bgl,
270 clip_pipeline,
271 clip_bgl,
272 step: 0,
273 })
274 }
275
276 pub fn upload(&self, data: &[f32]) -> wgpu::Buffer {
278 let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
279 label: Some("upload_data"),
280 size: (data.len() * 4) as u64,
281 usage: wgpu::BufferUsages::STORAGE
282 | wgpu::BufferUsages::COPY_SRC
283 | wgpu::BufferUsages::COPY_DST,
284 mapped_at_creation: false,
285 });
286 self.queue.write_buffer(&buf, 0, bytemuck::cast_slice(data));
287 buf
288 }
289
290 pub fn zeros(&self, len: usize) -> wgpu::Buffer {
292 self.upload(&vec![0.0f32; len])
293 }
294
295 pub fn download(&self, buffer: &wgpu::Buffer) -> Vec<f32> {
297 let size = buffer.size();
298 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
300 label: Some("download_staging"),
301 size,
302 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
303 mapped_at_creation: false,
304 });
305 let mut encoder = self.device.create_command_encoder(&Default::default());
306 encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size);
307 self.queue.submit(Some(encoder.finish()));
308
309 let slice = staging.slice(..);
310 let (tx, rx) = std::sync::mpsc::channel();
311 slice.map_async(wgpu::MapMode::Read, move |r| {
312 tx.send(r).ok();
313 });
314 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
315 rx.recv().unwrap().unwrap();
316
317 let data = slice.get_mapped_range();
318 let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
319 drop(data);
320 staging.unmap();
321 result
322 }
323
324 pub fn matmul_forward(
326 &self,
327 a: &wgpu::Buffer,
328 b: &wgpu::Buffer,
329 c: &wgpu::Buffer,
330 m: u32,
331 k: u32,
332 n: u32,
333 ) {
334 self.dispatch_gemm(a, b, c, m, k, n, 1.0);
335 }
336
337 pub fn matmul_backward(
344 &self,
345 a: &wgpu::Buffer, b: &wgpu::Buffer, grad_c: &wgpu::Buffer, grad_a: &wgpu::Buffer, grad_b: &wgpu::Buffer, m: u32,
351 k: u32,
352 n: u32,
353 ) {
354 debug_assert!(
356 m > 0 && k > 0 && n > 0,
357 "Contract matmul_backward: dimensions must be positive"
358 );
359 let b_data = self.download(b);
384 let mut bt_data = vec![0.0f32; (k * n) as usize];
385 for i in 0..k as usize {
386 for j in 0..n as usize {
387 bt_data[j * k as usize + i] = b_data[i * n as usize + j];
388 }
389 }
390 let bt = self.upload(&bt_data);
391 self.dispatch_gemm(grad_c, &bt, grad_a, m, n, k, 1.0);
392
393 let a_data = self.download(a);
395 let mut at_data = vec![0.0f32; (m * k) as usize];
396 for i in 0..m as usize {
397 for j in 0..k as usize {
398 at_data[j * m as usize + i] = a_data[i * k as usize + j];
399 }
400 }
401 let at = self.upload(&at_data);
402 self.dispatch_gemm(&at, grad_c, grad_b, k, m, n, 1.0);
403 }
404
405 pub fn adamw_step(
407 &mut self,
408 params: &wgpu::Buffer,
409 grads: &wgpu::Buffer,
410 m_state: &wgpu::Buffer,
411 v_state: &wgpu::Buffer,
412 lr: f32,
413 beta1: f32,
414 beta2: f32,
415 eps: f32,
416 weight_decay: f32,
417 ) {
418 self.step += 1;
419 let n = (params.size() / 4) as u32;
420 let bc1 = 1.0 - beta1.powi(self.step as i32);
421 let bc2 = 1.0 - beta2.powi(self.step as i32);
422
423 let cfg = AdamWConfig {
424 lr,
425 beta1,
426 beta2,
427 eps,
428 weight_decay,
429 bias_correction1: bc1,
430 bias_correction2: bc2,
431 n,
432 };
433 let cfg_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
434 label: None,
435 size: std::mem::size_of::<AdamWConfig>() as u64,
436 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
437 mapped_at_creation: false,
438 });
439 self.queue.write_buffer(&cfg_buf, 0, bytemuck::bytes_of(&cfg));
440
441 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
442 label: None,
443 layout: &self.adamw_bgl,
444 entries: &[
445 wgpu::BindGroupEntry { binding: 0, resource: params.as_entire_binding() },
446 wgpu::BindGroupEntry { binding: 1, resource: grads.as_entire_binding() },
447 wgpu::BindGroupEntry { binding: 2, resource: m_state.as_entire_binding() },
448 wgpu::BindGroupEntry { binding: 3, resource: v_state.as_entire_binding() },
449 wgpu::BindGroupEntry { binding: 4, resource: cfg_buf.as_entire_binding() },
450 ],
451 });
452
453 let mut encoder = self.device.create_command_encoder(&Default::default());
454 {
455 let mut pass = encoder.begin_compute_pass(&Default::default());
456 pass.set_pipeline(&self.adamw_pipeline);
457 pass.set_bind_group(0, &bg, &[]);
458 pass.dispatch_workgroups(n.div_ceil(256), 1, 1);
459 }
460 self.queue.submit(Some(encoder.finish()));
461 }
462
463 pub fn clip_gradients(&self, grads: &wgpu::Buffer, max_norm: f32) {
465 let grad_data = self.download(grads);
466 let grad_norm: f32 = grad_data.iter().map(|x| x * x).sum::<f32>().sqrt();
467 let scale = if grad_norm > max_norm {
468 max_norm / grad_norm
469 } else {
470 return; };
472
473 let n = grad_data.len() as u32;
474 let cfg = ClipConfig { scale, n, _pad0: 0, _pad1: 0 };
475 let cfg_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
476 label: None,
477 size: 16,
478 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
479 mapped_at_creation: false,
480 });
481 self.queue.write_buffer(&cfg_buf, 0, bytemuck::bytes_of(&cfg));
482
483 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
484 label: None,
485 layout: &self.clip_bgl,
486 entries: &[
487 wgpu::BindGroupEntry { binding: 0, resource: grads.as_entire_binding() },
488 wgpu::BindGroupEntry { binding: 1, resource: cfg_buf.as_entire_binding() },
489 ],
490 });
491
492 let mut encoder = self.device.create_command_encoder(&Default::default());
493 {
494 let mut pass = encoder.begin_compute_pass(&Default::default());
495 pass.set_pipeline(&self.clip_pipeline);
496 pass.set_bind_group(0, &bg, &[]);
497 pass.dispatch_workgroups(n.div_ceil(256), 1, 1);
498 }
499 self.queue.submit(Some(encoder.finish()));
500 }
501
502 pub fn step_count(&self) -> u32 {
504 self.step
505 }
506
507 pub fn reset_step(&mut self) {
509 self.step = 0;
510 }
511
512 pub fn queue_ref(&self) -> &wgpu::Queue {
514 &self.queue
515 }
516
517 pub fn device_ref(&self) -> &wgpu::Device {
519 &self.device
520 }
521
522 pub fn from_device(device: wgpu::Device, queue: wgpu::Queue) -> Result<Self> {
525 let matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
526 label: Some("tiled_gemm"),
527 source: wgpu::ShaderSource::Wgsl(GEMM_SHADER.into()),
528 });
529 let matmul_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
530 label: Some("gemm_bgl"),
531 entries: &[
532 storage_entry(0, true),
533 storage_entry(1, true),
534 storage_entry(2, false),
535 uniform_entry(3),
536 ],
537 });
538 let matmul_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
539 label: Some("gemm_pl"),
540 bind_group_layouts: &[&matmul_bgl],
541 push_constant_ranges: &[],
542 });
543 let matmul_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
544 label: Some("tiled_gemm_pipe"),
545 layout: Some(&matmul_pl),
546 module: &matmul_shader,
547 entry_point: Some("main"),
548 compilation_options: Default::default(),
549 cache: None,
550 });
551
552 let adamw_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
553 label: Some("adamw"),
554 source: wgpu::ShaderSource::Wgsl(ADAMW_SHADER.into()),
555 });
556 let adamw_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
557 label: Some("adamw_bgl"),
558 entries: &[
559 storage_entry(0, false),
560 storage_entry(1, true),
561 storage_entry(2, false),
562 storage_entry(3, false),
563 uniform_entry(4),
564 ],
565 });
566 let adamw_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
567 label: Some("adamw_pl"),
568 bind_group_layouts: &[&adamw_bgl],
569 push_constant_ranges: &[],
570 });
571 let adamw_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
572 label: Some("adamw_pipe"),
573 layout: Some(&adamw_pl),
574 module: &adamw_shader,
575 entry_point: Some("main"),
576 compilation_options: Default::default(),
577 cache: None,
578 });
579
580 let clip_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
581 label: Some("grad_clip"),
582 source: wgpu::ShaderSource::Wgsl(GRAD_CLIP_SHADER.into()),
583 });
584 let clip_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
585 label: Some("clip_bgl"),
586 entries: &[storage_entry(0, false), uniform_entry(1)],
587 });
588 let clip_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
589 label: Some("clip_pl"),
590 bind_group_layouts: &[&clip_bgl],
591 push_constant_ranges: &[],
592 });
593 let clip_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
594 label: Some("clip_pipe"),
595 layout: Some(&clip_pl),
596 module: &clip_shader,
597 entry_point: Some("main"),
598 compilation_options: Default::default(),
599 cache: None,
600 });
601
602 Ok(Self {
603 device,
604 queue,
605 matmul_pipeline,
606 matmul_bgl,
607 adamw_pipeline,
608 adamw_bgl,
609 clip_pipeline,
610 clip_bgl,
611 step: 0,
612 })
613 }
614
615 fn dispatch_gemm(
618 &self,
619 a: &wgpu::Buffer,
620 b: &wgpu::Buffer,
621 c: &wgpu::Buffer,
622 m: u32,
623 k: u32,
624 n: u32,
625 alpha: f32,
626 ) {
627 let max_binding = u64::from(self.device.limits().max_storage_buffer_binding_size);
630 let b_bytes = u64::from(k) * u64::from(n) * 4;
631 if b_bytes > max_binding {
632 let max_n_chunk = (max_binding / 4 / u64::from(k)) as u32;
633 let max_n_chunk = max_n_chunk.max(1);
634
635 let b_data = self.download(b);
642 let mut n_start = 0u32;
643 while n_start < n {
644 let chunk_n = (n - n_start).min(max_n_chunk);
645 let mut b_chunk = vec![0.0f32; (k * chunk_n) as usize];
646 for row in 0..k as usize {
647 let src_start = row * n as usize + n_start as usize;
648 let dst_start = row * chunk_n as usize;
649 b_chunk[dst_start..dst_start + chunk_n as usize]
650 .copy_from_slice(&b_data[src_start..src_start + chunk_n as usize]);
651 }
652 let b_chunk_buf = self.upload(&b_chunk);
653 let c_chunk_buf = self.zeros((m * chunk_n) as usize);
654 self.dispatch_gemm(a, &b_chunk_buf, &c_chunk_buf, m, k, chunk_n, alpha);
655 let c_chunk_data = self.download(&c_chunk_buf);
658 let mut c_data =
660 if n_start == 0 { vec![0.0f32; (m * n) as usize] } else { self.download(c) };
661 for row in 0..m as usize {
662 let dst_start = row * n as usize + n_start as usize;
663 let src_start = row * chunk_n as usize;
664 c_data[dst_start..dst_start + chunk_n as usize]
665 .copy_from_slice(&c_chunk_data[src_start..src_start + chunk_n as usize]);
666 }
667 self.queue.write_buffer(c, 0, bytemuck::cast_slice(&c_data));
668 n_start += chunk_n;
669 }
670 return;
671 }
672
673 let dims = GemmDims { m, k, n, alpha_bits: alpha.to_bits() };
674 let dims_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
675 label: None,
676 size: 16,
677 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
678 mapped_at_creation: false,
679 });
680 self.queue.write_buffer(&dims_buf, 0, bytemuck::bytes_of(&dims));
681
682 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
683 label: None,
684 layout: &self.matmul_bgl,
685 entries: &[
686 wgpu::BindGroupEntry { binding: 0, resource: a.as_entire_binding() },
687 wgpu::BindGroupEntry { binding: 1, resource: b.as_entire_binding() },
688 wgpu::BindGroupEntry { binding: 2, resource: c.as_entire_binding() },
689 wgpu::BindGroupEntry { binding: 3, resource: dims_buf.as_entire_binding() },
690 ],
691 });
692
693 let mut encoder = self.device.create_command_encoder(&Default::default());
694 {
695 let mut pass = encoder.begin_compute_pass(&Default::default());
696 pass.set_pipeline(&self.matmul_pipeline);
697 pass.set_bind_group(0, &bg, &[]);
698 pass.dispatch_workgroups(n.div_ceil(64), m.div_ceil(64), 1);
699 }
700 self.queue.submit(Some(encoder.finish()));
701 }
702}
703
704fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
705 wgpu::BindGroupLayoutEntry {
706 binding,
707 visibility: wgpu::ShaderStages::COMPUTE,
708 ty: wgpu::BindingType::Buffer {
709 ty: wgpu::BufferBindingType::Storage { read_only },
710 has_dynamic_offset: false,
711 min_binding_size: None,
712 },
713 count: None,
714 }
715}
716
717fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
718 wgpu::BindGroupLayoutEntry {
719 binding,
720 visibility: wgpu::ShaderStages::COMPUTE,
721 ty: wgpu::BindingType::Buffer {
722 ty: wgpu::BufferBindingType::Uniform,
723 has_dynamic_offset: false,
724 min_binding_size: None,
725 },
726 count: None,
727 }
728}
729
730#[cfg(test)]
731#[path = "wgpu_training_tests.rs"]
732mod tests;