ferrum_quantization/
lora.rs1use ferrum_types::{FerrumError, Result};
9
10#[derive(Debug, Clone)]
11pub struct LoraLinearRef {
12 base_weight: Vec<f32>,
13 a_weight: Vec<f32>,
14 b_weight: Vec<f32>,
15 in_features: usize,
16 out_features: usize,
17 rank: usize,
18 scaling: f32,
19}
20
21impl LoraLinearRef {
22 pub fn new(
23 base_weight: Vec<f32>,
24 a_weight: Vec<f32>,
25 b_weight: Vec<f32>,
26 in_features: usize,
27 out_features: usize,
28 rank: usize,
29 lora_alpha: f32,
30 ) -> Result<Self> {
31 if rank == 0 {
32 return Err(FerrumError::config("LoRA rank must be > 0"));
33 }
34 if base_weight.len() != out_features * in_features {
35 return Err(FerrumError::config(format!(
36 "base weight shape mismatch: got {} elements, expected {}x{}",
37 base_weight.len(),
38 out_features,
39 in_features
40 )));
41 }
42 if a_weight.len() != rank * in_features {
43 return Err(FerrumError::config(format!(
44 "LoRA A shape mismatch: got {} elements, expected {}x{}",
45 a_weight.len(),
46 rank,
47 in_features
48 )));
49 }
50 if b_weight.len() != out_features * rank {
51 return Err(FerrumError::config(format!(
52 "LoRA B shape mismatch: got {} elements, expected {}x{}",
53 b_weight.len(),
54 out_features,
55 rank
56 )));
57 }
58 Ok(Self {
59 base_weight,
60 a_weight,
61 b_weight,
62 in_features,
63 out_features,
64 rank,
65 scaling: lora_alpha / rank as f32,
66 })
67 }
68
69 pub fn forward(&self, input: &[f32], batch: usize) -> Result<Vec<f32>> {
70 if input.len() != batch * self.in_features {
71 return Err(FerrumError::config(format!(
72 "LoRA input shape mismatch: got {} elements, expected {}x{}",
73 input.len(),
74 batch,
75 self.in_features
76 )));
77 }
78 let mut out = vec![0.0f32; batch * self.out_features];
79 let mut low_rank = vec![0.0f32; batch * self.rank];
80
81 for m in 0..batch {
82 for o in 0..self.out_features {
83 let mut acc = 0.0f32;
84 for i in 0..self.in_features {
85 acc += input[m * self.in_features + i]
86 * self.base_weight[o * self.in_features + i];
87 }
88 out[m * self.out_features + o] = acc;
89 }
90 for r in 0..self.rank {
91 let mut acc = 0.0f32;
92 for i in 0..self.in_features {
93 acc +=
94 input[m * self.in_features + i] * self.a_weight[r * self.in_features + i];
95 }
96 low_rank[m * self.rank + r] = acc;
97 }
98 for o in 0..self.out_features {
99 let mut acc = 0.0f32;
100 for r in 0..self.rank {
101 acc += low_rank[m * self.rank + r] * self.b_weight[o * self.rank + r];
102 }
103 out[m * self.out_features + o] += self.scaling * acc;
104 }
105 }
106
107 Ok(out)
108 }
109}