use super::*;
use crate::lora::LoRALayer;
use crate::Tensor;
use approx::assert_abs_diff_eq;
use proptest::prelude::*;
use std::fs;
proptest! {
#![proptest_config(proptest::test_runner::Config::with_cases(100))]
#[test]
fn prop_round_trip_preserves_data(
d_out in 2usize..16,
d_in in 2usize..16,
rank in 1usize..4,
alpha in 1.0f32..32.0,
) {
let size = d_out * d_in;
let base_weight = Tensor::from_vec(vec![1.0; size], false);
let layer = LoRALayer::new(base_weight.clone(), d_out, d_in, rank, alpha);
let adapter = LoRAAdapter::from_layer(&layer, rank, alpha);
let path = format!("/tmp/prop_test_adapter_{d_out}_{d_in}_{rank}.json");
adapter.save(&path).expect("save should succeed");
let loaded = LoRAAdapter::load(&path).expect("load should succeed");
let loaded_layer = loaded.to_layer(base_weight).expect("load should succeed");
prop_assert_eq!(loaded_layer.d_out(), d_out);
prop_assert_eq!(loaded_layer.d_in(), d_in);
prop_assert_eq!(loaded_layer.rank(), rank);
prop_assert!((loaded_layer.scale() - alpha / rank as f32).abs() < 1e-5);
fs::remove_file(&path).ok();
}
#[test]
fn prop_metadata_correct(
d_out in 2usize..32,
d_in in 2usize..32,
rank in 1usize..8,
alpha in 1.0f32..64.0,
) {
let size = d_out * d_in;
let base_weight = Tensor::from_vec(vec![1.0; size], false);
let layer = LoRALayer::new(base_weight, d_out, d_in, rank, alpha);
let adapter = LoRAAdapter::from_layer(&layer, rank, alpha);
let metadata = adapter.metadata();
prop_assert_eq!(metadata.d_out, d_out);
prop_assert_eq!(metadata.d_in, d_in);
prop_assert_eq!(metadata.rank, rank);
prop_assert!((metadata.alpha - alpha).abs() < 1e-6);
prop_assert_eq!(metadata.num_params, rank * d_in + d_out * rank);
prop_assert_eq!(metadata.version, "1.0");
}
#[test]
fn prop_forward_invariant_after_save_load(
d in 4usize..12,
rank in 1usize..3,
) {
let size = d * d;
let base_weight = Tensor::from_vec(vec![0.5; size], false);
let layer = LoRALayer::new(base_weight.clone(), d, d, rank, 4.0);
let x = Tensor::from_vec(vec![0.1; d], true);
let original_output = layer.forward(&x);
let path = format!("/tmp/prop_forward_test_{d}_{rank}.json");
save_adapter(&layer, rank, 4.0, &path).expect("save should succeed");
let loaded_layer = load_adapter(base_weight, &path).expect("load should succeed");
let loaded_output = loaded_layer.forward(&x);
prop_assert_eq!(original_output.len(), loaded_output.len());
for i in 0..original_output.len() {
prop_assert!(
(original_output.data()[i] - loaded_output.data()[i]).abs() < 1e-5,
"Forward output mismatch at index {}", i
);
}
fs::remove_file(&path).ok();
}
#[test]
fn prop_dimension_validation_catches_mismatches(
d_out in 2usize..8,
d_in in 2usize..8,
rank in 1usize..3,
) {
let size = d_out * d_in;
let base_weight = Tensor::from_vec(vec![1.0; size], false);
let layer = LoRALayer::new(base_weight, d_out, d_in, rank, 4.0);
let adapter = LoRAAdapter::from_layer(&layer, rank, 4.0);
let wrong_size = size + 1;
let wrong_base = Tensor::from_vec(vec![1.0; wrong_size], false);
let result = adapter.to_layer(wrong_base);
prop_assert!(result.is_err());
}
}
#[test]
fn test_adapter_serialization_round_trip() {
let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false);
let mut layer = LoRALayer::new(base_weight.clone(), 2, 2, 2, 4.0);
*layer.lora_a_mut().data_mut() = ndarray::arr1(&[0.1, 0.2, 0.3, 0.4]);
*layer.lora_b_mut().data_mut() = ndarray::arr1(&[0.5, 0.6, 0.7, 0.8]);
let path = "/tmp/test_adapter.json";
let adapter = LoRAAdapter::from_layer(&layer, 2, 4.0);
adapter.save(path).expect("save should succeed");
let loaded_adapter = LoRAAdapter::load(path).expect("load should succeed");
let loaded_layer = loaded_adapter.to_layer(base_weight).expect("load should succeed");
for (&orig, &loaded) in layer.lora_a().data().iter().zip(loaded_layer.lora_a().data().iter()) {
assert_abs_diff_eq!(orig, loaded, epsilon = 1e-6);
}
for (&orig, &loaded) in layer.lora_b().data().iter().zip(loaded_layer.lora_b().data().iter()) {
assert_abs_diff_eq!(orig, loaded, epsilon = 1e-6);
}
assert_eq!(loaded_layer.rank(), 2);
assert_eq!(loaded_layer.d_out(), 2);
assert_eq!(loaded_layer.d_in(), 2);
assert_abs_diff_eq!(loaded_layer.scale(), 2.0, epsilon = 1e-6);
fs::remove_file(path).expect("operation should succeed");
}
#[test]
fn test_adapter_forward_consistency() {
let base_weight = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
let mut layer = LoRALayer::new(base_weight.clone(), 2, 2, 1, 1.0);
*layer.lora_a_mut().data_mut() = ndarray::arr1(&[0.5, 0.5]);
*layer.lora_b_mut().data_mut() = ndarray::arr1(&[0.3, 0.3]);
let x = Tensor::from_vec(vec![2.0, 3.0], true);
let output_original = layer.forward(&x);
let path = "/tmp/test_adapter_forward.json";
save_adapter(&layer, 1, 1.0, path).expect("save should succeed");
let loaded_layer = load_adapter(base_weight, path).expect("load should succeed");
let output_loaded = loaded_layer.forward(&x);
assert_eq!(output_original.len(), output_loaded.len());
for i in 0..output_original.len() {
assert_abs_diff_eq!(output_original.data()[i], output_loaded.data()[i], epsilon = 1e-5);
}
fs::remove_file(path).expect("operation should succeed");
}
#[test]
fn test_adapter_dimension_validation() {
let base_weight_2x2 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false);
let layer = LoRALayer::new(base_weight_2x2, 2, 2, 2, 4.0);
let adapter = LoRAAdapter::from_layer(&layer, 2, 4.0);
let wrong_base = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], false); let result = adapter.to_layer(wrong_base);
assert!(result.is_err());
match result {
Err(AdapterError::DimensionMismatch { .. }) => {}
_ => panic!("Expected DimensionMismatch error"),
}
}
#[test]
fn test_adapter_metadata() {
let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], false);
let layer = LoRALayer::new(base_weight, 3, 2, 2, 8.0);
let adapter = LoRAAdapter::from_layer(&layer, 2, 8.0);
let metadata = adapter.metadata();
assert_eq!(metadata.rank, 2);
assert_abs_diff_eq!(metadata.alpha, 8.0, epsilon = 1e-6);
assert_eq!(metadata.d_out, 3);
assert_eq!(metadata.d_in, 2);
assert_abs_diff_eq!(metadata.scale, 4.0, epsilon = 1e-6); assert_eq!(metadata.num_params, 4 + 6); assert_eq!(metadata.version, "1.0");
}
#[test]
fn test_multiple_adapters_same_base() {
let base_weight = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
let mut layer1 = LoRALayer::new(base_weight.clone(), 2, 2, 1, 1.0);
*layer1.lora_a_mut().data_mut() = ndarray::arr1(&[1.0, 1.0]);
*layer1.lora_b_mut().data_mut() = ndarray::arr1(&[1.0, 1.0]);
save_adapter(&layer1, 1, 1.0, "/tmp/adapter1.json").expect("save should succeed");
let mut layer2 = LoRALayer::new(base_weight.clone(), 2, 2, 1, 1.0);
*layer2.lora_a_mut().data_mut() = ndarray::arr1(&[2.0, 2.0]);
*layer2.lora_b_mut().data_mut() = ndarray::arr1(&[2.0, 2.0]);
save_adapter(&layer2, 1, 1.0, "/tmp/adapter2.json").expect("save should succeed");
let loaded1 =
load_adapter(base_weight.clone(), "/tmp/adapter1.json").expect("load should succeed");
let loaded2 = load_adapter(base_weight, "/tmp/adapter2.json").expect("load should succeed");
assert_abs_diff_eq!(loaded1.lora_a().data()[0], 1.0, epsilon = 1e-6);
assert_abs_diff_eq!(loaded2.lora_a().data()[0], 2.0, epsilon = 1e-6);
fs::remove_file("/tmp/adapter1.json").expect("operation should succeed");
fs::remove_file("/tmp/adapter2.json").expect("operation should succeed");
}
#[test]
fn test_adapter_file_format_readable() {
let base_weight = Tensor::from_vec(vec![1.0, 2.0], false);
let layer = LoRALayer::new(base_weight, 2, 1, 1, 2.0);
let path = "/tmp/test_readable.json";
save_adapter(&layer, 1, 2.0, path).expect("save should succeed");
let content = fs::read_to_string(path).expect("file read should succeed");
assert!(content.contains("\"version\""));
assert!(content.contains("\"rank\""));
assert!(content.contains("\"alpha\""));
assert!(content.contains("\"lora_a\""));
assert!(content.contains("\"lora_b\""));
fs::remove_file(path).expect("operation should succeed");
}