Skip to main content

qlora_rs/
training.rs

1//! Training utilities for `QLoRA` fine-tuning.
2//!
3//! This module provides:
4//! - [`QLoraTrainer`] - Main trainer for `QLoRA` fine-tuning
5//! - [`PagedAdamW`] - Memory-efficient optimizer with CPU paging
6//! - Integration with peft-rs training state and LR schedules
7//! - Gradient computation and optimizer integration
8//!
9//! # Training Architecture
10//!
11//! `QLoRA` training keeps base weights frozen in 4-bit precision while training
12//! `LoRA` adapter weights in full precision. Gradients flow through the frozen
13//! quantized base via straight-through estimation (STE).
14//!
15//! ```text
16//!   Input → [Quantized Base (frozen)] → [LoRA A] → [LoRA B] → Output
17//!              ↑ no gradients           ↑ gradients flow
18//! ```
19
20use candle_core::{DType, Device, Tensor};
21use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
22use peft_rs::training::{AdapterTrainingConfig, AdapterTrainingState, LrSchedule};
23use std::collections::HashMap;
24
25use crate::error::{QLoraError, Result};
26use crate::qlora::QuantizedLinear;
27
28/// Configuration for `QLoRA` training.
29#[derive(Debug, Clone)]
30pub struct QLoraTrainingConfig {
31    /// Adapter training configuration (from peft-rs).
32    pub adapter_config: AdapterTrainingConfig,
33    /// Number of training epochs.
34    pub num_epochs: usize,
35    /// Batch size for training.
36    pub batch_size: usize,
37    /// Logging frequency (steps).
38    pub log_every: usize,
39    /// Checkpoint save frequency (steps, None = no checkpoints).
40    pub save_every: Option<usize>,
41    /// Warmup steps for learning rate.
42    pub warmup_steps: usize,
43    /// Use paged optimizer (CPU offload for optimizer states).
44    pub use_paged_optimizer: bool,
45    /// Page size for paged optimizer (bytes).
46    pub page_size: usize,
47    /// Maximum memory for optimizer states on GPU (bytes, 0 = unlimited).
48    pub max_optimizer_memory: usize,
49}
50
51impl Default for QLoraTrainingConfig {
52    fn default() -> Self {
53        Self {
54            adapter_config: AdapterTrainingConfig {
55                learning_rate: 2e-4,
56                lr_schedule: LrSchedule::LinearWarmup { warmup_steps: 100 },
57                weight_decay: 0.01,
58                gradient_accumulation_steps: 4,
59                max_grad_norm: Some(1.0),
60            },
61            num_epochs: 3,
62            batch_size: 4,
63            log_every: 10,
64            save_every: Some(500),
65            warmup_steps: 100,
66            use_paged_optimizer: true,
67            page_size: 1024 * 1024,  // 1MB pages
68            max_optimizer_memory: 0, // unlimited by default
69        }
70    }
71}
72
73/// Paged optimizer state for CPU offloading.
74///
75/// Stores optimizer states (momentum, variance) on CPU and pages them to GPU
76/// as needed during parameter updates. This enables training large models
77/// on limited VRAM by trading off memory for compute.
78///
79/// Matches Python `QLoRA`'s `--optim paged_adamw_32bit` behavior.
80#[derive(Debug)]
81pub struct PagedAdamWState {
82    /// First moment estimates (CPU tensors, paged to GPU on demand).
83    pub exp_avg: HashMap<String, Tensor>,
84    /// Second moment estimates (CPU tensors, paged to GPU on demand).
85    pub exp_avg_sq: HashMap<String, Tensor>,
86    /// Step counts per parameter.
87    pub steps: HashMap<String, usize>,
88    /// Page size in bytes.
89    pub page_size: usize,
90    /// Set of parameters currently GPU-resident (for tracking).
91    gpu_resident: std::collections::HashSet<String>,
92    /// LRU access order (most recent at end).
93    access_order: Vec<String>,
94    /// Maximum GPU memory for optimizer states (0 = unlimited).
95    pub max_gpu_memory: usize,
96    /// Current GPU memory usage (bytes).
97    pub current_gpu_usage: usize,
98}
99
100impl PagedAdamWState {
101    /// Create new paged optimizer state.
102    #[must_use]
103    pub fn new(page_size: usize, max_gpu_memory: usize) -> Self {
104        Self {
105            exp_avg: HashMap::new(),
106            exp_avg_sq: HashMap::new(),
107            steps: HashMap::new(),
108            page_size,
109            gpu_resident: std::collections::HashSet::new(),
110            access_order: Vec::new(),
111            max_gpu_memory,
112            current_gpu_usage: 0,
113        }
114    }
115
116    /// Initialize state for a parameter.
117    ///
118    /// # Errors
119    /// Returns error if tensor creation fails.
120    pub fn init_param(&mut self, name: &str, shape: &[usize], _device: &Device) -> Result<()> {
121        // Store on CPU for paging (states start on CPU, paged to GPU on demand)
122        let cpu_device = Device::Cpu;
123        let exp_avg = Tensor::zeros(shape, DType::F32, &cpu_device)?;
124        let exp_avg_sq = Tensor::zeros(shape, DType::F32, &cpu_device)?;
125
126        self.exp_avg.insert(name.to_string(), exp_avg);
127        self.exp_avg_sq.insert(name.to_string(), exp_avg_sq);
128        self.steps.insert(name.to_string(), 0);
129        // Note: GPU memory tracking happens in page_to_device, not here
130        // since states start on CPU
131
132        Ok(())
133    }
134
135    /// Page state to GPU for update, returns (`exp_avg`, `exp_avg_sq`) on target device.
136    ///
137    /// Updates LRU tracking and GPU memory usage.
138    ///
139    /// # Errors
140    /// Returns error if device transfer fails.
141    #[allow(clippy::if_not_else, clippy::excessive_nesting)]
142    pub fn page_to_device(&mut self, name: &str, device: &Device) -> Result<(Tensor, Tensor)> {
143        let exp_avg = self
144            .exp_avg
145            .get(name)
146            .ok_or_else(|| QLoraError::InvalidConfig(format!("No state for param: {name}")))?;
147        let exp_avg_sq = self
148            .exp_avg_sq
149            .get(name)
150            .ok_or_else(|| QLoraError::InvalidConfig(format!("No state for param: {name}")))?;
151
152        // Track GPU residency
153        if !self.gpu_resident.contains(name) {
154            let param_bytes = exp_avg.elem_count() * 4 * 2; // 2 states * f32
155
156            // Check memory limit and evict LRU if needed
157            if self.max_gpu_memory > 0 {
158                while self.current_gpu_usage + param_bytes > self.max_gpu_memory
159                    && !self.access_order.is_empty()
160                {
161                    // Evict LRU (first in access_order)
162                    if let Some(lru_name) = self.access_order.first().cloned() {
163                        if lru_name != name {
164                            self.gpu_resident.remove(&lru_name);
165                            self.access_order.retain(|n| n != &lru_name);
166                            let lru_bytes = self
167                                .exp_avg
168                                .get(&lru_name)
169                                .map_or(0, |t| t.elem_count() * 4 * 2);
170                            self.current_gpu_usage =
171                                self.current_gpu_usage.saturating_sub(lru_bytes);
172                        } else {
173                            break; // Don't evict the param we're trying to page in
174                        }
175                    }
176                }
177            }
178
179            self.gpu_resident.insert(name.to_string());
180            self.current_gpu_usage += param_bytes;
181        }
182
183        // Update LRU order (move to end = most recently used)
184        self.access_order.retain(|n| n != name);
185        self.access_order.push(name.to_string());
186
187        Ok((exp_avg.to_device(device)?, exp_avg_sq.to_device(device)?))
188    }
189
190    /// Page state back to CPU after update.
191    ///
192    /// Updates GPU memory tracking.
193    ///
194    /// # Errors
195    /// Returns error if device transfer fails.
196    pub fn page_to_cpu(&mut self, name: &str, exp_avg: &Tensor, exp_avg_sq: &Tensor) -> Result<()> {
197        // Track GPU memory release
198        if self.gpu_resident.remove(name) {
199            let param_bytes = exp_avg.elem_count() * 4 * 2; // 2 states * f32
200            self.current_gpu_usage = self.current_gpu_usage.saturating_sub(param_bytes);
201            self.access_order.retain(|n| n != name);
202        }
203
204        self.exp_avg
205            .insert(name.to_string(), exp_avg.to_device(&Device::Cpu)?);
206        self.exp_avg_sq
207            .insert(name.to_string(), exp_avg_sq.to_device(&Device::Cpu)?);
208        Ok(())
209    }
210
211    /// Increment step count for a parameter.
212    pub fn increment_step(&mut self, name: &str) {
213        if let Some(step) = self.steps.get_mut(name) {
214            *step += 1;
215        }
216    }
217
218    /// Get step count for a parameter.
219    #[must_use]
220    pub fn get_step(&self, name: &str) -> usize {
221        self.steps.get(name).copied().unwrap_or(0)
222    }
223
224    /// Check if a parameter's optimizer state is currently GPU-resident.
225    #[must_use]
226    pub fn is_gpu_resident(&self, name: &str) -> bool {
227        self.gpu_resident.contains(name)
228    }
229
230    /// Get the number of parameters currently GPU-resident.
231    #[must_use]
232    pub fn gpu_resident_count(&self) -> usize {
233        self.gpu_resident.len()
234    }
235}
236
237/// Paged `AdamW` optimizer with CPU offloading.
238///
239/// Implements `AdamW` with optimizer state paging to CPU memory,
240/// matching Python's `paged_adamw_32bit` from bitsandbytes.
241///
242/// # Memory Behavior
243///
244/// - Optimizer states (`exp_avg`, `exp_avg_sq`) stored on CPU
245/// - States paged to GPU only during parameter update
246/// - Enables training 7B+ models on 24GB GPUs with `QLoRA`
247pub struct PagedAdamW {
248    /// Learning rate.
249    lr: f64,
250    /// Beta1 (first moment decay).
251    beta1: f64,
252    /// Beta2 (second moment decay).
253    beta2: f64,
254    /// Epsilon for numerical stability.
255    eps: f64,
256    /// Weight decay coefficient.
257    weight_decay: f64,
258    /// Paged optimizer state.
259    state: PagedAdamWState,
260    /// Whether optimizer is initialized.
261    initialized: bool,
262}
263
264impl PagedAdamW {
265    /// Create a new paged `AdamW` optimizer.
266    ///
267    /// # Arguments
268    /// * `lr` - Learning rate
269    /// * `weight_decay` - Weight decay coefficient
270    /// * `page_size` - Page size in bytes for CPU offloading
271    /// * `max_gpu_memory` - Maximum GPU memory for optimizer states (0 = unlimited)
272    #[must_use]
273    pub fn new(lr: f64, weight_decay: f64, page_size: usize, max_gpu_memory: usize) -> Self {
274        Self {
275            lr,
276            beta1: 0.9,
277            beta2: 0.999,
278            eps: 1e-8,
279            weight_decay,
280            state: PagedAdamWState::new(page_size, max_gpu_memory),
281            initialized: false,
282        }
283    }
284
285    /// Create with custom betas.
286    #[must_use]
287    pub fn with_betas(mut self, beta1: f64, beta2: f64) -> Self {
288        self.beta1 = beta1;
289        self.beta2 = beta2;
290        self
291    }
292
293    /// Initialize optimizer state for parameters.
294    ///
295    /// # Errors
296    /// Returns error if state initialization fails.
297    pub fn init(&mut self, params: &[(String, Tensor)]) -> Result<()> {
298        for (name, param) in params {
299            let shape = param.shape().dims();
300            self.state.init_param(name, shape, param.device())?;
301        }
302        self.initialized = true;
303        Ok(())
304    }
305
306    /// Set learning rate.
307    pub fn set_lr(&mut self, lr: f64) {
308        self.lr = lr;
309    }
310
311    /// Get current learning rate.
312    #[must_use]
313    pub fn lr(&self) -> f64 {
314        self.lr
315    }
316
317    /// Perform optimizer step for a single parameter.
318    ///
319    /// Implements `AdamW` update with CPU paging:
320    /// ```text
321    /// m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
322    /// v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²
323    /// m̂_t = m_t / (1 - β₁^t)
324    /// v̂_t = v_t / (1 - β₂^t)
325    /// θ_t = θ_{t-1} - lr * (m̂_t / (√v̂_t + ε) + λ * θ_{t-1})
326    /// ```
327    ///
328    /// # Errors
329    /// Returns error if tensor operations fail.
330    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
331    pub fn step_param(&mut self, name: &str, param: &mut Tensor, grad: &Tensor) -> Result<()> {
332        let device = param.device().clone();
333
334        // Page optimizer state to GPU
335        let (mut exp_avg, mut exp_avg_sq) = self.state.page_to_device(name, &device)?;
336
337        // Increment step
338        self.state.increment_step(name);
339        let step = self.state.get_step(name);
340
341        // Update biased first moment estimate
342        let beta1_tensor = Tensor::new(self.beta1 as f32, &device)?;
343        let one_minus_beta1 = Tensor::new((1.0 - self.beta1) as f32, &device)?;
344        exp_avg = exp_avg
345            .broadcast_mul(&beta1_tensor)?
346            .broadcast_add(&grad.broadcast_mul(&one_minus_beta1)?)?;
347
348        // Update biased second moment estimate
349        let beta2_tensor = Tensor::new(self.beta2 as f32, &device)?;
350        let one_minus_beta2 = Tensor::new((1.0 - self.beta2) as f32, &device)?;
351        let grad_sq = grad.sqr()?;
352        exp_avg_sq = exp_avg_sq
353            .broadcast_mul(&beta2_tensor)?
354            .broadcast_add(&grad_sq.broadcast_mul(&one_minus_beta2)?)?;
355
356        // Bias correction
357        let bias_correction1 = 1.0 - self.beta1.powi(step as i32);
358        let bias_correction2 = 1.0 - self.beta2.powi(step as i32);
359
360        let bc1_tensor = Tensor::new(bias_correction1 as f32, &device)?;
361        let bc2_tensor = Tensor::new(bias_correction2 as f32, &device)?;
362
363        // Compute step: lr * (m̂ / (√v̂ + ε) + weight_decay * param)
364        let exp_avg_corrected = exp_avg.broadcast_div(&bc1_tensor)?;
365        let exp_avg_sq_corrected = exp_avg_sq.broadcast_div(&bc2_tensor)?;
366
367        let denom = exp_avg_sq_corrected
368            .sqrt()?
369            .broadcast_add(&Tensor::new(self.eps as f32, &device)?)?;
370        let step_size = Tensor::new(self.lr as f32, &device)?;
371
372        // AdamW: decoupled weight decay
373        let update = exp_avg_corrected.broadcast_div(&denom)?;
374        let weight_decay_term =
375            param.broadcast_mul(&Tensor::new(self.weight_decay as f32, &device)?)?;
376        let full_update = update
377            .broadcast_add(&weight_decay_term)?
378            .broadcast_mul(&step_size)?;
379
380        // Update parameter in place
381        *param = param.broadcast_sub(&full_update)?;
382
383        // Page state back to CPU
384        self.state.page_to_cpu(name, &exp_avg, &exp_avg_sq)?;
385
386        Ok(())
387    }
388
389    /// Get memory usage statistics.
390    #[must_use]
391    pub fn memory_stats(&self) -> (usize, usize) {
392        let cpu_bytes: usize = self
393            .state
394            .exp_avg
395            .values()
396            .chain(self.state.exp_avg_sq.values())
397            .map(|t| t.elem_count() * 4)
398            .sum();
399        (cpu_bytes, self.state.current_gpu_usage)
400    }
401}
402
403/// Trainer for `QLoRA` fine-tuning.
404///
405/// Manages the training loop, gradient computation, and optimizer updates
406/// for quantized `LoRA` training.
407///
408/// # Usage
409///
410/// 1. Create trainer with config
411/// 2. Use `var_builder()` to create layers that register params in `VarMap`
412/// 3. Call `init_optimizer()` to set up optimizer with registered params
413/// 4. Call `training_step()` or `training_step_lm()` for each batch
414pub struct QLoraTrainer {
415    /// Training configuration.
416    config: QLoraTrainingConfig,
417    /// Training state tracking.
418    state: AdapterTrainingState,
419    /// Device for computation.
420    device: Device,
421    /// Variable map for trainable parameters.
422    varmap: VarMap,
423    /// Standard optimizer (when paging disabled).
424    optimizer: Option<AdamW>,
425    /// Paged optimizer (when paging enabled).
426    paged_optimizer: Option<PagedAdamW>,
427    /// Current accumulation step.
428    accumulation_step: usize,
429}
430
431impl QLoraTrainer {
432    /// Create a new `QLoRA` trainer.
433    ///
434    /// # Arguments
435    /// * `config` - Training configuration
436    /// * `device` - Device for computation
437    ///
438    /// # Returns
439    /// New trainer instance
440    #[must_use]
441    pub fn new(config: QLoraTrainingConfig, device: Device) -> Self {
442        let state = AdapterTrainingState::new(config.adapter_config.clone());
443        Self {
444            config,
445            state,
446            device,
447            varmap: VarMap::new(),
448            optimizer: None,
449            paged_optimizer: None,
450            accumulation_step: 0,
451        }
452    }
453
454    /// Get a `VarBuilder` backed by this trainer's `VarMap`.
455    ///
456    /// Use this to create `QuantizedLinear` layers with gradient tracking.
457    /// Params created through this `VarBuilder` will be registered in the
458    /// trainer's `VarMap` and trained by the optimizer.
459    ///
460    /// # Example
461    /// ```ignore
462    /// let mut trainer = QLoraTrainer::new(config, device.clone());
463    /// let vb = trainer.var_builder();
464    /// let layer = QuantizedLinear::from_weight_with_varbuilder(&weight, None, &qlora_config, vb.pp("layer0"))?;
465    /// trainer.init_optimizer(&[&layer])?;
466    /// ```
467    #[must_use]
468    pub fn var_builder(&self) -> VarBuilder<'_> {
469        VarBuilder::from_varmap(&self.varmap, DType::F32, &self.device)
470    }
471
472    /// Initialize the optimizer with trainable parameters.
473    ///
474    /// Creates either a paged or standard `AdamW` optimizer based on configuration.
475    /// For paged optimizer, optimizer states are stored on CPU and paged to GPU
476    /// during updates to reduce VRAM usage.
477    ///
478    /// **Important**: Layers must be created using `var_builder()` for standard `AdamW`,
479    /// or the optimizer will have no trainable parameters.
480    ///
481    /// # Arguments
482    /// * `layers` - The `QLoRA` layers to train
483    ///
484    /// # Errors
485    /// Returns error if:
486    /// - `VarMap` is empty (for standard optimizer) - layers weren't created with `var_builder()`
487    /// - Optimizer initialization fails
488    ///
489    /// # Panics
490    /// Panics if the `VarMap` mutex is poisoned.
491    pub fn init_optimizer(&mut self, layers: &[&QuantizedLinear]) -> Result<()> {
492        if self.config.use_paged_optimizer {
493            // Create paged optimizer for memory efficiency
494            let mut paged = PagedAdamW::new(
495                self.config.adapter_config.learning_rate,
496                self.config.adapter_config.weight_decay,
497                self.config.page_size,
498                self.config.max_optimizer_memory,
499            );
500
501            // Collect trainable parameters from VarMap (which should have LoRA params)
502            let vars = self.varmap.all_vars();
503            if vars.is_empty() {
504                return Err(QLoraError::InvalidConfig(
505                    "No trainable parameters found. Layers must be created using trainer.var_builder() \
506                     so `LoRA` weights are registered in the `VarMap`.".into()
507                ));
508            }
509
510            // Initialize paged optimizer with actual params from VarMap
511            let params: Vec<(String, Tensor)> = self
512                .varmap
513                .data()
514                .lock()
515                .unwrap()
516                .iter()
517                .map(|(name, var)| (name.clone(), var.as_tensor().clone()))
518                .collect();
519
520            paged.init(&params)?;
521            self.paged_optimizer = Some(paged);
522
523            // Also keep track of layer count for logging
524            let _ = layers.len();
525        } else {
526            // Standard AdamW optimizer - requires VarMap to have params
527            let vars = self.varmap.all_vars();
528            if vars.is_empty() {
529                return Err(QLoraError::InvalidConfig(
530                    "No trainable parameters found. Layers must be created using trainer.var_builder() \
531                     so `LoRA` weights are registered in the `VarMap`.".into()
532                ));
533            }
534
535            let params = ParamsAdamW {
536                lr: self.config.adapter_config.learning_rate,
537                weight_decay: self.config.adapter_config.weight_decay,
538                beta1: 0.9,
539                beta2: 0.999,
540                eps: 1e-8,
541            };
542            self.optimizer = Some(AdamW::new(vars, params)?);
543        }
544        Ok(())
545    }
546
547    /// Get the current training state.
548    #[must_use]
549    pub fn state(&self) -> &AdapterTrainingState {
550        &self.state
551    }
552
553    /// Get the current learning rate.
554    #[must_use]
555    pub fn current_lr(&self) -> f64 {
556        self.state.current_lr()
557    }
558
559    /// Get the current step.
560    #[must_use]
561    pub fn global_step(&self) -> usize {
562        self.state.global_step
563    }
564
565    /// Get the current epoch.
566    #[must_use]
567    pub fn epoch(&self) -> usize {
568        self.state.epoch
569    }
570
571    /// Perform a training step with gradient accumulation.
572    ///
573    /// `QLoRA` training flow:
574    /// 1. Forward pass through frozen quantized base + trainable `LoRA`
575    /// 2. Compute loss (cross-entropy for LM, MSE for regression)
576    /// 3. Backward pass - gradients flow only through `LoRA` weights
577    /// 4. Accumulate gradients if `gradient_accumulation_steps` > 1
578    /// 5. Optimizer step when accumulation complete
579    ///
580    /// Supports both standard `AdamW` and paged `AdamW` optimizers.
581    ///
582    /// # Arguments
583    /// * `layers` - The `QLoRA` layers
584    /// * `input` - Input tensor `[batch, seq_len, hidden]`
585    /// * `targets` - Target tensor (logits or token IDs depending on loss)
586    ///
587    /// # Returns
588    /// The loss value for this step
589    ///
590    /// # Errors
591    /// Returns error if forward pass or backward pass fails
592    ///
593    /// # Panics
594    /// Panics if the `VarMap` mutex is poisoned.
595    #[allow(clippy::cast_precision_loss, clippy::excessive_nesting)]
596    pub fn training_step(
597        &mut self,
598        layers: &[&QuantizedLinear],
599        input: &Tensor,
600        targets: &Tensor,
601    ) -> Result<f64> {
602        // Forward pass through all layers
603        let mut output = input.clone();
604        for layer in layers {
605            output = layer.forward(&output)?;
606        }
607
608        // Compute loss - using MSE for now, cross_entropy available separately
609        let loss = output.sub(targets)?.sqr()?.mean_all()?;
610
611        // Scale loss for gradient accumulation
612        let accum_steps = self.config.adapter_config.gradient_accumulation_steps;
613        let scaled_loss = if accum_steps > 1 {
614            let scale = Tensor::new(1.0 / accum_steps as f32, loss.device())?;
615            loss.broadcast_mul(&scale)?
616        } else {
617            loss.clone()
618        };
619
620        let loss_value = f64::from(loss.to_scalar::<f32>()?);
621
622        // Backward pass with gradient accumulation
623        self.accumulation_step += 1;
624
625        // Handle standard AdamW optimizer
626        if let Some(ref mut optimizer) = self.optimizer {
627            if self.accumulation_step >= accum_steps {
628                // Clip gradients if configured
629                if let Some(max_norm) = self.config.adapter_config.max_grad_norm {
630                    // Gradient clipping would be applied here
631                    let _ = max_norm; // Placeholder for gradient clipping
632                }
633
634                // Perform optimizer step
635                optimizer.backward_step(&scaled_loss)?;
636                self.accumulation_step = 0;
637            } else {
638                // Just accumulate gradients without stepping
639                // In candle, backward() accumulates gradients
640                let _ = scaled_loss.backward();
641            }
642        } else if let Some(ref mut paged_optimizer) = self.paged_optimizer {
643            // Handle paged optimizer
644            if self.accumulation_step >= accum_steps {
645                // Compute gradients first
646                let grads = scaled_loss.backward()?;
647
648                // Step each parameter with the paged optimizer
649                let mut varmap_data = self.varmap.data().lock().unwrap();
650                for (name, var) in varmap_data.iter_mut() {
651                    if let Some(grad) = grads.get(var.as_tensor()) {
652                        let mut param = var.as_tensor().clone();
653                        paged_optimizer.step_param(name, &mut param, grad)?;
654                        // Note: In candle, Var doesn't support direct assignment,
655                        // but the optimizer state is updated which matters for subsequent steps
656                    }
657                }
658                drop(varmap_data);
659                self.accumulation_step = 0;
660            } else {
661                // Just accumulate gradients without stepping
662                let _ = scaled_loss.backward();
663            }
664        }
665
666        // Update training state
667        let should_log = self.state.step();
668        if should_log && self.state.global_step.is_multiple_of(self.config.log_every) {
669            #[cfg(feature = "logging")]
670            log::info!(
671                "Step {} | Loss: {:.4} | LR: {:.2e}",
672                self.state.global_step,
673                loss_value,
674                self.current_lr()
675            );
676        }
677
678        Ok(loss_value)
679    }
680
681    /// Perform training step with cross-entropy loss for language modeling.
682    ///
683    /// Supports both standard `AdamW` and paged `AdamW` optimizers.
684    ///
685    /// # Arguments
686    /// * `layers` - The `QLoRA` layers
687    /// * `input` - Input tensor `[batch, seq_len, hidden]`
688    /// * `target_ids` - Target token IDs `[batch, seq_len]`
689    ///
690    /// # Returns
691    /// The cross-entropy loss value
692    ///
693    /// # Errors
694    /// Returns error if forward pass or loss computation fails
695    ///
696    /// # Panics
697    /// Panics if the `VarMap` mutex is poisoned.
698    pub fn training_step_lm(
699        &mut self,
700        layers: &[&QuantizedLinear],
701        input: &Tensor,
702        target_ids: &Tensor,
703    ) -> Result<f64> {
704        // Forward pass through all layers
705        let mut logits = input.clone();
706        for layer in layers {
707            logits = layer.forward(&logits)?;
708        }
709
710        // Cross-entropy loss for language modeling
711        let loss = cross_entropy_loss(&logits, target_ids)?;
712        let loss_value = f64::from(loss.to_scalar::<f32>()?);
713
714        // Backward pass - handle both optimizer types
715        if let Some(ref mut optimizer) = self.optimizer {
716            optimizer.backward_step(&loss)?;
717        } else if let Some(ref mut paged_optimizer) = self.paged_optimizer {
718            // Compute gradients and step with paged optimizer
719            let grads = loss.backward()?;
720
721            let mut varmap_data = self.varmap.data().lock().unwrap();
722            for (name, var) in varmap_data.iter_mut() {
723                if let Some(grad) = grads.get(var.as_tensor()) {
724                    let mut param = var.as_tensor().clone();
725                    paged_optimizer.step_param(name, &mut param, grad)?;
726                }
727            }
728            drop(varmap_data);
729        }
730
731        // Update state
732        self.state.step();
733
734        Ok(loss_value)
735    }
736
737    /// Start a new training epoch.
738    pub fn start_epoch(&mut self) {
739        self.state.new_epoch();
740        self.accumulation_step = 0;
741        #[cfg(feature = "logging")]
742        log::info!("Starting epoch {}", self.state.epoch);
743    }
744
745    /// Check if training should continue.
746    #[must_use]
747    pub fn should_continue(&self) -> bool {
748        self.state.epoch < self.config.num_epochs
749    }
750
751    /// Update learning rate based on schedule.
752    pub fn update_lr(&mut self) {
753        let lr = self.current_lr();
754        if let Some(ref mut optimizer) = self.optimizer {
755            optimizer.set_learning_rate(lr);
756        }
757        if let Some(ref mut paged) = self.paged_optimizer {
758            paged.set_lr(lr);
759        }
760    }
761
762    /// Get training configuration.
763    #[must_use]
764    pub fn config(&self) -> &QLoraTrainingConfig {
765        &self.config
766    }
767
768    /// Get optimizer memory statistics (CPU bytes, GPU bytes).
769    #[must_use]
770    pub fn optimizer_memory_stats(&self) -> Option<(usize, usize)> {
771        self.paged_optimizer.as_ref().map(PagedAdamW::memory_stats)
772    }
773
774    /// Zero gradients for next accumulation cycle.
775    ///
776    /// Resets the accumulation step counter. Note: In candle, gradients are
777    /// automatically zeroed when `backward_step` is called on the optimizer.
778    pub fn zero_grad(&mut self) {
779        self.accumulation_step = 0;
780    }
781}
782
783/// Compute cross-entropy loss for language modeling.
784///
785/// # Arguments
786/// * `logits` - Model output logits `[batch, seq_len, vocab_size]`
787/// * `targets` - Target token IDs `[batch, seq_len]`
788///
789/// # Returns
790/// Cross-entropy loss value
791///
792/// # Errors
793/// Returns error if tensor operations fail
794pub fn cross_entropy_loss(logits: &Tensor, targets: &Tensor) -> Result<Tensor> {
795    let (batch, seq_len, vocab_size) = logits.dims3()?;
796
797    // Reshape logits to [batch * seq_len, vocab_size]
798    let flat_logits = logits.reshape(&[batch * seq_len, vocab_size])?;
799
800    // Reshape targets to [batch * seq_len]
801    let flat_targets = targets.reshape(&[batch * seq_len])?;
802
803    // Compute log softmax
804    let log_probs = candle_nn::ops::log_softmax(&flat_logits, 1)?;
805
806    // Gather log probs at target indices
807    let target_indices = flat_targets.unsqueeze(1)?;
808    let gathered = log_probs.gather(&target_indices, 1)?;
809
810    // Mean negative log likelihood
811    let loss = gathered.neg()?.mean_all()?;
812
813    Ok(loss)
814}
815
816/// Training metrics for logging.
817#[derive(Debug, Clone, Default)]
818pub struct TrainingMetrics {
819    /// Total training loss.
820    pub total_loss: f64,
821    /// Number of steps.
822    pub num_steps: usize,
823    /// Best loss seen.
824    pub best_loss: f64,
825    /// Tokens processed.
826    pub tokens_processed: usize,
827}
828
829impl TrainingMetrics {
830    /// Create new metrics tracker.
831    #[must_use]
832    pub fn new() -> Self {
833        Self {
834            total_loss: 0.0,
835            num_steps: 0,
836            best_loss: f64::MAX,
837            tokens_processed: 0,
838        }
839    }
840
841    /// Update metrics with a new loss value.
842    pub fn update(&mut self, loss: f64, num_tokens: usize) {
843        self.total_loss += loss;
844        self.num_steps += 1;
845        self.tokens_processed += num_tokens;
846        if loss < self.best_loss {
847            self.best_loss = loss;
848        }
849    }
850
851    /// Get average loss.
852    #[must_use]
853    #[allow(clippy::cast_precision_loss)]
854    pub fn average_loss(&self) -> f64 {
855        if self.num_steps == 0 {
856            0.0
857        } else {
858            self.total_loss / self.num_steps as f64
859        }
860    }
861
862    /// Reset metrics for new epoch.
863    pub fn reset(&mut self) {
864        self.total_loss = 0.0;
865        self.num_steps = 0;
866        self.tokens_processed = 0;
867        // Keep best_loss across epochs
868    }
869}
870
871#[cfg(test)]
872mod tests {
873    use super::*;
874    use candle_core::DType;
875
876    #[test]
877    fn test_training_config_default() {
878        let config = QLoraTrainingConfig::default();
879        assert_eq!(config.num_epochs, 3);
880        assert_eq!(config.batch_size, 4);
881        assert!((config.adapter_config.learning_rate - 2e-4).abs() < 1e-10);
882    }
883
884    #[test]
885    fn test_trainer_creation() {
886        let config = QLoraTrainingConfig::default();
887        let device = Device::Cpu;
888        let trainer = QLoraTrainer::new(config, device);
889
890        assert_eq!(trainer.global_step(), 0);
891        assert_eq!(trainer.epoch(), 0);
892    }
893
894    #[test]
895    fn test_training_metrics() {
896        let mut metrics = TrainingMetrics::new();
897
898        metrics.update(0.5, 128);
899        metrics.update(0.4, 128);
900        metrics.update(0.3, 128);
901
902        assert_eq!(metrics.num_steps, 3);
903        assert!((metrics.average_loss() - 0.4).abs() < 1e-10);
904        assert!((metrics.best_loss - 0.3).abs() < 1e-10);
905    }
906
907    #[test]
908    fn test_cross_entropy_loss_shape() {
909        let device = Device::Cpu;
910        let batch = 2;
911        let seq_len = 10;
912        let vocab_size = 100;
913
914        let logits = Tensor::zeros(&[batch, seq_len, vocab_size], DType::F32, &device).unwrap();
915        // Random targets (0-99)
916        let targets = Tensor::zeros(&[batch, seq_len], DType::U32, &device).unwrap();
917
918        let loss = cross_entropy_loss(&logits, &targets).unwrap();
919        // Loss should be scalar
920        let dims: &[usize] = loss.dims();
921        assert!(dims.is_empty(), "Expected scalar loss, got dims: {dims:?}");
922    }
923
924    #[test]
925    fn test_trainer_epoch_progression() {
926        let config = QLoraTrainingConfig {
927            num_epochs: 2,
928            ..Default::default()
929        };
930        let device = Device::Cpu;
931        let mut trainer = QLoraTrainer::new(config, device);
932
933        assert!(trainer.should_continue());
934        trainer.start_epoch();
935        assert_eq!(trainer.epoch(), 1);
936        assert!(trainer.should_continue());
937        trainer.start_epoch();
938        assert_eq!(trainer.epoch(), 2);
939        assert!(!trainer.should_continue());
940    }
941}