Skip to main content

peft_rs/adapters/
lokr.rs

1//! `LoKr` (Low-Rank Kronecker Product) implementation.
2//!
3//! `LoKr` uses Kronecker product decomposition for efficient weight updates.
4//! The weight matrix is factorized as: `ΔW = kron(A, B)` where the Kronecker
5//! product allows for structured, parameter-efficient representations.
6//!
7//! Reference: <https://arxiv.org/abs/2108.06098> (`LyCORIS`)
8
9#![allow(clippy::doc_markdown)]
10#![allow(clippy::cast_possible_truncation)]
11#![allow(clippy::cast_precision_loss)]
12#![allow(clippy::cast_sign_loss)]
13
14use candle_core::{Device, Tensor};
15use candle_nn::VarMap;
16use serde::{Deserialize, Serialize};
17
18use crate::error::{PeftError, Result};
19use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
20
21/// Configuration for LoKr adapters.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct LoKrConfig {
24    /// Rank for the decomposition (used for one factor).
25    pub r: usize,
26
27    /// Scaling factor.
28    pub alpha: usize,
29
30    /// Factor dimension (splits the weight into factor x remaining).
31    /// If None, uses automatic factorization.
32    #[serde(default)]
33    pub factor: Option<usize>,
34
35    /// Decomposition type for the Kronecker factors.
36    #[serde(default)]
37    pub decompose_both: bool,
38
39    /// Target modules to apply LoKr to.
40    #[serde(default = "default_target_modules")]
41    pub target_modules: Vec<String>,
42}
43
44fn default_target_modules() -> Vec<String> {
45    vec!["q_proj".into(), "v_proj".into()]
46}
47
48impl Default for LoKrConfig {
49    fn default() -> Self {
50        Self {
51            r: 8,
52            alpha: 16,
53            factor: None,
54            decompose_both: false,
55            target_modules: default_target_modules(),
56        }
57    }
58}
59
60impl AdapterConfig for LoKrConfig {
61    fn validate(&self) -> Result<()> {
62        if self.r == 0 {
63            return Err(PeftError::InvalidConfig("rank must be > 0".into()));
64        }
65        if self.alpha == 0 {
66            return Err(PeftError::InvalidConfig("alpha must be > 0".into()));
67        }
68        Ok(())
69    }
70}
71
72/// LoKr layer implementing Low-Rank Kronecker Product adaptation.
73///
74/// Uses a simplified Kronecker-like decomposition where the weight update
75/// is computed as the outer product structure:
76/// `ΔW = (w1 ⊗ w2) @ (w1_b ⊗ w2_b)^T` approximated via low-rank factors.
77///
78/// For simplicity, this implementation uses a factored approach:
79/// - `lokr_w1`: First Kronecker factor [factor_out, factor_in]
80/// - `lokr_w2_a`, `lokr_w2_b`: Low-rank decomposition of second factor
81pub struct LoKrLayer {
82    /// First Kronecker factor: [factor_out, factor_in]
83    lokr_w1: Tensor,
84    /// Second factor A (low-rank): [remaining_out, r]
85    lokr_w2_a: Tensor,
86    /// Second factor B (low-rank): [r, remaining_in]
87    lokr_w2_b: Tensor,
88    /// Scaling factor = alpha / r
89    scaling: f64,
90    /// Configuration
91    config: LoKrConfig,
92    /// Input dimension
93    in_features: usize,
94    /// Output dimension
95    out_features: usize,
96    /// Factor for output dimension
97    factor_out: usize,
98    /// Factor for input dimension
99    factor_in: usize,
100    /// Whether gradients are disabled
101    frozen: bool,
102}
103
104impl LoKrLayer {
105    /// Create a new LoKr layer.
106    ///
107    /// # Arguments
108    /// * `in_features` - Input dimension
109    /// * `out_features` - Output dimension
110    /// * `config` - LoKr configuration
111    /// * `device` - Device to create tensors on
112    ///
113    /// # Errors
114    /// Returns error if configuration is invalid or tensor initialization fails.
115    pub fn new(
116        in_features: usize,
117        out_features: usize,
118        config: LoKrConfig,
119        device: &Device,
120    ) -> Result<Self> {
121        config.validate()?;
122
123        let scaling = config.alpha as f64 / config.r as f64;
124
125        // Determine factorization
126        // For a weight [out, in], we factorize as kron([f_out, f_in], [r_out, r_in])
127        // where out = f_out * r_out and in = f_in * r_in
128        let factor = config.factor.unwrap_or_else(|| {
129            // Find a reasonable factor (try to find a divisor close to sqrt)
130            let target = (out_features as f64).sqrt() as usize;
131            for f in (1..=target).rev() {
132                if out_features.is_multiple_of(f) && in_features.is_multiple_of(f) {
133                    return f;
134                }
135            }
136            1
137        });
138
139        let factor_out = factor.min(out_features);
140        let factor_in = factor.min(in_features);
141        let remaining_out = out_features / factor_out;
142        let remaining_in = in_features / factor_in;
143
144        // Initialize weights
145        let std = (1.0 / config.r as f64).sqrt() as f32;
146
147        // First Kronecker factor (full matrix)
148        let lokr_w1 = Tensor::randn(0.0f32, std, (factor_out, factor_in), device)?;
149
150        // Second factor as low-rank: A @ B
151        let lokr_w2_a = Tensor::randn(0.0f32, std, (remaining_out, config.r), device)?;
152        let lokr_w2_b = Tensor::randn(0.0f32, std, (config.r, remaining_in), device)?;
153
154        Ok(Self {
155            lokr_w1,
156            lokr_w2_a,
157            lokr_w2_b,
158            scaling,
159            config,
160            in_features,
161            out_features,
162            factor_out,
163            factor_in,
164            frozen: false,
165        })
166    }
167
168    /// Get the scaling factor.
169    #[must_use]
170    pub fn scaling(&self) -> f64 {
171        self.scaling
172    }
173
174    /// Get the rank.
175    #[must_use]
176    pub fn rank(&self) -> usize {
177        self.config.r
178    }
179
180    /// Compute the Kronecker product of two 2D tensors.
181    /// kron(A, B) where A is [m, n] and B is [p, q] produces [m*p, n*q]
182    #[allow(clippy::many_single_char_names)]
183    fn kronecker_product(a: &Tensor, b: &Tensor) -> Result<Tensor> {
184        let a_shape = a.dims();
185        let b_shape = b.dims();
186
187        let m = a_shape[0];
188        let n = a_shape[1];
189        let p = b_shape[0];
190        let q = b_shape[1];
191
192        // Result shape: [m*p, n*q]
193        let mut result_data = Vec::with_capacity(m * p * n * q);
194
195        // Get data as vectors
196        let a_data: Vec<f32> = a.flatten_all()?.to_vec1()?;
197        let b_data: Vec<f32> = b.flatten_all()?.to_vec1()?;
198
199        // Compute Kronecker product
200        for i in 0..m {
201            for k in 0..p {
202                for j in 0..n {
203                    for l in 0..q {
204                        let a_val = a_data[i * n + j];
205                        let b_val = b_data[k * q + l];
206                        result_data.push(a_val * b_val);
207                    }
208                }
209            }
210        }
211
212        Ok(Tensor::from_vec(result_data, (m * p, n * q), a.device())?)
213    }
214
215    /// Compute the weight delta using Kronecker product.
216    fn compute_delta_w(&self) -> Result<Tensor> {
217        // Compute the second factor: w2_a @ w2_b
218        let w2 = self.lokr_w2_a.matmul(&self.lokr_w2_b)?;
219
220        // Compute Kronecker product: kron(w1, w2)
221        Self::kronecker_product(&self.lokr_w1, &w2)
222    }
223}
224
225impl Adapter for LoKrLayer {
226    type Config = LoKrConfig;
227
228    fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
229        // Compute delta weight
230        let delta_w = self.compute_delta_w()?;
231
232        // Apply scaling
233        let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
234        let delta_w = delta_w.broadcast_mul(&scaling)?;
235
236        // Compute: input @ delta_w^T
237        let input_dims = input.dims();
238        let batch_seq = input_dims[0] * input_dims[1];
239        let input_2d = input.reshape((batch_seq, self.in_features))?;
240
241        let lokr_out = input_2d.matmul(&delta_w.t()?)?;
242        let lokr_out = lokr_out.reshape((input_dims[0], input_dims[1], self.out_features))?;
243
244        // Add to base output if provided
245        match base_output {
246            Some(base) => Ok(base.broadcast_add(&lokr_out)?),
247            None => Ok(lokr_out),
248        }
249    }
250
251    fn num_parameters(&self) -> usize {
252        let remaining_out = self.out_features / self.factor_out;
253        let remaining_in = self.in_features / self.factor_in;
254
255        // w1: factor_out * factor_in
256        // w2_a: remaining_out * r
257        // w2_b: r * remaining_in
258        self.factor_out * self.factor_in
259            + remaining_out * self.config.r
260            + self.config.r * remaining_in
261    }
262
263    fn config(&self) -> &Self::Config {
264        &self.config
265    }
266}
267
268impl Mergeable for LoKrLayer {
269    fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
270        let delta_w = self.compute_delta_w()?;
271        let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
272        let delta_w = delta_w.broadcast_mul(&scaling)?;
273
274        Ok(base_weight.broadcast_add(&delta_w)?)
275    }
276
277    fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
278        let delta_w = self.compute_delta_w()?;
279        let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
280        let delta_w = delta_w.broadcast_mul(&scaling)?;
281
282        Ok(merged_weight.broadcast_sub(&delta_w)?)
283    }
284}
285
286impl Trainable for LoKrLayer {
287    fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
288        Ok(())
289    }
290
291    fn freeze(&mut self) {
292        self.frozen = true;
293    }
294
295    fn unfreeze(&mut self) {
296        self.frozen = false;
297    }
298
299    fn is_frozen(&self) -> bool {
300        self.frozen
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use candle_core::DType;
308
309    #[test]
310    fn test_lokr_config_default() {
311        let config = LoKrConfig::default();
312        assert_eq!(config.r, 8);
313        assert_eq!(config.alpha, 16);
314        assert!(config.validate().is_ok());
315    }
316
317    #[test]
318    fn test_lokr_config_invalid_rank() {
319        let config = LoKrConfig {
320            r: 0,
321            ..Default::default()
322        };
323        assert!(config.validate().is_err());
324    }
325
326    #[test]
327    fn test_lokr_layer_creation() {
328        let config = LoKrConfig::default();
329        let device = Device::Cpu;
330        // Use dimensions that are easily factorizable
331        let layer = LoKrLayer::new(64, 64, config, &device);
332        assert!(layer.is_ok());
333    }
334
335    #[test]
336    fn test_lokr_layer_with_factor() {
337        let config = LoKrConfig {
338            factor: Some(8),
339            ..Default::default()
340        };
341        let device = Device::Cpu;
342        let layer = LoKrLayer::new(64, 64, config, &device);
343        assert!(layer.is_ok());
344
345        let layer = layer.unwrap();
346        assert_eq!(layer.factor_out, 8);
347        assert_eq!(layer.factor_in, 8);
348    }
349
350    #[test]
351    fn test_lokr_forward_shape() {
352        let config = LoKrConfig {
353            factor: Some(8),
354            ..Default::default()
355        };
356        let device = Device::Cpu;
357        let layer = LoKrLayer::new(64, 64, config, &device).unwrap();
358
359        let input = Tensor::zeros(&[1, 10, 64], DType::F32, &device).unwrap();
360        let output = layer.forward(&input, None).unwrap();
361
362        assert_eq!(output.shape().dims(), &[1, 10, 64]);
363    }
364
365    #[test]
366    fn test_lokr_forward_with_base_output() {
367        let config = LoKrConfig {
368            factor: Some(8),
369            ..Default::default()
370        };
371        let device = Device::Cpu;
372        let layer = LoKrLayer::new(64, 64, config, &device).unwrap();
373
374        let input = Tensor::zeros(&[1, 10, 64], DType::F32, &device).unwrap();
375        let base_output = Tensor::ones(&[1, 10, 64], DType::F32, &device).unwrap();
376        let output = layer.forward(&input, Some(&base_output)).unwrap();
377
378        assert_eq!(output.shape().dims(), &[1, 10, 64]);
379    }
380
381    #[test]
382    fn test_lokr_num_parameters() {
383        let config = LoKrConfig {
384            r: 4,
385            factor: Some(8),
386            ..Default::default()
387        };
388        let device = Device::Cpu;
389        let layer = LoKrLayer::new(64, 64, config, &device).unwrap();
390
391        // w1: 8 * 8 = 64
392        // remaining: 64/8 = 8
393        // w2_a: 8 * 4 = 32
394        // w2_b: 4 * 8 = 32
395        // Total: 64 + 32 + 32 = 128
396        assert_eq!(layer.num_parameters(), 128);
397    }
398
399    #[test]
400    fn test_lokr_merge_unmerge() {
401        let config = LoKrConfig {
402            factor: Some(8),
403            ..Default::default()
404        };
405        let device = Device::Cpu;
406        let layer = LoKrLayer::new(64, 64, config, &device).unwrap();
407
408        let base_weight = Tensor::randn(0.0f32, 0.02, (64, 64), &device).unwrap();
409        let merged = layer.merge(&base_weight).unwrap();
410        let unmerged = layer.unmerge(&merged).unwrap();
411
412        // Unmerged should be close to original
413        let diff = unmerged.broadcast_sub(&base_weight).unwrap();
414        let max_diff: f32 = diff
415            .abs()
416            .unwrap()
417            .max(0)
418            .unwrap()
419            .max(0)
420            .unwrap()
421            .to_scalar()
422            .unwrap();
423        assert!(max_diff < 1e-5);
424    }
425
426    #[test]
427    fn test_lokr_freeze_unfreeze() {
428        let config = LoKrConfig::default();
429        let device = Device::Cpu;
430        let mut layer = LoKrLayer::new(64, 64, config, &device).unwrap();
431
432        assert!(!layer.is_frozen());
433        layer.freeze();
434        assert!(layer.is_frozen());
435        layer.unfreeze();
436        assert!(!layer.is_frozen());
437    }
438
439    #[test]
440    fn test_kronecker_product() {
441        let device = Device::Cpu;
442        let a = Tensor::new(&[[1.0f32, 2.0], [3.0, 4.0]], &device).unwrap();
443        let b = Tensor::new(&[[0.0f32, 5.0], [6.0, 7.0]], &device).unwrap();
444
445        let result = LoKrLayer::kronecker_product(&a, &b).unwrap();
446        assert_eq!(result.dims(), &[4, 4]);
447    }
448}