mod apply;
pub mod online;
pub mod optimizer;
#[cfg(feature = "safetensors")]
mod safetensors;
pub use apply::apply_lora;
pub use online::{AdaptStepResult, adapt_step};
pub use optimizer::{AdamState, LoraGradients, compute_lora_gradients};
#[cfg(feature = "safetensors")]
pub use safetensors::{load_peft_safetensors, save_peft_safetensors};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct LoraConfig {
pub rank: usize,
pub alpha: f32,
pub target_modules: Vec<String>,
}
impl LoraConfig {
pub fn scale(&self) -> f32 {
if self.rank == 0 {
0.0
} else {
self.alpha / self.rank as f32
}
}
}
#[derive(Debug, Clone)]
pub struct LoraLayer {
pub a: Vec<f32>,
pub b: Vec<f32>,
pub d_in: usize,
pub d_out: usize,
pub rank: usize,
}
#[derive(Debug, Clone)]
pub struct LoraAdapter {
pub config: LoraConfig,
pub layers: HashMap<(usize, String), LoraLayer>,
}
impl LoraAdapter {
#[cfg(feature = "safetensors")]
pub fn from_safetensors(path: &Path) -> crate::error::Result<Self> {
safetensors::load_peft_safetensors(path)
}
#[cfg(feature = "safetensors")]
pub fn save_safetensors(&self, path: &Path) -> crate::error::Result<()> {
safetensors::save_peft_safetensors(self, path)
}
pub fn new(config: LoraConfig, layers: HashMap<(usize, String), LoraLayer>) -> Self {
Self { config, layers }
}
pub fn apply(&self, layer_idx: usize, module: &str, x: &[f32], base_output: &mut [f32]) {
let key = (layer_idx, module.to_string());
if let Some(lora_layer) = self.layers.get(&key) {
let scale = self.config.scale();
apply_lora(lora_layer, scale, x, base_output);
}
}
pub fn has_adapter(&self, layer_idx: usize, module: &str) -> bool {
self.layers.contains_key(&(layer_idx, module.to_string()))
}
pub fn num_adapted_layers(&self) -> usize {
self.layers.len()
}
pub fn num_parameters(&self) -> usize {
self.layers.values().map(|l| l.a.len() + l.b.len()).sum()
}
pub fn validate_modules(&self, known: &[&str]) -> Vec<(usize, String)> {
self.layers
.keys()
.filter(|(_, m)| !known.iter().any(|k| k == m))
.cloned()
.collect()
}
}
#[cfg(feature = "inference-hook")]
impl LoraAdapter {
pub fn validate_against(
&self,
config: &lattice_inference::model::qwen35_config::Qwen35Config,
) -> crate::error::Result<()> {
for ((layer_idx, module), layer) in &self.layers {
if *layer_idx >= config.num_hidden_layers {
return Err(crate::error::TuneError::Validation(format!(
"LoRA layer index {layer_idx} >= model num_hidden_layers {} (module: {module})",
config.num_hidden_layers
)));
}
let is_full = config.is_full_attention(*layer_idx);
let (expected_d_in, expected_d_out) = match (module.as_str(), is_full) {
("q_proj", true) => (config.hidden_size, 2 * config.full_q_dim()),
("k_proj", true) => (config.hidden_size, config.full_kv_dim()),
("v_proj", true) => (config.hidden_size, config.full_kv_dim()),
("o_proj", true) => (config.full_q_dim(), config.hidden_size),
("in_proj_qkv", false) => (config.hidden_size, config.linear_qkv_dim()),
("in_proj_z", false) => (config.hidden_size, config.linear_output_dim()),
("out_proj", false) => (config.linear_output_dim(), config.hidden_size),
("gate_proj", _) => (config.hidden_size, config.intermediate_size),
("up_proj", _) => (config.hidden_size, config.intermediate_size),
("down_proj", _) => (config.intermediate_size, config.hidden_size),
(m, _) => {
return Err(crate::error::TuneError::Validation(format!(
"LoRA module '{m}' (layer {layer_idx}) is not a recognised Qwen3.5 projection"
)));
}
};
if layer.d_in != expected_d_in || layer.d_out != expected_d_out {
return Err(crate::error::TuneError::Validation(format!(
"LoRA adapter dims mismatch for layer {layer_idx} module '{module}': \
adapter has (d_in={}, d_out={}) but model expects (d_in={expected_d_in}, d_out={expected_d_out})",
layer.d_in, layer.d_out
)));
}
}
Ok(())
}
}
#[cfg(feature = "inference-hook")]
impl lattice_inference::lora_hook::LoraHook for LoraAdapter {
fn apply(&self, layer_idx: usize, module: &str, x: &[f32], output: &mut [f32]) {
LoraAdapter::apply(self, layer_idx, module, x, output);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_adapter() -> LoraAdapter {
let config = LoraConfig {
rank: 2,
alpha: 4.0, target_modules: vec!["q_proj".into(), "v_proj".into()],
};
let mut layers = HashMap::new();
layers.insert(
(0, "q_proj".into()),
LoraLayer {
a: vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ],
b: vec![
1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ],
d_in: 4,
d_out: 4,
rank: 2,
},
);
LoraAdapter::new(config, layers)
}
#[test]
fn test_config_scale() {
let config = LoraConfig {
rank: 8,
alpha: 16.0,
target_modules: vec![],
};
assert!((config.scale() - 2.0).abs() < 1e-6);
let config_zero = LoraConfig {
rank: 0,
alpha: 1.0,
target_modules: vec![],
};
assert_eq!(config_zero.scale(), 0.0);
}
#[test]
fn test_adapter_apply() {
let adapter = make_test_adapter();
let x = [1.0, 2.0, 3.0, 4.0];
let mut output = [0.0f32; 4];
adapter.apply(0, "q_proj", &x, &mut output);
assert!((output[0] - 2.0).abs() < 1e-6);
assert!((output[1] - 4.0).abs() < 1e-6);
assert!((output[2] - 0.0).abs() < 1e-6);
assert!((output[3] - 0.0).abs() < 1e-6);
}
#[test]
fn test_adapter_noop_for_missing_module() {
let adapter = make_test_adapter();
let x = [1.0, 2.0, 3.0, 4.0];
let mut output = [10.0, 20.0, 30.0, 40.0];
adapter.apply(0, "v_proj", &x, &mut output);
assert!((output[0] - 10.0).abs() < 1e-6);
assert!((output[1] - 20.0).abs() < 1e-6);
}
#[test]
fn test_adapter_noop_for_wrong_layer() {
let adapter = make_test_adapter();
let x = [1.0, 2.0, 3.0, 4.0];
let mut output = [10.0, 20.0, 30.0, 40.0];
adapter.apply(1, "q_proj", &x, &mut output);
assert!((output[0] - 10.0).abs() < 1e-6);
}
#[test]
fn test_has_adapter() {
let adapter = make_test_adapter();
assert!(adapter.has_adapter(0, "q_proj"));
assert!(!adapter.has_adapter(0, "v_proj"));
assert!(!adapter.has_adapter(1, "q_proj"));
}
#[test]
fn test_num_parameters() {
let adapter = make_test_adapter();
assert_eq!(adapter.num_parameters(), 16);
}
#[test]
fn test_num_adapted_layers() {
let adapter = make_test_adapter();
assert_eq!(adapter.num_adapted_layers(), 1);
}
#[test]
fn test_validate_modules_all_known() {
let adapter = make_test_adapter();
let unknown = adapter.validate_modules(&["q_proj", "v_proj", "k_proj"]);
assert!(unknown.is_empty());
}
#[test]
fn test_validate_modules_typo() {
let config = LoraConfig {
rank: 2,
alpha: 4.0,
target_modules: vec!["q_porj".into()],
};
let mut layers = HashMap::new();
layers.insert(
(0, "q_porj".into()),
LoraLayer {
a: vec![1.0; 8],
b: vec![1.0; 8],
d_in: 4,
d_out: 4,
rank: 2,
},
);
let adapter = LoraAdapter::new(config, layers);
let unknown = adapter.validate_modules(&["q_proj", "v_proj"]);
assert_eq!(unknown.len(), 1);
assert_eq!(unknown[0], (0, "q_porj".to_string()));
}
#[test]
fn test_validate_modules_empty_adapter() {
let config = LoraConfig {
rank: 2,
alpha: 4.0,
target_modules: vec![],
};
let adapter = LoraAdapter::new(config, HashMap::new());
let unknown = adapter.validate_modules(&["q_proj"]);
assert!(unknown.is_empty());
}
#[cfg(feature = "inference-hook")]
mod validate_against_tests {
use super::*;
use lattice_inference::model::qwen35_config::Qwen35Config;
fn make_adapter_for_layer(
layer_idx: usize,
module: &str,
d_in: usize,
d_out: usize,
) -> LoraAdapter {
let rank = 4;
let mut layers = HashMap::new();
layers.insert(
(layer_idx, module.to_string()),
LoraLayer {
a: vec![0.0; rank * d_in],
b: vec![0.0; d_out * rank],
d_in,
d_out,
rank,
},
);
LoraAdapter::new(
LoraConfig {
rank,
alpha: rank as f32,
target_modules: vec![module.to_string()],
},
layers,
)
}
#[test]
fn test_validate_against_layer_out_of_bounds() {
let cfg = Qwen35Config::qwen35_0_8b();
let adapter = make_adapter_for_layer(999, "q_proj", 1024, 4096);
assert!(adapter.validate_against(&cfg).is_err());
}
#[test]
fn test_validate_against_dim_mismatch() {
let cfg = Qwen35Config::qwen35_0_8b();
let adapter = make_adapter_for_layer(3, "q_proj", 2048, 8192);
assert!(adapter.validate_against(&cfg).is_err());
}
#[test]
fn test_validate_against_correct_dims_passes() {
let cfg = Qwen35Config::qwen35_0_8b();
let adapter = make_adapter_for_layer(3, "q_proj", 1024, 4096);
assert!(adapter.validate_against(&cfg).is_ok());
}
#[test]
fn test_validate_against_mlp_correct() {
let cfg = Qwen35Config::qwen35_0_8b();
let adapter = make_adapter_for_layer(0, "gate_proj", 1024, 3584);
assert!(adapter.validate_against(&cfg).is_ok());
}
#[test]
fn test_validate_against_unknown_module_errors() {
let cfg = Qwen35Config::qwen35_0_8b();
let adapter = make_adapter_for_layer(3, "xq_proj_typo", 1024, 4096);
let err = adapter.validate_against(&cfg).unwrap_err();
assert!(err.to_string().contains("not a recognised"));
}
}
}