Skip to main content

entrenar/train/transformer_trainer/
trainer.rs

1//! Transformer trainer implementation
2
3use crate::autograd::{checkpoint, GradScaler};
4use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
5use crate::lora::LoRALayer;
6use crate::optim::{AdamW, Optimizer};
7use crate::train::{CausalLMLoss, LossFn, MetricsTracker};
8use crate::transformer::Transformer;
9use crate::Tensor;
10use std::path::Path;
11
12use super::batch::LMBatch;
13use super::config::TransformerTrainConfig;
14
15/// Transformer training state
16pub struct TransformerTrainer {
17    /// Model
18    model: Transformer,
19    /// Loss function
20    loss_fn: CausalLMLoss,
21    /// Optimizer
22    optimizer: AdamW,
23    /// Gradient scaler for mixed precision
24    grad_scaler: GradScaler,
25    /// Configuration
26    config: TransformerTrainConfig,
27    /// Metrics tracker
28    pub metrics: MetricsTracker,
29    /// Current step
30    step: usize,
31    /// Accumulated gradients (for gradient accumulation)
32    accumulated_loss: f32,
33    /// Number of accumulated batches
34    accumulated_batches: usize,
35    /// LoRA layers (ENT-LoRA-001): [Q_0, V_0, Q_1, V_1, ...] per transformer layer
36    /// None = full fine-tuning, Some = LoRA fine-tuning
37    lora_layers: Option<Vec<LoRALayer>>,
38}
39
40impl TransformerTrainer {
41    /// Create a new transformer trainer
42    pub fn new(config: TransformerTrainConfig) -> Self {
43        // GATE-TRAIN-006 / INV-TRAIN-006: honor config.seed before weight init
44        // AND hold the init-seed lock for the full Transformer::new call so
45        // concurrent callers (parallel tests, concurrent harnesses) cannot
46        // clobber INIT_SEED between set and read. Previously only the YAML
47        // loader set this; direct TransformerTrainer::new callers silently
48        // inherited the global default (42), breaking seed reproducibility
49        // for any non-default seed.
50        let seed_guard = crate::transformer::init::lock_init_seed(config.seed);
51        let model = Transformer::new(&config.model_config);
52        drop(seed_guard);
53        Self::build(model, config)
54    }
55
56    /// Create trainer from existing model
57    pub fn with_model(model: Transformer, config: TransformerTrainConfig) -> Self {
58        Self::build(model, config)
59    }
60
61    /// Internal builder: initializes LoRA layers when config has LoRA enabled
62    fn build(model: Transformer, config: TransformerTrainConfig) -> Self {
63        let loss_fn = CausalLMLoss::new(config.model_config.vocab_size);
64        let optimizer = AdamW::default_params(config.lr);
65        let grad_scaler = GradScaler::from_config(&config.precision_config);
66
67        // ENT-LoRA-001: Create LoRA layers when config has LoRA rank
68        let lora_layers = if let Some(rank) = config.lora_rank {
69            let alpha = config.lora_alpha.unwrap_or(rank as f32 * 2.0);
70            let default_targets = vec!["q_proj".to_string(), "v_proj".to_string()];
71            // ENT-LoRA-005: Expand shorthand targets ("all_linear", "attention", etc.)
72            let raw_targets = config.lora_target_modules.as_deref().unwrap_or(&default_targets);
73            let expanded = crate::lora::LoRAConfig::expand_shorthand(raw_targets);
74            let target_modules = expanded.as_slice();
75
76            let mut layers = Vec::new();
77            let hidden_size = config.model_config.hidden_size;
78            let num_kv_heads = config.model_config.num_kv_heads;
79            let head_dim = config.model_config.head_dim();
80            let q_dim = config.model_config.q_dim();
81            let kv_hidden_size = num_kv_heads * head_dim;
82
83            let intermediate = config.model_config.intermediate_size;
84
85            for block in &model.layers {
86                // Attention projections (ENT-LoRA-005: flexible targets)
87                if target_modules.iter().any(|m| m == "q_proj") {
88                    layers.push(LoRALayer::new(
89                        block.self_attn.w_q.clone(),
90                        q_dim,
91                        hidden_size,
92                        rank,
93                        alpha,
94                    ));
95                }
96                if target_modules.iter().any(|m| m == "k_proj") {
97                    layers.push(LoRALayer::new(
98                        block.self_attn.w_k.clone(),
99                        kv_hidden_size,
100                        hidden_size,
101                        rank,
102                        alpha,
103                    ));
104                }
105                if target_modules.iter().any(|m| m == "v_proj") {
106                    layers.push(LoRALayer::new(
107                        block.self_attn.w_v.clone(),
108                        kv_hidden_size,
109                        hidden_size,
110                        rank,
111                        alpha,
112                    ));
113                }
114                if target_modules.iter().any(|m| m == "o_proj") {
115                    layers.push(LoRALayer::new(
116                        block.self_attn.w_o.clone(),
117                        hidden_size,
118                        q_dim,
119                        rank,
120                        alpha,
121                    ));
122                }
123                // MLP projections (ENT-LoRA-005)
124                if target_modules.iter().any(|m| m == "gate_proj") {
125                    layers.push(LoRALayer::new(
126                        block.ffn.w_gate.clone(),
127                        intermediate,
128                        hidden_size,
129                        rank,
130                        alpha,
131                    ));
132                }
133                if target_modules.iter().any(|m| m == "up_proj") {
134                    layers.push(LoRALayer::new(
135                        block.ffn.w_up.clone(),
136                        intermediate,
137                        hidden_size,
138                        rank,
139                        alpha,
140                    ));
141                }
142                if target_modules.iter().any(|m| m == "down_proj") {
143                    layers.push(LoRALayer::new(
144                        block.ffn.w_down.clone(),
145                        hidden_size,
146                        intermediate,
147                        rank,
148                        alpha,
149                    ));
150                }
151            }
152
153            let lora_param_count: usize =
154                layers.iter().map(|l| l.rank() * (l.d_in() + l.d_out())).sum();
155            let total_params: usize = model.parameters().iter().map(|p| p.len()).sum();
156            println!(
157                "  LoRA enabled: rank={rank}, alpha={alpha}, \
158                 {lora_param_count} trainable params ({:.2}% of {total_params})",
159                100.0 * lora_param_count as f64 / total_params as f64
160            );
161
162            Some(layers)
163        } else {
164            None
165        };
166
167        Self {
168            model,
169            loss_fn,
170            optimizer,
171            grad_scaler,
172            config,
173            metrics: MetricsTracker::new(),
174            step: 0,
175            accumulated_loss: 0.0,
176            accumulated_batches: 0,
177            lora_layers,
178        }
179    }
180
181    /// Forward pass on a single batch item
182    ///
183    /// Returns (loss_value, loss_tensor, logits)
184    /// When LoRA is active, routes through `forward_with_lora` so only
185    /// LoRA adapter gradients are accumulated.
186    pub fn forward_single(&self, input_ids: &[u32], target_ids: &[u32]) -> (f32, Tensor, Tensor) {
187        // Forward through transformer (LoRA or full)
188        let logits = if let Some(ref lora) = self.lora_layers {
189            // ENT-LoRA-001: Use LoRA forward path
190            self.model.forward_with_lora(input_ids, lora)
191        } else if self.config.checkpoint_config.enabled {
192            checkpoint(|_| self.model.forward(input_ids), &Tensor::zeros(1, false))
193        } else {
194            self.model.forward(input_ids)
195        };
196
197        // Compute loss
198        let targets = Tensor::from_vec(target_ids.iter().map(|&id| id as f32).collect(), false);
199        let loss = self.loss_fn.forward(&logits, &targets);
200        let loss_val = loss.data()[0];
201
202        (loss_val, loss, logits)
203    }
204
205    /// Compute forward + backward for all items in a batch, returning average loss.
206    fn compute_batch_gradients(&self, batch: &LMBatch) -> f32 {
207        let mut total_loss = 0.0;
208
209        for i in 0..batch.batch_size {
210            let Some(input_ids) = batch.get_input(i) else {
211                continue;
212            };
213            let Some(target_ids) = batch.get_target(i) else {
214                continue;
215            };
216
217            let (loss_val, loss, _logits) = self.forward_single(input_ids, target_ids);
218
219            if let Some(backward_op) = loss.backward_op() {
220                backward_op.backward();
221            }
222
223            total_loss += loss_val / self.config.accumulation_steps as f32;
224        }
225
226        total_loss / batch.batch_size as f32
227    }
228
229    /// Apply gradient clipping and run the optimizer step, then reset accumulation.
230    fn clip_and_step(&mut self) {
231        if let Some(max_norm) = self.config.base.max_grad_norm {
232            let params = if let Some(ref lora) = self.lora_layers {
233                lora.iter().flat_map(|l| vec![l.lora_a(), l.lora_b()]).collect::<Vec<_>>()
234            } else {
235                self.model.parameters()
236            };
237            let total_norm: f32 = params
238                .iter()
239                .filter_map(|p| p.grad())
240                .map(|g| g.iter().map(|x| x * x).sum::<f32>())
241                .sum::<f32>()
242                .sqrt();
243
244            if total_norm > max_norm {
245                let scale = max_norm / (total_norm + 1e-6);
246                let _ = scale;
247            }
248        }
249
250        // ENT-LoRA-002: Only update trainable params (LoRA A/B + norms when active)
251        if let Some(ref mut lora) = self.lora_layers {
252            // ENT-LoRA-006: LoRA+ gradient scaling for B matrices
253            let ratio = self.config.lora_plus_ratio;
254            if ratio != 1.0 {
255                for layer in lora.iter_mut() {
256                    if let Some(grad) = layer.lora_b_mut().grad() {
257                        let scaled = grad.mapv(|g| g * ratio);
258                        layer.lora_b_mut().set_grad(scaled);
259                    }
260                }
261            }
262
263            let mut params: Vec<&mut Tensor> =
264                lora.iter_mut().flat_map(|l| l.trainable_params()).collect();
265            // Also include norm weights (small, critical for adaptation)
266            for layer in &mut self.model.layers {
267                params.push(&mut layer.input_norm.weight);
268                params.push(&mut layer.post_attn_norm.weight);
269            }
270            params.push(&mut self.model.norm.weight);
271            self.optimizer.step_refs(&mut params);
272        } else {
273            let mut params = self.model.parameters_mut();
274            self.optimizer.step_refs(&mut params);
275        }
276
277        self.step += 1;
278        self.metrics.losses.push(self.accumulated_loss);
279        self.metrics.increment_step();
280
281        self.accumulated_loss = 0.0;
282        self.accumulated_batches = 0;
283    }
284
285    /// Process a batch (forward + backward + optimizer step)
286    ///
287    /// Returns average loss for the batch
288    pub fn train_batch(&mut self, batch: &LMBatch) -> f32 {
289        if batch.batch_size == 0 {
290            return 0.0;
291        }
292
293        if self.accumulated_batches == 0 {
294            // ENT-LoRA-002: Zero grad only on trainable params (LoRA A/B + norms)
295            if let Some(ref mut lora) = self.lora_layers {
296                let mut params: Vec<&mut Tensor> =
297                    lora.iter_mut().flat_map(|l| l.trainable_params()).collect();
298                for layer in &mut self.model.layers {
299                    params.push(&mut layer.input_norm.weight);
300                    params.push(&mut layer.post_attn_norm.weight);
301                }
302                params.push(&mut self.model.norm.weight);
303                self.optimizer.zero_grad_refs(&mut params);
304            } else {
305                let mut params = self.model.parameters_mut();
306                self.optimizer.zero_grad_refs(&mut params);
307            }
308        }
309
310        let avg_loss = self.compute_batch_gradients(batch);
311
312        self.accumulated_loss += avg_loss;
313        self.accumulated_batches += 1;
314
315        if self.accumulated_batches >= self.config.accumulation_steps {
316            self.clip_and_step();
317        }
318
319        avg_loss
320    }
321
322    /// Train for one epoch over batches
323    pub fn train_epoch(&mut self, batches: &[LMBatch]) -> f32 {
324        self.train_epoch_with_callback(batches, |_, _, _| {})
325    }
326
327    /// Train for one epoch with a per-step callback.
328    ///
329    /// The callback receives (batch_index, batch_loss, &self) after each batch.
330    /// Use this for progress logging, checkpointing, or early stopping.
331    ///
332    /// Stops early if `max_steps` is set and the step count reaches it.
333    /// Returns `(avg_loss, reached_max_steps)`.
334    pub fn train_epoch_with_callback<F>(&mut self, batches: &[LMBatch], mut on_batch: F) -> f32
335    where
336        F: FnMut(usize, f32, &Self),
337    {
338        if batches.is_empty() {
339            return 0.0;
340        }
341
342        let mut total_loss = 0.0;
343        let mut batches_processed = 0;
344
345        for (i, batch) in batches.iter().enumerate() {
346            // Check max_steps before processing
347            if let Some(max) = self.config.max_steps {
348                if self.step >= max {
349                    break;
350                }
351            }
352
353            let batch_loss = self.train_batch(batch);
354            total_loss += batch_loss;
355            batches_processed += 1;
356            on_batch(i, batch_loss, self);
357        }
358
359        total_loss / batches_processed.max(1) as f32
360    }
361
362    /// Returns true if max_steps has been reached.
363    pub fn reached_max_steps(&self) -> bool {
364        self.config.max_steps.is_some_and(|max| self.step >= max)
365    }
366
367    /// Get current step count
368    pub fn step(&self) -> usize {
369        self.step
370    }
371
372    /// Get reference to model
373    pub fn model(&self) -> &Transformer {
374        &self.model
375    }
376
377    /// Get mutable reference to model
378    pub fn model_mut(&mut self) -> &mut Transformer {
379        &mut self.model
380    }
381
382    /// Get current learning rate (with warmup applied)
383    pub fn current_lr(&self) -> f32 {
384        let base_lr = self.config.lr;
385
386        if self.step < self.config.warmup_steps {
387            // Linear warmup
388            base_lr * (self.step as f32 / self.config.warmup_steps as f32)
389        } else {
390            base_lr
391        }
392    }
393
394    /// Get gradient scaler stats
395    pub fn grad_scaler_stats(&self) -> (f32, usize, usize) {
396        (
397            self.grad_scaler.scale(),
398            self.grad_scaler.overflow_count(),
399            self.grad_scaler.successful_steps(),
400        )
401    }
402
403    /// Check if using mixed precision
404    pub fn is_mixed_precision(&self) -> bool {
405        self.config.precision_config.is_mixed()
406    }
407
408    /// Check if using gradient checkpointing
409    pub fn is_checkpointing(&self) -> bool {
410        self.config.checkpoint_config.enabled
411    }
412
413    /// Check if LoRA training is active
414    pub fn is_lora(&self) -> bool {
415        self.lora_layers.is_some()
416    }
417
418    /// Get reference to LoRA layers (for checkpoint saving)
419    pub fn lora_layers(&self) -> Option<&[LoRALayer]> {
420        self.lora_layers.as_deref()
421    }
422
423    /// Get mutable reference to LoRA layers
424    pub fn lora_layers_mut(&mut self) -> Option<&mut Vec<LoRALayer>> {
425        self.lora_layers.as_mut()
426    }
427
428    /// Save LoRA adapter in PEFT-compatible format (ENT-LoRA-003)
429    ///
430    /// Saves only LoRA A/B weights as `adapter_model.safetensors` + `adapter_config.json`.
431    /// Adapter checkpoint is typically <1% of full model size.
432    ///
433    /// # Arguments
434    /// * `output_dir` - Directory to save adapter files
435    /// * `base_model_name` - Optional HuggingFace model ID for adapter_config.json
436    ///
437    /// # Errors
438    /// Returns error if not in LoRA mode or I/O fails.
439    pub fn save_lora_adapter(
440        &self,
441        output_dir: impl AsRef<Path>,
442        base_model_name: Option<&str>,
443    ) -> crate::Result<()> {
444        let lora = self.lora_layers.as_ref().ok_or_else(|| {
445            crate::error::Error::ConfigError("Cannot save adapter: LoRA not enabled".into())
446        })?;
447
448        let rank = self.config.lora_rank.unwrap_or(8);
449        let alpha = self.config.lora_alpha.unwrap_or(rank as f32 * 2.0);
450        let target_modules = self
451            .config
452            .lora_target_modules
453            .clone()
454            .unwrap_or_else(|| vec!["q_proj".to_string(), "v_proj".to_string()]);
455
456        // ENT-LoRA-005: Expand shorthand targets for correct naming
457        let expanded = crate::lora::LoRAConfig::expand_shorthand(&target_modules);
458        let lora_config = crate::lora::LoRAConfig::new(rank, alpha)
459            .target_modules(&expanded.iter().map(String::as_str).collect::<Vec<_>>());
460
461        // ENT-LoRA-007: Build named adapter pairs with correct PEFT naming
462        // Layers are ordered per build(): [q, k, v, o, gate, up, down] per block
463        let num_layers = self.model.layers.len();
464
465        // Map target module names to their layer path prefix
466        let module_paths: Vec<(&str, &str)> = [
467            ("q_proj", "self_attn.q_proj"),
468            ("k_proj", "self_attn.k_proj"),
469            ("v_proj", "self_attn.v_proj"),
470            ("o_proj", "self_attn.o_proj"),
471            ("gate_proj", "mlp.gate_proj"),
472            ("up_proj", "mlp.up_proj"),
473            ("down_proj", "mlp.down_proj"),
474        ]
475        .iter()
476        .filter(|(name, _)| expanded.iter().any(|t| t == *name))
477        .copied()
478        .collect();
479
480        // Generate full path names for each (block, module) pair
481        let all_names: Vec<String> = (0..num_layers)
482            .flat_map(|i| {
483                module_paths.iter().map(move |(_, path)| format!("model.layers.{i}.{path}"))
484            })
485            .collect();
486
487        let mut adapters: Vec<(&str, &LoRALayer)> = Vec::new();
488        for (idx, layer) in lora.iter().enumerate() {
489            if idx < all_names.len() {
490                adapters.push((&all_names[idx], layer));
491            }
492        }
493
494        crate::lora::save_adapter_peft(&adapters, &lora_config, base_model_name, output_dir)
495            .map_err(|e| crate::error::Error::Io(e.to_string()))
496    }
497
498    /// Save model weights to a SafeTensors file
499    ///
500    /// This persists the trained transformer weights to disk.
501    /// Call this after training completes to preserve the learned parameters.
502    ///
503    /// # Arguments
504    ///
505    /// * `path` - Output file path (should end in .safetensors)
506    /// * `name` - Model name for metadata
507    /// * `architecture` - Model architecture description (e.g., "Qwen2ForCausalLM")
508    ///
509    /// # Errors
510    ///
511    /// Returns an error if the file cannot be written.
512    pub fn save(
513        &self,
514        path: impl AsRef<Path>,
515        name: &str,
516        architecture: &str,
517    ) -> crate::Result<()> {
518        // Use named_parameters() for correct name mapping (handles attention biases etc.)
519        let params: Vec<(String, Tensor)> = self
520            .model
521            .named_parameters()
522            .into_iter()
523            .map(|(name, tensor)| (name, tensor.clone()))
524            .collect();
525
526        let metadata = ModelMetadata::new(name, architecture);
527        let model = Model::new(metadata, params);
528        let config = SaveConfig::new(ModelFormat::SafeTensors);
529
530        save_model(&model, path, &config)
531    }
532
533    /// Save model weights in the sovereign APR format.
534    ///
535    /// Mirror of `CudaTransformerTrainer::save_apr` for the CPU path.
536    /// APR is the row-major atomic single-file format shared across
537    /// training and inference (per aprender-train CLAUDE.md LAYOUT-002
538    /// mandate), so training checkpoints emitted by `PretrainLoop`
539    /// load directly in realizar / `apr run` with no re-transpose.
540    ///
541    /// # Arguments
542    ///
543    /// * `path` - Output file path (should end in `.apr`)
544    /// * `name` - Model name for metadata
545    /// * `architecture` - Model architecture description
546    ///   (e.g., `"LlamaForCausalLM"`)
547    pub fn save_apr(
548        &self,
549        path: impl AsRef<Path>,
550        name: &str,
551        architecture: &str,
552    ) -> crate::Result<()> {
553        let params: Vec<(String, Tensor)> =
554            self.model.named_parameters().into_iter().map(|(n, t)| (n, t.clone())).collect();
555        let metadata = ModelMetadata::new(name, architecture);
556        let model = Model::new(metadata, params);
557        let config = SaveConfig::new(ModelFormat::Apr);
558        save_model(&model, path, &config)
559    }
560
561    /// sha256 over the AdamW optimizer state bytes (INV-TRAIN-003).
562    ///
563    /// Hashes `(t, m_buffers, v_buffers)` in fixed order so two runs
564    /// with matching hyperparameters, seed, and batch order produce
565    /// the same digest (GATE-TRAIN-006 reproducibility).
566    ///
567    /// Uninitialized buffers (before the first step) hash to the
568    /// tag `"none"` so they still participate deterministically in
569    /// the digest — missing `m[i]` is semantically distinct from
570    /// an all-zeros `m[i]`.
571    #[must_use]
572    pub fn optimizer_state_sha256(&self) -> String {
573        use sha2::{Digest, Sha256};
574        let mut hasher = Sha256::new();
575        hasher.update(b"aprender-train:adamw:optstate:v1");
576        hasher.update(self.optimizer.step_count().to_le_bytes());
577        let moment_streams: [(&[u8], &[Option<ndarray::Array1<f32>>]); 2] =
578            [(b"m", self.optimizer.first_moments()), (b"v", self.optimizer.second_moments())];
579        for (tag, buffers) in moment_streams {
580            hasher.update(tag);
581            hasher.update((buffers.len() as u64).to_le_bytes());
582            for slot in buffers {
583                match slot {
584                    Some(arr) => {
585                        hasher.update(b"some");
586                        hasher.update((arr.len() as u64).to_le_bytes());
587                        let bytes: &[u8] = bytemuck::cast_slice(
588                            arr.as_slice().expect("AdamW buffers are contiguous"),
589                        );
590                        hasher.update(bytes);
591                    }
592                    None => hasher.update(b"none"),
593                }
594            }
595        }
596        format!("{:x}", hasher.finalize())
597    }
598}