#[cfg(feature = "gpu")]
use crate::{
autograd::{wgpu_cross_entropy::WgslCrossEntropy, wgpu_training::WgpuTrainer},
finetune::instruct_pipeline::InstructStepResult,
tokenizer::HfTokenizer,
};
#[cfg(feature = "gpu")]
use trueno::backends::gpu::{wgpu, WgslForwardPass};
#[cfg(feature = "gpu")]
pub struct LayerLoRA {
pub projections: Vec<LoRAProjection>,
}
#[cfg(feature = "gpu")]
pub struct LoRAProjection {
pub a: wgpu::Buffer, pub b: wgpu::Buffer, pub m_a: wgpu::Buffer, pub v_a: wgpu::Buffer, pub m_b: wgpu::Buffer, pub v_b: wgpu::Buffer, pub in_dim: u32,
pub out_dim: u32,
pub name: String, }
#[cfg(feature = "gpu")]
pub struct WgpuInstructPipeline {
fwd: WgslForwardPass,
cross_entropy: WgslCrossEntropy,
trainer: WgpuTrainer,
lm_head_t_chunks: Vec<(wgpu::Buffer, u32)>, lm_head_chunks: Vec<(wgpu::Buffer, u32)>, lora_addmm_pipeline: wgpu::ComputePipeline,
lora_addmm_bgl: wgpu::BindGroupLayout,
scatter_pipeline: wgpu::ComputePipeline,
gather_pipeline: wgpu::ComputePipeline,
scatter_bgl: wgpu::BindGroupLayout,
transpose_pipeline: wgpu::ComputePipeline,
transpose_bgl: wgpu::BindGroupLayout,
logits_buf: wgpu::Buffer,
labels_buf: wgpu::Buffer,
losses_buf: wgpu::Buffer,
logsumexp_buf: wgpu::Buffer,
lora: Vec<LayerLoRA>,
lora_rank: usize,
lora_scale: f32, lora_step: u32, learning_rate: f32, lora_target_set: Vec<String>, num_layers: usize,
hidden_dim: usize,
vocab_size: usize,
max_seq_len: usize,
tokenizer: HfTokenizer,
embed_weights: Vec<f32>,
output_norm_gpu: wgpu::Buffer,
normed_buf: wgpu::Buffer,
eps: f32,
}
#[cfg(feature = "gpu")]
impl WgpuInstructPipeline {
pub fn new(
fwd: WgslForwardPass,
trainer: WgpuTrainer,
tokenizer: HfTokenizer,
embed_weights: Vec<f32>,
output_norm: Vec<f32>,
lm_head_t_chunks: Vec<(wgpu::Buffer, u32)>,
lm_head_chunks: Vec<(wgpu::Buffer, u32)>,
num_layers: usize,
hidden_dim: usize,
vocab_size: usize,
max_seq_len: usize,
num_heads: usize,
num_kv_heads: usize,
intermediate_dim: usize,
lora_rank: usize,
lora_alpha: f32,
lora_targets: &[&str],
eps: f32,
learning_rate: f32,
) -> Self {
let ce = WgslCrossEntropy::new(trainer.device_ref().clone(), trainer.queue_ref().clone());
let seq = max_seq_len as u32;
let vocab = vocab_size as u32;
let make_buf = |size: u64, label: &str| -> wgpu::Buffer {
trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: size * 4,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
})
};
let scatter_bgl =
trainer.device_ref().create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("scatter_bgl"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let scatter_pl =
trainer.device_ref().create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("scatter_pl"),
bind_group_layouts: &[&scatter_bgl],
push_constant_ranges: &[],
});
let scatter_shader =
trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("scatter"),
source: wgpu::ShaderSource::Wgsl(
trueno::backends::gpu::shaders::COLUMN_SCATTER_SHADER.into(),
),
});
let scatter_pipeline =
trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("scatter_pipe"),
layout: Some(&scatter_pl),
module: &scatter_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let gather_shader =
trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("gather"),
source: wgpu::ShaderSource::Wgsl(
trueno::backends::gpu::shaders::COLUMN_GATHER_SHADER.into(),
),
});
let gather_pipeline =
trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("gather_pipe"),
layout: Some(&scatter_pl),
module: &gather_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let transpose_bgl =
trainer.device_ref().create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("transpose_bgl"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let transpose_pl =
trainer.device_ref().create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("transpose_pl"),
bind_group_layouts: &[&transpose_bgl],
push_constant_ranges: &[],
});
let transpose_shader =
trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("transpose"),
source: wgpu::ShaderSource::Wgsl(
trueno::backends::gpu::shaders::TRANSPOSE_SHADER.into(),
),
});
let transpose_pipeline =
trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("transpose_pipe"),
layout: Some(&transpose_pl),
module: &transpose_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let lora_bgl =
trainer.device_ref().create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("lora_bgl"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let lora_pl =
trainer.device_ref().create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("lora_pl"),
bind_group_layouts: &[&lora_bgl],
push_constant_ranges: &[],
});
let lora_shader = trainer.device_ref().create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("lora_addmm"),
source: wgpu::ShaderSource::Wgsl(
trueno::backends::gpu::shaders::LORA_ADDMM_SHADER.into(),
),
});
let lora_addmm_pipeline =
trainer.device_ref().create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("lora_addmm_pipe"),
layout: Some(&lora_pl),
module: &lora_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let lora_addmm_bgl = lora_bgl;
let r = lora_rank;
let scale = lora_alpha / r as f32;
let h = hidden_dim;
let q_dim = num_heads * (hidden_dim / num_heads);
let kv_dim = num_kv_heads * (hidden_dim / num_heads);
let inter = intermediate_dim;
let all_proj_dims: &[(&str, usize, usize)] = &[
("q_proj", h, q_dim),
("k_proj", h, kv_dim),
("v_proj", h, kv_dim),
("o_proj", q_dim, h),
("gate_proj", h, inter),
("up_proj", h, inter),
("down_proj", inter, h),
];
let _use_all = true; let proj_dims: Vec<(&str, usize, usize)> = all_proj_dims.to_vec();
let num_targets = proj_dims.len();
let qkv_only = vec!["q_proj".to_string(), "k_proj".to_string(), "v_proj".to_string()];
let lora_target_set: Vec<String> =
if lora_targets.is_empty() || lora_targets.contains(&"all") {
qkv_only
} else {
lora_targets.iter().map(std::string::ToString::to_string).collect()
};
let mut lora = Vec::with_capacity(num_layers);
for layer_idx in 0..num_layers {
let mut projections = Vec::with_capacity(num_targets);
for &(name, in_d, out_d) in &proj_dims {
let std = (2.0 / in_d as f32).sqrt();
let a_data: Vec<f32> = (0..in_d * r)
.map(|i| ((i as f32 * 0.013 + layer_idx as f32 * 7.0).sin() * std))
.collect();
let b_data = vec![0.0f32; r * out_d];
let zeros_a = vec![0.0f32; in_d * r];
let zeros_b = vec![0.0f32; r * out_d];
projections.push(LoRAProjection {
a: trainer.upload(&a_data),
b: trainer.upload(&b_data),
m_a: trainer.upload(&zeros_a),
v_a: trainer.upload(&zeros_a),
m_b: trainer.upload(&zeros_b),
v_b: trainer.upload(&zeros_b),
in_dim: in_d as u32,
out_dim: out_d as u32,
name: name.to_string(),
});
}
lora.push(LayerLoRA { projections });
}
eprintln!(
"[wgpu] LoRA initialized: {num_layers} layers × {num_targets} projections, rank={r}, scale={scale:.2}",
);
let logits_buf = make_buf(u64::from(seq) * u64::from(vocab), "logits");
let labels_buf = make_buf(u64::from(seq), "labels");
let losses_buf = make_buf(u64::from(seq), "losses");
let logsumexp_buf = make_buf(u64::from(seq), "logsumexp");
let normed_buf_alloc = make_buf(u64::from(seq) * hidden_dim as u64, "normed");
let output_norm_gpu_buf = trainer.upload(&output_norm);
Self {
fwd,
cross_entropy: ce,
logits_buf,
labels_buf,
losses_buf,
logsumexp_buf,
lm_head_t_chunks,
lm_head_chunks,
lora_addmm_pipeline,
lora_addmm_bgl,
scatter_pipeline,
gather_pipeline,
transpose_pipeline,
transpose_bgl,
scatter_bgl,
trainer,
lora,
lora_rank: r,
lora_scale: scale,
lora_step: 0,
learning_rate,
lora_target_set: lora_target_set,
num_layers,
hidden_dim,
vocab_size,
max_seq_len,
tokenizer,
embed_weights,
output_norm_gpu: output_norm_gpu_buf,
normed_buf: normed_buf_alloc,
eps,
}
}
fn dispatch_lora_addmm(
&self,
input: &wgpu::Buffer,
lora_a: &wgpu::Buffer,
lora_b: &wgpu::Buffer,
output: &wgpu::Buffer,
seq_len: u32,
in_dim: u32,
rank: u32,
out_dim: u32,
) {
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct P {
seq_len: u32,
in_dim: u32,
rank: u32,
out_dim: u32,
scale: f32,
_p0: u32,
_p1: u32,
_p2: u32,
}
let params =
P { seq_len, in_dim, rank, out_dim, scale: self.lora_scale, _p0: 0, _p1: 0, _p2: 0 };
let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 32,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(¶ms));
let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.lora_addmm_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: lora_a.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: lora_b.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: output.as_entire_binding() },
wgpu::BindGroupEntry { binding: 4, resource: pbuf.as_entire_binding() },
],
});
let total = seq_len * out_dim;
let wg = total.div_ceil(256);
let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&self.lora_addmm_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(x, y, 1);
}
self.trainer.queue_ref().submit(Some(encoder.finish()));
}
fn dispatch_scatter(
&self,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
seq_len: u32,
chunk_n: u32,
full_n: u32,
col_offset: u32,
) {
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct P {
seq_len: u32,
chunk_n: u32,
full_n: u32,
col_offset: u32,
}
let params = P { seq_len, chunk_n, full_n, col_offset };
let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 16,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(¶ms));
let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.scatter_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: src.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: pbuf.as_entire_binding() },
],
});
let total = seq_len * chunk_n;
let wg = total.div_ceil(256);
let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&self.scatter_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(x, y, 1);
}
self.trainer.queue_ref().submit(Some(encoder.finish()));
}
fn dispatch_gather(
&self,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
seq_len: u32,
chunk_n: u32,
full_n: u32,
col_offset: u32,
) {
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct P {
seq_len: u32,
chunk_n: u32,
full_n: u32,
col_offset: u32,
}
let params = P { seq_len, chunk_n, full_n, col_offset };
let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 16,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(¶ms));
let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.scatter_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: src.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: pbuf.as_entire_binding() },
],
});
let total = seq_len * chunk_n;
let wg = total.div_ceil(256);
let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&self.gather_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(x, y, 1);
}
self.trainer.queue_ref().submit(Some(encoder.finish()));
}
fn dispatch_transpose(
&self,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
m: u32,
n: u32,
scale: f32,
) {
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct P {
m: u32,
n: u32,
scale: f32,
_pad: u32,
}
let params = P { m, n, scale, _pad: 0 };
let pbuf = self.trainer.device_ref().create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 16,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.trainer.queue_ref().write_buffer(&pbuf, 0, bytemuck::bytes_of(¶ms));
let bg = self.trainer.device_ref().create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.transpose_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: src.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: dst.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: pbuf.as_entire_binding() },
],
});
let total = m * n;
let wg = total.div_ceil(256);
let (x, y) = if wg <= 65535 { (wg, 1) } else { (65535, wg.div_ceil(65535)) };
let mut encoder = self.trainer.device_ref().create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&self.transpose_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(x, y, 1);
}
self.trainer.queue_ref().submit(Some(encoder.finish()));
}
pub fn encode(&self, text: &str) -> Vec<u32> {
self.tokenizer.encode(text)
}
pub fn export_adapter(
&self,
output_path: &std::path::Path,
lora_alpha: f32,
) -> Result<(), String> {
use safetensors::tensor::{serialize_to_file, Dtype, TensorView};
let mut tensors: Vec<(String, Vec<f32>, Vec<usize>)> = Vec::new();
for (layer_idx, layer_lora) in self.lora.iter().enumerate() {
for proj in &layer_lora.projections {
if !self.lora_target_set.iter().any(|t| t == &proj.name) {
continue;
}
let a = self.trainer.download(&proj.a);
let b = self.trainer.download(&proj.b);
let base = format!("layer.{layer_idx}.{}", proj.name);
tensors.push((
format!("{base}.lora_a"),
a,
vec![proj.in_dim as usize, self.lora_rank],
));
tensors.push((
format!("{base}.lora_b"),
b,
vec![self.lora_rank, proj.out_dim as usize],
));
}
}
let byte_tensors: Vec<(String, Vec<u8>, Vec<usize>)> = tensors
.into_iter()
.map(|(name, data, shape)| (name, bytemuck::cast_slice(&data).to_vec(), shape))
.collect();
let views: Vec<(&str, TensorView<'_>)> = byte_tensors
.iter()
.map(|(name, bytes, shape)| {
let view =
TensorView::new(Dtype::F32, shape.clone(), bytes).expect("valid F32 tensor");
(name.as_str(), view)
})
.collect();
if let Some(parent) = output_path.parent() {
std::fs::create_dir_all(parent).map_err(|e| format!("mkdir: {e}"))?;
}
let st_path = if output_path.extension().is_some() {
output_path.to_path_buf()
} else {
output_path.join("adapter.safetensors")
};
let metadata: Option<std::collections::HashMap<String, String>> =
Some(std::collections::HashMap::from([
("lora_rank".to_string(), self.lora_rank.to_string()),
("lora_alpha".to_string(), lora_alpha.to_string()),
]));
serialize_to_file(views, metadata, &st_path)
.map_err(|e| format!("safetensors write: {e}"))?;
eprintln!(
"[wgpu] {} LoRA tensors saved ({} layers × 7 projections × A/B)",
byte_tensors.len(),
self.num_layers
);
Ok(())
}
pub fn train_step(&mut self, prompt_ids: &[u32], response_ids: &[u32]) -> InstructStepResult {
let t0 = std::time::Instant::now();
let full_ids: Vec<u32> = prompt_ids.iter().chain(response_ids).copied().collect();
let seq_len = full_ids.len().min(self.max_seq_len);
let full_ids = &full_ids[..seq_len];
let prompt_len = prompt_ids.len().min(seq_len);
let loss_start = prompt_len.saturating_sub(1);
let loss_end = seq_len - 1;
let num_loss_tokens = loss_end.saturating_sub(loss_start);
if num_loss_tokens == 0 {
return InstructStepResult { loss: 0.0, num_response_tokens: 0, perplexity: 1.0 };
}
let mut hidden = Vec::with_capacity(seq_len * self.hidden_dim);
for &tok in full_ids {
let offset = (tok as usize) * self.hidden_dim;
let end = offset + self.hidden_dim;
if end <= self.embed_weights.len() {
hidden.extend_from_slice(&self.embed_weights[offset..end]);
} else {
hidden.extend(std::iter::repeat_n(0.0f32, self.hidden_dim));
}
}
let t1 = std::time::Instant::now();
if self.lora_step == 0 {
let h_norm: f32 = hidden.iter().map(|x| x * x).sum::<f32>().sqrt();
let h_mean: f32 = hidden.iter().sum::<f32>() / hidden.len() as f32;
eprintln!(
"[DIAG-509] embed: norm={h_norm:.4}, mean={h_mean:.6}, len={}, seq={seq_len}",
hidden.len()
);
let tok264_offset = 264 * self.hidden_dim;
if tok264_offset + 5 < self.embed_weights.len() {
let tok264: Vec<f32> =
self.embed_weights[tok264_offset..tok264_offset + 5].to_vec();
eprintln!("[DIAG-509] embed[264,:5]={tok264:?} (PyTorch: [-0.0295, 0.0035, 0.0193, 0.0020, 0.0049])");
}
}
self.fwd.queue_ref().write_buffer(
self.fwd.hidden_buffer(),
0,
bytemuck::cast_slice(&hidden),
);
let mut _saved_activations = Vec::with_capacity(self.num_layers);
for layer_idx in 0..self.num_layers {
let prefix = format!("layer.{layer_idx}");
let qkv_lora = if layer_idx < self.lora.len() {
let lp = &self.lora[layer_idx].projections;
let q = lp.iter().find(|p| p.name == "q_proj");
let k = lp.iter().find(|p| p.name == "k_proj");
let v = lp.iter().find(|p| p.name == "v_proj");
match (q, k, v) {
(Some(qp), Some(kp), Some(vp)) => Some(trueno::backends::gpu::QkvLoRA {
q_a: &qp.a,
q_b: &qp.b,
k_a: &kp.a,
k_b: &kp.b,
v_a: &vp.a,
v_b: &vp.b,
rank: self.lora_rank as u32,
scale: self.lora_scale,
in_dim: qp.in_dim,
q_dim: qp.out_dim,
kv_dim: kp.out_dim,
lora_pipeline: &self.lora_addmm_pipeline,
lora_bgl: &self.lora_addmm_bgl,
}),
_ => None,
}
} else {
None
};
let saved = self.fwd.alloc_layer_activations(seq_len as u32);
if self.lora_step == 0 && layer_idx == 0 {
if let Err(e) = self.fwd.forward_layer_traced(
seq_len as u32,
&prefix,
&saved,
qkv_lora.as_ref(),
) {
eprintln!("[wgpu] traced forward failed: {e}");
}
} else {
let mut encoder = self.fwd.device_ref().create_command_encoder(&Default::default());
if let Err(e) = self.fwd.encode_forward_layer_training(
&mut encoder,
seq_len as u32,
&prefix,
&saved,
qkv_lora.as_ref(),
) {
eprintln!("[wgpu] GPU forward layer {layer_idx} failed: {e}");
return InstructStepResult {
loss: 100.0,
num_response_tokens: num_loss_tokens,
perplexity: 1e6,
};
}
self.fwd.queue_ref().submit(Some(encoder.finish()));
}
_saved_activations.push(saved);
if self.lora_step == 0
&& (layer_idx == 0 || layer_idx == 1 || layer_idx == self.num_layers - 1)
{
let n_floats = seq_len * self.hidden_dim;
let h = self.fwd.download_hidden(n_floats);
let norm: f32 = h.iter().map(|x| x * x).sum::<f32>().sqrt();
let nan_c = h.iter().filter(|x| x.is_nan()).count();
let first5: Vec<f32> = h.iter().take(5).copied().collect();
eprintln!("[DIAG-509] after layer {layer_idx}: norm={norm:.4}, nan={nan_c}, first5={first5:?}");
}
}
let t2 = std::time::Instant::now();
if self.lora_step == 0 {
let n_floats = seq_len * self.hidden_dim;
let h_data = self.fwd.download_hidden(n_floats);
let h_norm: f32 = h_data.iter().map(|x| x * x).sum::<f32>().sqrt();
let h_mean: f32 = h_data.iter().sum::<f32>() / h_data.len() as f32;
let nan_count = h_data.iter().filter(|x| x.is_nan()).count();
let inf_count = h_data.iter().filter(|x| x.is_infinite()).count();
let first5: Vec<f32> = h_data.iter().take(5).copied().collect();
eprintln!("[DIAG-509] post-layers: norm={h_norm:.4}, mean={h_mean:.6}, nan={nan_count}, inf={inf_count}, first5={first5:?}");
}
let _t2a = std::time::Instant::now();
self.fwd.gpu_rmsnorm(&self.output_norm_gpu, &self.normed_buf, seq_len as u32);
let t2b = std::time::Instant::now();
let _t2c = t2b;
let labels: Vec<u32> = (0..seq_len)
.map(|i| if i + 1 < full_ids.len() { full_ids[i + 1] } else { 0 })
.collect();
let mut col_offset = 0u64;
for (chunk_buf, chunk_n) in &self.lm_head_t_chunks {
let cn = u64::from(*chunk_n);
let c_chunk = self.trainer.zeros((seq_len as u64 * cn) as usize);
self.trainer.matmul_forward(
&self.normed_buf,
chunk_buf,
&c_chunk,
seq_len as u32,
self.hidden_dim as u32,
*chunk_n,
);
self.dispatch_scatter(
&c_chunk,
&self.logits_buf,
seq_len as u32,
*chunk_n,
self.vocab_size as u32,
col_offset as u32,
);
col_offset += cn;
}
let t3 = std::time::Instant::now();
if self.lora_step == 0 {
let logits_data = self.trainer.download(&self.logits_buf);
let l_norm: f32 =
logits_data.iter().take(self.vocab_size).map(|x| x * x).sum::<f32>().sqrt();
let l_max =
logits_data.iter().take(self.vocab_size).cloned().fold(f32::NEG_INFINITY, f32::max);
let l_min =
logits_data.iter().take(self.vocab_size).cloned().fold(f32::INFINITY, f32::min);
let nan_count = logits_data.iter().take(self.vocab_size).filter(|x| x.is_nan()).count();
let zero_count =
logits_data.iter().take(self.vocab_size).filter(|x| **x == 0.0).count();
eprintln!("[DIAG-509] logits[0]: norm={l_norm:.4}, min={l_min:.4}, max={l_max:.4}, nan={nan_count}, zeros={zero_count}/{}", self.vocab_size);
let argmax = logits_data
.iter()
.take(self.vocab_size)
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, v)| (i, *v));
let target = labels[0];
let target_logit = if (target as usize) < self.vocab_size {
logits_data[target as usize]
} else {
f32::NAN
};
eprintln!("[DIAG-509] pos0: argmax={argmax:?}, target={target}, target_logit={target_logit:.4}, loss_range=[{loss_start},{loss_end})");
}
self.trainer.queue_ref().write_buffer(&self.labels_buf, 0, bytemuck::cast_slice(&labels));
let t3a = std::time::Instant::now();
self.cross_entropy.forward_async(
&self.logits_buf,
&self.labels_buf,
&self.losses_buf,
&self.logsumexp_buf,
seq_len as u32,
self.vocab_size as u32,
loss_start as u32,
loss_end as u32,
);
let t3b = std::time::Instant::now();
self.cross_entropy.backward(
&self.logits_buf,
&self.labels_buf,
&self.logsumexp_buf,
seq_len as u32,
self.vocab_size as u32,
loss_start as u32,
loss_end as u32,
);
let t3c = std::time::Instant::now();
let grad_hidden_buf = self.trainer.zeros(seq_len * self.hidden_dim);
let mut row_offset = 0u64;
for (chunk_buf, chunk_k) in &self.lm_head_chunks {
let ck = u64::from(*chunk_k);
let gl_chunk = self.trainer.zeros((seq_len as u64 * ck) as usize);
self.dispatch_gather(
&self.logits_buf,
&gl_chunk,
seq_len as u32,
*chunk_k,
self.vocab_size as u32,
row_offset as u32,
);
let gh_chunk = self.trainer.zeros(seq_len * self.hidden_dim);
self.trainer.matmul_forward(
&gl_chunk,
chunk_buf,
&gh_chunk,
seq_len as u32,
*chunk_k,
self.hidden_dim as u32,
);
let sum_buf = self.trainer.zeros(seq_len * self.hidden_dim);
self.fwd.gpu_residual_add(
&grad_hidden_buf,
&gh_chunk,
&sum_buf,
(seq_len * self.hidden_dim) as u32,
);
let mut enc = self.fwd.device_ref().create_command_encoder(&Default::default());
enc.copy_buffer_to_buffer(
&sum_buf,
0,
&grad_hidden_buf,
0,
(seq_len * self.hidden_dim * 4) as u64,
);
self.fwd.queue_ref().submit(Some(enc.finish()));
row_offset += ck;
}
let t4 = std::time::Instant::now();
if self.lora_step == 1 {
let b0 = self.trainer.download(&self.lora[0].projections[0].b);
let b_norm: f32 = b0.iter().map(|x| x * x).sum::<f32>().sqrt();
eprintln!("[FALSIFY] step=1 B[0].q_proj norm={b_norm:.6}");
}
self.lora_step += 1;
let lr = self.learning_rate;
let s = seq_len as u32;
let rank = self.lora_rank as u32;
let h = self.hidden_dim as u32;
let mut grad_buf = grad_hidden_buf;
for layer_idx in (0..self.lora.len()).rev() {
let layer_lora = &self.lora[layer_idx];
let saved = &_saved_activations[layer_idx];
for proj in &layer_lora.projections {
if !self.lora_target_set.iter().any(|t| t == &proj.name) {
continue;
}
let input_buf = match proj.name.as_str() {
"q_proj" | "k_proj" | "v_proj" => &saved.attn_norm_out,
"o_proj" => &saved.attn_output,
"gate_proj" | "up_proj" => &saved.ffn_norm_out,
"down_proj" => &saved.silu_gate_output,
_ => continue,
};
let scale = self.lora_scale;
let xa = self.trainer.zeros((s * rank) as usize);
self.trainer.matmul_forward(input_buf, &proj.a, &xa, s, proj.in_dim, rank);
let xa_t = self.trainer.zeros((s * rank) as usize);
self.dispatch_transpose(&xa, &xa_t, s, rank, scale);
let db = self.trainer.zeros((rank * proj.out_dim) as usize);
self.trainer.matmul_forward(&xa_t, &grad_buf, &db, rank, s, proj.out_dim);
let da = if self.lora_step <= 1 {
self.trainer.zeros((proj.in_dim * rank) as usize)
} else {
let bt = self.trainer.zeros((rank * proj.out_dim) as usize);
self.dispatch_transpose(&proj.b, &bt, rank, proj.out_dim, 1.0);
let d_xa = self.trainer.zeros((s * rank) as usize);
self.trainer.matmul_forward(&grad_buf, &bt, &d_xa, s, proj.out_dim, rank);
let xt = self.trainer.zeros((s * proj.in_dim) as usize);
self.dispatch_transpose(input_buf, &xt, s, proj.in_dim, scale);
let da_buf = self.trainer.zeros((proj.in_dim * rank) as usize);
self.trainer.matmul_forward(&xt, &d_xa, &da_buf, proj.in_dim, s, rank);
da_buf
};
self.trainer
.adamw_step(&proj.a, &da, &proj.m_a, &proj.v_a, lr, 0.9, 0.999, 1e-8, 0.01);
self.trainer
.adamw_step(&proj.b, &db, &proj.m_b, &proj.v_b, lr, 0.9, 0.999, 1e-8, 0.01);
}
}
let t5 = std::time::Instant::now();
eprintln!(
"[PROFILE] step: {:.0}ms (embed={:.0} fwd={:.0} lm={:.0} ce={:.0}[fwd={:.0} bwd={:.0}] lm_bwd={:.0} lora_bwd={:.0})",
t5.duration_since(t0).as_millis(),
t1.duration_since(t0).as_millis(),
t2.duration_since(t1).as_millis(),
t3.duration_since(t2).as_millis(),
t3c.duration_since(t3).as_millis(),
t3b.duration_since(t3a).as_millis(),
t3c.duration_since(t3b).as_millis(),
t4.duration_since(t3c).as_millis(),
t5.duration_since(t4).as_millis(),
);
let avg_loss = self.cross_entropy.read_loss(
&self.losses_buf,
seq_len as u32,
loss_start as u32,
loss_end as u32,
);
InstructStepResult {
loss: if avg_loss.is_finite() { avg_loss } else { 100.0 },
num_response_tokens: num_loss_tokens,
perplexity: if avg_loss.is_finite() { avg_loss.exp().min(1e6) } else { 1e6 },
}
}
}
#[cfg(feature = "gpu")]
impl WgpuInstructPipeline {
pub fn dpo_step(
&mut self,
prompt_ids: &[u32],
chosen_ids: &[u32],
rejected_ids: &[u32],
beta: f32,
) -> f32 {
let chosen_logprob = self.compute_sequence_logprob(prompt_ids, chosen_ids);
let rejected_logprob = self.compute_sequence_logprob(prompt_ids, rejected_ids);
let delta = chosen_logprob - rejected_logprob;
let sigmoid_arg = beta * delta;
let sigmoid_val = 1.0 / (1.0 + (-sigmoid_arg).exp());
let loss = -(sigmoid_val.max(1e-7)).ln();
debug_assert!(loss >= 0.0, "DPO loss must be non-negative: {loss}");
loss
}
fn compute_sequence_logprob(&mut self, prompt_ids: &[u32], response_ids: &[u32]) -> f32 {
let result = self.train_step(prompt_ids, response_ids);
let num_tokens = response_ids.len() as f32;
-result.loss * num_tokens
}
}