1use crate::lora::LoRALayer;
16use crate::Tensor;
17
18#[derive(Clone)]
20pub struct NamedAdapter {
21 pub name: String,
23 pub layers: Vec<LoRALayer>,
25 pub active: bool,
27}
28
29impl NamedAdapter {
30 pub fn new(name: impl Into<String>, layers: Vec<LoRALayer>) -> Self {
32 Self { name: name.into(), layers, active: true }
33 }
34
35 pub fn trainable_params(&mut self) -> Vec<&mut Tensor> {
37 self.layers.iter_mut().flat_map(|l| l.trainable_params()).collect()
38 }
39
40 pub fn param_count(&self) -> usize {
42 self.layers.iter().map(|l| l.lora_a().len() + l.lora_b().len()).sum()
43 }
44
45 pub fn merge_all(&mut self) {
47 for layer in &mut self.layers {
48 layer.merge();
49 }
50 }
51
52 pub fn unmerge_all(&mut self) {
54 for layer in &mut self.layers {
55 layer.unmerge();
56 }
57 }
58}
59
60pub struct MultiAdapterManager {
66 adapters: Vec<NamedAdapter>,
68}
69
70impl MultiAdapterManager {
71 pub fn new() -> Self {
73 Self { adapters: Vec::new() }
74 }
75
76 pub fn add_adapter(&mut self, adapter: NamedAdapter) -> usize {
78 let idx = self.adapters.len();
79 self.adapters.push(adapter);
80 idx
81 }
82
83 pub fn get(&self, idx: usize) -> Option<&NamedAdapter> {
85 self.adapters.get(idx)
86 }
87
88 pub fn get_mut(&mut self, idx: usize) -> Option<&mut NamedAdapter> {
90 self.adapters.get_mut(idx)
91 }
92
93 pub fn find_by_name(&self, name: &str) -> Option<(usize, &NamedAdapter)> {
95 self.adapters.iter().enumerate().find(|(_, a)| a.name == name)
96 }
97
98 pub fn len(&self) -> usize {
100 self.adapters.len()
101 }
102
103 pub fn is_empty(&self) -> bool {
105 self.adapters.is_empty()
106 }
107
108 pub fn active_adapters(&self) -> Vec<(usize, &NamedAdapter)> {
110 self.adapters.iter().enumerate().filter(|(_, a)| a.active).collect()
111 }
112
113 pub fn set_active(&mut self, idx: usize, active: bool) {
115 if let Some(adapter) = self.adapters.get_mut(idx) {
116 adapter.active = active;
117 }
118 }
119
120 pub fn total_trainable_params(&self) -> usize {
122 self.adapters.iter().filter(|a| a.active).map(NamedAdapter::param_count).sum()
123 }
124
125 pub fn summary(&self) -> String {
127 let mut lines = vec![format!("Multi-adapter manager: {} adapters", self.adapters.len())];
128 for (i, adapter) in self.adapters.iter().enumerate() {
129 let status = if adapter.active { "ACTIVE" } else { "INACTIVE" };
130 lines.push(format!(
131 " [{}] {} — {} params, {} layers, {}",
132 i,
133 adapter.name,
134 adapter.param_count(),
135 adapter.layers.len(),
136 status,
137 ));
138 }
139 lines.join("\n")
140 }
141
142 pub fn remove(&mut self, idx: usize) -> Option<NamedAdapter> {
144 if idx < self.adapters.len() {
145 Some(self.adapters.remove(idx))
146 } else {
147 None
148 }
149 }
150
151 pub fn iter(&self) -> impl Iterator<Item = &NamedAdapter> {
153 self.adapters.iter()
154 }
155
156 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut NamedAdapter> {
158 self.adapters.iter_mut()
159 }
160}
161
162impl Default for MultiAdapterManager {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168#[cfg(test)]
169#[allow(clippy::unwrap_used)]
170mod tests {
171 use super::*;
172 use crate::lora::LoRALayer;
173 use proptest::prelude::*;
174
175 fn make_lora_layer(d_out: usize, d_in: usize, rank: usize) -> LoRALayer {
176 let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
177 LoRALayer::new(base, d_out, d_in, rank, 4.0)
178 }
179
180 #[test]
181 fn test_ent_lora_013_multi_adapter_creation() {
182 let mut mgr = MultiAdapterManager::new();
183 assert!(mgr.is_empty());
184 assert_eq!(mgr.len(), 0);
185
186 let adapter = NamedAdapter::new("task_a", vec![make_lora_layer(4, 4, 2)]);
187 let idx = mgr.add_adapter(adapter);
188 assert_eq!(idx, 0);
189 assert_eq!(mgr.len(), 1);
190 assert!(!mgr.is_empty());
191 }
192
193 #[test]
194 fn test_ent_lora_013_multiple_adapters() {
195 let mut mgr = MultiAdapterManager::new();
196 mgr.add_adapter(NamedAdapter::new("safety", vec![make_lora_layer(4, 4, 2)]));
197 mgr.add_adapter(NamedAdapter::new("style", vec![make_lora_layer(4, 4, 4)]));
198
199 assert_eq!(mgr.len(), 2);
200 assert_eq!(mgr.get(0).unwrap().name, "safety");
201 assert_eq!(mgr.get(1).unwrap().name, "style");
202 }
203
204 #[test]
205 fn test_ent_lora_013_find_by_name() {
206 let mut mgr = MultiAdapterManager::new();
207 mgr.add_adapter(NamedAdapter::new("alpha", vec![make_lora_layer(4, 4, 2)]));
208 mgr.add_adapter(NamedAdapter::new("beta", vec![make_lora_layer(4, 4, 2)]));
209
210 let (idx, adapter) = mgr.find_by_name("beta").unwrap();
211 assert_eq!(idx, 1);
212 assert_eq!(adapter.name, "beta");
213 assert!(mgr.find_by_name("gamma").is_none());
214 }
215
216 #[test]
217 fn test_ent_lora_013_active_inactive() {
218 let mut mgr = MultiAdapterManager::new();
219 mgr.add_adapter(NamedAdapter::new("a", vec![make_lora_layer(4, 4, 2)]));
220 mgr.add_adapter(NamedAdapter::new("b", vec![make_lora_layer(4, 4, 2)]));
221
222 assert_eq!(mgr.active_adapters().len(), 2);
223
224 mgr.set_active(0, false);
225 assert_eq!(mgr.active_adapters().len(), 1);
226 assert_eq!(mgr.active_adapters()[0].1.name, "b");
227 }
228
229 #[test]
230 fn test_ent_lora_013_param_count() {
231 let adapter = NamedAdapter::new(
232 "test",
233 vec![
234 make_lora_layer(8, 4, 2), make_lora_layer(4, 8, 2), ],
237 );
238 assert_eq!(adapter.param_count(), 48);
239 }
240
241 #[test]
242 fn test_ent_lora_013_total_trainable_params() {
243 let mut mgr = MultiAdapterManager::new();
244 mgr.add_adapter(NamedAdapter::new("a", vec![make_lora_layer(4, 4, 2)]));
245 mgr.add_adapter(NamedAdapter::new("b", vec![make_lora_layer(4, 4, 2)]));
246
247 let total = mgr.total_trainable_params();
248 assert!(total > 0);
249
250 mgr.set_active(0, false);
251 let reduced = mgr.total_trainable_params();
252 assert!(reduced < total);
253 }
254
255 #[test]
256 fn test_ent_lora_013_summary() {
257 let mut mgr = MultiAdapterManager::new();
258 mgr.add_adapter(NamedAdapter::new("task_a", vec![make_lora_layer(4, 4, 2)]));
259 let summary = mgr.summary();
260 assert!(summary.contains("task_a"));
261 assert!(summary.contains("ACTIVE"));
262 }
263
264 #[test]
265 fn test_ent_lora_013_remove_adapter() {
266 let mut mgr = MultiAdapterManager::new();
267 mgr.add_adapter(NamedAdapter::new("a", vec![]));
268 mgr.add_adapter(NamedAdapter::new("b", vec![]));
269
270 let removed = mgr.remove(0).unwrap();
271 assert_eq!(removed.name, "a");
272 assert_eq!(mgr.len(), 1);
273 assert_eq!(mgr.get(0).unwrap().name, "b");
274 }
275
276 #[test]
277 fn test_ent_lora_013_trainable_params_mut() {
278 let mut adapter = NamedAdapter::new("test", vec![make_lora_layer(4, 4, 2)]);
279 let params = adapter.trainable_params();
280 assert_eq!(params.len(), 2);
282 }
283
284 #[test]
285 fn test_ent_lora_013_merge_unmerge() {
286 let mut adapter = NamedAdapter::new("test", vec![make_lora_layer(4, 4, 2)]);
287 assert!(!adapter.layers[0].is_merged());
288
289 adapter.merge_all();
290 assert!(adapter.layers[0].is_merged());
291
292 adapter.unmerge_all();
293 assert!(!adapter.layers[0].is_merged());
294 }
295
296 #[test]
297 fn test_ent_lora_013_default() {
298 let mgr = MultiAdapterManager::default();
299 assert!(mgr.is_empty());
300 }
301
302 proptest! {
303 #![proptest_config(proptest::test_runner::Config::with_cases(30))]
304
305 #[test]
306 fn prop_multi_adapter_param_count_additive(
307 n_adapters in 1usize..5,
308 d in 4usize..8,
309 rank in 1usize..3,
310 ) {
311 let mut mgr = MultiAdapterManager::new();
312 let mut expected = 0usize;
313 for i in 0..n_adapters {
314 let adapter = NamedAdapter::new(format!("a{i}"), vec![make_lora_layer(d, d, rank)]);
315 expected += adapter.param_count();
316 mgr.add_adapter(adapter);
317 }
318 prop_assert_eq!(mgr.total_trainable_params(), expected);
319 }
320 }
321}