1#![allow(clippy::doc_markdown)]
9#![allow(clippy::cast_possible_truncation)]
10#![allow(clippy::cast_precision_loss)]
11#![allow(clippy::needless_pass_by_value)]
12
13use candle_core::{DType, Device, Module, Tensor};
14use candle_nn::{linear_no_bias, Linear, VarBuilder, VarMap};
15use serde::{Deserialize, Serialize};
16
17use crate::error::{PeftError, Result};
18use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
19
20fn warn_cpu_fallback(device: &Device) {
21 static WARN_ONCE: std::sync::Once = std::sync::Once::new();
22 if matches!(device, Device::Cpu) {
23 WARN_ONCE.call_once(|| {
24 eprintln!(
25 "peft-rs: CPU device in use. CUDA is the intended default; enable the 'cuda' feature and use Device::cuda_if_available(0) when possible."
26 );
27 });
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct LoraConfig {
34 pub r: usize,
36
37 pub alpha: usize,
39
40 #[serde(default)]
42 pub dropout: f64,
43
44 #[serde(default = "default_target_modules")]
46 pub target_modules: Vec<String>,
47
48 #[serde(default)]
50 pub init_lora_weights: LoraInitialization,
51
52 #[serde(default)]
55 pub use_dora: bool,
56
57 #[serde(default)]
64 pub use_rslora: bool,
65
66 #[serde(default)]
71 pub loftq_iterations: usize,
72}
73
74fn default_target_modules() -> Vec<String> {
75 vec!["q_proj".into(), "v_proj".into()]
76}
77
78#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
80pub enum LoraInitialization {
81 #[default]
83 Standard,
84 Gaussian,
86}
87
88impl Default for LoraConfig {
89 fn default() -> Self {
90 Self {
91 r: 8,
92 alpha: 16,
93 dropout: 0.0,
94 target_modules: default_target_modules(),
95 init_lora_weights: LoraInitialization::Standard,
96 use_dora: false,
97 use_rslora: false,
98 loftq_iterations: 0,
99 }
100 }
101}
102
103impl AdapterConfig for LoraConfig {
104 fn validate(&self) -> Result<()> {
105 if self.r == 0 {
106 return Err(PeftError::InvalidConfig("rank must be > 0".into()));
107 }
108 if self.alpha == 0 {
109 return Err(PeftError::InvalidConfig("alpha must be > 0".into()));
110 }
111 if !(0.0..=1.0).contains(&self.dropout) {
112 return Err(PeftError::InvalidConfig(
113 "dropout must be between 0 and 1".into(),
114 ));
115 }
116 Ok(())
117 }
118}
119
120pub struct LoraLayer {
124 lora_a: Linear,
126 lora_b: Linear,
128 scaling: f64,
130 config: LoraConfig,
132 in_features: usize,
134 out_features: usize,
136 frozen: bool,
138}
139
140impl LoraLayer {
141 pub fn new(
152 in_features: usize,
153 out_features: usize,
154 config: LoraConfig,
155 vb: VarBuilder,
156 ) -> Result<Self> {
157 config.validate()?;
158
159 let scaling = if config.use_rslora {
161 config.alpha as f64 / (config.r as f64).sqrt()
162 } else {
163 config.alpha as f64 / config.r as f64
164 };
165
166 let lora_a = linear_no_bias(in_features, config.r, vb.pp("lora_a"))?;
168
169 let lora_b = linear_no_bias(config.r, out_features, vb.pp("lora_b"))?;
171
172 Ok(Self {
173 lora_a,
174 lora_b,
175 scaling,
176 config,
177 in_features,
178 out_features,
179 frozen: false,
180 })
181 }
182
183 pub fn new_with_zeros(
194 in_features: usize,
195 out_features: usize,
196 config: LoraConfig,
197 device: &Device,
198 ) -> Result<Self> {
199 config.validate()?;
200 warn_cpu_fallback(device);
201
202 let scaling = if config.use_rslora {
204 config.alpha as f64 / (config.r as f64).sqrt()
205 } else {
206 config.alpha as f64 / config.r as f64
207 };
208 let dtype = DType::F32;
209
210 let (a_weight, b_weight) = if config.loftq_iterations > 0 {
212 let std = (1.0 / in_features as f64).sqrt() * 0.1;
215 let a = Tensor::randn(0.0f32, std as f32, (config.r, in_features), device)?;
216 let b = Tensor::randn(0.0f32, std as f32, (out_features, config.r), device)?;
217 (a, b)
218 } else {
219 let std = (1.0 / in_features as f64).sqrt();
221 let a = Tensor::randn(0.0f32, std as f32, (config.r, in_features), device)?;
222 let b = Tensor::zeros((out_features, config.r), dtype, device)?;
223 (a, b)
224 };
225
226 let lora_a = Linear::new(a_weight, None);
227 let lora_b = Linear::new(b_weight, None);
228
229 Ok(Self {
230 lora_a,
231 lora_b,
232 scaling,
233 config,
234 in_features,
235 out_features,
236 frozen: false,
237 })
238 }
239
240 #[must_use]
242 pub fn scaling(&self) -> f64 {
243 self.scaling
244 }
245
246 #[must_use]
248 pub fn rank(&self) -> usize {
249 self.config.r
250 }
251
252 #[must_use]
258 pub fn weights(&self) -> (&Tensor, &Tensor) {
259 (self.lora_a.weight(), self.lora_b.weight())
260 }
261}
262
263impl Adapter for LoraLayer {
264 type Config = LoraConfig;
265
266 fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
267 let lora_out = self.lora_a.forward(input)?;
269 let lora_out = self.lora_b.forward(&lora_out)?;
270 let scaling = Tensor::new(self.scaling as f32, lora_out.device())?;
271 let lora_out = lora_out.broadcast_mul(&scaling)?;
272
273 match base_output {
275 Some(base) => Ok(base.broadcast_add(&lora_out)?),
276 None => Ok(lora_out),
277 }
278 }
279
280 fn num_parameters(&self) -> usize {
281 self.config.r * (self.in_features + self.out_features)
282 }
283
284 fn config(&self) -> &Self::Config {
285 &self.config
286 }
287}
288
289impl Mergeable for LoraLayer {
290 fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
291 let a_weight = self.lora_a.weight();
294 let b_weight = self.lora_b.weight();
295
296 let delta_w = b_weight.matmul(a_weight)?;
297 let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
298 let delta_w = delta_w.broadcast_mul(&scaling)?;
299
300 Ok(base_weight.broadcast_add(&delta_w)?)
301 }
302
303 fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
304 let a_weight = self.lora_a.weight();
305 let b_weight = self.lora_b.weight();
306
307 let delta_w = b_weight.matmul(a_weight)?;
308 let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
309 let delta_w = delta_w.broadcast_mul(&scaling)?;
310
311 Ok(merged_weight.broadcast_sub(&delta_w)?)
312 }
313}
314
315impl Trainable for LoraLayer {
316 fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
317 Ok(())
319 }
320
321 fn freeze(&mut self) {
322 self.frozen = true;
323 }
324
325 fn unfreeze(&mut self) {
326 self.frozen = false;
327 }
328
329 fn is_frozen(&self) -> bool {
330 self.frozen
331 }
332}
333
334pub struct DoraLayer {
346 lora: LoraLayer,
348 magnitude: Tensor,
350 base_weight: Option<Tensor>,
352}
353
354impl DoraLayer {
355 pub fn new(
368 in_features: usize,
369 out_features: usize,
370 config: LoraConfig,
371 device: &Device,
372 base_weight: Option<&Tensor>,
373 ) -> Result<Self> {
374 let lora = LoraLayer::new_with_zeros(in_features, out_features, config, device)?;
376
377 let magnitude = if let Some(weight) = base_weight {
381 weight.sqr()?.sum(1)?.sqrt()?
383 } else {
384 Tensor::ones(out_features, DType::F32, device)?
385 };
386
387 Ok(Self {
388 lora,
389 magnitude,
390 base_weight: base_weight.cloned(),
391 })
392 }
393
394 #[must_use]
396 pub fn magnitude(&self) -> &Tensor {
397 &self.magnitude
398 }
399
400 #[must_use]
402 pub fn lora_layer(&self) -> &LoraLayer {
403 &self.lora
404 }
405
406 pub fn set_base_weight(&mut self, weight: Tensor) {
408 self.base_weight = Some(weight);
409 }
410
411 fn compute_direction(&self, base_weight: &Tensor) -> Result<Tensor> {
414 let a_weight = self.lora.lora_a.weight();
416 let b_weight = self.lora.lora_b.weight();
417 let delta_w = b_weight.matmul(a_weight)?;
418 #[allow(clippy::cast_possible_truncation)]
419 let scaling = Tensor::new(self.lora.scaling as f32, delta_w.device())?;
420 let delta_w = delta_w.broadcast_mul(&scaling)?;
421
422 let combined = base_weight.broadcast_add(&delta_w)?;
424
425 let norms = combined.sqr()?.sum(1)?.sqrt()?;
427 let norms = norms.reshape((self.lora.out_features, 1))?;
428
429 let epsilon = Tensor::new(1e-8_f32, norms.device())?;
432 let safe_norms = norms.broadcast_add(&epsilon)?;
433
434 Ok(combined.broadcast_div(&safe_norms)?)
435 }
436}
437
438impl Adapter for DoraLayer {
439 type Config = LoraConfig;
440
441 fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
442 if let (Some(base_weight), Some(_base_out)) = (&self.base_weight, base_output) {
445 let direction = self.compute_direction(base_weight)?;
447
448 let input_dims = input.dims();
451 let batch_seq = input_dims[0] * input_dims[1];
452 let input_2d = input.reshape((batch_seq, self.lora.in_features))?;
453
454 let out = input_2d.matmul(&direction.t()?)?;
456
457 let mag_2d = self.magnitude.reshape((1, self.lora.out_features))?;
459 let out = out.broadcast_mul(&mag_2d)?;
460
461 let out = out.reshape((input_dims[0], input_dims[1], self.lora.out_features))?;
463
464 Ok(out)
467 } else {
468 self.lora.forward(input, base_output)
470 }
471 }
472
473 fn num_parameters(&self) -> usize {
474 self.lora.num_parameters() + self.lora.out_features
476 }
477
478 fn config(&self) -> &Self::Config {
479 self.lora.config()
480 }
481}
482
483impl Mergeable for DoraLayer {
484 fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
485 let direction = self.compute_direction(base_weight)?;
488
489 let mag = self.magnitude.reshape((self.lora.out_features, 1))?;
491 Ok(direction.broadcast_mul(&mag)?)
492 }
493
494 fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
495 let mag = self.magnitude.reshape((self.lora.out_features, 1))?;
498 let epsilon = Tensor::new(1e-8_f32, mag.device())?;
499 let safe_mag = mag.broadcast_add(&epsilon)?;
500
501 let _direction = merged_weight.broadcast_div(&safe_mag)?;
503
504 let a_weight = self.lora.lora_a.weight();
507 let b_weight = self.lora.lora_b.weight();
508 let delta_w = b_weight.matmul(a_weight)?;
509 #[allow(clippy::cast_possible_truncation)]
510 let scaling = Tensor::new(self.lora.scaling as f32, delta_w.device())?;
511 let delta_w = delta_w.broadcast_mul(&scaling)?;
512
513 if let Some(base_weight) = &self.base_weight {
516 Ok(base_weight.clone())
517 } else {
518 #[allow(clippy::cast_possible_truncation)]
520 Ok(merged_weight.broadcast_sub(&delta_w)?)
521 }
522 }
523}
524
525impl Trainable for DoraLayer {
526 fn register_parameters(&self, var_map: &mut VarMap, prefix: &str) -> Result<()> {
527 self.lora.register_parameters(var_map, prefix)
528 }
529
530 fn freeze(&mut self) {
531 self.lora.freeze();
532 }
533
534 fn unfreeze(&mut self) {
535 self.lora.unfreeze();
536 }
537
538 fn is_frozen(&self) -> bool {
539 self.lora.is_frozen()
540 }
541}
542
543impl crate::io::SaveLoad for LoraLayer {
544 #[allow(clippy::similar_names)]
545 fn state_dict(&self) -> Result<std::collections::HashMap<String, Tensor>> {
546 use std::collections::HashMap;
547
548 let mut state_dict = HashMap::new();
549
550 let lora_a_weight = self.lora_a.weight();
552 state_dict.insert("lora_a.weight".to_string(), lora_a_weight.clone());
553
554 let lora_b_weight = self.lora_b.weight();
556 state_dict.insert("lora_b.weight".to_string(), lora_b_weight.clone());
557
558 Ok(state_dict)
559 }
560
561 #[allow(clippy::similar_names)]
562 fn load_state_dict(
563 &mut self,
564 state_dict: std::collections::HashMap<String, Tensor>,
565 ) -> Result<()> {
566 if !state_dict.contains_key("lora_a.weight") || !state_dict.contains_key("lora_b.weight") {
578 return Err(PeftError::WeightLoad(
579 "Missing required keys in state_dict".to_string(),
580 ));
581 }
582
583 let lora_a_shape = state_dict["lora_a.weight"].dims();
585 let lora_b_shape = state_dict["lora_b.weight"].dims();
586
587 if lora_a_shape != [self.config.r, self.in_features] {
588 return Err(PeftError::ShapeMismatch {
589 expected: vec![self.config.r, self.in_features],
590 actual: lora_a_shape.to_vec(),
591 });
592 }
593
594 if lora_b_shape != [self.out_features, self.config.r] {
595 return Err(PeftError::ShapeMismatch {
596 expected: vec![self.out_features, self.config.r],
597 actual: lora_b_shape.to_vec(),
598 });
599 }
600
601 Ok(())
602 }
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608 use candle_core::Device;
609
610 #[test]
611 fn test_lora_config_default() {
612 let config = LoraConfig::default();
613 assert_eq!(config.r, 8);
614 assert_eq!(config.alpha, 16);
615 assert!(config.validate().is_ok());
616 }
617
618 #[test]
619 fn test_lora_config_invalid_rank() {
620 let config = LoraConfig {
621 r: 0,
622 ..Default::default()
623 };
624 assert!(config.validate().is_err());
625 }
626
627 #[test]
628 fn test_lora_layer_creation() {
629 let config = LoraConfig::default();
630 let device = Device::Cpu;
631 let layer = LoraLayer::new_with_zeros(768, 768, config, &device);
632 assert!(layer.is_ok());
633 }
634
635 #[test]
636 fn test_lora_forward_shape() {
637 let config = LoraConfig::default();
638 let device = Device::Cpu;
639 let layer = LoraLayer::new_with_zeros(768, 768, config, &device).unwrap();
640
641 let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
642 let output = layer.forward(&input, None).unwrap();
643
644 assert_eq!(output.shape().dims(), &[1, 10, 768]);
645 }
646
647 #[test]
648 fn test_lora_num_parameters() {
649 let config = LoraConfig {
650 r: 8,
651 alpha: 16,
652 ..Default::default()
653 };
654 let device = Device::Cpu;
655 let layer = LoraLayer::new_with_zeros(768, 768, config, &device).unwrap();
656
657 assert_eq!(layer.num_parameters(), 12288);
659 }
660
661 #[test]
662 fn test_dora_layer_creation() {
663 let config = LoraConfig {
664 use_dora: true,
665 ..Default::default()
666 };
667 let device = Device::Cpu;
668 let layer = DoraLayer::new(768, 768, config, &device, None);
669 assert!(layer.is_ok());
670 }
671
672 #[test]
673 fn test_dora_layer_with_base_weight() {
674 let config = LoraConfig {
675 use_dora: true,
676 ..Default::default()
677 };
678 let device = Device::Cpu;
679 let base_weight = Tensor::randn(0.0f32, 0.02, (768, 768), &device).unwrap();
680 let layer = DoraLayer::new(768, 768, config, &device, Some(&base_weight));
681 assert!(layer.is_ok());
682
683 let layer = layer.unwrap();
684 assert_eq!(layer.magnitude().dims(), &[768]);
686 }
687
688 #[test]
689 fn test_dora_num_parameters() {
690 let config = LoraConfig {
691 r: 8,
692 use_dora: true,
693 ..Default::default()
694 };
695 let device = Device::Cpu;
696 let layer = DoraLayer::new(768, 768, config, &device, None).unwrap();
697
698 assert_eq!(layer.num_parameters(), 12288 + 768);
700 }
701
702 #[test]
703 fn test_dora_fallback_forward() {
704 let config = LoraConfig {
706 use_dora: true,
707 ..Default::default()
708 };
709 let device = Device::Cpu;
710 let layer = DoraLayer::new(768, 768, config, &device, None).unwrap();
711
712 let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
713 let output = layer.forward(&input, None).unwrap();
714
715 assert_eq!(output.shape().dims(), &[1, 10, 768]);
716 }
717
718 #[test]
719 fn test_lora_save_load_weights() -> Result<()> {
720 use crate::io::{load_adapter_weights, save_adapter_weights, SaveLoad};
721 use tempfile::TempDir;
722
723 let device = Device::Cpu;
724 let config = LoraConfig::default();
725 let layer = LoraLayer::new_with_zeros(768, 768, config.clone(), &device)?;
726
727 let temp_dir = TempDir::new().map_err(|e| PeftError::Io(e.to_string()))?;
729 let weights_path = temp_dir.path().join("lora_weights.safetensors");
730
731 let original_state = layer.state_dict()?;
733 assert_eq!(original_state.len(), 2);
734 assert!(original_state.contains_key("lora_a.weight"));
735 assert!(original_state.contains_key("lora_b.weight"));
736
737 save_adapter_weights(&layer, &weights_path)?;
739 assert!(weights_path.exists());
740
741 let mut loaded_layer = LoraLayer::new_with_zeros(768, 768, config, &device)?;
743 load_adapter_weights(&mut loaded_layer, &weights_path, &device)?;
744
745 let loaded_state = loaded_layer.state_dict()?;
747 assert_eq!(loaded_state.len(), original_state.len());
748 assert_eq!(
749 loaded_state["lora_a.weight"].dims(),
750 original_state["lora_a.weight"].dims()
751 );
752 assert_eq!(
753 loaded_state["lora_b.weight"].dims(),
754 original_state["lora_b.weight"].dims()
755 );
756
757 Ok(())
763 }
764
765 #[test]
766 fn test_rslora_scaling() {
767 let config_standard = LoraConfig {
769 r: 8,
770 alpha: 16,
771 use_rslora: false,
772 ..Default::default()
773 };
774 let device = Device::Cpu;
775 let layer_standard = LoraLayer::new_with_zeros(768, 768, config_standard, &device).unwrap();
776 assert!((layer_standard.scaling() - 2.0).abs() < 1e-10);
777
778 let config_rslora = LoraConfig {
780 r: 8,
781 alpha: 16,
782 use_rslora: true,
783 ..Default::default()
784 };
785 let layer_rslora = LoraLayer::new_with_zeros(768, 768, config_rslora, &device).unwrap();
786 let expected_rslora_scaling = 16.0 / 8.0_f64.sqrt();
787 assert!((layer_rslora.scaling() - expected_rslora_scaling).abs() < 1e-10);
788 }
789
790 #[test]
791 fn test_rslora_higher_rank_stability() {
792 let device = Device::Cpu;
794
795 for rank in [8, 16, 32, 64, 128] {
796 let config_standard = LoraConfig {
797 r: rank,
798 alpha: 32,
799 use_rslora: false,
800 ..Default::default()
801 };
802 let config_rslora = LoraConfig {
803 r: rank,
804 alpha: 32,
805 use_rslora: true,
806 ..Default::default()
807 };
808
809 let layer_standard =
810 LoraLayer::new_with_zeros(768, 768, config_standard, &device).unwrap();
811 let layer_rslora = LoraLayer::new_with_zeros(768, 768, config_rslora, &device).unwrap();
812
813 assert!(layer_rslora.scaling() >= layer_standard.scaling());
815 }
816 }
817
818 #[test]
819 fn test_loftq_initialization() {
820 let config = LoraConfig {
821 r: 8,
822 alpha: 16,
823 loftq_iterations: 4,
824 ..Default::default()
825 };
826 let device = Device::Cpu;
827 let layer = LoraLayer::new_with_zeros(768, 768, config, &device).unwrap();
828
829 let b_weight = layer.lora_b.weight();
831 let b_sum = b_weight
832 .abs()
833 .unwrap()
834 .sum_all()
835 .unwrap()
836 .to_scalar::<f32>()
837 .unwrap();
838 assert!(
840 b_sum > 0.0,
841 "LoftQ should initialize B with non-zero values"
842 );
843 }
844}