1use super::instruct_corpus::{format_chat_prompt, InstructSample};
15use super::instruct_pipeline::InstructPipeline;
16use sha2::{Digest, Sha256};
17use std::path::PathBuf;
18
19#[derive(Debug, Clone)]
21pub struct InstructTrainingConfig {
22 pub epochs: usize,
24 pub val_split: f32,
26 pub save_every: usize,
28 pub early_stopping_patience: usize,
30 pub checkpoint_dir: PathBuf,
32 pub seed: u64,
34 pub log_interval: usize,
36 pub warmup_fraction: f32,
38 pub lr_min: f32,
40}
41
42impl Default for InstructTrainingConfig {
43 fn default() -> Self {
44 Self {
45 epochs: 3,
46 val_split: 0.2,
47 save_every: 1,
48 early_stopping_patience: 5,
49 checkpoint_dir: PathBuf::from("checkpoints"),
50 seed: 42,
51 log_interval: 1,
52 warmup_fraction: 0.1,
53 lr_min: 1e-6,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct InstructEpochMetrics {
61 pub epoch: usize,
63 pub train_loss: f32,
65 pub train_perplexity: f32,
67 pub val_loss: f32,
69 pub val_perplexity: f32,
71 pub learning_rate: f32,
73 pub epoch_time_ms: u64,
75 pub samples_per_sec: f32,
77}
78
79#[derive(Debug, Clone)]
81pub struct InstructTrainResult {
82 pub epoch_metrics: Vec<InstructEpochMetrics>,
84 pub best_epoch: usize,
86 pub best_val_loss: f32,
88 pub stopped_early: bool,
90 pub total_time_ms: u64,
92}
93
94struct PreparedSample {
96 prompt_ids: Vec<u32>,
97 response_ids: Vec<u32>,
98}
99
100pub struct InstructTrainer {
102 pipeline: InstructPipeline,
104 config: InstructTrainingConfig,
106 train_data: Vec<InstructSample>,
108 val_data: Vec<InstructSample>,
110 rng_seed: u64,
112 data_hash: String,
114}
115
116impl InstructTrainer {
117 pub fn new(
122 pipeline: InstructPipeline,
123 corpus: Vec<InstructSample>,
124 config: InstructTrainingConfig,
125 ) -> crate::Result<Self> {
126 if corpus.is_empty() {
127 return Err(crate::Error::ConfigError("GH-371: corpus must not be empty".to_string()));
128 }
129 if config.val_split <= 0.0 || config.val_split > 0.5 {
130 return Err(crate::Error::ConfigError(format!(
131 "GH-371: val_split must be in (0.0, 0.5], got {}",
132 config.val_split,
133 )));
134 }
135 if config.epochs == 0 {
136 return Err(crate::Error::ConfigError("GH-371: epochs must be > 0".to_string()));
137 }
138
139 let (train_data, val_data) = Self::split_dataset(&corpus, config.val_split, config.seed);
140
141 if train_data.is_empty() || val_data.is_empty() {
142 return Err(crate::Error::ConfigError(format!(
143 "GH-371: split produced empty set (train={}, val={}). Need more samples.",
144 train_data.len(),
145 val_data.len(),
146 )));
147 }
148
149 let rng_seed = config.seed;
150 let data_hash = Self::compute_data_hash(&corpus);
151
152 Ok(Self { pipeline, config, train_data, val_data, rng_seed, data_hash })
153 }
154
155 pub fn train(&mut self) -> InstructTrainResult {
157 use crate::optim::{LRScheduler, WarmupCosineDecayLR};
158
159 let total_start = std::time::Instant::now();
160 let base_lr = self.pipeline.learning_rate();
161 let total_steps = self.config.epochs * self.train_data.len();
162 let warmup_steps = (total_steps as f32 * self.config.warmup_fraction) as usize;
163
164 let mut scheduler =
165 WarmupCosineDecayLR::new(base_lr, self.config.lr_min, warmup_steps, total_steps);
166
167 let mut epoch_metrics = Vec::new();
168 let mut best_val_loss = f32::INFINITY;
169 let mut best_epoch = 0usize;
170 let mut patience_counter = 0usize;
171 let mut stopped_early = false;
172
173 let val_prepared = self.prepare_samples(&self.val_data);
176
177 let val_prompts: Vec<Vec<u32>> =
179 val_prepared.iter().map(|s| s.prompt_ids.clone()).collect();
180 let val_responses: Vec<Vec<u32>> =
181 val_prepared.iter().map(|s| s.response_ids.clone()).collect();
182
183 for epoch in 0..self.config.epochs {
184 let epoch_start = std::time::Instant::now();
185
186 self.shuffle_train(epoch as u64);
188
189 let train_prepared = self.prepare_samples(&self.train_data);
192
193 let mut epoch_loss = 0.0f32;
195 let mut epoch_tokens = 0usize;
196
197 for sample in &train_prepared {
198 let lr = scheduler.get_lr();
199 self.pipeline.set_learning_rate(lr);
200
201 let result = self.pipeline.train_step(&sample.prompt_ids, &sample.response_ids);
202 epoch_loss += result.loss * result.num_response_tokens as f32;
203 epoch_tokens += result.num_response_tokens;
204 scheduler.step();
205 }
206
207 let train_loss = if epoch_tokens > 0 { epoch_loss / epoch_tokens as f32 } else { 0.0 };
208
209 eprintln!(
212 " Epoch {} complete: avg_loss={:.4} tokens={} samples={} lr={:.2e}",
213 epoch + 1,
214 train_loss,
215 epoch_tokens,
216 train_prepared.len(),
217 self.pipeline.learning_rate(),
218 );
219
220 let val_result = self.pipeline.evaluate(&val_prompts, &val_responses);
223
224 let epoch_time_ms = epoch_start.elapsed().as_millis() as u64;
225 let samples_per_sec = if epoch_time_ms > 0 {
226 train_prepared.len() as f32 / (epoch_time_ms as f32 / 1000.0)
227 } else {
228 0.0
229 };
230
231 let metrics = InstructEpochMetrics {
232 epoch,
233 train_loss,
234 train_perplexity: train_loss.exp().min(1e6),
235 val_loss: val_result.avg_loss,
236 val_perplexity: val_result.perplexity,
237 learning_rate: self.pipeline.learning_rate(),
238 epoch_time_ms,
239 samples_per_sec,
240 };
241
242 if val_result.avg_loss < best_val_loss {
244 best_val_loss = val_result.avg_loss;
245 best_epoch = epoch;
246 patience_counter = 0;
247
248 let best_path = self.config.checkpoint_dir.join("best");
250 let _ = self.save_checkpoint(&best_path, epoch, &metrics);
251 } else {
252 patience_counter += 1;
253 }
254
255 let effective_save_every = if self.config.epochs <= self.config.save_every {
257 1
258 } else {
259 self.config.save_every
260 };
261 if effective_save_every > 0 && (epoch + 1) % effective_save_every == 0 {
262 let epoch_path = self.config.checkpoint_dir.join(format!("epoch-{epoch}"));
263 let _ = self.save_checkpoint(&epoch_path, epoch, &metrics);
264 }
265
266 epoch_metrics.push(metrics);
267
268 if patience_counter >= self.config.early_stopping_patience {
270 stopped_early = true;
271 break;
272 }
273 }
274
275 if let Some(last) = epoch_metrics.last() {
277 eprintln!(
278 "[training] Training complete: final_loss={:.4} best_val_loss={:.4} best_epoch={} epochs={} time={}s{}",
279 last.train_loss,
280 best_val_loss,
281 best_epoch + 1,
282 epoch_metrics.len(),
283 total_start.elapsed().as_secs(),
284 if stopped_early { " (early stopped)" } else { "" },
285 );
286 }
287
288 if self.pipeline.profiler.is_enabled() {
290 self.pipeline.profiler.print_report();
291 self.pipeline.profiler.print_json_report();
292 }
293
294 InstructTrainResult {
295 epoch_metrics,
296 best_epoch,
297 best_val_loss,
298 stopped_early,
299 total_time_ms: total_start.elapsed().as_millis() as u64,
300 }
301 }
302
303 fn prepare_samples(&self, samples: &[InstructSample]) -> Vec<PreparedSample> {
305 samples
306 .iter()
307 .map(|sample| {
308 let (prompt_text, response_text) = format_chat_prompt(sample);
309 PreparedSample {
310 prompt_ids: self.pipeline.tokenize(&prompt_text),
311 response_ids: self.pipeline.tokenize(&response_text),
312 }
313 })
314 .collect()
315 }
316
317 fn split_dataset(
319 corpus: &[InstructSample],
320 val_split: f32,
321 seed: u64,
322 ) -> (Vec<InstructSample>, Vec<InstructSample>) {
323 use std::collections::hash_map::DefaultHasher;
324 use std::hash::{Hash, Hasher};
325
326 let mut indices: Vec<usize> = (0..corpus.len()).collect();
327
328 for i in (1..indices.len()).rev() {
330 let mut hasher = DefaultHasher::new();
331 seed.hash(&mut hasher);
332 i.hash(&mut hasher);
333 let j = (hasher.finish() as usize) % (i + 1);
334 indices.swap(i, j);
335 }
336
337 let val_size = (corpus.len() as f32 * val_split).ceil() as usize;
338 let val_size = val_size.max(1).min(corpus.len() - 1);
339
340 let val_data: Vec<InstructSample> =
341 indices[..val_size].iter().map(|&i| corpus[i].clone()).collect();
342 let train_data: Vec<InstructSample> =
343 indices[val_size..].iter().map(|&i| corpus[i].clone()).collect();
344
345 (train_data, val_data)
346 }
347
348 fn shuffle_train(&mut self, epoch: u64) {
350 use std::collections::hash_map::DefaultHasher;
351 use std::hash::{Hash, Hasher};
352
353 let n = self.train_data.len();
354 for i in (1..n).rev() {
355 let mut hasher = DefaultHasher::new();
356 self.rng_seed.hash(&mut hasher);
357 epoch.hash(&mut hasher);
358 i.hash(&mut hasher);
359 let j = (hasher.finish() as usize) % (i + 1);
360 self.train_data.swap(i, j);
361 }
362 }
363
364 fn compute_data_hash(corpus: &[InstructSample]) -> String {
366 let mut hasher = Sha256::new();
367 for s in corpus {
368 hasher.update(s.instruction.as_bytes());
369 hasher.update([0u8]);
370 hasher.update(s.response.as_bytes());
371 hasher.update([0u8]);
372 }
373 format!("sha256:{:x}", hasher.finalize())
374 }
375
376 #[must_use]
378 pub fn data_hash(&self) -> &str {
379 &self.data_hash
380 }
381
382 #[must_use]
384 pub fn train_size(&self) -> usize {
385 self.train_data.len()
386 }
387
388 #[must_use]
390 pub fn val_size(&self) -> usize {
391 self.val_data.len()
392 }
393
394 pub fn save_checkpoint(
400 &mut self,
401 path: &std::path::Path,
402 epoch: usize,
403 metrics: &InstructEpochMetrics,
404 ) -> crate::Result<()> {
405 contract_pre_save_checkpoint!();
406 #[cfg(feature = "cuda")]
408 self.pipeline.sync_lora_to_cpu();
409
410 std::fs::create_dir_all(path).map_err(|e| {
411 crate::Error::Io(format!("Failed to create checkpoint dir {}: {e}", path.display()))
412 })?;
413
414 let metadata = serde_json::json!({
416 "task": "instruct",
417 "epoch": epoch,
418 "train_loss": metrics.train_loss,
419 "val_loss": metrics.val_loss,
420 "train_perplexity": metrics.train_perplexity,
421 "val_perplexity": metrics.val_perplexity,
422 "learning_rate": metrics.learning_rate,
423 "epoch_time_ms": metrics.epoch_time_ms,
424 "samples_per_sec": metrics.samples_per_sec,
425 "lora_rank": self.pipeline.config.lora_rank,
426 "lora_alpha": self.pipeline.config.lora_alpha,
427 "data_hash": self.data_hash,
428 });
429
430 let meta_json = serde_json::to_string_pretty(&metadata).map_err(|e| {
431 crate::Error::Serialization(format!("Failed to serialize metadata: {e}"))
432 })?;
433 std::fs::write(path.join("metadata.json"), meta_json)?;
434
435 let mut tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
437
438 for (idx, lora) in self.pipeline.lora_layers.iter().enumerate() {
439 let layer = idx / 2;
440 let proj = if idx % 2 == 0 { "q" } else { "v" };
441
442 let a_data = lora.lora_a().data();
444 let a_bytes: Vec<u8> =
445 bytemuck::cast_slice(a_data.as_slice().expect("contiguous lora_a")).to_vec();
446 let a_shape = vec![lora.rank(), lora.d_in()];
447 tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_a"), a_bytes, a_shape));
448
449 let b_data = lora.lora_b().data();
451 let b_bytes: Vec<u8> =
452 bytemuck::cast_slice(b_data.as_slice().expect("contiguous lora_b")).to_vec();
453 let b_shape = vec![lora.d_out(), lora.rank()];
454 tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_b"), b_bytes, b_shape));
455 }
456
457 let views: Vec<(&str, safetensors::tensor::TensorView<'_>)> = tensor_data
458 .iter()
459 .map(|(name, bytes, shape)| {
460 let view = safetensors::tensor::TensorView::new(
461 safetensors::tensor::Dtype::F32,
462 shape.clone(),
463 bytes,
464 )
465 .expect("valid tensor view");
466 (name.as_str(), view)
467 })
468 .collect();
469
470 let mut st_metadata = std::collections::HashMap::new();
471 st_metadata.insert("epoch".to_string(), epoch.to_string());
472 st_metadata.insert("val_loss".to_string(), format!("{:.6}", metrics.val_loss));
473
474 let safetensor_bytes = safetensors::serialize(views, Some(st_metadata)).map_err(|e| {
475 crate::Error::Serialization(format!("SafeTensors serialization failed: {e}"))
476 })?;
477 std::fs::write(path.join("model.safetensors"), safetensor_bytes)?;
478
479 contract_post_save_checkpoint!(());
480 Ok(())
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use crate::finetune::instruct_pipeline::InstructConfig;
488 use crate::transformer::TransformerConfig;
489
490 fn make_corpus(n: usize) -> Vec<InstructSample> {
491 (0..n)
492 .map(|i| InstructSample {
493 instruction: format!("Write function {i}"),
494 response: format!("def func_{i}():\n return {i}"),
495 system: None,
496 metadata: None,
497 })
498 .collect()
499 }
500
501 #[test]
502 fn test_trainer_creation() {
503 let model_config = TransformerConfig::tiny();
504 let instruct_config =
505 InstructConfig { lora_rank: 4, max_seq_len: 32, ..InstructConfig::default() };
506 let pipeline = InstructPipeline::new(&model_config, instruct_config);
507 let corpus = make_corpus(20);
508 let config = InstructTrainingConfig { epochs: 2, ..Default::default() };
509
510 let trainer = InstructTrainer::new(pipeline, corpus, config);
511 assert!(trainer.is_ok());
512
513 let trainer = trainer.unwrap();
514 assert!(trainer.train_size() > 0);
515 assert!(trainer.val_size() > 0);
516 }
517
518 #[test]
519 fn test_trainer_empty_corpus() {
520 let model_config = TransformerConfig::tiny();
521 let instruct_config = InstructConfig::default();
522 let pipeline = InstructPipeline::new(&model_config, instruct_config);
523 let config = InstructTrainingConfig::default();
524
525 let result = InstructTrainer::new(pipeline, vec![], config);
526 assert!(result.is_err());
527 }
528
529 #[test]
530 fn test_trainer_train() {
531 let model_config = TransformerConfig::tiny();
532 let instruct_config =
533 InstructConfig { lora_rank: 4, max_seq_len: 32, ..InstructConfig::default() };
534 let pipeline = InstructPipeline::new(&model_config, instruct_config);
535 let corpus = make_corpus(10);
536 let config = InstructTrainingConfig { epochs: 2, save_every: 1, ..Default::default() };
537
538 let mut trainer = InstructTrainer::new(pipeline, corpus, config).unwrap();
539 let result = trainer.train();
540
541 assert_eq!(result.epoch_metrics.len(), 2);
542 assert!(result.best_val_loss >= 0.0);
543 assert!(result.total_time_ms > 0);
544 }
545
546 #[test]
547 fn test_data_hash_deterministic() {
548 let corpus = make_corpus(5);
549 let hash1 = InstructTrainer::compute_data_hash(&corpus);
550 let hash2 = InstructTrainer::compute_data_hash(&corpus);
551 assert_eq!(hash1, hash2);
552 assert!(hash1.starts_with("sha256:"));
553 }
554
555 #[test]
556 fn test_split_disjoint() {
557 let corpus = make_corpus(20);
558 let (train, val) = InstructTrainer::split_dataset(&corpus, 0.2, 42);
559 assert_eq!(train.len() + val.len(), 20);
560 assert!(!train.is_empty());
561 assert!(!val.is_empty());
562 }
563}