entrenar/train/transformer_trainer/
wgpu_runner.rs1#[cfg(feature = "gpu")]
17use super::wgpu_trainer::{WgpuModelState, WgpuTransformerTrainer};
18
19#[cfg(feature = "gpu")]
21pub struct WgpuTrainConfig {
22 pub model_dir: std::path::PathBuf,
23 pub data_path: std::path::PathBuf,
24 pub epochs: usize,
25 pub lr: f32,
26 pub lora_rank: u32,
27 pub lora_alpha: f32,
28 pub seq_len: usize,
29 pub batch_size: usize,
30 pub log_every: usize,
31 pub save_every: usize,
32 pub output_dir: std::path::PathBuf,
33 pub accumulation_steps: usize,
35}
36
37#[cfg(feature = "gpu")]
39fn tokenize_example(
40 tokenizer: &crate::tokenizer::HfTokenizer,
41 text: &str,
42 max_len: usize,
43) -> (Vec<u32>, Vec<u32>) {
44 let tokens = tokenizer.encode(text);
45 let len = tokens.len().min(max_len);
46 let input_ids = tokens[..len].to_vec();
47 let target_ids: Vec<u32> = if len > 1 { tokens[1..len].to_vec() } else { vec![0] };
49 (input_ids, target_ids)
50}
51
52#[cfg(feature = "gpu")]
56pub fn run_wgpu_training(config: &WgpuTrainConfig) -> Result<(), String> {
57 use crate::tokenizer::HfTokenizer;
58 use crate::transformer::TransformerConfig;
59
60 eprintln!("=== WGPU Training: Qwen3-4B QLoRA on AMD GPU ===\n");
61
62 let tokenizer_path = config.model_dir.join("tokenizer.json");
64 let tokenizer =
65 HfTokenizer::from_file(&tokenizer_path).map_err(|e| format!("Tokenizer: {e}"))?;
66 eprintln!("Tokenizer loaded: {}", tokenizer_path.display());
67
68 let mut model =
70 WgpuModelState::load_qwen3_4b(&config.model_dir, config.lora_rank, config.lora_alpha)?;
71 eprintln!("Model: {} trainable params\n", model.trainable_params());
72
73 let data_str = std::fs::read_to_string(&config.data_path).map_err(|e| format!("Data: {e}"))?;
75 let examples: Vec<String> = data_str
76 .lines()
77 .filter_map(|line| {
78 serde_json::from_str::<serde_json::Value>(line)
79 .ok()
80 .and_then(|v| v["text"].as_str().map(std::string::ToString::to_string))
81 })
82 .collect();
83 eprintln!("Data: {} examples from {}\n", examples.len(), config.data_path.display());
84
85 let mut tc = TransformerConfig::llama2_7b();
87 tc.hidden_size = model.hidden_size;
88 tc.intermediate_size = model.intermediate_size;
89 tc.num_hidden_layers = model.num_layers;
90 tc.num_attention_heads = model.num_heads;
91 tc.num_kv_heads = model.num_kv_heads;
92 tc.vocab_size = model.vocab_size;
93
94 let effective_lr = config.lr / config.accumulation_steps.max(1) as f32;
96 let mut trainer = WgpuTransformerTrainer::new(&tc, effective_lr)?;
97 eprintln!(
98 "Effective lr: {effective_lr} (lr={} / accum={})\n",
99 config.lr, config.accumulation_steps
100 );
101
102 let mut total_loss = 0.0f32;
104 let mut step = 0usize;
105
106 let mut best_loss = f32::INFINITY;
107 for epoch in 0..config.epochs {
108 let mut indices: Vec<usize> = (0..examples.len()).collect();
110 if epoch > 0 {
111 use std::collections::hash_map::DefaultHasher;
112 use std::hash::{Hash, Hasher};
113 let mut hasher = DefaultHasher::new();
114 epoch.hash(&mut hasher);
115 let seed = hasher.finish();
116 for i in (1..indices.len()).rev() {
118 let j =
119 ((seed.wrapping_mul(i as u64 + 1).wrapping_add(7)) % (i as u64 + 1)) as usize;
120 indices.swap(i, j);
121 }
122 }
123 eprintln!("--- Epoch {}/{} ({} examples) ---", epoch + 1, config.epochs, examples.len());
124
125 for (idx, &ei) in indices.iter().enumerate() {
126 let text = &examples[ei];
127 let (input_ids, target_ids) = tokenize_example(&tokenizer, text, config.seq_len);
128 if input_ids.len() < 2 {
129 continue;
130 }
131
132 let seq_len = target_ids.len() as u32;
134 let h = model.hidden_size;
135 let mut hidden = vec![0.0f32; seq_len as usize * h];
136 for (si, &tid) in input_ids[..target_ids.len()].iter().enumerate() {
137 let tid = (tid as usize).min(model.vocab_size - 1);
138 for hi in 0..h {
139 hidden[si * h + hi] = model.lm_head[tid * h + hi];
140 }
141 }
142
143 let (loss, gnorm) = trainer.full_train_step(&hidden, &target_ids, &mut model)?;
145
146 total_loss += loss;
147 step += 1;
148
149 if step.is_multiple_of(config.log_every) {
150 let avg_loss = total_loss / step as f32;
151 eprintln!(
152 " step={step} loss={loss:.3} avg_loss={avg_loss:.3} gnorm={gnorm:.2e} [{}/{}]",
153 idx + 1,
154 examples.len()
155 );
156 }
157
158 if config.save_every > 0 && step.is_multiple_of(config.save_every) {
159 model.save_checkpoint(
160 &config.output_dir,
161 step as u32,
162 loss,
163 config.lora_rank,
164 config.lora_alpha,
165 )?;
166 }
167 if loss < best_loss {
168 best_loss = loss;
169 }
170 }
171 }
172
173 let final_avg = total_loss / step.max(1) as f32;
174
175 model.save_checkpoint(
177 &config.output_dir,
178 step as u32,
179 final_avg,
180 config.lora_rank,
181 config.lora_alpha,
182 )?;
183
184 eprintln!("\n=== Training complete: {step} steps, avg_loss={final_avg:.3} ===");
185 Ok(())
186}
187
188#[cfg(all(test, feature = "gpu"))]
189mod tests {
190 use super::*;
191
192 #[test]
194 fn test_wgpu_training_smoke() {
195 let model_dir = std::path::Path::new("/home/noah/src/models/qwen3-4b");
196 let data_path =
197 std::path::Path::new("/home/noah/src/bashrs/training/conversations_v4.jsonl");
198
199 if !model_dir.exists() || !data_path.exists() {
200 eprintln!("Skipping: model or data not found");
201 return;
202 }
203
204 let tokenizer = crate::tokenizer::HfTokenizer::from_file(model_dir.join("tokenizer.json"))
206 .expect("tokenizer");
207
208 let data = std::fs::read_to_string(data_path).expect("read data");
210 let first_line = data.lines().next().expect("first line");
211 let text: serde_json::Value = serde_json::from_str(first_line).expect("parse json");
212 let text = text["text"].as_str().expect("text field");
213
214 let (input_ids, target_ids) = tokenize_example(&tokenizer, text, 32);
215 eprintln!(
216 "Tokenized: {} tokens, first 5: {:?}",
217 input_ids.len(),
218 &input_ids[..5.min(input_ids.len())]
219 );
220
221 assert!(input_ids.len() >= 2, "Need at least 2 tokens");
222 assert_eq!(target_ids.len(), input_ids.len() - 1);
223
224 let mut model = WgpuModelState::load_qwen3_4b(model_dir, 16, 32.0).expect("model");
226
227 let mut config = crate::transformer::TransformerConfig::llama2_7b();
228 config.hidden_size = 2560;
229 config.intermediate_size = 9728;
230 config.num_hidden_layers = 36;
231 config.vocab_size = 151936;
232
233 let mut trainer = WgpuTransformerTrainer::new(&config, 5e-4).expect("trainer");
234
235 let seq_len = target_ids.len();
237 let h = 2560;
238 let mut hidden = vec![0.0f32; seq_len * h];
239 for (si, &tid) in input_ids[..seq_len].iter().enumerate() {
240 let tid = (tid as usize).min(151935);
241 for hi in 0..h {
242 hidden[si * h + hi] = model.lm_head[tid * h + hi];
243 }
244 }
245
246 let start = std::time::Instant::now();
247 let (loss, gnorm) =
248 trainer.full_train_step(&hidden, &target_ids, &mut model).expect("train step");
249 let elapsed = start.elapsed();
250
251 eprintln!(
252 "WGPU smoke: loss={loss:.3}, gnorm={gnorm:.2e}, time={:.1}s",
253 elapsed.as_secs_f64()
254 );
255 assert!(loss.is_finite(), "Loss must be finite");
256 assert!(loss > 0.0, "Loss must be positive");
257 }
258}