use super::LoraLayer;
pub fn apply_lora(lora: &LoraLayer, scale: f32, x: &[f32], output: &mut [f32]) {
debug_assert_eq!(x.len(), lora.d_in, "x length must equal d_in");
debug_assert!(output.len() >= lora.d_out, "output length must be >= d_out");
let rank = lora.rank;
let d_in = lora.d_in;
let d_out = lora.d_out;
let mut intermediate = vec![0.0f32; rank];
for (r, inter) in intermediate.iter_mut().enumerate() {
let row = &lora.a[r * d_in..(r + 1) * d_in];
let acc: f32 = row.iter().zip(x.iter()).map(|(a, x)| a * x).sum();
*inter = acc;
}
for (r, out) in output[..d_out].iter_mut().enumerate() {
let row = &lora.b[r * rank..(r + 1) * rank];
let acc: f32 = row
.iter()
.zip(intermediate.iter())
.map(|(b, i)| b * i)
.sum();
*out += scale * acc;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_apply_lora_identity_like() {
let lora = LoraLayer {
a: vec![1.0, 0.0, 0.0],
b: vec![1.0, 0.0],
d_in: 3,
d_out: 2,
rank: 1,
};
let x = [5.0, 3.0, 1.0];
let mut output = [100.0, 200.0]; apply_lora(&lora, 2.0, &x, &mut output);
assert!((output[0] - 110.0).abs() < 1e-6);
assert!((output[1] - 200.0).abs() < 1e-6);
}
#[test]
fn test_apply_lora_rank2() {
let lora = LoraLayer {
a: vec![1.0, 0.0, 0.0, 1.0],
b: vec![1.0, 2.0, 3.0, 4.0],
d_in: 2,
d_out: 2,
rank: 2,
};
let x = [1.0, 1.0];
let mut output = [0.0, 0.0];
apply_lora(&lora, 0.5, &x, &mut output);
assert!((output[0] - 1.5).abs() < 1e-6);
assert!((output[1] - 3.5).abs() < 1e-6);
}
#[test]
fn test_apply_lora_zero_scale() {
let lora = LoraLayer {
a: vec![1.0, 2.0],
b: vec![3.0, 4.0],
d_in: 2,
d_out: 2,
rank: 1,
};
let x = [1.0, 1.0];
let mut output = [10.0, 20.0];
apply_lora(&lora, 0.0, &x, &mut output);
assert!((output[0] - 10.0).abs() < 1e-6);
assert!((output[1] - 20.0).abs() < 1e-6);
}
}