1#![allow(clippy::uninlined_format_args)]
13
14use std::collections::HashMap;
15
16use candle_core::{Device, IndexOp, Tensor, Var};
17use candle_nn::VarMap;
18use serde::{Deserialize, Serialize};
19
20use crate::error::{PeftError, Result};
21use crate::io::SaveLoad;
22use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct BoftConfig {
27 #[serde(default)]
30 pub boft_block_size: usize,
31
32 #[serde(default = "default_boft_block_num")]
35 pub boft_block_num: usize,
36
37 #[serde(default = "default_boft_n_butterfly_factor")]
40 pub boft_n_butterfly_factor: usize,
41
42 #[serde(default)]
46 pub boft_dropout: f64,
47
48 #[serde(default = "default_eps")]
50 pub eps: f64,
51
52 #[serde(default = "default_target_modules")]
54 pub target_modules: Vec<String>,
55}
56
57fn default_boft_block_num() -> usize {
58 4
59}
60
61fn default_boft_n_butterfly_factor() -> usize {
62 1
63}
64
65fn default_eps() -> f64 {
66 1e-5
67}
68
69fn default_target_modules() -> Vec<String> {
70 vec!["q_proj".into(), "v_proj".into()]
71}
72
73impl Default for BoftConfig {
74 fn default() -> Self {
75 Self {
76 boft_block_size: 0,
77 boft_block_num: default_boft_block_num(),
78 boft_n_butterfly_factor: default_boft_n_butterfly_factor(),
79 boft_dropout: 0.0,
80 eps: default_eps(),
81 target_modules: default_target_modules(),
82 }
83 }
84}
85
86impl AdapterConfig for BoftConfig {
87 fn validate(&self) -> Result<()> {
88 if self.boft_block_size == 0 && self.boft_block_num == 0 {
89 return Err(PeftError::InvalidConfig(
90 "Either boft_block_size or boft_block_num must be > 0".into(),
91 ));
92 }
93 if self.boft_block_size != 0 && self.boft_block_num != 0 {
94 return Err(PeftError::InvalidConfig(
95 "Only one of boft_block_size or boft_block_num should be specified".into(),
96 ));
97 }
98 if self.boft_n_butterfly_factor == 0 {
99 return Err(PeftError::InvalidConfig(
100 "boft_n_butterfly_factor must be > 0".into(),
101 ));
102 }
103 if self.eps <= 0.0 {
104 return Err(PeftError::InvalidConfig("eps must be > 0".into()));
105 }
106 if !(0.0..=1.0).contains(&self.boft_dropout) {
107 return Err(PeftError::InvalidConfig(
108 "boft_dropout must be in [0.0, 1.0]".into(),
109 ));
110 }
111 Ok(())
112 }
113}
114
115pub struct BoftLayer {
123 boft_r: Tensor,
126
127 boft_s: Tensor,
130
131 boft_p: Tensor,
134
135 config: BoftConfig,
137
138 out_features: usize,
140
141 block_size: usize,
143
144 block_num: usize,
146
147 n_butterfly_factor: usize,
149
150 frozen: bool,
152}
153
154impl BoftLayer {
155 pub fn new(
166 in_features: usize,
167 out_features: usize,
168 config: BoftConfig,
169 device: &Device,
170 ) -> Result<Self> {
171 config.validate()?;
172
173 let (block_size, block_num) = if config.boft_block_size == 0 {
175 if !in_features.is_multiple_of(config.boft_block_num) {
177 return Err(PeftError::InvalidConfig(format!(
178 "in_features ({}) must be divisible by boft_block_num ({})",
179 in_features, config.boft_block_num
180 )));
181 }
182 (in_features / config.boft_block_num, config.boft_block_num)
183 } else {
184 if !in_features.is_multiple_of(config.boft_block_size) {
186 return Err(PeftError::InvalidConfig(format!(
187 "in_features ({}) must be divisible by boft_block_size ({})",
188 in_features, config.boft_block_size
189 )));
190 }
191 (config.boft_block_size, in_features / config.boft_block_size)
192 };
193
194 let n_butterfly_factor = config.boft_n_butterfly_factor.saturating_sub(1);
196
197 if n_butterfly_factor > 0 {
198 #[allow(clippy::cast_possible_truncation)]
200 let divisor = 2_usize.pow(n_butterfly_factor as u32);
201 if block_num % divisor != 0 {
202 return Err(PeftError::InvalidConfig(format!(
203 "boft_block_num ({}) must be divisible by 2^{} = {}",
204 block_num, n_butterfly_factor, divisor
205 )));
206 }
207
208 if in_features < block_size * divisor {
210 return Err(PeftError::InvalidConfig(format!(
211 "in_features ({}) must be >= block_size * 2^{} = {}",
212 in_features,
213 n_butterfly_factor,
214 block_size * divisor
215 )));
216 }
217
218 if block_num % 2 != 0 {
220 return Err(PeftError::InvalidConfig(format!(
221 "boft_block_num ({}) must be even for butterfly factorization",
222 block_num
223 )));
224 }
225 if block_size % 2 != 0 {
226 return Err(PeftError::InvalidConfig(format!(
227 "boft_block_size ({}) must be even for butterfly factorization",
228 block_size
229 )));
230 }
231 }
232
233 let std = 0.1_f32;
236 let boft_r = Tensor::randn(
237 0.0f32,
238 std,
239 (n_butterfly_factor + 1, block_num, block_size, block_size),
240 device,
241 )?;
242
243 let boft_s = Tensor::ones((out_features, 1), candle_core::DType::F32, device)?;
246
247 let boft_p = Self::compute_permutation_matrices(
249 in_features,
250 block_num,
251 block_size,
252 n_butterfly_factor,
253 device,
254 )?;
255
256 Ok(Self {
257 boft_r,
258 boft_s,
259 boft_p,
260 config,
261 out_features,
262 block_size,
263 block_num,
264 n_butterfly_factor,
265 frozen: false,
266 })
267 }
268
269 fn compute_permutation_matrices(
271 n: usize,
272 block_num: usize,
273 block_size: usize,
274 n_butterfly_factor: usize,
275 device: &Device,
276 ) -> Result<Tensor> {
277 let mut permutation_matrices = Vec::new();
278
279 for i in 0..=n_butterfly_factor {
280 #[allow(clippy::cast_possible_truncation)]
281 let current_block_num = block_num / (2_usize.pow(i as u32));
282 #[allow(clippy::cast_possible_truncation)]
283 let current_block_size = block_size * (2_usize.pow(i as u32));
284
285 let perm_indices = Self::block_butterfly_perm(
286 n,
287 current_block_num,
288 current_block_size / 2,
289 n_butterfly_factor,
290 )?;
291
292 let perm_matrix = Self::perm_to_matrix(&perm_indices, n, device)?;
293 permutation_matrices.push(perm_matrix);
294 }
295
296 Ok(Tensor::stack(&permutation_matrices, 0)?)
298 }
299
300 fn block_butterfly_perm(
305 n: usize,
306 b: usize,
307 r: usize,
308 n_butterfly_factor: usize,
309 ) -> Result<Vec<usize>> {
310 if n_butterfly_factor == 0 {
312 return Ok((0..n).collect());
313 }
314
315 if b * r * 2 > n {
317 return Err(PeftError::InvalidConfig(
318 "Invalid number of blocks for butterfly permutation".into(),
319 ));
320 }
321
322 let block_size = n / b;
323 let mut indices: Vec<usize> = (0..n).collect();
324
325 let sorted_order = Self::sort_block(block_size, r);
327
328 for i in (0..n).step_by(block_size) {
330 let block_end = i + block_size;
331 let tmp_indices: Vec<usize> = indices[i..block_end].to_vec();
332 for (j, &idx) in sorted_order.iter().enumerate() {
333 indices[i + j] = tmp_indices[idx];
334 }
335 }
336
337 Ok(indices)
338 }
339
340 fn sort_block(block_size: usize, r: usize) -> Vec<usize> {
342 let step = block_size / r;
343 let mut sorted_order = vec![0; block_size];
344
345 let mut evens: Vec<usize> = (0..step).step_by(2).collect();
347 let mut odds: Vec<usize> = (1..step).step_by(2).collect();
349
350 evens.append(&mut odds);
351 let sorted_seq = evens;
352
353 for (i, &pos) in sorted_seq.iter().enumerate() {
354 for j in 0..r {
355 sorted_order[i * r + j] = pos * r + j;
356 }
357 }
358
359 sorted_order
360 }
361
362 fn perm_to_matrix(indices: &[usize], n: usize, device: &Device) -> Result<Tensor> {
364 let mut data = vec![0.0f32; n * n];
365
366 for (i, &idx) in indices.iter().enumerate() {
367 data[i * n + idx] = 1.0;
368 }
369
370 Ok(Tensor::from_vec(data, (n, n), device)?)
371 }
372
373 fn make_skew_symmetric(&self) -> Result<Tensor> {
375 let r_t = self.boft_r.transpose(2, 3)?;
377 let diff = self.boft_r.broadcast_sub(&r_t)?;
378 let two = Tensor::new(2.0f32, self.boft_r.device())?;
379 Ok(diff.broadcast_div(&two)?)
380 }
381
382 fn cayley_batch(skew_mat: &Tensor) -> Result<Tensor> {
387 let device = skew_mat.device();
388 let shape = skew_mat.dims();
389 let batch_size = shape[0];
390 let mat_size = shape[1];
391
392 let eye = Tensor::eye(mat_size, candle_core::DType::F32, device)?;
394 let eye = eye.unsqueeze(0)?.expand((batch_size, mat_size, mat_size))?;
395
396 let i_minus_q = eye.broadcast_sub(skew_mat)?;
398
399 let _i_plus_q = eye.broadcast_add(skew_mat)?;
401
402 let mut result_blocks = Vec::with_capacity(batch_size);
405
406 for batch_idx in 0..batch_size {
407 let i_minus_q_block = i_minus_q.i(batch_idx)?;
408
409 let q_block = skew_mat.i(batch_idx)?;
414 let q_sq = q_block.matmul(&q_block)?;
415 let inv_approx = eye
416 .i(batch_idx)?
417 .broadcast_sub(&q_block)?
418 .broadcast_add(&q_sq)?;
419
420 let result = i_minus_q_block.matmul(&inv_approx)?;
421 result_blocks.push(result);
422 }
423
424 Ok(Tensor::stack(&result_blocks, 0)?)
425 }
426
427 fn block_diag(blocks: &Tensor) -> Result<Tensor> {
432 let device = blocks.device();
433 let shape = blocks.dims();
434 let num_blocks = shape[0];
435 let block_size = shape[1];
436 let total_size = num_blocks * block_size;
437
438 let mut data = vec![0.0f32; total_size * total_size];
440
441 for block_idx in 0..num_blocks {
443 let block = blocks.i(block_idx)?;
444 let block_data: Vec<f32> = block.flatten_all()?.to_vec1()?;
445
446 let offset = block_idx * block_size;
447 for i in 0..block_size {
448 for j in 0..block_size {
449 let row = offset + i;
450 let col = offset + j;
451 data[row * total_size + col] = block_data[i * block_size + j];
452 }
453 }
454 }
455
456 Ok(Tensor::from_vec(data, (total_size, total_size), device)?)
457 }
458
459 fn compute_butterfly_oft_matrix(&self) -> Result<Tensor> {
464 let q = self.make_skew_symmetric()?;
466
467 let mut butterfly_matrices = Vec::new();
469
470 for factor_idx in 0..=self.n_butterfly_factor {
471 let q_factor = q.i(factor_idx)?; let shape = q_factor.dims();
476 let d = shape[0];
477 let h = shape[1];
478 let q_reshaped = q_factor.reshape((d, h, h))?;
479
480 let orth_blocks = Self::cayley_batch(&q_reshaped)?;
482
483 let block_diag_mat = Self::block_diag(&orth_blocks)?;
485
486 let perm = self.boft_p.i(factor_idx)?;
488 let perm_t = perm.t()?;
489
490 let tmp = block_diag_mat.matmul(&perm_t)?;
492 let butterfly_mat = perm.matmul(&tmp)?;
493
494 butterfly_matrices.push(butterfly_mat);
495 }
496
497 let mut result = butterfly_matrices[0].clone();
499 for butterfly_mat in butterfly_matrices.iter().skip(1) {
500 result = butterfly_mat.matmul(&result)?;
501 }
502
503 Ok(result)
504 }
505
506 #[must_use]
508 pub fn block_num(&self) -> usize {
509 self.block_num
510 }
511
512 #[must_use]
514 pub fn block_size(&self) -> usize {
515 self.block_size
516 }
517
518 #[must_use]
520 pub fn n_butterfly_factor(&self) -> usize {
521 self.n_butterfly_factor + 1 }
523}
524
525impl Adapter for BoftLayer {
526 type Config = BoftConfig;
527
528 fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
529 let butterfly_oft = self.compute_butterfly_oft_matrix()?;
531
532 let input_shape = input.dims();
534 let is_3d = input_shape.len() == 3;
535
536 let input_2d = if is_3d {
537 input.reshape((input_shape[0] * input_shape[1], input_shape[2]))?
539 } else {
540 input.clone()
541 };
542
543 let transformed = input_2d.matmul(&butterfly_oft.t()?)?;
545
546 let transformed = if is_3d {
548 transformed.reshape(input_shape)?
549 } else {
550 transformed
551 };
552
553 let scaled = transformed.broadcast_mul(&self.boft_s.t()?)?;
555
556 if let Some(base) = base_output {
558 Ok(scaled.broadcast_add(base)?)
559 } else {
560 Ok(scaled)
561 }
562 }
563
564 fn num_parameters(&self) -> usize {
565 let r_params =
567 (self.n_butterfly_factor + 1) * self.block_num * self.block_size * self.block_size;
568
569 let s_params = self.out_features;
571
572 r_params + s_params
573 }
574
575 fn config(&self) -> &Self::Config {
576 &self.config
577 }
578}
579
580impl Mergeable for BoftLayer {
581 fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
582 let butterfly_oft = self.compute_butterfly_oft_matrix()?;
584
585 let weight_t = base_weight.t()?;
588 let merged_t = butterfly_oft.matmul(&weight_t)?;
589 let merged = merged_t.t()?;
590
591 Ok(merged.broadcast_mul(&self.boft_s)?)
593 }
594
595 fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
596 let butterfly_oft = self.compute_butterfly_oft_matrix()?;
598
599 let unscaled = merged_weight.broadcast_div(&self.boft_s)?;
601 let unscaled_t = unscaled.t()?;
602 let butterfly_oft_t = butterfly_oft.t()?;
603 let unmerged_t = butterfly_oft_t.matmul(&unscaled_t)?;
604
605 Ok(unmerged_t.t()?)
606 }
607}
608
609impl Trainable for BoftLayer {
610 #[allow(clippy::similar_names)]
611 fn register_parameters(&self, var_map: &mut VarMap, prefix: &str) -> Result<()> {
612 let boft_r_name = format!("{prefix}.boft_r");
613 let boft_s_name = format!("{prefix}.boft_s");
614
615 var_map
616 .data()
617 .lock()
618 .unwrap()
619 .insert(boft_r_name, Var::from_tensor(&self.boft_r)?);
620 var_map
621 .data()
622 .lock()
623 .unwrap()
624 .insert(boft_s_name, Var::from_tensor(&self.boft_s)?);
625
626 Ok(())
627 }
628
629 fn freeze(&mut self) {
630 self.frozen = true;
631 }
632
633 fn unfreeze(&mut self) {
634 self.frozen = false;
635 }
636
637 fn is_frozen(&self) -> bool {
638 self.frozen
639 }
640}
641
642impl SaveLoad for BoftLayer {
643 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
644 let mut state_dict = HashMap::new();
645 state_dict.insert("boft_r".to_string(), self.boft_r.clone());
646 state_dict.insert("boft_s".to_string(), self.boft_s.clone());
647 Ok(state_dict)
648 }
649
650 fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
651 if let Some(boft_r) = state_dict.get("boft_r") {
652 self.boft_r = boft_r.clone();
653 }
654 if let Some(boft_s) = state_dict.get("boft_s") {
655 self.boft_s = boft_s.clone();
656 }
657 Ok(())
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664 use candle_core::Device;
665
666 #[test]
667 fn test_boft_config_default() {
668 let config = BoftConfig::default();
669 assert_eq!(config.boft_block_num, 4);
670 assert_eq!(config.boft_n_butterfly_factor, 1);
671 assert!((config.boft_dropout - 0.0).abs() < f64::EPSILON);
672 }
673
674 #[test]
675 fn test_boft_config_validation() {
676 let mut config = BoftConfig::default();
677
678 assert!(config.validate().is_ok());
680
681 config.boft_block_size = 8;
683 config.boft_block_num = 4;
684 assert!(config.validate().is_err());
685
686 config.boft_block_size = 0;
688 config.boft_block_num = 0;
689 assert!(config.validate().is_err());
690
691 config.boft_block_num = 4;
693 config.boft_n_butterfly_factor = 0;
694 assert!(config.validate().is_err());
695 }
696
697 #[test]
698 fn test_boft_layer_creation() -> Result<()> {
699 let device = Device::Cpu;
700 let config = BoftConfig {
701 boft_block_size: 0,
702 boft_block_num: 4,
703 boft_n_butterfly_factor: 1,
704 ..Default::default()
705 };
706
707 let layer = BoftLayer::new(64, 64, config, &device)?;
708 assert_eq!(layer.block_num(), 4);
709 assert_eq!(layer.block_size(), 16);
710 assert_eq!(layer.n_butterfly_factor(), 1);
711
712 Ok(())
713 }
714
715 #[test]
716 fn test_boft_layer_forward() -> Result<()> {
717 let device = Device::Cpu;
718 let config = BoftConfig {
719 boft_block_size: 0,
720 boft_block_num: 4,
721 boft_n_butterfly_factor: 1,
722 ..Default::default()
723 };
724
725 let layer = BoftLayer::new(64, 64, config, &device)?;
726 let input = Tensor::randn(0.0f32, 1.0f32, (2, 10, 64), &device)?;
727 let output = layer.forward(&input, None)?;
728
729 assert_eq!(output.dims(), &[2, 10, 64]);
730
731 Ok(())
732 }
733
734 #[test]
735 fn test_boft_parameter_count() -> Result<()> {
736 let device = Device::Cpu;
737 let config = BoftConfig {
738 boft_block_size: 0,
739 boft_block_num: 4,
740 boft_n_butterfly_factor: 1,
741 ..Default::default()
742 };
743
744 let layer = BoftLayer::new(64, 64, config, &device)?;
745
746 assert_eq!(layer.num_parameters(), 1088);
751
752 Ok(())
753 }
754
755 #[test]
756 fn test_boft_block_butterfly_perm() -> Result<()> {
757 let perm = BoftLayer::block_butterfly_perm(8, 4, 1, 0)?;
759 assert_eq!(perm, vec![0, 1, 2, 3, 4, 5, 6, 7]);
760
761 let perm = BoftLayer::block_butterfly_perm(8, 4, 1, 1)?;
763 assert_eq!(perm.len(), 8);
765
766 Ok(())
767 }
768
769 #[test]
770 fn test_boft_merge_unmerge() -> Result<()> {
771 let device = Device::Cpu;
772 let config = BoftConfig {
773 boft_block_size: 0,
774 boft_block_num: 4,
775 boft_n_butterfly_factor: 1,
776 ..Default::default()
777 };
778
779 let layer = BoftLayer::new(64, 64, config, &device)?;
780 let base_weight = Tensor::randn(0.0f32, 1.0f32, (64, 64), &device)?;
781
782 let merged = layer.merge(&base_weight)?;
784 assert_eq!(merged.dims(), base_weight.dims());
785
786 let unmerged = layer.unmerge(&merged)?;
788 assert_eq!(unmerged.dims(), base_weight.dims());
789
790 Ok(())
791 }
792
793 #[test]
794 fn test_boft_invalid_features() {
795 let device = Device::Cpu;
796 let config = BoftConfig {
797 boft_block_size: 0,
798 boft_block_num: 5, boft_n_butterfly_factor: 1,
800 ..Default::default()
801 };
802
803 let result = BoftLayer::new(64, 64, config, &device);
804 assert!(result.is_err());
805 }
806
807 #[test]
808 fn test_boft_butterfly_factor_validation() {
809 let device = Device::Cpu;
810
811 let config = BoftConfig {
813 boft_block_size: 0,
814 boft_block_num: 3, boft_n_butterfly_factor: 2,
816 ..Default::default()
817 };
818
819 let result = BoftLayer::new(64, 64, config, &device);
820 assert!(result.is_err());
821 }
822
823 #[test]
824 fn test_boft_freeze_unfreeze() -> Result<()> {
825 let device = Device::Cpu;
826 let config = BoftConfig::default();
827 let mut layer = BoftLayer::new(64, 64, config, &device)?;
828
829 assert!(!layer.is_frozen());
830 layer.freeze();
831 assert!(layer.is_frozen());
832 layer.unfreeze();
833 assert!(!layer.is_frozen());
834
835 Ok(())
836 }
837}