1use candle_core::{Device, IndexOp, Tensor};
10use candle_nn::VarMap;
11use serde::{Deserialize, Serialize};
12
13use crate::error::{PeftError, Result};
14use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct OftConfig {
19 pub r: usize,
21
22 #[serde(default)]
24 pub coft: bool,
25
26 #[serde(default = "default_eps")]
28 pub eps: f64,
29
30 #[serde(default)]
32 pub block_share: bool,
33
34 #[serde(default = "default_target_modules")]
36 pub target_modules: Vec<String>,
37
38 #[serde(default)]
46 pub use_exact_cayley: bool,
47}
48
49fn default_eps() -> f64 {
50 1e-5
51}
52
53fn default_target_modules() -> Vec<String> {
54 vec!["q_proj".into(), "v_proj".into()]
55}
56
57impl Default for OftConfig {
58 fn default() -> Self {
59 Self {
60 r: 8,
61 coft: false,
62 eps: default_eps(),
63 block_share: false,
64 target_modules: default_target_modules(),
65 use_exact_cayley: false,
66 }
67 }
68}
69
70impl AdapterConfig for OftConfig {
71 fn validate(&self) -> Result<()> {
72 if self.r == 0 {
73 return Err(PeftError::InvalidConfig(
74 "number of blocks (r) must be > 0".into(),
75 ));
76 }
77 if self.eps <= 0.0 {
78 return Err(PeftError::InvalidConfig("eps must be > 0".into()));
79 }
80 Ok(())
81 }
82}
83
84pub struct OftLayer {
92 oft_r: Tensor,
95 config: OftConfig,
97 features: usize,
99 block_size: usize,
101 num_blocks: usize,
103 frozen: bool,
105}
106
107impl OftLayer {
108 pub fn new(features: usize, config: OftConfig, device: &Device) -> Result<Self> {
119 config.validate()?;
120
121 if !features.is_multiple_of(config.r) {
122 return Err(PeftError::InvalidConfig(format!(
123 "features ({}) must be divisible by r ({})",
124 features, config.r
125 )));
126 }
127
128 let num_blocks = config.r;
129 let block_size = features / num_blocks;
130
131 let std = 0.01_f32;
134 let oft_r = Tensor::randn(0.0f32, std, (num_blocks, block_size, block_size), device)?;
135
136 Ok(Self {
137 oft_r,
138 config,
139 features,
140 block_size,
141 num_blocks,
142 frozen: false,
143 })
144 }
145
146 #[must_use]
148 pub fn num_blocks(&self) -> usize {
149 self.num_blocks
150 }
151
152 #[must_use]
154 pub fn block_size(&self) -> usize {
155 self.block_size
156 }
157
158 fn make_skew_symmetric(&self) -> Result<Tensor> {
160 let r_t = self.oft_r.transpose(1, 2)?;
161 let diff = self.oft_r.broadcast_sub(&r_t)?;
162 let two = Tensor::new(2.0f32, self.oft_r.device())?;
163 Ok(diff.broadcast_div(&two)?)
164 }
165
166 fn compute_orthogonal_matrix(&self) -> Result<Tensor> {
171 let q = self.make_skew_symmetric()?;
172 let device = q.device();
173
174 let eye = Tensor::eye(self.block_size, candle_core::DType::F32, device)?;
176 let eye = eye
177 .unsqueeze(0)?
178 .expand((self.num_blocks, self.block_size, self.block_size))?;
179
180 let i_minus_q = eye.broadcast_sub(&q)?;
182
183 let i_plus_q = eye.broadcast_add(&q)?;
185
186 let mut result_blocks = Vec::with_capacity(self.num_blocks);
187
188 for block_idx in 0..self.num_blocks {
189 let i_minus_q_block = i_minus_q.i(block_idx)?;
190 let i_plus_q_block = i_plus_q.i(block_idx)?;
191 let q_block = q.i(block_idx)?;
192
193 let inv = if self.config.use_exact_cayley {
194 self.compute_exact_inverse(&i_plus_q_block)?
201 } else {
202 let eye_block = Tensor::eye(self.block_size, candle_core::DType::F32, device)?;
205 let q_sq = q_block.matmul(&q_block)?;
206 eye_block.broadcast_sub(&q_block)?.broadcast_add(&q_sq)?
207 };
208
209 let r_block = i_minus_q_block.matmul(&inv)?;
211 result_blocks.push(r_block);
212 }
213
214 Ok(Tensor::stack(&result_blocks, 0)?)
216 }
217
218 fn compute_exact_inverse(&self, matrix: &Tensor) -> Result<Tensor> {
231 const NUM_ITERATIONS: usize = 5;
234
235 let device = matrix.device();
236 let eye = Tensor::eye(self.block_size, candle_core::DType::F32, device)?;
237 let two = Tensor::new(2.0f32, device)?;
238 let two_eye = eye.broadcast_mul(&two)?;
239
240 let mut x = eye.clone();
242
243 for _ in 0..NUM_ITERATIONS {
244 let ax = matrix.matmul(&x)?;
246 let factor = two_eye.broadcast_sub(&ax)?;
247 x = x.matmul(&factor)?;
248 }
249
250 Ok(x)
251 }
252
253 fn apply_block_diagonal(&self, input: &Tensor, orth_matrix: &Tensor) -> Result<Tensor> {
255 let input_dims = input.dims();
256 let batch_seq = input_dims[0] * input_dims[1];
257
258 let input_blocked = input.reshape((batch_seq, self.num_blocks, self.block_size))?;
260
261 let mut output_blocks = Vec::with_capacity(self.num_blocks);
269
270 for block_idx in 0..self.num_blocks {
271 let input_block = input_blocked.i((.., block_idx, ..))?;
273 let orth_block = orth_matrix.i(block_idx)?;
275
276 let output_block = input_block.matmul(&orth_block)?;
278 output_blocks.push(output_block);
279 }
280
281 let output_stacked = Tensor::stack(&output_blocks, 1)?; Ok(output_stacked.reshape((input_dims[0], input_dims[1], self.features))?)
284 }
285}
286
287impl Adapter for OftLayer {
288 type Config = OftConfig;
289
290 fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
291 let orth_matrix = self.compute_orthogonal_matrix()?;
293
294 let transformed = self.apply_block_diagonal(input, &orth_matrix)?;
296
297 match base_output {
300 Some(base) => {
301 let delta = transformed.broadcast_sub(input)?;
303 Ok(base.broadcast_add(&delta)?)
304 }
305 None => Ok(transformed),
306 }
307 }
308
309 fn num_parameters(&self) -> usize {
310 self.num_blocks * self.block_size * self.block_size
314 }
315
316 fn config(&self) -> &Self::Config {
317 &self.config
318 }
319}
320
321impl Mergeable for OftLayer {
322 fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
323 let orth_matrix = self.compute_orthogonal_matrix()?;
325
326 let full_orth = self.construct_full_matrix(&orth_matrix)?;
328
329 Ok(base_weight.matmul(&full_orth)?)
332 }
333
334 fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
335 let orth_matrix = self.compute_orthogonal_matrix()?;
337 let full_orth = self.construct_full_matrix(&orth_matrix)?;
338
339 let full_orth_t = full_orth.t()?;
341
342 Ok(merged_weight.matmul(&full_orth_t)?)
343 }
344}
345
346impl OftLayer {
347 fn construct_full_matrix(&self, blocks: &Tensor) -> Result<Tensor> {
349 let device = blocks.device();
350 let n = self.features;
351
352 let mut full_data = vec![0.0f32; n * n];
354
355 for block_idx in 0..self.num_blocks {
357 let block = blocks.i(block_idx)?;
358 let block_data: Vec<f32> = block.flatten_all()?.to_vec1()?;
359
360 let start = block_idx * self.block_size;
361
362 for i in 0..self.block_size {
363 for j in 0..self.block_size {
364 let row = start + i;
365 let col = start + j;
366 full_data[row * n + col] = block_data[i * self.block_size + j];
367 }
368 }
369 }
370
371 Ok(Tensor::from_vec(full_data, (n, n), device)?)
372 }
373}
374
375impl Trainable for OftLayer {
376 fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
377 Ok(())
378 }
379
380 fn freeze(&mut self) {
381 self.frozen = true;
382 }
383
384 fn unfreeze(&mut self) {
385 self.frozen = false;
386 }
387
388 fn is_frozen(&self) -> bool {
389 self.frozen
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use candle_core::{DType, IndexOp};
397
398 #[test]
399 fn test_oft_config_default() {
400 let config = OftConfig::default();
401 assert_eq!(config.r, 8);
402 assert!(!config.coft);
403 assert!(config.validate().is_ok());
404 }
405
406 #[test]
407 fn test_oft_config_invalid_r() {
408 let config = OftConfig {
409 r: 0,
410 ..Default::default()
411 };
412 assert!(config.validate().is_err());
413 }
414
415 #[test]
416 fn test_oft_layer_creation() {
417 let config = OftConfig {
418 r: 8,
419 ..Default::default()
420 };
421 let device = Device::Cpu;
422 let layer = OftLayer::new(64, config, &device);
424 assert!(layer.is_ok());
425
426 let layer = layer.unwrap();
427 assert_eq!(layer.num_blocks(), 8);
428 assert_eq!(layer.block_size(), 8);
429 }
430
431 #[test]
432 fn test_oft_layer_invalid_dimensions() {
433 let config = OftConfig {
434 r: 8,
435 ..Default::default()
436 };
437 let device = Device::Cpu;
438 let layer = OftLayer::new(65, config, &device);
440 assert!(layer.is_err());
441 }
442
443 #[test]
444 fn test_oft_forward_shape() {
445 let config = OftConfig {
446 r: 8,
447 ..Default::default()
448 };
449 let device = Device::Cpu;
450 let layer = OftLayer::new(64, config, &device).unwrap();
451
452 let input = Tensor::zeros(&[1, 10, 64], DType::F32, &device).unwrap();
453 let output = layer.forward(&input, None).unwrap();
454
455 assert_eq!(output.shape().dims(), &[1, 10, 64]);
456 }
457
458 #[test]
459 fn test_oft_forward_with_base_output() {
460 let config = OftConfig {
461 r: 8,
462 ..Default::default()
463 };
464 let device = Device::Cpu;
465 let layer = OftLayer::new(64, config, &device).unwrap();
466
467 let input = Tensor::zeros(&[1, 10, 64], DType::F32, &device).unwrap();
468 let base_output = Tensor::ones(&[1, 10, 64], DType::F32, &device).unwrap();
469 let output = layer.forward(&input, Some(&base_output)).unwrap();
470
471 assert_eq!(output.shape().dims(), &[1, 10, 64]);
472 }
473
474 #[test]
475 fn test_oft_num_parameters() {
476 let config = OftConfig {
477 r: 8,
478 ..Default::default()
479 };
480 let device = Device::Cpu;
481 let layer = OftLayer::new(64, config, &device).unwrap();
482
483 assert_eq!(layer.num_parameters(), 512);
485 }
486
487 #[test]
488 fn test_oft_skew_symmetric() {
489 let config = OftConfig {
490 r: 2,
491 ..Default::default()
492 };
493 let device = Device::Cpu;
494 let layer = OftLayer::new(8, config, &device).unwrap();
495
496 let skew = layer.make_skew_symmetric().unwrap();
497
498 for block_idx in 0..2 {
500 let q = skew.i(block_idx).unwrap();
501 let q_t = q.t().unwrap();
502 let sum = q.broadcast_add(&q_t).unwrap();
503 let max_val: f32 = sum
504 .abs()
505 .unwrap()
506 .max(0)
507 .unwrap()
508 .max(0)
509 .unwrap()
510 .to_scalar()
511 .unwrap();
512 assert!(max_val < 1e-5, "Matrix should be skew-symmetric");
513 }
514 }
515
516 #[test]
517 fn test_oft_freeze_unfreeze() {
518 let config = OftConfig::default();
519 let device = Device::Cpu;
520 let mut layer = OftLayer::new(64, config, &device).unwrap();
521
522 assert!(!layer.is_frozen());
523 layer.freeze();
524 assert!(layer.is_frozen());
525 layer.unfreeze();
526 assert!(!layer.is_frozen());
527 }
528
529 #[test]
530 fn test_oft_merge_unmerge() {
531 let config = OftConfig {
532 r: 4,
533 ..Default::default()
534 };
535 let device = Device::Cpu;
536 let layer = OftLayer::new(16, config, &device).unwrap();
537
538 let base_weight = Tensor::eye(16, DType::F32, &device).unwrap();
539 let merged = layer.merge(&base_weight).unwrap();
540 let unmerged = layer.unmerge(&merged).unwrap();
541
542 let diff = unmerged.broadcast_sub(&base_weight).unwrap();
544 let max_diff: f32 = diff
545 .abs()
546 .unwrap()
547 .max(0)
548 .unwrap()
549 .max(0)
550 .unwrap()
551 .to_scalar()
552 .unwrap();
553 assert!(max_diff < 0.1, "Max diff: {max_diff}"); }
555
556 #[test]
557 fn test_oft_exact_cayley_config() {
558 let config = OftConfig {
560 r: 4,
561 use_exact_cayley: true,
562 ..Default::default()
563 };
564 assert!(config.use_exact_cayley);
565 assert!(config.validate().is_ok());
566 }
567
568 #[test]
569 fn test_oft_exact_cayley_forward() {
570 let config = OftConfig {
572 r: 4,
573 use_exact_cayley: true,
574 ..Default::default()
575 };
576 let device = Device::Cpu;
577 let layer = OftLayer::new(16, config, &device).unwrap();
578
579 let input = Tensor::zeros(&[1, 10, 16], DType::F32, &device).unwrap();
580 let output = layer.forward(&input, None).unwrap();
581
582 assert_eq!(output.shape().dims(), &[1, 10, 16]);
583 }
584
585 #[test]
586 fn test_oft_exact_cayley_merge_unmerge() {
587 const EXACT_METHOD_TOLERANCE: f32 = 0.05;
592
593 let config = OftConfig {
594 r: 4,
595 use_exact_cayley: true,
596 ..Default::default()
597 };
598 let device = Device::Cpu;
599 let layer = OftLayer::new(16, config, &device).unwrap();
600
601 let base_weight = Tensor::eye(16, DType::F32, &device).unwrap();
602 let merged = layer.merge(&base_weight).unwrap();
603 let unmerged = layer.unmerge(&merged).unwrap();
604
605 let diff = unmerged.broadcast_sub(&base_weight).unwrap();
606 let max_diff: f32 = diff
607 .abs()
608 .unwrap()
609 .max(0)
610 .unwrap()
611 .max(0)
612 .unwrap()
613 .to_scalar()
614 .unwrap();
615 assert!(
616 max_diff < EXACT_METHOD_TOLERANCE,
617 "Max diff with exact Cayley: {max_diff}"
618 );
619 }
620
621 #[test]
622 fn test_oft_approx_vs_exact_cayley() {
623 let device = Device::Cpu;
625
626 let config_approx = OftConfig {
628 r: 4,
629 use_exact_cayley: false,
630 ..Default::default()
631 };
632 let layer_approx = OftLayer::new(16, config_approx, &device).unwrap();
633
634 let config_exact = OftConfig {
636 r: 4,
637 use_exact_cayley: true,
638 ..Default::default()
639 };
640 let layer_exact = OftLayer::new(16, config_exact, &device).unwrap();
641
642 let input = Tensor::randn(0.0f32, 1.0, (1, 10, 16), &device).unwrap();
644
645 let output_approx = layer_approx.forward(&input, None).unwrap();
646 let output_exact = layer_exact.forward(&input, None).unwrap();
647
648 assert_eq!(output_approx.shape().dims(), &[1, 10, 16]);
649 assert_eq!(output_exact.shape().dims(), &[1, 10, 16]);
650 }
651}