use crate::error::{QuantError, QuantResult};
#[derive(Debug, Clone)]
pub struct LoraAdapter {
pub a: Vec<f32>,
pub b: Vec<f32>,
pub rank: usize,
pub scale: f32,
pub in_features: usize,
pub out_features: usize,
}
impl LoraAdapter {
pub fn new(
a: Vec<f32>,
b: Vec<f32>,
rank: usize,
scale: f32,
in_features: usize,
out_features: usize,
) -> QuantResult<Self> {
let expected_a = rank * in_features;
if a.len() != expected_a {
return Err(QuantError::DimensionMismatch {
expected: expected_a,
got: a.len(),
});
}
let expected_b = out_features * rank;
if b.len() != expected_b {
return Err(QuantError::DimensionMismatch {
expected: expected_b,
got: b.len(),
});
}
Ok(Self {
a,
b,
rank,
scale,
in_features,
out_features,
})
}
pub fn apply(&self, input: &[f32], output: &mut [f32]) -> QuantResult<()> {
if input.len() < self.in_features {
return Err(QuantError::DimensionMismatch {
expected: self.in_features,
got: input.len(),
});
}
if output.len() < self.out_features {
return Err(QuantError::DimensionMismatch {
expected: self.out_features,
got: output.len(),
});
}
let mut r_vec = vec![0.0f32; self.rank];
for (i, r) in r_vec.iter_mut().enumerate().take(self.rank) {
let row_start = i * self.in_features;
let row = &self.a[row_start..row_start + self.in_features];
*r = row
.iter()
.zip(input[..self.in_features].iter())
.map(|(&a, &x)| a * x)
.sum();
}
let s = self.scale;
for (i, out) in output.iter_mut().enumerate().take(self.out_features) {
let row_start = i * self.rank;
let row = &self.b[row_start..row_start + self.rank];
let delta: f32 = row
.iter()
.zip(r_vec.iter())
.map(|(&b, &r)| b * r)
.sum::<f32>()
* s;
*out += delta;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_apply_rank2_correctness() {
let a = vec![1.0f32, 0.0, 0.0, 1.0]; let b = vec![2.0f32, 0.0, 0.0, 2.0]; let adapter = LoraAdapter::new(a, b, 2, 1.0, 2, 2).expect("valid adapter");
let input = vec![3.0f32, 4.0];
let mut output = vec![0.0f32, 0.0];
adapter.apply(&input, &mut output).expect("apply ok");
assert!(
(output[0] - 6.0).abs() < 1e-6,
"output[0] should be 6, got {}",
output[0]
);
assert!(
(output[1] - 8.0).abs() < 1e-6,
"output[1] should be 8, got {}",
output[1]
);
}
#[test]
fn test_apply_scale() {
let a = vec![1.0f32];
let b = vec![1.0f32];
let scale = 0.5;
let adapter = LoraAdapter::new(a, b, 1, scale, 1, 1).expect("valid adapter");
let input = vec![4.0f32];
let mut output = vec![10.0f32]; adapter.apply(&input, &mut output).expect("apply ok");
assert!(
(output[0] - 12.0).abs() < 1e-6,
"output[0] should be 12, got {}",
output[0]
);
}
#[test]
fn test_apply_accumulates() {
let a = vec![1.0f32];
let b = vec![1.0f32];
let adapter = LoraAdapter::new(a, b, 1, 1.0, 1, 1).expect("valid adapter");
let input = vec![1.0f32];
let mut output = vec![5.0f32];
adapter.apply(&input, &mut output).expect("apply ok");
assert!(
(output[0] - 6.0).abs() < 1e-6,
"output should accumulate: 5 + 1 = 6, got {}",
output[0]
);
}
#[test]
fn test_new_dimension_mismatch_a() {
let a = vec![1.0f32]; let b = vec![1.0f32, 0.0, 0.0, 1.0]; let result = LoraAdapter::new(a, b, 2, 1.0, 1, 2);
assert!(
matches!(result, Err(QuantError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {:?}",
result
);
}
#[test]
fn test_new_dimension_mismatch_b() {
let a = vec![1.0f32, 0.0]; let b = vec![1.0f32]; let result = LoraAdapter::new(a, b, 1, 1.0, 2, 2);
assert!(
matches!(result, Err(QuantError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {:?}",
result
);
}
#[test]
fn test_apply_input_too_short() {
let a = vec![1.0f32, 0.0, 0.0, 1.0]; let b = vec![1.0f32, 0.0, 0.0, 1.0]; let adapter = LoraAdapter::new(a, b, 2, 1.0, 2, 2).expect("valid adapter");
let input = vec![1.0f32]; let mut output = vec![0.0f32, 0.0];
let result = adapter.apply(&input, &mut output);
assert!(
matches!(result, Err(QuantError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {:?}",
result
);
}
#[test]
fn test_apply_output_too_short() {
let a = vec![1.0f32, 0.0, 0.0, 1.0]; let b = vec![1.0f32, 0.0, 0.0, 1.0]; let adapter = LoraAdapter::new(a, b, 2, 1.0, 2, 2).expect("valid adapter");
let input = vec![1.0f32, 2.0];
let mut output = vec![0.0f32]; let result = adapter.apply(&input, &mut output);
assert!(
matches!(result, Err(QuantError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {:?}",
result
);
}
#[test]
fn test_zero_rank_adapter() {
let adapter = LoraAdapter::new(vec![], vec![], 0, 1.0, 4, 4).expect("rank=0 valid");
let input = vec![1.0f32, 2.0, 3.0, 4.0];
let mut output = vec![0.0f32, 0.0, 0.0, 0.0];
adapter.apply(&input, &mut output).expect("apply ok");
for &v in &output {
assert!(
v.abs() < 1e-9,
"zero-rank adapter should not modify output, got {v}"
);
}
}
}