1use candle_core::{DType, Device, Tensor};
21use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
22use peft_rs::training::{AdapterTrainingConfig, AdapterTrainingState, LrSchedule};
23use std::collections::HashMap;
24
25use crate::error::{QLoraError, Result};
26use crate::qlora::QuantizedLinear;
27
28#[derive(Debug, Clone)]
30pub struct QLoraTrainingConfig {
31 pub adapter_config: AdapterTrainingConfig,
33 pub num_epochs: usize,
35 pub batch_size: usize,
37 pub log_every: usize,
39 pub save_every: Option<usize>,
41 pub warmup_steps: usize,
43 pub use_paged_optimizer: bool,
45 pub page_size: usize,
47 pub max_optimizer_memory: usize,
49}
50
51impl Default for QLoraTrainingConfig {
52 fn default() -> Self {
53 Self {
54 adapter_config: AdapterTrainingConfig {
55 learning_rate: 2e-4,
56 lr_schedule: LrSchedule::LinearWarmup { warmup_steps: 100 },
57 weight_decay: 0.01,
58 gradient_accumulation_steps: 4,
59 max_grad_norm: Some(1.0),
60 },
61 num_epochs: 3,
62 batch_size: 4,
63 log_every: 10,
64 save_every: Some(500),
65 warmup_steps: 100,
66 use_paged_optimizer: true,
67 page_size: 1024 * 1024, max_optimizer_memory: 0, }
70 }
71}
72
73#[derive(Debug)]
81pub struct PagedAdamWState {
82 pub exp_avg: HashMap<String, Tensor>,
84 pub exp_avg_sq: HashMap<String, Tensor>,
86 pub steps: HashMap<String, usize>,
88 pub page_size: usize,
90 gpu_resident: std::collections::HashSet<String>,
92 access_order: Vec<String>,
94 pub max_gpu_memory: usize,
96 pub current_gpu_usage: usize,
98}
99
100impl PagedAdamWState {
101 #[must_use]
103 pub fn new(page_size: usize, max_gpu_memory: usize) -> Self {
104 Self {
105 exp_avg: HashMap::new(),
106 exp_avg_sq: HashMap::new(),
107 steps: HashMap::new(),
108 page_size,
109 gpu_resident: std::collections::HashSet::new(),
110 access_order: Vec::new(),
111 max_gpu_memory,
112 current_gpu_usage: 0,
113 }
114 }
115
116 pub fn init_param(&mut self, name: &str, shape: &[usize], _device: &Device) -> Result<()> {
121 let cpu_device = Device::Cpu;
123 let exp_avg = Tensor::zeros(shape, DType::F32, &cpu_device)?;
124 let exp_avg_sq = Tensor::zeros(shape, DType::F32, &cpu_device)?;
125
126 self.exp_avg.insert(name.to_string(), exp_avg);
127 self.exp_avg_sq.insert(name.to_string(), exp_avg_sq);
128 self.steps.insert(name.to_string(), 0);
129 Ok(())
133 }
134
135 #[allow(clippy::if_not_else, clippy::excessive_nesting)]
142 pub fn page_to_device(&mut self, name: &str, device: &Device) -> Result<(Tensor, Tensor)> {
143 let exp_avg = self
144 .exp_avg
145 .get(name)
146 .ok_or_else(|| QLoraError::InvalidConfig(format!("No state for param: {name}")))?;
147 let exp_avg_sq = self
148 .exp_avg_sq
149 .get(name)
150 .ok_or_else(|| QLoraError::InvalidConfig(format!("No state for param: {name}")))?;
151
152 if !self.gpu_resident.contains(name) {
154 let param_bytes = exp_avg.elem_count() * 4 * 2; if self.max_gpu_memory > 0 {
158 while self.current_gpu_usage + param_bytes > self.max_gpu_memory
159 && !self.access_order.is_empty()
160 {
161 if let Some(lru_name) = self.access_order.first().cloned() {
163 if lru_name != name {
164 self.gpu_resident.remove(&lru_name);
165 self.access_order.retain(|n| n != &lru_name);
166 let lru_bytes = self
167 .exp_avg
168 .get(&lru_name)
169 .map_or(0, |t| t.elem_count() * 4 * 2);
170 self.current_gpu_usage =
171 self.current_gpu_usage.saturating_sub(lru_bytes);
172 } else {
173 break; }
175 }
176 }
177 }
178
179 self.gpu_resident.insert(name.to_string());
180 self.current_gpu_usage += param_bytes;
181 }
182
183 self.access_order.retain(|n| n != name);
185 self.access_order.push(name.to_string());
186
187 Ok((exp_avg.to_device(device)?, exp_avg_sq.to_device(device)?))
188 }
189
190 pub fn page_to_cpu(&mut self, name: &str, exp_avg: &Tensor, exp_avg_sq: &Tensor) -> Result<()> {
197 if self.gpu_resident.remove(name) {
199 let param_bytes = exp_avg.elem_count() * 4 * 2; self.current_gpu_usage = self.current_gpu_usage.saturating_sub(param_bytes);
201 self.access_order.retain(|n| n != name);
202 }
203
204 self.exp_avg
205 .insert(name.to_string(), exp_avg.to_device(&Device::Cpu)?);
206 self.exp_avg_sq
207 .insert(name.to_string(), exp_avg_sq.to_device(&Device::Cpu)?);
208 Ok(())
209 }
210
211 pub fn increment_step(&mut self, name: &str) {
213 if let Some(step) = self.steps.get_mut(name) {
214 *step += 1;
215 }
216 }
217
218 #[must_use]
220 pub fn get_step(&self, name: &str) -> usize {
221 self.steps.get(name).copied().unwrap_or(0)
222 }
223
224 #[must_use]
226 pub fn is_gpu_resident(&self, name: &str) -> bool {
227 self.gpu_resident.contains(name)
228 }
229
230 #[must_use]
232 pub fn gpu_resident_count(&self) -> usize {
233 self.gpu_resident.len()
234 }
235}
236
237pub struct PagedAdamW {
248 lr: f64,
250 beta1: f64,
252 beta2: f64,
254 eps: f64,
256 weight_decay: f64,
258 state: PagedAdamWState,
260 initialized: bool,
262}
263
264impl PagedAdamW {
265 #[must_use]
273 pub fn new(lr: f64, weight_decay: f64, page_size: usize, max_gpu_memory: usize) -> Self {
274 Self {
275 lr,
276 beta1: 0.9,
277 beta2: 0.999,
278 eps: 1e-8,
279 weight_decay,
280 state: PagedAdamWState::new(page_size, max_gpu_memory),
281 initialized: false,
282 }
283 }
284
285 #[must_use]
287 pub fn with_betas(mut self, beta1: f64, beta2: f64) -> Self {
288 self.beta1 = beta1;
289 self.beta2 = beta2;
290 self
291 }
292
293 pub fn init(&mut self, params: &[(String, Tensor)]) -> Result<()> {
298 for (name, param) in params {
299 let shape = param.shape().dims();
300 self.state.init_param(name, shape, param.device())?;
301 }
302 self.initialized = true;
303 Ok(())
304 }
305
306 pub fn set_lr(&mut self, lr: f64) {
308 self.lr = lr;
309 }
310
311 #[must_use]
313 pub fn lr(&self) -> f64 {
314 self.lr
315 }
316
317 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
331 pub fn step_param(&mut self, name: &str, param: &mut Tensor, grad: &Tensor) -> Result<()> {
332 let device = param.device().clone();
333
334 let (mut exp_avg, mut exp_avg_sq) = self.state.page_to_device(name, &device)?;
336
337 self.state.increment_step(name);
339 let step = self.state.get_step(name);
340
341 let beta1_tensor = Tensor::new(self.beta1 as f32, &device)?;
343 let one_minus_beta1 = Tensor::new((1.0 - self.beta1) as f32, &device)?;
344 exp_avg = exp_avg
345 .broadcast_mul(&beta1_tensor)?
346 .broadcast_add(&grad.broadcast_mul(&one_minus_beta1)?)?;
347
348 let beta2_tensor = Tensor::new(self.beta2 as f32, &device)?;
350 let one_minus_beta2 = Tensor::new((1.0 - self.beta2) as f32, &device)?;
351 let grad_sq = grad.sqr()?;
352 exp_avg_sq = exp_avg_sq
353 .broadcast_mul(&beta2_tensor)?
354 .broadcast_add(&grad_sq.broadcast_mul(&one_minus_beta2)?)?;
355
356 let bias_correction1 = 1.0 - self.beta1.powi(step as i32);
358 let bias_correction2 = 1.0 - self.beta2.powi(step as i32);
359
360 let bc1_tensor = Tensor::new(bias_correction1 as f32, &device)?;
361 let bc2_tensor = Tensor::new(bias_correction2 as f32, &device)?;
362
363 let exp_avg_corrected = exp_avg.broadcast_div(&bc1_tensor)?;
365 let exp_avg_sq_corrected = exp_avg_sq.broadcast_div(&bc2_tensor)?;
366
367 let denom = exp_avg_sq_corrected
368 .sqrt()?
369 .broadcast_add(&Tensor::new(self.eps as f32, &device)?)?;
370 let step_size = Tensor::new(self.lr as f32, &device)?;
371
372 let update = exp_avg_corrected.broadcast_div(&denom)?;
374 let weight_decay_term =
375 param.broadcast_mul(&Tensor::new(self.weight_decay as f32, &device)?)?;
376 let full_update = update
377 .broadcast_add(&weight_decay_term)?
378 .broadcast_mul(&step_size)?;
379
380 *param = param.broadcast_sub(&full_update)?;
382
383 self.state.page_to_cpu(name, &exp_avg, &exp_avg_sq)?;
385
386 Ok(())
387 }
388
389 #[must_use]
391 pub fn memory_stats(&self) -> (usize, usize) {
392 let cpu_bytes: usize = self
393 .state
394 .exp_avg
395 .values()
396 .chain(self.state.exp_avg_sq.values())
397 .map(|t| t.elem_count() * 4)
398 .sum();
399 (cpu_bytes, self.state.current_gpu_usage)
400 }
401}
402
403pub struct QLoraTrainer {
415 config: QLoraTrainingConfig,
417 state: AdapterTrainingState,
419 device: Device,
421 varmap: VarMap,
423 optimizer: Option<AdamW>,
425 paged_optimizer: Option<PagedAdamW>,
427 accumulation_step: usize,
429}
430
431impl QLoraTrainer {
432 #[must_use]
441 pub fn new(config: QLoraTrainingConfig, device: Device) -> Self {
442 let state = AdapterTrainingState::new(config.adapter_config.clone());
443 Self {
444 config,
445 state,
446 device,
447 varmap: VarMap::new(),
448 optimizer: None,
449 paged_optimizer: None,
450 accumulation_step: 0,
451 }
452 }
453
454 #[must_use]
468 pub fn var_builder(&self) -> VarBuilder<'_> {
469 VarBuilder::from_varmap(&self.varmap, DType::F32, &self.device)
470 }
471
472 pub fn init_optimizer(&mut self, layers: &[&QuantizedLinear]) -> Result<()> {
492 if self.config.use_paged_optimizer {
493 let mut paged = PagedAdamW::new(
495 self.config.adapter_config.learning_rate,
496 self.config.adapter_config.weight_decay,
497 self.config.page_size,
498 self.config.max_optimizer_memory,
499 );
500
501 let vars = self.varmap.all_vars();
503 if vars.is_empty() {
504 return Err(QLoraError::InvalidConfig(
505 "No trainable parameters found. Layers must be created using trainer.var_builder() \
506 so `LoRA` weights are registered in the `VarMap`.".into()
507 ));
508 }
509
510 let params: Vec<(String, Tensor)> = self
512 .varmap
513 .data()
514 .lock()
515 .unwrap()
516 .iter()
517 .map(|(name, var)| (name.clone(), var.as_tensor().clone()))
518 .collect();
519
520 paged.init(¶ms)?;
521 self.paged_optimizer = Some(paged);
522
523 let _ = layers.len();
525 } else {
526 let vars = self.varmap.all_vars();
528 if vars.is_empty() {
529 return Err(QLoraError::InvalidConfig(
530 "No trainable parameters found. Layers must be created using trainer.var_builder() \
531 so `LoRA` weights are registered in the `VarMap`.".into()
532 ));
533 }
534
535 let params = ParamsAdamW {
536 lr: self.config.adapter_config.learning_rate,
537 weight_decay: self.config.adapter_config.weight_decay,
538 beta1: 0.9,
539 beta2: 0.999,
540 eps: 1e-8,
541 };
542 self.optimizer = Some(AdamW::new(vars, params)?);
543 }
544 Ok(())
545 }
546
547 #[must_use]
549 pub fn state(&self) -> &AdapterTrainingState {
550 &self.state
551 }
552
553 #[must_use]
555 pub fn current_lr(&self) -> f64 {
556 self.state.current_lr()
557 }
558
559 #[must_use]
561 pub fn global_step(&self) -> usize {
562 self.state.global_step
563 }
564
565 #[must_use]
567 pub fn epoch(&self) -> usize {
568 self.state.epoch
569 }
570
571 #[allow(clippy::cast_precision_loss, clippy::excessive_nesting)]
596 pub fn training_step(
597 &mut self,
598 layers: &[&QuantizedLinear],
599 input: &Tensor,
600 targets: &Tensor,
601 ) -> Result<f64> {
602 let mut output = input.clone();
604 for layer in layers {
605 output = layer.forward(&output)?;
606 }
607
608 let loss = output.sub(targets)?.sqr()?.mean_all()?;
610
611 let accum_steps = self.config.adapter_config.gradient_accumulation_steps;
613 let scaled_loss = if accum_steps > 1 {
614 let scale = Tensor::new(1.0 / accum_steps as f32, loss.device())?;
615 loss.broadcast_mul(&scale)?
616 } else {
617 loss.clone()
618 };
619
620 let loss_value = f64::from(loss.to_scalar::<f32>()?);
621
622 self.accumulation_step += 1;
624
625 if let Some(ref mut optimizer) = self.optimizer {
627 if self.accumulation_step >= accum_steps {
628 if let Some(max_norm) = self.config.adapter_config.max_grad_norm {
630 let _ = max_norm; }
633
634 optimizer.backward_step(&scaled_loss)?;
636 self.accumulation_step = 0;
637 } else {
638 let _ = scaled_loss.backward();
641 }
642 } else if let Some(ref mut paged_optimizer) = self.paged_optimizer {
643 if self.accumulation_step >= accum_steps {
645 let grads = scaled_loss.backward()?;
647
648 let mut varmap_data = self.varmap.data().lock().unwrap();
650 for (name, var) in varmap_data.iter_mut() {
651 if let Some(grad) = grads.get(var.as_tensor()) {
652 let mut param = var.as_tensor().clone();
653 paged_optimizer.step_param(name, &mut param, grad)?;
654 }
657 }
658 drop(varmap_data);
659 self.accumulation_step = 0;
660 } else {
661 let _ = scaled_loss.backward();
663 }
664 }
665
666 let should_log = self.state.step();
668 if should_log && self.state.global_step.is_multiple_of(self.config.log_every) {
669 #[cfg(feature = "logging")]
670 log::info!(
671 "Step {} | Loss: {:.4} | LR: {:.2e}",
672 self.state.global_step,
673 loss_value,
674 self.current_lr()
675 );
676 }
677
678 Ok(loss_value)
679 }
680
681 pub fn training_step_lm(
699 &mut self,
700 layers: &[&QuantizedLinear],
701 input: &Tensor,
702 target_ids: &Tensor,
703 ) -> Result<f64> {
704 let mut logits = input.clone();
706 for layer in layers {
707 logits = layer.forward(&logits)?;
708 }
709
710 let loss = cross_entropy_loss(&logits, target_ids)?;
712 let loss_value = f64::from(loss.to_scalar::<f32>()?);
713
714 if let Some(ref mut optimizer) = self.optimizer {
716 optimizer.backward_step(&loss)?;
717 } else if let Some(ref mut paged_optimizer) = self.paged_optimizer {
718 let grads = loss.backward()?;
720
721 let mut varmap_data = self.varmap.data().lock().unwrap();
722 for (name, var) in varmap_data.iter_mut() {
723 if let Some(grad) = grads.get(var.as_tensor()) {
724 let mut param = var.as_tensor().clone();
725 paged_optimizer.step_param(name, &mut param, grad)?;
726 }
727 }
728 drop(varmap_data);
729 }
730
731 self.state.step();
733
734 Ok(loss_value)
735 }
736
737 pub fn start_epoch(&mut self) {
739 self.state.new_epoch();
740 self.accumulation_step = 0;
741 #[cfg(feature = "logging")]
742 log::info!("Starting epoch {}", self.state.epoch);
743 }
744
745 #[must_use]
747 pub fn should_continue(&self) -> bool {
748 self.state.epoch < self.config.num_epochs
749 }
750
751 pub fn update_lr(&mut self) {
753 let lr = self.current_lr();
754 if let Some(ref mut optimizer) = self.optimizer {
755 optimizer.set_learning_rate(lr);
756 }
757 if let Some(ref mut paged) = self.paged_optimizer {
758 paged.set_lr(lr);
759 }
760 }
761
762 #[must_use]
764 pub fn config(&self) -> &QLoraTrainingConfig {
765 &self.config
766 }
767
768 #[must_use]
770 pub fn optimizer_memory_stats(&self) -> Option<(usize, usize)> {
771 self.paged_optimizer.as_ref().map(PagedAdamW::memory_stats)
772 }
773
774 pub fn zero_grad(&mut self) {
779 self.accumulation_step = 0;
780 }
781}
782
783pub fn cross_entropy_loss(logits: &Tensor, targets: &Tensor) -> Result<Tensor> {
795 let (batch, seq_len, vocab_size) = logits.dims3()?;
796
797 let flat_logits = logits.reshape(&[batch * seq_len, vocab_size])?;
799
800 let flat_targets = targets.reshape(&[batch * seq_len])?;
802
803 let log_probs = candle_nn::ops::log_softmax(&flat_logits, 1)?;
805
806 let target_indices = flat_targets.unsqueeze(1)?;
808 let gathered = log_probs.gather(&target_indices, 1)?;
809
810 let loss = gathered.neg()?.mean_all()?;
812
813 Ok(loss)
814}
815
816#[derive(Debug, Clone, Default)]
818pub struct TrainingMetrics {
819 pub total_loss: f64,
821 pub num_steps: usize,
823 pub best_loss: f64,
825 pub tokens_processed: usize,
827}
828
829impl TrainingMetrics {
830 #[must_use]
832 pub fn new() -> Self {
833 Self {
834 total_loss: 0.0,
835 num_steps: 0,
836 best_loss: f64::MAX,
837 tokens_processed: 0,
838 }
839 }
840
841 pub fn update(&mut self, loss: f64, num_tokens: usize) {
843 self.total_loss += loss;
844 self.num_steps += 1;
845 self.tokens_processed += num_tokens;
846 if loss < self.best_loss {
847 self.best_loss = loss;
848 }
849 }
850
851 #[must_use]
853 #[allow(clippy::cast_precision_loss)]
854 pub fn average_loss(&self) -> f64 {
855 if self.num_steps == 0 {
856 0.0
857 } else {
858 self.total_loss / self.num_steps as f64
859 }
860 }
861
862 pub fn reset(&mut self) {
864 self.total_loss = 0.0;
865 self.num_steps = 0;
866 self.tokens_processed = 0;
867 }
869}
870
871#[cfg(test)]
872mod tests {
873 use super::*;
874 use candle_core::DType;
875
876 #[test]
877 fn test_training_config_default() {
878 let config = QLoraTrainingConfig::default();
879 assert_eq!(config.num_epochs, 3);
880 assert_eq!(config.batch_size, 4);
881 assert!((config.adapter_config.learning_rate - 2e-4).abs() < 1e-10);
882 }
883
884 #[test]
885 fn test_trainer_creation() {
886 let config = QLoraTrainingConfig::default();
887 let device = Device::Cpu;
888 let trainer = QLoraTrainer::new(config, device);
889
890 assert_eq!(trainer.global_step(), 0);
891 assert_eq!(trainer.epoch(), 0);
892 }
893
894 #[test]
895 fn test_training_metrics() {
896 let mut metrics = TrainingMetrics::new();
897
898 metrics.update(0.5, 128);
899 metrics.update(0.4, 128);
900 metrics.update(0.3, 128);
901
902 assert_eq!(metrics.num_steps, 3);
903 assert!((metrics.average_loss() - 0.4).abs() < 1e-10);
904 assert!((metrics.best_loss - 0.3).abs() < 1e-10);
905 }
906
907 #[test]
908 fn test_cross_entropy_loss_shape() {
909 let device = Device::Cpu;
910 let batch = 2;
911 let seq_len = 10;
912 let vocab_size = 100;
913
914 let logits = Tensor::zeros(&[batch, seq_len, vocab_size], DType::F32, &device).unwrap();
915 let targets = Tensor::zeros(&[batch, seq_len], DType::U32, &device).unwrap();
917
918 let loss = cross_entropy_loss(&logits, &targets).unwrap();
919 let dims: &[usize] = loss.dims();
921 assert!(dims.is_empty(), "Expected scalar loss, got dims: {dims:?}");
922 }
923
924 #[test]
925 fn test_trainer_epoch_progression() {
926 let config = QLoraTrainingConfig {
927 num_epochs: 2,
928 ..Default::default()
929 };
930 let device = Device::Cpu;
931 let mut trainer = QLoraTrainer::new(config, device);
932
933 assert!(trainer.should_continue());
934 trainer.start_epoch();
935 assert_eq!(trainer.epoch(), 1);
936 assert!(trainer.should_continue());
937 trainer.start_epoch();
938 assert_eq!(trainer.epoch(), 2);
939 assert!(!trainer.should_continue());
940 }
941}