use std::collections::HashSet;
#[derive(Clone, Debug)]
pub struct LoRAConfig {
pub rank: usize,
pub alpha: f32,
pub target_modules: HashSet<String>,
pub layers: Option<Vec<usize>>,
pub all_linear: bool,
}
impl LoRAConfig {
pub fn new(rank: usize, alpha: f32) -> Self {
Self { rank, alpha, target_modules: HashSet::new(), layers: None, all_linear: false }
}
pub fn target_modules(mut self, modules: &[&str]) -> Self {
self.target_modules = modules.iter().map(ToString::to_string).collect();
self
}
pub fn target_attention_projections(mut self) -> Self {
self.target_modules = vec![
"q_proj".to_string(),
"k_proj".to_string(),
"v_proj".to_string(),
"o_proj".to_string(),
]
.into_iter()
.collect();
self
}
pub fn target_qv_projections(mut self) -> Self {
self.target_modules =
vec!["q_proj".to_string(), "v_proj".to_string()].into_iter().collect();
self
}
pub fn target_qkv_projections(mut self) -> Self {
self.target_modules =
vec!["q_proj".to_string(), "k_proj".to_string(), "v_proj".to_string()]
.into_iter()
.collect();
self
}
pub fn target_layers(mut self, layer_indices: &[usize]) -> Self {
self.layers = Some(layer_indices.to_vec());
self
}
pub fn all_linear_layers(mut self) -> Self {
self.all_linear = true;
self
}
pub fn expand_shorthand(modules: &[String]) -> Vec<String> {
if modules.len() == 1 {
match modules[0].as_str() {
"all_linear" => {
return vec![
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
.into_iter()
.map(String::from)
.collect()
}
"attention" => {
return vec!["q_proj", "k_proj", "v_proj", "o_proj"]
.into_iter()
.map(String::from)
.collect()
}
"qv" => return vec!["q_proj", "v_proj"].into_iter().map(String::from).collect(),
"mlp" => {
return vec!["gate_proj", "up_proj", "down_proj"]
.into_iter()
.map(String::from)
.collect()
}
_ => {}
}
}
modules.to_vec()
}
pub fn should_apply(&self, module_name: &str, layer_idx: Option<usize>) -> bool {
if let Some(layers) = &self.layers {
if let Some(idx) = layer_idx {
if !layers.contains(&idx) {
return false;
}
}
}
if self.all_linear {
module_name.ends_with("proj") || module_name.ends_with("linear")
} else {
self.target_modules.contains(module_name)
}
}
pub fn num_target_modules(&self) -> usize {
self.target_modules.len()
}
pub fn is_all_linear(&self) -> bool {
self.all_linear
}
pub fn get_target_modules(&self) -> Vec<&str> {
self.target_modules.iter().map(String::as_str).collect()
}
}
impl Default for LoRAConfig {
fn default() -> Self {
Self::new(8, 8.0).target_qv_projections()
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(proptest::test_runner::Config::with_cases(200))]
#[test]
fn prop_should_apply_consistent_with_modules(
rank in 1usize..64,
alpha in 1.0f32..64.0,
include_q in proptest::bool::ANY,
include_k in proptest::bool::ANY,
include_v in proptest::bool::ANY,
include_o in proptest::bool::ANY,
) {
let mut modules = vec![];
if include_q { modules.push("q_proj"); }
if include_k { modules.push("k_proj"); }
if include_v { modules.push("v_proj"); }
if include_o { modules.push("o_proj"); }
let config = LoRAConfig::new(rank, alpha).target_modules(&modules);
prop_assert_eq!(config.should_apply("q_proj", None), include_q);
prop_assert_eq!(config.should_apply("k_proj", None), include_k);
prop_assert_eq!(config.should_apply("v_proj", None), include_v);
prop_assert_eq!(config.should_apply("o_proj", None), include_o);
prop_assert_eq!(config.num_target_modules(), modules.len());
}
#[test]
fn prop_layer_filtering_respects_indices(
layers in prop::collection::vec(0usize..32, 1..8),
test_layer in 0usize..32,
) {
let config = LoRAConfig::new(8, 8.0)
.target_modules(&["q_proj"])
.target_layers(&layers);
let in_list = layers.contains(&test_layer);
prop_assert_eq!(config.should_apply("q_proj", Some(test_layer)), in_list);
}
#[test]
fn prop_all_linear_matches_suffixes(
prefix in "[a-z]{1,8}",
) {
let config = LoRAConfig::new(8, 8.0).all_linear_layers();
let proj_name = format!("{prefix}_proj");
let linear_name = format!("{prefix}_linear");
let other_name = format!("{prefix}_norm");
prop_assert!(config.should_apply(&proj_name, None));
prop_assert!(config.should_apply(&linear_name, None));
prop_assert!(!config.should_apply(&other_name, None));
}
#[test]
fn prop_config_params_preserved(
rank in 1usize..128,
alpha in 0.1f32..128.0,
) {
let config = LoRAConfig::new(rank, alpha)
.target_attention_projections()
.target_layers(&[0, 1, 2]);
prop_assert_eq!(config.rank, rank);
prop_assert!((config.alpha - alpha).abs() < 1e-6);
prop_assert_eq!(config.num_target_modules(), 4);
}
#[test]
fn prop_none_layer_bypasses_filter(
layers in prop::collection::vec(0usize..16, 1..4),
) {
let config = LoRAConfig::new(8, 8.0)
.target_modules(&["q_proj"])
.target_layers(&layers);
prop_assert!(config.should_apply("q_proj", None));
}
}
#[test]
fn test_lora_config_creation() {
let config = LoRAConfig::new(16, 16.0);
assert_eq!(config.rank, 16);
assert_eq!(config.alpha, 16.0);
assert_eq!(config.num_target_modules(), 0);
assert!(!config.is_all_linear());
}
#[test]
fn test_target_modules() {
let config = LoRAConfig::new(8, 8.0).target_modules(&["q_proj", "k_proj"]);
assert!(config.should_apply("q_proj", None));
assert!(config.should_apply("k_proj", None));
assert!(!config.should_apply("v_proj", None));
assert!(!config.should_apply("o_proj", None));
assert_eq!(config.num_target_modules(), 2);
}
#[test]
fn test_target_attention_projections() {
let config = LoRAConfig::new(8, 8.0).target_attention_projections();
assert!(config.should_apply("q_proj", None));
assert!(config.should_apply("k_proj", None));
assert!(config.should_apply("v_proj", None));
assert!(config.should_apply("o_proj", None));
assert!(!config.should_apply("mlp_proj", None));
assert_eq!(config.num_target_modules(), 4);
}
#[test]
fn test_target_qv_projections() {
let config = LoRAConfig::new(8, 8.0).target_qv_projections();
assert!(config.should_apply("q_proj", None));
assert!(config.should_apply("v_proj", None));
assert!(!config.should_apply("k_proj", None));
assert!(!config.should_apply("o_proj", None));
assert_eq!(config.num_target_modules(), 2);
}
#[test]
fn test_target_qkv_projections() {
let config = LoRAConfig::new(8, 8.0).target_qkv_projections();
assert!(config.should_apply("q_proj", None));
assert!(config.should_apply("k_proj", None));
assert!(config.should_apply("v_proj", None));
assert!(!config.should_apply("o_proj", None));
assert_eq!(config.num_target_modules(), 3);
}
#[test]
fn test_target_layers() {
let config = LoRAConfig::new(8, 8.0).target_modules(&["q_proj"]).target_layers(&[0, 2, 4]);
assert!(config.should_apply("q_proj", Some(0)));
assert!(!config.should_apply("q_proj", Some(1)));
assert!(config.should_apply("q_proj", Some(2)));
assert!(!config.should_apply("q_proj", Some(3)));
assert!(config.should_apply("q_proj", Some(4)));
}
#[test]
fn test_all_linear_layers() {
let config = LoRAConfig::new(8, 8.0).all_linear_layers();
assert!(config.is_all_linear());
assert!(config.should_apply("q_proj", None));
assert!(config.should_apply("k_proj", None));
assert!(config.should_apply("mlp_proj", None));
assert!(config.should_apply("fc_linear", None));
assert!(!config.should_apply("layer_norm", None));
}
#[test]
fn test_default_config() {
let config = LoRAConfig::default();
assert_eq!(config.rank, 8);
assert_eq!(config.alpha, 8.0);
assert!(config.should_apply("q_proj", None));
assert!(config.should_apply("v_proj", None));
assert!(!config.should_apply("k_proj", None));
assert_eq!(config.num_target_modules(), 2);
}
#[test]
fn test_layer_filtering_with_modules() {
let config = LoRAConfig::new(4, 4.0).target_attention_projections().target_layers(&[1, 3]);
assert!(!config.should_apply("q_proj", Some(0)));
assert!(config.should_apply("q_proj", Some(1)));
assert!(config.should_apply("v_proj", Some(1)));
assert!(!config.should_apply("q_proj", Some(2)));
assert!(config.should_apply("k_proj", Some(3)));
assert!(config.should_apply("o_proj", Some(3)));
}
#[test]
fn test_ent_lora_005_expand_all_linear() {
let expanded = LoRAConfig::expand_shorthand(&["all_linear".to_string()]);
assert_eq!(expanded.len(), 7);
assert!(expanded.contains(&"q_proj".to_string()));
assert!(expanded.contains(&"k_proj".to_string()));
assert!(expanded.contains(&"v_proj".to_string()));
assert!(expanded.contains(&"o_proj".to_string()));
assert!(expanded.contains(&"gate_proj".to_string()));
assert!(expanded.contains(&"up_proj".to_string()));
assert!(expanded.contains(&"down_proj".to_string()));
}
#[test]
fn test_ent_lora_005_expand_attention() {
let expanded = LoRAConfig::expand_shorthand(&["attention".to_string()]);
assert_eq!(expanded.len(), 4);
assert!(expanded.contains(&"q_proj".to_string()));
assert!(expanded.contains(&"k_proj".to_string()));
assert!(expanded.contains(&"v_proj".to_string()));
assert!(expanded.contains(&"o_proj".to_string()));
}
#[test]
fn test_ent_lora_005_expand_qv() {
let expanded = LoRAConfig::expand_shorthand(&["qv".to_string()]);
assert_eq!(expanded.len(), 2);
assert!(expanded.contains(&"q_proj".to_string()));
assert!(expanded.contains(&"v_proj".to_string()));
}
#[test]
fn test_ent_lora_005_expand_mlp() {
let expanded = LoRAConfig::expand_shorthand(&["mlp".to_string()]);
assert_eq!(expanded.len(), 3);
assert!(expanded.contains(&"gate_proj".to_string()));
assert!(expanded.contains(&"up_proj".to_string()));
assert!(expanded.contains(&"down_proj".to_string()));
}
#[test]
fn test_ent_lora_005_expand_explicit_passthrough() {
let explicit = vec!["q_proj".to_string(), "v_proj".to_string()];
let expanded = LoRAConfig::expand_shorthand(&explicit);
assert_eq!(expanded, explicit);
}
#[test]
fn test_ent_lora_005_expand_unknown_single() {
let modules = vec!["custom_proj".to_string()];
let expanded = LoRAConfig::expand_shorthand(&modules);
assert_eq!(expanded, modules);
}
#[test]
fn test_get_target_modules() {
let config = LoRAConfig::new(8, 8.0).target_modules(&["q_proj", "v_proj"]);
let mut modules = config.get_target_modules();
modules.sort_unstable();
assert_eq!(modules.len(), 2);
assert!(modules.contains(&"q_proj"));
assert!(modules.contains(&"v_proj"));
}
}