Skip to main content

entrenar/finetune/
classify_trainer.rs

1//! Production training loop for classification fine-tuning (SSC-026)
2//!
3//! `ClassifyTrainer` wraps `ClassifyPipeline` with epoch management,
4//! validation, checkpointing, LR scheduling, and early stopping.
5//!
6//! # Contract Invariants
7//!
8//! - F-LOOP-001: EMA loss decreasing over training (alpha=0.1, 5-epoch window)
9//! - F-LOOP-002: Validation computed every epoch
10//! - F-LOOP-007: Data shuffled per epoch (different order)
11//! - F-LOOP-008: Val split disjoint (zero overlap with training set)
12//! - F-LOOP-009: Val set frozen (same composition across epochs)
13//! - F-LOOP-010: Early stopping respects patience
14
15use super::classification::{SafetySample, TokenizedSample};
16use super::classify_eval_report::ClassifyEvalReport;
17use super::classify_pipeline::ClassifyPipeline;
18use super::distributed::DistributedConfig;
19use crate::optim::LRScheduler;
20use crate::optim::WarmupCosineDecayLR;
21use sha2::{Digest, Sha256};
22use std::path::{Path, PathBuf};
23
24/// Training configuration for the classification trainer.
25#[derive(Debug, Clone)]
26pub struct TrainingConfig {
27    /// Number of training epochs (default: 50)
28    pub epochs: usize,
29    /// Fraction of data reserved for validation (default: 0.2)
30    pub val_split: f32,
31    /// Save checkpoint every N epochs (default: 5)
32    pub save_every: usize,
33    /// Early stopping patience in epochs (default: 10)
34    pub early_stopping_patience: usize,
35    /// Directory for checkpoint files
36    pub checkpoint_dir: PathBuf,
37    /// Random seed for reproducibility (default: 42)
38    pub seed: u64,
39    /// Log metrics every N epochs (default: 1)
40    pub log_interval: usize,
41    /// Warmup steps as fraction of total steps (default: 0.1)
42    pub warmup_fraction: f32,
43    /// Minimum learning rate for cosine decay (default: 1e-6)
44    pub lr_min: f32,
45    /// Oversample minority classes to match majority count (default: false).
46    /// When enabled, duplicates minority-class samples and skips auto class weights.
47    pub oversample_minority: bool,
48    /// Quantize frozen weights to NF4 (4-bit) for QLoRA training (default: false).
49    ///
50    /// When enabled, transformer blocks use `CudaNf4TransformerBlock` instead of
51    /// `CudaTransformerBlock`, achieving ~8x VRAM compression on frozen weights.
52    /// Only LoRA adapters remain trainable in fp32.
53    pub quantize_nf4: bool,
54    /// Distributed training configuration (multi-node TCP gradient AllReduce).
55    ///
56    /// When set, the trainer operates in either coordinator or worker mode:
57    /// - Coordinator: manages epochs, shards data, AllReduces gradients (F-DP-001)
58    /// - Worker: receives shards, computes forward/backward, sends gradients
59    pub distributed: Option<DistributedConfig>,
60}
61
62impl Default for TrainingConfig {
63    fn default() -> Self {
64        Self {
65            epochs: 50,
66            val_split: 0.2,
67            save_every: 5,
68            early_stopping_patience: 10,
69            checkpoint_dir: PathBuf::from("checkpoints"),
70            seed: 42,
71            log_interval: 1,
72            warmup_fraction: 0.1,
73            lr_min: 1e-6,
74            oversample_minority: false,
75            quantize_nf4: false,
76            distributed: None,
77        }
78    }
79}
80
81/// Metrics for a single training epoch.
82#[derive(Debug, Clone)]
83pub struct EpochMetrics {
84    /// Epoch number (0-indexed)
85    pub epoch: usize,
86    /// Average training loss
87    pub train_loss: f32,
88    /// Training accuracy (0.0-1.0)
89    pub train_accuracy: f32,
90    /// Average validation loss
91    pub val_loss: f32,
92    /// Validation accuracy (0.0-1.0)
93    pub val_accuracy: f32,
94    /// Current learning rate
95    pub learning_rate: f32,
96    /// Epoch wall-clock time in milliseconds
97    pub epoch_time_ms: u64,
98    /// Training throughput (samples/second)
99    pub samples_per_sec: f32,
100}
101
102/// Result of the full training run.
103#[derive(Debug, Clone)]
104pub struct TrainResult {
105    /// Per-epoch metrics
106    pub epoch_metrics: Vec<EpochMetrics>,
107    /// Epoch with lowest validation loss
108    pub best_epoch: usize,
109    /// Lowest validation loss achieved
110    pub best_val_loss: f32,
111    /// Whether training stopped early
112    pub stopped_early: bool,
113    /// Total wall-clock training time in milliseconds
114    pub total_time_ms: u64,
115}
116
117/// Production training loop for classification fine-tuning.
118///
119/// Wraps `ClassifyPipeline` with:
120/// - Epoch management with per-epoch shuffling
121/// - Validation on a disjoint, frozen split
122/// - Warmup + cosine decay LR scheduling
123/// - Periodic checkpointing (SafeTensors + metadata JSON)
124/// - Early stopping with configurable patience
125pub struct ClassifyTrainer {
126    /// The classification pipeline (model + optimizer)
127    pipeline: ClassifyPipeline,
128    /// Training configuration
129    config: TrainingConfig,
130    /// Training data (shuffled per epoch)
131    train_data: Vec<SafetySample>,
132    /// Pre-tokenized training data — indices parallel `train_data` (KAIZEN-028).
133    /// Token IDs computed once at construction; shuffled in sync with `train_data`.
134    train_tokens: Vec<TokenizedSample>,
135    /// Pre-tokenized validation data (frozen, KAIZEN-028).
136    val_tokens: Vec<TokenizedSample>,
137    /// Validation data (frozen, never shuffled)
138    val_data: Vec<SafetySample>,
139    /// Base random seed
140    rng_seed: u64,
141    /// Optional monitor writer for live TUI updates
142    monitor_writer: Option<crate::monitor::tui::TrainingStateWriter>,
143    /// SHA-256 hash of training data for provenance (F-CKPT-017)
144    data_hash: String,
145    /// Training start timestamp (ISO 8601) for provenance
146    train_start: String,
147}
148
149#[allow(clippy::missing_fields_in_debug)]
150impl std::fmt::Debug for ClassifyTrainer {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        f.debug_struct("ClassifyTrainer")
153            .field("config", &self.config)
154            .field("train_data_len", &self.train_data.len())
155            .field("train_tokens_len", &self.train_tokens.len())
156            .field("val_data_len", &self.val_data.len())
157            .field("val_tokens_len", &self.val_tokens.len())
158            .field("rng_seed", &self.rng_seed)
159            .finish()
160    }
161}
162
163impl ClassifyTrainer {
164    /// Create a new trainer by splitting corpus into train/val sets.
165    ///
166    /// # Arguments
167    /// * `pipeline` - Initialized `ClassifyPipeline`
168    /// * `corpus` - Full dataset of labeled samples
169    /// * `config` - Training configuration
170    ///
171    /// # Errors
172    /// Returns error if corpus is empty, val_split is out of (0.0, 0.5],
173    /// or epochs is 0.
174    pub fn new(
175        mut pipeline: ClassifyPipeline,
176        corpus: Vec<SafetySample>,
177        config: TrainingConfig,
178    ) -> crate::Result<Self> {
179        if corpus.is_empty() {
180            return Err(crate::Error::ConfigError("SSC-026: corpus must not be empty".to_string()));
181        }
182        if config.val_split <= 0.0 || config.val_split > 0.5 {
183            return Err(crate::Error::ConfigError(format!(
184                "SSC-026: val_split must be in (0.0, 0.5], got {}",
185                config.val_split,
186            )));
187        }
188        if config.epochs == 0 {
189            return Err(crate::Error::ConfigError("SSC-026: epochs must be > 0".to_string()));
190        }
191
192        // ── Auto-detect class imbalance and apply weights ────────────────
193        // Skip when oversampling: data will be balanced, so weights are unnecessary.
194        if !config.oversample_minority {
195            Self::auto_balance_classes(&mut pipeline, &corpus);
196        }
197
198        let (mut train_data, val_data) =
199            Self::split_dataset(&corpus, config.val_split, config.seed);
200
201        if config.oversample_minority {
202            Self::oversample_training_data(&mut train_data, config.seed);
203        }
204
205        if train_data.is_empty() || val_data.is_empty() {
206            return Err(crate::Error::ConfigError(format!(
207                "SSC-026: split produced empty set (train={}, val={}). Need more samples.",
208                train_data.len(),
209                val_data.len(),
210            )));
211        }
212
213        let rng_seed = config.seed;
214
215        // KAIZEN-028: Pre-tokenize all samples once at construction time.
216        // Eliminates redundant BPE encoding across epochs and batches.
217        // For 17,942 SSC samples × 50 epochs = 897,100 tokenizations reduced to 17,942.
218        let train_tokens = pipeline.pre_tokenize(&train_data);
219        let val_tokens = pipeline.pre_tokenize(&val_data);
220
221        // F-CKPT-017: Compute data hash for provenance
222        let data_hash = Self::compute_data_hash(&corpus);
223        let train_start = chrono::Utc::now().to_rfc3339();
224
225        Ok(Self {
226            pipeline,
227            config,
228            train_data,
229            train_tokens,
230            val_tokens,
231            val_data,
232            rng_seed,
233            monitor_writer: None,
234            data_hash,
235            train_start,
236        })
237    }
238
239    /// Compute SHA-256 hash of training corpus for provenance (F-CKPT-017).
240    ///
241    /// Hash is computed over sorted (input, label) pairs for determinism.
242    fn compute_data_hash(corpus: &[SafetySample]) -> String {
243        let mut hasher = Sha256::new();
244        let mut sorted: Vec<(&str, usize)> =
245            corpus.iter().map(|s| (s.input.as_str(), s.label)).collect();
246        sorted.sort_unstable();
247        for (input, label) in &sorted {
248            hasher.update(input.as_bytes());
249            hasher.update([0u8]); // separator
250            hasher.update(label.to_le_bytes());
251        }
252        let result = hasher.finalize();
253        format!("sha256:{result:x}")
254    }
255
256    /// Auto-detect class imbalance and apply sqrt-inverse weights when no
257    /// explicit weights are configured.
258    ///
259    /// World-class training frameworks (sklearn, HuggingFace Trainer) auto-balance
260    /// by default. A training run on imbalanced data with uniform weights silently
261    /// optimizes for majority-class accuracy — the model learns to never predict
262    /// minority classes.
263    ///
264    /// Threshold: if max_count / min_count > 2.0, imbalance is detected.
265    /// Strategy: `SqrtInverse` (moderate rebalancing, avoids overadjust).
266    fn auto_balance_classes(pipeline: &mut ClassifyPipeline, corpus: &[SafetySample]) {
267        use super::classification::{compute_class_weights, corpus_stats, ClassWeightStrategy};
268
269        // Skip if user explicitly configured weights
270        if pipeline.config.class_weights.is_some() {
271            return;
272        }
273
274        let num_classes = pipeline.config.num_classes;
275        let stats = corpus_stats(corpus, num_classes);
276
277        // Check if any class is missing entirely
278        let min_count = stats.class_counts.iter().copied().min().unwrap_or(0);
279        let max_count = stats.class_counts.iter().copied().max().unwrap_or(1);
280
281        if min_count == 0 {
282            println!(
283                "  Warning: class with zero samples detected. \
284                 Class weights not applied (would produce Inf)."
285            );
286            return;
287        }
288
289        let imbalance_ratio = max_count as f64 / min_count as f64;
290
291        if imbalance_ratio > 2.0 {
292            let weights =
293                compute_class_weights(&stats, ClassWeightStrategy::SqrtInverse, num_classes);
294            println!(
295                "  Auto-detected class imbalance (ratio {imbalance_ratio:.1}:1), \
296                 applying sqrt-inverse weights: {weights:?}"
297            );
298            println!("  Class counts: {:?} (total: {})", stats.class_counts, stats.total);
299            pipeline.config.class_weights = Some(weights);
300        } else {
301            println!("  Class balance OK (ratio {imbalance_ratio:.1}:1), using uniform weights");
302        }
303    }
304
305    /// Oversample minority classes by duplicating samples until each class
306    /// matches the majority count.
307    ///
308    /// This is a simple, effective strategy for moderate imbalance (e.g. 93/7 splits).
309    /// After oversampling the training set is shuffled deterministically.
310    fn oversample_training_data(train_data: &mut Vec<SafetySample>, seed: u64) {
311        use std::collections::HashMap;
312
313        // Count per-class
314        let mut class_indices: HashMap<usize, Vec<usize>> = HashMap::new();
315        for (i, sample) in train_data.iter().enumerate() {
316            class_indices.entry(sample.label).or_default().push(i);
317        }
318
319        let majority_count = class_indices.values().map(std::vec::Vec::len).max().unwrap_or(0);
320        let before = train_data.len();
321
322        // Duplicate minority samples (cycling) to match majority
323        for indices in class_indices.values() {
324            let count = indices.len();
325            if count < majority_count {
326                let deficit = majority_count - count;
327                for i in 0..deficit {
328                    let src_idx = indices[i % count];
329                    train_data.push(train_data[src_idx].clone());
330                }
331            }
332        }
333
334        // Deterministic shuffle via Fisher-Yates with simple LCG
335        let n = train_data.len();
336        let mut rng_state: u64 = seed.wrapping_mul(0x517cc1b727220a95).wrapping_add(1);
337        for i in (1..n).rev() {
338            rng_state =
339                rng_state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
340            let j = (rng_state >> 33) as usize % (i + 1);
341            train_data.swap(i, j);
342        }
343
344        println!(
345            "  Oversampled minority classes: {before} \u{2192} {} training samples",
346            train_data.len()
347        );
348    }
349
350    /// Attach a monitor writer for live TUI updates.
351    ///
352    /// When set, training emits per-batch metrics to the experiment directory
353    /// via atomic JSON writes, enabling `apr monitor <dir>` from another shell.
354    pub fn set_monitor_writer(&mut self, writer: crate::monitor::tui::TrainingStateWriter) {
355        self.monitor_writer = Some(writer);
356    }
357
358    /// Run the full training loop.
359    ///
360    /// For each epoch:
361    /// 1. Shuffle training data (deterministic, seed varies per epoch)
362    /// 2. Process batches via `pipeline.train_batch()`
363    /// 3. Compute validation metrics (forward-only)
364    /// 4. Step LR scheduler
365    /// 5. Record metrics
366    /// 6. Save checkpoint if `save_every` or new best val_loss
367    /// 7. Check early stopping
368    pub fn train(&mut self) -> TrainResult {
369        // Dispatch to coordinator-mode training if distributed config is set
370        if self.is_coordinator_mode() {
371            return self.train_as_coordinator();
372        }
373
374        let total_start = std::time::Instant::now();
375        let batch_size = self.pipeline.config.batch_size;
376        let batches_per_epoch = self.train_data.len().div_ceil(batch_size);
377        let total_steps = self.config.epochs * batches_per_epoch;
378        let warmup_steps = (self.config.warmup_fraction * total_steps as f32) as usize;
379        let lr_max = self.pipeline.optimizer_lr();
380
381        let mut scheduler =
382            WarmupCosineDecayLR::new(lr_max, self.config.lr_min, warmup_steps, total_steps);
383
384        // Initialize monitor writer if attached
385        if let Some(ref mut writer) = self.monitor_writer {
386            writer.set_epochs(self.config.epochs, batches_per_epoch);
387            let _ = writer.start();
388        }
389
390        let mut epoch_metrics_vec: Vec<EpochMetrics> = Vec::with_capacity(self.config.epochs);
391        let mut best_val_loss = f32::INFINITY;
392        let mut best_epoch: usize = 0;
393        let mut epochs_without_improvement: usize = 0;
394        let mut stopped_early = false;
395        let mut training_failed = false;
396
397        for epoch in 0..self.config.epochs {
398            let epoch_start = std::time::Instant::now();
399
400            // F-LOOP-007: Shuffle training data with epoch-specific seed
401            self.shuffle_training_data(epoch);
402
403            // Train one epoch
404            let (train_loss, train_accuracy) = self.train_epoch(&mut scheduler, epoch);
405
406            // F-LOOP-002: Validate every epoch
407            let (val_loss, val_accuracy) = self.validate();
408
409            let epoch_time = epoch_start.elapsed();
410            let epoch_time_ms = epoch_time.as_millis() as u64;
411            let samples_per_sec = if epoch_time_ms > 0 {
412                self.train_data.len() as f32 / (epoch_time_ms as f32 / 1000.0)
413            } else {
414                0.0
415            };
416
417            let metrics = EpochMetrics {
418                epoch,
419                train_loss,
420                train_accuracy,
421                val_loss,
422                val_accuracy,
423                learning_rate: scheduler.get_lr(),
424                epoch_time_ms,
425                samples_per_sec,
426            };
427
428            epoch_metrics_vec.push(metrics.clone());
429
430            // Epoch summary via monitoring framework
431            let is_best = val_loss < best_val_loss;
432            if let Some(ref writer) = self.monitor_writer {
433                writer.emit_epoch_summary(
434                    epoch + 1,
435                    self.config.epochs,
436                    train_loss,
437                    train_accuracy,
438                    val_loss,
439                    val_accuracy,
440                    epoch_time.as_secs_f32(),
441                    scheduler.get_lr(),
442                    is_best,
443                );
444            }
445
446            // Track best validation loss
447            if val_loss < best_val_loss {
448                best_val_loss = val_loss;
449                best_epoch = epoch;
450                epochs_without_improvement = 0;
451
452                // Save best checkpoint
453                let best_path = self.config.checkpoint_dir.join("best");
454                let _ = self.save_checkpoint(&best_path, epoch, &metrics);
455            } else {
456                epochs_without_improvement += 1;
457            }
458
459            // Periodic checkpoint — when epochs <= save_every, save every epoch
460            let effective_save_every = if self.config.epochs <= self.config.save_every {
461                1
462            } else {
463                self.config.save_every
464            };
465            if effective_save_every > 0 && (epoch + 1) % effective_save_every == 0 {
466                let epoch_path = self.config.checkpoint_dir.join(format!("epoch-{epoch}"));
467                let _ = self.save_checkpoint(&epoch_path, epoch, &metrics);
468            }
469
470            // Detect NaN/Inf loss — signal failure to monitor
471            if !train_loss.is_finite() || !val_loss.is_finite() {
472                if let Some(ref mut writer) = self.monitor_writer {
473                    let _ = writer.fail("NaN or Inf loss detected");
474                }
475                training_failed = true;
476                stopped_early = true;
477                break;
478            }
479
480            // F-LOOP-010: Early stopping
481            if epochs_without_improvement >= self.config.early_stopping_patience {
482                stopped_early = true;
483                break;
484            }
485        }
486
487        // Signal training completion to monitor (skip if already failed)
488        if !training_failed {
489            if let Some(ref mut writer) = self.monitor_writer {
490                let _ = writer.complete();
491            }
492        }
493
494        let total_time_ms = total_start.elapsed().as_millis() as u64;
495
496        TrainResult {
497            epoch_metrics: epoch_metrics_vec,
498            best_epoch,
499            best_val_loss,
500            stopped_early,
501            total_time_ms,
502        }
503    }
504
505    /// Run training as the distributed coordinator.
506    ///
507    /// Starts a `GradientServer`, waits for workers, then runs the full
508    /// training loop with distributed AllReduce gradient averaging.
509    ///
510    /// # Contract: F-DP-001 (Weight Consistency)
511    ///
512    /// After each AllReduce step, all workers receive identical averaged
513    /// gradients and apply the same optimizer step.
514    fn train_as_coordinator(&mut self) -> TrainResult {
515        use super::gradient_server::GradientServer;
516
517        let dist_config = self
518            .config
519            .distributed
520            .clone()
521            .expect("train_as_coordinator requires distributed config");
522
523        let total_start = std::time::Instant::now();
524
525        // Bind gradient server
526        let mut server = match GradientServer::bind(dist_config) {
527            Ok(s) => s,
528            Err(e) => {
529                eprintln!("[coordinator] Failed to bind: {e}");
530                return TrainResult {
531                    epoch_metrics: vec![],
532                    best_epoch: 0,
533                    best_val_loss: f32::INFINITY,
534                    stopped_early: true,
535                    total_time_ms: total_start.elapsed().as_millis() as u64,
536                };
537            }
538        };
539
540        // Wait for all workers to connect
541        if let Err(e) = server.wait_for_workers() {
542            eprintln!("[coordinator] Worker connection failed: {e}");
543            return TrainResult {
544                epoch_metrics: vec![],
545                best_epoch: 0,
546                best_val_loss: f32::INFINITY,
547                stopped_early: true,
548                total_time_ms: total_start.elapsed().as_millis() as u64,
549            };
550        }
551
552        let num_workers = server.worker_count();
553        server.set_total_samples(self.train_data.len());
554
555        eprintln!(
556            "[coordinator] Starting training: {} epochs, {} workers, {} samples",
557            self.config.epochs,
558            num_workers,
559            self.train_data.len(),
560        );
561
562        let mut epoch_metrics_vec: Vec<EpochMetrics> = Vec::with_capacity(self.config.epochs);
563        let mut best_val_loss = f32::INFINITY;
564        let mut best_epoch = 0usize;
565        let mut stopped_early = false;
566
567        for epoch in 0..self.config.epochs {
568            let epoch_start = std::time::Instant::now();
569
570            self.shuffle_training_data(epoch);
571
572            let batch_size = self.pipeline.config.batch_size;
573            let mut total_loss = 0.0f32;
574            let mut total_correct = 0usize;
575            let mut total_samples = 0usize;
576
577            // KAIZEN-032: Borrow pre-tokenized data directly — no per-epoch clone.
578            for (step_idx, chunk) in self.train_tokens.chunks(batch_size).enumerate() {
579                let step =
580                    epoch as u64 * (self.train_tokens.len() / batch_size) as u64 + step_idx as u64;
581
582                // Send shard assignments to workers
583                if let Err(e) = server.send_shard_assignments(step) {
584                    eprintln!("[coordinator] Shard assignment failed at step {step}: {e}");
585                    stopped_early = true;
586                    break;
587                }
588
589                // Coordinator also computes its own shard (local forward/backward)
590                let _local = self.pipeline.train_batch_tokenized(chunk);
591
592                // Collect and average gradients from all workers (F-DP-001)
593                match server.collect_and_reduce(step) {
594                    Ok(allreduce) => {
595                        // Apply averaged gradients locally
596                        self.pipeline.apply_lora_gradients(&allreduce.avg_gradients);
597
598                        // Broadcast to workers
599                        if let Err(e) = server.broadcast_averaged(step, &allreduce) {
600                            eprintln!("[coordinator] Broadcast failed at step {step}: {e}");
601                            stopped_early = true;
602                            break;
603                        }
604
605                        total_loss += allreduce.global_loss * allreduce.total_samples as f32;
606                        total_correct += allreduce.total_correct;
607                        total_samples += allreduce.total_samples;
608                    }
609                    Err(e) => {
610                        eprintln!("[coordinator] AllReduce failed at step {step}: {e}");
611                        stopped_early = true;
612                        break;
613                    }
614                }
615            }
616
617            if stopped_early {
618                break;
619            }
620
621            let avg_loss = if total_samples > 0 { total_loss / total_samples as f32 } else { 0.0 };
622            let accuracy =
623                if total_samples > 0 { total_correct as f32 / total_samples as f32 } else { 0.0 };
624
625            // Validate on coordinator's local val set
626            let (val_loss, val_accuracy) = self.validate();
627
628            let epoch_time_ms = epoch_start.elapsed().as_millis() as u64;
629            let samples_per_sec = if epoch_time_ms > 0 {
630                total_samples as f32 / (epoch_time_ms as f32 / 1000.0)
631            } else {
632                0.0
633            };
634
635            let metrics = EpochMetrics {
636                epoch,
637                train_loss: avg_loss,
638                train_accuracy: accuracy,
639                val_loss,
640                val_accuracy,
641                learning_rate: self.pipeline.optimizer_lr(),
642                epoch_time_ms,
643                samples_per_sec,
644            };
645
646            eprintln!(
647                "[coordinator] Epoch {}: loss={:.4}, acc={:.1}%, val_loss={:.4}, val_acc={:.1}%",
648                epoch + 1,
649                avg_loss,
650                accuracy * 100.0,
651                val_loss,
652                val_accuracy * 100.0,
653            );
654
655            if val_loss < best_val_loss {
656                best_val_loss = val_loss;
657                best_epoch = epoch;
658
659                let best_path = self.config.checkpoint_dir.join("best");
660                let _ = self.save_checkpoint(&best_path, epoch, &metrics);
661            }
662
663            epoch_metrics_vec.push(metrics);
664        }
665
666        server.shutdown_workers();
667
668        TrainResult {
669            epoch_metrics: epoch_metrics_vec,
670            best_epoch,
671            best_val_loss,
672            stopped_early,
673            total_time_ms: total_start.elapsed().as_millis() as u64,
674        }
675    }
676
677    /// Train for one epoch, processing all training data in batches.
678    ///
679    /// Returns `(avg_loss, accuracy)` for the epoch.
680    fn train_epoch(&mut self, scheduler: &mut WarmupCosineDecayLR, epoch: usize) -> (f32, f32) {
681        let batch_size = self.pipeline.config.batch_size;
682        let mut total_loss = 0.0f32;
683        let mut total_correct = 0usize;
684        let mut total_samples = 0usize;
685
686        let epoch_start = std::time::Instant::now();
687
688        // KAIZEN-032: Borrow pre-tokenized data directly — no per-epoch clone.
689        for (batch_idx, chunk) in self.train_tokens.chunks(batch_size).enumerate() {
690            // Apply current LR from scheduler
691            self.pipeline.set_optimizer_lr(scheduler.get_lr());
692
693            let result = self.pipeline.train_batch_tokenized(chunk);
694            total_loss += result.avg_loss * result.total as f32;
695            total_correct += result.correct;
696            total_samples += result.total;
697
698            let running_avg_loss =
699                if total_samples > 0 { total_loss / total_samples as f32 } else { 0.0 };
700            let elapsed_secs = epoch_start.elapsed().as_secs_f32();
701            let samples_per_sec =
702                if elapsed_secs > 0.0 { total_samples as f32 / elapsed_secs } else { 0.0 };
703            let current_lr = scheduler.get_lr();
704
705            let step = batch_idx + 1;
706            let acc =
707                if total_samples > 0 { total_correct as f32 / total_samples as f32 } else { 0.0 };
708
709            // Emit per-batch metrics to monitor (JSON state file + optional console)
710            if let Some(ref mut writer) = self.monitor_writer {
711                let _ = writer.update_step(
712                    epoch + 1,
713                    step,
714                    running_avg_loss,
715                    current_lr,
716                    result.grad_norm,
717                    samples_per_sec,
718                    acc,
719                );
720            }
721
722            // Step scheduler per batch
723            scheduler.step();
724        }
725
726        let avg_loss = if total_samples > 0 { total_loss / total_samples as f32 } else { 0.0 };
727        let accuracy =
728            if total_samples > 0 { total_correct as f32 / total_samples as f32 } else { 0.0 };
729
730        (avg_loss, accuracy)
731    }
732
733    /// Compute validation metrics (forward-only, no gradient updates).
734    ///
735    /// F-LOOP-009: Val set is frozen — same samples every epoch.
736    /// KAIZEN-013: Uses pre-tokenized cache and reports progress.
737    fn validate(&mut self) -> (f32, f32) {
738        let mut total_loss = 0.0f32;
739        let mut correct = 0usize;
740        let total = self.val_tokens.len();
741
742        let val_start = std::time::Instant::now();
743
744        // KAIZEN-028: Use pre-tokenized validation data — no BPE re-encoding.
745        // KAIZEN-013: Progress reporting with timing and running accuracy.
746        for (i, sample) in self.val_tokens.iter().enumerate() {
747            let (loss, predicted) = self.pipeline.forward_only(&sample.token_ids, sample.label);
748            total_loss += loss;
749            if predicted == sample.label {
750                correct += 1;
751            }
752            // Progress reporting every 100 samples
753            if (i + 1) % 100 == 0 || i + 1 == total {
754                let elapsed = val_start.elapsed().as_secs_f32();
755                let sam_per_sec = if elapsed > 0.0 { (i + 1) as f32 / elapsed } else { 0.0 };
756                let running_acc = if i > 0 { correct as f32 / (i + 1) as f32 * 100.0 } else { 0.0 };
757                eprint!(
758                    "\r  Validating: {}/{} ({:.1} sam/s, acc={:.1}%)   ",
759                    i + 1,
760                    total,
761                    sam_per_sec,
762                    running_acc,
763                );
764            }
765        }
766
767        let val_elapsed = val_start.elapsed();
768        let val_sam_per_sec = if val_elapsed.as_secs_f32() > 0.0 {
769            total as f32 / val_elapsed.as_secs_f32()
770        } else {
771            0.0
772        };
773        eprintln!(
774            "\r  Validation complete: {} samples in {:.1}s ({:.1} sam/s)              ",
775            total,
776            val_elapsed.as_secs_f32(),
777            val_sam_per_sec,
778        );
779
780        let avg_loss = if total > 0 { total_loss / total as f32 } else { 0.0 };
781        let accuracy = if total > 0 { correct as f32 / total as f32 } else { 0.0 };
782
783        (avg_loss, accuracy)
784    }
785
786    /// Shuffle training data using Fisher-Yates with epoch-dependent seed.
787    ///
788    /// F-LOOP-007: `seed = base_seed + epoch` ensures different order per epoch
789    /// but deterministic across runs.
790    fn shuffle_training_data(&mut self, epoch: usize) {
791        let seed = self.rng_seed.wrapping_add(epoch as u64);
792        let mut rng_state = seed;
793        let n = self.train_data.len();
794
795        // Fisher-Yates shuffle with LCG PRNG
796        // KAIZEN-028: Shuffle train_tokens in sync with train_data
797        for i in (1..n).rev() {
798            rng_state = rng_state
799                .wrapping_mul(6_364_136_223_846_793_005)
800                .wrapping_add(1_442_695_040_888_963_407);
801            let j = (rng_state >> 33) as usize % (i + 1);
802            self.train_data.swap(i, j);
803            self.train_tokens.swap(i, j);
804        }
805    }
806
807    /// Save checkpoint with metadata JSON and SafeTensors model weights.
808    ///
809    /// When GPU training is active, downloads GPU-updated transformer weights
810    /// to CPU before saving so checkpoints include all trained parameters.
811    ///
812    /// Creates: `{path}/metadata.json` and `{path}/model.safetensors`
813    ///
814    /// # Contract (C-CKPT-001)
815    ///
816    /// - **Precondition**: `path` is a writable directory (or will be created)
817    /// - **Postcondition**: Checkpoint contains all trained parameters including
818    ///   GPU-updated transformer block weights (if GPU training active)
819    /// - **Invariant**: CPU model state is consistent with GPU state after save
820    pub fn save_checkpoint(
821        &mut self,
822        path: &Path,
823        epoch: usize,
824        metrics: &EpochMetrics,
825    ) -> crate::Result<()> {
826        contract_pre_save_checkpoint!();
827        // Sync GPU weights to CPU before saving (no-op if GPU training inactive)
828        #[cfg(feature = "cuda")]
829        self.pipeline.sync_weights_to_cpu();
830        std::fs::create_dir_all(path).map_err(|e| {
831            crate::Error::Io(format!("Failed to create checkpoint dir {}: {e}", path.display()))
832        })?;
833
834        // Save metadata.json
835        let metadata = serde_json::json!({
836            "epoch": epoch,
837            "train_loss": metrics.train_loss,
838            "train_accuracy": metrics.train_accuracy,
839            "val_loss": metrics.val_loss,
840            "val_accuracy": metrics.val_accuracy,
841            "learning_rate": metrics.learning_rate,
842            "epoch_time_ms": metrics.epoch_time_ms,
843            "samples_per_sec": metrics.samples_per_sec,
844            "class_weights": self.pipeline.config.class_weights,
845        });
846
847        let meta_path = path.join("metadata.json");
848        let meta_json = serde_json::to_string_pretty(&metadata).map_err(|e| {
849            crate::Error::Serialization(format!("Failed to serialize metadata: {e}"))
850        })?;
851        std::fs::write(&meta_path, meta_json)?;
852
853        // Save model weights as SafeTensors (classifier head + LoRA adapters)
854        let params = self.pipeline.classifier.parameters();
855        let st_path = path.join("model.safetensors");
856
857        // Collect classifier head tensor data
858        let tensor_names = ["classifier.weight", "classifier.bias"];
859        let mut tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = params
860            .iter()
861            .zip(tensor_names.iter())
862            .map(|(tensor, name)| {
863                let data = tensor.data();
864                let bytes: Vec<u8> =
865                    bytemuck::cast_slice(data.as_slice().expect("contiguous")).to_vec();
866                let shape = vec![tensor.len()];
867                (name.to_string(), bytes, shape)
868            })
869            .collect();
870
871        // Collect LoRA adapter weights (F-CLASS-008: Q/V projections)
872        // Convention: 2 adapters per layer (Q=even, V=odd)
873        for (idx, lora) in self.pipeline.lora_layers.iter().enumerate() {
874            let layer = idx / 2;
875            let proj = if idx % 2 == 0 { "q" } else { "v" };
876
877            // LoRA A: [rank, d_in]
878            let a_data = lora.lora_a().data();
879            let a_bytes: Vec<u8> =
880                bytemuck::cast_slice(a_data.as_slice().expect("contiguous lora_a")).to_vec();
881            let a_shape = vec![lora.rank(), lora.d_in()];
882            tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_a"), a_bytes, a_shape));
883
884            // LoRA B: [d_out, rank]
885            let b_data = lora.lora_b().data();
886            let b_bytes: Vec<u8> =
887                bytemuck::cast_slice(b_data.as_slice().expect("contiguous lora_b")).to_vec();
888            let b_shape = vec![lora.d_out(), lora.rank()];
889            tensor_data.push((format!("lora.{layer}.{proj}_proj.lora_b"), b_bytes, b_shape));
890        }
891
892        let views: Vec<(&str, safetensors::tensor::TensorView<'_>)> = tensor_data
893            .iter()
894            .map(|(name, bytes, shape)| {
895                let view = safetensors::tensor::TensorView::new(
896                    safetensors::tensor::Dtype::F32,
897                    shape.clone(),
898                    bytes,
899                )
900                .expect("valid tensor view");
901                (name.as_str(), view)
902            })
903            .collect();
904
905        let mut st_metadata = std::collections::HashMap::new();
906        st_metadata.insert("epoch".to_string(), epoch.to_string());
907        st_metadata.insert("val_loss".to_string(), format!("{:.6}", metrics.val_loss));
908
909        let safetensor_bytes = safetensors::serialize(views, Some(st_metadata)).map_err(|e| {
910            crate::Error::Serialization(format!("SafeTensors serialization failed: {e}"))
911        })?;
912        std::fs::write(&st_path, safetensor_bytes)?;
913
914        // Save APR format (full training state)
915        self.save_apr_checkpoint(path, epoch, metrics)?;
916
917        // Save adapter-only APR (F-CKPT-003: no __training__.* tensors)
918        self.save_adapter_apr(path, epoch, metrics)?;
919
920        // ── HuggingFace-compatible metadata (config.json, adapter_config.json, tokenizer.json) ──
921
922        // config.json: HF model architecture config
923        let model_config = &self.pipeline.model.config;
924        let hf_config = serde_json::json!({
925            "architectures": ["Qwen2ForSequenceClassification"],
926            "model_type": "qwen2",
927            "hidden_size": model_config.hidden_size,
928            "num_attention_heads": model_config.num_attention_heads,
929            "num_key_value_heads": model_config.num_kv_heads,
930            "intermediate_size": model_config.intermediate_size,
931            "num_hidden_layers": model_config.num_hidden_layers,
932            "vocab_size": model_config.vocab_size,
933            "max_position_embeddings": model_config.max_position_embeddings,
934            "rms_norm_eps": model_config.rms_norm_eps,
935            "rope_theta": model_config.rope_theta,
936            "use_cache": true,
937            "torch_dtype": "float32",
938            "num_labels": self.pipeline.config.num_classes,
939            "problem_type": "single_label_classification",
940        });
941        let config_json = serde_json::to_string_pretty(&hf_config).map_err(|e| {
942            crate::Error::Serialization(format!("Failed to serialize config.json: {e}"))
943        })?;
944        std::fs::write(path.join("config.json"), config_json)?;
945
946        // adapter_config.json: PEFT adapter configuration
947        let lora_config = crate::lora::LoRAConfig::new(
948            self.pipeline.config.lora_rank,
949            self.pipeline.config.lora_alpha,
950        )
951        .target_qv_projections();
952
953        let base_model = self.pipeline.model_dir().map(|p| p.display().to_string());
954
955        let peft_config =
956            crate::lora::PeftAdapterConfig::from_lora_config(&lora_config, base_model.as_deref())
957                .with_task_type("SEQ_CLS");
958
959        let adapter_json = peft_config.to_json().map_err(|e| {
960            crate::Error::Serialization(format!("Failed to serialize adapter_config.json: {e}"))
961        })?;
962        std::fs::write(path.join("adapter_config.json"), adapter_json)?;
963
964        // tokenizer.json: copy from base model directory (if available)
965        if let Some(model_dir) = self.pipeline.model_dir() {
966            let src = model_dir.join("tokenizer.json");
967            if src.exists() {
968                std::fs::copy(&src, path.join("tokenizer.json"))
969                    .map_err(|e| crate::Error::Io(format!("Failed to copy tokenizer.json: {e}")))?;
970            }
971        }
972
973        contract_post_save_checkpoint!(());
974        Ok(())
975    }
976
977    /// Save model in APR format with full training state.
978    ///
979    /// # Contract (F-CKPT-001, F-CKPT-004, F-CKPT-005)
980    ///
981    /// - **F-CKPT-001**: All adapter tensors (classifier + LoRA A/B)
982    /// - **F-CKPT-004**: Optimizer state (`__training__.optimizer.*`)
983    /// - **F-CKPT-005**: Training metadata (epoch, LR, step count)
984    ///
985    /// Inference readers skip `__training__.*` via `AprReader::open_filtered()`.
986    fn save_apr_checkpoint(
987        &self,
988        path: &Path,
989        epoch: usize,
990        metrics: &EpochMetrics,
991    ) -> crate::Result<()> {
992        use aprender::serialization::apr::AprWriter;
993
994        let mut writer = AprWriter::new();
995
996        // ── Schema version (F-CKPT-002) ─────────────────────────────────
997        writer
998            .set_metadata("__checkpoint__.schema_version".to_string(), serde_json::json!("1.2.0"));
999
1000        // ── Rich metadata ────────────────────────────────────────────────
1001        writer.set_metadata("model_type".to_string(), serde_json::json!("adapter"));
1002        writer.set_metadata("epoch".to_string(), serde_json::json!(epoch));
1003        writer.set_metadata("val_loss".to_string(), serde_json::json!(metrics.val_loss));
1004        writer.set_metadata("val_accuracy".to_string(), serde_json::json!(metrics.val_accuracy));
1005        writer.set_metadata("train_loss".to_string(), serde_json::json!(metrics.train_loss));
1006        writer
1007            .set_metadata("train_accuracy".to_string(), serde_json::json!(metrics.train_accuracy));
1008        writer.set_metadata("architecture".to_string(), serde_json::json!("qwen2_classify"));
1009        writer.set_metadata(
1010            "num_classes".to_string(),
1011            serde_json::json!(self.pipeline.config.num_classes),
1012        );
1013        writer.set_metadata(
1014            "lora_rank".to_string(),
1015            serde_json::json!(self.pipeline.config.lora_rank),
1016        );
1017        writer.set_metadata(
1018            "lora_alpha".to_string(),
1019            serde_json::json!(self.pipeline.config.lora_alpha),
1020        );
1021        writer.set_metadata(
1022            "hidden_size".to_string(),
1023            serde_json::json!(self.pipeline.model.config.hidden_size),
1024        );
1025        writer.set_metadata(
1026            "num_layers".to_string(),
1027            serde_json::json!(self.pipeline.model.config.num_hidden_layers),
1028        );
1029
1030        // ── Provenance chain (F-CKPT-017) ───────────────────────────────
1031        writer.set_metadata("data_hash".to_string(), serde_json::json!(self.data_hash));
1032        if let Some(model_dir) = self.pipeline.model_dir() {
1033            writer.set_metadata(
1034                "base_model_source".to_string(),
1035                serde_json::json!(model_dir.display().to_string()),
1036            );
1037        }
1038        writer.set_metadata(
1039            "provenance".to_string(),
1040            serde_json::json!({
1041                "tool": format!("entrenar v{}", env!("CARGO_PKG_VERSION")),
1042                "started_at": self.train_start,
1043            }),
1044        );
1045
1046        // ── Classifier head tensors ──────────────────────────────────────
1047        let weight = &self.pipeline.classifier.weight;
1048        let weight_data = weight.data();
1049        let weight_slice = weight_data.as_slice().expect("contiguous weight");
1050        writer.add_tensor_f32("classifier.weight", vec![weight.len()], weight_slice);
1051
1052        let bias = &self.pipeline.classifier.bias;
1053        let bias_data = bias.data();
1054        let bias_slice = bias_data.as_slice().expect("contiguous bias");
1055        writer.add_tensor_f32("classifier.bias", vec![bias.len()], bias_slice);
1056
1057        // ── LoRA adapter tensors (F-CKPT-001: adapter completeness) ──────
1058        for (idx, lora) in self.pipeline.lora_layers.iter().enumerate() {
1059            let layer = idx / 2;
1060            let proj = if idx % 2 == 0 { "q" } else { "v" };
1061
1062            let a_data = lora.lora_a().data();
1063            let a_slice = a_data.as_slice().expect("contiguous lora_a");
1064            writer.add_tensor_f32(
1065                format!("lora.{layer}.{proj}_proj.lora_a"),
1066                vec![lora.rank(), lora.d_in()],
1067                a_slice,
1068            );
1069
1070            let b_data = lora.lora_b().data();
1071            let b_slice = b_data.as_slice().expect("contiguous lora_b");
1072            writer.add_tensor_f32(
1073                format!("lora.{layer}.{proj}_proj.lora_b"),
1074                vec![lora.d_out(), lora.rank()],
1075                b_slice,
1076            );
1077        }
1078
1079        // ── Training state (F-CKPT-004: optimizer moments) ──────────────
1080        let optimizer = self.pipeline.optimizer();
1081
1082        // Save AdamW step counter as 1-element tensor
1083        writer.add_tensor_f32(
1084            "__training__.optimizer.step",
1085            vec![1],
1086            &[optimizer.step_count() as f32],
1087        );
1088
1089        // Save first moments (m) and second moments (v)
1090        for (i, (m_opt, v_opt)) in
1091            optimizer.first_moments().iter().zip(optimizer.second_moments().iter()).enumerate()
1092        {
1093            if let Some(m) = m_opt {
1094                let m_slice = m.as_slice().expect("contiguous moment m");
1095                writer.add_tensor_f32(
1096                    format!("__training__.optimizer.m.{i}"),
1097                    vec![m.len()],
1098                    m_slice,
1099                );
1100            }
1101            if let Some(v) = v_opt {
1102                let v_slice = v.as_slice().expect("contiguous moment v");
1103                writer.add_tensor_f32(
1104                    format!("__training__.optimizer.v.{i}"),
1105                    vec![v.len()],
1106                    v_slice,
1107                );
1108            }
1109        }
1110
1111        // ── Training metadata (F-CKPT-005) ──────────────────────────────
1112        writer.add_tensor_f32("__training__.epoch", vec![1], &[epoch as f32]);
1113        writer.add_tensor_f32("__training__.learning_rate", vec![1], &[metrics.learning_rate]);
1114
1115        // ── NaN/Inf check (F-CKPT-007) ──────────────────────────────────
1116        if !weight_slice.iter().all(|v| v.is_finite()) {
1117            return Err(crate::Error::Serialization(
1118                "F-CKPT-007: classifier.weight contains NaN or Inf".to_string(),
1119            ));
1120        }
1121        if !bias_slice.iter().all(|v| v.is_finite()) {
1122            return Err(crate::Error::Serialization(
1123                "F-CKPT-007: classifier.bias contains NaN or Inf".to_string(),
1124            ));
1125        }
1126        for (idx, lora) in self.pipeline.lora_layers.iter().enumerate() {
1127            let a = lora.lora_a().data();
1128            let b = lora.lora_b().data();
1129            if !a.iter().all(|v| v.is_finite()) {
1130                return Err(crate::Error::Serialization(format!(
1131                    "F-CKPT-007: lora[{idx}].lora_a contains NaN or Inf"
1132                )));
1133            }
1134            if !b.iter().all(|v| v.is_finite()) {
1135                return Err(crate::Error::Serialization(format!(
1136                    "F-CKPT-007: lora[{idx}].lora_b contains NaN or Inf"
1137                )));
1138            }
1139        }
1140
1141        // ── Shape validation (F-CKPT-008) ────────────────────────────────
1142        let expected_weight_len =
1143            self.pipeline.config.num_classes * self.pipeline.model.config.hidden_size;
1144        if weight_slice.len() != expected_weight_len {
1145            return Err(crate::Error::Serialization(format!(
1146                "F-CKPT-008: classifier.weight shape mismatch: \
1147                 expected {} ({}×{}), got {}",
1148                expected_weight_len,
1149                self.pipeline.config.num_classes,
1150                self.pipeline.model.config.hidden_size,
1151                weight_slice.len(),
1152            )));
1153        }
1154        if bias_slice.len() != self.pipeline.config.num_classes {
1155            return Err(crate::Error::Serialization(format!(
1156                "F-CKPT-008: classifier.bias shape mismatch: \
1157                 expected {}, got {}",
1158                self.pipeline.config.num_classes,
1159                bias_slice.len(),
1160            )));
1161        }
1162
1163        let apr_path = path.join("model.apr");
1164        writer
1165            .write(&apr_path)
1166            .map_err(|e| crate::Error::Serialization(format!("APR serialization failed: {e}")))?;
1167
1168        Ok(())
1169    }
1170
1171    /// Save adapter-only APR (no training state) (F-CKPT-003).
1172    ///
1173    /// Produces a `.adapter.apr` with zero `__training__.*` tensors.
1174    /// Used for publishing and inference deployment.
1175    fn save_adapter_apr(
1176        &self,
1177        path: &Path,
1178        epoch: usize,
1179        metrics: &EpochMetrics,
1180    ) -> crate::Result<()> {
1181        use aprender::serialization::apr::AprWriter;
1182
1183        let mut writer = AprWriter::new();
1184
1185        writer
1186            .set_metadata("__checkpoint__.schema_version".to_string(), serde_json::json!("1.3.0"));
1187        writer.set_metadata("model_type".to_string(), serde_json::json!("adapter"));
1188        writer.set_metadata("epoch".to_string(), serde_json::json!(epoch));
1189        writer.set_metadata("val_loss".to_string(), serde_json::json!(metrics.val_loss));
1190        writer.set_metadata("val_accuracy".to_string(), serde_json::json!(metrics.val_accuracy));
1191        writer.set_metadata("architecture".to_string(), serde_json::json!("qwen2_classify"));
1192        writer.set_metadata(
1193            "num_classes".to_string(),
1194            serde_json::json!(self.pipeline.config.num_classes),
1195        );
1196        writer.set_metadata(
1197            "lora_rank".to_string(),
1198            serde_json::json!(self.pipeline.config.lora_rank),
1199        );
1200        writer.set_metadata(
1201            "lora_alpha".to_string(),
1202            serde_json::json!(self.pipeline.config.lora_alpha),
1203        );
1204        writer.set_metadata(
1205            "hidden_size".to_string(),
1206            serde_json::json!(self.pipeline.model.config.hidden_size),
1207        );
1208        writer.set_metadata("data_hash".to_string(), serde_json::json!(self.data_hash));
1209        writer.set_metadata(
1210            "provenance".to_string(),
1211            serde_json::json!({
1212                "tool": format!("entrenar v{}", env!("CARGO_PKG_VERSION")),
1213                "started_at": self.train_start,
1214            }),
1215        );
1216
1217        // Classifier head
1218        let weight = &self.pipeline.classifier.weight;
1219        let weight_data = weight.data();
1220        let weight_slice = weight_data.as_slice().expect("contiguous weight");
1221        writer.add_tensor_f32("classifier.weight", vec![weight.len()], weight_slice);
1222
1223        let bias = &self.pipeline.classifier.bias;
1224        let bias_data = bias.data();
1225        let bias_slice = bias_data.as_slice().expect("contiguous bias");
1226        writer.add_tensor_f32("classifier.bias", vec![bias.len()], bias_slice);
1227
1228        // LoRA adapters (NO __training__.* tensors — F-CKPT-003)
1229        for (idx, lora) in self.pipeline.lora_layers.iter().enumerate() {
1230            let layer = idx / 2;
1231            let proj = if idx % 2 == 0 { "q" } else { "v" };
1232
1233            let a_data = lora.lora_a().data();
1234            let a_slice = a_data.as_slice().expect("contiguous lora_a");
1235            writer.add_tensor_f32(
1236                format!("lora.{layer}.{proj}_proj.lora_a"),
1237                vec![lora.rank(), lora.d_in()],
1238                a_slice,
1239            );
1240
1241            let b_data = lora.lora_b().data();
1242            let b_slice = b_data.as_slice().expect("contiguous lora_b");
1243            writer.add_tensor_f32(
1244                format!("lora.{layer}.{proj}_proj.lora_b"),
1245                vec![lora.d_out(), lora.rank()],
1246                b_slice,
1247            );
1248        }
1249
1250        let adapter_path = path.join("model.adapter.apr");
1251        writer.write(&adapter_path).map_err(|e| {
1252            crate::Error::Serialization(format!("APR adapter serialization failed: {e}"))
1253        })?;
1254
1255        Ok(())
1256    }
1257
1258    /// Resume training state from an APR checkpoint (F-CKPT-006).
1259    ///
1260    /// Loads model weights (classifier + LoRA) and optimizer state
1261    /// (`__training__.*` tensors) from a `.ckpt.apr` or `model.apr` file.
1262    ///
1263    /// Returns the epoch number stored in the checkpoint so the training
1264    /// loop can resume from the next epoch.
1265    ///
1266    /// # Errors
1267    /// Returns error if checkpoint is invalid or tensors are missing.
1268    pub fn resume_from_apr_checkpoint(&mut self, apr_path: &Path) -> crate::Result<usize> {
1269        use aprender::serialization::apr::AprReader;
1270
1271        let reader = AprReader::open(apr_path).map_err(|e| {
1272            crate::Error::Serialization(format!("Failed to open APR checkpoint: {e}"))
1273        })?;
1274
1275        // ── Data hash verification (F-CKPT-006) ─────────────────────────
1276        if let Some(saved_hash) = reader.get_metadata("data_hash").and_then(|v| v.as_str()) {
1277            if saved_hash != self.data_hash {
1278                return Err(crate::Error::ConfigError(format!(
1279                    "F-CKPT-006: training data hash mismatch. \
1280                     Checkpoint: {saved_hash}, current: {}. \
1281                     Use --allow-data-mismatch to override.",
1282                    self.data_hash,
1283                )));
1284            }
1285        }
1286
1287        // ── Shape-config validation (F-CKPT-014) ────────────────────────
1288        let expected_weight =
1289            self.pipeline.config.num_classes * self.pipeline.model.config.hidden_size;
1290        reader
1291            .validate_tensor_shape("classifier.weight", expected_weight)
1292            .map_err(crate::Error::Serialization)?;
1293        reader
1294            .validate_tensor_shape("classifier.bias", self.pipeline.config.num_classes)
1295            .map_err(crate::Error::Serialization)?;
1296
1297        // ── Restore classifier head (F-CKPT-013: NaN scan) ──────────────
1298        let weight_data = reader
1299            .read_tensor_f32_checked("classifier.weight")
1300            .map_err(crate::Error::Serialization)?;
1301        let bias_data = reader
1302            .read_tensor_f32_checked("classifier.bias")
1303            .map_err(crate::Error::Serialization)?;
1304
1305        self.pipeline
1306            .classifier
1307            .weight
1308            .data_mut()
1309            .as_slice_mut()
1310            .expect("contiguous weight")
1311            .copy_from_slice(&weight_data);
1312        self.pipeline
1313            .classifier
1314            .bias
1315            .data_mut()
1316            .as_slice_mut()
1317            .expect("contiguous bias")
1318            .copy_from_slice(&bias_data);
1319
1320        // ── Restore LoRA adapters ───────────────────────────────────────
1321        for (idx, lora) in self.pipeline.lora_layers.iter_mut().enumerate() {
1322            let layer = idx / 2;
1323            let proj = if idx % 2 == 0 { "q" } else { "v" };
1324
1325            let a_name = format!("lora.{layer}.{proj}_proj.lora_a");
1326            let b_name = format!("lora.{layer}.{proj}_proj.lora_b");
1327
1328            if let Ok(a_data) = reader.read_tensor_f32(&a_name) {
1329                let a_tensor = lora.lora_a_mut();
1330                let a_buf = a_tensor.data_mut();
1331                a_buf.as_slice_mut().expect("contiguous lora_a").copy_from_slice(&a_data);
1332            }
1333            if let Ok(b_data) = reader.read_tensor_f32(&b_name) {
1334                let b_tensor = lora.lora_b_mut();
1335                let b_buf = b_tensor.data_mut();
1336                b_buf.as_slice_mut().expect("contiguous lora_b").copy_from_slice(&b_data);
1337            }
1338        }
1339
1340        // ── Restore optimizer state (F-CKPT-004) ────────────────────────
1341        let optimizer = self.pipeline.optimizer_mut();
1342
1343        // Restore step counter
1344        if let Ok(step_data) = reader.read_tensor_f32("__training__.optimizer.step") {
1345            optimizer.set_step_count(step_data[0] as u64);
1346        }
1347
1348        // Restore first and second moments
1349        for i in 0..256 {
1350            let m_name = format!("__training__.optimizer.m.{i}");
1351            let v_name = format!("__training__.optimizer.v.{i}");
1352
1353            let m_exists = reader.read_tensor_f32(&m_name);
1354            let v_exists = reader.read_tensor_f32(&v_name);
1355
1356            match (m_exists, v_exists) {
1357                (Ok(m_data), Ok(v_data)) => {
1358                    optimizer.set_first_moment(i, ndarray::Array1::from_vec(m_data));
1359                    optimizer.set_second_moment(i, ndarray::Array1::from_vec(v_data));
1360                }
1361                _ => break, // No more moment buffers
1362            }
1363        }
1364
1365        // ── Restore training metadata (F-CKPT-005) ─────────────────────
1366        let epoch = if let Ok(epoch_data) = reader.read_tensor_f32("__training__.epoch") {
1367            epoch_data[0] as usize
1368        } else {
1369            // Fall back to metadata
1370            reader
1371                .get_metadata("epoch")
1372                .and_then(serde_json::Value::as_u64)
1373                .map_or(0, |e| e as usize)
1374        };
1375
1376        if let Ok(lr_data) = reader.read_tensor_f32("__training__.learning_rate") {
1377            self.pipeline.set_optimizer_lr(lr_data[0]);
1378        }
1379
1380        println!(
1381            "  Resumed from APR checkpoint: epoch {epoch}, optimizer step {}",
1382            self.pipeline.optimizer().step_count(),
1383        );
1384
1385        Ok(epoch)
1386    }
1387
1388    /// Split dataset into disjoint train/val sets.
1389    ///
1390    /// F-LOOP-008: Guarantees zero overlap between train and val.
1391    /// F-LOOP-009: Val set is deterministic given the same seed.
1392    ///
1393    /// # Arguments
1394    /// * `data` - Full dataset
1395    /// * `val_ratio` - Fraction for validation (0.0, 0.5]
1396    /// * `seed` - Random seed for deterministic shuffling
1397    pub fn split_dataset(
1398        data: &[SafetySample],
1399        val_ratio: f32,
1400        seed: u64,
1401    ) -> (Vec<SafetySample>, Vec<SafetySample>) {
1402        if data.is_empty() {
1403            return (Vec::new(), Vec::new());
1404        }
1405
1406        let mut indices: Vec<usize> = (0..data.len()).collect();
1407
1408        // Fisher-Yates shuffle with LCG PRNG for determinism
1409        let mut rng_state = seed;
1410        for i in (1..indices.len()).rev() {
1411            rng_state = rng_state
1412                .wrapping_mul(6_364_136_223_846_793_005)
1413                .wrapping_add(1_442_695_040_888_963_407);
1414            let j = (rng_state >> 33) as usize % (i + 1);
1415            indices.swap(i, j);
1416        }
1417
1418        let val_count = ((data.len() as f32) * val_ratio).ceil() as usize;
1419        let val_count = val_count.min(data.len() - 1).max(1);
1420
1421        let val_indices = &indices[..val_count];
1422        let train_indices = &indices[val_count..];
1423
1424        let val_data: Vec<SafetySample> = val_indices.iter().map(|&i| data[i].clone()).collect();
1425        let train_data: Vec<SafetySample> =
1426            train_indices.iter().map(|&i| data[i].clone()).collect();
1427
1428        (train_data, val_data)
1429    }
1430
1431    /// Get a reference to the training data.
1432    #[must_use]
1433    pub fn train_data(&self) -> &[SafetySample] {
1434        &self.train_data
1435    }
1436
1437    /// Get a reference to the validation data.
1438    #[must_use]
1439    pub fn val_data(&self) -> &[SafetySample] {
1440        &self.val_data
1441    }
1442
1443    /// Get a reference to the training config.
1444    #[must_use]
1445    pub fn config(&self) -> &TrainingConfig {
1446        &self.config
1447    }
1448
1449    /// Get a mutable reference to the underlying pipeline.
1450    pub fn pipeline_mut(&mut self) -> &mut ClassifyPipeline {
1451        &mut self.pipeline
1452    }
1453
1454    /// Check if distributed coordinator mode is configured.
1455    fn is_coordinator_mode(&self) -> bool {
1456        self.config
1457            .distributed
1458            .as_ref()
1459            .is_some_and(|d| matches!(d.role, super::distributed::NodeRole::Coordinator))
1460    }
1461
1462    /// Run as a distributed worker node.
1463    ///
1464    /// Connects to the coordinator, then enters a loop:
1465    /// 1. Receive shard assignment (or shutdown)
1466    /// 2. Compute forward/backward on assigned shard
1467    /// 3. Collect LoRA gradients and send to coordinator
1468    /// 4. Receive averaged gradients and apply optimizer step
1469    ///
1470    /// # Contract: F-DP-001 (Weight Consistency)
1471    ///
1472    /// After applying averaged gradients, worker weights match coordinator weights.
1473    ///
1474    /// # Errors
1475    ///
1476    /// Returns error on connection failure or protocol violation.
1477    pub fn run_worker(&mut self) -> crate::Result<TrainResult> {
1478        let dist_config = self.config.distributed.clone().ok_or_else(|| {
1479            crate::Error::ConfigError("distributed config required for worker mode".into())
1480        })?;
1481
1482        let gpu_count = 1u32; // single GPU per worker for now
1483        let backend = "cpu"; // will be wgpu/cuda when GPU training wired
1484
1485        let client =
1486            super::worker_client::WorkerClient::connect(dist_config, gpu_count, backend)
1487                .map_err(|e| crate::Error::ConfigError(format!("worker connect failed: {e}")))?;
1488
1489        eprintln!(
1490            "[worker {}] Connected (total workers: {})",
1491            client.worker_id(),
1492            client.total_workers(),
1493        );
1494
1495        let total_start = std::time::Instant::now();
1496        let epoch_metrics_vec: Vec<EpochMetrics> = Vec::new();
1497        let best_val_loss = f32::INFINITY;
1498        let best_epoch = 0usize;
1499
1500        // Clone training data so we can index into it by shard range
1501        let all_samples: Vec<SafetySample> = self.train_data.clone();
1502
1503        loop {
1504            let shard = match client.receive_shard() {
1505                Ok(Some(s)) => s,
1506                Ok(None) => {
1507                    eprintln!("[worker {}] Received shutdown", client.worker_id());
1508                    break;
1509                }
1510                Err(e) => {
1511                    return Err(crate::Error::ConfigError(format!("shard receive failed: {e}")));
1512                }
1513            };
1514
1515            let step = shard.step;
1516            let shard_start = shard.shard_start.min(all_samples.len());
1517            let shard_end = shard.shard_end.min(all_samples.len());
1518            let shard_data = &all_samples[shard_start..shard_end];
1519
1520            // Forward + backward on our shard
1521            let batch_result = self.pipeline.train_batch(shard_data);
1522
1523            // Collect LoRA gradients
1524            let gradients = self.pipeline.collect_lora_gradients();
1525
1526            // Send gradients to coordinator
1527            client
1528                .send_gradients(
1529                    step,
1530                    gradients,
1531                    batch_result.avg_loss,
1532                    batch_result.correct,
1533                    batch_result.total,
1534                )
1535                .map_err(|e| crate::Error::ConfigError(format!("gradient send failed: {e}")))?;
1536
1537            // Receive averaged gradients
1538            let averaged = client
1539                .receive_averaged()
1540                .map_err(|e| crate::Error::ConfigError(format!("averaged receive failed: {e}")))?;
1541
1542            // Apply averaged gradients via optimizer step
1543            self.pipeline.apply_lora_gradients(&averaged.gradients);
1544
1545            eprintln!(
1546                "[worker {}] step {step}: loss={:.4}, global_loss={:.4}",
1547                client.worker_id(),
1548                batch_result.avg_loss,
1549                averaged.global_loss,
1550            );
1551        }
1552
1553        Ok(TrainResult {
1554            epoch_metrics: epoch_metrics_vec,
1555            best_epoch,
1556            best_val_loss,
1557            stopped_early: false,
1558            total_time_ms: total_start.elapsed().as_millis() as u64,
1559        })
1560    }
1561
1562    /// Evaluate the model on a dataset, returning structured per-class metrics.
1563    ///
1564    /// Runs forward-only on every sample, collects predictions, and computes
1565    /// precision/recall/F1/confusion matrix via `ConfusionMatrix` and `MultiClassMetrics`.
1566    ///
1567    /// # Arguments
1568    /// * `data` - Labeled samples to evaluate on
1569    /// * `label_names` - Human-readable class names (length must match num_classes)
1570    pub fn evaluate(
1571        &mut self,
1572        data: &[SafetySample],
1573        label_names: &[String],
1574    ) -> ClassifyEvalReport {
1575        let start = std::time::Instant::now();
1576        let num_classes = self.pipeline.config.num_classes;
1577
1578        let mut y_true: Vec<usize> = Vec::with_capacity(data.len());
1579        let mut y_pred: Vec<usize> = Vec::with_capacity(data.len());
1580        let mut all_probs: Vec<Vec<f32>> = Vec::with_capacity(data.len());
1581        let mut total_loss = 0.0f32;
1582
1583        for sample in data {
1584            let ids = self.pipeline.tokenize(&sample.input);
1585            let (loss, predicted, probs) =
1586                self.pipeline.forward_only_with_probs(&ids, sample.label);
1587            total_loss += loss;
1588            y_true.push(sample.label);
1589            y_pred.push(predicted);
1590            all_probs.push(probs);
1591        }
1592
1593        ClassifyEvalReport::from_predictions_with_probs(
1594            &y_pred,
1595            &y_true,
1596            &all_probs,
1597            total_loss,
1598            num_classes,
1599            label_names,
1600            start.elapsed().as_millis() as u64,
1601        )
1602    }
1603}
1604
1605#[cfg(test)]
1606#[allow(clippy::unwrap_used)]
1607#[path = "classify_trainer_tests.rs"]
1608mod tests;