1use candle_core::{Device, Tensor};
10use candle_nn::VarMap;
11use serde::{Deserialize, Serialize};
12
13use crate::error::{PeftError, Result};
14use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct VeraConfig {
19 pub r: usize,
21
22 #[serde(default = "default_d_initial")]
24 pub d_initial: f64,
25
26 #[serde(default)]
28 pub projection_prng_key: u64,
29
30 #[serde(default)]
32 pub save_projection: bool,
33
34 #[serde(default = "default_target_modules")]
36 pub target_modules: Vec<String>,
37}
38
39fn default_d_initial() -> f64 {
40 0.1
41}
42
43fn default_target_modules() -> Vec<String> {
44 vec!["q_proj".into(), "v_proj".into()]
45}
46
47impl Default for VeraConfig {
48 fn default() -> Self {
49 Self {
50 r: 256,
51 d_initial: default_d_initial(),
52 projection_prng_key: 0,
53 save_projection: false,
54 target_modules: default_target_modules(),
55 }
56 }
57}
58
59impl AdapterConfig for VeraConfig {
60 fn validate(&self) -> Result<()> {
61 if self.r == 0 {
62 return Err(PeftError::InvalidConfig("rank must be > 0".into()));
63 }
64 Ok(())
65 }
66}
67
68pub struct VeraLayer {
79 vera_a: Tensor,
81 vera_b: Tensor,
83 vera_d: Tensor,
85 vera_b_bias: Option<Tensor>,
87 config: VeraConfig,
89 in_features: usize,
91 out_features: usize,
93 frozen: bool,
95}
96
97impl VeraLayer {
98 pub fn new(
110 in_features: usize,
111 out_features: usize,
112 config: VeraConfig,
113 device: &Device,
114 ) -> Result<Self> {
115 config.validate()?;
116
117 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
119 let std_a = (1.0 / in_features as f64).sqrt() as f32;
120 let vera_a = Tensor::randn(0.0f32, std_a, (config.r, in_features), device)?;
121
122 let vera_b = Tensor::zeros((out_features, config.r), candle_core::DType::F32, device)?;
125
126 #[allow(clippy::cast_possible_truncation)]
128 let vera_d = Tensor::full(config.d_initial as f32, config.r, device)?;
129
130 Ok(Self {
131 vera_a,
132 vera_b,
133 vera_d,
134 vera_b_bias: None,
135 config,
136 in_features,
137 out_features,
138 frozen: false,
139 })
140 }
141
142 pub fn new_with_bias(
154 in_features: usize,
155 out_features: usize,
156 config: VeraConfig,
157 device: &Device,
158 ) -> Result<Self> {
159 let mut layer = Self::new(in_features, out_features, config, device)?;
160 layer.vera_b_bias = Some(Tensor::zeros(
161 out_features,
162 candle_core::DType::F32,
163 device,
164 )?);
165 Ok(layer)
166 }
167
168 #[must_use]
170 pub fn scaling_vector(&self) -> &Tensor {
171 &self.vera_d
172 }
173
174 #[must_use]
176 pub fn rank(&self) -> usize {
177 self.config.r
178 }
179
180 fn compute_delta_w(&self) -> Result<Tensor> {
182 let d_col = self.vera_d.reshape((self.config.r, 1))?;
186 let da = self.vera_a.broadcast_mul(&d_col)?;
187
188 Ok(self.vera_b.matmul(&da)?)
192 }
193}
194
195impl Adapter for VeraLayer {
196 type Config = VeraConfig;
197
198 fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
199 let delta_w = self.compute_delta_w()?;
201
202 let input_dims = input.dims();
204 let batch_seq = input_dims[0] * input_dims[1];
205 let input_2d = input.reshape((batch_seq, self.in_features))?;
206
207 let mut vera_out = input_2d.matmul(&delta_w.t()?)?;
208
209 if let Some(bias) = &self.vera_b_bias {
211 let bias_expanded = bias.reshape((1, self.out_features))?;
212 vera_out = vera_out.broadcast_add(&bias_expanded)?;
213 }
214
215 let vera_out = vera_out.reshape((input_dims[0], input_dims[1], self.out_features))?;
216
217 match base_output {
219 Some(base) => Ok(base.broadcast_add(&vera_out)?),
220 None => Ok(vera_out),
221 }
222 }
223
224 fn num_parameters(&self) -> usize {
225 let mut params = self.config.r;
228 if self.vera_b_bias.is_some() {
229 params += self.out_features;
230 }
231 params
232 }
233
234 fn config(&self) -> &Self::Config {
235 &self.config
236 }
237}
238
239impl Mergeable for VeraLayer {
240 fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
241 let delta_w = self.compute_delta_w()?;
242 Ok(base_weight.broadcast_add(&delta_w)?)
243 }
244
245 fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
246 let delta_w = self.compute_delta_w()?;
247 Ok(merged_weight.broadcast_sub(&delta_w)?)
248 }
249}
250
251impl Trainable for VeraLayer {
252 fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
253 Ok(())
257 }
258
259 fn freeze(&mut self) {
260 self.frozen = true;
261 }
262
263 fn unfreeze(&mut self) {
264 self.frozen = false;
265 }
266
267 fn is_frozen(&self) -> bool {
268 self.frozen
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use candle_core::DType;
276
277 #[test]
278 fn test_vera_config_default() {
279 let config = VeraConfig::default();
280 assert_eq!(config.r, 256);
281 assert!((config.d_initial - 0.1).abs() < 1e-6);
282 assert!(config.validate().is_ok());
283 }
284
285 #[test]
286 fn test_vera_config_invalid_rank() {
287 let config = VeraConfig {
288 r: 0,
289 ..Default::default()
290 };
291 assert!(config.validate().is_err());
292 }
293
294 #[test]
295 fn test_vera_layer_creation() {
296 let config = VeraConfig {
297 r: 64,
298 ..Default::default()
299 };
300 let device = Device::Cpu;
301 let layer = VeraLayer::new(768, 768, config, &device);
302 assert!(layer.is_ok());
303 }
304
305 #[test]
306 fn test_vera_layer_with_bias() {
307 let config = VeraConfig {
308 r: 64,
309 ..Default::default()
310 };
311 let device = Device::Cpu;
312 let layer = VeraLayer::new_with_bias(768, 768, config, &device);
313 assert!(layer.is_ok());
314
315 let layer = layer.unwrap();
316 assert!(layer.vera_b_bias.is_some());
317 }
318
319 #[test]
320 fn test_vera_forward_shape() {
321 let config = VeraConfig {
322 r: 64,
323 ..Default::default()
324 };
325 let device = Device::Cpu;
326 let layer = VeraLayer::new(768, 768, config, &device).unwrap();
327
328 let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
329 let output = layer.forward(&input, None).unwrap();
330
331 assert_eq!(output.shape().dims(), &[1, 10, 768]);
332 }
333
334 #[test]
335 fn test_vera_forward_with_base_output() {
336 let config = VeraConfig {
337 r: 64,
338 ..Default::default()
339 };
340 let device = Device::Cpu;
341 let layer = VeraLayer::new(768, 768, config, &device).unwrap();
342
343 let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
344 let base_output = Tensor::ones(&[1, 10, 768], DType::F32, &device).unwrap();
345 let output = layer.forward(&input, Some(&base_output)).unwrap();
346
347 assert_eq!(output.shape().dims(), &[1, 10, 768]);
348 }
349
350 #[test]
351 fn test_vera_num_parameters() {
352 let config = VeraConfig {
353 r: 64,
354 ..Default::default()
355 };
356 let device = Device::Cpu;
357 let layer = VeraLayer::new(768, 768, config, &device).unwrap();
358
359 assert_eq!(layer.num_parameters(), 64);
361 }
362
363 #[test]
364 fn test_vera_num_parameters_with_bias() {
365 let config = VeraConfig {
366 r: 64,
367 ..Default::default()
368 };
369 let device = Device::Cpu;
370 let layer = VeraLayer::new_with_bias(768, 768, config, &device).unwrap();
371
372 assert_eq!(layer.num_parameters(), 64 + 768);
374 }
375
376 #[test]
377 fn test_vera_merge_unmerge() {
378 let config = VeraConfig {
379 r: 32,
380 d_initial: 0.01,
381 ..Default::default()
382 };
383 let device = Device::Cpu;
384 let layer = VeraLayer::new(64, 64, config, &device).unwrap();
385
386 let base_weight = Tensor::randn(0.0f32, 0.02, (64, 64), &device).unwrap();
387 let merged = layer.merge(&base_weight).unwrap();
388 let unmerged = layer.unmerge(&merged).unwrap();
389
390 let diff = unmerged.broadcast_sub(&base_weight).unwrap();
392 let max_diff: f32 = diff
393 .abs()
394 .unwrap()
395 .max(0)
396 .unwrap()
397 .max(0)
398 .unwrap()
399 .to_scalar()
400 .unwrap();
401 assert!(max_diff < 1e-5);
402 }
403
404 #[test]
405 fn test_vera_freeze_unfreeze() {
406 let config = VeraConfig::default();
407 let device = Device::Cpu;
408 let mut layer = VeraLayer::new(768, 768, config, &device).unwrap();
409
410 assert!(!layer.is_frozen());
411 layer.freeze();
412 assert!(layer.is_frozen());
413 layer.unfreeze();
414 assert!(!layer.is_frozen());
415 }
416
417 #[test]
418 fn test_vera_ultra_efficient() {
419 let config = VeraConfig {
421 r: 64,
422 ..Default::default()
423 };
424 let device = Device::Cpu;
425 let layer = VeraLayer::new(768, 768, config, &device).unwrap();
426
427 assert_eq!(layer.num_parameters(), 64);
430
431 let lora_equivalent_params = 64 * (768 + 768);
433 assert!(layer.num_parameters() < lora_equivalent_params / 1000);
434 }
435}