Skip to main content

entrenar/lora/
config.rs

1//! LoRA configuration for target module selection
2//!
3//! Allows selective application of LoRA adapters to specific modules/layers,
4//! commonly used for transformer attention projections (q/k/v/o_proj).
5
6use std::collections::HashSet;
7
8/// Configuration for LoRA adapter targeting
9#[derive(Clone, Debug)]
10pub struct LoRAConfig {
11    /// LoRA rank
12    pub rank: usize,
13    /// LoRA alpha (scaling parameter)
14    pub alpha: f32,
15    /// Target module names (e.g., "q_proj", "k_proj", "v_proj", "o_proj")
16    pub target_modules: HashSet<String>,
17    /// Layer indices to apply LoRA to (None = all layers)
18    pub layers: Option<Vec<usize>>,
19    /// Whether to apply LoRA to all linear layers
20    pub all_linear: bool,
21}
22
23impl LoRAConfig {
24    /// Create a new LoRA configuration
25    ///
26    /// # Arguments
27    /// * `rank` - LoRA rank (typically 4, 8, 16, 32, or 64)
28    /// * `alpha` - LoRA alpha scaling parameter (often same as rank)
29    pub fn new(rank: usize, alpha: f32) -> Self {
30        Self { rank, alpha, target_modules: HashSet::new(), layers: None, all_linear: false }
31    }
32
33    /// Target specific modules by name
34    ///
35    /// # Example
36    /// ```ignore
37    /// config.target_modules(&["q_proj", "v_proj"]);
38    /// ```
39    pub fn target_modules(mut self, modules: &[&str]) -> Self {
40        self.target_modules = modules.iter().map(ToString::to_string).collect();
41        self
42    }
43
44    /// Target attention projection modules (q, k, v, o)
45    pub fn target_attention_projections(mut self) -> Self {
46        self.target_modules = vec![
47            "q_proj".to_string(),
48            "k_proj".to_string(),
49            "v_proj".to_string(),
50            "o_proj".to_string(),
51        ]
52        .into_iter()
53        .collect();
54        self
55    }
56
57    /// Target query and value projections only (common for efficient fine-tuning)
58    pub fn target_qv_projections(mut self) -> Self {
59        self.target_modules =
60            vec!["q_proj".to_string(), "v_proj".to_string()].into_iter().collect();
61        self
62    }
63
64    /// Target all attention projections except output (q, k, v only)
65    pub fn target_qkv_projections(mut self) -> Self {
66        self.target_modules =
67            vec!["q_proj".to_string(), "k_proj".to_string(), "v_proj".to_string()]
68                .into_iter()
69                .collect();
70        self
71    }
72
73    /// Target specific layer indices
74    ///
75    /// # Example
76    /// ```ignore
77    /// config.target_layers(&[0, 1, 2]); // Only first 3 layers
78    /// ```
79    pub fn target_layers(mut self, layer_indices: &[usize]) -> Self {
80        self.layers = Some(layer_indices.to_vec());
81        self
82    }
83
84    /// Apply LoRA to all linear layers
85    pub fn all_linear_layers(mut self) -> Self {
86        self.all_linear = true;
87        self
88    }
89
90    /// Expand shorthand target module names to concrete module lists (ENT-LoRA-005)
91    ///
92    /// Supports:
93    /// - `"all_linear"` → all projection modules (q/k/v/o/gate/up/down)
94    /// - `"attention"` → q/k/v/o projections
95    /// - `"qv"` → q/v projections (default, original paper)
96    /// - `"mlp"` → gate/up/down projections
97    /// - Explicit module names passed through unchanged
98    pub fn expand_shorthand(modules: &[String]) -> Vec<String> {
99        if modules.len() == 1 {
100            match modules[0].as_str() {
101                "all_linear" => {
102                    return vec![
103                        "q_proj",
104                        "k_proj",
105                        "v_proj",
106                        "o_proj",
107                        "gate_proj",
108                        "up_proj",
109                        "down_proj",
110                    ]
111                    .into_iter()
112                    .map(String::from)
113                    .collect()
114                }
115                "attention" => {
116                    return vec!["q_proj", "k_proj", "v_proj", "o_proj"]
117                        .into_iter()
118                        .map(String::from)
119                        .collect()
120                }
121                "qv" => return vec!["q_proj", "v_proj"].into_iter().map(String::from).collect(),
122                "mlp" => {
123                    return vec!["gate_proj", "up_proj", "down_proj"]
124                        .into_iter()
125                        .map(String::from)
126                        .collect()
127                }
128                _ => {}
129            }
130        }
131        modules.to_vec()
132    }
133
134    /// Check if a module should have LoRA applied
135    ///
136    /// # Arguments
137    /// * `module_name` - Name of the module (e.g., "q_proj", "k_proj")
138    /// * `layer_idx` - Layer index (if applicable)
139    pub fn should_apply(&self, module_name: &str, layer_idx: Option<usize>) -> bool {
140        // Check layer index filter
141        if let Some(layers) = &self.layers {
142            if let Some(idx) = layer_idx {
143                if !layers.contains(&idx) {
144                    return false;
145                }
146            }
147        }
148
149        // Check module name filter
150        if self.all_linear {
151            // Apply to all linear layers (assuming module names ending in "proj" or "linear")
152            module_name.ends_with("proj") || module_name.ends_with("linear")
153        } else {
154            self.target_modules.contains(module_name)
155        }
156    }
157
158    /// Get the number of target modules
159    pub fn num_target_modules(&self) -> usize {
160        self.target_modules.len()
161    }
162
163    /// Check if targeting all linear layers
164    pub fn is_all_linear(&self) -> bool {
165        self.all_linear
166    }
167
168    /// Get target module names
169    pub fn get_target_modules(&self) -> Vec<&str> {
170        self.target_modules.iter().map(String::as_str).collect()
171    }
172}
173
174impl Default for LoRAConfig {
175    /// Default configuration: rank=8, alpha=8, target q_proj and v_proj
176    fn default() -> Self {
177        Self::new(8, 8.0).target_qv_projections()
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use proptest::prelude::*;
185
186    // ========================================================================
187    // PROPERTY TESTS - Configuration correctness validation
188    // ========================================================================
189
190    proptest! {
191        #![proptest_config(proptest::test_runner::Config::with_cases(200))]
192
193        /// should_apply should be consistent with target_modules set
194        #[test]
195        fn prop_should_apply_consistent_with_modules(
196            rank in 1usize..64,
197            alpha in 1.0f32..64.0,
198            include_q in proptest::bool::ANY,
199            include_k in proptest::bool::ANY,
200            include_v in proptest::bool::ANY,
201            include_o in proptest::bool::ANY,
202        ) {
203            let mut modules = vec![];
204            if include_q { modules.push("q_proj"); }
205            if include_k { modules.push("k_proj"); }
206            if include_v { modules.push("v_proj"); }
207            if include_o { modules.push("o_proj"); }
208
209            let config = LoRAConfig::new(rank, alpha).target_modules(&modules);
210
211            // should_apply must match what was set
212            prop_assert_eq!(config.should_apply("q_proj", None), include_q);
213            prop_assert_eq!(config.should_apply("k_proj", None), include_k);
214            prop_assert_eq!(config.should_apply("v_proj", None), include_v);
215            prop_assert_eq!(config.should_apply("o_proj", None), include_o);
216            prop_assert_eq!(config.num_target_modules(), modules.len());
217        }
218
219        /// Layer filtering should respect layer indices
220        #[test]
221        fn prop_layer_filtering_respects_indices(
222            layers in prop::collection::vec(0usize..32, 1..8),
223            test_layer in 0usize..32,
224        ) {
225            let config = LoRAConfig::new(8, 8.0)
226                .target_modules(&["q_proj"])
227                .target_layers(&layers);
228
229            // should_apply for a layer should match layer list membership
230            let in_list = layers.contains(&test_layer);
231            prop_assert_eq!(config.should_apply("q_proj", Some(test_layer)), in_list);
232        }
233
234        /// all_linear mode should match any *proj or *linear suffix
235        #[test]
236        fn prop_all_linear_matches_suffixes(
237            prefix in "[a-z]{1,8}",
238        ) {
239            let config = LoRAConfig::new(8, 8.0).all_linear_layers();
240
241            // Should match *_proj and *_linear
242            let proj_name = format!("{prefix}_proj");
243            let linear_name = format!("{prefix}_linear");
244            let other_name = format!("{prefix}_norm");
245
246            prop_assert!(config.should_apply(&proj_name, None));
247            prop_assert!(config.should_apply(&linear_name, None));
248            prop_assert!(!config.should_apply(&other_name, None));
249        }
250
251        /// Config parameters should be preserved after builder chain
252        #[test]
253        fn prop_config_params_preserved(
254            rank in 1usize..128,
255            alpha in 0.1f32..128.0,
256        ) {
257            let config = LoRAConfig::new(rank, alpha)
258                .target_attention_projections()
259                .target_layers(&[0, 1, 2]);
260
261            prop_assert_eq!(config.rank, rank);
262            prop_assert!((config.alpha - alpha).abs() < 1e-6);
263            prop_assert_eq!(config.num_target_modules(), 4);
264        }
265
266        /// None layer index should bypass layer filtering
267        #[test]
268        fn prop_none_layer_bypasses_filter(
269            layers in prop::collection::vec(0usize..16, 1..4),
270        ) {
271            let config = LoRAConfig::new(8, 8.0)
272                .target_modules(&["q_proj"])
273                .target_layers(&layers);
274
275            // None layer index should always pass layer check
276            prop_assert!(config.should_apply("q_proj", None));
277        }
278    }
279
280    // ========================================================================
281    // UNIT TESTS
282    // ========================================================================
283
284    #[test]
285    fn test_lora_config_creation() {
286        let config = LoRAConfig::new(16, 16.0);
287        assert_eq!(config.rank, 16);
288        assert_eq!(config.alpha, 16.0);
289        assert_eq!(config.num_target_modules(), 0);
290        assert!(!config.is_all_linear());
291    }
292
293    #[test]
294    fn test_target_modules() {
295        let config = LoRAConfig::new(8, 8.0).target_modules(&["q_proj", "k_proj"]);
296
297        assert!(config.should_apply("q_proj", None));
298        assert!(config.should_apply("k_proj", None));
299        assert!(!config.should_apply("v_proj", None));
300        assert!(!config.should_apply("o_proj", None));
301        assert_eq!(config.num_target_modules(), 2);
302    }
303
304    #[test]
305    fn test_target_attention_projections() {
306        let config = LoRAConfig::new(8, 8.0).target_attention_projections();
307
308        assert!(config.should_apply("q_proj", None));
309        assert!(config.should_apply("k_proj", None));
310        assert!(config.should_apply("v_proj", None));
311        assert!(config.should_apply("o_proj", None));
312        assert!(!config.should_apply("mlp_proj", None));
313        assert_eq!(config.num_target_modules(), 4);
314    }
315
316    #[test]
317    fn test_target_qv_projections() {
318        let config = LoRAConfig::new(8, 8.0).target_qv_projections();
319
320        assert!(config.should_apply("q_proj", None));
321        assert!(config.should_apply("v_proj", None));
322        assert!(!config.should_apply("k_proj", None));
323        assert!(!config.should_apply("o_proj", None));
324        assert_eq!(config.num_target_modules(), 2);
325    }
326
327    #[test]
328    fn test_target_qkv_projections() {
329        let config = LoRAConfig::new(8, 8.0).target_qkv_projections();
330
331        assert!(config.should_apply("q_proj", None));
332        assert!(config.should_apply("k_proj", None));
333        assert!(config.should_apply("v_proj", None));
334        assert!(!config.should_apply("o_proj", None));
335        assert_eq!(config.num_target_modules(), 3);
336    }
337
338    #[test]
339    fn test_target_layers() {
340        let config = LoRAConfig::new(8, 8.0).target_modules(&["q_proj"]).target_layers(&[0, 2, 4]);
341
342        // Layer 0 - should apply
343        assert!(config.should_apply("q_proj", Some(0)));
344        // Layer 1 - should not apply (not in target layers)
345        assert!(!config.should_apply("q_proj", Some(1)));
346        // Layer 2 - should apply
347        assert!(config.should_apply("q_proj", Some(2)));
348        // Layer 3 - should not apply
349        assert!(!config.should_apply("q_proj", Some(3)));
350        // Layer 4 - should apply
351        assert!(config.should_apply("q_proj", Some(4)));
352    }
353
354    #[test]
355    fn test_all_linear_layers() {
356        let config = LoRAConfig::new(8, 8.0).all_linear_layers();
357
358        assert!(config.is_all_linear());
359        assert!(config.should_apply("q_proj", None));
360        assert!(config.should_apply("k_proj", None));
361        assert!(config.should_apply("mlp_proj", None));
362        assert!(config.should_apply("fc_linear", None));
363        assert!(!config.should_apply("layer_norm", None));
364    }
365
366    #[test]
367    fn test_default_config() {
368        let config = LoRAConfig::default();
369
370        assert_eq!(config.rank, 8);
371        assert_eq!(config.alpha, 8.0);
372        assert!(config.should_apply("q_proj", None));
373        assert!(config.should_apply("v_proj", None));
374        assert!(!config.should_apply("k_proj", None));
375        assert_eq!(config.num_target_modules(), 2);
376    }
377
378    #[test]
379    fn test_layer_filtering_with_modules() {
380        let config = LoRAConfig::new(4, 4.0).target_attention_projections().target_layers(&[1, 3]);
381
382        // Layer 0 - wrong layer
383        assert!(!config.should_apply("q_proj", Some(0)));
384        // Layer 1 - correct layer and module
385        assert!(config.should_apply("q_proj", Some(1)));
386        assert!(config.should_apply("v_proj", Some(1)));
387        // Layer 2 - wrong layer
388        assert!(!config.should_apply("q_proj", Some(2)));
389        // Layer 3 - correct layer and module
390        assert!(config.should_apply("k_proj", Some(3)));
391        assert!(config.should_apply("o_proj", Some(3)));
392    }
393
394    // ========================================================================
395    // ENT-LoRA-005: Shorthand expansion tests
396    // ========================================================================
397
398    #[test]
399    fn test_ent_lora_005_expand_all_linear() {
400        let expanded = LoRAConfig::expand_shorthand(&["all_linear".to_string()]);
401        assert_eq!(expanded.len(), 7);
402        assert!(expanded.contains(&"q_proj".to_string()));
403        assert!(expanded.contains(&"k_proj".to_string()));
404        assert!(expanded.contains(&"v_proj".to_string()));
405        assert!(expanded.contains(&"o_proj".to_string()));
406        assert!(expanded.contains(&"gate_proj".to_string()));
407        assert!(expanded.contains(&"up_proj".to_string()));
408        assert!(expanded.contains(&"down_proj".to_string()));
409    }
410
411    #[test]
412    fn test_ent_lora_005_expand_attention() {
413        let expanded = LoRAConfig::expand_shorthand(&["attention".to_string()]);
414        assert_eq!(expanded.len(), 4);
415        assert!(expanded.contains(&"q_proj".to_string()));
416        assert!(expanded.contains(&"k_proj".to_string()));
417        assert!(expanded.contains(&"v_proj".to_string()));
418        assert!(expanded.contains(&"o_proj".to_string()));
419    }
420
421    #[test]
422    fn test_ent_lora_005_expand_qv() {
423        let expanded = LoRAConfig::expand_shorthand(&["qv".to_string()]);
424        assert_eq!(expanded.len(), 2);
425        assert!(expanded.contains(&"q_proj".to_string()));
426        assert!(expanded.contains(&"v_proj".to_string()));
427    }
428
429    #[test]
430    fn test_ent_lora_005_expand_mlp() {
431        let expanded = LoRAConfig::expand_shorthand(&["mlp".to_string()]);
432        assert_eq!(expanded.len(), 3);
433        assert!(expanded.contains(&"gate_proj".to_string()));
434        assert!(expanded.contains(&"up_proj".to_string()));
435        assert!(expanded.contains(&"down_proj".to_string()));
436    }
437
438    #[test]
439    fn test_ent_lora_005_expand_explicit_passthrough() {
440        let explicit = vec!["q_proj".to_string(), "v_proj".to_string()];
441        let expanded = LoRAConfig::expand_shorthand(&explicit);
442        assert_eq!(expanded, explicit);
443    }
444
445    #[test]
446    fn test_ent_lora_005_expand_unknown_single() {
447        let modules = vec!["custom_proj".to_string()];
448        let expanded = LoRAConfig::expand_shorthand(&modules);
449        assert_eq!(expanded, modules);
450    }
451
452    #[test]
453    fn test_get_target_modules() {
454        let config = LoRAConfig::new(8, 8.0).target_modules(&["q_proj", "v_proj"]);
455
456        let mut modules = config.get_target_modules();
457        modules.sort_unstable(); // HashSet order is not guaranteed
458
459        assert_eq!(modules.len(), 2);
460        assert!(modules.contains(&"q_proj"));
461        assert!(modules.contains(&"v_proj"));
462    }
463}