1#![allow(clippy::doc_markdown)]
9#![allow(clippy::cast_possible_truncation)]
10#![allow(clippy::cast_precision_loss)]
11
12use candle_core::{Device, Tensor};
13use candle_nn::VarMap;
14use serde::{Deserialize, Serialize};
15
16use crate::error::{PeftError, Result};
17use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct LoHaConfig {
22 pub r: usize,
24
25 pub alpha: usize,
27
28 #[serde(default = "default_target_modules")]
30 pub target_modules: Vec<String>,
31
32 #[serde(default)]
34 pub use_effective_conv2d: bool,
35}
36
37fn default_target_modules() -> Vec<String> {
38 vec!["q_proj".into(), "v_proj".into()]
39}
40
41impl Default for LoHaConfig {
42 fn default() -> Self {
43 Self {
44 r: 8,
45 alpha: 16,
46 target_modules: default_target_modules(),
47 use_effective_conv2d: false,
48 }
49 }
50}
51
52impl AdapterConfig for LoHaConfig {
53 fn validate(&self) -> Result<()> {
54 if self.r == 0 {
55 return Err(PeftError::InvalidConfig("rank must be > 0".into()));
56 }
57 if self.alpha == 0 {
58 return Err(PeftError::InvalidConfig("alpha must be > 0".into()));
59 }
60 Ok(())
61 }
62}
63
64pub struct LoHaLayer {
73 hada_w1_a: Tensor,
75 hada_w1_b: Tensor,
77 hada_w2_a: Tensor,
79 hada_w2_b: Tensor,
81 scaling: f64,
83 config: LoHaConfig,
85 in_features: usize,
87 out_features: usize,
89 frozen: bool,
91}
92
93impl LoHaLayer {
94 pub fn new(
105 in_features: usize,
106 out_features: usize,
107 config: LoHaConfig,
108 device: &Device,
109 ) -> Result<Self> {
110 config.validate()?;
111
112 let scaling = config.alpha as f64 / config.r as f64;
113
114 let std = (1.0 / config.r as f64).sqrt() as f32;
116
117 let hada_w1_a = Tensor::randn(0.0f32, std, (out_features, config.r), device)?;
119 let hada_w1_b = Tensor::randn(0.0f32, std, (config.r, in_features), device)?;
120
121 let hada_w2_a = Tensor::randn(0.0f32, std, (out_features, config.r), device)?;
123 let hada_w2_b = Tensor::randn(0.0f32, std, (config.r, in_features), device)?;
124
125 Ok(Self {
126 hada_w1_a,
127 hada_w1_b,
128 hada_w2_a,
129 hada_w2_b,
130 scaling,
131 config,
132 in_features,
133 out_features,
134 frozen: false,
135 })
136 }
137
138 #[must_use]
140 pub fn scaling(&self) -> f64 {
141 self.scaling
142 }
143
144 #[must_use]
146 pub fn rank(&self) -> usize {
147 self.config.r
148 }
149
150 fn compute_delta_w(&self) -> Result<Tensor> {
152 let term1 = self.hada_w1_a.matmul(&self.hada_w1_b)?;
154
155 let term2 = self.hada_w2_a.matmul(&self.hada_w2_b)?;
157
158 Ok(term1.mul(&term2)?)
160 }
161}
162
163impl Adapter for LoHaLayer {
164 type Config = LoHaConfig;
165
166 fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
167 let delta_w = self.compute_delta_w()?;
169
170 let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
172 let delta_w = delta_w.broadcast_mul(&scaling)?;
173
174 let input_dims = input.dims();
176 let batch_seq = input_dims[0] * input_dims[1];
177 let input_2d = input.reshape((batch_seq, self.in_features))?;
178
179 let loha_out = input_2d.matmul(&delta_w.t()?)?;
180 let loha_out = loha_out.reshape((input_dims[0], input_dims[1], self.out_features))?;
181
182 match base_output {
184 Some(base) => Ok(base.broadcast_add(&loha_out)?),
185 None => Ok(loha_out),
186 }
187 }
188
189 fn num_parameters(&self) -> usize {
190 2 * (self.out_features * self.config.r + self.config.r * self.in_features)
192 }
193
194 fn config(&self) -> &Self::Config {
195 &self.config
196 }
197}
198
199impl Mergeable for LoHaLayer {
200 fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
201 let delta_w = self.compute_delta_w()?;
202 let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
203 let delta_w = delta_w.broadcast_mul(&scaling)?;
204
205 Ok(base_weight.broadcast_add(&delta_w)?)
206 }
207
208 fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
209 let delta_w = self.compute_delta_w()?;
210 let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
211 let delta_w = delta_w.broadcast_mul(&scaling)?;
212
213 Ok(merged_weight.broadcast_sub(&delta_w)?)
214 }
215}
216
217impl Trainable for LoHaLayer {
218 fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
219 Ok(())
223 }
224
225 fn freeze(&mut self) {
226 self.frozen = true;
227 }
228
229 fn unfreeze(&mut self) {
230 self.frozen = false;
231 }
232
233 fn is_frozen(&self) -> bool {
234 self.frozen
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use candle_core::DType;
242
243 #[test]
244 fn test_loha_config_default() {
245 let config = LoHaConfig::default();
246 assert_eq!(config.r, 8);
247 assert_eq!(config.alpha, 16);
248 assert!(config.validate().is_ok());
249 }
250
251 #[test]
252 fn test_loha_config_invalid_rank() {
253 let config = LoHaConfig {
254 r: 0,
255 ..Default::default()
256 };
257 assert!(config.validate().is_err());
258 }
259
260 #[test]
261 fn test_loha_config_invalid_alpha() {
262 let config = LoHaConfig {
263 alpha: 0,
264 ..Default::default()
265 };
266 assert!(config.validate().is_err());
267 }
268
269 #[test]
270 fn test_loha_layer_creation() {
271 let config = LoHaConfig::default();
272 let device = Device::Cpu;
273 let layer = LoHaLayer::new(768, 768, config, &device);
274 assert!(layer.is_ok());
275 }
276
277 #[test]
278 fn test_loha_forward_shape() {
279 let config = LoHaConfig::default();
280 let device = Device::Cpu;
281 let layer = LoHaLayer::new(768, 768, config, &device).unwrap();
282
283 let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
284 let output = layer.forward(&input, None).unwrap();
285
286 assert_eq!(output.shape().dims(), &[1, 10, 768]);
287 }
288
289 #[test]
290 fn test_loha_forward_with_base_output() {
291 let config = LoHaConfig::default();
292 let device = Device::Cpu;
293 let layer = LoHaLayer::new(768, 768, config, &device).unwrap();
294
295 let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
296 let base_output = Tensor::ones(&[1, 10, 768], DType::F32, &device).unwrap();
297 let output = layer.forward(&input, Some(&base_output)).unwrap();
298
299 assert_eq!(output.shape().dims(), &[1, 10, 768]);
300 }
301
302 #[test]
303 fn test_loha_num_parameters() {
304 let config = LoHaConfig {
305 r: 8,
306 alpha: 16,
307 ..Default::default()
308 };
309 let device = Device::Cpu;
310 let layer = LoHaLayer::new(768, 768, config, &device).unwrap();
311
312 assert_eq!(layer.num_parameters(), 24576);
314 }
315
316 #[test]
317 fn test_loha_merge_unmerge() {
318 let config = LoHaConfig::default();
319 let device = Device::Cpu;
320 let layer = LoHaLayer::new(64, 64, config, &device).unwrap();
321
322 let base_weight = Tensor::randn(0.0f32, 0.02, (64, 64), &device).unwrap();
323 let merged = layer.merge(&base_weight).unwrap();
324 let unmerged = layer.unmerge(&merged).unwrap();
325
326 let diff = unmerged.broadcast_sub(&base_weight).unwrap();
328 let max_diff: f32 = diff
329 .abs()
330 .unwrap()
331 .max(0)
332 .unwrap()
333 .max(0)
334 .unwrap()
335 .to_scalar()
336 .unwrap();
337 assert!(max_diff < 1e-5);
338 }
339
340 #[test]
341 fn test_loha_freeze_unfreeze() {
342 let config = LoHaConfig::default();
343 let device = Device::Cpu;
344 let mut layer = LoHaLayer::new(768, 768, config, &device).unwrap();
345
346 assert!(!layer.is_frozen());
347 layer.freeze();
348 assert!(layer.is_frozen());
349 layer.unfreeze();
350 assert!(!layer.is_frozen());
351 }
352}