Skip to main content

entrenar/train/transformer_trainer/
wgpu_runner.rs

1//! WGPU Training Runner — end-to-end Qwen3-4B QLoRA training on AMD GPUs
2//!
3//! Connects: tokenizer → data loading → 36-layer forward/backward → AdamW
4//!
5//! # Contract: wgpu-transformer-trainer-v1.yaml (C-WGPU-TRAIN-001)
6//!
7//! # Usage
8//!
9//! ```bash
10//! cargo run --features gpu --release --example wgpu_train -- \
11//!   --model /home/noah/src/models/qwen3-4b \
12//!   --data /home/noah/src/bashrs/training/conversations_v4.jsonl \
13//!   --epochs 1 --lr 5e-4 --lora-rank 16 --seq-len 128
14//! ```
15
16#[cfg(feature = "gpu")]
17use super::wgpu_trainer::{WgpuModelState, WgpuTransformerTrainer};
18
19/// Training configuration
20#[cfg(feature = "gpu")]
21pub struct WgpuTrainConfig {
22    pub model_dir: std::path::PathBuf,
23    pub data_path: std::path::PathBuf,
24    pub epochs: usize,
25    pub lr: f32,
26    pub lora_rank: u32,
27    pub lora_alpha: f32,
28    pub seq_len: usize,
29    pub batch_size: usize,
30    pub log_every: usize,
31    pub save_every: usize,
32    pub output_dir: std::path::PathBuf,
33    /// Gradient accumulation steps (effective batch_size = accumulation_steps)
34    pub accumulation_steps: usize,
35}
36
37/// Load and tokenize one training example
38#[cfg(feature = "gpu")]
39fn tokenize_example(
40    tokenizer: &crate::tokenizer::HfTokenizer,
41    text: &str,
42    max_len: usize,
43) -> (Vec<u32>, Vec<u32>) {
44    let tokens = tokenizer.encode(text);
45    let len = tokens.len().min(max_len);
46    let input_ids = tokens[..len].to_vec();
47    // Target = shifted input (next token prediction)
48    let target_ids: Vec<u32> = if len > 1 { tokens[1..len].to_vec() } else { vec![0] };
49    (input_ids, target_ids)
50}
51
52/// Run WGPU training
53///
54/// # Contract (C-WGPU-TRAIN-001)
55#[cfg(feature = "gpu")]
56pub fn run_wgpu_training(config: &WgpuTrainConfig) -> Result<(), String> {
57    use crate::tokenizer::HfTokenizer;
58    use crate::transformer::TransformerConfig;
59
60    eprintln!("=== WGPU Training: Qwen3-4B QLoRA on AMD GPU ===\n");
61
62    // 1. Load tokenizer
63    let tokenizer_path = config.model_dir.join("tokenizer.json");
64    let tokenizer =
65        HfTokenizer::from_file(&tokenizer_path).map_err(|e| format!("Tokenizer: {e}"))?;
66    eprintln!("Tokenizer loaded: {}", tokenizer_path.display());
67
68    // 2. Load model
69    let mut model =
70        WgpuModelState::load_qwen3_4b(&config.model_dir, config.lora_rank, config.lora_alpha)?;
71    eprintln!("Model: {} trainable params\n", model.trainable_params());
72
73    // 3. Load data
74    let data_str = std::fs::read_to_string(&config.data_path).map_err(|e| format!("Data: {e}"))?;
75    let examples: Vec<String> = data_str
76        .lines()
77        .filter_map(|line| {
78            serde_json::from_str::<serde_json::Value>(line)
79                .ok()
80                .and_then(|v| v["text"].as_str().map(std::string::ToString::to_string))
81        })
82        .collect();
83    eprintln!("Data: {} examples from {}\n", examples.len(), config.data_path.display());
84
85    // 4. Create trainer
86    let mut tc = TransformerConfig::llama2_7b();
87    tc.hidden_size = model.hidden_size;
88    tc.intermediate_size = model.intermediate_size;
89    tc.num_hidden_layers = model.num_layers;
90    tc.num_attention_heads = model.num_heads;
91    tc.num_kv_heads = model.num_kv_heads;
92    tc.vocab_size = model.vocab_size;
93
94    // Scale lr by 1/accumulation_steps for gradient accumulation equivalence
95    let effective_lr = config.lr / config.accumulation_steps.max(1) as f32;
96    let mut trainer = WgpuTransformerTrainer::new(&tc, effective_lr)?;
97    eprintln!(
98        "Effective lr: {effective_lr} (lr={} / accum={})\n",
99        config.lr, config.accumulation_steps
100    );
101
102    // 5. Training loop
103    let mut total_loss = 0.0f32;
104    let mut step = 0usize;
105
106    let mut best_loss = f32::INFINITY;
107    for epoch in 0..config.epochs {
108        // Shuffle data each epoch (deterministic seed for reproducibility)
109        let mut indices: Vec<usize> = (0..examples.len()).collect();
110        if epoch > 0 {
111            use std::collections::hash_map::DefaultHasher;
112            use std::hash::{Hash, Hasher};
113            let mut hasher = DefaultHasher::new();
114            epoch.hash(&mut hasher);
115            let seed = hasher.finish();
116            // Fisher-Yates shuffle with deterministic seed
117            for i in (1..indices.len()).rev() {
118                let j =
119                    ((seed.wrapping_mul(i as u64 + 1).wrapping_add(7)) % (i as u64 + 1)) as usize;
120                indices.swap(i, j);
121            }
122        }
123        eprintln!("--- Epoch {}/{} ({} examples) ---", epoch + 1, config.epochs, examples.len());
124
125        for (idx, &ei) in indices.iter().enumerate() {
126            let text = &examples[ei];
127            let (input_ids, target_ids) = tokenize_example(&tokenizer, text, config.seq_len);
128            if input_ids.len() < 2 {
129                continue;
130            }
131
132            // Create embedding (simplified: use token IDs as indices into lm_head)
133            let seq_len = target_ids.len() as u32;
134            let h = model.hidden_size;
135            let mut hidden = vec![0.0f32; seq_len as usize * h];
136            for (si, &tid) in input_ids[..target_ids.len()].iter().enumerate() {
137                let tid = (tid as usize).min(model.vocab_size - 1);
138                for hi in 0..h {
139                    hidden[si * h + hi] = model.lm_head[tid * h + hi];
140                }
141            }
142
143            // Training step
144            let (loss, gnorm) = trainer.full_train_step(&hidden, &target_ids, &mut model)?;
145
146            total_loss += loss;
147            step += 1;
148
149            if step.is_multiple_of(config.log_every) {
150                let avg_loss = total_loss / step as f32;
151                eprintln!(
152                    "  step={step} loss={loss:.3} avg_loss={avg_loss:.3} gnorm={gnorm:.2e} [{}/{}]",
153                    idx + 1,
154                    examples.len()
155                );
156            }
157
158            if config.save_every > 0 && step.is_multiple_of(config.save_every) {
159                model.save_checkpoint(
160                    &config.output_dir,
161                    step as u32,
162                    loss,
163                    config.lora_rank,
164                    config.lora_alpha,
165                )?;
166            }
167            if loss < best_loss {
168                best_loss = loss;
169            }
170        }
171    }
172
173    let final_avg = total_loss / step.max(1) as f32;
174
175    // Save final checkpoint
176    model.save_checkpoint(
177        &config.output_dir,
178        step as u32,
179        final_avg,
180        config.lora_rank,
181        config.lora_alpha,
182    )?;
183
184    eprintln!("\n=== Training complete: {step} steps, avg_loss={final_avg:.3} ===");
185    Ok(())
186}
187
188#[cfg(all(test, feature = "gpu"))]
189mod tests {
190    use super::*;
191
192    /// Smoke test: tokenize + 1 training step with real data
193    #[test]
194    fn test_wgpu_training_smoke() {
195        let model_dir = std::path::Path::new("/home/noah/src/models/qwen3-4b");
196        let data_path =
197            std::path::Path::new("/home/noah/src/bashrs/training/conversations_v4.jsonl");
198
199        if !model_dir.exists() || !data_path.exists() {
200            eprintln!("Skipping: model or data not found");
201            return;
202        }
203
204        // Load tokenizer
205        let tokenizer = crate::tokenizer::HfTokenizer::from_file(model_dir.join("tokenizer.json"))
206            .expect("tokenizer");
207
208        // Tokenize first example
209        let data = std::fs::read_to_string(data_path).expect("read data");
210        let first_line = data.lines().next().expect("first line");
211        let text: serde_json::Value = serde_json::from_str(first_line).expect("parse json");
212        let text = text["text"].as_str().expect("text field");
213
214        let (input_ids, target_ids) = tokenize_example(&tokenizer, text, 32);
215        eprintln!(
216            "Tokenized: {} tokens, first 5: {:?}",
217            input_ids.len(),
218            &input_ids[..5.min(input_ids.len())]
219        );
220
221        assert!(input_ids.len() >= 2, "Need at least 2 tokens");
222        assert_eq!(target_ids.len(), input_ids.len() - 1);
223
224        // Load model and run 1 step
225        let mut model = WgpuModelState::load_qwen3_4b(model_dir, 16, 32.0).expect("model");
226
227        let mut config = crate::transformer::TransformerConfig::llama2_7b();
228        config.hidden_size = 2560;
229        config.intermediate_size = 9728;
230        config.num_hidden_layers = 36;
231        config.vocab_size = 151936;
232
233        let mut trainer = WgpuTransformerTrainer::new(&config, 5e-4).expect("trainer");
234
235        // Embedding lookup
236        let seq_len = target_ids.len();
237        let h = 2560;
238        let mut hidden = vec![0.0f32; seq_len * h];
239        for (si, &tid) in input_ids[..seq_len].iter().enumerate() {
240            let tid = (tid as usize).min(151935);
241            for hi in 0..h {
242                hidden[si * h + hi] = model.lm_head[tid * h + hi];
243            }
244        }
245
246        let start = std::time::Instant::now();
247        let (loss, gnorm) =
248            trainer.full_train_step(&hidden, &target_ids, &mut model).expect("train step");
249        let elapsed = start.elapsed();
250
251        eprintln!(
252            "WGPU smoke: loss={loss:.3}, gnorm={gnorm:.2e}, time={:.1}s",
253            elapsed.as_secs_f64()
254        );
255        assert!(loss.is_finite(), "Loss must be finite");
256        assert!(loss > 0.0, "Loss must be positive");
257    }
258}