Skip to main content

entrenar/finetune/
instruct_trainer.rs

1//! Production training loop for instruction fine-tuning (GH-371)
2//!
3//! `InstructTrainer` wraps `InstructPipeline` with epoch management,
4//! validation, checkpointing, LR scheduling, and early stopping.
5//!
6//! # Contract Invariants
7//!
8//! - F-INST-003: Perplexity reported per epoch
9//! - F-LOOP-002: Validation computed every epoch
10//! - F-LOOP-007: Data shuffled per epoch
11//! - F-LOOP-008: Val split disjoint
12//! - F-LOOP-010: Early stopping respects patience
13
14use super::instruct_corpus::{format_chat_prompt, InstructSample};
15use super::instruct_pipeline::InstructPipeline;
16use sha2::{Digest, Sha256};
17use std::path::PathBuf;
18
19/// Training configuration for instruction trainer.
20#[derive(Debug, Clone)]
21pub struct InstructTrainingConfig {
22    /// Number of training epochs
23    pub epochs: usize,
24    /// Fraction of data reserved for validation (0.0, 0.5]
25    pub val_split: f32,
26    /// Save checkpoint every N epochs
27    pub save_every: usize,
28    /// Early stopping patience in epochs
29    pub early_stopping_patience: usize,
30    /// Directory for checkpoint files
31    pub checkpoint_dir: PathBuf,
32    /// Random seed for reproducibility
33    pub seed: u64,
34    /// Log metrics every N epochs
35    pub log_interval: usize,
36    /// Warmup steps as fraction of total steps
37    pub warmup_fraction: f32,
38    /// Minimum learning rate for cosine decay
39    pub lr_min: f32,
40}
41
42impl Default for InstructTrainingConfig {
43    fn default() -> Self {
44        Self {
45            epochs: 3,
46            val_split: 0.2,
47            save_every: 1,
48            early_stopping_patience: 5,
49            checkpoint_dir: PathBuf::from("checkpoints"),
50            seed: 42,
51            log_interval: 1,
52            warmup_fraction: 0.1,
53            lr_min: 1e-6,
54        }
55    }
56}
57
58/// Metrics for a single training epoch.
59#[derive(Debug, Clone)]
60pub struct InstructEpochMetrics {
61    /// Epoch number (0-indexed)
62    pub epoch: usize,
63    /// Average training loss (response tokens only)
64    pub train_loss: f32,
65    /// Training perplexity
66    pub train_perplexity: f32,
67    /// Average validation loss
68    pub val_loss: f32,
69    /// Validation perplexity
70    pub val_perplexity: f32,
71    /// Current learning rate
72    pub learning_rate: f32,
73    /// Epoch wall-clock time in milliseconds
74    pub epoch_time_ms: u64,
75    /// Training throughput (samples/second)
76    pub samples_per_sec: f32,
77}
78
79/// Result of the full training run.
80#[derive(Debug, Clone)]
81pub struct InstructTrainResult {
82    /// Per-epoch metrics
83    pub epoch_metrics: Vec<InstructEpochMetrics>,
84    /// Epoch with lowest validation loss
85    pub best_epoch: usize,
86    /// Lowest validation loss achieved
87    pub best_val_loss: f32,
88    /// Whether training stopped early
89    pub stopped_early: bool,
90    /// Total wall-clock training time in milliseconds
91    pub total_time_ms: u64,
92}
93
94/// Prepared token sequences for training.
95struct PreparedSample {
96    prompt_ids: Vec<u32>,
97    response_ids: Vec<u32>,
98}
99
100/// Production training loop for instruction fine-tuning.
101pub struct InstructTrainer {
102    /// The instruction pipeline (model + optimizer)
103    pipeline: InstructPipeline,
104    /// Training configuration
105    config: InstructTrainingConfig,
106    /// Training data (shuffled per epoch)
107    train_data: Vec<InstructSample>,
108    /// Validation data (frozen, never shuffled)
109    val_data: Vec<InstructSample>,
110    /// Base random seed
111    rng_seed: u64,
112    /// SHA-256 hash of training data for provenance
113    data_hash: String,
114}
115
116impl InstructTrainer {
117    /// Create a new trainer by splitting corpus into train/val sets.
118    ///
119    /// # Errors
120    /// Returns error if corpus is empty, val_split is out of range, or epochs is 0.
121    pub fn new(
122        pipeline: InstructPipeline,
123        corpus: Vec<InstructSample>,
124        config: InstructTrainingConfig,
125    ) -> crate::Result<Self> {
126        if corpus.is_empty() {
127            return Err(crate::Error::ConfigError("GH-371: corpus must not be empty".to_string()));
128        }
129        if config.val_split <= 0.0 || config.val_split > 0.5 {
130            return Err(crate::Error::ConfigError(format!(
131                "GH-371: val_split must be in (0.0, 0.5], got {}",
132                config.val_split,
133            )));
134        }
135        if config.epochs == 0 {
136            return Err(crate::Error::ConfigError("GH-371: epochs must be > 0".to_string()));
137        }
138
139        let (train_data, val_data) = Self::split_dataset(&corpus, config.val_split, config.seed);
140
141        if train_data.is_empty() || val_data.is_empty() {
142            return Err(crate::Error::ConfigError(format!(
143                "GH-371: split produced empty set (train={}, val={}). Need more samples.",
144                train_data.len(),
145                val_data.len(),
146            )));
147        }
148
149        let rng_seed = config.seed;
150        let data_hash = Self::compute_data_hash(&corpus);
151
152        Ok(Self { pipeline, config, train_data, val_data, rng_seed, data_hash })
153    }
154
155    /// Run the full training loop.
156    pub fn train(&mut self) -> InstructTrainResult {
157        use crate::optim::{LRScheduler, WarmupCosineDecayLR};
158
159        let total_start = std::time::Instant::now();
160        let base_lr = self.pipeline.learning_rate();
161        let total_steps = self.config.epochs * self.train_data.len();
162        let warmup_steps = (total_steps as f32 * self.config.warmup_fraction) as usize;
163
164        let mut scheduler =
165            WarmupCosineDecayLR::new(base_lr, self.config.lr_min, warmup_steps, total_steps);
166
167        let mut epoch_metrics = Vec::new();
168        let mut best_val_loss = f32::INFINITY;
169        let mut best_epoch = 0usize;
170        let mut patience_counter = 0usize;
171        let mut stopped_early = false;
172
173        // Pre-tokenize validation data (frozen across epochs)
174        // KAIZEN-046: Removed unnecessary .clone() — both borrows are immutable.
175        let val_prepared = self.prepare_samples(&self.val_data);
176
177        // KAIZEN-046: Pre-collect validation token IDs once (was: re-cloned every epoch).
178        let val_prompts: Vec<Vec<u32>> =
179            val_prepared.iter().map(|s| s.prompt_ids.clone()).collect();
180        let val_responses: Vec<Vec<u32>> =
181            val_prepared.iter().map(|s| s.response_ids.clone()).collect();
182
183        for epoch in 0..self.config.epochs {
184            let epoch_start = std::time::Instant::now();
185
186            // Shuffle training data
187            self.shuffle_train(epoch as u64);
188
189            // Pre-tokenize training data for this epoch (after shuffle)
190            // KAIZEN-046: Removed unnecessary .clone() — both borrows are immutable.
191            let train_prepared = self.prepare_samples(&self.train_data);
192
193            // ── Train ──
194            let mut epoch_loss = 0.0f32;
195            let mut epoch_tokens = 0usize;
196
197            for sample in &train_prepared {
198                let lr = scheduler.get_lr();
199                self.pipeline.set_learning_rate(lr);
200
201                let result = self.pipeline.train_step(&sample.prompt_ids, &sample.response_ids);
202                epoch_loss += result.loss * result.num_response_tokens as f32;
203                epoch_tokens += result.num_response_tokens;
204                scheduler.step();
205            }
206
207            let train_loss = if epoch_tokens > 0 { epoch_loss / epoch_tokens as f32 } else { 0.0 };
208
209            // PMAT-512: Emit per-epoch loss to stderr so canary parser can extract it.
210            // Format matches WGPU path (finetune.rs:678) for parser compatibility.
211            eprintln!(
212                "  Epoch {} complete: avg_loss={:.4} tokens={} samples={} lr={:.2e}",
213                epoch + 1,
214                train_loss,
215                epoch_tokens,
216                train_prepared.len(),
217                self.pipeline.learning_rate(),
218            );
219
220            // ── Validate ──
221            // KAIZEN-046: val_prompts/val_responses hoisted outside epoch loop.
222            let val_result = self.pipeline.evaluate(&val_prompts, &val_responses);
223
224            let epoch_time_ms = epoch_start.elapsed().as_millis() as u64;
225            let samples_per_sec = if epoch_time_ms > 0 {
226                train_prepared.len() as f32 / (epoch_time_ms as f32 / 1000.0)
227            } else {
228                0.0
229            };
230
231            let metrics = InstructEpochMetrics {
232                epoch,
233                train_loss,
234                train_perplexity: train_loss.exp().min(1e6),
235                val_loss: val_result.avg_loss,
236                val_perplexity: val_result.perplexity,
237                learning_rate: self.pipeline.learning_rate(),
238                epoch_time_ms,
239                samples_per_sec,
240            };
241
242            // ── Checkpointing ──
243            if val_result.avg_loss < best_val_loss {
244                best_val_loss = val_result.avg_loss;
245                best_epoch = epoch;
246                patience_counter = 0;
247
248                // Save best checkpoint
249                let best_path = self.config.checkpoint_dir.join("best");
250                let _ = self.save_checkpoint(&best_path, epoch, &metrics);
251            } else {
252                patience_counter += 1;
253            }
254
255            // Periodic checkpoint
256            let effective_save_every = if self.config.epochs <= self.config.save_every {
257                1
258            } else {
259                self.config.save_every
260            };
261            if effective_save_every > 0 && (epoch + 1) % effective_save_every == 0 {
262                let epoch_path = self.config.checkpoint_dir.join(format!("epoch-{epoch}"));
263                let _ = self.save_checkpoint(&epoch_path, epoch, &metrics);
264            }
265
266            epoch_metrics.push(metrics);
267
268            // ── Early stopping ──
269            if patience_counter >= self.config.early_stopping_patience {
270                stopped_early = true;
271                break;
272            }
273        }
274
275        // PMAT-512: Print final training summary to stderr for canary parser.
276        if let Some(last) = epoch_metrics.last() {
277            eprintln!(
278                "[training] Training complete: final_loss={:.4} best_val_loss={:.4} best_epoch={} epochs={} time={}s{}",
279                last.train_loss,
280                best_val_loss,
281                best_epoch + 1,
282                epoch_metrics.len(),
283                total_start.elapsed().as_secs(),
284                if stopped_early { " (early stopped)" } else { "" },
285            );
286        }
287
288        // PMAT-483: Print profiler report at end of training (text + JSON)
289        if self.pipeline.profiler.is_enabled() {
290            self.pipeline.profiler.print_report();
291            self.pipeline.profiler.print_json_report();
292        }
293
294        InstructTrainResult {
295            epoch_metrics,
296            best_epoch,
297            best_val_loss,
298            stopped_early,
299            total_time_ms: total_start.elapsed().as_millis() as u64,
300        }
301    }
302
303    /// Prepare samples by tokenizing prompt and response.
304    fn prepare_samples(&self, samples: &[InstructSample]) -> Vec<PreparedSample> {
305        samples
306            .iter()
307            .map(|sample| {
308                let (prompt_text, response_text) = format_chat_prompt(sample);
309                PreparedSample {
310                    prompt_ids: self.pipeline.tokenize(&prompt_text),
311                    response_ids: self.pipeline.tokenize(&response_text),
312                }
313            })
314            .collect()
315    }
316
317    /// Split dataset into train/val with deterministic shuffling.
318    fn split_dataset(
319        corpus: &[InstructSample],
320        val_split: f32,
321        seed: u64,
322    ) -> (Vec<InstructSample>, Vec<InstructSample>) {
323        use std::collections::hash_map::DefaultHasher;
324        use std::hash::{Hash, Hasher};
325
326        let mut indices: Vec<usize> = (0..corpus.len()).collect();
327
328        // Fisher-Yates shuffle with deterministic seed
329        for i in (1..indices.len()).rev() {
330            let mut hasher = DefaultHasher::new();
331            seed.hash(&mut hasher);
332            i.hash(&mut hasher);
333            let j = (hasher.finish() as usize) % (i + 1);
334            indices.swap(i, j);
335        }
336
337        let val_size = (corpus.len() as f32 * val_split).ceil() as usize;
338        let val_size = val_size.max(1).min(corpus.len() - 1);
339
340        let val_data: Vec<InstructSample> =
341            indices[..val_size].iter().map(|&i| corpus[i].clone()).collect();
342        let train_data: Vec<InstructSample> =
343            indices[val_size..].iter().map(|&i| corpus[i].clone()).collect();
344
345        (train_data, val_data)
346    }
347
348    /// Shuffle training data with epoch-specific seed.
349    fn shuffle_train(&mut self, epoch: u64) {
350        use std::collections::hash_map::DefaultHasher;
351        use std::hash::{Hash, Hasher};
352
353        let n = self.train_data.len();
354        for i in (1..n).rev() {
355            let mut hasher = DefaultHasher::new();
356            self.rng_seed.hash(&mut hasher);
357            epoch.hash(&mut hasher);
358            i.hash(&mut hasher);
359            let j = (hasher.finish() as usize) % (i + 1);
360            self.train_data.swap(i, j);
361        }
362    }
363
364    /// Compute SHA-256 hash of corpus for provenance.
365    fn compute_data_hash(corpus: &[InstructSample]) -> String {
366        let mut hasher = Sha256::new();
367        for s in corpus {
368            hasher.update(s.instruction.as_bytes());
369            hasher.update([0u8]);
370            hasher.update(s.response.as_bytes());
371            hasher.update([0u8]);
372        }
373        format!("sha256:{:x}", hasher.finalize())
374    }
375
376    /// Get data hash for provenance tracking.
377    #[must_use]
378    pub fn data_hash(&self) -> &str {
379        &self.data_hash
380    }
381
382    /// Get training data size.
383    #[must_use]
384    pub fn train_size(&self) -> usize {
385        self.train_data.len()
386    }
387
388    /// Get validation data size.
389    #[must_use]
390    pub fn val_size(&self) -> usize {
391        self.val_data.len()
392    }
393
394    /// Save a checkpoint with LoRA adapter weights and training metadata.
395    ///
396    /// Creates a directory at `path` containing:
397    /// - `metadata.json`: training metrics for this checkpoint
398    /// - `model.safetensors`: LoRA adapter weights (Q/V projections per layer)
399    pub fn save_checkpoint(
400        &mut self,
401        path: &std::path::Path,
402        epoch: usize,
403        metrics: &InstructEpochMetrics,
404    ) -> crate::Result<()> {
405        contract_pre_save_checkpoint!();
406        // Sync GPU LoRA weights to CPU before saving
407        #[cfg(feature = "cuda")]
408        self.pipeline.sync_lora_to_cpu();
409
410        std::fs::create_dir_all(path).map_err(|e| {
411            crate::Error::Io(format!("Failed to create checkpoint dir {}: {e}", path.display()))
412        })?;
413
414        // Save metadata.json
415        let metadata = serde_json::json!({
416            "task": "instruct",
417            "epoch": epoch,
418            "train_loss": metrics.train_loss,
419            "val_loss": metrics.val_loss,
420            "train_perplexity": metrics.train_perplexity,
421            "val_perplexity": metrics.val_perplexity,
422            "learning_rate": metrics.learning_rate,
423            "epoch_time_ms": metrics.epoch_time_ms,
424            "samples_per_sec": metrics.samples_per_sec,
425            "lora_rank": self.pipeline.config.lora_rank,
426            "lora_alpha": self.pipeline.config.lora_alpha,
427            "data_hash": self.data_hash,
428        });
429
430        let meta_json = serde_json::to_string_pretty(&metadata).map_err(|e| {
431            crate::Error::Serialization(format!("Failed to serialize metadata: {e}"))
432        })?;
433        std::fs::write(path.join("metadata.json"), meta_json)?;
434
435        // Save LoRA adapter weights as SafeTensors
436        let mut tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
437
438        for (idx, lora) in self.pipeline.lora_layers.iter().enumerate() {
439            let layer = idx / 2;
440            let proj = if idx % 2 == 0 { "q" } else { "v" };
441
442            // LoRA A: [rank, d_in]
443            let a_data = lora.lora_a().data();
444            let a_bytes: Vec<u8> =
445                bytemuck::cast_slice(a_data.as_slice().expect("contiguous lora_a")).to_vec();
446            let a_shape = vec![lora.rank(), lora.d_in()];
447            tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_a"), a_bytes, a_shape));
448
449            // LoRA B: [d_out, rank]
450            let b_data = lora.lora_b().data();
451            let b_bytes: Vec<u8> =
452                bytemuck::cast_slice(b_data.as_slice().expect("contiguous lora_b")).to_vec();
453            let b_shape = vec![lora.d_out(), lora.rank()];
454            tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_b"), b_bytes, b_shape));
455        }
456
457        let views: Vec<(&str, safetensors::tensor::TensorView<'_>)> = tensor_data
458            .iter()
459            .map(|(name, bytes, shape)| {
460                let view = safetensors::tensor::TensorView::new(
461                    safetensors::tensor::Dtype::F32,
462                    shape.clone(),
463                    bytes,
464                )
465                .expect("valid tensor view");
466                (name.as_str(), view)
467            })
468            .collect();
469
470        let mut st_metadata = std::collections::HashMap::new();
471        st_metadata.insert("epoch".to_string(), epoch.to_string());
472        st_metadata.insert("val_loss".to_string(), format!("{:.6}", metrics.val_loss));
473
474        let safetensor_bytes = safetensors::serialize(views, Some(st_metadata)).map_err(|e| {
475            crate::Error::Serialization(format!("SafeTensors serialization failed: {e}"))
476        })?;
477        std::fs::write(path.join("model.safetensors"), safetensor_bytes)?;
478
479        contract_post_save_checkpoint!(());
480        Ok(())
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use crate::finetune::instruct_pipeline::InstructConfig;
488    use crate::transformer::TransformerConfig;
489
490    fn make_corpus(n: usize) -> Vec<InstructSample> {
491        (0..n)
492            .map(|i| InstructSample {
493                instruction: format!("Write function {i}"),
494                response: format!("def func_{i}():\n    return {i}"),
495                system: None,
496                metadata: None,
497            })
498            .collect()
499    }
500
501    #[test]
502    fn test_trainer_creation() {
503        let model_config = TransformerConfig::tiny();
504        let instruct_config =
505            InstructConfig { lora_rank: 4, max_seq_len: 32, ..InstructConfig::default() };
506        let pipeline = InstructPipeline::new(&model_config, instruct_config);
507        let corpus = make_corpus(20);
508        let config = InstructTrainingConfig { epochs: 2, ..Default::default() };
509
510        let trainer = InstructTrainer::new(pipeline, corpus, config);
511        assert!(trainer.is_ok());
512
513        let trainer = trainer.unwrap();
514        assert!(trainer.train_size() > 0);
515        assert!(trainer.val_size() > 0);
516    }
517
518    #[test]
519    fn test_trainer_empty_corpus() {
520        let model_config = TransformerConfig::tiny();
521        let instruct_config = InstructConfig::default();
522        let pipeline = InstructPipeline::new(&model_config, instruct_config);
523        let config = InstructTrainingConfig::default();
524
525        let result = InstructTrainer::new(pipeline, vec![], config);
526        assert!(result.is_err());
527    }
528
529    #[test]
530    fn test_trainer_train() {
531        let model_config = TransformerConfig::tiny();
532        let instruct_config =
533            InstructConfig { lora_rank: 4, max_seq_len: 32, ..InstructConfig::default() };
534        let pipeline = InstructPipeline::new(&model_config, instruct_config);
535        let corpus = make_corpus(10);
536        let config = InstructTrainingConfig { epochs: 2, save_every: 1, ..Default::default() };
537
538        let mut trainer = InstructTrainer::new(pipeline, corpus, config).unwrap();
539        let result = trainer.train();
540
541        assert_eq!(result.epoch_metrics.len(), 2);
542        assert!(result.best_val_loss >= 0.0);
543        assert!(result.total_time_ms > 0);
544    }
545
546    #[test]
547    fn test_data_hash_deterministic() {
548        let corpus = make_corpus(5);
549        let hash1 = InstructTrainer::compute_data_hash(&corpus);
550        let hash2 = InstructTrainer::compute_data_hash(&corpus);
551        assert_eq!(hash1, hash2);
552        assert!(hash1.starts_with("sha256:"));
553    }
554
555    #[test]
556    fn test_split_disjoint() {
557        let corpus = make_corpus(20);
558        let (train, val) = InstructTrainer::split_dataset(&corpus, 0.2, 42);
559        assert_eq!(train.len() + val.len(), 20);
560        assert!(!train.is_empty());
561        assert!(!val.is_empty());
562    }
563}