1use std::collections::HashSet;
7
8#[derive(Clone, Debug)]
10pub struct LoRAConfig {
11 pub rank: usize,
13 pub alpha: f32,
15 pub target_modules: HashSet<String>,
17 pub layers: Option<Vec<usize>>,
19 pub all_linear: bool,
21}
22
23impl LoRAConfig {
24 pub fn new(rank: usize, alpha: f32) -> Self {
30 Self { rank, alpha, target_modules: HashSet::new(), layers: None, all_linear: false }
31 }
32
33 pub fn target_modules(mut self, modules: &[&str]) -> Self {
40 self.target_modules = modules.iter().map(ToString::to_string).collect();
41 self
42 }
43
44 pub fn target_attention_projections(mut self) -> Self {
46 self.target_modules = vec![
47 "q_proj".to_string(),
48 "k_proj".to_string(),
49 "v_proj".to_string(),
50 "o_proj".to_string(),
51 ]
52 .into_iter()
53 .collect();
54 self
55 }
56
57 pub fn target_qv_projections(mut self) -> Self {
59 self.target_modules =
60 vec!["q_proj".to_string(), "v_proj".to_string()].into_iter().collect();
61 self
62 }
63
64 pub fn target_qkv_projections(mut self) -> Self {
66 self.target_modules =
67 vec!["q_proj".to_string(), "k_proj".to_string(), "v_proj".to_string()]
68 .into_iter()
69 .collect();
70 self
71 }
72
73 pub fn target_layers(mut self, layer_indices: &[usize]) -> Self {
80 self.layers = Some(layer_indices.to_vec());
81 self
82 }
83
84 pub fn all_linear_layers(mut self) -> Self {
86 self.all_linear = true;
87 self
88 }
89
90 pub fn expand_shorthand(modules: &[String]) -> Vec<String> {
99 if modules.len() == 1 {
100 match modules[0].as_str() {
101 "all_linear" => {
102 return vec![
103 "q_proj",
104 "k_proj",
105 "v_proj",
106 "o_proj",
107 "gate_proj",
108 "up_proj",
109 "down_proj",
110 ]
111 .into_iter()
112 .map(String::from)
113 .collect()
114 }
115 "attention" => {
116 return vec!["q_proj", "k_proj", "v_proj", "o_proj"]
117 .into_iter()
118 .map(String::from)
119 .collect()
120 }
121 "qv" => return vec!["q_proj", "v_proj"].into_iter().map(String::from).collect(),
122 "mlp" => {
123 return vec!["gate_proj", "up_proj", "down_proj"]
124 .into_iter()
125 .map(String::from)
126 .collect()
127 }
128 _ => {}
129 }
130 }
131 modules.to_vec()
132 }
133
134 pub fn should_apply(&self, module_name: &str, layer_idx: Option<usize>) -> bool {
140 if let Some(layers) = &self.layers {
142 if let Some(idx) = layer_idx {
143 if !layers.contains(&idx) {
144 return false;
145 }
146 }
147 }
148
149 if self.all_linear {
151 module_name.ends_with("proj") || module_name.ends_with("linear")
153 } else {
154 self.target_modules.contains(module_name)
155 }
156 }
157
158 pub fn num_target_modules(&self) -> usize {
160 self.target_modules.len()
161 }
162
163 pub fn is_all_linear(&self) -> bool {
165 self.all_linear
166 }
167
168 pub fn get_target_modules(&self) -> Vec<&str> {
170 self.target_modules.iter().map(String::as_str).collect()
171 }
172}
173
174impl Default for LoRAConfig {
175 fn default() -> Self {
177 Self::new(8, 8.0).target_qv_projections()
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use proptest::prelude::*;
185
186 proptest! {
191 #![proptest_config(proptest::test_runner::Config::with_cases(200))]
192
193 #[test]
195 fn prop_should_apply_consistent_with_modules(
196 rank in 1usize..64,
197 alpha in 1.0f32..64.0,
198 include_q in proptest::bool::ANY,
199 include_k in proptest::bool::ANY,
200 include_v in proptest::bool::ANY,
201 include_o in proptest::bool::ANY,
202 ) {
203 let mut modules = vec![];
204 if include_q { modules.push("q_proj"); }
205 if include_k { modules.push("k_proj"); }
206 if include_v { modules.push("v_proj"); }
207 if include_o { modules.push("o_proj"); }
208
209 let config = LoRAConfig::new(rank, alpha).target_modules(&modules);
210
211 prop_assert_eq!(config.should_apply("q_proj", None), include_q);
213 prop_assert_eq!(config.should_apply("k_proj", None), include_k);
214 prop_assert_eq!(config.should_apply("v_proj", None), include_v);
215 prop_assert_eq!(config.should_apply("o_proj", None), include_o);
216 prop_assert_eq!(config.num_target_modules(), modules.len());
217 }
218
219 #[test]
221 fn prop_layer_filtering_respects_indices(
222 layers in prop::collection::vec(0usize..32, 1..8),
223 test_layer in 0usize..32,
224 ) {
225 let config = LoRAConfig::new(8, 8.0)
226 .target_modules(&["q_proj"])
227 .target_layers(&layers);
228
229 let in_list = layers.contains(&test_layer);
231 prop_assert_eq!(config.should_apply("q_proj", Some(test_layer)), in_list);
232 }
233
234 #[test]
236 fn prop_all_linear_matches_suffixes(
237 prefix in "[a-z]{1,8}",
238 ) {
239 let config = LoRAConfig::new(8, 8.0).all_linear_layers();
240
241 let proj_name = format!("{prefix}_proj");
243 let linear_name = format!("{prefix}_linear");
244 let other_name = format!("{prefix}_norm");
245
246 prop_assert!(config.should_apply(&proj_name, None));
247 prop_assert!(config.should_apply(&linear_name, None));
248 prop_assert!(!config.should_apply(&other_name, None));
249 }
250
251 #[test]
253 fn prop_config_params_preserved(
254 rank in 1usize..128,
255 alpha in 0.1f32..128.0,
256 ) {
257 let config = LoRAConfig::new(rank, alpha)
258 .target_attention_projections()
259 .target_layers(&[0, 1, 2]);
260
261 prop_assert_eq!(config.rank, rank);
262 prop_assert!((config.alpha - alpha).abs() < 1e-6);
263 prop_assert_eq!(config.num_target_modules(), 4);
264 }
265
266 #[test]
268 fn prop_none_layer_bypasses_filter(
269 layers in prop::collection::vec(0usize..16, 1..4),
270 ) {
271 let config = LoRAConfig::new(8, 8.0)
272 .target_modules(&["q_proj"])
273 .target_layers(&layers);
274
275 prop_assert!(config.should_apply("q_proj", None));
277 }
278 }
279
280 #[test]
285 fn test_lora_config_creation() {
286 let config = LoRAConfig::new(16, 16.0);
287 assert_eq!(config.rank, 16);
288 assert_eq!(config.alpha, 16.0);
289 assert_eq!(config.num_target_modules(), 0);
290 assert!(!config.is_all_linear());
291 }
292
293 #[test]
294 fn test_target_modules() {
295 let config = LoRAConfig::new(8, 8.0).target_modules(&["q_proj", "k_proj"]);
296
297 assert!(config.should_apply("q_proj", None));
298 assert!(config.should_apply("k_proj", None));
299 assert!(!config.should_apply("v_proj", None));
300 assert!(!config.should_apply("o_proj", None));
301 assert_eq!(config.num_target_modules(), 2);
302 }
303
304 #[test]
305 fn test_target_attention_projections() {
306 let config = LoRAConfig::new(8, 8.0).target_attention_projections();
307
308 assert!(config.should_apply("q_proj", None));
309 assert!(config.should_apply("k_proj", None));
310 assert!(config.should_apply("v_proj", None));
311 assert!(config.should_apply("o_proj", None));
312 assert!(!config.should_apply("mlp_proj", None));
313 assert_eq!(config.num_target_modules(), 4);
314 }
315
316 #[test]
317 fn test_target_qv_projections() {
318 let config = LoRAConfig::new(8, 8.0).target_qv_projections();
319
320 assert!(config.should_apply("q_proj", None));
321 assert!(config.should_apply("v_proj", None));
322 assert!(!config.should_apply("k_proj", None));
323 assert!(!config.should_apply("o_proj", None));
324 assert_eq!(config.num_target_modules(), 2);
325 }
326
327 #[test]
328 fn test_target_qkv_projections() {
329 let config = LoRAConfig::new(8, 8.0).target_qkv_projections();
330
331 assert!(config.should_apply("q_proj", None));
332 assert!(config.should_apply("k_proj", None));
333 assert!(config.should_apply("v_proj", None));
334 assert!(!config.should_apply("o_proj", None));
335 assert_eq!(config.num_target_modules(), 3);
336 }
337
338 #[test]
339 fn test_target_layers() {
340 let config = LoRAConfig::new(8, 8.0).target_modules(&["q_proj"]).target_layers(&[0, 2, 4]);
341
342 assert!(config.should_apply("q_proj", Some(0)));
344 assert!(!config.should_apply("q_proj", Some(1)));
346 assert!(config.should_apply("q_proj", Some(2)));
348 assert!(!config.should_apply("q_proj", Some(3)));
350 assert!(config.should_apply("q_proj", Some(4)));
352 }
353
354 #[test]
355 fn test_all_linear_layers() {
356 let config = LoRAConfig::new(8, 8.0).all_linear_layers();
357
358 assert!(config.is_all_linear());
359 assert!(config.should_apply("q_proj", None));
360 assert!(config.should_apply("k_proj", None));
361 assert!(config.should_apply("mlp_proj", None));
362 assert!(config.should_apply("fc_linear", None));
363 assert!(!config.should_apply("layer_norm", None));
364 }
365
366 #[test]
367 fn test_default_config() {
368 let config = LoRAConfig::default();
369
370 assert_eq!(config.rank, 8);
371 assert_eq!(config.alpha, 8.0);
372 assert!(config.should_apply("q_proj", None));
373 assert!(config.should_apply("v_proj", None));
374 assert!(!config.should_apply("k_proj", None));
375 assert_eq!(config.num_target_modules(), 2);
376 }
377
378 #[test]
379 fn test_layer_filtering_with_modules() {
380 let config = LoRAConfig::new(4, 4.0).target_attention_projections().target_layers(&[1, 3]);
381
382 assert!(!config.should_apply("q_proj", Some(0)));
384 assert!(config.should_apply("q_proj", Some(1)));
386 assert!(config.should_apply("v_proj", Some(1)));
387 assert!(!config.should_apply("q_proj", Some(2)));
389 assert!(config.should_apply("k_proj", Some(3)));
391 assert!(config.should_apply("o_proj", Some(3)));
392 }
393
394 #[test]
399 fn test_ent_lora_005_expand_all_linear() {
400 let expanded = LoRAConfig::expand_shorthand(&["all_linear".to_string()]);
401 assert_eq!(expanded.len(), 7);
402 assert!(expanded.contains(&"q_proj".to_string()));
403 assert!(expanded.contains(&"k_proj".to_string()));
404 assert!(expanded.contains(&"v_proj".to_string()));
405 assert!(expanded.contains(&"o_proj".to_string()));
406 assert!(expanded.contains(&"gate_proj".to_string()));
407 assert!(expanded.contains(&"up_proj".to_string()));
408 assert!(expanded.contains(&"down_proj".to_string()));
409 }
410
411 #[test]
412 fn test_ent_lora_005_expand_attention() {
413 let expanded = LoRAConfig::expand_shorthand(&["attention".to_string()]);
414 assert_eq!(expanded.len(), 4);
415 assert!(expanded.contains(&"q_proj".to_string()));
416 assert!(expanded.contains(&"k_proj".to_string()));
417 assert!(expanded.contains(&"v_proj".to_string()));
418 assert!(expanded.contains(&"o_proj".to_string()));
419 }
420
421 #[test]
422 fn test_ent_lora_005_expand_qv() {
423 let expanded = LoRAConfig::expand_shorthand(&["qv".to_string()]);
424 assert_eq!(expanded.len(), 2);
425 assert!(expanded.contains(&"q_proj".to_string()));
426 assert!(expanded.contains(&"v_proj".to_string()));
427 }
428
429 #[test]
430 fn test_ent_lora_005_expand_mlp() {
431 let expanded = LoRAConfig::expand_shorthand(&["mlp".to_string()]);
432 assert_eq!(expanded.len(), 3);
433 assert!(expanded.contains(&"gate_proj".to_string()));
434 assert!(expanded.contains(&"up_proj".to_string()));
435 assert!(expanded.contains(&"down_proj".to_string()));
436 }
437
438 #[test]
439 fn test_ent_lora_005_expand_explicit_passthrough() {
440 let explicit = vec!["q_proj".to_string(), "v_proj".to_string()];
441 let expanded = LoRAConfig::expand_shorthand(&explicit);
442 assert_eq!(expanded, explicit);
443 }
444
445 #[test]
446 fn test_ent_lora_005_expand_unknown_single() {
447 let modules = vec!["custom_proj".to_string()];
448 let expanded = LoRAConfig::expand_shorthand(&modules);
449 assert_eq!(expanded, modules);
450 }
451
452 #[test]
453 fn test_get_target_modules() {
454 let config = LoRAConfig::new(8, 8.0).target_modules(&["q_proj", "v_proj"]);
455
456 let mut modules = config.get_target_modules();
457 modules.sort_unstable(); assert_eq!(modules.len(), 2);
460 assert!(modules.contains(&"q_proj"));
461 assert!(modules.contains(&"v_proj"));
462 }
463}