1#![allow(clippy::doc_markdown)]
10#![allow(clippy::cast_possible_truncation)]
11#![allow(clippy::cast_precision_loss)]
12#![allow(clippy::cast_sign_loss)]
13
14use candle_core::{Device, Tensor};
15use candle_nn::VarMap;
16use serde::{Deserialize, Serialize};
17
18use crate::error::{PeftError, Result};
19use crate::traits::{Adapter, AdapterConfig, Mergeable, Trainable};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct LoKrConfig {
24 pub r: usize,
26
27 pub alpha: usize,
29
30 #[serde(default)]
33 pub factor: Option<usize>,
34
35 #[serde(default)]
37 pub decompose_both: bool,
38
39 #[serde(default = "default_target_modules")]
41 pub target_modules: Vec<String>,
42}
43
44fn default_target_modules() -> Vec<String> {
45 vec!["q_proj".into(), "v_proj".into()]
46}
47
48impl Default for LoKrConfig {
49 fn default() -> Self {
50 Self {
51 r: 8,
52 alpha: 16,
53 factor: None,
54 decompose_both: false,
55 target_modules: default_target_modules(),
56 }
57 }
58}
59
60impl AdapterConfig for LoKrConfig {
61 fn validate(&self) -> Result<()> {
62 if self.r == 0 {
63 return Err(PeftError::InvalidConfig("rank must be > 0".into()));
64 }
65 if self.alpha == 0 {
66 return Err(PeftError::InvalidConfig("alpha must be > 0".into()));
67 }
68 Ok(())
69 }
70}
71
72pub struct LoKrLayer {
82 lokr_w1: Tensor,
84 lokr_w2_a: Tensor,
86 lokr_w2_b: Tensor,
88 scaling: f64,
90 config: LoKrConfig,
92 in_features: usize,
94 out_features: usize,
96 factor_out: usize,
98 factor_in: usize,
100 frozen: bool,
102}
103
104impl LoKrLayer {
105 pub fn new(
116 in_features: usize,
117 out_features: usize,
118 config: LoKrConfig,
119 device: &Device,
120 ) -> Result<Self> {
121 config.validate()?;
122
123 let scaling = config.alpha as f64 / config.r as f64;
124
125 let factor = config.factor.unwrap_or_else(|| {
129 let target = (out_features as f64).sqrt() as usize;
131 for f in (1..=target).rev() {
132 if out_features.is_multiple_of(f) && in_features.is_multiple_of(f) {
133 return f;
134 }
135 }
136 1
137 });
138
139 let factor_out = factor.min(out_features);
140 let factor_in = factor.min(in_features);
141 let remaining_out = out_features / factor_out;
142 let remaining_in = in_features / factor_in;
143
144 let std = (1.0 / config.r as f64).sqrt() as f32;
146
147 let lokr_w1 = Tensor::randn(0.0f32, std, (factor_out, factor_in), device)?;
149
150 let lokr_w2_a = Tensor::randn(0.0f32, std, (remaining_out, config.r), device)?;
152 let lokr_w2_b = Tensor::randn(0.0f32, std, (config.r, remaining_in), device)?;
153
154 Ok(Self {
155 lokr_w1,
156 lokr_w2_a,
157 lokr_w2_b,
158 scaling,
159 config,
160 in_features,
161 out_features,
162 factor_out,
163 factor_in,
164 frozen: false,
165 })
166 }
167
168 #[must_use]
170 pub fn scaling(&self) -> f64 {
171 self.scaling
172 }
173
174 #[must_use]
176 pub fn rank(&self) -> usize {
177 self.config.r
178 }
179
180 #[allow(clippy::many_single_char_names)]
183 fn kronecker_product(a: &Tensor, b: &Tensor) -> Result<Tensor> {
184 let a_shape = a.dims();
185 let b_shape = b.dims();
186
187 let m = a_shape[0];
188 let n = a_shape[1];
189 let p = b_shape[0];
190 let q = b_shape[1];
191
192 let mut result_data = Vec::with_capacity(m * p * n * q);
194
195 let a_data: Vec<f32> = a.flatten_all()?.to_vec1()?;
197 let b_data: Vec<f32> = b.flatten_all()?.to_vec1()?;
198
199 for i in 0..m {
201 for k in 0..p {
202 for j in 0..n {
203 for l in 0..q {
204 let a_val = a_data[i * n + j];
205 let b_val = b_data[k * q + l];
206 result_data.push(a_val * b_val);
207 }
208 }
209 }
210 }
211
212 Ok(Tensor::from_vec(result_data, (m * p, n * q), a.device())?)
213 }
214
215 fn compute_delta_w(&self) -> Result<Tensor> {
217 let w2 = self.lokr_w2_a.matmul(&self.lokr_w2_b)?;
219
220 Self::kronecker_product(&self.lokr_w1, &w2)
222 }
223}
224
225impl Adapter for LoKrLayer {
226 type Config = LoKrConfig;
227
228 fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
229 let delta_w = self.compute_delta_w()?;
231
232 let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
234 let delta_w = delta_w.broadcast_mul(&scaling)?;
235
236 let input_dims = input.dims();
238 let batch_seq = input_dims[0] * input_dims[1];
239 let input_2d = input.reshape((batch_seq, self.in_features))?;
240
241 let lokr_out = input_2d.matmul(&delta_w.t()?)?;
242 let lokr_out = lokr_out.reshape((input_dims[0], input_dims[1], self.out_features))?;
243
244 match base_output {
246 Some(base) => Ok(base.broadcast_add(&lokr_out)?),
247 None => Ok(lokr_out),
248 }
249 }
250
251 fn num_parameters(&self) -> usize {
252 let remaining_out = self.out_features / self.factor_out;
253 let remaining_in = self.in_features / self.factor_in;
254
255 self.factor_out * self.factor_in
259 + remaining_out * self.config.r
260 + self.config.r * remaining_in
261 }
262
263 fn config(&self) -> &Self::Config {
264 &self.config
265 }
266}
267
268impl Mergeable for LoKrLayer {
269 fn merge(&self, base_weight: &Tensor) -> Result<Tensor> {
270 let delta_w = self.compute_delta_w()?;
271 let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
272 let delta_w = delta_w.broadcast_mul(&scaling)?;
273
274 Ok(base_weight.broadcast_add(&delta_w)?)
275 }
276
277 fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor> {
278 let delta_w = self.compute_delta_w()?;
279 let scaling = Tensor::new(self.scaling as f32, delta_w.device())?;
280 let delta_w = delta_w.broadcast_mul(&scaling)?;
281
282 Ok(merged_weight.broadcast_sub(&delta_w)?)
283 }
284}
285
286impl Trainable for LoKrLayer {
287 fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> Result<()> {
288 Ok(())
289 }
290
291 fn freeze(&mut self) {
292 self.frozen = true;
293 }
294
295 fn unfreeze(&mut self) {
296 self.frozen = false;
297 }
298
299 fn is_frozen(&self) -> bool {
300 self.frozen
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use candle_core::DType;
308
309 #[test]
310 fn test_lokr_config_default() {
311 let config = LoKrConfig::default();
312 assert_eq!(config.r, 8);
313 assert_eq!(config.alpha, 16);
314 assert!(config.validate().is_ok());
315 }
316
317 #[test]
318 fn test_lokr_config_invalid_rank() {
319 let config = LoKrConfig {
320 r: 0,
321 ..Default::default()
322 };
323 assert!(config.validate().is_err());
324 }
325
326 #[test]
327 fn test_lokr_layer_creation() {
328 let config = LoKrConfig::default();
329 let device = Device::Cpu;
330 let layer = LoKrLayer::new(64, 64, config, &device);
332 assert!(layer.is_ok());
333 }
334
335 #[test]
336 fn test_lokr_layer_with_factor() {
337 let config = LoKrConfig {
338 factor: Some(8),
339 ..Default::default()
340 };
341 let device = Device::Cpu;
342 let layer = LoKrLayer::new(64, 64, config, &device);
343 assert!(layer.is_ok());
344
345 let layer = layer.unwrap();
346 assert_eq!(layer.factor_out, 8);
347 assert_eq!(layer.factor_in, 8);
348 }
349
350 #[test]
351 fn test_lokr_forward_shape() {
352 let config = LoKrConfig {
353 factor: Some(8),
354 ..Default::default()
355 };
356 let device = Device::Cpu;
357 let layer = LoKrLayer::new(64, 64, config, &device).unwrap();
358
359 let input = Tensor::zeros(&[1, 10, 64], DType::F32, &device).unwrap();
360 let output = layer.forward(&input, None).unwrap();
361
362 assert_eq!(output.shape().dims(), &[1, 10, 64]);
363 }
364
365 #[test]
366 fn test_lokr_forward_with_base_output() {
367 let config = LoKrConfig {
368 factor: Some(8),
369 ..Default::default()
370 };
371 let device = Device::Cpu;
372 let layer = LoKrLayer::new(64, 64, config, &device).unwrap();
373
374 let input = Tensor::zeros(&[1, 10, 64], DType::F32, &device).unwrap();
375 let base_output = Tensor::ones(&[1, 10, 64], DType::F32, &device).unwrap();
376 let output = layer.forward(&input, Some(&base_output)).unwrap();
377
378 assert_eq!(output.shape().dims(), &[1, 10, 64]);
379 }
380
381 #[test]
382 fn test_lokr_num_parameters() {
383 let config = LoKrConfig {
384 r: 4,
385 factor: Some(8),
386 ..Default::default()
387 };
388 let device = Device::Cpu;
389 let layer = LoKrLayer::new(64, 64, config, &device).unwrap();
390
391 assert_eq!(layer.num_parameters(), 128);
397 }
398
399 #[test]
400 fn test_lokr_merge_unmerge() {
401 let config = LoKrConfig {
402 factor: Some(8),
403 ..Default::default()
404 };
405 let device = Device::Cpu;
406 let layer = LoKrLayer::new(64, 64, config, &device).unwrap();
407
408 let base_weight = Tensor::randn(0.0f32, 0.02, (64, 64), &device).unwrap();
409 let merged = layer.merge(&base_weight).unwrap();
410 let unmerged = layer.unmerge(&merged).unwrap();
411
412 let diff = unmerged.broadcast_sub(&base_weight).unwrap();
414 let max_diff: f32 = diff
415 .abs()
416 .unwrap()
417 .max(0)
418 .unwrap()
419 .max(0)
420 .unwrap()
421 .to_scalar()
422 .unwrap();
423 assert!(max_diff < 1e-5);
424 }
425
426 #[test]
427 fn test_lokr_freeze_unfreeze() {
428 let config = LoKrConfig::default();
429 let device = Device::Cpu;
430 let mut layer = LoKrLayer::new(64, 64, config, &device).unwrap();
431
432 assert!(!layer.is_frozen());
433 layer.freeze();
434 assert!(layer.is_frozen());
435 layer.unfreeze();
436 assert!(!layer.is_frozen());
437 }
438
439 #[test]
440 fn test_kronecker_product() {
441 let device = Device::Cpu;
442 let a = Tensor::new(&[[1.0f32, 2.0], [3.0, 4.0]], &device).unwrap();
443 let b = Tensor::new(&[[0.0f32, 5.0], [6.0, 7.0]], &device).unwrap();
444
445 let result = LoKrLayer::kronecker_product(&a, &b).unwrap();
446 assert_eq!(result.dims(), &[4, 4]);
447 }
448}