use super::*;
use crate::autograd::matmul;
use crate::Tensor;
use approx::assert_abs_diff_eq;
use proptest::prelude::*;
proptest! {
#![proptest_config(proptest::test_runner::Config::with_cases(200))]
#[test]
fn prop_zero_b_gives_base_output(
d_out in 2usize..10,
d_in in 2usize..10,
rank in 1usize..5,
) {
let size = d_out * d_in;
let base_data: Vec<f32> = (0..size).map(|i| (i as f32 * 0.1).sin()).collect();
let base_weight = Tensor::from_vec(base_data, false);
let lora = LoRALayer::new(base_weight.clone(), d_out, d_in, rank, 1.0);
let x_data: Vec<f32> = (0..d_in).map(|i| i as f32 * 0.5).collect();
let x = Tensor::from_vec(x_data.clone(), true);
let lora_output = lora.forward(&x);
let base_output = matmul(&base_weight, &Tensor::from_vec(x_data, false), d_out, d_in, 1);
for i in 0..d_out {
prop_assert!(
(lora_output.data()[i] - base_output.data()[i]).abs() < 1e-4,
"Zero B should give base output at index {}", i
);
}
}
#[test]
fn prop_merge_preserves_forward_output(
d_out in 2usize..8,
d_in in 2usize..8,
rank in 1usize..4,
) {
let size = d_out * d_in;
let base_data: Vec<f32> = (0..size).map(|i| (i as f32 * 0.1).cos()).collect();
let base_weight = Tensor::from_vec(base_data, false);
let mut lora = LoRALayer::new(base_weight, d_out, d_in, rank, 2.0);
let a_data: Vec<f32> = (0..rank * d_in).map(|i| (i as f32 * 0.2).sin() * 0.1).collect();
let b_data: Vec<f32> = (0..d_out * rank).map(|i| (i as f32 * 0.3).cos() * 0.1).collect();
*lora.lora_a_mut().data_mut() = ndarray::Array1::from_vec(a_data);
*lora.lora_b_mut().data_mut() = ndarray::Array1::from_vec(b_data);
let x_data: Vec<f32> = (0..d_in).map(|i| i as f32 + 1.0).collect();
let x = Tensor::from_vec(x_data.clone(), true);
let output_before = lora.forward(&x);
lora.merge();
prop_assert!(lora.is_merged());
let x2 = Tensor::from_vec(x_data, true);
let output_after = lora.forward(&x2);
for i in 0..d_out {
prop_assert!(
(output_before.data()[i] - output_after.data()[i]).abs() < 1e-3,
"Merge should preserve output at index {}: before={} after={}",
i, output_before.data()[i], output_after.data()[i]
);
}
}
#[test]
fn prop_unmerge_restores_weights(
d_out in 2usize..8,
d_in in 2usize..8,
rank in 1usize..4,
) {
let size = d_out * d_in;
let base_data: Vec<f32> = (0..size).map(|i| i as f32 * 0.5).collect();
let base_weight = Tensor::from_vec(base_data.clone(), false);
let mut lora = LoRALayer::new(base_weight, d_out, d_in, rank, 1.0);
let a_data: Vec<f32> = (0..rank * d_in).map(|i| i as f32 * 0.01).collect();
let b_data: Vec<f32> = (0..d_out * rank).map(|i| i as f32 * 0.02).collect();
*lora.lora_a_mut().data_mut() = ndarray::Array1::from_vec(a_data);
*lora.lora_b_mut().data_mut() = ndarray::Array1::from_vec(b_data);
lora.merge();
lora.unmerge();
for i in 0..size {
prop_assert!(
(lora.base_weight().data()[i] - base_data[i]).abs() < 1e-4,
"Unmerge should restore weight at index {}", i
);
}
}
#[test]
fn prop_scale_factor_correct(
rank in 1usize..32,
alpha in 1.0f32..64.0,
) {
let base_weight = Tensor::from_vec(vec![1.0], false);
let lora = LoRALayer::new(base_weight, 1, 1, rank, alpha);
let expected_scale = alpha / rank as f32;
prop_assert!(
(lora.scale() - expected_scale).abs() < 1e-6,
"Scale should be alpha/rank: expected {} got {}", expected_scale, lora.scale()
);
}
#[test]
fn prop_lora_dimensions_correct(
d_out in 2usize..20,
d_in in 2usize..20,
rank in 1usize..10,
) {
let size = d_out * d_in;
let base_data: Vec<f32> = vec![0.0; size];
let base_weight = Tensor::from_vec(base_data, false);
let lora = LoRALayer::new(base_weight, d_out, d_in, rank, 1.0);
prop_assert_eq!(lora.d_out(), d_out);
prop_assert_eq!(lora.d_in(), d_in);
prop_assert_eq!(lora.rank(), rank);
prop_assert_eq!(lora.lora_a().len(), rank * d_in);
prop_assert_eq!(lora.lora_b().len(), d_out * rank);
prop_assert_eq!(lora.base_weight().len(), d_out * d_in);
}
}
#[test]
fn test_lora_layer_creation() {
let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], false);
let lora = LoRALayer::new(base_weight, 3, 2, 2, 2.0);
assert_eq!(lora.rank(), 2);
assert_eq!(lora.d_out(), 3);
assert_eq!(lora.d_in(), 2);
assert_abs_diff_eq!(lora.scale(), 1.0, epsilon = 1e-6); assert!(!lora.is_merged());
assert_eq!(lora.lora_a().len(), 2 * 2); assert_eq!(lora.lora_b().len(), 3 * 2); }
#[test]
fn test_lora_forward_unmerged() {
let base_weight = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
let mut lora = LoRALayer::new(base_weight, 2, 2, 1, 1.0);
*lora.lora_a_mut().data_mut() = ndarray::arr1(&[1.0, 2.0]);
*lora.lora_b_mut().data_mut() = ndarray::arr1(&[3.0, 4.0]);
let x = Tensor::from_vec(vec![1.0, 2.0], true);
let output = lora.forward(&x);
assert_eq!(output.len(), 2);
assert_abs_diff_eq!(output.data()[0], 16.0, epsilon = 1e-4);
assert_abs_diff_eq!(output.data()[1], 22.0, epsilon = 1e-4);
}
#[test]
fn test_lora_merge_unmerge() {
let base_weight = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
let mut lora = LoRALayer::new(base_weight, 2, 2, 1, 1.0);
*lora.lora_a_mut().data_mut() = ndarray::arr1(&[1.0, 2.0]);
*lora.lora_b_mut().data_mut() = ndarray::arr1(&[0.5, 0.5]);
let original_weight = lora.base_weight().data().to_owned();
lora.merge();
assert!(lora.is_merged());
let merged_weight = lora.base_weight().data();
assert_abs_diff_eq!(merged_weight[0], 1.5, epsilon = 1e-4);
assert_abs_diff_eq!(merged_weight[1], 1.0, epsilon = 1e-4);
assert_abs_diff_eq!(merged_weight[2], 0.5, epsilon = 1e-4);
assert_abs_diff_eq!(merged_weight[3], 2.0, epsilon = 1e-4);
lora.unmerge();
assert!(!lora.is_merged());
let restored_weight = lora.base_weight().data();
for i in 0..4 {
assert_abs_diff_eq!(restored_weight[i], original_weight[i], epsilon = 1e-4);
}
}
#[test]
fn test_lora_forward_merged() {
let base_weight = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
let mut lora = LoRALayer::new(base_weight, 2, 2, 1, 1.0);
*lora.lora_a_mut().data_mut() = ndarray::arr1(&[1.0, 1.0]);
*lora.lora_b_mut().data_mut() = ndarray::arr1(&[1.0, 1.0]);
let x = Tensor::from_vec(vec![1.0, 1.0], true);
let output_unmerged = lora.forward(&x);
lora.merge();
let output_merged = lora.forward(&x);
assert_eq!(output_unmerged.len(), output_merged.len());
for i in 0..output_unmerged.len() {
assert_abs_diff_eq!(output_unmerged.data()[i], output_merged.data()[i], epsilon = 1e-4);
}
}
#[test]
fn test_lora_trainable_params() {
let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false);
let mut lora = LoRALayer::new(base_weight, 2, 2, 2, 4.0);
let params = lora.trainable_params();
assert_eq!(params.len(), 2);
assert_eq!(params[0].len(), 2 * 2); assert_eq!(params[1].len(), 2 * 2);
assert!(params[0].requires_grad());
assert!(params[1].requires_grad());
}
#[test]
fn test_lora_zero_initialization() {
let base_weight = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
let lora = LoRALayer::new(base_weight, 2, 2, 2, 2.0);
let x = Tensor::from_vec(vec![2.0, 3.0], true);
let output = lora.forward(&x);
assert_abs_diff_eq!(output.data()[0], 2.0, epsilon = 1e-4);
assert_abs_diff_eq!(output.data()[1], 3.0, epsilon = 1e-4);
}
#[test]
fn test_lora_rank_scaling() {
let base_weight = Tensor::from_vec(vec![1.0], false);
let lora_r4 = LoRALayer::new(base_weight.clone(), 1, 1, 4, 8.0);
let lora_r8 = LoRALayer::new(base_weight, 1, 1, 8, 8.0);
assert_abs_diff_eq!(lora_r4.scale(), 2.0, epsilon = 1e-6); assert_abs_diff_eq!(lora_r8.scale(), 1.0, epsilon = 1e-6); }
#[test]
fn test_rslora_scaling_compute() {
assert_abs_diff_eq!(LoRAScaling::Standard.compute(32.0, 16), 2.0, epsilon = 1e-6);
assert_abs_diff_eq!(LoRAScaling::Standard.compute(32.0, 64), 0.5, epsilon = 1e-6);
assert_abs_diff_eq!(LoRAScaling::RsLoRA.compute(32.0, 16), 8.0, epsilon = 1e-6); assert_abs_diff_eq!(LoRAScaling::RsLoRA.compute(32.0, 64), 4.0, epsilon = 1e-6); assert_abs_diff_eq!(LoRAScaling::RsLoRA.compute(8.0, 4), 4.0, epsilon = 1e-6);
}
#[test]
fn test_rslora_scaling_all_ranks() {
let alpha = 32.0;
for &rank in &[4usize, 8, 16, 32, 64, 128] {
let standard = LoRAScaling::Standard.compute(alpha, rank);
let rslora = LoRAScaling::RsLoRA.compute(alpha, rank);
assert_abs_diff_eq!(standard, alpha / rank as f32, epsilon = 1e-6);
assert_abs_diff_eq!(rslora, alpha / (rank as f32).sqrt(), epsilon = 1e-6);
assert!(rslora >= standard, "rsLoRA should be >= standard for rank={rank}");
}
}
#[test]
fn test_lora_layer_with_rslora() {
let base_weight = Tensor::from_vec(vec![1.0; 4], false);
let layer = LoRALayer::new_with_scaling(base_weight, 2, 2, 4, 8.0, LoRAScaling::RsLoRA);
assert_abs_diff_eq!(layer.scale(), 4.0, epsilon = 1e-6);
}
#[test]
fn test_lora_layer_standard_scaling_matches_new() {
let base_weight = Tensor::from_vec(vec![1.0; 4], false);
let standard = LoRALayer::new(base_weight.clone(), 2, 2, 4, 8.0);
let explicit = LoRALayer::new_with_scaling(base_weight, 2, 2, 4, 8.0, LoRAScaling::Standard);
assert_abs_diff_eq!(standard.scale(), explicit.scale(), epsilon = 1e-10);
}
#[test]
fn test_merge_already_merged_is_noop() {
let base_weight = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
let mut lora = LoRALayer::new(base_weight, 2, 2, 1, 1.0);
*lora.lora_a_mut().data_mut() = ndarray::arr1(&[1.0, 2.0]);
*lora.lora_b_mut().data_mut() = ndarray::arr1(&[0.5, 0.5]);
lora.merge();
assert!(lora.is_merged());
let weight_after_first_merge = lora.base_weight().data().to_owned();
lora.merge();
assert!(lora.is_merged());
let weight_after_second_merge = lora.base_weight().data().to_owned();
for i in 0..4 {
assert_abs_diff_eq!(
weight_after_first_merge[i],
weight_after_second_merge[i],
epsilon = 1e-10
);
}
}
#[test]
fn test_unmerge_not_merged_is_noop() {
let base_data = vec![1.0, 2.0, 3.0, 4.0];
let base_weight = Tensor::from_vec(base_data.clone(), false);
let mut lora = LoRALayer::new(base_weight, 2, 2, 1, 1.0);
*lora.lora_a_mut().data_mut() = ndarray::arr1(&[1.0, 2.0]);
*lora.lora_b_mut().data_mut() = ndarray::arr1(&[0.5, 0.5]);
assert!(!lora.is_merged());
lora.unmerge();
assert!(!lora.is_merged());
for i in 0..4 {
assert_abs_diff_eq!(lora.base_weight().data()[i], base_data[i], epsilon = 1e-10);
}
}
#[test]
fn test_lora_layer_clone() {
let base_weight = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
let mut original = LoRALayer::new(base_weight, 2, 2, 1, 2.0);
*original.lora_a_mut().data_mut() = ndarray::arr1(&[0.3, 0.7]);
*original.lora_b_mut().data_mut() = ndarray::arr1(&[0.5, 0.9]);
let cloned = original.clone();
assert_eq!(cloned.d_out(), original.d_out());
assert_eq!(cloned.d_in(), original.d_in());
assert_eq!(cloned.rank(), original.rank());
assert_abs_diff_eq!(cloned.scale(), original.scale(), epsilon = 1e-10);
assert_eq!(cloned.is_merged(), original.is_merged());
for (a, b) in cloned.lora_a().data().iter().zip(original.lora_a().data().iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-10);
}
for (a, b) in cloned.lora_b().data().iter().zip(original.lora_b().data().iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-10);
}
for (a, b) in cloned.base_weight().data().iter().zip(original.base_weight().data().iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-10);
}
let x = Tensor::from_vec(vec![1.0, 1.0], true);
let output_original = original.forward(&x);
let output_cloned = cloned.forward(&x);
for i in 0..2 {
assert_abs_diff_eq!(output_original.data()[i], output_cloned.data()[i], epsilon = 1e-6);
}
}
#[test]
fn test_lora_a_mut_modifies_weights() {
let base_weight = Tensor::from_vec(vec![1.0; 4], false);
let mut lora = LoRALayer::new(base_weight, 2, 2, 1, 1.0);
let initial_a: Vec<f32> = lora.lora_a().data().to_vec();
*lora.lora_a_mut().data_mut() = ndarray::arr1(&[99.0, 99.0]);
assert_abs_diff_eq!(lora.lora_a().data()[0], 99.0, epsilon = 1e-6);
assert_abs_diff_eq!(lora.lora_a().data()[1], 99.0, epsilon = 1e-6);
assert!(
(lora.lora_a().data()[0] - initial_a[0]).abs() > 1.0,
"lora_a_mut should have changed the values"
);
}
#[test]
fn test_lora_b_mut_modifies_weights() {
let base_weight = Tensor::from_vec(vec![1.0; 4], false);
let mut lora = LoRALayer::new(base_weight, 2, 2, 1, 1.0);
assert_abs_diff_eq!(lora.lora_b().data()[0], 0.0, epsilon = 1e-10);
*lora.lora_b_mut().data_mut() = ndarray::arr1(&[42.0, 42.0]);
assert_abs_diff_eq!(lora.lora_b().data()[0], 42.0, epsilon = 1e-6);
assert_abs_diff_eq!(lora.lora_b().data()[1], 42.0, epsilon = 1e-6);
}
#[test]
fn test_lora_scaling_clone_and_eq() {
let standard = LoRAScaling::Standard;
let rslora = LoRAScaling::RsLoRA;
let standard_clone = standard;
assert_eq!(standard, standard_clone);
assert_ne!(standard, rslora);
assert_eq!(rslora, LoRAScaling::RsLoRA);
}
#[test]
fn test_lora_scaling_debug() {
let s = format!("{:?}", LoRAScaling::Standard);
assert!(s.contains("Standard"));
let r = format!("{:?}", LoRAScaling::RsLoRA);
assert!(r.contains("RsLoRA"));
}