1#![allow(clippy::doc_markdown)]
12#![allow(clippy::cast_possible_truncation)]
13#![allow(clippy::cast_precision_loss)]
14#![allow(clippy::cast_sign_loss)]
15#![allow(clippy::uninlined_format_args)]
16
17use candle_core::{DType, Device, Tensor};
18use candle_nn::VarMap;
19use serde::{Deserialize, Serialize};
20
21use crate::error::{PeftError, Result};
22use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct AdaLoraConfig {
27 pub target_r: usize,
29
30 pub init_r: usize,
32
33 pub alpha: usize,
35
36 #[serde(default)]
38 pub dropout: f64,
39
40 #[serde(default = "default_target_modules")]
42 pub target_modules: Vec<String>,
43
44 #[serde(default)]
46 pub tinit: usize,
47
48 #[serde(default)]
50 pub tfinal: usize,
51
52 #[serde(default = "default_delta_t")]
54 pub delta_t: usize,
55
56 #[serde(default = "default_beta")]
58 pub beta1: f64,
59
60 #[serde(default = "default_beta")]
62 pub beta2: f64,
63
64 #[serde(default = "default_orth_reg")]
66 pub orth_reg_weight: f64,
67
68 pub total_step: usize,
70}
71
72fn default_target_modules() -> Vec<String> {
73 vec!["q_proj".into(), "v_proj".into()]
74}
75
76fn default_delta_t() -> usize {
77 1
78}
79
80fn default_beta() -> f64 {
81 0.85
82}
83
84fn default_orth_reg() -> f64 {
85 0.5
86}
87
88impl Default for AdaLoraConfig {
89 fn default() -> Self {
90 Self {
91 target_r: 8,
92 init_r: 12,
93 alpha: 16,
94 dropout: 0.0,
95 target_modules: default_target_modules(),
96 tinit: 0,
97 tfinal: 0,
98 delta_t: default_delta_t(),
99 beta1: default_beta(),
100 beta2: default_beta(),
101 orth_reg_weight: default_orth_reg(),
102 total_step: 1000, }
104 }
105}
106
107impl AdapterConfig for AdaLoraConfig {
108 fn validate(&self) -> Result<()> {
109 if self.init_r == 0 {
110 return Err(PeftError::InvalidConfig("init_r must be > 0".into()));
111 }
112 if self.target_r == 0 {
113 return Err(PeftError::InvalidConfig("target_r must be > 0".into()));
114 }
115 if self.target_r > self.init_r {
116 return Err(PeftError::InvalidConfig(
117 "target_r must be <= init_r".into(),
118 ));
119 }
120 if self.alpha == 0 {
121 return Err(PeftError::InvalidConfig("alpha must be > 0".into()));
122 }
123 if !(0.0..=1.0).contains(&self.dropout) {
124 return Err(PeftError::InvalidConfig(
125 "dropout must be between 0 and 1".into(),
126 ));
127 }
128 if self.total_step == 0 {
129 return Err(PeftError::InvalidConfig("total_step must be > 0".into()));
130 }
131 if self.tinit >= self.total_step.saturating_sub(self.tfinal) {
132 return Err(PeftError::InvalidConfig(
133 "tinit must be < (total_step - tfinal) for budgeting phase".into(),
134 ));
135 }
136 if !(0.0..=1.0).contains(&self.beta1) || !(0.0..=1.0).contains(&self.beta2) {
137 return Err(PeftError::InvalidConfig(
138 "beta1 and beta2 must be between 0 and 1".into(),
139 ));
140 }
141 Ok(())
142 }
143}
144
145pub struct AdaLoraLayer {
154 lora_a: Tensor,
156 lora_e: Tensor,
158 lora_b: Tensor,
160 scaling: f64,
162 config: AdaLoraConfig,
164 in_features: usize,
166 out_features: usize,
168 current_rank: usize,
170 rank_mask: Tensor,
172 frozen: bool,
174}
175
176impl AdaLoraLayer {
177 pub fn new(
188 in_features: usize,
189 out_features: usize,
190 config: AdaLoraConfig,
191 device: &Device,
192 ) -> Result<Self> {
193 config.validate()?;
194
195 let scaling = config.alpha as f64 / config.init_r as f64;
196 let dtype = DType::F32;
197
198 let std_a = (1.0 / out_features as f64).sqrt();
200 let lora_a = Tensor::randn(0.0f32, std_a as f32, (out_features, config.init_r), device)?;
201
202 let lora_e = Tensor::ones(config.init_r, dtype, device)?;
204 let lora_e = lora_e.broadcast_mul(&Tensor::new(0.01f32, device)?)?;
205
206 let std_b = (1.0 / in_features as f64).sqrt();
208 let lora_b = Tensor::randn(0.0f32, std_b as f32, (config.init_r, in_features), device)?;
209
210 let rank_mask = Tensor::ones(config.init_r, dtype, device)?;
212
213 let init_r = config.init_r;
214
215 Ok(Self {
216 lora_a,
217 lora_e,
218 lora_b,
219 scaling,
220 config,
221 in_features,
222 out_features,
223 current_rank: init_r,
224 rank_mask,
225 frozen: false,
226 })
227 }
228
229 #[must_use]
231 pub fn current_rank(&self) -> usize {
232 self.current_rank
233 }
234
235 #[must_use]
237 pub fn target_rank(&self) -> usize {
238 self.config.target_r
239 }
240
241 #[must_use]
243 pub fn init_rank(&self) -> usize {
244 self.config.init_r
245 }
246
247 #[must_use]
249 pub fn scaling(&self) -> f64 {
250 self.scaling
251 }
252
253 pub fn update_rank_mask(&mut self, importance_scores: &Tensor, budget: usize) -> Result<()> {
262 if budget >= self.config.init_r {
267 self.rank_mask =
269 Tensor::ones(self.config.init_r, DType::F32, importance_scores.device())?;
270 self.current_rank = self.config.init_r;
271 } else if budget == 0 {
272 self.rank_mask =
274 Tensor::zeros(self.config.init_r, DType::F32, importance_scores.device())?;
275 self.current_rank = 0;
276 } else {
277 let scores = importance_scores.flatten_all()?;
280 let mean_score = scores.mean_all()?;
281 let mean: f32 = mean_score.to_scalar()?;
282
283 let threshold = Tensor::new(mean, importance_scores.device())?;
285 let mask = importance_scores.ge(&threshold)?;
286 self.rank_mask = mask.to_dtype(DType::F32)?;
287
288 let sum: f32 = self.rank_mask.sum_all()?.to_scalar()?;
290 self.current_rank = sum as usize;
291 }
292
293 Ok(())
294 }
295
296 pub fn orthogonal_regularization(&self) -> Result<Tensor> {
303 let pta = self.lora_a.t()?.matmul(&self.lora_a)?;
305 let eye_a = Tensor::eye(self.config.init_r, DType::F32, self.lora_a.device())?;
306 let orth_loss_a = pta.broadcast_sub(&eye_a)?.sqr()?.sum_all()?;
307
308 let bbt = self.lora_b.matmul(&self.lora_b.t()?)?;
310 let eye_b = Tensor::eye(self.config.init_r, DType::F32, self.lora_b.device())?;
311 let orth_loss_b = bbt.broadcast_sub(&eye_b)?.sqr()?.sum_all()?;
312
313 Ok(orth_loss_a.broadcast_add(&orth_loss_b)?)
314 }
315
316 pub fn get_importance_scores(&self) -> Result<Tensor> {
323 Ok(self.lora_e.abs()?)
325 }
326}
327
328impl Adapter for AdaLoraLayer {
329 type Config = AdaLoraConfig;
330
331 fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
332 let input_dims = input.dims();
340
341 let batch_seq = input_dims[0] * input_dims[1];
344 let input_2d = input.reshape((batch_seq, self.in_features))?;
345
346 let out = input_2d.matmul(&self.lora_b.t()?)?;
348
349 let masked_e = self.lora_e.broadcast_mul(&self.rank_mask)?;
351 let masked_e = masked_e.reshape((1, self.config.init_r))?;
352 let out = out.broadcast_mul(&masked_e)?;
353
354 let out = out.matmul(&self.lora_a.t()?)?;
357
358 let out = out.reshape((input_dims[0], input_dims[1], self.out_features))?;
360
361 let scaling = Tensor::new(self.scaling as f32, out.device())?;
363 let out = out.broadcast_mul(&scaling)?;
364
365 match base_output {
367 Some(base) => Ok(base.broadcast_add(&out)?),
368 None => Ok(out),
369 }
370 }
371
372 fn num_parameters(&self) -> usize {
373 self.out_features * self.config.init_r
377 + self.config.init_r
378 + self.config.init_r * self.in_features
379 }
380
381 fn config(&self) -> &Self::Config {
382 &self.config
383 }
384}
385
386impl Mergeable for AdaLoraLayer {
387 fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
388 let masked_e = self.lora_e.broadcast_mul(&self.rank_mask)?;
391
392 let masked_e_col = masked_e.reshape((self.config.init_r, 1))?;
394 let ae = self.lora_a.broadcast_mul(&masked_e_col.t()?)?;
395
396 let delta_w = ae.matmul(&self.lora_b)?;
398
399 let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
401 let delta_w = delta_w.broadcast_mul(&scaling)?;
402
403 Ok(base_weight.broadcast_add(&delta_w)?)
404 }
405
406 fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
407 let masked_e = self.lora_e.broadcast_mul(&self.rank_mask)?;
408
409 let masked_e_col = masked_e.reshape((self.config.init_r, 1))?;
410 let ae = self.lora_a.broadcast_mul(&masked_e_col.t()?)?;
411 let delta_w = ae.matmul(&self.lora_b)?;
412
413 let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
414 let delta_w = delta_w.broadcast_mul(&scaling)?;
415
416 Ok(merged_weight.broadcast_sub(&delta_w)?)
417 }
418}
419
420impl Trainable for AdaLoraLayer {
421 fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
422 Ok(())
427 }
428
429 fn freeze(&mut self) {
430 self.frozen = true;
431 }
432
433 fn unfreeze(&mut self) {
434 self.frozen = false;
435 }
436
437 fn is_frozen(&self) -> bool {
438 self.frozen
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_adalora_config_default() {
448 let config = AdaLoraConfig::default();
449 assert_eq!(config.target_r, 8);
450 assert_eq!(config.init_r, 12);
451 assert!(config.validate().is_ok());
452 }
453
454 #[test]
455 fn test_adalora_config_invalid_rank() {
456 let config = AdaLoraConfig {
457 target_r: 16,
458 init_r: 8, ..Default::default()
460 };
461 assert!(config.validate().is_err());
462 }
463
464 #[test]
465 fn test_adalora_config_invalid_schedule() {
466 let config = AdaLoraConfig {
467 tinit: 500,
468 tfinal: 600,
469 total_step: 1000, ..Default::default()
471 };
472 assert!(config.validate().is_err());
473 }
474
475 #[test]
476 fn test_adalora_layer_creation() {
477 let config = AdaLoraConfig::default();
478 let device = Device::Cpu;
479 let layer = AdaLoraLayer::new(768, 768, config, &device);
480 assert!(layer.is_ok());
481
482 let layer = layer.unwrap();
483 assert_eq!(layer.init_rank(), 12);
484 assert_eq!(layer.target_rank(), 8);
485 assert_eq!(layer.current_rank(), 12);
486 }
487
488 #[test]
489 fn test_adalora_forward_shape() {
490 let config = AdaLoraConfig::default();
491 let device = Device::Cpu;
492 let layer = AdaLoraLayer::new(768, 768, config, &device).unwrap();
493
494 let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
495 let output = layer.forward(&input, None).unwrap();
496
497 assert_eq!(output.shape().dims(), &[1, 10, 768]);
498 }
499
500 #[test]
501 fn test_adalora_num_parameters() {
502 let config = AdaLoraConfig {
503 init_r: 12,
504 ..Default::default()
505 };
506 let device = Device::Cpu;
507 let layer = AdaLoraLayer::new(768, 768, config, &device).unwrap();
508
509 assert_eq!(layer.num_parameters(), 768 * 12 + 12 + 12 * 768);
514 }
515
516 #[test]
517 fn test_adalora_importance_scores() {
518 let config = AdaLoraConfig::default();
519 let device = Device::Cpu;
520 let layer = AdaLoraLayer::new(768, 768, config, &device).unwrap();
521
522 let scores = layer.get_importance_scores().unwrap();
523 assert_eq!(scores.dims(), &[12]);
524 }
525
526 #[test]
527 fn test_adalora_orthogonal_regularization() {
528 let config = AdaLoraConfig::default();
529 let device = Device::Cpu;
530 let layer = AdaLoraLayer::new(64, 64, config, &device).unwrap();
531
532 let orth_loss = layer.orthogonal_regularization().unwrap();
533 assert!(orth_loss.dims().is_empty());
535 }
536}