use ferrum_quantization::LoraLinearRef;
#[test]
fn lora_linear_ref_matches_manual_f32() {
let base = vec![
1.0, 2.0, -1.0, 0.5, -0.5, 1.0, ];
let a = vec![
1.0, 0.0, 1.0, 0.0, 1.0, -1.0, ];
let b = vec![
2.0, -1.0, 0.5, 1.5, ];
let linear = LoraLinearRef::new(base, a, b, 3, 2, 2, 4.0).expect("linear");
let input = vec![
1.0, 2.0, 3.0, -1.0, 0.5, 2.0, ];
let out = linear.forward(&input, 2).expect("forward");
let expected = vec![
20.0, 3.5, 5.0, -2.25,
];
for (idx, (got, want)) in out.iter().zip(expected.iter()).enumerate() {
assert!(
(got - want).abs() <= 1e-5,
"idx={idx} got={got} want={want} out={out:?}"
);
}
}
#[test]
fn lora_linear_ref_rejects_shape_mismatch() {
let err = LoraLinearRef::new(vec![1.0], vec![1.0, 2.0], vec![1.0], 2, 1, 1, 1.0)
.expect_err("shape mismatch should fail");
assert!(err.to_string().contains("base weight shape mismatch"));
}