mod apply;
#[cfg(feature = "safetensors")]
mod safetensors;
pub use apply::apply_lora;
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct LoraConfig {
pub rank: usize,
pub alpha: f32,
pub target_modules: Vec<String>,
}
impl LoraConfig {
pub fn scale(&self) -> f32 {
if self.rank == 0 {
0.0
} else {
self.alpha / self.rank as f32
}
}
}
#[derive(Debug, Clone)]
pub struct LoraLayer {
pub a: Vec<f32>,
pub b: Vec<f32>,
pub d_in: usize,
pub d_out: usize,
pub rank: usize,
}
#[derive(Debug, Clone)]
pub struct LoraAdapter {
pub config: LoraConfig,
pub layers: HashMap<(usize, String), LoraLayer>,
}
impl LoraAdapter {
#[cfg(feature = "safetensors")]
pub fn from_safetensors(path: &Path) -> crate::error::Result<Self> {
safetensors::load_peft_safetensors(path)
}
pub fn new(config: LoraConfig, layers: HashMap<(usize, String), LoraLayer>) -> Self {
Self { config, layers }
}
pub fn apply(&self, layer_idx: usize, module: &str, x: &[f32], base_output: &mut [f32]) {
let key = (layer_idx, module.to_string());
if let Some(lora_layer) = self.layers.get(&key) {
let scale = self.config.scale();
apply_lora(lora_layer, scale, x, base_output);
}
}
pub fn has_adapter(&self, layer_idx: usize, module: &str) -> bool {
self.layers.contains_key(&(layer_idx, module.to_string()))
}
pub fn num_adapted_layers(&self) -> usize {
self.layers.len()
}
pub fn num_parameters(&self) -> usize {
self.layers.values().map(|l| l.a.len() + l.b.len()).sum()
}
}
#[cfg(feature = "inference-hook")]
impl lattice_inference::lora_hook::LoraHook for LoraAdapter {
fn apply(&self, layer_idx: usize, module: &str, x: &[f32], output: &mut [f32]) {
LoraAdapter::apply(self, layer_idx, module, x, output);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_adapter() -> LoraAdapter {
let config = LoraConfig {
rank: 2,
alpha: 4.0, target_modules: vec!["q_proj".into(), "v_proj".into()],
};
let mut layers = HashMap::new();
layers.insert(
(0, "q_proj".into()),
LoraLayer {
a: vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ],
b: vec![
1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ],
d_in: 4,
d_out: 4,
rank: 2,
},
);
LoraAdapter::new(config, layers)
}
#[test]
fn test_config_scale() {
let config = LoraConfig {
rank: 8,
alpha: 16.0,
target_modules: vec![],
};
assert!((config.scale() - 2.0).abs() < 1e-6);
let config_zero = LoraConfig {
rank: 0,
alpha: 1.0,
target_modules: vec![],
};
assert_eq!(config_zero.scale(), 0.0);
}
#[test]
fn test_adapter_apply() {
let adapter = make_test_adapter();
let x = [1.0, 2.0, 3.0, 4.0];
let mut output = [0.0f32; 4];
adapter.apply(0, "q_proj", &x, &mut output);
assert!((output[0] - 2.0).abs() < 1e-6);
assert!((output[1] - 4.0).abs() < 1e-6);
assert!((output[2] - 0.0).abs() < 1e-6);
assert!((output[3] - 0.0).abs() < 1e-6);
}
#[test]
fn test_adapter_noop_for_missing_module() {
let adapter = make_test_adapter();
let x = [1.0, 2.0, 3.0, 4.0];
let mut output = [10.0, 20.0, 30.0, 40.0];
adapter.apply(0, "v_proj", &x, &mut output);
assert!((output[0] - 10.0).abs() < 1e-6);
assert!((output[1] - 20.0).abs() < 1e-6);
}
#[test]
fn test_adapter_noop_for_wrong_layer() {
let adapter = make_test_adapter();
let x = [1.0, 2.0, 3.0, 4.0];
let mut output = [10.0, 20.0, 30.0, 40.0];
adapter.apply(1, "q_proj", &x, &mut output);
assert!((output[0] - 10.0).abs() < 1e-6);
}
#[test]
fn test_has_adapter() {
let adapter = make_test_adapter();
assert!(adapter.has_adapter(0, "q_proj"));
assert!(!adapter.has_adapter(0, "v_proj"));
assert!(!adapter.has_adapter(1, "q_proj"));
}
#[test]
fn test_num_parameters() {
let adapter = make_test_adapter();
assert_eq!(adapter.num_parameters(), 16);
}
#[test]
fn test_num_adapted_layers() {
let adapter = make_test_adapter();
assert_eq!(adapter.num_adapted_layers(), 1);
}
}