Skip to main content

peft_rs/adapters/
vera.rs

1//! `VeRA` (Vector-based Random Matrix Adaptation) implementation.
2//!
3//! `VeRA` uses frozen random matrices with trainable scaling vectors for
4//! ultra-efficient adaptation. It achieves similar performance to `LoRA`
5//! with significantly fewer trainable parameters.
6//!
7//! Reference: <https://arxiv.org/abs/2310.11454>
8
9use candle_core::{Device, Tensor};
10use candle_nn::VarMap;
11use serde::{Deserialize, Serialize};
12
13use crate::error::{PeftError, Result};
14use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
15
16/// Configuration for `VeRA` adapters.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct VeraConfig {
19    /// Rank of the random projection matrices.
20    pub r: usize,
21
22    /// Initial value for the scaling vector d.
23    #[serde(default = "default_d_initial")]
24    pub d_initial: f64,
25
26    /// Seed for random number generation (for reproducible projections).
27    #[serde(default)]
28    pub projection_prng_key: u64,
29
30    /// Whether to save the projection matrices (for exact reproducibility).
31    #[serde(default)]
32    pub save_projection: bool,
33
34    /// Target modules to apply `VeRA` to.
35    #[serde(default = "default_target_modules")]
36    pub target_modules: Vec<String>,
37}
38
39fn default_d_initial() -> f64 {
40    0.1
41}
42
43fn default_target_modules() -> Vec<String> {
44    vec!["q_proj".into(), "v_proj".into()]
45}
46
47impl Default for VeraConfig {
48    fn default() -> Self {
49        Self {
50            r: 256,
51            d_initial: default_d_initial(),
52            projection_prng_key: 0,
53            save_projection: false,
54            target_modules: default_target_modules(),
55        }
56    }
57}
58
59impl AdapterConfig for VeraConfig {
60    fn validate(&self) -> Result<()> {
61        if self.r == 0 {
62            return Err(PeftError::InvalidConfig("rank must be > 0".into()));
63        }
64        Ok(())
65    }
66}
67
68/// `VeRA` layer implementing Vector-based Random Matrix Adaptation.
69///
70/// Uses frozen random matrices A and B with trainable scaling vectors:
71/// `ΔW = B @ diag(d) @ A`
72///
73/// Where:
74/// - A: Frozen random matrix [r, `in_features`] (Kaiming initialization)
75/// - B: Frozen random matrix [`out_features`, r] (zero initialization or small random)
76/// - d: Trainable scaling vector [r]
77/// - b: Optional trainable bias vector [`out_features`]
78pub struct VeraLayer {
79    /// Frozen random projection A: [r, `in_features`]
80    vera_a: Tensor,
81    /// Frozen random projection B: [`out_features`, r]
82    vera_b: Tensor,
83    /// Trainable scaling vector d: [r]
84    vera_d: Tensor,
85    /// Optional trainable bias b: [`out_features`]
86    vera_b_bias: Option<Tensor>,
87    /// Configuration
88    config: VeraConfig,
89    /// Input dimension
90    in_features: usize,
91    /// Output dimension
92    out_features: usize,
93    /// Whether gradients are disabled
94    frozen: bool,
95}
96
97impl VeraLayer {
98    /// Create a new `VeRA` layer.
99    ///
100    /// # Arguments
101    /// * `in_features` - Input dimension
102    /// * `out_features` - Output dimension
103    /// * `config` - `VeRA` configuration
104    /// * `device` - Device to create tensors on
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if configuration validation fails or layer construction fails.
109    pub fn new(
110        in_features: usize,
111        out_features: usize,
112        config: VeraConfig,
113        device: &Device,
114    ) -> Result<Self> {
115        config.validate()?;
116
117        // Initialize frozen random projection A with Kaiming uniform
118        #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
119        let std_a = (1.0 / in_features as f64).sqrt() as f32;
120        let vera_a = Tensor::randn(0.0f32, std_a, (config.r, in_features), device)?;
121
122        // Initialize frozen random projection B with zeros (or small random)
123        // In the original paper, B is initialized to zeros for a clean start
124        let vera_b = Tensor::zeros((out_features, config.r), candle_core::DType::F32, device)?;
125
126        // Initialize trainable scaling vector d
127        #[allow(clippy::cast_possible_truncation)]
128        let vera_d = Tensor::full(config.d_initial as f32, config.r, device)?;
129
130        Ok(Self {
131            vera_a,
132            vera_b,
133            vera_d,
134            vera_b_bias: None,
135            config,
136            in_features,
137            out_features,
138            frozen: false,
139        })
140    }
141
142    /// Create a new `VeRA` layer with trainable bias.
143    ///
144    /// # Arguments
145    /// * `in_features` - Input dimension
146    /// * `out_features` - Output dimension
147    /// * `config` - `VeRA` configuration
148    /// * `device` - Device to create tensors on
149    ///
150    /// # Errors
151    ///
152    /// Returns an error if configuration validation fails or layer construction fails.
153    pub fn new_with_bias(
154        in_features: usize,
155        out_features: usize,
156        config: VeraConfig,
157        device: &Device,
158    ) -> Result<Self> {
159        let mut layer = Self::new(in_features, out_features, config, device)?;
160        layer.vera_b_bias = Some(Tensor::zeros(
161            out_features,
162            candle_core::DType::F32,
163            device,
164        )?);
165        Ok(layer)
166    }
167
168    /// Get the scaling vector d.
169    #[must_use]
170    pub fn scaling_vector(&self) -> &Tensor {
171        &self.vera_d
172    }
173
174    /// Get the rank.
175    #[must_use]
176    pub fn rank(&self) -> usize {
177        self.config.r
178    }
179
180    /// Compute the weight delta: B @ diag(d) @ A
181    fn compute_delta_w(&self) -> Result<Tensor> {
182        // diag(d) @ A: scale each row of A by corresponding element of d
183        // d: [r], A: [r, in_features]
184        // Result: [r, in_features]
185        let d_col = self.vera_d.reshape((self.config.r, 1))?;
186        let da = self.vera_a.broadcast_mul(&d_col)?;
187
188        // B @ (diag(d) @ A)
189        // B: [out_features, r], da: [r, in_features]
190        // Result: [out_features, in_features]
191        Ok(self.vera_b.matmul(&da)?)
192    }
193}
194
195impl Adapter for VeraLayer {
196    type Config = VeraConfig;
197
198    fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
199        // Compute delta weight
200        let delta_w = self.compute_delta_w()?;
201
202        // Compute: input @ delta_w^T
203        let input_dims = input.dims();
204        let batch_seq = input_dims[0] * input_dims[1];
205        let input_2d = input.reshape((batch_seq, self.in_features))?;
206
207        let mut vera_out = input_2d.matmul(&delta_w.t()?)?;
208
209        // Add bias if present
210        if let Some(bias) = &self.vera_b_bias {
211            let bias_expanded = bias.reshape((1, self.out_features))?;
212            vera_out = vera_out.broadcast_add(&bias_expanded)?;
213        }
214
215        let vera_out = vera_out.reshape((input_dims[0], input_dims[1], self.out_features))?;
216
217        // Add to base output if provided
218        match base_output {
219            Some(base) => Ok(base.broadcast_add(&vera_out)?),
220            None => Ok(vera_out),
221        }
222    }
223
224    fn num_parameters(&self) -> usize {
225        // Only the scaling vector d is trainable
226        // Optionally, bias b is also trainable
227        let mut params = self.config.r;
228        if self.vera_b_bias.is_some() {
229            params += self.out_features;
230        }
231        params
232    }
233
234    fn config(&self) -> &Self::Config {
235        &self.config
236    }
237}
238
239impl Mergeable for VeraLayer {
240    fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
241        let delta_w = self.compute_delta_w()?;
242        Ok(base_weight.broadcast_add(&delta_w)?)
243    }
244
245    fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
246        let delta_w = self.compute_delta_w()?;
247        Ok(merged_weight.broadcast_sub(&delta_w)?)
248    }
249}
250
251impl Trainable for VeraLayer {
252    fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
253        // Note: In the current design, tensors are created directly.
254        // For full training support, only vera_d (and optionally vera_b_bias)
255        // should be registered as trainable. vera_a and vera_b are frozen.
256        Ok(())
257    }
258
259    fn freeze(&mut self) {
260        self.frozen = true;
261    }
262
263    fn unfreeze(&mut self) {
264        self.frozen = false;
265    }
266
267    fn is_frozen(&self) -> bool {
268        self.frozen
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use candle_core::DType;
276
277    #[test]
278    fn test_vera_config_default() {
279        let config = VeraConfig::default();
280        assert_eq!(config.r, 256);
281        assert!((config.d_initial - 0.1).abs() < 1e-6);
282        assert!(config.validate().is_ok());
283    }
284
285    #[test]
286    fn test_vera_config_invalid_rank() {
287        let config = VeraConfig {
288            r: 0,
289            ..Default::default()
290        };
291        assert!(config.validate().is_err());
292    }
293
294    #[test]
295    fn test_vera_layer_creation() {
296        let config = VeraConfig {
297            r: 64,
298            ..Default::default()
299        };
300        let device = Device::Cpu;
301        let layer = VeraLayer::new(768, 768, config, &device);
302        assert!(layer.is_ok());
303    }
304
305    #[test]
306    fn test_vera_layer_with_bias() {
307        let config = VeraConfig {
308            r: 64,
309            ..Default::default()
310        };
311        let device = Device::Cpu;
312        let layer = VeraLayer::new_with_bias(768, 768, config, &device);
313        assert!(layer.is_ok());
314
315        let layer = layer.unwrap();
316        assert!(layer.vera_b_bias.is_some());
317    }
318
319    #[test]
320    fn test_vera_forward_shape() {
321        let config = VeraConfig {
322            r: 64,
323            ..Default::default()
324        };
325        let device = Device::Cpu;
326        let layer = VeraLayer::new(768, 768, config, &device).unwrap();
327
328        let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
329        let output = layer.forward(&input, None).unwrap();
330
331        assert_eq!(output.shape().dims(), &[1, 10, 768]);
332    }
333
334    #[test]
335    fn test_vera_forward_with_base_output() {
336        let config = VeraConfig {
337            r: 64,
338            ..Default::default()
339        };
340        let device = Device::Cpu;
341        let layer = VeraLayer::new(768, 768, config, &device).unwrap();
342
343        let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
344        let base_output = Tensor::ones(&[1, 10, 768], DType::F32, &device).unwrap();
345        let output = layer.forward(&input, Some(&base_output)).unwrap();
346
347        assert_eq!(output.shape().dims(), &[1, 10, 768]);
348    }
349
350    #[test]
351    fn test_vera_num_parameters() {
352        let config = VeraConfig {
353            r: 64,
354            ..Default::default()
355        };
356        let device = Device::Cpu;
357        let layer = VeraLayer::new(768, 768, config, &device).unwrap();
358
359        // Only d vector is trainable: 64 parameters
360        assert_eq!(layer.num_parameters(), 64);
361    }
362
363    #[test]
364    fn test_vera_num_parameters_with_bias() {
365        let config = VeraConfig {
366            r: 64,
367            ..Default::default()
368        };
369        let device = Device::Cpu;
370        let layer = VeraLayer::new_with_bias(768, 768, config, &device).unwrap();
371
372        // d vector + bias: 64 + 768 = 832
373        assert_eq!(layer.num_parameters(), 64 + 768);
374    }
375
376    #[test]
377    fn test_vera_merge_unmerge() {
378        let config = VeraConfig {
379            r: 32,
380            d_initial: 0.01,
381            ..Default::default()
382        };
383        let device = Device::Cpu;
384        let layer = VeraLayer::new(64, 64, config, &device).unwrap();
385
386        let base_weight = Tensor::randn(0.0f32, 0.02, (64, 64), &device).unwrap();
387        let merged = layer.merge(&base_weight).unwrap();
388        let unmerged = layer.unmerge(&merged).unwrap();
389
390        // Unmerged should be close to original
391        let diff = unmerged.broadcast_sub(&base_weight).unwrap();
392        let max_diff: f32 = diff
393            .abs()
394            .unwrap()
395            .max(0)
396            .unwrap()
397            .max(0)
398            .unwrap()
399            .to_scalar()
400            .unwrap();
401        assert!(max_diff < 1e-5);
402    }
403
404    #[test]
405    fn test_vera_freeze_unfreeze() {
406        let config = VeraConfig::default();
407        let device = Device::Cpu;
408        let mut layer = VeraLayer::new(768, 768, config, &device).unwrap();
409
410        assert!(!layer.is_frozen());
411        layer.freeze();
412        assert!(layer.is_frozen());
413        layer.unfreeze();
414        assert!(!layer.is_frozen());
415    }
416
417    #[test]
418    fn test_vera_ultra_efficient() {
419        // VeRA should have far fewer parameters than LoRA for same rank
420        let config = VeraConfig {
421            r: 64,
422            ..Default::default()
423        };
424        let device = Device::Cpu;
425        let layer = VeraLayer::new(768, 768, config, &device).unwrap();
426
427        // VeRA: only 64 trainable params (the d vector)
428        // LoRA with r=64: 64 * 768 + 64 * 768 = 98,304 params
429        assert_eq!(layer.num_parameters(), 64);
430
431        // That's ~1500x fewer parameters than equivalent LoRA!
432        let lora_equivalent_params = 64 * (768 + 768);
433        assert!(layer.num_parameters() < lora_equivalent_params / 1000);
434    }
435}