Skip to main content

entrenar/finetune/classify_pipeline/
mod.rs

1//! Classification fine-tuning pipeline
2//!
3//! Wires Transformer + LoRA + ClassificationHead for sequence classification.
4//!
5//! # Architecture
6//!
7//! ```text
8//! token_ids -> Transformer.forward_hidden() -> [seq_len, hidden_size]
9//!           -> ClassificationHead.forward()  -> [num_classes]
10//!           -> cross_entropy_loss(target)    -> scalar loss
11//! ```
12//!
13//! # Contract
14//!
15//! See `aprender/contracts/classification-finetune-v1.yaml`
16
17use super::classification::{
18    load_multi_label_corpus, load_safety_corpus, ClassificationHead, MultiLabelSafetySample,
19    SafetySample, TokenizedSample,
20};
21use crate::autograd::matmul;
22use crate::lora::LoRAConfig;
23use crate::lora::LoRALayer;
24use crate::optim::{clip_grad_norm_refs, AdamW, Optimizer};
25use crate::tokenizer::HfTokenizer;
26use crate::transformer::Transformer;
27use crate::transformer::TransformerConfig;
28use crate::Tensor;
29use std::path::{Path, PathBuf};
30
31#[cfg(feature = "cuda")]
32use crate::autograd::cuda_backward::pre_warm_lora_backward_kernels as pre_warm_backward_cache_kernels;
33#[cfg(feature = "cuda")]
34use crate::autograd::cuda_forward::{pre_warm_forward_kernels, pre_warm_lora_backward_kernels};
35#[cfg(feature = "cuda")]
36use crate::autograd::cuda_optim::pre_warm_lora_adamw_kernels;
37#[cfg(feature = "cuda")]
38use crate::autograd::cuda_training::{cuda_training_available, CudaTrainer};
39#[cfg(feature = "cuda")]
40use crate::gpu::guard::VramGuard;
41#[cfg(feature = "cuda")]
42use crate::transformer::{
43    CudaBlock, CudaBlockScratch, CudaGradWorkspace, CudaLoraGradWorkspace, CudaTransformerBlock,
44    GpuBlockOptimizerState, GpuLoraOptimizerState,
45};
46#[cfg(feature = "cuda")]
47use std::sync::Arc;
48#[cfg(feature = "cuda")]
49use trueno_gpu::driver::GpuBuffer;
50
51/// Classification fine-tuning pipeline configuration.
52#[derive(Debug, Clone)]
53pub struct ClassifyConfig {
54    /// Number of output classes
55    pub num_classes: usize,
56    /// LoRA rank
57    pub lora_rank: usize,
58    /// LoRA alpha
59    pub lora_alpha: f32,
60    /// Learning rate
61    pub learning_rate: f32,
62    /// Number of training epochs
63    pub epochs: usize,
64    /// Maximum sequence length
65    pub max_seq_len: usize,
66    /// Log every N steps
67    pub log_interval: usize,
68    /// Mini-batch size for `train_batch()`.
69    ///
70    /// Samples are processed one at a time (forward + backward), but the
71    /// optimizer step is applied once per batch after accumulating gradients.
72    pub batch_size: usize,
73    /// Number of gradient accumulation steps.
74    ///
75    /// Allows effective batch size = `batch_size * accumulation_steps` without
76    /// increasing peak memory beyond a single micro-batch forward pass.
77    pub accumulation_steps: usize,
78    /// Maximum gradient norm for clipping.
79    ///
80    /// When `Some(max_norm)`, gradients are clipped to this L2 norm before
81    /// the optimizer step. `None` disables gradient clipping.
82    pub gradient_clip_norm: Option<f32>,
83    /// Per-class loss weights for imbalanced datasets.
84    ///
85    /// When `Some(weights)`, the cross-entropy loss for label `c` is multiplied
86    /// by `weights[c]`. Weights should sum to `num_classes` to preserve loss scale.
87    /// When `None`, all classes are weighted equally (weight = 1.0).
88    ///
89    /// See SPEC-TUNE-2026-001 §9 for weight computation strategies.
90    pub class_weights: Option<Vec<f32>>,
91    /// Quantize frozen weights to NF4 (4-bit) for QLoRA training (default: false).
92    ///
93    /// When enabled, uses `CudaNf4TransformerBlock` (~8x VRAM compression) instead
94    /// of `CudaTransformerBlock`. GPU backward pass is disabled (LoRA-only training).
95    pub quantize_nf4: bool,
96}
97
98impl Default for ClassifyConfig {
99    fn default() -> Self {
100        Self {
101            num_classes: 5,
102            lora_rank: 16,
103            lora_alpha: 16.0,
104            learning_rate: 1e-4,
105            epochs: 3,
106            max_seq_len: 512,
107            log_interval: 100,
108            batch_size: 32,
109            accumulation_steps: 1,
110            gradient_clip_norm: Some(1.0),
111            class_weights: None,
112            quantize_nf4: false,
113        }
114    }
115}
116
117/// Hyperparameter diagnostic from contract validation.
118///
119/// Contract: qlora-hyperparameters-v1.yaml (C-HP-001..008)
120#[derive(Debug, Clone)]
121pub struct HyperparamDiagnostic {
122    pub contract_id: &'static str,
123    pub severity: DiagSeverity,
124    pub message: String,
125    pub recommendation: String,
126}
127
128/// Severity level for hyperparameter diagnostics.
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum DiagSeverity {
131    /// Informational — config is acceptable but not optimal.
132    Info,
133    /// Warning — config violates a research-grounded contract.
134    Warn,
135    /// Error — config is mathematically invalid (e.g., lr=0).
136    Error,
137}
138
139/// Collection of hyperparameter diagnostics.
140#[derive(Debug, Clone, Default)]
141pub struct HyperparamDiagnostics {
142    pub items: Vec<HyperparamDiagnostic>,
143}
144
145impl HyperparamDiagnostics {
146    /// Check if any diagnostic matches a contract ID.
147    pub fn has_warning(&self, contract_id: &str) -> bool {
148        self.items.iter().any(|d| {
149            d.contract_id == contract_id
150                && matches!(d.severity, DiagSeverity::Warn | DiagSeverity::Error)
151        })
152    }
153
154    /// Check if there are any blocking errors.
155    pub fn has_errors(&self) -> bool {
156        self.items.iter().any(|d| matches!(d.severity, DiagSeverity::Error))
157    }
158
159    /// Print all diagnostics to stderr.
160    pub fn print_all(&self) {
161        for d in &self.items {
162            let prefix = match d.severity {
163                DiagSeverity::Info => "[HP-INFO]",
164                DiagSeverity::Warn => "[HP-WARN]",
165                DiagSeverity::Error => "[HP-ERROR]",
166            };
167            eprintln!("{prefix} {}: {} → {}", d.contract_id, d.message, d.recommendation);
168        }
169    }
170}
171
172/// Data distribution statistics for data-driven hyperparameter validation.
173///
174/// Contract: C-HP-004 (seq_len from data), C-HP-008 (epochs for imbalance)
175pub struct DataStats {
176    /// 99th percentile of BPE token lengths in training data.
177    pub p99_token_length: usize,
178    /// Class imbalance ratio (majority_count / minority_count).
179    pub imbalance_ratio: f32,
180    /// Number of samples in minority class.
181    pub minority_count: usize,
182}
183
184impl ClassifyConfig {
185    /// Create a QLoRA config with research-grounded defaults.
186    ///
187    /// Every parameter traces to a published source:
188    ///
189    /// | Parameter | Value | Source |
190    /// |-----------|-------|--------|
191    /// | `learning_rate` | 2e-4 (≤13B) / 1e-4 (>13B) | Dettmers 2023 Table 9 |
192    /// | `lora_alpha` | 2 × rank | Lightning AI experiments |
193    /// | `batch_size` | 4 | RTX VRAM budget |
194    /// | `accumulation_steps` | 4 | effective=16, Dettmers 2023 |
195    /// | `gradient_clip_norm` | 1.0 | Standard practice |
196    /// | `epochs` | 3 | Imbalanced classification |
197    ///
198    /// Contract: provable-contracts/contracts/entrenar/qlora-hyperparameters-v1.yaml
199    pub fn qlora_default(model_params: u64) -> Self {
200        // C-HP-001: lr scales with model size (Dettmers 2023 Table 9)
201        let learning_rate = if model_params <= 13_000_000_000 { 2e-4 } else { 1e-4 };
202        let lora_rank = 16;
203        Self {
204            num_classes: 2,
205            lora_rank,
206            // C-HP-003: alpha = 2 * rank (Lightning AI)
207            lora_alpha: (2 * lora_rank) as f32,
208            learning_rate,
209            // C-HP-008: >= 2 epochs for imbalanced classification
210            epochs: 3,
211            // C-HP-004: set from data, 256 is SSC v3 default
212            max_seq_len: 256,
213            log_interval: 100,
214            // C-HP-002: effective=16 (Dettmers 2023). Micro-batch=16 to saturate
215            // GPU occupancy (batch=4 leaves RTX 4090 at 1.5% MFU).
216            batch_size: 16,
217            accumulation_steps: 1,
218            // C-HP-006: gradient clipping (standard, SSC v2.2 precedent)
219            gradient_clip_norm: Some(1.0),
220            class_weights: None,
221            quantize_nf4: true,
222        }
223    }
224
225    /// Validate hyperparameters against research-grounded contracts.
226    ///
227    /// Returns diagnostics (warnings/errors) for each violated contract.
228    /// Does NOT block — caller decides whether to abort or proceed.
229    ///
230    /// Contract: qlora-hyperparameters-v1.yaml (FALSIFY-HP-001..008)
231    pub fn validate_hyperparameters(&self, model_params: u64) -> HyperparamDiagnostics {
232        let mut diags = HyperparamDiagnostics::default();
233
234        // C-HP-001: Learning rate scaling
235        if self.quantize_nf4 && model_params <= 13_000_000_000 && self.learning_rate < 1.5e-4 {
236            diags.items.push(HyperparamDiagnostic {
237                contract_id: "C-HP-001",
238                severity: DiagSeverity::Warn,
239                message: format!(
240                    "lr={:.0e} too low for {}B model (Dettmers 2023: use 2e-4 for ≤13B)",
241                    self.learning_rate,
242                    model_params / 1_000_000_000
243                ),
244                recommendation: "learning_rate: 0.0002".to_string(),
245            });
246        }
247
248        // C-HP-002: Effective batch size
249        let eff_batch = self.batch_size * self.accumulation_steps;
250        if eff_batch != 16 {
251            diags.items.push(HyperparamDiagnostic {
252                contract_id: "C-HP-002",
253                severity: DiagSeverity::Warn,
254                message: format!(
255                    "effective_batch={eff_batch} ({}×{}), Dettmers 2023 recommends 16 for ≤13B",
256                    self.batch_size, self.accumulation_steps
257                ),
258                recommendation: format!(
259                    "batch_size: {}, accumulation_steps: {}",
260                    self.batch_size,
261                    16 / self.batch_size.max(1)
262                ),
263            });
264        }
265
266        // C-HP-003: Alpha/rank ratio
267        let expected_alpha = 2.0 * self.lora_rank as f32;
268        if (self.lora_alpha - expected_alpha).abs() > 0.5 {
269            diags.items.push(HyperparamDiagnostic {
270                contract_id: "C-HP-003",
271                severity: DiagSeverity::Warn,
272                message: format!(
273                    "lora_alpha={} with rank={} (ratio={:.1}), Lightning AI: alpha=2×rank={} optimal",
274                    self.lora_alpha, self.lora_rank,
275                    self.lora_alpha / self.lora_rank as f32,
276                    expected_alpha
277                ),
278                recommendation: format!("lora_alpha: {expected_alpha}"),
279            });
280        }
281
282        // C-HP-006: Gradient clipping
283        if self.gradient_clip_norm.is_none() {
284            diags.items.push(HyperparamDiagnostic {
285                contract_id: "C-HP-006",
286                severity: DiagSeverity::Warn,
287                message: "No gradient clipping — SSC v2.2 saw grad norms up to 115.1".to_string(),
288                recommendation: "gradient_clip_norm: 1.0".to_string(),
289            });
290        }
291
292        // Blocking errors
293        if self.learning_rate <= 0.0 {
294            diags.items.push(HyperparamDiagnostic {
295                contract_id: "C-HP-001",
296                severity: DiagSeverity::Error,
297                message: "learning_rate must be > 0".to_string(),
298                recommendation: "learning_rate: 0.0002".to_string(),
299            });
300        }
301        if self.batch_size == 0 {
302            diags.items.push(HyperparamDiagnostic {
303                contract_id: "C-HP-002",
304                severity: DiagSeverity::Error,
305                message: "batch_size must be > 0".to_string(),
306                recommendation: "batch_size: 4".to_string(),
307            });
308        }
309
310        diags
311    }
312
313    /// Validate hyperparameters that depend on training data distribution.
314    ///
315    /// Requires measuring the actual data (genchi genbutsu — go and see).
316    ///
317    /// Contract: C-HP-004 (seq_len), C-HP-008 (epochs)
318    pub fn validate_with_data(&self, stats: &DataStats) -> HyperparamDiagnostics {
319        let mut diags = HyperparamDiagnostics::default();
320
321        // C-HP-004: Sequence length from data distribution
322        if self.max_seq_len > 2 * stats.p99_token_length && stats.p99_token_length > 0 {
323            diags.items.push(HyperparamDiagnostic {
324                contract_id: "C-HP-004",
325                severity: DiagSeverity::Warn,
326                message: format!(
327                    "max_seq_len={} but p99(tokens)={} — attention is O(n²), wasting {:.0}× compute",
328                    self.max_seq_len,
329                    stats.p99_token_length,
330                    (self.max_seq_len as f64 / stats.p99_token_length as f64).powi(2)
331                ),
332                recommendation: format!(
333                    "max_seq_len: {} (next_pow2 of p99)",
334                    stats.p99_token_length.next_power_of_two()
335                ),
336            });
337        }
338
339        // C-HP-008: Epochs for imbalanced classification
340        if stats.imbalance_ratio > 5.0 && self.epochs < 2 {
341            let eff_batch = self.batch_size * self.accumulation_steps;
342            let updates_per_epoch = stats.minority_count / eff_batch.max(1);
343            diags.items.push(HyperparamDiagnostic {
344                contract_id: "C-HP-008",
345                severity: DiagSeverity::Warn,
346                message: format!(
347                    "epochs={} with {:.1}:1 imbalance — minority gets only {} gradient updates",
348                    self.epochs,
349                    stats.imbalance_ratio,
350                    updates_per_epoch * self.epochs
351                ),
352                recommendation: format!(
353                    "epochs: 3 (minority gets {} updates)",
354                    updates_per_epoch * 3
355                ),
356            });
357        }
358
359        diags
360    }
361}
362
363/// Result of processing one mini-batch via [`ClassifyPipeline::train_batch`].
364#[derive(Debug, Clone)]
365pub struct BatchResult {
366    /// Average cross-entropy loss across the batch
367    pub avg_loss: f32,
368    /// Number of correctly classified samples
369    pub correct: usize,
370    /// Total number of samples in the batch
371    pub total: usize,
372    /// Global gradient norm (before clipping). 0.0 if clipping disabled.
373    pub grad_norm: f32,
374}
375
376impl BatchResult {
377    /// Compute classification accuracy as `correct / total`.
378    ///
379    /// Returns 0.0 for an empty batch (total == 0).
380    #[must_use]
381    pub fn accuracy(&self) -> f32 {
382        contract_pre_accuracy!();
383        self.correct as f32 / self.total.max(1) as f32
384    }
385}
386
387/// GPU-resident training state for full-finetune backward pass.
388///
389/// Holds per-layer activation snapshots, optimizer state, and scratch buffers
390/// required to run backward through all transformer blocks on GPU.
391///
392/// # Contract (C-GPUTRAIN-001)
393///
394/// - **Precondition**: `init()` called after CUDA blocks are created
395/// - **Postcondition**: All layer_inputs saved during forward; optimizer states zero-initialized
396/// - **Invariant**: `layer_inputs.len() == num_layers`; buffers never reallocated after init
397#[cfg(feature = "cuda")]
398struct GpuTrainingState {
399    /// Saved input to each block during forward [num_layers][max_seq_len * hidden_size]
400    layer_inputs: Vec<GpuBuffer<f32>>,
401    /// Final RMSNorm weight uploaded to GPU [hidden_size]
402    final_norm_weight: GpuBuffer<f32>,
403    /// Blocks output saved on GPU for final norm backward [max_seq_len * hidden_size]
404    blocks_output: GpuBuffer<f32>,
405    /// Gradient scratch buffer A [max_seq_len * hidden_size]
406    grad_buf_a: GpuBuffer<f32>,
407    /// Gradient scratch buffer B [max_seq_len * hidden_size]
408    grad_buf_b: GpuBuffer<f32>,
409    /// Gradient for final RMSNorm weight [hidden_size]
410    grad_final_norm_weight: GpuBuffer<f32>,
411    /// Per-block AdamW optimizer states
412    optimizer_states: Vec<GpuBlockOptimizerState>,
413    /// Global optimizer step counter
414    step: u32,
415    /// KAIZEN-045: Pre-allocated scratch buffer for activation checkpointing in backward
416    /// [max_seq_len * hidden_size]. Eliminates per-backward cuMemAlloc/cuMemFree.
417    output_scratch: GpuBuffer<f32>,
418    /// KAIZEN-045: Pre-allocated upload buffer for gradient H2D transfer in backward
419    /// [max_seq_len * hidden_size]. Eliminates per-backward cuMemAlloc/cuMemFree.
420    grad_upload_buf: GpuBuffer<f32>,
421    /// KAIZEN-060: Pre-allocated forward ping-pong buffers [max_seq_len * hidden_size].
422    /// Eliminates 2 × cuMemAlloc/Free per forward pass.
423    fwd_scratch_a: GpuBuffer<f32>,
424    fwd_scratch_b: GpuBuffer<f32>,
425    /// KAIZEN-061: Pre-allocated CPU staging buffer for backward mean-pool gradient.
426    /// Sized to max_seq_len * hidden_size. Eliminates ~1.25MB heap alloc per sample
427    /// in both backward_gpu_blocks and backward_nf4_gpu_blocks (~17.5GB/epoch).
428    backward_cpu_staging: Vec<f32>,
429}
430
431/// Classification fine-tuning pipeline.
432///
433/// Owns the transformer, LoRA adapters, and classification head.
434/// Provides `train_step()` for single-step training and `train()` for full loop.
435///
436/// When compiled with `feature = "cuda"` and a GPU is available, the forward pass
437/// runs on CUDA via `CudaTransformerBlock`s for ~10-50x speedup (F-CUDA-007).
438/// When `gpu_training` is set, the backward pass also runs on GPU with full-finetune
439/// of all transformer weights (F-CUDA-014).
440pub struct ClassifyPipeline {
441    /// Base transformer model (weights frozen)
442    pub model: Transformer,
443    /// Classification head (trainable)
444    pub classifier: ClassificationHead,
445    /// LoRA adapters applied to attention projections
446    pub lora_layers: Vec<LoRALayer>,
447    /// Pipeline configuration
448    pub config: ClassifyConfig,
449    /// AdamW optimizer for trainable parameters
450    optimizer: AdamW,
451    /// Optional BPE tokenizer (None = byte-level fallback)
452    tokenizer: Option<HfTokenizer>,
453    /// Path to base model directory (set by `from_pretrained`, None for random init)
454    model_dir: Option<PathBuf>,
455    /// CUDA trainer for GPU memory management (F-CUDA-002)
456    #[cfg(feature = "cuda")]
457    cuda_trainer: Option<CudaTrainer>,
458    /// CUDA-accelerated transformer blocks — one per layer (F-CUDA-006)
459    #[cfg(feature = "cuda")]
460    cuda_blocks: Option<Vec<CudaBlock>>,
461    /// Shared scratch buffers for NF4 forward pass (C-SCRATCH-001).
462    /// One allocation shared across all 36 layers — saves 7.5 GB for Qwen3-4B.
463    #[cfg(feature = "cuda")]
464    shared_scratch: Option<CudaBlockScratch>,
465    /// Count of GPU forward passes that produced NaN/Inf and fell back to CPU.
466    /// Used to decide when to permanently disable CUDA (threshold: 100).
467    #[cfg(feature = "cuda")]
468    cuda_nan_count: usize,
469    /// GPU training state for full-finetune backward pass (F-CUDA-014).
470    /// When `Some`, backward pass runs on GPU updating all transformer weights.
471    #[cfg(feature = "cuda")]
472    gpu_training: Option<GpuTrainingState>,
473    /// Shared gradient workspace — one set of weight-gradient buffers for all layers.
474    /// Contract: C-GRADWS-001. Saves (L-1) * 372 MB for Qwen3-4B (13.0 GB).
475    #[cfg(feature = "cuda")]
476    cuda_grad_workspace: Option<CudaGradWorkspace>,
477    /// Shared LoRA gradient workspace for NF4 QLoRA backward (ENT-153).
478    /// One set of LoRA gradient buffers shared across all layers.
479    #[cfg(feature = "cuda")]
480    cuda_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
481    /// Per-layer LoRA optimizer states for NF4 QLoRA training (ENT-153).
482    #[cfg(feature = "cuda")]
483    cuda_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
484    /// KAIZEN-014: Per-layer gradient accumulators for batch-correct LoRA optimizer.
485    #[cfg(feature = "cuda")]
486    cuda_lora_grad_accum: Option<Vec<CudaLoraGradWorkspace>>,
487    /// NF4 LoRA optimizer step counter (separate from fp32 GpuTrainingState.step).
488    #[cfg(feature = "cuda")]
489    nf4_lora_step: u32,
490    /// wgpu-accelerated forward pass (GPU feature, non-CUDA)
491    #[cfg(feature = "gpu")]
492    wgpu_forward_pass: Option<crate::transformer::WgpuForwardPass>,
493    /// VRAM reservation guard (GPU-SHARE-002). Releases ledger entry on Drop.
494    /// Held for RAII — released when pipeline is dropped.
495    #[cfg(feature = "cuda")]
496    #[allow(dead_code)]
497    vram_guard: Option<VramGuard>,
498}
499
500impl ClassifyPipeline {
501    /// Create a new classification pipeline with random weights and byte-level tokenization.
502    ///
503    /// # Arguments
504    /// * `model_config` - Transformer configuration (e.g., `TransformerConfig::qwen2_0_5b()`)
505    /// * `classify_config` - Classification pipeline configuration
506    pub fn new(model_config: &TransformerConfig, classify_config: ClassifyConfig) -> Self {
507        let model = Transformer::new(model_config);
508        let classifier =
509            ClassificationHead::new(model_config.hidden_size, classify_config.num_classes);
510        let mut lora_layers = Self::build_lora_layers(&model, model_config, &classify_config);
511
512        // Ensure LoRA A/B matrices have requires_grad=true
513        for lora in &mut lora_layers {
514            for param in lora.trainable_params() {
515                param.set_requires_grad(true);
516            }
517        }
518
519        let optimizer = AdamW::default_params(classify_config.learning_rate);
520
521        // ── CUDA initialization (F-CUDA-001..006, GPU-SHARE-002) ─────────
522        #[cfg(feature = "cuda")]
523        let (cuda_trainer, cuda_blocks, shared_scratch, vram_guard) =
524            Self::try_init_cuda(&model, model_config, &classify_config, &lora_layers);
525
526        // ── GPU training state (F-CUDA-014) ────────────────────────────
527        #[cfg(feature = "cuda")]
528        let gpu_training = Self::try_init_gpu_training(
529            &model,
530            model_config,
531            classify_config.max_seq_len,
532            cuda_trainer.as_ref(),
533            cuda_blocks.as_ref(),
534        );
535
536        // ── Shared gradient workspace (C-GRADWS-001) ────────────────────
537        #[cfg(feature = "cuda")]
538        let cuda_grad_workspace = if classify_config.quantize_nf4 {
539            None
540        } else {
541            cuda_trainer.as_ref().and_then(|t| {
542                CudaGradWorkspace::new(t.context(), model_config)
543                    .map_err(|e| eprintln!("[CUDA] Failed to allocate grad workspace: {e}"))
544                    .ok()
545            })
546        };
547
548        // ── NF4 LoRA training state (ENT-153) ──────────────────────────
549        #[cfg(feature = "cuda")]
550        let (cuda_lora_grad_workspace, cuda_lora_optimizer_states, cuda_lora_grad_accum) =
551            if classify_config.quantize_nf4 {
552                Self::try_init_nf4_lora_training(
553                    cuda_trainer.as_ref(),
554                    cuda_blocks.as_ref(),
555                    model_config,
556                    &classify_config,
557                )
558            } else {
559                (None, None, None)
560            };
561
562        // ── wgpu initialization (when CUDA unavailable) ──────────────────
563        #[cfg(feature = "gpu")]
564        let wgpu_forward_pass = {
565            #[cfg(feature = "cuda")]
566            let has_cuda = cuda_trainer.is_some();
567            #[cfg(not(feature = "cuda"))]
568            let has_cuda = false;
569
570            if has_cuda {
571                None // CUDA takes priority
572            } else {
573                // KAIZEN-015: Pre-upload FFN weights to GPU (zero H2D per forward pass)
574                match crate::transformer::WgpuForwardPass::with_resident_weights(&model) {
575                    Ok(pass) => {
576                        eprintln!("[wgpu] GPU forward pass initialized (resident weights)");
577                        Some(pass)
578                    }
579                    Err(e) => {
580                        eprintln!("[wgpu] GPU resident init failed, trying default: {e}");
581                        match crate::transformer::WgpuForwardPass::new_default(model_config) {
582                            Ok(pass) => {
583                                eprintln!("[wgpu] GPU forward pass initialized (upload per call)");
584                                Some(pass)
585                            }
586                            Err(e2) => {
587                                eprintln!("[wgpu] GPU initialization failed, using CPU: {e2}");
588                                None
589                            }
590                        }
591                    }
592                }
593            }
594        };
595
596        Self {
597            model,
598            classifier,
599            lora_layers,
600            config: classify_config,
601            optimizer,
602            tokenizer: None,
603            model_dir: None,
604            #[cfg(feature = "cuda")]
605            cuda_trainer,
606            #[cfg(feature = "cuda")]
607            cuda_blocks,
608            #[cfg(feature = "cuda")]
609            shared_scratch,
610            #[cfg(feature = "cuda")]
611            cuda_nan_count: 0,
612            #[cfg(feature = "cuda")]
613            gpu_training,
614            #[cfg(feature = "cuda")]
615            cuda_grad_workspace,
616            #[cfg(feature = "cuda")]
617            cuda_lora_grad_workspace,
618            #[cfg(feature = "cuda")]
619            cuda_lora_optimizer_states,
620            #[cfg(feature = "cuda")]
621            cuda_lora_grad_accum,
622            #[cfg(feature = "cuda")]
623            nf4_lora_step: 0,
624            #[cfg(feature = "gpu")]
625            wgpu_forward_pass,
626            #[cfg(feature = "cuda")]
627            vram_guard,
628        }
629    }
630
631    /// Create a classification pipeline from pretrained weights.
632    ///
633    /// Loads a transformer from SafeTensors weights and optionally a BPE tokenizer
634    /// from `tokenizer.json` in the model directory.
635    ///
636    /// # Arguments
637    /// * `model_dir` - Directory containing SafeTensors weights (and optionally `tokenizer.json`)
638    /// * `model_config` - Transformer configuration matching the pretrained weights
639    /// * `classify_config` - Classification pipeline configuration
640    ///
641    /// # Errors
642    /// Returns error if the model directory doesn't exist or weights fail to load.
643    pub fn from_pretrained(
644        model_dir: impl AsRef<Path>,
645        model_config: &TransformerConfig,
646        classify_config: ClassifyConfig,
647    ) -> crate::Result<Self> {
648        let model_dir = model_dir.as_ref();
649
650        let model = Transformer::from_safetensors(model_dir, model_config)?;
651        let classifier =
652            ClassificationHead::new(model_config.hidden_size, classify_config.num_classes);
653        let mut lora_layers = Self::build_lora_layers(&model, model_config, &classify_config);
654
655        for lora in &mut lora_layers {
656            for param in lora.trainable_params() {
657                param.set_requires_grad(true);
658            }
659        }
660
661        // CONTRACT: Training requires a BPE tokenizer — byte-fallback is not acceptable.
662        let tokenizer_path = model_dir.join("tokenizer.json");
663        let tokenizer = if tokenizer_path.exists() {
664            Some(
665                HfTokenizer::from_file(&tokenizer_path)
666                    .map_err(|e| crate::Error::Io(format!("Failed to load tokenizer: {e}")))?,
667            )
668        } else {
669            return Err(crate::Error::ConfigError(format!(
670                "No tokenizer.json found in '{}'. Training requires a BPE tokenizer.",
671                model_dir.display(),
672            )));
673        };
674
675        let optimizer = AdamW::default_params(classify_config.learning_rate);
676
677        // ── CUDA initialization (F-CUDA-001..006, GPU-SHARE-002) ─────────
678        #[cfg(feature = "cuda")]
679        let (cuda_trainer, cuda_blocks, shared_scratch, vram_guard) =
680            Self::try_init_cuda(&model, model_config, &classify_config, &lora_layers);
681
682        // ── GPU training state (F-CUDA-014) ────────────────────────────
683        #[cfg(feature = "cuda")]
684        let gpu_training = Self::try_init_gpu_training(
685            &model,
686            model_config,
687            classify_config.max_seq_len,
688            cuda_trainer.as_ref(),
689            cuda_blocks.as_ref(),
690        );
691
692        // ── Shared gradient workspace (C-GRADWS-001) ────────────────────
693        #[cfg(feature = "cuda")]
694        let cuda_grad_workspace = if classify_config.quantize_nf4 {
695            None // No grad workspace needed for frozen NF4 weights
696        } else {
697            cuda_trainer.as_ref().and_then(|t| {
698                CudaGradWorkspace::new(t.context(), model_config)
699                    .map_err(|e| eprintln!("[CUDA] Failed to allocate grad workspace: {e}"))
700                    .ok()
701            })
702        };
703
704        // ── NF4 LoRA training state (ENT-153) ──────────────────────────
705        #[cfg(feature = "cuda")]
706        let (cuda_lora_grad_workspace, cuda_lora_optimizer_states, cuda_lora_grad_accum) =
707            if classify_config.quantize_nf4 {
708                Self::try_init_nf4_lora_training(
709                    cuda_trainer.as_ref(),
710                    cuda_blocks.as_ref(),
711                    model_config,
712                    &classify_config,
713                )
714            } else {
715                (None, None, None)
716            };
717
718        // ── wgpu initialization (when CUDA unavailable) ──────────────────
719        #[cfg(feature = "gpu")]
720        let wgpu_forward_pass = {
721            #[cfg(feature = "cuda")]
722            let has_cuda = cuda_trainer.is_some();
723            #[cfg(not(feature = "cuda"))]
724            let has_cuda = false;
725
726            if has_cuda {
727                None
728            } else {
729                // KAIZEN-015: Pre-upload FFN weights to GPU
730                match crate::transformer::WgpuForwardPass::with_resident_weights(&model) {
731                    Ok(pass) => {
732                        eprintln!(
733                            "[wgpu] Batched forward pass initialized ({} layers, resident weights)",
734                            model_config.num_hidden_layers
735                        );
736                        Some(pass)
737                    }
738                    Err(e) => {
739                        eprintln!("[wgpu] Resident init failed, trying default: {e}");
740                        match crate::transformer::WgpuForwardPass::new_default(model_config) {
741                            Ok(pass) => {
742                                eprintln!("[wgpu] Batched forward pass initialized ({} layers, upload per call)", model_config.num_hidden_layers);
743                                Some(pass)
744                            }
745                            Err(e2) => {
746                                eprintln!("[wgpu] GPU init failed, using CPU: {e2}");
747                                None
748                            }
749                        }
750                    }
751                }
752            }
753        };
754
755        Ok(Self {
756            model,
757            classifier,
758            lora_layers,
759            config: classify_config,
760            optimizer,
761            tokenizer,
762            model_dir: Some(model_dir.to_path_buf()),
763            #[cfg(feature = "cuda")]
764            cuda_trainer,
765            #[cfg(feature = "cuda")]
766            cuda_blocks,
767            #[cfg(feature = "cuda")]
768            shared_scratch,
769            #[cfg(feature = "cuda")]
770            cuda_nan_count: 0,
771            #[cfg(feature = "cuda")]
772            gpu_training,
773            #[cfg(feature = "cuda")]
774            cuda_grad_workspace,
775            #[cfg(feature = "cuda")]
776            cuda_lora_grad_workspace,
777            #[cfg(feature = "cuda")]
778            cuda_lora_optimizer_states,
779            #[cfg(feature = "cuda")]
780            cuda_lora_grad_accum,
781            #[cfg(feature = "cuda")]
782            nf4_lora_step: 0,
783            #[cfg(feature = "gpu")]
784            wgpu_forward_pass,
785            #[cfg(feature = "cuda")]
786            vram_guard,
787        })
788    }
789
790    /// Create pipeline from APR model file (.apr format).
791    ///
792    /// Loads transformer weights from the APR binary, dequantizing from any
793    /// stored dtype (F16, Q4K, etc.) to F32. Loads sibling tokenizer if present
794    /// (e.g., `model.tokenizer.json` next to `model.apr`).
795    ///
796    /// # Errors
797    /// Returns error if APR file cannot be loaded or weights are invalid.
798    pub fn from_apr(
799        apr_path: &Path,
800        model_config: &TransformerConfig,
801        classify_config: ClassifyConfig,
802    ) -> crate::Result<Self> {
803        let model = Transformer::from_apr(apr_path, model_config)?;
804        let classifier =
805            ClassificationHead::new(model_config.hidden_size, classify_config.num_classes);
806        let mut lora_layers = Self::build_lora_layers(&model, model_config, &classify_config);
807
808        for lora in &mut lora_layers {
809            for param in lora.trainable_params() {
810                param.set_requires_grad(true);
811            }
812        }
813
814        // CONTRACT: Training requires a BPE tokenizer — byte-fallback is not acceptable.
815        let tokenizer = {
816            let sibling = apr_path.file_stem().and_then(|stem| {
817                apr_path
818                    .parent()
819                    .map(|p| p.join(format!("{}.tokenizer.json", stem.to_str().unwrap_or(""))))
820            });
821
822            match sibling {
823                Some(ref path) if path.exists() => {
824                    let tok = HfTokenizer::from_file(path).map_err(|e| {
825                        crate::Error::ConfigError(format!(
826                            "Failed to load tokenizer from '{}': {e}. \
827                             Training requires a BPE tokenizer.",
828                            path.display(),
829                        ))
830                    })?;
831                    Some(tok)
832                }
833                _ => {
834                    return Err(crate::Error::ConfigError(format!(
835                        "No sibling tokenizer found for '{}'. Expected \
836                         '{}.tokenizer.json' next to the .apr file. Training \
837                         requires a BPE tokenizer.",
838                        apr_path.display(),
839                        apr_path.file_stem().unwrap_or_default().to_str().unwrap_or(""),
840                    )));
841                }
842            }
843        };
844
845        let optimizer = AdamW::default_params(classify_config.learning_rate);
846
847        #[cfg(feature = "cuda")]
848        let (cuda_trainer, cuda_blocks, shared_scratch, vram_guard) =
849            Self::try_init_cuda(&model, model_config, &classify_config, &lora_layers);
850
851        #[cfg(feature = "cuda")]
852        let gpu_training = Self::try_init_gpu_training(
853            &model,
854            model_config,
855            classify_config.max_seq_len,
856            cuda_trainer.as_ref(),
857            cuda_blocks.as_ref(),
858        );
859
860        #[cfg(feature = "cuda")]
861        let cuda_grad_workspace = if classify_config.quantize_nf4 {
862            None
863        } else {
864            cuda_trainer.as_ref().and_then(|t| {
865                CudaGradWorkspace::new(t.context(), model_config)
866                    .map_err(|e| eprintln!("[CUDA] Failed to allocate grad workspace: {e}"))
867                    .ok()
868            })
869        };
870
871        #[cfg(feature = "cuda")]
872        let (cuda_lora_grad_workspace, cuda_lora_optimizer_states, cuda_lora_grad_accum) =
873            if classify_config.quantize_nf4 {
874                Self::try_init_nf4_lora_training(
875                    cuda_trainer.as_ref(),
876                    cuda_blocks.as_ref(),
877                    model_config,
878                    &classify_config,
879                )
880            } else {
881                (None, None, None)
882            };
883
884        // ── wgpu initialization ──────────────────────────────────────────
885        #[cfg(feature = "gpu")]
886        let wgpu_forward_pass = {
887            #[cfg(feature = "cuda")]
888            let has_cuda = cuda_trainer.is_some();
889            #[cfg(not(feature = "cuda"))]
890            let has_cuda = false;
891
892            if has_cuda {
893                None
894            } else {
895                // KAIZEN-015: Pre-upload FFN weights to GPU
896                crate::transformer::WgpuForwardPass::with_resident_weights(&model)
897                    .or_else(|e| {
898                        eprintln!("[wgpu] Resident init failed: {e}, trying default");
899                        crate::transformer::WgpuForwardPass::new_default(model_config)
900                    })
901                    .map_err(|e| eprintln!("[wgpu] GPU init failed: {e}"))
902                    .ok()
903            }
904        };
905
906        Ok(Self {
907            model,
908            classifier,
909            lora_layers,
910            config: classify_config,
911            optimizer,
912            tokenizer,
913            model_dir: Some(apr_path.to_path_buf()),
914            #[cfg(feature = "cuda")]
915            cuda_trainer,
916            #[cfg(feature = "cuda")]
917            cuda_blocks,
918            #[cfg(feature = "cuda")]
919            shared_scratch,
920            #[cfg(feature = "cuda")]
921            cuda_nan_count: 0,
922            #[cfg(feature = "cuda")]
923            gpu_training,
924            #[cfg(feature = "cuda")]
925            cuda_grad_workspace,
926            #[cfg(feature = "cuda")]
927            cuda_lora_grad_workspace,
928            #[cfg(feature = "cuda")]
929            cuda_lora_optimizer_states,
930            #[cfg(feature = "cuda")]
931            cuda_lora_grad_accum,
932            #[cfg(feature = "cuda")]
933            nf4_lora_step: 0,
934            #[cfg(feature = "gpu")]
935            wgpu_forward_pass,
936            #[cfg(feature = "cuda")]
937            vram_guard,
938        })
939    }
940
941    /// Tokenize input text using BPE tokenizer.
942    ///
943    /// Truncates to `config.max_seq_len` and ensures at least one token.
944    ///
945    /// # Panics
946    /// Panics if no BPE tokenizer is loaded. Training pipelines MUST have a
947    /// tokenizer — byte-level fallback is a silent corruption path.
948    pub(crate) fn tokenize(&self, text: &str) -> Vec<u32> {
949        let mut ids = match self.tokenizer.as_ref() {
950            Some(tok) => tok.encode(text),
951            None => {
952                // Byte-level fallback when no BPE tokenizer is loaded
953                text.bytes().map(u32::from).collect()
954            }
955        };
956        ids.truncate(self.config.max_seq_len);
957        if ids.is_empty() {
958            ids.push(0);
959        }
960        ids
961    }
962
963    /// Pre-tokenize a batch of samples for efficient training (KAIZEN-028).
964    ///
965    /// Tokenizes each sample once and stores the token IDs alongside the label.
966    /// This eliminates redundant BPE encoding across epochs and batches.
967    ///
968    /// # Contract (C-PRETOK-001)
969    ///
970    /// - **Precondition**: All samples have non-empty `input`
971    /// - **Postcondition**: Each `TokenizedSample` has `token_ids.len() in 1..=max_seq_len`
972    /// - **Invariant**: Tokenization is deterministic — same input always produces same IDs
973    pub fn pre_tokenize(&self, samples: &[SafetySample]) -> Vec<TokenizedSample> {
974        let has_tokenizer = self.tokenizer.is_some();
975        samples
976            .iter()
977            .map(|s| {
978                let token_ids = if has_tokenizer {
979                    self.tokenize(&s.input)
980                } else {
981                    // Byte-level fallback for tests without BPE tokenizer
982                    let mut ids = s.input_ids();
983                    ids.truncate(self.config.max_seq_len);
984                    if ids.is_empty() {
985                        ids.push(0);
986                    }
987                    ids
988                };
989                TokenizedSample { token_ids, label: s.label }
990            })
991            .collect()
992    }
993
994    /// Train on a batch of pre-tokenized samples (KAIZEN-028).
995    ///
996    /// Identical to [`train_batch`] but skips tokenization — token IDs are
997    /// pre-computed at dataset construction time.
998    pub fn train_batch_tokenized(&mut self, samples: &[TokenizedSample]) -> BatchResult {
999        if samples.is_empty() {
1000            return BatchResult { avg_loss: 0.0, correct: 0, total: 0, grad_norm: 0.0 };
1001        }
1002
1003        let batch_size = samples.len();
1004
1005        // ── 1. Zero gradients ──────────────────────────────────────────
1006        self.zero_all_gradients();
1007
1008        // ── 2. Accumulate gradients over all samples ───────────────────
1009        #[cfg(feature = "gpu")]
1010        let (total_loss, correct) = self
1011            .try_train_batch_wgpu_tokenized(samples)
1012            .unwrap_or_else(|| self.train_batch_per_sample_tokenized(samples));
1013
1014        #[cfg(not(feature = "gpu"))]
1015        let (total_loss, correct) = self.train_batch_per_sample_tokenized(samples);
1016
1017        // ── 3. Normalize gradients by batch size ───────────────────────
1018        self.scale_all_gradients(1.0 / batch_size as f32);
1019
1020        // ── 4. Gradient clipping (captures pre-clip norm) ────────────
1021        let grad_norm = if let Some(max_norm) = self.config.gradient_clip_norm {
1022            let mut params = self.trainable_parameters_mut();
1023            clip_grad_norm_refs(&mut params, max_norm)
1024        } else {
1025            self.compute_grad_norm()
1026        };
1027
1028        // ── 5. Optimizer step (once for the whole batch) ───────────────
1029        #[cfg(feature = "cuda")]
1030        {
1031            if self.gpu_training.is_some() && !self.config.quantize_nf4 {
1032                let lr = self.optimizer.lr();
1033                self.gpu_optimizer_step(lr);
1034            }
1035        }
1036
1037        #[cfg(feature = "cuda")]
1038        {
1039            if self.gpu_training.is_some() && self.config.quantize_nf4 {
1040                self.nf4_lora_batch_optimizer_step(batch_size);
1041            }
1042        }
1043
1044        let mut params: Vec<&mut Tensor> = Vec::new();
1045        if !self.config.quantize_nf4 {
1046            for lora in &mut self.lora_layers {
1047                params.extend(lora.trainable_params());
1048            }
1049        }
1050        params.extend(self.classifier.parameters_mut());
1051        self.optimizer.step_refs(&mut params);
1052
1053        BatchResult {
1054            avg_loss: total_loss / batch_size as f32,
1055            correct,
1056            total: batch_size,
1057            grad_norm,
1058        }
1059    }
1060
1061    /// Per-sample forward + backward with pre-tokenized IDs (KAIZEN-028).
1062    fn train_batch_per_sample_tokenized(&mut self, samples: &[TokenizedSample]) -> (f32, usize) {
1063        let mut total_loss = 0.0f32;
1064        let mut correct = 0usize;
1065        for sample in samples {
1066            let (loss, predicted) = self.forward_backward_single(&sample.token_ids, sample.label);
1067            total_loss += loss;
1068            if predicted == sample.label {
1069                correct += 1;
1070            }
1071        }
1072        (total_loss, correct)
1073    }
1074
1075    /// Batched wgpu forward with pre-tokenized IDs (KAIZEN-028).
1076    #[cfg(feature = "gpu")]
1077    fn try_train_batch_wgpu_tokenized(
1078        &mut self,
1079        samples: &[TokenizedSample],
1080    ) -> Option<(f32, usize)> {
1081        self.wgpu_forward_pass.as_ref()?;
1082
1083        let batch_token_ids: Vec<Vec<u32>> = samples.iter().map(|s| s.token_ids.clone()).collect();
1084
1085        let lora_ref =
1086            if self.lora_layers.is_empty() { None } else { Some(self.lora_layers.as_slice()) };
1087
1088        let hiddens = self
1089            .wgpu_forward_pass
1090            .as_ref()
1091            .expect("checked is_none above")
1092            .forward_hidden_batch(&self.model, &batch_token_ids, lora_ref)
1093            .map_err(|e| {
1094                eprintln!("[wgpu] Batched forward failed, falling back to per-sample: {e}");
1095            })
1096            .ok()?;
1097
1098        let mut total_loss = 0.0f32;
1099        let mut correct = 0usize;
1100        for (i, hidden) in hiddens.iter().enumerate() {
1101            let (loss, predicted) = self.classify_backward_from_hidden(
1102                hidden,
1103                batch_token_ids[i].len(),
1104                samples[i].label,
1105            );
1106            total_loss += loss;
1107            if predicted == samples[i].label {
1108                correct += 1;
1109            }
1110        }
1111        Some((total_loss, correct))
1112    }
1113
1114    /// Accumulate gradients with pre-tokenized samples (KAIZEN-028).
1115    ///
1116    /// Identical to [`accumulate_gradients`] but uses pre-tokenized IDs.
1117    pub fn accumulate_gradients_tokenized(
1118        &mut self,
1119        micro_batch: &[TokenizedSample],
1120    ) -> BatchResult {
1121        if micro_batch.is_empty() {
1122            return BatchResult { avg_loss: 0.0, correct: 0, total: 0, grad_norm: 0.0 };
1123        }
1124
1125        let mut total_loss = 0.0f32;
1126        let mut correct = 0usize;
1127
1128        for sample in micro_batch {
1129            let (loss, predicted) = self.forward_backward_single(&sample.token_ids, sample.label);
1130            total_loss += loss;
1131            if predicted == sample.label {
1132                correct += 1;
1133            }
1134        }
1135
1136        BatchResult {
1137            avg_loss: total_loss / micro_batch.len() as f32,
1138            correct,
1139            total: micro_batch.len(),
1140            grad_norm: 0.0,
1141        }
1142    }
1143
1144    /// Forward-only inference with pre-tokenized sample (KAIZEN-028).
1145    ///
1146    /// Used for validation where we need loss + prediction without backward pass.
1147    pub fn forward_only_tokenized(&mut self, token_ids: &[u32], label: usize) -> (f32, usize) {
1148        self.forward_only(token_ids, label)
1149    }
1150}
1151
1152// GPU initialization, VRAM guards, and CUDA pipeline methods
1153include!("gpu.rs");
1154
1155// Training, inference, gradient, and optimizer methods
1156include!("training.rs");
1157
1158#[cfg(test)]
1159#[allow(clippy::unwrap_used)]
1160mod tests;