use crate::autograd::matmul;
use crate::Tensor;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LoRAScaling {
Standard,
RsLoRA,
}
impl LoRAScaling {
pub fn compute(self, alpha: f32, rank: usize) -> f32 {
assert!(rank > 0, "LoRA rank must be > 0");
match self {
Self::Standard => alpha / rank as f32,
Self::RsLoRA => alpha / (rank as f32).sqrt(),
}
}
}
#[derive(Clone)]
pub struct LoRALayer {
base_weight: Tensor,
lora_a: Tensor,
lora_b: Tensor,
d_out: usize,
d_in: usize,
rank: usize,
scale: f32,
merged: bool,
}
impl LoRALayer {
pub fn new(base_weight: Tensor, d_out: usize, d_in: usize, rank: usize, alpha: f32) -> Self {
assert!(rank > 0, "LoRA rank must be > 0");
assert_eq!(base_weight.len(), d_out * d_in, "Base weight size must match d_out * d_in");
let lora_a_data: Vec<f32> = (0..rank * d_in)
.map(|i| {
let x = (i as f32 * 0.1).sin();
x * 0.01 })
.collect();
let lora_a = Tensor::from_vec(lora_a_data, true);
let lora_b = Tensor::zeros(d_out * rank, true);
let scale = alpha / rank as f32;
Self { base_weight, lora_a, lora_b, d_out, d_in, rank, scale, merged: false }
}
pub fn new_with_scaling(
base_weight: Tensor,
d_out: usize,
d_in: usize,
rank: usize,
alpha: f32,
scaling: LoRAScaling,
) -> Self {
let mut layer = Self::new(base_weight, d_out, d_in, rank, alpha);
layer.scale = scaling.compute(alpha, rank);
layer
}
pub fn forward(&self, x: &Tensor) -> Tensor {
assert_eq!(x.len(), self.d_in, "Input size must match d_in");
let base_output = matmul(&self.base_weight, x, self.d_out, self.d_in, 1);
if self.merged {
base_output
} else {
let lora_out_a = matmul(&self.lora_a, x, self.rank, self.d_in, 1);
let lora_out_b = matmul(&self.lora_b, &lora_out_a, self.d_out, self.rank, 1);
let mut scaled_lora_data = lora_out_b.data().to_owned();
for val in &mut scaled_lora_data {
*val *= self.scale;
}
let scaled_lora = Tensor::new(scaled_lora_data, false);
let mut result_data = base_output.data().to_owned();
for (i, val) in result_data.iter_mut().enumerate() {
*val += scaled_lora.data()[i];
}
Tensor::new(result_data, base_output.requires_grad())
}
}
pub fn merge(&mut self) {
if self.merged {
return; }
let ba = matmul(&self.lora_b, &self.lora_a, self.d_out, self.rank, self.d_in);
for (i, val) in self.base_weight.data_mut().iter_mut().enumerate() {
*val += self.scale * ba.data()[i];
}
self.merged = true;
}
pub fn unmerge(&mut self) {
if !self.merged {
return; }
let ba = matmul(&self.lora_b, &self.lora_a, self.d_out, self.rank, self.d_in);
for (i, val) in self.base_weight.data_mut().iter_mut().enumerate() {
*val -= self.scale * ba.data()[i];
}
self.merged = false;
}
pub fn base_weight(&self) -> &Tensor {
&self.base_weight
}
pub fn lora_a(&self) -> &Tensor {
&self.lora_a
}
pub fn lora_a_mut(&mut self) -> &mut Tensor {
&mut self.lora_a
}
pub fn lora_b(&self) -> &Tensor {
&self.lora_b
}
pub fn lora_b_mut(&mut self) -> &mut Tensor {
&mut self.lora_b
}
pub fn trainable_params(&mut self) -> Vec<&mut Tensor> {
vec![&mut self.lora_a, &mut self.lora_b]
}
pub fn is_merged(&self) -> bool {
self.merged
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn scale(&self) -> f32 {
self.scale
}
pub fn d_out(&self) -> usize {
self.d_out
}
pub fn d_in(&self) -> usize {
self.d_in
}
}