use crate::lora::{LoRALayer, LoRAScaling};
use crate::Tensor;
pub struct DoRALayer {
magnitude: Tensor,
lora: LoRALayer,
d_out: usize,
d_in: usize,
}
impl DoRALayer {
pub fn new(
base_weight: Tensor,
d_out: usize,
d_in: usize,
rank: usize,
alpha: f32,
scaling: LoRAScaling,
) -> Self {
let magnitude_data: Vec<f32> = (0..d_out)
.map(|row| {
let row_start = row * d_in;
let row_end = row_start + d_in;
let row_norm_sq: f32 = base_weight
.data()
.slice(ndarray::s![row_start..row_end])
.iter()
.map(|x| x * x)
.sum();
row_norm_sq.sqrt().max(1e-8)
})
.collect();
let magnitude = Tensor::from_vec(magnitude_data, true);
let lora = LoRALayer::new_with_scaling(base_weight, d_out, d_in, rank, alpha, scaling);
Self { magnitude, lora, d_out, d_in }
}
pub fn forward(&self, x: &Tensor) -> Tensor {
assert_eq!(x.len(), self.d_in, "Input size must match d_in");
let lora_output = self.lora.forward(x);
let row_norms = self.compute_effective_row_norms();
let mut result = lora_output.data().to_owned();
for (i, val) in result.iter_mut().enumerate() {
let norm = row_norms[i].max(1e-8);
*val = self.magnitude.data()[i] * (*val / norm);
}
Tensor::new(result, self.magnitude.requires_grad())
}
fn compute_effective_row_norms(&self) -> Vec<f32> {
let base = self.lora.base_weight().data();
let scale = self.lora.scale();
let a_data = self.lora.lora_a().data();
let b_data = self.lora.lora_b().data();
let rank = self.lora.rank();
let mut norms = vec![0.0f32; self.d_out];
for row in 0..self.d_out {
let mut row_norm_sq = 0.0f32;
for col in 0..self.d_in {
let base_val = base[row * self.d_in + col];
let mut ba_val = 0.0f32;
for r in 0..rank {
ba_val += b_data[row * rank + r] * a_data[r * self.d_in + col];
}
let effective = base_val + scale * ba_val;
row_norm_sq += effective * effective;
}
norms[row] = row_norm_sq.sqrt();
}
norms
}
pub fn merge_to_f32(&self) -> Vec<f32> {
let row_norms = self.compute_effective_row_norms();
let base = self.lora.base_weight().data();
let scale = self.lora.scale();
let a_data = self.lora.lora_a().data();
let b_data = self.lora.lora_b().data();
let rank = self.lora.rank();
let mut merged = vec![0.0f32; self.d_out * self.d_in];
for row in 0..self.d_out {
let m = self.magnitude.data()[row];
let norm = row_norms[row].max(1e-8);
for col in 0..self.d_in {
let base_val = base[row * self.d_in + col];
let mut ba_val = 0.0f32;
for r in 0..rank {
ba_val += b_data[row * rank + r] * a_data[r * self.d_in + col];
}
merged[row * self.d_in + col] = m * (base_val + scale * ba_val) / norm;
}
}
merged
}
pub fn trainable_params(&mut self) -> Vec<&mut Tensor> {
let mut params = vec![&mut self.magnitude];
params.extend(self.lora.trainable_params());
params
}
pub fn magnitude(&self) -> &Tensor {
&self.magnitude
}
pub fn lora(&self) -> &LoRALayer {
&self.lora
}
pub fn trainable_param_count(&self) -> usize {
self.d_out + self.lora.rank() * self.d_in + self.d_out * self.lora.rank()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use proptest::prelude::*;
#[test]
fn test_ent_lora_011_dora_creation() {
let base = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
let dora = DoRALayer::new(base, 2, 2, 1, 2.0, LoRAScaling::Standard);
assert_eq!(dora.d_out, 2);
assert_eq!(dora.d_in, 2);
assert!(dora.magnitude().len() == 2);
}
#[test]
fn test_ent_lora_011_dora_magnitude_init() {
let base = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
let dora = DoRALayer::new(base, 2, 2, 1, 2.0, LoRAScaling::Standard);
assert_abs_diff_eq!(dora.magnitude().data()[0], 1.0, epsilon = 1e-6);
assert_abs_diff_eq!(dora.magnitude().data()[1], 1.0, epsilon = 1e-6);
}
#[test]
fn test_ent_lora_011_dora_forward_dimensions() {
let base = Tensor::from_vec(vec![1.0; 12], false);
let dora = DoRALayer::new(base, 3, 4, 2, 4.0, LoRAScaling::RsLoRA);
let x = Tensor::from_vec(vec![0.5; 4], true);
let out = dora.forward(&x);
assert_eq!(out.len(), 3);
}
#[test]
fn test_ent_lora_011_dora_trainable_count() {
let base = Tensor::from_vec(vec![1.0; 16], false);
let dora = DoRALayer::new(base, 4, 4, 2, 4.0, LoRAScaling::Standard);
assert_eq!(dora.trainable_param_count(), 20);
}
#[test]
fn test_ent_lora_011_dora_merge_dimensions() {
let base = Tensor::from_vec(vec![1.0; 12], false);
let dora = DoRALayer::new(base, 3, 4, 2, 4.0, LoRAScaling::Standard);
let merged = dora.merge_to_f32();
assert_eq!(merged.len(), 12);
}
#[test]
fn test_ent_lora_011_dora_trainable_params() {
let base = Tensor::from_vec(vec![1.0; 16], false);
let mut dora = DoRALayer::new(base, 4, 4, 2, 4.0, LoRAScaling::Standard);
let params = dora.trainable_params();
assert_eq!(params.len(), 3);
}
proptest! {
#![proptest_config(proptest::test_runner::Config::with_cases(50))]
#[test]
fn prop_dora_forward_finite(
d_out in 2usize..8,
d_in in 2usize..8,
rank in 1usize..4,
) {
let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
let dora = DoRALayer::new(base, d_out, d_in, rank, 4.0, LoRAScaling::Standard);
let x = Tensor::from_vec(vec![0.1; d_in], true);
let out = dora.forward(&x);
prop_assert_eq!(out.len(), d_out);
for val in out.data() {
prop_assert!(val.is_finite(), "Output must be finite, got {val}");
}
}
#[test]
fn prop_dora_merge_finite(
d_out in 2usize..8,
d_in in 2usize..8,
rank in 1usize..4,
) {
let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
let dora = DoRALayer::new(base, d_out, d_in, rank, 4.0, LoRAScaling::Standard);
let merged = dora.merge_to_f32();
prop_assert_eq!(merged.len(), d_out * d_in);
for val in &merged {
prop_assert!(val.is_finite());
}
}
}
}