1use super::classification::{SafetySample, TokenizedSample};
16use super::classify_eval_report::ClassifyEvalReport;
17use super::classify_pipeline::ClassifyPipeline;
18use super::distributed::DistributedConfig;
19use crate::optim::LRScheduler;
20use crate::optim::WarmupCosineDecayLR;
21use sha2::{Digest, Sha256};
22use std::path::{Path, PathBuf};
23
24#[derive(Debug, Clone)]
26pub struct TrainingConfig {
27 pub epochs: usize,
29 pub val_split: f32,
31 pub save_every: usize,
33 pub early_stopping_patience: usize,
35 pub checkpoint_dir: PathBuf,
37 pub seed: u64,
39 pub log_interval: usize,
41 pub warmup_fraction: f32,
43 pub lr_min: f32,
45 pub oversample_minority: bool,
48 pub quantize_nf4: bool,
54 pub distributed: Option<DistributedConfig>,
60}
61
62impl Default for TrainingConfig {
63 fn default() -> Self {
64 Self {
65 epochs: 50,
66 val_split: 0.2,
67 save_every: 5,
68 early_stopping_patience: 10,
69 checkpoint_dir: PathBuf::from("checkpoints"),
70 seed: 42,
71 log_interval: 1,
72 warmup_fraction: 0.1,
73 lr_min: 1e-6,
74 oversample_minority: false,
75 quantize_nf4: false,
76 distributed: None,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct EpochMetrics {
84 pub epoch: usize,
86 pub train_loss: f32,
88 pub train_accuracy: f32,
90 pub val_loss: f32,
92 pub val_accuracy: f32,
94 pub learning_rate: f32,
96 pub epoch_time_ms: u64,
98 pub samples_per_sec: f32,
100}
101
102#[derive(Debug, Clone)]
104pub struct TrainResult {
105 pub epoch_metrics: Vec<EpochMetrics>,
107 pub best_epoch: usize,
109 pub best_val_loss: f32,
111 pub stopped_early: bool,
113 pub total_time_ms: u64,
115}
116
117pub struct ClassifyTrainer {
126 pipeline: ClassifyPipeline,
128 config: TrainingConfig,
130 train_data: Vec<SafetySample>,
132 train_tokens: Vec<TokenizedSample>,
135 val_tokens: Vec<TokenizedSample>,
137 val_data: Vec<SafetySample>,
139 rng_seed: u64,
141 monitor_writer: Option<crate::monitor::tui::TrainingStateWriter>,
143 data_hash: String,
145 train_start: String,
147}
148
149#[allow(clippy::missing_fields_in_debug)]
150impl std::fmt::Debug for ClassifyTrainer {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 f.debug_struct("ClassifyTrainer")
153 .field("config", &self.config)
154 .field("train_data_len", &self.train_data.len())
155 .field("train_tokens_len", &self.train_tokens.len())
156 .field("val_data_len", &self.val_data.len())
157 .field("val_tokens_len", &self.val_tokens.len())
158 .field("rng_seed", &self.rng_seed)
159 .finish()
160 }
161}
162
163impl ClassifyTrainer {
164 pub fn new(
175 mut pipeline: ClassifyPipeline,
176 corpus: Vec<SafetySample>,
177 config: TrainingConfig,
178 ) -> crate::Result<Self> {
179 if corpus.is_empty() {
180 return Err(crate::Error::ConfigError("SSC-026: corpus must not be empty".to_string()));
181 }
182 if config.val_split <= 0.0 || config.val_split > 0.5 {
183 return Err(crate::Error::ConfigError(format!(
184 "SSC-026: val_split must be in (0.0, 0.5], got {}",
185 config.val_split,
186 )));
187 }
188 if config.epochs == 0 {
189 return Err(crate::Error::ConfigError("SSC-026: epochs must be > 0".to_string()));
190 }
191
192 if !config.oversample_minority {
195 Self::auto_balance_classes(&mut pipeline, &corpus);
196 }
197
198 let (mut train_data, val_data) =
199 Self::split_dataset(&corpus, config.val_split, config.seed);
200
201 if config.oversample_minority {
202 Self::oversample_training_data(&mut train_data, config.seed);
203 }
204
205 if train_data.is_empty() || val_data.is_empty() {
206 return Err(crate::Error::ConfigError(format!(
207 "SSC-026: split produced empty set (train={}, val={}). Need more samples.",
208 train_data.len(),
209 val_data.len(),
210 )));
211 }
212
213 let rng_seed = config.seed;
214
215 let train_tokens = pipeline.pre_tokenize(&train_data);
219 let val_tokens = pipeline.pre_tokenize(&val_data);
220
221 let data_hash = Self::compute_data_hash(&corpus);
223 let train_start = chrono::Utc::now().to_rfc3339();
224
225 Ok(Self {
226 pipeline,
227 config,
228 train_data,
229 train_tokens,
230 val_tokens,
231 val_data,
232 rng_seed,
233 monitor_writer: None,
234 data_hash,
235 train_start,
236 })
237 }
238
239 fn compute_data_hash(corpus: &[SafetySample]) -> String {
243 let mut hasher = Sha256::new();
244 let mut sorted: Vec<(&str, usize)> =
245 corpus.iter().map(|s| (s.input.as_str(), s.label)).collect();
246 sorted.sort_unstable();
247 for (input, label) in &sorted {
248 hasher.update(input.as_bytes());
249 hasher.update([0u8]); hasher.update(label.to_le_bytes());
251 }
252 let result = hasher.finalize();
253 format!("sha256:{result:x}")
254 }
255
256 fn auto_balance_classes(pipeline: &mut ClassifyPipeline, corpus: &[SafetySample]) {
267 use super::classification::{compute_class_weights, corpus_stats, ClassWeightStrategy};
268
269 if pipeline.config.class_weights.is_some() {
271 return;
272 }
273
274 let num_classes = pipeline.config.num_classes;
275 let stats = corpus_stats(corpus, num_classes);
276
277 let min_count = stats.class_counts.iter().copied().min().unwrap_or(0);
279 let max_count = stats.class_counts.iter().copied().max().unwrap_or(1);
280
281 if min_count == 0 {
282 println!(
283 " Warning: class with zero samples detected. \
284 Class weights not applied (would produce Inf)."
285 );
286 return;
287 }
288
289 let imbalance_ratio = max_count as f64 / min_count as f64;
290
291 if imbalance_ratio > 2.0 {
292 let weights =
293 compute_class_weights(&stats, ClassWeightStrategy::SqrtInverse, num_classes);
294 println!(
295 " Auto-detected class imbalance (ratio {imbalance_ratio:.1}:1), \
296 applying sqrt-inverse weights: {weights:?}"
297 );
298 println!(" Class counts: {:?} (total: {})", stats.class_counts, stats.total);
299 pipeline.config.class_weights = Some(weights);
300 } else {
301 println!(" Class balance OK (ratio {imbalance_ratio:.1}:1), using uniform weights");
302 }
303 }
304
305 fn oversample_training_data(train_data: &mut Vec<SafetySample>, seed: u64) {
311 use std::collections::HashMap;
312
313 let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
315 for (i, sample) in train_data.iter().enumerate() {
316 class_indices.entry(sample.label).or_default().push(i);
317 }
318
319 let majority_count = class_indices.values().map(std::vec::Vec::len).max().unwrap_or(0);
320 let before = train_data.len();
321
322 for indices in class_indices.values() {
324 let count = indices.len();
325 if count < majority_count {
326 let deficit = majority_count - count;
327 for i in 0..deficit {
328 let src_idx = indices[i % count];
329 train_data.push(train_data[src_idx].clone());
330 }
331 }
332 }
333
334 let n = train_data.len();
336 let mut rng_state: u64 = seed.wrapping_mul(0x517cc1b727220a95).wrapping_add(1);
337 for i in (1..n).rev() {
338 rng_state =
339 rng_state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
340 let j = (rng_state >> 33) as usize % (i + 1);
341 train_data.swap(i, j);
342 }
343
344 println!(
345 " Oversampled minority classes: {before} \u{2192} {} training samples",
346 train_data.len()
347 );
348 }
349
350 pub fn set_monitor_writer(&mut self, writer: crate::monitor::tui::TrainingStateWriter) {
355 self.monitor_writer = Some(writer);
356 }
357
358 pub fn train(&mut self) -> TrainResult {
369 if self.is_coordinator_mode() {
371 return self.train_as_coordinator();
372 }
373
374 let total_start = std::time::Instant::now();
375 let batch_size = self.pipeline.config.batch_size;
376 let batches_per_epoch = self.train_data.len().div_ceil(batch_size);
377 let total_steps = self.config.epochs * batches_per_epoch;
378 let warmup_steps = (self.config.warmup_fraction * total_steps as f32) as usize;
379 let lr_max = self.pipeline.optimizer_lr();
380
381 let mut scheduler =
382 WarmupCosineDecayLR::new(lr_max, self.config.lr_min, warmup_steps, total_steps);
383
384 if let Some(ref mut writer) = self.monitor_writer {
386 writer.set_epochs(self.config.epochs, batches_per_epoch);
387 let _ = writer.start();
388 }
389
390 let mut epoch_metrics_vec: Vec<EpochMetrics> = Vec::with_capacity(self.config.epochs);
391 let mut best_val_loss = f32::INFINITY;
392 let mut best_epoch: usize = 0;
393 let mut epochs_without_improvement: usize = 0;
394 let mut stopped_early = false;
395 let mut training_failed = false;
396
397 for epoch in 0..self.config.epochs {
398 let epoch_start = std::time::Instant::now();
399
400 self.shuffle_training_data(epoch);
402
403 let (train_loss, train_accuracy) = self.train_epoch(&mut scheduler, epoch);
405
406 let (val_loss, val_accuracy) = self.validate();
408
409 let epoch_time = epoch_start.elapsed();
410 let epoch_time_ms = epoch_time.as_millis() as u64;
411 let samples_per_sec = if epoch_time_ms > 0 {
412 self.train_data.len() as f32 / (epoch_time_ms as f32 / 1000.0)
413 } else {
414 0.0
415 };
416
417 let metrics = EpochMetrics {
418 epoch,
419 train_loss,
420 train_accuracy,
421 val_loss,
422 val_accuracy,
423 learning_rate: scheduler.get_lr(),
424 epoch_time_ms,
425 samples_per_sec,
426 };
427
428 epoch_metrics_vec.push(metrics.clone());
429
430 let is_best = val_loss < best_val_loss;
432 if let Some(ref writer) = self.monitor_writer {
433 writer.emit_epoch_summary(
434 epoch + 1,
435 self.config.epochs,
436 train_loss,
437 train_accuracy,
438 val_loss,
439 val_accuracy,
440 epoch_time.as_secs_f32(),
441 scheduler.get_lr(),
442 is_best,
443 );
444 }
445
446 if val_loss < best_val_loss {
448 best_val_loss = val_loss;
449 best_epoch = epoch;
450 epochs_without_improvement = 0;
451
452 let best_path = self.config.checkpoint_dir.join("best");
454 let _ = self.save_checkpoint(&best_path, epoch, &metrics);
455 } else {
456 epochs_without_improvement += 1;
457 }
458
459 let effective_save_every = if self.config.epochs <= self.config.save_every {
461 1
462 } else {
463 self.config.save_every
464 };
465 if effective_save_every > 0 && (epoch + 1) % effective_save_every == 0 {
466 let epoch_path = self.config.checkpoint_dir.join(format!("epoch-{epoch}"));
467 let _ = self.save_checkpoint(&epoch_path, epoch, &metrics);
468 }
469
470 if !train_loss.is_finite() || !val_loss.is_finite() {
472 if let Some(ref mut writer) = self.monitor_writer {
473 let _ = writer.fail("NaN or Inf loss detected");
474 }
475 training_failed = true;
476 stopped_early = true;
477 break;
478 }
479
480 if epochs_without_improvement >= self.config.early_stopping_patience {
482 stopped_early = true;
483 break;
484 }
485 }
486
487 if !training_failed {
489 if let Some(ref mut writer) = self.monitor_writer {
490 let _ = writer.complete();
491 }
492 }
493
494 let total_time_ms = total_start.elapsed().as_millis() as u64;
495
496 TrainResult {
497 epoch_metrics: epoch_metrics_vec,
498 best_epoch,
499 best_val_loss,
500 stopped_early,
501 total_time_ms,
502 }
503 }
504
505 fn train_as_coordinator(&mut self) -> TrainResult {
515 use super::gradient_server::GradientServer;
516
517 let dist_config = self
518 .config
519 .distributed
520 .clone()
521 .expect("train_as_coordinator requires distributed config");
522
523 let total_start = std::time::Instant::now();
524
525 let mut server = match GradientServer::bind(dist_config) {
527 Ok(s) => s,
528 Err(e) => {
529 eprintln!("[coordinator] Failed to bind: {e}");
530 return TrainResult {
531 epoch_metrics: vec![],
532 best_epoch: 0,
533 best_val_loss: f32::INFINITY,
534 stopped_early: true,
535 total_time_ms: total_start.elapsed().as_millis() as u64,
536 };
537 }
538 };
539
540 if let Err(e) = server.wait_for_workers() {
542 eprintln!("[coordinator] Worker connection failed: {e}");
543 return TrainResult {
544 epoch_metrics: vec![],
545 best_epoch: 0,
546 best_val_loss: f32::INFINITY,
547 stopped_early: true,
548 total_time_ms: total_start.elapsed().as_millis() as u64,
549 };
550 }
551
552 let num_workers = server.worker_count();
553 server.set_total_samples(self.train_data.len());
554
555 eprintln!(
556 "[coordinator] Starting training: {} epochs, {} workers, {} samples",
557 self.config.epochs,
558 num_workers,
559 self.train_data.len(),
560 );
561
562 let mut epoch_metrics_vec: Vec<EpochMetrics> = Vec::with_capacity(self.config.epochs);
563 let mut best_val_loss = f32::INFINITY;
564 let mut best_epoch = 0usize;
565 let mut stopped_early = false;
566
567 for epoch in 0..self.config.epochs {
568 let epoch_start = std::time::Instant::now();
569
570 self.shuffle_training_data(epoch);
571
572 let batch_size = self.pipeline.config.batch_size;
573 let mut total_loss = 0.0f32;
574 let mut total_correct = 0usize;
575 let mut total_samples = 0usize;
576
577 for (step_idx, chunk) in self.train_tokens.chunks(batch_size).enumerate() {
579 let step =
580 epoch as u64 * (self.train_tokens.len() / batch_size) as u64 + step_idx as u64;
581
582 if let Err(e) = server.send_shard_assignments(step) {
584 eprintln!("[coordinator] Shard assignment failed at step {step}: {e}");
585 stopped_early = true;
586 break;
587 }
588
589 let _local = self.pipeline.train_batch_tokenized(chunk);
591
592 match server.collect_and_reduce(step) {
594 Ok(allreduce) => {
595 self.pipeline.apply_lora_gradients(&allreduce.avg_gradients);
597
598 if let Err(e) = server.broadcast_averaged(step, &allreduce) {
600 eprintln!("[coordinator] Broadcast failed at step {step}: {e}");
601 stopped_early = true;
602 break;
603 }
604
605 total_loss += allreduce.global_loss * allreduce.total_samples as f32;
606 total_correct += allreduce.total_correct;
607 total_samples += allreduce.total_samples;
608 }
609 Err(e) => {
610 eprintln!("[coordinator] AllReduce failed at step {step}: {e}");
611 stopped_early = true;
612 break;
613 }
614 }
615 }
616
617 if stopped_early {
618 break;
619 }
620
621 let avg_loss = if total_samples > 0 { total_loss / total_samples as f32 } else { 0.0 };
622 let accuracy =
623 if total_samples > 0 { total_correct as f32 / total_samples as f32 } else { 0.0 };
624
625 let (val_loss, val_accuracy) = self.validate();
627
628 let epoch_time_ms = epoch_start.elapsed().as_millis() as u64;
629 let samples_per_sec = if epoch_time_ms > 0 {
630 total_samples as f32 / (epoch_time_ms as f32 / 1000.0)
631 } else {
632 0.0
633 };
634
635 let metrics = EpochMetrics {
636 epoch,
637 train_loss: avg_loss,
638 train_accuracy: accuracy,
639 val_loss,
640 val_accuracy,
641 learning_rate: self.pipeline.optimizer_lr(),
642 epoch_time_ms,
643 samples_per_sec,
644 };
645
646 eprintln!(
647 "[coordinator] Epoch {}: loss={:.4}, acc={:.1}%, val_loss={:.4}, val_acc={:.1}%",
648 epoch + 1,
649 avg_loss,
650 accuracy * 100.0,
651 val_loss,
652 val_accuracy * 100.0,
653 );
654
655 if val_loss < best_val_loss {
656 best_val_loss = val_loss;
657 best_epoch = epoch;
658
659 let best_path = self.config.checkpoint_dir.join("best");
660 let _ = self.save_checkpoint(&best_path, epoch, &metrics);
661 }
662
663 epoch_metrics_vec.push(metrics);
664 }
665
666 server.shutdown_workers();
667
668 TrainResult {
669 epoch_metrics: epoch_metrics_vec,
670 best_epoch,
671 best_val_loss,
672 stopped_early,
673 total_time_ms: total_start.elapsed().as_millis() as u64,
674 }
675 }
676
677 fn train_epoch(&mut self, scheduler: &mut WarmupCosineDecayLR, epoch: usize) -> (f32, f32) {
681 let batch_size = self.pipeline.config.batch_size;
682 let mut total_loss = 0.0f32;
683 let mut total_correct = 0usize;
684 let mut total_samples = 0usize;
685
686 let epoch_start = std::time::Instant::now();
687
688 for (batch_idx, chunk) in self.train_tokens.chunks(batch_size).enumerate() {
690 self.pipeline.set_optimizer_lr(scheduler.get_lr());
692
693 let result = self.pipeline.train_batch_tokenized(chunk);
694 total_loss += result.avg_loss * result.total as f32;
695 total_correct += result.correct;
696 total_samples += result.total;
697
698 let running_avg_loss =
699 if total_samples > 0 { total_loss / total_samples as f32 } else { 0.0 };
700 let elapsed_secs = epoch_start.elapsed().as_secs_f32();
701 let samples_per_sec =
702 if elapsed_secs > 0.0 { total_samples as f32 / elapsed_secs } else { 0.0 };
703 let current_lr = scheduler.get_lr();
704
705 let step = batch_idx + 1;
706 let acc =
707 if total_samples > 0 { total_correct as f32 / total_samples as f32 } else { 0.0 };
708
709 if let Some(ref mut writer) = self.monitor_writer {
711 let _ = writer.update_step(
712 epoch + 1,
713 step,
714 running_avg_loss,
715 current_lr,
716 result.grad_norm,
717 samples_per_sec,
718 acc,
719 );
720 }
721
722 scheduler.step();
724 }
725
726 let avg_loss = if total_samples > 0 { total_loss / total_samples as f32 } else { 0.0 };
727 let accuracy =
728 if total_samples > 0 { total_correct as f32 / total_samples as f32 } else { 0.0 };
729
730 (avg_loss, accuracy)
731 }
732
733 fn validate(&mut self) -> (f32, f32) {
738 let mut total_loss = 0.0f32;
739 let mut correct = 0usize;
740 let total = self.val_tokens.len();
741
742 let val_start = std::time::Instant::now();
743
744 for (i, sample) in self.val_tokens.iter().enumerate() {
747 let (loss, predicted) = self.pipeline.forward_only(&sample.token_ids, sample.label);
748 total_loss += loss;
749 if predicted == sample.label {
750 correct += 1;
751 }
752 if (i + 1) % 100 == 0 || i + 1 == total {
754 let elapsed = val_start.elapsed().as_secs_f32();
755 let sam_per_sec = if elapsed > 0.0 { (i + 1) as f32 / elapsed } else { 0.0 };
756 let running_acc = if i > 0 { correct as f32 / (i + 1) as f32 * 100.0 } else { 0.0 };
757 eprint!(
758 "\r Validating: {}/{} ({:.1} sam/s, acc={:.1}%) ",
759 i + 1,
760 total,
761 sam_per_sec,
762 running_acc,
763 );
764 }
765 }
766
767 let val_elapsed = val_start.elapsed();
768 let val_sam_per_sec = if val_elapsed.as_secs_f32() > 0.0 {
769 total as f32 / val_elapsed.as_secs_f32()
770 } else {
771 0.0
772 };
773 eprintln!(
774 "\r Validation complete: {} samples in {:.1}s ({:.1} sam/s) ",
775 total,
776 val_elapsed.as_secs_f32(),
777 val_sam_per_sec,
778 );
779
780 let avg_loss = if total > 0 { total_loss / total as f32 } else { 0.0 };
781 let accuracy = if total > 0 { correct as f32 / total as f32 } else { 0.0 };
782
783 (avg_loss, accuracy)
784 }
785
786 fn shuffle_training_data(&mut self, epoch: usize) {
791 let seed = self.rng_seed.wrapping_add(epoch as u64);
792 let mut rng_state = seed;
793 let n = self.train_data.len();
794
795 for i in (1..n).rev() {
798 rng_state = rng_state
799 .wrapping_mul(6_364_136_223_846_793_005)
800 .wrapping_add(1_442_695_040_888_963_407);
801 let j = (rng_state >> 33) as usize % (i + 1);
802 self.train_data.swap(i, j);
803 self.train_tokens.swap(i, j);
804 }
805 }
806
807 pub fn save_checkpoint(
821 &mut self,
822 path: &Path,
823 epoch: usize,
824 metrics: &EpochMetrics,
825 ) -> crate::Result<()> {
826 contract_pre_save_checkpoint!();
827 #[cfg(feature = "cuda")]
829 self.pipeline.sync_weights_to_cpu();
830 std::fs::create_dir_all(path).map_err(|e| {
831 crate::Error::Io(format!("Failed to create checkpoint dir {}: {e}", path.display()))
832 })?;
833
834 let metadata = serde_json::json!({
836 "epoch": epoch,
837 "train_loss": metrics.train_loss,
838 "train_accuracy": metrics.train_accuracy,
839 "val_loss": metrics.val_loss,
840 "val_accuracy": metrics.val_accuracy,
841 "learning_rate": metrics.learning_rate,
842 "epoch_time_ms": metrics.epoch_time_ms,
843 "samples_per_sec": metrics.samples_per_sec,
844 "class_weights": self.pipeline.config.class_weights,
845 });
846
847 let meta_path = path.join("metadata.json");
848 let meta_json = serde_json::to_string_pretty(&metadata).map_err(|e| {
849 crate::Error::Serialization(format!("Failed to serialize metadata: {e}"))
850 })?;
851 std::fs::write(&meta_path, meta_json)?;
852
853 let params = self.pipeline.classifier.parameters();
855 let st_path = path.join("model.safetensors");
856
857 let tensor_names = ["classifier.weight", "classifier.bias"];
859 let mut tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = params
860 .iter()
861 .zip(tensor_names.iter())
862 .map(|(tensor, name)| {
863 let data = tensor.data();
864 let bytes: Vec<u8> =
865 bytemuck::cast_slice(data.as_slice().expect("contiguous")).to_vec();
866 let shape = vec![tensor.len()];
867 (name.to_string(), bytes, shape)
868 })
869 .collect();
870
871 for (idx, lora) in self.pipeline.lora_layers.iter().enumerate() {
874 let layer = idx / 2;
875 let proj = if idx % 2 == 0 { "q" } else { "v" };
876
877 let a_data = lora.lora_a().data();
879 let a_bytes: Vec<u8> =
880 bytemuck::cast_slice(a_data.as_slice().expect("contiguous lora_a")).to_vec();
881 let a_shape = vec![lora.rank(), lora.d_in()];
882 tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_a"), a_bytes, a_shape));
883
884 let b_data = lora.lora_b().data();
886 let b_bytes: Vec<u8> =
887 bytemuck::cast_slice(b_data.as_slice().expect("contiguous lora_b")).to_vec();
888 let b_shape = vec![lora.d_out(), lora.rank()];
889 tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_b"), b_bytes, b_shape));
890 }
891
892 let views: Vec<(&str, safetensors::tensor::TensorView<'_>)> = tensor_data
893 .iter()
894 .map(|(name, bytes, shape)| {
895 let view = safetensors::tensor::TensorView::new(
896 safetensors::tensor::Dtype::F32,
897 shape.clone(),
898 bytes,
899 )
900 .expect("valid tensor view");
901 (name.as_str(), view)
902 })
903 .collect();
904
905 let mut st_metadata = std::collections::HashMap::new();
906 st_metadata.insert("epoch".to_string(), epoch.to_string());
907 st_metadata.insert("val_loss".to_string(), format!("{:.6}", metrics.val_loss));
908
909 let safetensor_bytes = safetensors::serialize(views, Some(st_metadata)).map_err(|e| {
910 crate::Error::Serialization(format!("SafeTensors serialization failed: {e}"))
911 })?;
912 std::fs::write(&st_path, safetensor_bytes)?;
913
914 self.save_apr_checkpoint(path, epoch, metrics)?;
916
917 self.save_adapter_apr(path, epoch, metrics)?;
919
920 let model_config = &self.pipeline.model.config;
924 let hf_config = serde_json::json!({
925 "architectures": ["Qwen2ForSequenceClassification"],
926 "model_type": "qwen2",
927 "hidden_size": model_config.hidden_size,
928 "num_attention_heads": model_config.num_attention_heads,
929 "num_key_value_heads": model_config.num_kv_heads,
930 "intermediate_size": model_config.intermediate_size,
931 "num_hidden_layers": model_config.num_hidden_layers,
932 "vocab_size": model_config.vocab_size,
933 "max_position_embeddings": model_config.max_position_embeddings,
934 "rms_norm_eps": model_config.rms_norm_eps,
935 "rope_theta": model_config.rope_theta,
936 "use_cache": true,
937 "torch_dtype": "float32",
938 "num_labels": self.pipeline.config.num_classes,
939 "problem_type": "single_label_classification",
940 });
941 let config_json = serde_json::to_string_pretty(&hf_config).map_err(|e| {
942 crate::Error::Serialization(format!("Failed to serialize config.json: {e}"))
943 })?;
944 std::fs::write(path.join("config.json"), config_json)?;
945
946 let lora_config = crate::lora::LoRAConfig::new(
948 self.pipeline.config.lora_rank,
949 self.pipeline.config.lora_alpha,
950 )
951 .target_qv_projections();
952
953 let base_model = self.pipeline.model_dir().map(|p| p.display().to_string());
954
955 let peft_config =
956 crate::lora::PeftAdapterConfig::from_lora_config(&lora_config, base_model.as_deref())
957 .with_task_type("SEQ_CLS");
958
959 let adapter_json = peft_config.to_json().map_err(|e| {
960 crate::Error::Serialization(format!("Failed to serialize adapter_config.json: {e}"))
961 })?;
962 std::fs::write(path.join("adapter_config.json"), adapter_json)?;
963
964 if let Some(model_dir) = self.pipeline.model_dir() {
966 let src = model_dir.join("tokenizer.json");
967 if src.exists() {
968 std::fs::copy(&src, path.join("tokenizer.json"))
969 .map_err(|e| crate::Error::Io(format!("Failed to copy tokenizer.json: {e}")))?;
970 }
971 }
972
973 contract_post_save_checkpoint!(());
974 Ok(())
975 }
976
977 fn save_apr_checkpoint(
987 &self,
988 path: &Path,
989 epoch: usize,
990 metrics: &EpochMetrics,
991 ) -> crate::Result<()> {
992 use aprender::serialization::apr::AprWriter;
993
994 let mut writer = AprWriter::new();
995
996 writer
998 .set_metadata("__checkpoint__.schema_version".to_string(), serde_json::json!("1.2.0"));
999
1000 writer.set_metadata("model_type".to_string(), serde_json::json!("adapter"));
1002 writer.set_metadata("epoch".to_string(), serde_json::json!(epoch));
1003 writer.set_metadata("val_loss".to_string(), serde_json::json!(metrics.val_loss));
1004 writer.set_metadata("val_accuracy".to_string(), serde_json::json!(metrics.val_accuracy));
1005 writer.set_metadata("train_loss".to_string(), serde_json::json!(metrics.train_loss));
1006 writer
1007 .set_metadata("train_accuracy".to_string(), serde_json::json!(metrics.train_accuracy));
1008 writer.set_metadata("architecture".to_string(), serde_json::json!("qwen2_classify"));
1009 writer.set_metadata(
1010 "num_classes".to_string(),
1011 serde_json::json!(self.pipeline.config.num_classes),
1012 );
1013 writer.set_metadata(
1014 "lora_rank".to_string(),
1015 serde_json::json!(self.pipeline.config.lora_rank),
1016 );
1017 writer.set_metadata(
1018 "lora_alpha".to_string(),
1019 serde_json::json!(self.pipeline.config.lora_alpha),
1020 );
1021 writer.set_metadata(
1022 "hidden_size".to_string(),
1023 serde_json::json!(self.pipeline.model.config.hidden_size),
1024 );
1025 writer.set_metadata(
1026 "num_layers".to_string(),
1027 serde_json::json!(self.pipeline.model.config.num_hidden_layers),
1028 );
1029
1030 writer.set_metadata("data_hash".to_string(), serde_json::json!(self.data_hash));
1032 if let Some(model_dir) = self.pipeline.model_dir() {
1033 writer.set_metadata(
1034 "base_model_source".to_string(),
1035 serde_json::json!(model_dir.display().to_string()),
1036 );
1037 }
1038 writer.set_metadata(
1039 "provenance".to_string(),
1040 serde_json::json!({
1041 "tool": format!("entrenar v{}", env!("CARGO_PKG_VERSION")),
1042 "started_at": self.train_start,
1043 }),
1044 );
1045
1046 let weight = &self.pipeline.classifier.weight;
1048 let weight_data = weight.data();
1049 let weight_slice = weight_data.as_slice().expect("contiguous weight");
1050 writer.add_tensor_f32("classifier.weight", vec![weight.len()], weight_slice);
1051
1052 let bias = &self.pipeline.classifier.bias;
1053 let bias_data = bias.data();
1054 let bias_slice = bias_data.as_slice().expect("contiguous bias");
1055 writer.add_tensor_f32("classifier.bias", vec![bias.len()], bias_slice);
1056
1057 for (idx, lora) in self.pipeline.lora_layers.iter().enumerate() {
1059 let layer = idx / 2;
1060 let proj = if idx % 2 == 0 { "q" } else { "v" };
1061
1062 let a_data = lora.lora_a().data();
1063 let a_slice = a_data.as_slice().expect("contiguous lora_a");
1064 writer.add_tensor_f32(
1065 format!("lora.{layer}.{proj}_proj.lora_a"),
1066 vec![lora.rank(), lora.d_in()],
1067 a_slice,
1068 );
1069
1070 let b_data = lora.lora_b().data();
1071 let b_slice = b_data.as_slice().expect("contiguous lora_b");
1072 writer.add_tensor_f32(
1073 format!("lora.{layer}.{proj}_proj.lora_b"),
1074 vec![lora.d_out(), lora.rank()],
1075 b_slice,
1076 );
1077 }
1078
1079 let optimizer = self.pipeline.optimizer();
1081
1082 writer.add_tensor_f32(
1084 "__training__.optimizer.step",
1085 vec![1],
1086 &[optimizer.step_count() as f32],
1087 );
1088
1089 for (i, (m_opt, v_opt)) in
1091 optimizer.first_moments().iter().zip(optimizer.second_moments().iter()).enumerate()
1092 {
1093 if let Some(m) = m_opt {
1094 let m_slice = m.as_slice().expect("contiguous moment m");
1095 writer.add_tensor_f32(
1096 format!("__training__.optimizer.m.{i}"),
1097 vec![m.len()],
1098 m_slice,
1099 );
1100 }
1101 if let Some(v) = v_opt {
1102 let v_slice = v.as_slice().expect("contiguous moment v");
1103 writer.add_tensor_f32(
1104 format!("__training__.optimizer.v.{i}"),
1105 vec![v.len()],
1106 v_slice,
1107 );
1108 }
1109 }
1110
1111 writer.add_tensor_f32("__training__.epoch", vec![1], &[epoch as f32]);
1113 writer.add_tensor_f32("__training__.learning_rate", vec![1], &[metrics.learning_rate]);
1114
1115 if !weight_slice.iter().all(|v| v.is_finite()) {
1117 return Err(crate::Error::Serialization(
1118 "F-CKPT-007: classifier.weight contains NaN or Inf".to_string(),
1119 ));
1120 }
1121 if !bias_slice.iter().all(|v| v.is_finite()) {
1122 return Err(crate::Error::Serialization(
1123 "F-CKPT-007: classifier.bias contains NaN or Inf".to_string(),
1124 ));
1125 }
1126 for (idx, lora) in self.pipeline.lora_layers.iter().enumerate() {
1127 let a = lora.lora_a().data();
1128 let b = lora.lora_b().data();
1129 if !a.iter().all(|v| v.is_finite()) {
1130 return Err(crate::Error::Serialization(format!(
1131 "F-CKPT-007: lora[{idx}].lora_a contains NaN or Inf"
1132 )));
1133 }
1134 if !b.iter().all(|v| v.is_finite()) {
1135 return Err(crate::Error::Serialization(format!(
1136 "F-CKPT-007: lora[{idx}].lora_b contains NaN or Inf"
1137 )));
1138 }
1139 }
1140
1141 let expected_weight_len =
1143 self.pipeline.config.num_classes * self.pipeline.model.config.hidden_size;
1144 if weight_slice.len() != expected_weight_len {
1145 return Err(crate::Error::Serialization(format!(
1146 "F-CKPT-008: classifier.weight shape mismatch: \
1147 expected {} ({}×{}), got {}",
1148 expected_weight_len,
1149 self.pipeline.config.num_classes,
1150 self.pipeline.model.config.hidden_size,
1151 weight_slice.len(),
1152 )));
1153 }
1154 if bias_slice.len() != self.pipeline.config.num_classes {
1155 return Err(crate::Error::Serialization(format!(
1156 "F-CKPT-008: classifier.bias shape mismatch: \
1157 expected {}, got {}",
1158 self.pipeline.config.num_classes,
1159 bias_slice.len(),
1160 )));
1161 }
1162
1163 let apr_path = path.join("model.apr");
1164 writer
1165 .write(&apr_path)
1166 .map_err(|e| crate::Error::Serialization(format!("APR serialization failed: {e}")))?;
1167
1168 Ok(())
1169 }
1170
1171 fn save_adapter_apr(
1176 &self,
1177 path: &Path,
1178 epoch: usize,
1179 metrics: &EpochMetrics,
1180 ) -> crate::Result<()> {
1181 use aprender::serialization::apr::AprWriter;
1182
1183 let mut writer = AprWriter::new();
1184
1185 writer
1186 .set_metadata("__checkpoint__.schema_version".to_string(), serde_json::json!("1.3.0"));
1187 writer.set_metadata("model_type".to_string(), serde_json::json!("adapter"));
1188 writer.set_metadata("epoch".to_string(), serde_json::json!(epoch));
1189 writer.set_metadata("val_loss".to_string(), serde_json::json!(metrics.val_loss));
1190 writer.set_metadata("val_accuracy".to_string(), serde_json::json!(metrics.val_accuracy));
1191 writer.set_metadata("architecture".to_string(), serde_json::json!("qwen2_classify"));
1192 writer.set_metadata(
1193 "num_classes".to_string(),
1194 serde_json::json!(self.pipeline.config.num_classes),
1195 );
1196 writer.set_metadata(
1197 "lora_rank".to_string(),
1198 serde_json::json!(self.pipeline.config.lora_rank),
1199 );
1200 writer.set_metadata(
1201 "lora_alpha".to_string(),
1202 serde_json::json!(self.pipeline.config.lora_alpha),
1203 );
1204 writer.set_metadata(
1205 "hidden_size".to_string(),
1206 serde_json::json!(self.pipeline.model.config.hidden_size),
1207 );
1208 writer.set_metadata("data_hash".to_string(), serde_json::json!(self.data_hash));
1209 writer.set_metadata(
1210 "provenance".to_string(),
1211 serde_json::json!({
1212 "tool": format!("entrenar v{}", env!("CARGO_PKG_VERSION")),
1213 "started_at": self.train_start,
1214 }),
1215 );
1216
1217 let weight = &self.pipeline.classifier.weight;
1219 let weight_data = weight.data();
1220 let weight_slice = weight_data.as_slice().expect("contiguous weight");
1221 writer.add_tensor_f32("classifier.weight", vec![weight.len()], weight_slice);
1222
1223 let bias = &self.pipeline.classifier.bias;
1224 let bias_data = bias.data();
1225 let bias_slice = bias_data.as_slice().expect("contiguous bias");
1226 writer.add_tensor_f32("classifier.bias", vec![bias.len()], bias_slice);
1227
1228 for (idx, lora) in self.pipeline.lora_layers.iter().enumerate() {
1230 let layer = idx / 2;
1231 let proj = if idx % 2 == 0 { "q" } else { "v" };
1232
1233 let a_data = lora.lora_a().data();
1234 let a_slice = a_data.as_slice().expect("contiguous lora_a");
1235 writer.add_tensor_f32(
1236 format!("lora.{layer}.{proj}_proj.lora_a"),
1237 vec![lora.rank(), lora.d_in()],
1238 a_slice,
1239 );
1240
1241 let b_data = lora.lora_b().data();
1242 let b_slice = b_data.as_slice().expect("contiguous lora_b");
1243 writer.add_tensor_f32(
1244 format!("lora.{layer}.{proj}_proj.lora_b"),
1245 vec![lora.d_out(), lora.rank()],
1246 b_slice,
1247 );
1248 }
1249
1250 let adapter_path = path.join("model.adapter.apr");
1251 writer.write(&adapter_path).map_err(|e| {
1252 crate::Error::Serialization(format!("APR adapter serialization failed: {e}"))
1253 })?;
1254
1255 Ok(())
1256 }
1257
1258 pub fn resume_from_apr_checkpoint(&mut self, apr_path: &Path) -> crate::Result<usize> {
1269 use aprender::serialization::apr::AprReader;
1270
1271 let reader = AprReader::open(apr_path).map_err(|e| {
1272 crate::Error::Serialization(format!("Failed to open APR checkpoint: {e}"))
1273 })?;
1274
1275 if let Some(saved_hash) = reader.get_metadata("data_hash").and_then(|v| v.as_str()) {
1277 if saved_hash != self.data_hash {
1278 return Err(crate::Error::ConfigError(format!(
1279 "F-CKPT-006: training data hash mismatch. \
1280 Checkpoint: {saved_hash}, current: {}. \
1281 Use --allow-data-mismatch to override.",
1282 self.data_hash,
1283 )));
1284 }
1285 }
1286
1287 let expected_weight =
1289 self.pipeline.config.num_classes * self.pipeline.model.config.hidden_size;
1290 reader
1291 .validate_tensor_shape("classifier.weight", expected_weight)
1292 .map_err(crate::Error::Serialization)?;
1293 reader
1294 .validate_tensor_shape("classifier.bias", self.pipeline.config.num_classes)
1295 .map_err(crate::Error::Serialization)?;
1296
1297 let weight_data = reader
1299 .read_tensor_f32_checked("classifier.weight")
1300 .map_err(crate::Error::Serialization)?;
1301 let bias_data = reader
1302 .read_tensor_f32_checked("classifier.bias")
1303 .map_err(crate::Error::Serialization)?;
1304
1305 self.pipeline
1306 .classifier
1307 .weight
1308 .data_mut()
1309 .as_slice_mut()
1310 .expect("contiguous weight")
1311 .copy_from_slice(&weight_data);
1312 self.pipeline
1313 .classifier
1314 .bias
1315 .data_mut()
1316 .as_slice_mut()
1317 .expect("contiguous bias")
1318 .copy_from_slice(&bias_data);
1319
1320 for (idx, lora) in self.pipeline.lora_layers.iter_mut().enumerate() {
1322 let layer = idx / 2;
1323 let proj = if idx % 2 == 0 { "q" } else { "v" };
1324
1325 let a_name = format!("lora.{layer}.{proj}_proj.lora_a");
1326 let b_name = format!("lora.{layer}.{proj}_proj.lora_b");
1327
1328 if let Ok(a_data) = reader.read_tensor_f32(&a_name) {
1329 let a_tensor = lora.lora_a_mut();
1330 let a_buf = a_tensor.data_mut();
1331 a_buf.as_slice_mut().expect("contiguous lora_a").copy_from_slice(&a_data);
1332 }
1333 if let Ok(b_data) = reader.read_tensor_f32(&b_name) {
1334 let b_tensor = lora.lora_b_mut();
1335 let b_buf = b_tensor.data_mut();
1336 b_buf.as_slice_mut().expect("contiguous lora_b").copy_from_slice(&b_data);
1337 }
1338 }
1339
1340 let optimizer = self.pipeline.optimizer_mut();
1342
1343 if let Ok(step_data) = reader.read_tensor_f32("__training__.optimizer.step") {
1345 optimizer.set_step_count(step_data[0] as u64);
1346 }
1347
1348 for i in 0..256 {
1350 let m_name = format!("__training__.optimizer.m.{i}");
1351 let v_name = format!("__training__.optimizer.v.{i}");
1352
1353 let m_exists = reader.read_tensor_f32(&m_name);
1354 let v_exists = reader.read_tensor_f32(&v_name);
1355
1356 match (m_exists, v_exists) {
1357 (Ok(m_data), Ok(v_data)) => {
1358 optimizer.set_first_moment(i, ndarray::Array1::from_vec(m_data));
1359 optimizer.set_second_moment(i, ndarray::Array1::from_vec(v_data));
1360 }
1361 _ => break, }
1363 }
1364
1365 let epoch = if let Ok(epoch_data) = reader.read_tensor_f32("__training__.epoch") {
1367 epoch_data[0] as usize
1368 } else {
1369 reader
1371 .get_metadata("epoch")
1372 .and_then(serde_json::Value::as_u64)
1373 .map_or(0, |e| e as usize)
1374 };
1375
1376 if let Ok(lr_data) = reader.read_tensor_f32("__training__.learning_rate") {
1377 self.pipeline.set_optimizer_lr(lr_data[0]);
1378 }
1379
1380 println!(
1381 " Resumed from APR checkpoint: epoch {epoch}, optimizer step {}",
1382 self.pipeline.optimizer().step_count(),
1383 );
1384
1385 Ok(epoch)
1386 }
1387
1388 pub fn split_dataset(
1398 data: &[SafetySample],
1399 val_ratio: f32,
1400 seed: u64,
1401 ) -> (Vec<SafetySample>, Vec<SafetySample>) {
1402 if data.is_empty() {
1403 return (Vec::new(), Vec::new());
1404 }
1405
1406 let mut indices: Vec<usize> = (0..data.len()).collect();
1407
1408 let mut rng_state = seed;
1410 for i in (1..indices.len()).rev() {
1411 rng_state = rng_state
1412 .wrapping_mul(6_364_136_223_846_793_005)
1413 .wrapping_add(1_442_695_040_888_963_407);
1414 let j = (rng_state >> 33) as usize % (i + 1);
1415 indices.swap(i, j);
1416 }
1417
1418 let val_count = ((data.len() as f32) * val_ratio).ceil() as usize;
1419 let val_count = val_count.min(data.len() - 1).max(1);
1420
1421 let val_indices = &indices[..val_count];
1422 let train_indices = &indices[val_count..];
1423
1424 let val_data: Vec<SafetySample> = val_indices.iter().map(|&i| data[i].clone()).collect();
1425 let train_data: Vec<SafetySample> =
1426 train_indices.iter().map(|&i| data[i].clone()).collect();
1427
1428 (train_data, val_data)
1429 }
1430
1431 #[must_use]
1433 pub fn train_data(&self) -> &[SafetySample] {
1434 &self.train_data
1435 }
1436
1437 #[must_use]
1439 pub fn val_data(&self) -> &[SafetySample] {
1440 &self.val_data
1441 }
1442
1443 #[must_use]
1445 pub fn config(&self) -> &TrainingConfig {
1446 &self.config
1447 }
1448
1449 pub fn pipeline_mut(&mut self) -> &mut ClassifyPipeline {
1451 &mut self.pipeline
1452 }
1453
1454 fn is_coordinator_mode(&self) -> bool {
1456 self.config
1457 .distributed
1458 .as_ref()
1459 .is_some_and(|d| matches!(d.role, super::distributed::NodeRole::Coordinator))
1460 }
1461
1462 pub fn run_worker(&mut self) -> crate::Result<TrainResult> {
1478 let dist_config = self.config.distributed.clone().ok_or_else(|| {
1479 crate::Error::ConfigError("distributed config required for worker mode".into())
1480 })?;
1481
1482 let gpu_count = 1u32; let backend = "cpu"; let client =
1486 super::worker_client::WorkerClient::connect(dist_config, gpu_count, backend)
1487 .map_err(|e| crate::Error::ConfigError(format!("worker connect failed: {e}")))?;
1488
1489 eprintln!(
1490 "[worker {}] Connected (total workers: {})",
1491 client.worker_id(),
1492 client.total_workers(),
1493 );
1494
1495 let total_start = std::time::Instant::now();
1496 let epoch_metrics_vec: Vec<EpochMetrics> = Vec::new();
1497 let best_val_loss = f32::INFINITY;
1498 let best_epoch = 0usize;
1499
1500 let all_samples: Vec<SafetySample> = self.train_data.clone();
1502
1503 loop {
1504 let shard = match client.receive_shard() {
1505 Ok(Some(s)) => s,
1506 Ok(None) => {
1507 eprintln!("[worker {}] Received shutdown", client.worker_id());
1508 break;
1509 }
1510 Err(e) => {
1511 return Err(crate::Error::ConfigError(format!("shard receive failed: {e}")));
1512 }
1513 };
1514
1515 let step = shard.step;
1516 let shard_start = shard.shard_start.min(all_samples.len());
1517 let shard_end = shard.shard_end.min(all_samples.len());
1518 let shard_data = &all_samples[shard_start..shard_end];
1519
1520 let batch_result = self.pipeline.train_batch(shard_data);
1522
1523 let gradients = self.pipeline.collect_lora_gradients();
1525
1526 client
1528 .send_gradients(
1529 step,
1530 gradients,
1531 batch_result.avg_loss,
1532 batch_result.correct,
1533 batch_result.total,
1534 )
1535 .map_err(|e| crate::Error::ConfigError(format!("gradient send failed: {e}")))?;
1536
1537 let averaged = client
1539 .receive_averaged()
1540 .map_err(|e| crate::Error::ConfigError(format!("averaged receive failed: {e}")))?;
1541
1542 self.pipeline.apply_lora_gradients(&averaged.gradients);
1544
1545 eprintln!(
1546 "[worker {}] step {step}: loss={:.4}, global_loss={:.4}",
1547 client.worker_id(),
1548 batch_result.avg_loss,
1549 averaged.global_loss,
1550 );
1551 }
1552
1553 Ok(TrainResult {
1554 epoch_metrics: epoch_metrics_vec,
1555 best_epoch,
1556 best_val_loss,
1557 stopped_early: false,
1558 total_time_ms: total_start.elapsed().as_millis() as u64,
1559 })
1560 }
1561
1562 pub fn evaluate(
1571 &mut self,
1572 data: &[SafetySample],
1573 label_names: &[String],
1574 ) -> ClassifyEvalReport {
1575 let start = std::time::Instant::now();
1576 let num_classes = self.pipeline.config.num_classes;
1577
1578 let mut y_true: Vec<usize> = Vec::with_capacity(data.len());
1579 let mut y_pred: Vec<usize> = Vec::with_capacity(data.len());
1580 let mut all_probs: Vec<Vec<f32>> = Vec::with_capacity(data.len());
1581 let mut total_loss = 0.0f32;
1582
1583 for sample in data {
1584 let ids = self.pipeline.tokenize(&sample.input);
1585 let (loss, predicted, probs) =
1586 self.pipeline.forward_only_with_probs(&ids, sample.label);
1587 total_loss += loss;
1588 y_true.push(sample.label);
1589 y_pred.push(predicted);
1590 all_probs.push(probs);
1591 }
1592
1593 ClassifyEvalReport::from_predictions_with_probs(
1594 &y_pred,
1595 &y_true,
1596 &all_probs,
1597 total_loss,
1598 num_classes,
1599 label_names,
1600 start.elapsed().as_millis() as u64,
1601 )
1602 }
1603}
1604
1605#[cfg(test)]
1606#[allow(clippy::unwrap_used)]
1607#[path = "classify_trainer_tests.rs"]
1608mod tests;