Skip to main content

entrenar/lora/
multi_adapter.rs

1//! Multi-adapter training (ENT-LoRA-013)
2//!
3//! Supports N adapters sharing one frozen base model, each with independent
4//! optimizer states. This enables:
5//! - Multi-task fine-tuning (one adapter per task)
6//! - A/B testing different LoRA configurations
7//! - Adapter composition (combine adapters at inference)
8//!
9//! Architecture:
10//! - One shared base model (frozen, loaded once)
11//! - N LoRA adapter sets, each with own A/B matrices
12//! - Independent forward passes per adapter
13//! - Independent optimizer states per adapter
14
15use crate::lora::LoRALayer;
16use crate::Tensor;
17
18/// Named adapter wrapping a set of LoRA layers
19#[derive(Clone)]
20pub struct NamedAdapter {
21    /// Human-readable name (e.g., "task_a", "safety_classifier")
22    pub name: String,
23    /// LoRA layers for this adapter
24    pub layers: Vec<LoRALayer>,
25    /// Whether this adapter is active (receives gradients)
26    pub active: bool,
27}
28
29impl NamedAdapter {
30    /// Create a new named adapter
31    pub fn new(name: impl Into<String>, layers: Vec<LoRALayer>) -> Self {
32        Self { name: name.into(), layers, active: true }
33    }
34
35    /// Get trainable parameters across all layers
36    pub fn trainable_params(&mut self) -> Vec<&mut Tensor> {
37        self.layers.iter_mut().flat_map(|l| l.trainable_params()).collect()
38    }
39
40    /// Total trainable parameter count
41    pub fn param_count(&self) -> usize {
42        self.layers.iter().map(|l| l.lora_a().len() + l.lora_b().len()).sum()
43    }
44
45    /// Merge all layers for inference
46    pub fn merge_all(&mut self) {
47        for layer in &mut self.layers {
48            layer.merge();
49        }
50    }
51
52    /// Unmerge all layers (return to training mode)
53    pub fn unmerge_all(&mut self) {
54        for layer in &mut self.layers {
55            layer.unmerge();
56        }
57    }
58}
59
60/// Multi-adapter manager
61///
62/// Manages multiple LoRA adapters sharing a frozen base model.
63/// Each adapter can be independently trained, activated/deactivated,
64/// and merged.
65pub struct MultiAdapterManager {
66    /// Named adapters indexed by position
67    adapters: Vec<NamedAdapter>,
68}
69
70impl MultiAdapterManager {
71    /// Create an empty multi-adapter manager
72    pub fn new() -> Self {
73        Self { adapters: Vec::new() }
74    }
75
76    /// Add an adapter, returns its index
77    pub fn add_adapter(&mut self, adapter: NamedAdapter) -> usize {
78        let idx = self.adapters.len();
79        self.adapters.push(adapter);
80        idx
81    }
82
83    /// Get adapter by index
84    pub fn get(&self, idx: usize) -> Option<&NamedAdapter> {
85        self.adapters.get(idx)
86    }
87
88    /// Get mutable adapter by index
89    pub fn get_mut(&mut self, idx: usize) -> Option<&mut NamedAdapter> {
90        self.adapters.get_mut(idx)
91    }
92
93    /// Find adapter by name
94    pub fn find_by_name(&self, name: &str) -> Option<(usize, &NamedAdapter)> {
95        self.adapters.iter().enumerate().find(|(_, a)| a.name == name)
96    }
97
98    /// Number of adapters
99    pub fn len(&self) -> usize {
100        self.adapters.len()
101    }
102
103    /// Check if empty
104    pub fn is_empty(&self) -> bool {
105        self.adapters.is_empty()
106    }
107
108    /// Get all active adapters
109    pub fn active_adapters(&self) -> Vec<(usize, &NamedAdapter)> {
110        self.adapters.iter().enumerate().filter(|(_, a)| a.active).collect()
111    }
112
113    /// Set adapter active/inactive
114    pub fn set_active(&mut self, idx: usize, active: bool) {
115        if let Some(adapter) = self.adapters.get_mut(idx) {
116            adapter.active = active;
117        }
118    }
119
120    /// Total trainable parameters across all active adapters
121    pub fn total_trainable_params(&self) -> usize {
122        self.adapters.iter().filter(|a| a.active).map(NamedAdapter::param_count).sum()
123    }
124
125    /// Summary of all adapters
126    pub fn summary(&self) -> String {
127        let mut lines = vec![format!("Multi-adapter manager: {} adapters", self.adapters.len())];
128        for (i, adapter) in self.adapters.iter().enumerate() {
129            let status = if adapter.active { "ACTIVE" } else { "INACTIVE" };
130            lines.push(format!(
131                "  [{}] {} — {} params, {} layers, {}",
132                i,
133                adapter.name,
134                adapter.param_count(),
135                adapter.layers.len(),
136                status,
137            ));
138        }
139        lines.join("\n")
140    }
141
142    /// Remove adapter by index (returns the removed adapter)
143    pub fn remove(&mut self, idx: usize) -> Option<NamedAdapter> {
144        if idx < self.adapters.len() {
145            Some(self.adapters.remove(idx))
146        } else {
147            None
148        }
149    }
150
151    /// Iterator over all adapters
152    pub fn iter(&self) -> impl Iterator<Item = &NamedAdapter> {
153        self.adapters.iter()
154    }
155
156    /// Mutable iterator
157    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut NamedAdapter> {
158        self.adapters.iter_mut()
159    }
160}
161
162impl Default for MultiAdapterManager {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168#[cfg(test)]
169#[allow(clippy::unwrap_used)]
170mod tests {
171    use super::*;
172    use crate::lora::LoRALayer;
173    use proptest::prelude::*;
174
175    fn make_lora_layer(d_out: usize, d_in: usize, rank: usize) -> LoRALayer {
176        let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
177        LoRALayer::new(base, d_out, d_in, rank, 4.0)
178    }
179
180    #[test]
181    fn test_ent_lora_013_multi_adapter_creation() {
182        let mut mgr = MultiAdapterManager::new();
183        assert!(mgr.is_empty());
184        assert_eq!(mgr.len(), 0);
185
186        let adapter = NamedAdapter::new("task_a", vec![make_lora_layer(4, 4, 2)]);
187        let idx = mgr.add_adapter(adapter);
188        assert_eq!(idx, 0);
189        assert_eq!(mgr.len(), 1);
190        assert!(!mgr.is_empty());
191    }
192
193    #[test]
194    fn test_ent_lora_013_multiple_adapters() {
195        let mut mgr = MultiAdapterManager::new();
196        mgr.add_adapter(NamedAdapter::new("safety", vec![make_lora_layer(4, 4, 2)]));
197        mgr.add_adapter(NamedAdapter::new("style", vec![make_lora_layer(4, 4, 4)]));
198
199        assert_eq!(mgr.len(), 2);
200        assert_eq!(mgr.get(0).unwrap().name, "safety");
201        assert_eq!(mgr.get(1).unwrap().name, "style");
202    }
203
204    #[test]
205    fn test_ent_lora_013_find_by_name() {
206        let mut mgr = MultiAdapterManager::new();
207        mgr.add_adapter(NamedAdapter::new("alpha", vec![make_lora_layer(4, 4, 2)]));
208        mgr.add_adapter(NamedAdapter::new("beta", vec![make_lora_layer(4, 4, 2)]));
209
210        let (idx, adapter) = mgr.find_by_name("beta").unwrap();
211        assert_eq!(idx, 1);
212        assert_eq!(adapter.name, "beta");
213        assert!(mgr.find_by_name("gamma").is_none());
214    }
215
216    #[test]
217    fn test_ent_lora_013_active_inactive() {
218        let mut mgr = MultiAdapterManager::new();
219        mgr.add_adapter(NamedAdapter::new("a", vec![make_lora_layer(4, 4, 2)]));
220        mgr.add_adapter(NamedAdapter::new("b", vec![make_lora_layer(4, 4, 2)]));
221
222        assert_eq!(mgr.active_adapters().len(), 2);
223
224        mgr.set_active(0, false);
225        assert_eq!(mgr.active_adapters().len(), 1);
226        assert_eq!(mgr.active_adapters()[0].1.name, "b");
227    }
228
229    #[test]
230    fn test_ent_lora_013_param_count() {
231        let adapter = NamedAdapter::new(
232            "test",
233            vec![
234                make_lora_layer(8, 4, 2), // A: 2*4=8, B: 8*2=16 = 24
235                make_lora_layer(4, 8, 2), // A: 2*8=16, B: 4*2=8 = 24
236            ],
237        );
238        assert_eq!(adapter.param_count(), 48);
239    }
240
241    #[test]
242    fn test_ent_lora_013_total_trainable_params() {
243        let mut mgr = MultiAdapterManager::new();
244        mgr.add_adapter(NamedAdapter::new("a", vec![make_lora_layer(4, 4, 2)]));
245        mgr.add_adapter(NamedAdapter::new("b", vec![make_lora_layer(4, 4, 2)]));
246
247        let total = mgr.total_trainable_params();
248        assert!(total > 0);
249
250        mgr.set_active(0, false);
251        let reduced = mgr.total_trainable_params();
252        assert!(reduced < total);
253    }
254
255    #[test]
256    fn test_ent_lora_013_summary() {
257        let mut mgr = MultiAdapterManager::new();
258        mgr.add_adapter(NamedAdapter::new("task_a", vec![make_lora_layer(4, 4, 2)]));
259        let summary = mgr.summary();
260        assert!(summary.contains("task_a"));
261        assert!(summary.contains("ACTIVE"));
262    }
263
264    #[test]
265    fn test_ent_lora_013_remove_adapter() {
266        let mut mgr = MultiAdapterManager::new();
267        mgr.add_adapter(NamedAdapter::new("a", vec![]));
268        mgr.add_adapter(NamedAdapter::new("b", vec![]));
269
270        let removed = mgr.remove(0).unwrap();
271        assert_eq!(removed.name, "a");
272        assert_eq!(mgr.len(), 1);
273        assert_eq!(mgr.get(0).unwrap().name, "b");
274    }
275
276    #[test]
277    fn test_ent_lora_013_trainable_params_mut() {
278        let mut adapter = NamedAdapter::new("test", vec![make_lora_layer(4, 4, 2)]);
279        let params = adapter.trainable_params();
280        // Each LoRA layer has 2 params: A and B
281        assert_eq!(params.len(), 2);
282    }
283
284    #[test]
285    fn test_ent_lora_013_merge_unmerge() {
286        let mut adapter = NamedAdapter::new("test", vec![make_lora_layer(4, 4, 2)]);
287        assert!(!adapter.layers[0].is_merged());
288
289        adapter.merge_all();
290        assert!(adapter.layers[0].is_merged());
291
292        adapter.unmerge_all();
293        assert!(!adapter.layers[0].is_merged());
294    }
295
296    #[test]
297    fn test_ent_lora_013_default() {
298        let mgr = MultiAdapterManager::default();
299        assert!(mgr.is_empty());
300    }
301
302    proptest! {
303        #![proptest_config(proptest::test_runner::Config::with_cases(30))]
304
305        #[test]
306        fn prop_multi_adapter_param_count_additive(
307            n_adapters in 1usize..5,
308            d in 4usize..8,
309            rank in 1usize..3,
310        ) {
311            let mut mgr = MultiAdapterManager::new();
312            let mut expected = 0usize;
313            for i in 0..n_adapters {
314                let adapter = NamedAdapter::new(format!("a{i}"), vec![make_lora_layer(d, d, rank)]);
315                expected += adapter.param_count();
316                mgr.add_adapter(adapter);
317            }
318            prop_assert_eq!(mgr.total_trainable_params(), expected);
319        }
320    }
321}