Skip to main content

ferrum_quantization/
lora.rs

1//! LoRA reference utilities.
2//!
3//! G4 keeps production adapter serving startup-scoped. This module provides
4//! the small f32 reference path used by loader and routing tests:
5//!
6//! y = base(x) + (alpha / r) * B(A(x))
7
8use ferrum_types::{FerrumError, Result};
9
10#[derive(Debug, Clone)]
11pub struct LoraLinearRef {
12    base_weight: Vec<f32>,
13    a_weight: Vec<f32>,
14    b_weight: Vec<f32>,
15    in_features: usize,
16    out_features: usize,
17    rank: usize,
18    scaling: f32,
19}
20
21impl LoraLinearRef {
22    pub fn new(
23        base_weight: Vec<f32>,
24        a_weight: Vec<f32>,
25        b_weight: Vec<f32>,
26        in_features: usize,
27        out_features: usize,
28        rank: usize,
29        lora_alpha: f32,
30    ) -> Result<Self> {
31        if rank == 0 {
32            return Err(FerrumError::config("LoRA rank must be > 0"));
33        }
34        if base_weight.len() != out_features * in_features {
35            return Err(FerrumError::config(format!(
36                "base weight shape mismatch: got {} elements, expected {}x{}",
37                base_weight.len(),
38                out_features,
39                in_features
40            )));
41        }
42        if a_weight.len() != rank * in_features {
43            return Err(FerrumError::config(format!(
44                "LoRA A shape mismatch: got {} elements, expected {}x{}",
45                a_weight.len(),
46                rank,
47                in_features
48            )));
49        }
50        if b_weight.len() != out_features * rank {
51            return Err(FerrumError::config(format!(
52                "LoRA B shape mismatch: got {} elements, expected {}x{}",
53                b_weight.len(),
54                out_features,
55                rank
56            )));
57        }
58        Ok(Self {
59            base_weight,
60            a_weight,
61            b_weight,
62            in_features,
63            out_features,
64            rank,
65            scaling: lora_alpha / rank as f32,
66        })
67    }
68
69    pub fn forward(&self, input: &[f32], batch: usize) -> Result<Vec<f32>> {
70        if input.len() != batch * self.in_features {
71            return Err(FerrumError::config(format!(
72                "LoRA input shape mismatch: got {} elements, expected {}x{}",
73                input.len(),
74                batch,
75                self.in_features
76            )));
77        }
78        let mut out = vec![0.0f32; batch * self.out_features];
79        let mut low_rank = vec![0.0f32; batch * self.rank];
80
81        for m in 0..batch {
82            for o in 0..self.out_features {
83                let mut acc = 0.0f32;
84                for i in 0..self.in_features {
85                    acc += input[m * self.in_features + i]
86                        * self.base_weight[o * self.in_features + i];
87                }
88                out[m * self.out_features + o] = acc;
89            }
90            for r in 0..self.rank {
91                let mut acc = 0.0f32;
92                for i in 0..self.in_features {
93                    acc +=
94                        input[m * self.in_features + i] * self.a_weight[r * self.in_features + i];
95                }
96                low_rank[m * self.rank + r] = acc;
97            }
98            for o in 0..self.out_features {
99                let mut acc = 0.0f32;
100                for r in 0..self.rank {
101                    acc += low_rank[m * self.rank + r] * self.b_weight[o * self.rank + r];
102                }
103                out[m * self.out_features + o] += self.scaling * acc;
104            }
105        }
106
107        Ok(out)
108    }
109}