Skip to main content

peft_rs/adapters/
loha.rs

1//! `LoHa` (Low-Rank Hadamard Product) implementation.
2//!
3//! `LoHa` uses the Hadamard (element-wise) product of two low-rank matrices
4//! for more expressive weight updates: `ΔW = (A1 ⊗ B1) ⊙ (A2 ⊗ B2)`
5//!
6//! Reference: <https://arxiv.org/abs/2108.06098> (`LyCORIS`)
7
8#![allow(clippy::doc_markdown)]
9#![allow(clippy::cast_possible_truncation)]
10#![allow(clippy::cast_precision_loss)]
11
12use candle_core::{Device, Tensor};
13use candle_nn::VarMap;
14use serde::{Deserialize, Serialize};
15
16use crate::error::{PeftError, Result};
17use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
18
19/// Configuration for LoHa adapters.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct LoHaConfig {
22    /// Rank of the first low-rank decomposition.
23    pub r: usize,
24
25    /// Scaling factor (typically `alpha / r`).
26    pub alpha: usize,
27
28    /// Target modules to apply LoHa to.
29    #[serde(default = "default_target_modules")]
30    pub target_modules: Vec<String>,
31
32    /// Whether to use effective convolution for conv layers.
33    #[serde(default)]
34    pub use_effective_conv2d: bool,
35}
36
37fn default_target_modules() -> Vec<String> {
38    vec!["q_proj".into(), "v_proj".into()]
39}
40
41impl Default for LoHaConfig {
42    fn default() -> Self {
43        Self {
44            r: 8,
45            alpha: 16,
46            target_modules: default_target_modules(),
47            use_effective_conv2d: false,
48        }
49    }
50}
51
52impl AdapterConfig for LoHaConfig {
53    fn validate(&self) -> Result<()> {
54        if self.r == 0 {
55            return Err(PeftError::InvalidConfig("rank must be > 0".into()));
56        }
57        if self.alpha == 0 {
58            return Err(PeftError::InvalidConfig("alpha must be > 0".into()));
59        }
60        Ok(())
61    }
62}
63
64/// LoHa layer implementing Low-Rank Hadamard Product adaptation.
65///
66/// Computes: `ΔW = (A1 @ B1) ⊙ (A2 @ B2) * scaling`
67///
68/// Where:
69/// - A1, A2: [out_features, r]
70/// - B1, B2: [r, in_features]
71/// - ⊙ is element-wise (Hadamard) product
72pub struct LoHaLayer {
73    /// First decomposition: A1 [out_features, r]
74    hada_w1_a: Tensor,
75    /// First decomposition: B1 [r, in_features]
76    hada_w1_b: Tensor,
77    /// Second decomposition: A2 [out_features, r]
78    hada_w2_a: Tensor,
79    /// Second decomposition: B2 [r, in_features]
80    hada_w2_b: Tensor,
81    /// Scaling factor = alpha / r
82    scaling: f64,
83    /// Configuration
84    config: LoHaConfig,
85    /// Input dimension
86    in_features: usize,
87    /// Output dimension
88    out_features: usize,
89    /// Whether gradients are disabled
90    frozen: bool,
91}
92
93impl LoHaLayer {
94    /// Create a new LoHa layer.
95    ///
96    /// # Arguments
97    /// * `in_features` - Input dimension
98    /// * `out_features` - Output dimension
99    /// * `config` - LoHa configuration
100    /// * `device` - Device to create tensors on
101    ///
102    /// # Errors
103    /// Returns error if configuration is invalid or tensor initialization fails.
104    pub fn new(
105        in_features: usize,
106        out_features: usize,
107        config: LoHaConfig,
108        device: &Device,
109    ) -> Result<Self> {
110        config.validate()?;
111
112        let scaling = config.alpha as f64 / config.r as f64;
113
114        // Initialize with Kaiming-like initialization
115        let std = (1.0 / config.r as f64).sqrt() as f32;
116
117        // First low-rank decomposition
118        let hada_w1_a = Tensor::randn(0.0f32, std, (out_features, config.r), device)?;
119        let hada_w1_b = Tensor::randn(0.0f32, std, (config.r, in_features), device)?;
120
121        // Second low-rank decomposition
122        let hada_w2_a = Tensor::randn(0.0f32, std, (out_features, config.r), device)?;
123        let hada_w2_b = Tensor::randn(0.0f32, std, (config.r, in_features), device)?;
124
125        Ok(Self {
126            hada_w1_a,
127            hada_w1_b,
128            hada_w2_a,
129            hada_w2_b,
130            scaling,
131            config,
132            in_features,
133            out_features,
134            frozen: false,
135        })
136    }
137
138    /// Get the scaling factor.
139    #[must_use]
140    pub fn scaling(&self) -> f64 {
141        self.scaling
142    }
143
144    /// Get the rank.
145    #[must_use]
146    pub fn rank(&self) -> usize {
147        self.config.r
148    }
149
150    /// Compute the weight delta: (A1 @ B1) ⊙ (A2 @ B2)
151    fn compute_delta_w(&self) -> Result<Tensor> {
152        // Compute first term: A1 @ B1 -> [out_features, in_features]
153        let term1 = self.hada_w1_a.matmul(&self.hada_w1_b)?;
154
155        // Compute second term: A2 @ B2 -> [out_features, in_features]
156        let term2 = self.hada_w2_a.matmul(&self.hada_w2_b)?;
157
158        // Hadamard (element-wise) product
159        Ok(term1.mul(&term2)?)
160    }
161}
162
163impl Adapter for LoHaLayer {
164    type Config = LoHaConfig;
165
166    fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
167        // Compute delta weight
168        let delta_w = self.compute_delta_w()?;
169
170        // Apply scaling
171        let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
172        let delta_w = delta_w.broadcast_mul(&scaling)?;
173
174        // Compute: input @ delta_w^T
175        let input_dims = input.dims();
176        let batch_seq = input_dims[0] * input_dims[1];
177        let input_2d = input.reshape((batch_seq, self.in_features))?;
178
179        let loha_out = input_2d.matmul(&delta_w.t()?)?;
180        let loha_out = loha_out.reshape((input_dims[0], input_dims[1], self.out_features))?;
181
182        // Add to base output if provided
183        match base_output {
184            Some(base) => Ok(base.broadcast_add(&loha_out)?),
185            None => Ok(loha_out),
186        }
187    }
188
189    fn num_parameters(&self) -> usize {
190        // 4 matrices: 2 * (out_features * r + r * in_features)
191        2 * (self.out_features * self.config.r + self.config.r * self.in_features)
192    }
193
194    fn config(&self) -> &Self::Config {
195        &self.config
196    }
197}
198
199impl Mergeable for LoHaLayer {
200    fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
201        let delta_w = self.compute_delta_w()?;
202        let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
203        let delta_w = delta_w.broadcast_mul(&scaling)?;
204
205        Ok(base_weight.broadcast_add(&delta_w)?)
206    }
207
208    fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
209        let delta_w = self.compute_delta_w()?;
210        let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
211        let delta_w = delta_w.broadcast_mul(&scaling)?;
212
213        Ok(merged_weight.broadcast_sub(&delta_w)?)
214    }
215}
216
217impl Trainable for LoHaLayer {
218    fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
219        // Note: In the current design, tensors are created directly.
220        // For full training support, tensors should be created via VarBuilder
221        // during construction, which automatically registers them.
222        Ok(())
223    }
224
225    fn freeze(&mut self) {
226        self.frozen = true;
227    }
228
229    fn unfreeze(&mut self) {
230        self.frozen = false;
231    }
232
233    fn is_frozen(&self) -> bool {
234        self.frozen
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use candle_core::DType;
242
243    #[test]
244    fn test_loha_config_default() {
245        let config = LoHaConfig::default();
246        assert_eq!(config.r, 8);
247        assert_eq!(config.alpha, 16);
248        assert!(config.validate().is_ok());
249    }
250
251    #[test]
252    fn test_loha_config_invalid_rank() {
253        let config = LoHaConfig {
254            r: 0,
255            ..Default::default()
256        };
257        assert!(config.validate().is_err());
258    }
259
260    #[test]
261    fn test_loha_config_invalid_alpha() {
262        let config = LoHaConfig {
263            alpha: 0,
264            ..Default::default()
265        };
266        assert!(config.validate().is_err());
267    }
268
269    #[test]
270    fn test_loha_layer_creation() {
271        let config = LoHaConfig::default();
272        let device = Device::Cpu;
273        let layer = LoHaLayer::new(768, 768, config, &device);
274        assert!(layer.is_ok());
275    }
276
277    #[test]
278    fn test_loha_forward_shape() {
279        let config = LoHaConfig::default();
280        let device = Device::Cpu;
281        let layer = LoHaLayer::new(768, 768, config, &device).unwrap();
282
283        let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
284        let output = layer.forward(&input, None).unwrap();
285
286        assert_eq!(output.shape().dims(), &[1, 10, 768]);
287    }
288
289    #[test]
290    fn test_loha_forward_with_base_output() {
291        let config = LoHaConfig::default();
292        let device = Device::Cpu;
293        let layer = LoHaLayer::new(768, 768, config, &device).unwrap();
294
295        let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
296        let base_output = Tensor::ones(&[1, 10, 768], DType::F32, &device).unwrap();
297        let output = layer.forward(&input, Some(&base_output)).unwrap();
298
299        assert_eq!(output.shape().dims(), &[1, 10, 768]);
300    }
301
302    #[test]
303    fn test_loha_num_parameters() {
304        let config = LoHaConfig {
305            r: 8,
306            alpha: 16,
307            ..Default::default()
308        };
309        let device = Device::Cpu;
310        let layer = LoHaLayer::new(768, 768, config, &device).unwrap();
311
312        // 2 * (out * r + r * in) = 2 * (768 * 8 + 8 * 768) = 2 * 12288 = 24576
313        assert_eq!(layer.num_parameters(), 24576);
314    }
315
316    #[test]
317    fn test_loha_merge_unmerge() {
318        let config = LoHaConfig::default();
319        let device = Device::Cpu;
320        let layer = LoHaLayer::new(64, 64, config, &device).unwrap();
321
322        let base_weight = Tensor::randn(0.0f32, 0.02, (64, 64), &device).unwrap();
323        let merged = layer.merge(&base_weight).unwrap();
324        let unmerged = layer.unmerge(&merged).unwrap();
325
326        // Unmerged should be close to original
327        let diff = unmerged.broadcast_sub(&base_weight).unwrap();
328        let max_diff: f32 = diff
329            .abs()
330            .unwrap()
331            .max(0)
332            .unwrap()
333            .max(0)
334            .unwrap()
335            .to_scalar()
336            .unwrap();
337        assert!(max_diff < 1e-5);
338    }
339
340    #[test]
341    fn test_loha_freeze_unfreeze() {
342        let config = LoHaConfig::default();
343        let device = Device::Cpu;
344        let mut layer = LoHaLayer::new(768, 768, config, &device).unwrap();
345
346        assert!(!layer.is_frozen());
347        layer.freeze();
348        assert!(layer.is_frozen());
349        layer.unfreeze();
350        assert!(!layer.is_frozen());
351    }
352}