Skip to main content

peft_rs/adapters/
oft.rs

1//! OFT (Orthogonal Fine-Tuning) implementation.
2//!
3//! OFT applies orthogonal transformations to preserve the pretrained knowledge
4//! while adapting models. It uses block-diagonal orthogonal matrices to
5//! transform weights efficiently.
6//!
7//! Reference: <https://arxiv.org/abs/2306.07280>
8
9use candle_core::{Device, IndexOp, Tensor};
10use candle_nn::VarMap;
11use serde::{Deserialize, Serialize};
12
13use crate::error::{PeftError, Result};
14use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
15
16/// Configuration for OFT adapters.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct OftConfig {
19    /// Number of OFT blocks (determines expressiveness vs efficiency).
20    pub r: usize,
21
22    /// Whether to use constrained OFT (COFT) which enforces strict orthogonality.
23    #[serde(default)]
24    pub coft: bool,
25
26    /// Small constant for numerical stability.
27    #[serde(default = "default_eps")]
28    pub eps: f64,
29
30    /// Block sharing across layers.
31    #[serde(default)]
32    pub block_share: bool,
33
34    /// Target modules to apply OFT to.
35    #[serde(default = "default_target_modules")]
36    pub target_modules: Vec<String>,
37
38    /// Whether to use exact Cayley transform computation.
39    ///
40    /// When `false` (default), uses a Neumann series approximation `(I + Q)^{-1} ≈ I - Q + Q^2`
41    /// which is efficient but less accurate for larger Q values.
42    ///
43    /// When `true`, computes the exact inverse using Newton-Schulz iteration,
44    /// providing higher accuracy at the cost of additional computation.
45    #[serde(default)]
46    pub use_exact_cayley: bool,
47}
48
49fn default_eps() -> f64 {
50    1e-5
51}
52
53fn default_target_modules() -> Vec<String> {
54    vec!["q_proj".into(), "v_proj".into()]
55}
56
57impl Default for OftConfig {
58    fn default() -> Self {
59        Self {
60            r: 8,
61            coft: false,
62            eps: default_eps(),
63            block_share: false,
64            target_modules: default_target_modules(),
65            use_exact_cayley: false,
66        }
67    }
68}
69
70impl AdapterConfig for OftConfig {
71    fn validate(&self) -> Result<()> {
72        if self.r == 0 {
73            return Err(PeftError::InvalidConfig(
74                "number of blocks (r) must be > 0".into(),
75            ));
76        }
77        if self.eps <= 0.0 {
78            return Err(PeftError::InvalidConfig("eps must be > 0".into()));
79        }
80        Ok(())
81    }
82}
83
84/// OFT layer implementing Orthogonal Fine-Tuning.
85///
86/// Uses block-diagonal orthogonal matrices to transform weights:
87/// `W' = W @ R` where R is a block-diagonal orthogonal matrix.
88///
89/// The orthogonal matrix R is parameterized via Cayley transform:
90/// `R = (I - Q) @ (I + Q)^{-1}` where Q is skew-symmetric.
91pub struct OftLayer {
92    /// Skew-symmetric parameters for Cayley parameterization.
93    /// Shape: [`num_blocks`, `block_size`, `block_size`]
94    oft_r: Tensor,
95    /// Configuration
96    config: OftConfig,
97    /// Input/output dimension (OFT requires square transformation)
98    features: usize,
99    /// Size of each block
100    block_size: usize,
101    /// Number of blocks
102    num_blocks: usize,
103    /// Whether gradients are disabled
104    frozen: bool,
105}
106
107impl OftLayer {
108    /// Create a new OFT layer.
109    ///
110    /// # Arguments
111    /// * `features` - Dimension of the weight matrix (must be divisible by r)
112    /// * `config` - OFT configuration
113    /// * `device` - Device to create tensors on
114    ///
115    /// # Errors
116    ///
117    /// Returns an error if configuration validation fails or if layer construction fails.
118    pub fn new(features: usize, config: OftConfig, device: &Device) -> Result<Self> {
119        config.validate()?;
120
121        if !features.is_multiple_of(config.r) {
122            return Err(PeftError::InvalidConfig(format!(
123                "features ({}) must be divisible by r ({})",
124                features, config.r
125            )));
126        }
127
128        let num_blocks = config.r;
129        let block_size = features / num_blocks;
130
131        // Initialize skew-symmetric parameters to small values
132        // This makes the initial orthogonal matrix close to identity
133        let std = 0.01_f32;
134        let oft_r = Tensor::randn(0.0f32, std, (num_blocks, block_size, block_size), device)?;
135
136        Ok(Self {
137            oft_r,
138            config,
139            features,
140            block_size,
141            num_blocks,
142            frozen: false,
143        })
144    }
145
146    /// Get the number of blocks.
147    #[must_use]
148    pub fn num_blocks(&self) -> usize {
149        self.num_blocks
150    }
151
152    /// Get the block size.
153    #[must_use]
154    pub fn block_size(&self) -> usize {
155        self.block_size
156    }
157
158    /// Make the parameter matrix skew-symmetric: Q = (R - R^T) / 2
159    fn make_skew_symmetric(&self) -> Result<Tensor> {
160        let r_t = self.oft_r.transpose(1, 2)?;
161        let diff = self.oft_r.broadcast_sub(&r_t)?;
162        let two = Tensor::new(2.0f32, self.oft_r.device())?;
163        Ok(diff.broadcast_div(&two)?)
164    }
165
166    /// Compute the orthogonal matrix using Cayley transform.
167    /// R = (I - Q) @ (I + Q)^{-1}
168    ///
169    /// Uses either exact computation or Neumann series approximation based on config.
170    fn compute_orthogonal_matrix(&self) -> Result<Tensor> {
171        let q = self.make_skew_symmetric()?;
172        let device = q.device();
173
174        // Create block-diagonal identity
175        let eye = Tensor::eye(self.block_size, candle_core::DType::F32, device)?;
176        let eye = eye
177            .unsqueeze(0)?
178            .expand((self.num_blocks, self.block_size, self.block_size))?;
179
180        // I - Q
181        let i_minus_q = eye.broadcast_sub(&q)?;
182
183        // I + Q
184        let i_plus_q = eye.broadcast_add(&q)?;
185
186        let mut result_blocks = Vec::with_capacity(self.num_blocks);
187
188        for block_idx in 0..self.num_blocks {
189            let i_minus_q_block = i_minus_q.i(block_idx)?;
190            let i_plus_q_block = i_plus_q.i(block_idx)?;
191            let q_block = q.i(block_idx)?;
192
193            let inv = if self.config.use_exact_cayley {
194                // Exact method: Compute (I + Q)^{-1} using iterative refinement
195                // Since Q is small (initialized with std=0.01), we use Newton-Schulz iteration
196                // which converges quickly for matrices close to identity.
197                //
198                // Newton-Schulz iteration: X_{k+1} = X_k @ (2I - (I+Q) @ X_k)
199                // Starting with X_0 = I (good initial guess since I+Q ≈ I)
200                self.compute_exact_inverse(&i_plus_q_block)?
201            } else {
202                // Approximation method: Neumann series (I + Q)^{-1} ≈ I - Q + Q^2
203                // Efficient but less accurate for larger Q values
204                let eye_block = Tensor::eye(self.block_size, candle_core::DType::F32, device)?;
205                let q_sq = q_block.matmul(&q_block)?;
206                eye_block.broadcast_sub(&q_block)?.broadcast_add(&q_sq)?
207            };
208
209            // R_block = (I - Q) @ (I + Q)^{-1}
210            let r_block = i_minus_q_block.matmul(&inv)?;
211            result_blocks.push(r_block);
212        }
213
214        // Stack blocks: [num_blocks, block_size, block_size]
215        Ok(Tensor::stack(&result_blocks, 0)?)
216    }
217
218    /// Compute exact inverse using Newton-Schulz iteration.
219    ///
220    /// Newton-Schulz iteration: `X_{k+1}` = `X_k` @ (2I - A @ `X_k`)
221    /// Converges for matrices A with ||I - A|| < 1, which is satisfied
222    /// since (I + Q) is close to identity for small Q (initialized with std=0.01).
223    ///
224    /// # Iteration Count
225    /// Uses 5 iterations which provides accuracy to approximately 1e-10 for well-conditioned
226    /// matrices close to identity. This is sufficient since:
227    /// - Q is initialized with small values (std=0.01)
228    /// - (I + Q) is thus very close to I, ensuring fast quadratic convergence
229    /// - Each iteration roughly squares the error: ||`X_k` - A^{-1}|| ≈ ||`X_0` - A^{-1}||^{2^k}
230    fn compute_exact_inverse(&self, matrix: &Tensor) -> Result<Tensor> {
231        // Newton-Schulz iterations: 5 iterations provides ~1e-10 accuracy for matrices
232        // close to identity, which is the case here since Q is initialized with small values.
233        const NUM_ITERATIONS: usize = 5;
234
235        let device = matrix.device();
236        let eye = Tensor::eye(self.block_size, candle_core::DType::F32, device)?;
237        let two = Tensor::new(2.0f32, device)?;
238        let two_eye = eye.broadcast_mul(&two)?;
239
240        // Start with identity as initial guess (good for matrices close to I)
241        let mut x = eye.clone();
242
243        for _ in 0..NUM_ITERATIONS {
244            // X_{k+1} = X_k @ (2I - A @ X_k)
245            let ax = matrix.matmul(&x)?;
246            let factor = two_eye.broadcast_sub(&ax)?;
247            x = x.matmul(&factor)?;
248        }
249
250        Ok(x)
251    }
252
253    /// Apply block-diagonal orthogonal transformation to input.
254    fn apply_block_diagonal(&self, input: &Tensor, orth_matrix: &Tensor) -> Result<Tensor> {
255        let input_dims = input.dims();
256        let batch_seq = input_dims[0] * input_dims[1];
257
258        // Reshape input to [batch*seq, num_blocks, block_size]
259        let input_blocked = input.reshape((batch_seq, self.num_blocks, self.block_size))?;
260
261        // Apply orthogonal transformation to each block
262        // input_blocked: [batch*seq, num_blocks, block_size]
263        // orth_matrix: [num_blocks, block_size, block_size]
264
265        // For each block: output[b, n, :] = input[b, n, :] @ R[n, :, :]
266        // We need batch matrix multiply
267
268        let mut output_blocks = Vec::with_capacity(self.num_blocks);
269
270        for block_idx in 0..self.num_blocks {
271            // input_block: [batch*seq, block_size]
272            let input_block = input_blocked.i((.., block_idx, ..))?;
273            // orth_block: [block_size, block_size]
274            let orth_block = orth_matrix.i(block_idx)?;
275
276            // output_block: [batch*seq, block_size]
277            let output_block = input_block.matmul(&orth_block)?;
278            output_blocks.push(output_block);
279        }
280
281        // Stack and reshape back
282        let output_stacked = Tensor::stack(&output_blocks, 1)?; // [batch*seq, num_blocks, block_size]
283        Ok(output_stacked.reshape((input_dims[0], input_dims[1], self.features))?)
284    }
285}
286
287impl Adapter for OftLayer {
288    type Config = OftConfig;
289
290    fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
291        // Compute the orthogonal transformation matrix
292        let orth_matrix = self.compute_orthogonal_matrix()?;
293
294        // Apply block-diagonal orthogonal transformation
295        let transformed = self.apply_block_diagonal(input, &orth_matrix)?;
296
297        // For OFT, the transformation replaces the base output
298        // If base_output provided, compute the difference (delta)
299        match base_output {
300            Some(base) => {
301                // Return: base + (transformed - input) = base + delta
302                let delta = transformed.broadcast_sub(input)?;
303                Ok(base.broadcast_add(&delta)?)
304            }
305            None => Ok(transformed),
306        }
307    }
308
309    fn num_parameters(&self) -> usize {
310        // Skew-symmetric blocks: num_blocks * block_size * block_size
311        // But only lower/upper triangle is independent: num_blocks * block_size * (block_size - 1) / 2
312        // For simplicity, we count all parameters
313        self.num_blocks * self.block_size * self.block_size
314    }
315
316    fn config(&self) -> &Self::Config {
317        &self.config
318    }
319}
320
321impl Mergeable for OftLayer {
322    fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
323        // W' = W @ R (right multiply by orthogonal matrix)
324        let orth_matrix = self.compute_orthogonal_matrix()?;
325
326        // Construct full block-diagonal matrix from blocks
327        let full_orth = self.construct_full_matrix(&orth_matrix)?;
328
329        // base_weight: [out_features, in_features]
330        // full_orth: [in_features, in_features]
331        Ok(base_weight.matmul(&full_orth)?)
332    }
333
334    fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
335        // W = W' @ R^T (R is orthogonal, so R^{-1} = R^T)
336        let orth_matrix = self.compute_orthogonal_matrix()?;
337        let full_orth = self.construct_full_matrix(&orth_matrix)?;
338
339        // R^T
340        let full_orth_t = full_orth.t()?;
341
342        Ok(merged_weight.matmul(&full_orth_t)?)
343    }
344}
345
346impl OftLayer {
347    /// Construct full block-diagonal matrix from blocks.
348    fn construct_full_matrix(&self, blocks: &Tensor) -> Result<Tensor> {
349        let device = blocks.device();
350        let n = self.features;
351
352        // Start with zeros
353        let mut full_data = vec![0.0f32; n * n];
354
355        // Fill in blocks along diagonal
356        for block_idx in 0..self.num_blocks {
357            let block = blocks.i(block_idx)?;
358            let block_data: Vec<f32> = block.flatten_all()?.to_vec1()?;
359
360            let start = block_idx * self.block_size;
361
362            for i in 0..self.block_size {
363                for j in 0..self.block_size {
364                    let row = start + i;
365                    let col = start + j;
366                    full_data[row * n + col] = block_data[i * self.block_size + j];
367                }
368            }
369        }
370
371        Ok(Tensor::from_vec(full_data, (n, n), device)?)
372    }
373}
374
375impl Trainable for OftLayer {
376    fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
377        Ok(())
378    }
379
380    fn freeze(&mut self) {
381        self.frozen = true;
382    }
383
384    fn unfreeze(&mut self) {
385        self.frozen = false;
386    }
387
388    fn is_frozen(&self) -> bool {
389        self.frozen
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use candle_core::{DType, IndexOp};
397
398    #[test]
399    fn test_oft_config_default() {
400        let config = OftConfig::default();
401        assert_eq!(config.r, 8);
402        assert!(!config.coft);
403        assert!(config.validate().is_ok());
404    }
405
406    #[test]
407    fn test_oft_config_invalid_r() {
408        let config = OftConfig {
409            r: 0,
410            ..Default::default()
411        };
412        assert!(config.validate().is_err());
413    }
414
415    #[test]
416    fn test_oft_layer_creation() {
417        let config = OftConfig {
418            r: 8,
419            ..Default::default()
420        };
421        let device = Device::Cpu;
422        // 64 is divisible by 8
423        let layer = OftLayer::new(64, config, &device);
424        assert!(layer.is_ok());
425
426        let layer = layer.unwrap();
427        assert_eq!(layer.num_blocks(), 8);
428        assert_eq!(layer.block_size(), 8);
429    }
430
431    #[test]
432    fn test_oft_layer_invalid_dimensions() {
433        let config = OftConfig {
434            r: 8,
435            ..Default::default()
436        };
437        let device = Device::Cpu;
438        // 65 is not divisible by 8
439        let layer = OftLayer::new(65, config, &device);
440        assert!(layer.is_err());
441    }
442
443    #[test]
444    fn test_oft_forward_shape() {
445        let config = OftConfig {
446            r: 8,
447            ..Default::default()
448        };
449        let device = Device::Cpu;
450        let layer = OftLayer::new(64, config, &device).unwrap();
451
452        let input = Tensor::zeros(&[1, 10, 64], DType::F32, &device).unwrap();
453        let output = layer.forward(&input, None).unwrap();
454
455        assert_eq!(output.shape().dims(), &[1, 10, 64]);
456    }
457
458    #[test]
459    fn test_oft_forward_with_base_output() {
460        let config = OftConfig {
461            r: 8,
462            ..Default::default()
463        };
464        let device = Device::Cpu;
465        let layer = OftLayer::new(64, config, &device).unwrap();
466
467        let input = Tensor::zeros(&[1, 10, 64], DType::F32, &device).unwrap();
468        let base_output = Tensor::ones(&[1, 10, 64], DType::F32, &device).unwrap();
469        let output = layer.forward(&input, Some(&base_output)).unwrap();
470
471        assert_eq!(output.shape().dims(), &[1, 10, 64]);
472    }
473
474    #[test]
475    fn test_oft_num_parameters() {
476        let config = OftConfig {
477            r: 8,
478            ..Default::default()
479        };
480        let device = Device::Cpu;
481        let layer = OftLayer::new(64, config, &device).unwrap();
482
483        // 8 blocks of 8x8 = 8 * 64 = 512
484        assert_eq!(layer.num_parameters(), 512);
485    }
486
487    #[test]
488    fn test_oft_skew_symmetric() {
489        let config = OftConfig {
490            r: 2,
491            ..Default::default()
492        };
493        let device = Device::Cpu;
494        let layer = OftLayer::new(8, config, &device).unwrap();
495
496        let skew = layer.make_skew_symmetric().unwrap();
497
498        // Check Q = -Q^T for each block
499        for block_idx in 0..2 {
500            let q = skew.i(block_idx).unwrap();
501            let q_t = q.t().unwrap();
502            let sum = q.broadcast_add(&q_t).unwrap();
503            let max_val: f32 = sum
504                .abs()
505                .unwrap()
506                .max(0)
507                .unwrap()
508                .max(0)
509                .unwrap()
510                .to_scalar()
511                .unwrap();
512            assert!(max_val < 1e-5, "Matrix should be skew-symmetric");
513        }
514    }
515
516    #[test]
517    fn test_oft_freeze_unfreeze() {
518        let config = OftConfig::default();
519        let device = Device::Cpu;
520        let mut layer = OftLayer::new(64, config, &device).unwrap();
521
522        assert!(!layer.is_frozen());
523        layer.freeze();
524        assert!(layer.is_frozen());
525        layer.unfreeze();
526        assert!(!layer.is_frozen());
527    }
528
529    #[test]
530    fn test_oft_merge_unmerge() {
531        let config = OftConfig {
532            r: 4,
533            ..Default::default()
534        };
535        let device = Device::Cpu;
536        let layer = OftLayer::new(16, config, &device).unwrap();
537
538        let base_weight = Tensor::eye(16, DType::F32, &device).unwrap();
539        let merged = layer.merge(&base_weight).unwrap();
540        let unmerged = layer.unmerge(&merged).unwrap();
541
542        // Unmerged should be close to original
543        let diff = unmerged.broadcast_sub(&base_weight).unwrap();
544        let max_diff: f32 = diff
545            .abs()
546            .unwrap()
547            .max(0)
548            .unwrap()
549            .max(0)
550            .unwrap()
551            .to_scalar()
552            .unwrap();
553        assert!(max_diff < 0.1, "Max diff: {max_diff}"); // Allow some numerical error
554    }
555
556    #[test]
557    fn test_oft_exact_cayley_config() {
558        // Test that exact Cayley option is properly configured
559        let config = OftConfig {
560            r: 4,
561            use_exact_cayley: true,
562            ..Default::default()
563        };
564        assert!(config.use_exact_cayley);
565        assert!(config.validate().is_ok());
566    }
567
568    #[test]
569    fn test_oft_exact_cayley_forward() {
570        // Test forward pass with exact Cayley transform
571        let config = OftConfig {
572            r: 4,
573            use_exact_cayley: true,
574            ..Default::default()
575        };
576        let device = Device::Cpu;
577        let layer = OftLayer::new(16, config, &device).unwrap();
578
579        let input = Tensor::zeros(&[1, 10, 16], DType::F32, &device).unwrap();
580        let output = layer.forward(&input, None).unwrap();
581
582        assert_eq!(output.shape().dims(), &[1, 10, 16]);
583    }
584
585    #[test]
586    fn test_oft_exact_cayley_merge_unmerge() {
587        // Test merge/unmerge with exact Cayley - should have better accuracy
588        // The exact method uses Newton-Schulz iteration which provides higher precision
589        // for the Cayley transform computation. Tolerance of 0.05 is stricter than the
590        // 0.1 used for the approximation method in test_oft_merge_unmerge.
591        const EXACT_METHOD_TOLERANCE: f32 = 0.05;
592
593        let config = OftConfig {
594            r: 4,
595            use_exact_cayley: true,
596            ..Default::default()
597        };
598        let device = Device::Cpu;
599        let layer = OftLayer::new(16, config, &device).unwrap();
600
601        let base_weight = Tensor::eye(16, DType::F32, &device).unwrap();
602        let merged = layer.merge(&base_weight).unwrap();
603        let unmerged = layer.unmerge(&merged).unwrap();
604
605        let diff = unmerged.broadcast_sub(&base_weight).unwrap();
606        let max_diff: f32 = diff
607            .abs()
608            .unwrap()
609            .max(0)
610            .unwrap()
611            .max(0)
612            .unwrap()
613            .to_scalar()
614            .unwrap();
615        assert!(
616            max_diff < EXACT_METHOD_TOLERANCE,
617            "Max diff with exact Cayley: {max_diff}"
618        );
619    }
620
621    #[test]
622    fn test_oft_approx_vs_exact_cayley() {
623        // Compare approximation vs exact methods
624        let device = Device::Cpu;
625
626        // Approximation method
627        let config_approx = OftConfig {
628            r: 4,
629            use_exact_cayley: false,
630            ..Default::default()
631        };
632        let layer_approx = OftLayer::new(16, config_approx, &device).unwrap();
633
634        // Exact method (with same initialization - we can't easily compare due to random init)
635        let config_exact = OftConfig {
636            r: 4,
637            use_exact_cayley: true,
638            ..Default::default()
639        };
640        let layer_exact = OftLayer::new(16, config_exact, &device).unwrap();
641
642        // Both should produce valid outputs with correct shape
643        let input = Tensor::randn(0.0f32, 1.0, (1, 10, 16), &device).unwrap();
644
645        let output_approx = layer_approx.forward(&input, None).unwrap();
646        let output_exact = layer_exact.forward(&input, None).unwrap();
647
648        assert_eq!(output_approx.shape().dims(), &[1, 10, 16]);
649        assert_eq!(output_exact.shape().dims(), &[1, 10, 16]);
650    }
651}