1#[cfg(feature = "gpu")]
21use crate::{
22 autograd::{wgpu_cross_entropy::WgslCrossEntropy, wgpu_training::WgpuTrainer},
23 finetune::instruct_pipeline::InstructStepResult,
24 tokenizer::HfTokenizer,
25};
26#[cfg(feature = "gpu")]
27use trueno::backends::gpu::{wgpu, WgslForwardPass};
28
29#[cfg(feature = "gpu")]
31pub struct LayerLoRA {
32 pub projections: Vec<LoRAProjection>,
34}
35
36#[cfg(feature = "gpu")]
37pub struct LoRAProjection {
38 pub a: wgpu::Buffer, pub b: wgpu::Buffer, pub m_a: wgpu::Buffer, pub v_a: wgpu::Buffer, pub m_b: wgpu::Buffer, pub v_b: wgpu::Buffer, pub in_dim: u32,
45 pub out_dim: u32,
46 pub name: String, }
48
49#[cfg(feature = "gpu")]
51pub struct WgpuInstructPipeline {
52 fwd: WgslForwardPass,
54 cross_entropy: WgslCrossEntropy,
56 trainer: WgpuTrainer,
58 lm_head_t_chunks: Vec<(wgpu::Buffer, u32)>, lm_head_chunks: Vec<(wgpu::Buffer, u32)>, lora_addmm_pipeline: wgpu::ComputePipeline,
63 lora_addmm_bgl: wgpu::BindGroupLayout,
64 scatter_pipeline: wgpu::ComputePipeline,
66 gather_pipeline: wgpu::ComputePipeline,
67 scatter_bgl: wgpu::BindGroupLayout,
68 transpose_pipeline: wgpu::ComputePipeline,
69 transpose_bgl: wgpu::BindGroupLayout,
70 logits_buf: wgpu::Buffer,
72 labels_buf: wgpu::Buffer,
73 losses_buf: wgpu::Buffer,
74 logsumexp_buf: wgpu::Buffer,
75 lora: Vec<LayerLoRA>,
78 lora_rank: usize,
79 lora_scale: f32, lora_step: u32, learning_rate: f32, lora_target_set: Vec<String>, num_layers: usize,
85 hidden_dim: usize,
86 vocab_size: usize,
87 max_seq_len: usize,
88 tokenizer: HfTokenizer,
90 embed_weights: Vec<f32>,
92 output_norm_gpu: wgpu::Buffer,
94 normed_buf: wgpu::Buffer,
96 eps: f32,
98}
99
100#[cfg(feature = "gpu")]
101impl WgpuInstructPipeline {
102 pub fn new(
113 fwd: WgslForwardPass,
114 trainer: WgpuTrainer,
115 tokenizer: HfTokenizer,
116 embed_weights: Vec<f32>,
117 output_norm: Vec<f32>,
118 lm_head_t_chunks: Vec<(wgpu::Buffer, u32)>,
119 lm_head_chunks: Vec<(wgpu::Buffer, u32)>,
120 num_layers: usize,
121 hidden_dim: usize,
122 vocab_size: usize,
123 max_seq_len: usize,
124 num_heads: usize,
125 num_kv_heads: usize,
126 intermediate_dim: usize,
127 lora_rank: usize,
128 lora_alpha: f32,
129 lora_targets: &[&str],
130 eps: f32,
131 learning_rate: f32,
132 ) -> Self {
133 let ce = WgslCrossEntropy::new(trainer.device_ref().clone(), trainer.queue_ref().clone());
134
135 let seq = max_seq_len as u32;
136 let vocab = vocab_size as u32;
137 let make_buf = |size: u64, label: &str| -> wgpu::Buffer {
138 trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
139 label: Some(label),
140 size: size * 4,
141 usage: wgpu::BufferUsages::STORAGE
142 | wgpu::BufferUsages::COPY_SRC
143 | wgpu::BufferUsages::COPY_DST,
144 mapped_at_creation: false,
145 })
146 };
147
148 let scatter_bgl =
150 trainer.device_ref().create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
151 label: Some("scatter_bgl"),
152 entries: &[
153 wgpu::BindGroupLayoutEntry {
154 binding: 0,
155 visibility: wgpu::ShaderStages::COMPUTE,
156 ty: wgpu::BindingType::Buffer {
157 ty: wgpu::BufferBindingType::Storage { read_only: true },
158 has_dynamic_offset: false,
159 min_binding_size: None,
160 },
161 count: None,
162 },
163 wgpu::BindGroupLayoutEntry {
164 binding: 1,
165 visibility: wgpu::ShaderStages::COMPUTE,
166 ty: wgpu::BindingType::Buffer {
167 ty: wgpu::BufferBindingType::Storage { read_only: false },
168 has_dynamic_offset: false,
169 min_binding_size: None,
170 },
171 count: None,
172 },
173 wgpu::BindGroupLayoutEntry {
174 binding: 2,
175 visibility: wgpu::ShaderStages::COMPUTE,
176 ty: wgpu::BindingType::Buffer {
177 ty: wgpu::BufferBindingType::Uniform,
178 has_dynamic_offset: false,
179 min_binding_size: None,
180 },
181 count: None,
182 },
183 ],
184 });
185 let scatter_pl =
186 trainer.device_ref().create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
187 label: Some("scatter_pl"),
188 bind_group_layouts: &[&scatter_bgl],
189 push_constant_ranges: &[],
190 });
191 let scatter_shader =
192 trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
193 label: Some("scatter"),
194 source: wgpu::ShaderSource::Wgsl(
195 trueno::backends::gpu::shaders::COLUMN_SCATTER_SHADER.into(),
196 ),
197 });
198 let scatter_pipeline =
199 trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
200 label: Some("scatter_pipe"),
201 layout: Some(&scatter_pl),
202 module: &scatter_shader,
203 entry_point: Some("main"),
204 compilation_options: Default::default(),
205 cache: None,
206 });
207 let gather_shader =
208 trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
209 label: Some("gather"),
210 source: wgpu::ShaderSource::Wgsl(
211 trueno::backends::gpu::shaders::COLUMN_GATHER_SHADER.into(),
212 ),
213 });
214 let gather_pipeline =
215 trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
216 label: Some("gather_pipe"),
217 layout: Some(&scatter_pl),
218 module: &gather_shader,
219 entry_point: Some("main"),
220 compilation_options: Default::default(),
221 cache: None,
222 });
223
224 let transpose_bgl =
226 trainer.device_ref().create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
227 label: Some("transpose_bgl"),
228 entries: &[
229 wgpu::BindGroupLayoutEntry {
230 binding: 0,
231 visibility: wgpu::ShaderStages::COMPUTE,
232 ty: wgpu::BindingType::Buffer {
233 ty: wgpu::BufferBindingType::Storage { read_only: true },
234 has_dynamic_offset: false,
235 min_binding_size: None,
236 },
237 count: None,
238 },
239 wgpu::BindGroupLayoutEntry {
240 binding: 1,
241 visibility: wgpu::ShaderStages::COMPUTE,
242 ty: wgpu::BindingType::Buffer {
243 ty: wgpu::BufferBindingType::Storage { read_only: false },
244 has_dynamic_offset: false,
245 min_binding_size: None,
246 },
247 count: None,
248 },
249 wgpu::BindGroupLayoutEntry {
250 binding: 2,
251 visibility: wgpu::ShaderStages::COMPUTE,
252 ty: wgpu::BindingType::Buffer {
253 ty: wgpu::BufferBindingType::Uniform,
254 has_dynamic_offset: false,
255 min_binding_size: None,
256 },
257 count: None,
258 },
259 ],
260 });
261 let transpose_pl =
262 trainer.device_ref().create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
263 label: Some("transpose_pl"),
264 bind_group_layouts: &[&transpose_bgl],
265 push_constant_ranges: &[],
266 });
267 let transpose_shader =
268 trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
269 label: Some("transpose"),
270 source: wgpu::ShaderSource::Wgsl(
271 trueno::backends::gpu::shaders::TRANSPOSE_SHADER.into(),
272 ),
273 });
274 let transpose_pipeline =
275 trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
276 label: Some("transpose_pipe"),
277 layout: Some(&transpose_pl),
278 module: &transpose_shader,
279 entry_point: Some("main"),
280 compilation_options: Default::default(),
281 cache: None,
282 });
283
284 let lora_bgl =
286 trainer.device_ref().create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
287 label: Some("lora_bgl"),
288 entries: &[
289 wgpu::BindGroupLayoutEntry {
290 binding: 0,
291 visibility: wgpu::ShaderStages::COMPUTE,
292 ty: wgpu::BindingType::Buffer {
293 ty: wgpu::BufferBindingType::Storage { read_only: true },
294 has_dynamic_offset: false,
295 min_binding_size: None,
296 },
297 count: None,
298 },
299 wgpu::BindGroupLayoutEntry {
300 binding: 1,
301 visibility: wgpu::ShaderStages::COMPUTE,
302 ty: wgpu::BindingType::Buffer {
303 ty: wgpu::BufferBindingType::Storage { read_only: true },
304 has_dynamic_offset: false,
305 min_binding_size: None,
306 },
307 count: None,
308 },
309 wgpu::BindGroupLayoutEntry {
310 binding: 2,
311 visibility: wgpu::ShaderStages::COMPUTE,
312 ty: wgpu::BindingType::Buffer {
313 ty: wgpu::BufferBindingType::Storage { read_only: true },
314 has_dynamic_offset: false,
315 min_binding_size: None,
316 },
317 count: None,
318 },
319 wgpu::BindGroupLayoutEntry {
320 binding: 3,
321 visibility: wgpu::ShaderStages::COMPUTE,
322 ty: wgpu::BindingType::Buffer {
323 ty: wgpu::BufferBindingType::Storage { read_only: false },
324 has_dynamic_offset: false,
325 min_binding_size: None,
326 },
327 count: None,
328 },
329 wgpu::BindGroupLayoutEntry {
330 binding: 4,
331 visibility: wgpu::ShaderStages::COMPUTE,
332 ty: wgpu::BindingType::Buffer {
333 ty: wgpu::BufferBindingType::Uniform,
334 has_dynamic_offset: false,
335 min_binding_size: None,
336 },
337 count: None,
338 },
339 ],
340 });
341 let lora_pl =
342 trainer.device_ref().create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
343 label: Some("lora_pl"),
344 bind_group_layouts: &[&lora_bgl],
345 push_constant_ranges: &[],
346 });
347 let lora_shader = trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
348 label: Some("lora_addmm"),
349 source: wgpu::ShaderSource::Wgsl(
350 trueno::backends::gpu::shaders::LORA_ADDMM_SHADER.into(),
351 ),
352 });
353 let lora_addmm_pipeline =
354 trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
355 label: Some("lora_addmm_pipe"),
356 layout: Some(&lora_pl),
357 module: &lora_shader,
358 entry_point: Some("main"),
359 compilation_options: Default::default(),
360 cache: None,
361 });
362 let lora_addmm_bgl = lora_bgl;
363
364 let r = lora_rank;
367 let scale = lora_alpha / r as f32;
368 let h = hidden_dim;
369 let q_dim = num_heads * (hidden_dim / num_heads);
370 let kv_dim = num_kv_heads * (hidden_dim / num_heads);
371 let inter = intermediate_dim;
372
373 let all_proj_dims: &[(&str, usize, usize)] = &[
374 ("q_proj", h, q_dim),
375 ("k_proj", h, kv_dim),
376 ("v_proj", h, kv_dim),
377 ("o_proj", q_dim, h),
378 ("gate_proj", h, inter),
379 ("up_proj", h, inter),
380 ("down_proj", inter, h),
381 ];
382
383 let _use_all = true; let proj_dims: Vec<(&str, usize, usize)> = all_proj_dims.to_vec();
389 let num_targets = proj_dims.len();
390 let qkv_only = vec!["q_proj".to_string(), "k_proj".to_string(), "v_proj".to_string()];
391 let lora_target_set: Vec<String> =
392 if lora_targets.is_empty() || lora_targets.contains(&"all") {
393 qkv_only
395 } else {
396 lora_targets.iter().map(std::string::ToString::to_string).collect()
397 };
398
399 let mut lora = Vec::with_capacity(num_layers);
400 for layer_idx in 0..num_layers {
401 let mut projections = Vec::with_capacity(num_targets);
402 for &(name, in_d, out_d) in &proj_dims {
403 let std = (2.0 / in_d as f32).sqrt();
405 let a_data: Vec<f32> = (0..in_d * r)
406 .map(|i| ((i as f32 * 0.013 + layer_idx as f32 * 7.0).sin() * std))
407 .collect();
408 let b_data = vec![0.0f32; r * out_d];
410 let zeros_a = vec![0.0f32; in_d * r];
411 let zeros_b = vec![0.0f32; r * out_d];
412
413 projections.push(LoRAProjection {
414 a: trainer.upload(&a_data),
415 b: trainer.upload(&b_data),
416 m_a: trainer.upload(&zeros_a),
417 v_a: trainer.upload(&zeros_a),
418 m_b: trainer.upload(&zeros_b),
419 v_b: trainer.upload(&zeros_b),
420 in_dim: in_d as u32,
421 out_dim: out_d as u32,
422 name: name.to_string(),
423 });
424 }
425 lora.push(LayerLoRA { projections });
426 }
427
428 eprintln!(
429 "[wgpu] LoRA initialized: {num_layers} layers × {num_targets} projections, rank={r}, scale={scale:.2}",
430 );
431
432 let logits_buf = make_buf(u64::from(seq) * u64::from(vocab), "logits");
434 let labels_buf = make_buf(u64::from(seq), "labels");
435 let losses_buf = make_buf(u64::from(seq), "losses");
436 let logsumexp_buf = make_buf(u64::from(seq), "logsumexp");
437 let normed_buf_alloc = make_buf(u64::from(seq) * hidden_dim as u64, "normed");
438 let output_norm_gpu_buf = trainer.upload(&output_norm);
439
440 Self {
441 fwd,
442 cross_entropy: ce,
443 logits_buf,
444 labels_buf,
445 losses_buf,
446 logsumexp_buf,
447 lm_head_t_chunks,
448 lm_head_chunks,
449 lora_addmm_pipeline,
450 lora_addmm_bgl,
451 scatter_pipeline,
452 gather_pipeline,
453 transpose_pipeline,
454 transpose_bgl,
455 scatter_bgl,
456 trainer,
457 lora,
458 lora_rank: r,
459 lora_scale: scale,
460 lora_step: 0,
461 learning_rate,
462 lora_target_set: lora_target_set,
463 num_layers,
464 hidden_dim,
465 vocab_size,
466 max_seq_len,
467 tokenizer,
468 embed_weights,
469 output_norm_gpu: output_norm_gpu_buf,
470 normed_buf: normed_buf_alloc,
471 eps,
472 }
473 }
474
475 fn dispatch_lora_addmm(
478 &self,
479 input: &wgpu::Buffer,
480 lora_a: &wgpu::Buffer,
481 lora_b: &wgpu::Buffer,
482 output: &wgpu::Buffer,
483 seq_len: u32,
484 in_dim: u32,
485 rank: u32,
486 out_dim: u32,
487 ) {
488 #[repr(C)]
489 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
490 struct P {
491 seq_len: u32,
492 in_dim: u32,
493 rank: u32,
494 out_dim: u32,
495 scale: f32,
496 _p0: u32,
497 _p1: u32,
498 _p2: u32,
499 }
500 let params =
501 P { seq_len, in_dim, rank, out_dim, scale: self.lora_scale, _p0: 0, _p1: 0, _p2: 0 };
502 let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
503 label: None,
504 size: 32,
505 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
506 mapped_at_creation: false,
507 });
508 self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(¶ms));
509 let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
510 label: None,
511 layout: &self.lora_addmm_bgl,
512 entries: &[
513 wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
514 wgpu::BindGroupEntry { binding: 1, resource: lora_a.as_entire_binding() },
515 wgpu::BindGroupEntry { binding: 2, resource: lora_b.as_entire_binding() },
516 wgpu::BindGroupEntry { binding: 3, resource: output.as_entire_binding() },
517 wgpu::BindGroupEntry { binding: 4, resource: pbuf.as_entire_binding() },
518 ],
519 });
520 let total = seq_len * out_dim;
521 let wg = total.div_ceil(256);
522 let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
523 let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
524 {
525 let mut pass = encoder.begin_compute_pass(&Default::default());
526 pass.set_pipeline(&self.lora_addmm_pipeline);
527 pass.set_bind_group(0, &bg, &[]);
528 pass.dispatch_workgroups(x, y, 1);
529 }
530 self.trainer.queue_ref().submit(Some(encoder.finish()));
531 }
532
533 fn dispatch_scatter(
535 &self,
536 src: &wgpu::Buffer,
537 dst: &wgpu::Buffer,
538 seq_len: u32,
539 chunk_n: u32,
540 full_n: u32,
541 col_offset: u32,
542 ) {
543 #[repr(C)]
544 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
545 struct P {
546 seq_len: u32,
547 chunk_n: u32,
548 full_n: u32,
549 col_offset: u32,
550 }
551 let params = P { seq_len, chunk_n, full_n, col_offset };
552 let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
553 label: None,
554 size: 16,
555 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
556 mapped_at_creation: false,
557 });
558 self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(¶ms));
559 let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
560 label: None,
561 layout: &self.scatter_bgl,
562 entries: &[
563 wgpu::BindGroupEntry { binding: 0, resource: src.as_entire_binding() },
564 wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
565 wgpu::BindGroupEntry { binding: 2, resource: pbuf.as_entire_binding() },
566 ],
567 });
568 let total = seq_len * chunk_n;
569 let wg = total.div_ceil(256);
570 let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
571 let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
572 {
573 let mut pass = encoder.begin_compute_pass(&Default::default());
574 pass.set_pipeline(&self.scatter_pipeline);
575 pass.set_bind_group(0, &bg, &[]);
576 pass.dispatch_workgroups(x, y, 1);
577 }
578 self.trainer.queue_ref().submit(Some(encoder.finish()));
579 }
580
581 fn dispatch_gather(
583 &self,
584 src: &wgpu::Buffer,
585 dst: &wgpu::Buffer,
586 seq_len: u32,
587 chunk_n: u32,
588 full_n: u32,
589 col_offset: u32,
590 ) {
591 #[repr(C)]
592 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
593 struct P {
594 seq_len: u32,
595 chunk_n: u32,
596 full_n: u32,
597 col_offset: u32,
598 }
599 let params = P { seq_len, chunk_n, full_n, col_offset };
600 let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
601 label: None,
602 size: 16,
603 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
604 mapped_at_creation: false,
605 });
606 self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(¶ms));
607 let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
608 label: None,
609 layout: &self.scatter_bgl,
610 entries: &[
611 wgpu::BindGroupEntry { binding: 0, resource: src.as_entire_binding() },
612 wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
613 wgpu::BindGroupEntry { binding: 2, resource: pbuf.as_entire_binding() },
614 ],
615 });
616 let total = seq_len * chunk_n;
617 let wg = total.div_ceil(256);
618 let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
619 let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
620 {
621 let mut pass = encoder.begin_compute_pass(&Default::default());
622 pass.set_pipeline(&self.gather_pipeline);
623 pass.set_bind_group(0, &bg, &[]);
624 pass.dispatch_workgroups(x, y, 1);
625 }
626 self.trainer.queue_ref().submit(Some(encoder.finish()));
627 }
628
629 fn dispatch_transpose(
632 &self,
633 src: &wgpu::Buffer,
634 dst: &wgpu::Buffer,
635 m: u32,
636 n: u32,
637 scale: f32,
638 ) {
639 #[repr(C)]
640 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
641 struct P {
642 m: u32,
643 n: u32,
644 scale: f32,
645 _pad: u32,
646 }
647 let params = P { m, n, scale, _pad: 0 };
648 let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
649 label: None,
650 size: 16,
651 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
652 mapped_at_creation: false,
653 });
654 self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(¶ms));
655 let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
656 label: None,
657 layout: &self.transpose_bgl,
658 entries: &[
659 wgpu::BindGroupEntry { binding: 0, resource: src.as_entire_binding() },
660 wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
661 wgpu::BindGroupEntry { binding: 2, resource: pbuf.as_entire_binding() },
662 ],
663 });
664 let total = m * n;
665 let wg = total.div_ceil(256);
666 let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
667 let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
668 {
669 let mut pass = encoder.begin_compute_pass(&Default::default());
670 pass.set_pipeline(&self.transpose_pipeline);
671 pass.set_bind_group(0, &bg, &[]);
672 pass.dispatch_workgroups(x, y, 1);
673 }
674 self.trainer.queue_ref().submit(Some(encoder.finish()));
675 }
676
677 pub fn encode(&self, text: &str) -> Vec<u32> {
678 self.tokenizer.encode(text)
679 }
680
681 pub fn export_adapter(
685 &self,
686 output_path: &std::path::Path,
687 lora_alpha: f32,
688 ) -> Result<(), String> {
689 use safetensors::tensor::{serialize_to_file, Dtype, TensorView};
690
691 let mut tensors: Vec<(String, Vec<f32>, Vec<usize>)> = Vec::new();
693
694 for (layer_idx, layer_lora) in self.lora.iter().enumerate() {
695 for proj in &layer_lora.projections {
696 if !self.lora_target_set.iter().any(|t| t == &proj.name) {
697 continue;
698 }
699 let a = self.trainer.download(&proj.a);
700 let b = self.trainer.download(&proj.b);
701 let base = format!("layer.{layer_idx}.{}", proj.name);
703 tensors.push((
704 format!("{base}.lora_a"),
705 a,
706 vec![proj.in_dim as usize, self.lora_rank],
707 ));
708 tensors.push((
709 format!("{base}.lora_b"),
710 b,
711 vec![self.lora_rank, proj.out_dim as usize],
712 ));
713 }
714 }
715
716 let byte_tensors: Vec<(String, Vec<u8>, Vec<usize>)> = tensors
718 .into_iter()
719 .map(|(name, data, shape)| (name, bytemuck::cast_slice(&data).to_vec(), shape))
720 .collect();
721
722 let views: Vec<(&str, TensorView<'_>)> = byte_tensors
723 .iter()
724 .map(|(name, bytes, shape)| {
725 let view =
726 TensorView::new(Dtype::F32, shape.clone(), bytes).expect("valid F32 tensor");
727 (name.as_str(), view)
728 })
729 .collect();
730
731 if let Some(parent) = output_path.parent() {
733 std::fs::create_dir_all(parent).map_err(|e| format!("mkdir: {e}"))?;
734 }
735
736 let st_path = if output_path.extension().is_some() {
738 output_path.to_path_buf()
739 } else {
740 output_path.join("adapter.safetensors")
741 };
742
743 let metadata: Option<std::collections::HashMap<String, String>> =
744 Some(std::collections::HashMap::from([
745 ("lora_rank".to_string(), self.lora_rank.to_string()),
746 ("lora_alpha".to_string(), lora_alpha.to_string()),
747 ]));
748 serialize_to_file(views, metadata, &st_path)
749 .map_err(|e| format!("safetensors write: {e}"))?;
750
751 eprintln!(
752 "[wgpu] {} LoRA tensors saved ({} layers × 7 projections × A/B)",
753 byte_tensors.len(),
754 self.num_layers
755 );
756 Ok(())
757 }
758
759 pub fn train_step(&mut self, prompt_ids: &[u32], response_ids: &[u32]) -> InstructStepResult {
763 let t0 = std::time::Instant::now();
764
765 let full_ids: Vec<u32> = prompt_ids.iter().chain(response_ids).copied().collect();
766 let seq_len = full_ids.len().min(self.max_seq_len);
767 let full_ids = &full_ids[..seq_len];
768 let prompt_len = prompt_ids.len().min(seq_len);
769
770 let loss_start = prompt_len.saturating_sub(1);
771 let loss_end = seq_len - 1;
772 let num_loss_tokens = loss_end.saturating_sub(loss_start);
773
774 if num_loss_tokens == 0 {
775 return InstructStepResult { loss: 0.0, num_response_tokens: 0, perplexity: 1.0 };
776 }
777
778 let mut hidden = Vec::with_capacity(seq_len * self.hidden_dim);
780 for &tok in full_ids {
781 let offset = (tok as usize) * self.hidden_dim;
782 let end = offset + self.hidden_dim;
783 if end <= self.embed_weights.len() {
784 hidden.extend_from_slice(&self.embed_weights[offset..end]);
785 } else {
786 hidden.extend(std::iter::repeat_n(0.0f32, self.hidden_dim));
787 }
788 }
789
790 let t1 = std::time::Instant::now();
791
792 if self.lora_step == 0 {
794 let h_norm: f32 = hidden.iter().map(|x| x * x).sum::<f32>().sqrt();
795 let h_mean: f32 = hidden.iter().sum::<f32>() / hidden.len() as f32;
796 eprintln!(
797 "[DIAG-509] embed: norm={h_norm:.4}, mean={h_mean:.6}, len={}, seq={seq_len}",
798 hidden.len()
799 );
800 let tok264_offset = 264 * self.hidden_dim;
802 if tok264_offset + 5 < self.embed_weights.len() {
803 let tok264: Vec<f32> =
804 self.embed_weights[tok264_offset..tok264_offset + 5].to_vec();
805 eprintln!("[DIAG-509] embed[264,:5]={tok264:?} (PyTorch: [-0.0295, 0.0035, 0.0193, 0.0020, 0.0049])");
806 }
807 }
808
809 self.fwd.queue_ref().write_buffer(
813 self.fwd.hidden_buffer(),
814 0,
815 bytemuck::cast_slice(&hidden),
816 );
817
818 let mut _saved_activations = Vec::with_capacity(self.num_layers);
819 for layer_idx in 0..self.num_layers {
820 let prefix = format!("layer.{layer_idx}");
821 let qkv_lora = if layer_idx < self.lora.len() {
823 let lp = &self.lora[layer_idx].projections;
824 let q = lp.iter().find(|p| p.name == "q_proj");
826 let k = lp.iter().find(|p| p.name == "k_proj");
827 let v = lp.iter().find(|p| p.name == "v_proj");
828 match (q, k, v) {
829 (Some(qp), Some(kp), Some(vp)) => Some(trueno::backends::gpu::QkvLoRA {
830 q_a: &qp.a,
831 q_b: &qp.b,
832 k_a: &kp.a,
833 k_b: &kp.b,
834 v_a: &vp.a,
835 v_b: &vp.b,
836 rank: self.lora_rank as u32,
837 scale: self.lora_scale,
838 in_dim: qp.in_dim,
839 q_dim: qp.out_dim,
840 kv_dim: kp.out_dim,
841 lora_pipeline: &self.lora_addmm_pipeline,
842 lora_bgl: &self.lora_addmm_bgl,
843 }),
844 _ => None,
845 }
846 } else {
847 None
848 };
849
850 let saved = self.fwd.alloc_layer_activations(seq_len as u32);
852
853 if self.lora_step == 0 && layer_idx == 0 {
855 if let Err(e) = self.fwd.forward_layer_traced(
856 seq_len as u32,
857 &prefix,
858 &saved,
859 qkv_lora.as_ref(),
860 ) {
861 eprintln!("[wgpu] traced forward failed: {e}");
862 }
863 } else {
864 let mut encoder = self.fwd.device_ref().create_command_encoder(&Default::default());
865 if let Err(e) = self.fwd.encode_forward_layer_training(
866 &mut encoder,
867 seq_len as u32,
868 &prefix,
869 &saved,
870 qkv_lora.as_ref(),
871 ) {
872 eprintln!("[wgpu] GPU forward layer {layer_idx} failed: {e}");
873 return InstructStepResult {
874 loss: 100.0,
875 num_response_tokens: num_loss_tokens,
876 perplexity: 1e6,
877 };
878 }
879 self.fwd.queue_ref().submit(Some(encoder.finish()));
880 }
881 _saved_activations.push(saved);
882
883 if self.lora_step == 0
888 && (layer_idx == 0 || layer_idx == 1 || layer_idx == self.num_layers - 1)
889 {
890 let n_floats = seq_len * self.hidden_dim;
891 let h = self.fwd.download_hidden(n_floats);
892 let norm: f32 = h.iter().map(|x| x * x).sum::<f32>().sqrt();
893 let nan_c = h.iter().filter(|x| x.is_nan()).count();
894 let first5: Vec<f32> = h.iter().take(5).copied().collect();
895 eprintln!("[DIAG-509] after layer {layer_idx}: norm={norm:.4}, nan={nan_c}, first5={first5:?}");
896 }
897 }
898
899 let t2 = std::time::Instant::now();
900
901 if self.lora_step == 0 {
903 let n_floats = seq_len * self.hidden_dim;
904 let h_data = self.fwd.download_hidden(n_floats);
905 let h_norm: f32 = h_data.iter().map(|x| x * x).sum::<f32>().sqrt();
906 let h_mean: f32 = h_data.iter().sum::<f32>() / h_data.len() as f32;
907 let nan_count = h_data.iter().filter(|x| x.is_nan()).count();
908 let inf_count = h_data.iter().filter(|x| x.is_infinite()).count();
909 let first5: Vec<f32> = h_data.iter().take(5).copied().collect();
910 eprintln!("[DIAG-509] post-layers: norm={h_norm:.4}, mean={h_mean:.6}, nan={nan_count}, inf={inf_count}, first5={first5:?}");
911 }
912
913 let _t2a = std::time::Instant::now();
915 self.fwd.gpu_rmsnorm(&self.output_norm_gpu, &self.normed_buf, seq_len as u32);
916 let t2b = std::time::Instant::now();
917 let _t2c = t2b;
918 let labels: Vec<u32> = (0..seq_len)
920 .map(|i| if i + 1 < full_ids.len() { full_ids[i + 1] } else { 0 })
921 .collect();
922
923 let mut col_offset = 0u64;
924 for (chunk_buf, chunk_n) in &self.lm_head_t_chunks {
925 let cn = u64::from(*chunk_n);
926 let c_chunk = self.trainer.zeros((seq_len as u64 * cn) as usize);
927 self.trainer.matmul_forward(
928 &self.normed_buf,
929 chunk_buf,
930 &c_chunk,
931 seq_len as u32,
932 self.hidden_dim as u32,
933 *chunk_n,
934 );
935 self.dispatch_scatter(
937 &c_chunk,
938 &self.logits_buf,
939 seq_len as u32,
940 *chunk_n,
941 self.vocab_size as u32,
942 col_offset as u32,
943 );
944 col_offset += cn;
945 }
946
947 let t3 = std::time::Instant::now();
948
949 if self.lora_step == 0 {
951 let logits_data = self.trainer.download(&self.logits_buf);
952 let l_norm: f32 =
953 logits_data.iter().take(self.vocab_size).map(|x| x * x).sum::<f32>().sqrt();
954 let l_max =
955 logits_data.iter().take(self.vocab_size).cloned().fold(f32::NEG_INFINITY, f32::max);
956 let l_min =
957 logits_data.iter().take(self.vocab_size).cloned().fold(f32::INFINITY, f32::min);
958 let nan_count = logits_data.iter().take(self.vocab_size).filter(|x| x.is_nan()).count();
959 let zero_count =
961 logits_data.iter().take(self.vocab_size).filter(|x| **x == 0.0).count();
962 eprintln!("[DIAG-509] logits[0]: norm={l_norm:.4}, min={l_min:.4}, max={l_max:.4}, nan={nan_count}, zeros={zero_count}/{}", self.vocab_size);
963 let argmax = logits_data
965 .iter()
966 .take(self.vocab_size)
967 .enumerate()
968 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
969 .map(|(i, v)| (i, *v));
970 let target = labels[0];
971 let target_logit = if (target as usize) < self.vocab_size {
972 logits_data[target as usize]
973 } else {
974 f32::NAN
975 };
976 eprintln!("[DIAG-509] pos0: argmax={argmax:?}, target={target}, target_logit={target_logit:.4}, loss_range=[{loss_start},{loss_end})");
977 }
978
979 self.trainer.queue_ref().write_buffer(&self.labels_buf, 0, bytemuck::cast_slice(&labels));
981
982 let t3a = std::time::Instant::now();
986 self.cross_entropy.forward_async(
987 &self.logits_buf,
988 &self.labels_buf,
989 &self.losses_buf,
990 &self.logsumexp_buf,
991 seq_len as u32,
992 self.vocab_size as u32,
993 loss_start as u32,
994 loss_end as u32,
995 );
996
997 let t3b = std::time::Instant::now();
998 self.cross_entropy.backward(
1000 &self.logits_buf,
1001 &self.labels_buf,
1002 &self.logsumexp_buf,
1003 seq_len as u32,
1004 self.vocab_size as u32,
1005 loss_start as u32,
1006 loss_end as u32,
1007 );
1008
1009 let t3c = std::time::Instant::now();
1010
1011 let grad_hidden_buf = self.trainer.zeros(seq_len * self.hidden_dim);
1015 let mut row_offset = 0u64;
1016 for (chunk_buf, chunk_k) in &self.lm_head_chunks {
1017 let ck = u64::from(*chunk_k);
1018 let gl_chunk = self.trainer.zeros((seq_len as u64 * ck) as usize);
1019 self.dispatch_gather(
1020 &self.logits_buf,
1021 &gl_chunk,
1022 seq_len as u32,
1023 *chunk_k,
1024 self.vocab_size as u32,
1025 row_offset as u32,
1026 );
1027 let gh_chunk = self.trainer.zeros(seq_len * self.hidden_dim);
1029 self.trainer.matmul_forward(
1030 &gl_chunk,
1031 chunk_buf,
1032 &gh_chunk,
1033 seq_len as u32,
1034 *chunk_k,
1035 self.hidden_dim as u32,
1036 );
1037 let sum_buf = self.trainer.zeros(seq_len * self.hidden_dim);
1039 self.fwd.gpu_residual_add(
1040 &grad_hidden_buf,
1041 &gh_chunk,
1042 &sum_buf,
1043 (seq_len * self.hidden_dim) as u32,
1044 );
1045 let mut enc = self.fwd.device_ref().create_command_encoder(&Default::default());
1047 enc.copy_buffer_to_buffer(
1048 &sum_buf,
1049 0,
1050 &grad_hidden_buf,
1051 0,
1052 (seq_len * self.hidden_dim * 4) as u64,
1053 );
1054 self.fwd.queue_ref().submit(Some(enc.finish()));
1055 row_offset += ck;
1056 }
1057
1058 let t4 = std::time::Instant::now();
1059
1060 if self.lora_step == 1 {
1062 let b0 = self.trainer.download(&self.lora[0].projections[0].b);
1063 let b_norm: f32 = b0.iter().map(|x| x * x).sum::<f32>().sqrt();
1064 eprintln!("[FALSIFY] step=1 B[0].q_proj norm={b_norm:.6}");
1065 }
1066
1067 self.lora_step += 1;
1074 let lr = self.learning_rate;
1077 let s = seq_len as u32;
1078 let rank = self.lora_rank as u32;
1079
1080 let h = self.hidden_dim as u32;
1094 let mut grad_buf = grad_hidden_buf;
1095
1096 for layer_idx in (0..self.lora.len()).rev() {
1097 let layer_lora = &self.lora[layer_idx];
1098 let saved = &_saved_activations[layer_idx];
1099
1100 for proj in &layer_lora.projections {
1101 if !self.lora_target_set.iter().any(|t| t == &proj.name) {
1102 continue;
1103 }
1104 let input_buf = match proj.name.as_str() {
1106 "q_proj" | "k_proj" | "v_proj" => &saved.attn_norm_out,
1107 "o_proj" => &saved.attn_output,
1108 "gate_proj" | "up_proj" => &saved.ffn_norm_out,
1109 "down_proj" => &saved.silu_gate_output,
1110 _ => continue,
1111 };
1112
1113 let scale = self.lora_scale;
1116
1117 let xa = self.trainer.zeros((s * rank) as usize);
1119 self.trainer.matmul_forward(input_buf, &proj.a, &xa, s, proj.in_dim, rank);
1120
1121 let xa_t = self.trainer.zeros((s * rank) as usize);
1123 self.dispatch_transpose(&xa, &xa_t, s, rank, scale);
1124 let db = self.trainer.zeros((rank * proj.out_dim) as usize);
1125 self.trainer.matmul_forward(&xa_t, &grad_buf, &db, rank, s, proj.out_dim);
1126
1127 let da = if self.lora_step <= 1 {
1129 self.trainer.zeros((proj.in_dim * rank) as usize)
1130 } else {
1131 let bt = self.trainer.zeros((rank * proj.out_dim) as usize);
1132 self.dispatch_transpose(&proj.b, &bt, rank, proj.out_dim, 1.0);
1133 let d_xa = self.trainer.zeros((s * rank) as usize);
1134 self.trainer.matmul_forward(&grad_buf, &bt, &d_xa, s, proj.out_dim, rank);
1135 let xt = self.trainer.zeros((s * proj.in_dim) as usize);
1136 self.dispatch_transpose(input_buf, &xt, s, proj.in_dim, scale);
1137 let da_buf = self.trainer.zeros((proj.in_dim * rank) as usize);
1138 self.trainer.matmul_forward(&xt, &d_xa, &da_buf, proj.in_dim, s, rank);
1139 da_buf
1140 };
1141
1142 self.trainer
1145 .adamw_step(&proj.a, &da, &proj.m_a, &proj.v_a, lr, 0.9, 0.999, 1e-8, 0.01);
1146 self.trainer
1147 .adamw_step(&proj.b, &db, &proj.m_b, &proj.v_b, lr, 0.9, 0.999, 1e-8, 0.01);
1148 }
1149
1150 }
1156
1157 let t5 = std::time::Instant::now();
1158
1159 eprintln!(
1160 "[PROFILE] step: {:.0}ms (embed={:.0} fwd={:.0} lm={:.0} ce={:.0}[fwd={:.0} bwd={:.0}] lm_bwd={:.0} lora_bwd={:.0})",
1161 t5.duration_since(t0).as_millis(),
1162 t1.duration_since(t0).as_millis(),
1163 t2.duration_since(t1).as_millis(),
1164 t3.duration_since(t2).as_millis(),
1165 t3c.duration_since(t3).as_millis(),
1166 t3b.duration_since(t3a).as_millis(),
1167 t3c.duration_since(t3b).as_millis(),
1168 t4.duration_since(t3c).as_millis(),
1169 t5.duration_since(t4).as_millis(),
1170 );
1171
1172 let avg_loss = self.cross_entropy.read_loss(
1175 &self.losses_buf,
1176 seq_len as u32,
1177 loss_start as u32,
1178 loss_end as u32,
1179 );
1180
1181 InstructStepResult {
1182 loss: if avg_loss.is_finite() { avg_loss } else { 100.0 },
1183 num_response_tokens: num_loss_tokens,
1184 perplexity: if avg_loss.is_finite() { avg_loss.exp().min(1e6) } else { 1e6 },
1185 }
1186 }
1187}
1188
1189#[cfg(feature = "gpu")]
1190impl WgpuInstructPipeline {
1191 pub fn dpo_step(
1201 &mut self,
1202 prompt_ids: &[u32],
1203 chosen_ids: &[u32],
1204 rejected_ids: &[u32],
1205 beta: f32,
1206 ) -> f32 {
1207 let chosen_logprob = self.compute_sequence_logprob(prompt_ids, chosen_ids);
1209 let rejected_logprob = self.compute_sequence_logprob(prompt_ids, rejected_ids);
1211
1212 let delta = chosen_logprob - rejected_logprob;
1214 let sigmoid_arg = beta * delta;
1215 let sigmoid_val = 1.0 / (1.0 + (-sigmoid_arg).exp());
1216 let loss = -(sigmoid_val.max(1e-7)).ln();
1217
1218 debug_assert!(loss >= 0.0, "DPO loss must be non-negative: {loss}");
1220
1221 loss
1222 }
1223
1224 fn compute_sequence_logprob(&mut self, prompt_ids: &[u32], response_ids: &[u32]) -> f32 {
1227 let result = self.train_step(prompt_ids, response_ids);
1229 let num_tokens = response_ids.len() as f32;
1232 -result.loss * num_tokens
1233 }
1234}