Skip to main content

peft_rs/adapters/
boft.rs

1//! BOFT (Butterfly Orthogonal Fine-Tuning) implementation.
2//!
3//! BOFT extends OFT by using butterfly factorization to achieve even more
4//! parameter efficiency. It reduces the parameter complexity from O(d²) to
5//! O(d log d) while maintaining the benefits of orthogonal transformations.
6//!
7//! The butterfly structure is inspired by the Cooley-Tukey FFT algorithm
8//! and enables efficient O(n log n) matrix multiplication.
9//!
10//! Reference: <https://arxiv.org/abs/2311.06243>
11
12#![allow(clippy::uninlined_format_args)]
13
14use std::collections::HashMap;
15
16use candle_core::{Device, IndexOp, Tensor, Var};
17use candle_nn::VarMap;
18use serde::{Deserialize, Serialize};
19
20use crate::error::{PeftError, Result};
21use crate::io::SaveLoad;
22use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
23
24/// Configuration for BOFT adapters.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct BoftConfig {
27    /// Block size for butterfly factorization.
28    /// If 0, computed from `boft_block_num`. Must divide input features.
29    #[serde(default)]
30    pub boft_block_size: usize,
31
32    /// Number of blocks for butterfly factorization.
33    /// If 0, computed from `boft_block_size`. Must divide input features.
34    #[serde(default = "default_boft_block_num")]
35    pub boft_block_num: usize,
36
37    /// Number of butterfly factors to use (1 = no butterfly, higher = more expressive).
38    /// Must satisfy: `boft_block_num` must be divisible by `2^(n_butterfly_factor-1)`.
39    #[serde(default = "default_boft_n_butterfly_factor")]
40    pub boft_n_butterfly_factor: usize,
41
42    /// Dropout probability for multiplicative dropout (0.0 = no dropout).
43    /// Note: Dropout is currently not implemented in forward pass.
44    /// This parameter is reserved for future implementation.
45    #[serde(default)]
46    pub boft_dropout: f64,
47
48    /// Small constant for numerical stability.
49    #[serde(default = "default_eps")]
50    pub eps: f64,
51
52    /// Target modules to apply BOFT to.
53    #[serde(default = "default_target_modules")]
54    pub target_modules: Vec<String>,
55}
56
57fn default_boft_block_num() -> usize {
58    4
59}
60
61fn default_boft_n_butterfly_factor() -> usize {
62    1
63}
64
65fn default_eps() -> f64 {
66    1e-5
67}
68
69fn default_target_modules() -> Vec<String> {
70    vec!["q_proj".into(), "v_proj".into()]
71}
72
73impl Default for BoftConfig {
74    fn default() -> Self {
75        Self {
76            boft_block_size: 0,
77            boft_block_num: default_boft_block_num(),
78            boft_n_butterfly_factor: default_boft_n_butterfly_factor(),
79            boft_dropout: 0.0,
80            eps: default_eps(),
81            target_modules: default_target_modules(),
82        }
83    }
84}
85
86impl AdapterConfig for BoftConfig {
87    fn validate(&self) -> Result<()> {
88        if self.boft_block_size == 0 && self.boft_block_num == 0 {
89            return Err(PeftError::InvalidConfig(
90                "Either boft_block_size or boft_block_num must be > 0".into(),
91            ));
92        }
93        if self.boft_block_size != 0 && self.boft_block_num != 0 {
94            return Err(PeftError::InvalidConfig(
95                "Only one of boft_block_size or boft_block_num should be specified".into(),
96            ));
97        }
98        if self.boft_n_butterfly_factor == 0 {
99            return Err(PeftError::InvalidConfig(
100                "boft_n_butterfly_factor must be > 0".into(),
101            ));
102        }
103        if self.eps <= 0.0 {
104            return Err(PeftError::InvalidConfig("eps must be > 0".into()));
105        }
106        if !(0.0..=1.0).contains(&self.boft_dropout) {
107            return Err(PeftError::InvalidConfig(
108                "boft_dropout must be in [0.0, 1.0]".into(),
109            ));
110        }
111        Ok(())
112    }
113}
114
115/// BOFT layer implementing Butterfly Orthogonal Fine-Tuning.
116///
117/// Uses butterfly factorization of block-diagonal orthogonal matrices:
118/// `W' = W @ R` where R is constructed from butterfly factors.
119///
120/// Each butterfly factor is: `P @ BlockDiag(O_i) @ P^T`
121/// where `P` is a permutation matrix and `O_i` are orthogonal matrices.
122pub struct BoftLayer {
123    /// Skew-symmetric parameters for Cayley parameterization.
124    /// Shape: `[n_butterfly_factor + 1, block_num, block_size, block_size]`
125    boft_r: Tensor,
126
127    /// Scaling factors for output features.
128    /// Shape: `[out_features, 1]`
129    boft_s: Tensor,
130
131    /// Precomputed permutation matrices for butterfly structure.
132    /// Shape: `[n_butterfly_factor + 1, features, features]`
133    boft_p: Tensor,
134
135    /// Configuration
136    config: BoftConfig,
137
138    /// Output dimension
139    out_features: usize,
140
141    /// Size of each block
142    block_size: usize,
143
144    /// Number of blocks
145    block_num: usize,
146
147    /// Number of butterfly factors (config value - 1, for internal use)
148    n_butterfly_factor: usize,
149
150    /// Whether gradients are disabled
151    frozen: bool,
152}
153
154impl BoftLayer {
155    /// Create a new BOFT layer.
156    ///
157    /// # Arguments
158    /// * `in_features` - Input dimension
159    /// * `out_features` - Output dimension  
160    /// * `config` - BOFT configuration
161    /// * `device` - Device to create tensors on
162    ///
163    /// # Errors
164    /// Returns error if configuration is invalid or tensor initialization fails.
165    pub fn new(
166        in_features: usize,
167        out_features: usize,
168        config: BoftConfig,
169        device: &Device,
170    ) -> Result<Self> {
171        config.validate()?;
172
173        // Compute block_size and block_num based on config
174        let (block_size, block_num) = if config.boft_block_size == 0 {
175            // Compute block_size from block_num
176            if !in_features.is_multiple_of(config.boft_block_num) {
177                return Err(PeftError::InvalidConfig(format!(
178                    "in_features ({}) must be divisible by boft_block_num ({})",
179                    in_features, config.boft_block_num
180                )));
181            }
182            (in_features / config.boft_block_num, config.boft_block_num)
183        } else {
184            // Compute block_num from block_size
185            if !in_features.is_multiple_of(config.boft_block_size) {
186                return Err(PeftError::InvalidConfig(format!(
187                    "in_features ({}) must be divisible by boft_block_size ({})",
188                    in_features, config.boft_block_size
189                )));
190            }
191            (config.boft_block_size, in_features / config.boft_block_size)
192        };
193
194        // Butterfly factor validation (internally we use n-1)
195        let n_butterfly_factor = config.boft_n_butterfly_factor.saturating_sub(1);
196
197        if n_butterfly_factor > 0 {
198            // Check block_num divisibility
199            #[allow(clippy::cast_possible_truncation)]
200            let divisor = 2_usize.pow(n_butterfly_factor as u32);
201            if block_num % divisor != 0 {
202                return Err(PeftError::InvalidConfig(format!(
203                    "boft_block_num ({}) must be divisible by 2^{} = {}",
204                    block_num, n_butterfly_factor, divisor
205                )));
206            }
207
208            // Check that we have enough features
209            if in_features < block_size * divisor {
210                return Err(PeftError::InvalidConfig(format!(
211                    "in_features ({}) must be >= block_size * 2^{} = {}",
212                    in_features,
213                    n_butterfly_factor,
214                    block_size * divisor
215                )));
216            }
217
218            // Block size and block num must be even for butterfly
219            if block_num % 2 != 0 {
220                return Err(PeftError::InvalidConfig(format!(
221                    "boft_block_num ({}) must be even for butterfly factorization",
222                    block_num
223                )));
224            }
225            if block_size % 2 != 0 {
226                return Err(PeftError::InvalidConfig(format!(
227                    "boft_block_size ({}) must be even for butterfly factorization",
228                    block_size
229                )));
230            }
231        }
232
233        // Initialize skew-symmetric parameters
234        // Shape: [n_butterfly_factor+1, block_num, block_size, block_size]
235        let std = 0.1_f32;
236        let boft_r = Tensor::randn(
237            0.0f32,
238            std,
239            (n_butterfly_factor + 1, block_num, block_size, block_size),
240            device,
241        )?;
242
243        // Initialize scaling factors to ones
244        // Shape: [out_features, 1]
245        let boft_s = Tensor::ones((out_features, 1), candle_core::DType::F32, device)?;
246
247        // Precompute permutation matrices
248        let boft_p = Self::compute_permutation_matrices(
249            in_features,
250            block_num,
251            block_size,
252            n_butterfly_factor,
253            device,
254        )?;
255
256        Ok(Self {
257            boft_r,
258            boft_s,
259            boft_p,
260            config,
261            out_features,
262            block_size,
263            block_num,
264            n_butterfly_factor,
265            frozen: false,
266        })
267    }
268
269    /// Compute all permutation matrices for butterfly structure.
270    fn compute_permutation_matrices(
271        n: usize,
272        block_num: usize,
273        block_size: usize,
274        n_butterfly_factor: usize,
275        device: &Device,
276    ) -> Result<Tensor> {
277        let mut permutation_matrices = Vec::new();
278
279        for i in 0..=n_butterfly_factor {
280            #[allow(clippy::cast_possible_truncation)]
281            let current_block_num = block_num / (2_usize.pow(i as u32));
282            #[allow(clippy::cast_possible_truncation)]
283            let current_block_size = block_size * (2_usize.pow(i as u32));
284
285            let perm_indices = Self::block_butterfly_perm(
286                n,
287                current_block_num,
288                current_block_size / 2,
289                n_butterfly_factor,
290            )?;
291
292            let perm_matrix = Self::perm_to_matrix(&perm_indices, n, device)?;
293            permutation_matrices.push(perm_matrix);
294        }
295
296        // Stack into single tensor [n_butterfly_factor+1, n, n]
297        Ok(Tensor::stack(&permutation_matrices, 0)?)
298    }
299
300    /// Generate block butterfly permutation indices.
301    ///
302    /// This creates a permutation that reorders blocks in a butterfly pattern,
303    /// separating even and odd positioned blocks.
304    fn block_butterfly_perm(
305        n: usize,
306        b: usize,
307        r: usize,
308        n_butterfly_factor: usize,
309    ) -> Result<Vec<usize>> {
310        // If no butterfly factor, return identity permutation
311        if n_butterfly_factor == 0 {
312            return Ok((0..n).collect());
313        }
314
315        // Validate parameters
316        if b * r * 2 > n {
317            return Err(PeftError::InvalidConfig(
318                "Invalid number of blocks for butterfly permutation".into(),
319            ));
320        }
321
322        let block_size = n / b;
323        let mut indices: Vec<usize> = (0..n).collect();
324
325        // Sort blocks by separating even and odd positions
326        let sorted_order = Self::sort_block(block_size, r);
327
328        // Apply sorting to each block
329        for i in (0..n).step_by(block_size) {
330            let block_end = i + block_size;
331            let tmp_indices: Vec<usize> = indices[i..block_end].to_vec();
332            for (j, &idx) in sorted_order.iter().enumerate() {
333                indices[i + j] = tmp_indices[idx];
334            }
335        }
336
337        Ok(indices)
338    }
339
340    /// Sort a single block by separating even and odd positions.
341    fn sort_block(block_size: usize, r: usize) -> Vec<usize> {
342        let step = block_size / r;
343        let mut sorted_order = vec![0; block_size];
344
345        // Collect even positions
346        let mut evens: Vec<usize> = (0..step).step_by(2).collect();
347        // Collect odd positions
348        let mut odds: Vec<usize> = (1..step).step_by(2).collect();
349
350        evens.append(&mut odds);
351        let sorted_seq = evens;
352
353        for (i, &pos) in sorted_seq.iter().enumerate() {
354            for j in 0..r {
355                sorted_order[i * r + j] = pos * r + j;
356            }
357        }
358
359        sorted_order
360    }
361
362    /// Convert permutation indices to permutation matrix.
363    fn perm_to_matrix(indices: &[usize], n: usize, device: &Device) -> Result<Tensor> {
364        let mut data = vec![0.0f32; n * n];
365
366        for (i, &idx) in indices.iter().enumerate() {
367            data[i * n + idx] = 1.0;
368        }
369
370        Ok(Tensor::from_vec(data, (n, n), device)?)
371    }
372
373    /// Make parameter matrices skew-symmetric: Q = (R - R^T) / 2
374    fn make_skew_symmetric(&self) -> Result<Tensor> {
375        // boft_r shape: [N, D, H, H]
376        let r_t = self.boft_r.transpose(2, 3)?;
377        let diff = self.boft_r.broadcast_sub(&r_t)?;
378        let two = Tensor::new(2.0f32, self.boft_r.device())?;
379        Ok(diff.broadcast_div(&two)?)
380    }
381
382    /// Apply Cayley transform to skew-symmetric matrices.
383    ///
384    /// For a skew-symmetric matrix Q: `R = (I - Q) @ (I + Q)^{-1}`
385    /// This produces an orthogonal matrix R.
386    fn cayley_batch(skew_mat: &Tensor) -> Result<Tensor> {
387        let device = skew_mat.device();
388        let shape = skew_mat.dims();
389        let batch_size = shape[0];
390        let mat_size = shape[1];
391
392        // Create identity matrix
393        let eye = Tensor::eye(mat_size, candle_core::DType::F32, device)?;
394        let eye = eye.unsqueeze(0)?.expand((batch_size, mat_size, mat_size))?;
395
396        // I - Q
397        let i_minus_q = eye.broadcast_sub(skew_mat)?;
398
399        // I + Q (computed for potential future exact inverse implementation)
400        let _i_plus_q = eye.broadcast_add(skew_mat)?;
401
402        // Solve (I + Q) @ R = (I - Q) for R
403        // This is equivalent to R = (I - Q) @ (I + Q)^{-1}
404        let mut result_blocks = Vec::with_capacity(batch_size);
405
406        for batch_idx in 0..batch_size {
407            let i_minus_q_block = i_minus_q.i(batch_idx)?;
408
409            // Use Neumann series approximation: (I + Q)^{-1} ≈ I - Q + Q²
410            // This approximation is valid when ||Q|| is small (typically < 0.5).
411            // Since Q is skew-symmetric and initialized with small std (0.1),
412            // this approximation is accurate for most practical cases.
413            let q_block = skew_mat.i(batch_idx)?;
414            let q_sq = q_block.matmul(&q_block)?;
415            let inv_approx = eye
416                .i(batch_idx)?
417                .broadcast_sub(&q_block)?
418                .broadcast_add(&q_sq)?;
419
420            let result = i_minus_q_block.matmul(&inv_approx)?;
421            result_blocks.push(result);
422        }
423
424        Ok(Tensor::stack(&result_blocks, 0)?)
425    }
426
427    /// Construct block diagonal matrix from blocks.
428    ///
429    /// Given blocks of shape `[D, H, H]`, creates a block diagonal matrix
430    /// of shape `[D*H, D*H]`.
431    fn block_diag(blocks: &Tensor) -> Result<Tensor> {
432        let device = blocks.device();
433        let shape = blocks.dims();
434        let num_blocks = shape[0];
435        let block_size = shape[1];
436        let total_size = num_blocks * block_size;
437
438        // Create zero matrix
439        let mut data = vec![0.0f32; total_size * total_size];
440
441        // Fill in blocks
442        for block_idx in 0..num_blocks {
443            let block = blocks.i(block_idx)?;
444            let block_data: Vec<f32> = block.flatten_all()?.to_vec1()?;
445
446            let offset = block_idx * block_size;
447            for i in 0..block_size {
448                for j in 0..block_size {
449                    let row = offset + i;
450                    let col = offset + j;
451                    data[row * total_size + col] = block_data[i * block_size + j];
452                }
453            }
454        }
455
456        Ok(Tensor::from_vec(data, (total_size, total_size), device)?)
457    }
458
459    /// Compute the full butterfly OFT matrix.
460    ///
461    /// Applies the butterfly factorization: product of `P @ BlockDiag @ P^T`
462    /// across all butterfly factors.
463    fn compute_butterfly_oft_matrix(&self) -> Result<Tensor> {
464        // Get skew-symmetric matrices
465        let q = self.make_skew_symmetric()?;
466
467        // q shape: [N, D, H, H] where N = n_butterfly_factor + 1
468        let mut butterfly_matrices = Vec::new();
469
470        for factor_idx in 0..=self.n_butterfly_factor {
471            // Extract blocks for this factor
472            let q_factor = q.i(factor_idx)?; // Shape: [D, H, H]
473
474            // Reshape for batch Cayley
475            let shape = q_factor.dims();
476            let d = shape[0];
477            let h = shape[1];
478            let q_reshaped = q_factor.reshape((d, h, h))?;
479
480            // Apply Cayley transform to get orthogonal blocks
481            let orth_blocks = Self::cayley_batch(&q_reshaped)?;
482
483            // Construct block diagonal matrix
484            let block_diag_mat = Self::block_diag(&orth_blocks)?;
485
486            // Get permutation matrix for this factor
487            let perm = self.boft_p.i(factor_idx)?;
488            let perm_t = perm.t()?;
489
490            // Compute P @ BlockDiag @ P^T
491            let tmp = block_diag_mat.matmul(&perm_t)?;
492            let butterfly_mat = perm.matmul(&tmp)?;
493
494            butterfly_matrices.push(butterfly_mat);
495        }
496
497        // Multiply all butterfly factors together
498        let mut result = butterfly_matrices[0].clone();
499        for butterfly_mat in butterfly_matrices.iter().skip(1) {
500            result = butterfly_mat.matmul(&result)?;
501        }
502
503        Ok(result)
504    }
505
506    /// Get the number of blocks.
507    #[must_use]
508    pub fn block_num(&self) -> usize {
509        self.block_num
510    }
511
512    /// Get the block size.
513    #[must_use]
514    pub fn block_size(&self) -> usize {
515        self.block_size
516    }
517
518    /// Get the number of butterfly factors.
519    #[must_use]
520    pub fn n_butterfly_factor(&self) -> usize {
521        self.n_butterfly_factor + 1 // Return the config value
522    }
523}
524
525impl Adapter for BoftLayer {
526    type Config = BoftConfig;
527
528    fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
529        // Compute the butterfly OFT transformation matrix
530        let butterfly_oft = self.compute_butterfly_oft_matrix()?;
531
532        // Handle 3D input by reshaping
533        let input_shape = input.dims();
534        let is_3d = input_shape.len() == 3;
535
536        let input_2d = if is_3d {
537            // Reshape [batch, seq, features] -> [batch*seq, features]
538            input.reshape((input_shape[0] * input_shape[1], input_shape[2]))?
539        } else {
540            input.clone()
541        };
542
543        // Apply transformation: output = input @ butterfly_oft^T
544        let transformed = input_2d.matmul(&butterfly_oft.t()?)?;
545
546        // Reshape back if needed
547        let transformed = if is_3d {
548            transformed.reshape(input_shape)?
549        } else {
550            transformed
551        };
552
553        // Apply scaling: output = transformed * boft_s
554        let scaled = transformed.broadcast_mul(&self.boft_s.t()?)?;
555
556        // Add base output if provided
557        if let Some(base) = base_output {
558            Ok(scaled.broadcast_add(base)?)
559        } else {
560            Ok(scaled)
561        }
562    }
563
564    fn num_parameters(&self) -> usize {
565        // Parameters in boft_r: (n_butterfly_factor+1) * block_num * block_size^2
566        let r_params =
567            (self.n_butterfly_factor + 1) * self.block_num * self.block_size * self.block_size;
568
569        // Parameters in boft_s: out_features
570        let s_params = self.out_features;
571
572        r_params + s_params
573    }
574
575    fn config(&self) -> &Self::Config {
576        &self.config
577    }
578}
579
580impl Mergeable for BoftLayer {
581    fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
582        // Get butterfly OFT matrix
583        let butterfly_oft = self.compute_butterfly_oft_matrix()?;
584
585        // Merge: W' = (W^T @ butterfly_oft)^T * boft_s
586        // = butterfly_oft^T @ W * boft_s
587        let weight_t = base_weight.t()?;
588        let merged_t = butterfly_oft.matmul(&weight_t)?;
589        let merged = merged_t.t()?;
590
591        // Apply scaling
592        Ok(merged.broadcast_mul(&self.boft_s)?)
593    }
594
595    fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
596        // Get butterfly OFT matrix
597        let butterfly_oft = self.compute_butterfly_oft_matrix()?;
598
599        // Unmerge: W = (butterfly_oft^T @ (W' / boft_s)^T)^T
600        let unscaled = merged_weight.broadcast_div(&self.boft_s)?;
601        let unscaled_t = unscaled.t()?;
602        let butterfly_oft_t = butterfly_oft.t()?;
603        let unmerged_t = butterfly_oft_t.matmul(&unscaled_t)?;
604
605        Ok(unmerged_t.t()?)
606    }
607}
608
609impl Trainable for BoftLayer {
610    #[allow(clippy::similar_names)]
611    fn register_parameters(&self, var_map: &mut VarMap, prefix: &str) -> Result<()> {
612        let boft_r_name = format!("{prefix}.boft_r");
613        let boft_s_name = format!("{prefix}.boft_s");
614
615        var_map
616            .data()
617            .lock()
618            .unwrap()
619            .insert(boft_r_name, Var::from_tensor(&self.boft_r)?);
620        var_map
621            .data()
622            .lock()
623            .unwrap()
624            .insert(boft_s_name, Var::from_tensor(&self.boft_s)?);
625
626        Ok(())
627    }
628
629    fn freeze(&mut self) {
630        self.frozen = true;
631    }
632
633    fn unfreeze(&mut self) {
634        self.frozen = false;
635    }
636
637    fn is_frozen(&self) -> bool {
638        self.frozen
639    }
640}
641
642impl SaveLoad for BoftLayer {
643    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
644        let mut state_dict = HashMap::new();
645        state_dict.insert("boft_r".to_string(), self.boft_r.clone());
646        state_dict.insert("boft_s".to_string(), self.boft_s.clone());
647        Ok(state_dict)
648    }
649
650    fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
651        if let Some(boft_r) = state_dict.get("boft_r") {
652            self.boft_r = boft_r.clone();
653        }
654        if let Some(boft_s) = state_dict.get("boft_s") {
655            self.boft_s = boft_s.clone();
656        }
657        Ok(())
658    }
659}
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664    use candle_core::Device;
665
666    #[test]
667    fn test_boft_config_default() {
668        let config = BoftConfig::default();
669        assert_eq!(config.boft_block_num, 4);
670        assert_eq!(config.boft_n_butterfly_factor, 1);
671        assert!((config.boft_dropout - 0.0).abs() < f64::EPSILON);
672    }
673
674    #[test]
675    fn test_boft_config_validation() {
676        let mut config = BoftConfig::default();
677
678        // Valid config
679        assert!(config.validate().is_ok());
680
681        // Both block_size and block_num set
682        config.boft_block_size = 8;
683        config.boft_block_num = 4;
684        assert!(config.validate().is_err());
685
686        // Neither set
687        config.boft_block_size = 0;
688        config.boft_block_num = 0;
689        assert!(config.validate().is_err());
690
691        // Invalid butterfly factor
692        config.boft_block_num = 4;
693        config.boft_n_butterfly_factor = 0;
694        assert!(config.validate().is_err());
695    }
696
697    #[test]
698    fn test_boft_layer_creation() -> Result<()> {
699        let device = Device::Cpu;
700        let config = BoftConfig {
701            boft_block_size: 0,
702            boft_block_num: 4,
703            boft_n_butterfly_factor: 1,
704            ..Default::default()
705        };
706
707        let layer = BoftLayer::new(64, 64, config, &device)?;
708        assert_eq!(layer.block_num(), 4);
709        assert_eq!(layer.block_size(), 16);
710        assert_eq!(layer.n_butterfly_factor(), 1);
711
712        Ok(())
713    }
714
715    #[test]
716    fn test_boft_layer_forward() -> Result<()> {
717        let device = Device::Cpu;
718        let config = BoftConfig {
719            boft_block_size: 0,
720            boft_block_num: 4,
721            boft_n_butterfly_factor: 1,
722            ..Default::default()
723        };
724
725        let layer = BoftLayer::new(64, 64, config, &device)?;
726        let input = Tensor::randn(0.0f32, 1.0f32, (2, 10, 64), &device)?;
727        let output = layer.forward(&input, None)?;
728
729        assert_eq!(output.dims(), &[2, 10, 64]);
730
731        Ok(())
732    }
733
734    #[test]
735    fn test_boft_parameter_count() -> Result<()> {
736        let device = Device::Cpu;
737        let config = BoftConfig {
738            boft_block_size: 0,
739            boft_block_num: 4,
740            boft_n_butterfly_factor: 1,
741            ..Default::default()
742        };
743
744        let layer = BoftLayer::new(64, 64, config, &device)?;
745
746        // With n_butterfly_factor=1 (internally 0), we have:
747        // 1 * 4 * 16 * 16 = 1024 parameters in boft_r
748        // 64 parameters in boft_s
749        // Total: 1088
750        assert_eq!(layer.num_parameters(), 1088);
751
752        Ok(())
753    }
754
755    #[test]
756    fn test_boft_block_butterfly_perm() -> Result<()> {
757        // Test identity permutation (no butterfly)
758        let perm = BoftLayer::block_butterfly_perm(8, 4, 1, 0)?;
759        assert_eq!(perm, vec![0, 1, 2, 3, 4, 5, 6, 7]);
760
761        // Test actual butterfly permutation
762        let perm = BoftLayer::block_butterfly_perm(8, 4, 1, 1)?;
763        // Should separate even and odd positions within blocks
764        assert_eq!(perm.len(), 8);
765
766        Ok(())
767    }
768
769    #[test]
770    fn test_boft_merge_unmerge() -> Result<()> {
771        let device = Device::Cpu;
772        let config = BoftConfig {
773            boft_block_size: 0,
774            boft_block_num: 4,
775            boft_n_butterfly_factor: 1,
776            ..Default::default()
777        };
778
779        let layer = BoftLayer::new(64, 64, config, &device)?;
780        let base_weight = Tensor::randn(0.0f32, 1.0f32, (64, 64), &device)?;
781
782        // Merge
783        let merged = layer.merge(&base_weight)?;
784        assert_eq!(merged.dims(), base_weight.dims());
785
786        // Unmerge
787        let unmerged = layer.unmerge(&merged)?;
788        assert_eq!(unmerged.dims(), base_weight.dims());
789
790        Ok(())
791    }
792
793    #[test]
794    fn test_boft_invalid_features() {
795        let device = Device::Cpu;
796        let config = BoftConfig {
797            boft_block_size: 0,
798            boft_block_num: 5, // 64 is not divisible by 5
799            boft_n_butterfly_factor: 1,
800            ..Default::default()
801        };
802
803        let result = BoftLayer::new(64, 64, config, &device);
804        assert!(result.is_err());
805    }
806
807    #[test]
808    fn test_boft_butterfly_factor_validation() {
809        let device = Device::Cpu;
810
811        // With butterfly factor 2, block_num must be divisible by 2^1 = 2
812        let config = BoftConfig {
813            boft_block_size: 0,
814            boft_block_num: 3, // Not divisible by 2
815            boft_n_butterfly_factor: 2,
816            ..Default::default()
817        };
818
819        let result = BoftLayer::new(64, 64, config, &device);
820        assert!(result.is_err());
821    }
822
823    #[test]
824    fn test_boft_freeze_unfreeze() -> Result<()> {
825        let device = Device::Cpu;
826        let config = BoftConfig::default();
827        let mut layer = BoftLayer::new(64, 64, config, &device)?;
828
829        assert!(!layer.is_frozen());
830        layer.freeze();
831        assert!(layer.is_frozen());
832        layer.unfreeze();
833        assert!(!layer.is_frozen());
834
835        Ok(())
836    }
837}