Skip to main content

peft_rs/adapters/
adalora.rs

1//! `AdaLoRA` (Adaptive Low-Rank Adaptation) implementation.
2//!
3//! `AdaLoRA` dynamically allocates rank budget during training using SVD-based
4//! importance scores. It uses a three-phase training schedule:
5//! 1. Initial warmup phase (tinit steps)
6//! 2. Rank reduction phase (between tinit and `total_step` - tfinal)
7//! 3. Final fine-tuning phase (tfinal steps)
8//!
9//! Reference: <https://arxiv.org/abs/2303.10512>
10
11#![allow(clippy::doc_markdown)]
12#![allow(clippy::cast_possible_truncation)]
13#![allow(clippy::cast_precision_loss)]
14#![allow(clippy::cast_sign_loss)]
15#![allow(clippy::uninlined_format_args)]
16
17use candle_core::{DType, Device, Tensor};
18use candle_nn::VarMap;
19use serde::{Deserialize, Serialize};
20
21use crate::error::{PeftError, Result};
22use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
23
24/// Configuration for AdaLoRA adapters.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct AdaLoraConfig {
27    /// Target average rank after pruning.
28    pub target_r: usize,
29
30    /// Initial rank for each incremental matrix (before pruning).
31    pub init_r: usize,
32
33    /// Scaling factor (typically alpha / r).
34    pub alpha: usize,
35
36    /// Dropout probability applied to outputs.
37    #[serde(default)]
38    pub dropout: f64,
39
40    /// Target modules to apply AdaLoRA to.
41    #[serde(default = "default_target_modules")]
42    pub target_modules: Vec<String>,
43
44    /// Steps of initial warmup (no rank reduction).
45    #[serde(default)]
46    pub tinit: usize,
47
48    /// Steps of final fine-tuning (no rank reduction).
49    #[serde(default)]
50    pub tfinal: usize,
51
52    /// Time interval between budget allocations.
53    #[serde(default = "default_delta_t")]
54    pub delta_t: usize,
55
56    /// Hyperparameter of EMA for sensitivity smoothing.
57    #[serde(default = "default_beta")]
58    pub beta1: f64,
59
60    /// Hyperparameter of EMA for uncertainty quantification.
61    #[serde(default = "default_beta")]
62    pub beta2: f64,
63
64    /// Coefficient of orthogonal regularization.
65    #[serde(default = "default_orth_reg")]
66    pub orth_reg_weight: f64,
67
68    /// Total training steps (required for AdaLoRA).
69    pub total_step: usize,
70}
71
72fn default_target_modules() -> Vec<String> {
73    vec!["q_proj".into(), "v_proj".into()]
74}
75
76fn default_delta_t() -> usize {
77    1
78}
79
80fn default_beta() -> f64 {
81    0.85
82}
83
84fn default_orth_reg() -> f64 {
85    0.5
86}
87
88impl Default for AdaLoraConfig {
89    fn default() -> Self {
90        Self {
91            target_r: 8,
92            init_r: 12,
93            alpha: 16,
94            dropout: 0.0,
95            target_modules: default_target_modules(),
96            tinit: 0,
97            tfinal: 0,
98            delta_t: default_delta_t(),
99            beta1: default_beta(),
100            beta2: default_beta(),
101            orth_reg_weight: default_orth_reg(),
102            total_step: 1000, // Must be set by user
103        }
104    }
105}
106
107impl AdapterConfig for AdaLoraConfig {
108    fn validate(&self) -> Result<()> {
109        if self.init_r == 0 {
110            return Err(PeftError::InvalidConfig("init_r must be > 0".into()));
111        }
112        if self.target_r == 0 {
113            return Err(PeftError::InvalidConfig("target_r must be > 0".into()));
114        }
115        if self.target_r > self.init_r {
116            return Err(PeftError::InvalidConfig(
117                "target_r must be <= init_r".into(),
118            ));
119        }
120        if self.alpha == 0 {
121            return Err(PeftError::InvalidConfig("alpha must be > 0".into()));
122        }
123        if !(0.0..=1.0).contains(&self.dropout) {
124            return Err(PeftError::InvalidConfig(
125                "dropout must be between 0 and 1".into(),
126            ));
127        }
128        if self.total_step == 0 {
129            return Err(PeftError::InvalidConfig("total_step must be > 0".into()));
130        }
131        if self.tinit >= self.total_step.saturating_sub(self.tfinal) {
132            return Err(PeftError::InvalidConfig(
133                "tinit must be < (total_step - tfinal) for budgeting phase".into(),
134            ));
135        }
136        if !(0.0..=1.0).contains(&self.beta1) || !(0.0..=1.0).contains(&self.beta2) {
137            return Err(PeftError::InvalidConfig(
138                "beta1 and beta2 must be between 0 and 1".into(),
139            ));
140        }
141        Ok(())
142    }
143}
144
145/// AdaLoRA layer using SVD-based parameterization.
146///
147/// Uses `W = W0 + P * Λ * Q` where:
148/// - P: Left singular vectors (out_features × r)
149/// - Λ: Diagonal singular values (r)
150/// - Q: Right singular vectors (r × in_features)
151///
152/// This allows for dynamic rank allocation by zeroing out singular values.
153pub struct AdaLoraLayer {
154    /// Left singular vectors: [out_features, init_r]
155    lora_a: Tensor,
156    /// Singular values: [init_r]
157    lora_e: Tensor,
158    /// Right singular vectors: [init_r, in_features]
159    lora_b: Tensor,
160    /// Scaling factor
161    scaling: f64,
162    /// Configuration
163    config: AdaLoraConfig,
164    /// Input dimension
165    in_features: usize,
166    /// Output dimension
167    out_features: usize,
168    /// Current rank (may be reduced during training)
169    current_rank: usize,
170    /// Mask for pruned singular values
171    rank_mask: Tensor,
172    /// Whether gradients are disabled
173    frozen: bool,
174}
175
176impl AdaLoraLayer {
177    /// Create a new AdaLoRA layer.
178    ///
179    /// # Arguments
180    /// * `in_features` - Input dimension
181    /// * `out_features` - Output dimension
182    /// * `config` - AdaLoRA configuration
183    /// * `device` - Device to create tensors on
184    ///
185    /// # Errors
186    /// Returns error if configuration is invalid or tensor initialization fails.
187    pub fn new(
188        in_features: usize,
189        out_features: usize,
190        config: AdaLoraConfig,
191        device: &Device,
192    ) -> Result<Self> {
193        config.validate()?;
194
195        let scaling = config.alpha as f64 / config.init_r as f64;
196        let dtype = DType::F32;
197
198        // Initialize A (left singular vectors) with orthogonal-like initialization
199        let std_a = (1.0 / out_features as f64).sqrt();
200        let lora_a = Tensor::randn(0.0f32, std_a as f32, (out_features, config.init_r), device)?;
201
202        // Initialize E (singular values) to small values
203        let lora_e = Tensor::ones(config.init_r, dtype, device)?;
204        let lora_e = lora_e.broadcast_mul(&Tensor::new(0.01f32, device)?)?;
205
206        // Initialize B (right singular vectors) with orthogonal-like initialization
207        let std_b = (1.0 / in_features as f64).sqrt();
208        let lora_b = Tensor::randn(0.0f32, std_b as f32, (config.init_r, in_features), device)?;
209
210        // Initialize rank mask to all ones (all ranks active)
211        let rank_mask = Tensor::ones(config.init_r, dtype, device)?;
212
213        let init_r = config.init_r;
214
215        Ok(Self {
216            lora_a,
217            lora_e,
218            lora_b,
219            scaling,
220            config,
221            in_features,
222            out_features,
223            current_rank: init_r,
224            rank_mask,
225            frozen: false,
226        })
227    }
228
229    /// Get the current active rank.
230    #[must_use]
231    pub fn current_rank(&self) -> usize {
232        self.current_rank
233    }
234
235    /// Get the target rank.
236    #[must_use]
237    pub fn target_rank(&self) -> usize {
238        self.config.target_r
239    }
240
241    /// Get the initial rank.
242    #[must_use]
243    pub fn init_rank(&self) -> usize {
244        self.config.init_r
245    }
246
247    /// Get the scaling factor.
248    #[must_use]
249    pub fn scaling(&self) -> f64 {
250        self.scaling
251    }
252
253    /// Update the rank mask based on importance scores.
254    ///
255    /// # Arguments
256    /// * `importance_scores` - Importance score for each rank [init_r]
257    /// * `budget` - Number of ranks to keep
258    ///
259    /// # Errors
260    /// Returns error if tensor operations fail.
261    pub fn update_rank_mask(&mut self, importance_scores: &Tensor, budget: usize) -> Result<()> {
262        // Get the indices of top-k importance scores
263        // For simplicity, we'll create a mask based on a threshold
264        // In practice, this would involve sorting and selecting top-k
265
266        if budget >= self.config.init_r {
267            // Keep all ranks
268            self.rank_mask =
269                Tensor::ones(self.config.init_r, DType::F32, importance_scores.device())?;
270            self.current_rank = self.config.init_r;
271        } else if budget == 0 {
272            // Zero out all ranks
273            self.rank_mask =
274                Tensor::zeros(self.config.init_r, DType::F32, importance_scores.device())?;
275            self.current_rank = 0;
276        } else {
277            // Sort importance scores and keep top budget
278            // Note: This is a simplified version - in practice would use argsort
279            let scores = importance_scores.flatten_all()?;
280            let mean_score = scores.mean_all()?;
281            let mean: f32 = mean_score.to_scalar()?;
282
283            // Simple threshold-based approach
284            let threshold = Tensor::new(mean, importance_scores.device())?;
285            let mask = importance_scores.ge(&threshold)?;
286            self.rank_mask = mask.to_dtype(DType::F32)?;
287
288            // Update current rank (count non-zero elements)
289            let sum: f32 = self.rank_mask.sum_all()?.to_scalar()?;
290            self.current_rank = sum as usize;
291        }
292
293        Ok(())
294    }
295
296    /// Compute the orthogonal regularization loss.
297    ///
298    /// Encourages P^T P ≈ I and Q Q^T ≈ I.
299    ///
300    /// # Errors
301    /// Returns error if tensor operations fail.
302    pub fn orthogonal_regularization(&self) -> Result<Tensor> {
303        // P^T P - I
304        let pta = self.lora_a.t()?.matmul(&self.lora_a)?;
305        let eye_a = Tensor::eye(self.config.init_r, DType::F32, self.lora_a.device())?;
306        let orth_loss_a = pta.broadcast_sub(&eye_a)?.sqr()?.sum_all()?;
307
308        // Q Q^T - I
309        let bbt = self.lora_b.matmul(&self.lora_b.t()?)?;
310        let eye_b = Tensor::eye(self.config.init_r, DType::F32, self.lora_b.device())?;
311        let orth_loss_b = bbt.broadcast_sub(&eye_b)?.sqr()?.sum_all()?;
312
313        Ok(orth_loss_a.broadcast_add(&orth_loss_b)?)
314    }
315
316    /// Get the importance scores for rank allocation.
317    ///
318    /// The importance is based on the magnitude of singular values.
319    ///
320    /// # Errors
321    /// Returns error if tensor operations fail.
322    pub fn get_importance_scores(&self) -> Result<Tensor> {
323        // Simple importance: absolute value of singular values
324        Ok(self.lora_e.abs()?)
325    }
326}
327
328impl Adapter for AdaLoraLayer {
329    type Config = AdaLoraConfig;
330
331    fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
332        // AdaLoRA forward: x @ B^T @ diag(E * mask) @ A^T * scaling
333        // Input shape: [batch, seq, in_features]
334        // B shape: [init_r, in_features], B^T shape: [in_features, init_r]
335        // A shape: [out_features, init_r], A^T shape: [init_r, out_features]
336
337        // For batched matmul, we need to handle the 3D tensor
338        // Use the last dimension for matmul
339        let input_dims = input.dims();
340
341        // First: input @ B^T -> [batch, seq, init_r]
342        // Reshape input to [batch*seq, in_features] for matmul
343        let batch_seq = input_dims[0] * input_dims[1];
344        let input_2d = input.reshape((batch_seq, self.in_features))?;
345
346        // Compute input_2d @ B^T = [batch*seq, in_features] @ [in_features, init_r] = [batch*seq, init_r]
347        let out = input_2d.matmul(&self.lora_b.t()?)?;
348
349        // Apply singular values with mask: [batch*seq, init_r] * [init_r]
350        let masked_e = self.lora_e.broadcast_mul(&self.rank_mask)?;
351        let masked_e = masked_e.reshape((1, self.config.init_r))?;
352        let out = out.broadcast_mul(&masked_e)?;
353
354        // Then: @ A^T -> [batch*seq, out_features]
355        // A^T shape: [init_r, out_features]
356        let out = out.matmul(&self.lora_a.t()?)?;
357
358        // Reshape back to [batch, seq, out_features]
359        let out = out.reshape((input_dims[0], input_dims[1], self.out_features))?;
360
361        // Apply scaling
362        let scaling = Tensor::new(self.scaling as f32, out.device())?;
363        let out = out.broadcast_mul(&scaling)?;
364
365        // Add to base output if provided
366        match base_output {
367            Some(base) => Ok(base.broadcast_add(&out)?),
368            None => Ok(out),
369        }
370    }
371
372    fn num_parameters(&self) -> usize {
373        // A: out_features × init_r
374        // E: init_r
375        // B: init_r × in_features
376        self.out_features * self.config.init_r
377            + self.config.init_r
378            + self.config.init_r * self.in_features
379    }
380
381    fn config(&self) -> &Self::Config {
382        &self.config
383    }
384}
385
386impl Mergeable for AdaLoraLayer {
387    fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
388        // ΔW = A @ diag(E * mask) @ B * scaling
389        // Apply mask to singular values
390        let masked_e = self.lora_e.broadcast_mul(&self.rank_mask)?;
391
392        // Compute A @ diag(E) = A * E (broadcast along columns)
393        let masked_e_col = masked_e.reshape((self.config.init_r, 1))?;
394        let ae = self.lora_a.broadcast_mul(&masked_e_col.t()?)?;
395
396        // Then @ B
397        let delta_w = ae.matmul(&self.lora_b)?;
398
399        // Apply scaling
400        let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
401        let delta_w = delta_w.broadcast_mul(&scaling)?;
402
403        Ok(base_weight.broadcast_add(&delta_w)?)
404    }
405
406    fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
407        let masked_e = self.lora_e.broadcast_mul(&self.rank_mask)?;
408
409        let masked_e_col = masked_e.reshape((self.config.init_r, 1))?;
410        let ae = self.lora_a.broadcast_mul(&masked_e_col.t()?)?;
411        let delta_w = ae.matmul(&self.lora_b)?;
412
413        let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
414        let delta_w = delta_w.broadcast_mul(&scaling)?;
415
416        Ok(merged_weight.broadcast_sub(&delta_w)?)
417    }
418}
419
420impl Trainable for AdaLoraLayer {
421    fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
422        // Note: In the current design, tensors are created directly.
423        // For full training support, tensors should be created via VarBuilder
424        // during construction, which automatically registers them.
425        // This is a simplified implementation suitable for inference.
426        Ok(())
427    }
428
429    fn freeze(&mut self) {
430        self.frozen = true;
431    }
432
433    fn unfreeze(&mut self) {
434        self.frozen = false;
435    }
436
437    fn is_frozen(&self) -> bool {
438        self.frozen
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_adalora_config_default() {
448        let config = AdaLoraConfig::default();
449        assert_eq!(config.target_r, 8);
450        assert_eq!(config.init_r, 12);
451        assert!(config.validate().is_ok());
452    }
453
454    #[test]
455    fn test_adalora_config_invalid_rank() {
456        let config = AdaLoraConfig {
457            target_r: 16,
458            init_r: 8, // target > init is invalid
459            ..Default::default()
460        };
461        assert!(config.validate().is_err());
462    }
463
464    #[test]
465    fn test_adalora_config_invalid_schedule() {
466        let config = AdaLoraConfig {
467            tinit: 500,
468            tfinal: 600,
469            total_step: 1000, // tinit >= total_step - tfinal
470            ..Default::default()
471        };
472        assert!(config.validate().is_err());
473    }
474
475    #[test]
476    fn test_adalora_layer_creation() {
477        let config = AdaLoraConfig::default();
478        let device = Device::Cpu;
479        let layer = AdaLoraLayer::new(768, 768, config, &device);
480        assert!(layer.is_ok());
481
482        let layer = layer.unwrap();
483        assert_eq!(layer.init_rank(), 12);
484        assert_eq!(layer.target_rank(), 8);
485        assert_eq!(layer.current_rank(), 12);
486    }
487
488    #[test]
489    fn test_adalora_forward_shape() {
490        let config = AdaLoraConfig::default();
491        let device = Device::Cpu;
492        let layer = AdaLoraLayer::new(768, 768, config, &device).unwrap();
493
494        let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
495        let output = layer.forward(&input, None).unwrap();
496
497        assert_eq!(output.shape().dims(), &[1, 10, 768]);
498    }
499
500    #[test]
501    fn test_adalora_num_parameters() {
502        let config = AdaLoraConfig {
503            init_r: 12,
504            ..Default::default()
505        };
506        let device = Device::Cpu;
507        let layer = AdaLoraLayer::new(768, 768, config, &device).unwrap();
508
509        // A: 768 × 12 = 9216
510        // E: 12
511        // B: 12 × 768 = 9216
512        // Total: 18444
513        assert_eq!(layer.num_parameters(), 768 * 12 + 12 + 12 * 768);
514    }
515
516    #[test]
517    fn test_adalora_importance_scores() {
518        let config = AdaLoraConfig::default();
519        let device = Device::Cpu;
520        let layer = AdaLoraLayer::new(768, 768, config, &device).unwrap();
521
522        let scores = layer.get_importance_scores().unwrap();
523        assert_eq!(scores.dims(), &[12]);
524    }
525
526    #[test]
527    fn test_adalora_orthogonal_regularization() {
528        let config = AdaLoraConfig::default();
529        let device = Device::Cpu;
530        let layer = AdaLoraLayer::new(64, 64, config, &device).unwrap();
531
532        let orth_loss = layer.orthogonal_regularization().unwrap();
533        // Should be a scalar tensor (0-dimensional)
534        assert!(orth_loss.dims().is_empty());
535    }
536}