Skip to main content

oxillama_arch/lora/
stack.rs

1//! `LoraStack` — ordered composition of multiple LoRA adapters.
2
3use std::sync::Arc;
4
5use super::adapter::{LoraAdapterTrait, TargetModule};
6use super::loader::LoadedLora;
7use crate::error::ArchResult;
8
9/// A stack of LoRA adapters applied in order.
10///
11/// Adapters are applied additively:
12/// `output += scale_0 · LoRA_0(x) + scale_1 · LoRA_1(x) + ...`
13///
14/// Each adapter's intrinsic `alpha/rank` scale is multiplied by the per-entry `stack_scale`.
15///
16/// The stack supports two independent adapter lists:
17/// - **`entries`**: legacy [`LoadedLora`] adapters (GGUF-loaded, keyed by tensor name).
18/// - **`adapter_list`**: trait-object adapters implementing [`LoraAdapterTrait`].
19///
20/// Both lists are consulted by the respective dispatch methods.
21#[derive(Default)]
22pub struct LoraStack {
23    /// Ordered list of `(GGUF-loaded adapter set, per-entry scale multiplier)`.
24    entries: Vec<(Arc<LoadedLora>, f32)>,
25    /// Ordered list of trait-object adapters (new public API).
26    adapter_list: Vec<Arc<dyn LoraAdapterTrait>>,
27}
28
29impl std::fmt::Debug for LoraStack {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("LoraStack")
32            .field("entries_len", &self.entries.len())
33            .field("adapter_list_len", &self.adapter_list.len())
34            .finish()
35    }
36}
37
38impl Clone for LoraStack {
39    fn clone(&self) -> Self {
40        Self {
41            entries: self.entries.clone(),
42            // Arc clones are cheap reference-count bumps.
43            adapter_list: self.adapter_list.clone(),
44        }
45    }
46}
47
48impl LoraStack {
49    /// Create an empty stack.
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    // ─── Legacy `LoadedLora` API ──────────────────────────────────────────────
55
56    /// Push a new GGUF-loaded adapter onto the stack with a given scale multiplier.
57    ///
58    /// `scale` multiplies the adapter's intrinsic `alpha/rank` scale factor.
59    pub fn push(&mut self, adapter: Arc<LoadedLora>, scale: f32) {
60        self.entries.push((adapter, scale));
61    }
62
63    /// Remove the last GGUF-loaded adapter from the stack.
64    ///
65    /// Returns `None` if the entries list is empty.
66    pub fn pop(&mut self) -> Option<(Arc<LoadedLora>, f32)> {
67        self.entries.pop()
68    }
69
70    /// Remove all adapters (both lists).
71    pub fn clear(&mut self) {
72        self.entries.clear();
73        self.adapter_list.clear();
74    }
75
76    /// Number of stacked legacy adapters.
77    pub fn len(&self) -> usize {
78        self.entries.len()
79    }
80
81    /// Returns `true` if no adapters are stacked in either list.
82    pub fn is_empty(&self) -> bool {
83        self.entries.is_empty() && self.adapter_list.is_empty()
84    }
85
86    /// Immutable view of the `(LoadedLora, scale)` entries.
87    pub fn entries(&self) -> &[(Arc<LoadedLora>, f32)] {
88        &self.entries
89    }
90
91    /// Immutable view of the trait-object adapter list.
92    pub fn adapters(&self) -> &[Arc<dyn LoraAdapterTrait>] {
93        &self.adapter_list
94    }
95
96    /// Apply the entire legacy stack to compute the additive LoRA delta for
97    /// `tensor_name`.
98    ///
99    /// Returns `Ok(delta)` where `delta` is a `Vec<f32>` of length
100    /// `out_features` representing `Σ_i stack_scale_i · (alpha_i/rank_i) · B_i @ A_i @ input`.
101    /// Adapters that do not define `tensor_name` are silently skipped.
102    /// The caller is responsible for adding `delta` to the base linear output.
103    ///
104    /// # Errors
105    /// Returns [`crate::error::ArchError::Quant`] if a matching adapter's dimension checks fail.
106    pub fn apply(
107        &self,
108        tensor_name: &str,
109        input: &[f32],
110        out_features: usize,
111    ) -> ArchResult<Vec<f32>> {
112        let mut delta = vec![0.0f32; out_features];
113        for (lora, stack_scale) in &self.entries {
114            let Some(adapter) = lora.get(tensor_name) else {
115                continue;
116            };
117            let rank = adapter.rank;
118            let in_f = adapter.in_features;
119            let out_f = adapter.out_features.min(out_features);
120
121            // Step 1: r = A @ input  (length = rank)
122            let mut r_vec = vec![0.0f32; rank];
123            for (i, r) in r_vec.iter_mut().enumerate() {
124                let row = &adapter.a[i * in_f..(i + 1) * in_f];
125                *r = row
126                    .iter()
127                    .zip(input.iter().take(in_f))
128                    .map(|(&a, &x)| a * x)
129                    .sum();
130            }
131
132            // Step 2: delta += B @ r * (intrinsic_scale * stack_scale)
133            let combined = adapter.scale * stack_scale;
134            for (i, d) in delta.iter_mut().enumerate().take(out_f) {
135                let row = &adapter.b[i * rank..(i + 1) * rank];
136                let v: f32 = row.iter().zip(r_vec.iter()).map(|(&b, &r)| b * r).sum();
137                *d += v * combined;
138            }
139        }
140        Ok(delta)
141    }
142
143    // ─── New trait-object adapter API ─────────────────────────────────────────
144
145    /// Push a trait-object adapter onto the new adapter list.
146    pub fn push_adapter(&mut self, adapter: Arc<dyn LoraAdapterTrait>) {
147        self.adapter_list.push(adapter);
148    }
149
150    /// Compute the combined delta for `(target, layer)` across all
151    /// trait-object adapters, applied to input vector `input`.
152    ///
153    /// Returns `None` if no adapter in the list covers this `(target, layer)`.
154    ///
155    /// Formula: `Δ_total = Σᵢ (αᵢ / rᵢ) * Bᵢ @ Aᵢ @ input`.
156    pub fn applied_delta(
157        &self,
158        target: TargetModule,
159        layer: usize,
160        input: &[f32],
161    ) -> Option<Vec<f32>> {
162        let mut result: Option<Vec<f32>> = None;
163        for adapter in &self.adapter_list {
164            let scale = adapter.alpha() / adapter.rank().max(1) as f32;
165            if let Some(delta) = adapter.delta(target, layer) {
166                let contribution = delta.apply(input, scale);
167                match result {
168                    None => result = Some(contribution),
169                    Some(ref mut acc) => {
170                        for (a, c) in acc.iter_mut().zip(contribution.iter()) {
171                            *a += c;
172                        }
173                    }
174                }
175            }
176        }
177        result
178    }
179}
180
181// ─── Tests ────────────────────────────────────────────────────────────────────
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::lora::adapter::{LoraAdapterTrait, LoraDelta, TargetModule};
187    use std::collections::HashMap;
188    use std::sync::Arc;
189
190    // ─── Helper: simple in-memory LoraAdapterTrait impl ──────────────────────
191
192    /// A simple test adapter backed by a map of `(TargetModule, layer) → LoraDelta`.
193    struct TestAdapter {
194        rank: usize,
195        alpha: f32,
196        deltas: HashMap<(u32, usize), LoraDelta>,
197        modules: Vec<TargetModule>,
198    }
199
200    impl TestAdapter {
201        fn new(rank: usize, alpha: f32) -> Self {
202            Self {
203                rank,
204                alpha,
205                deltas: HashMap::new(),
206                modules: Vec::new(),
207            }
208        }
209
210        fn add_delta(&mut self, target: TargetModule, layer: usize, delta: LoraDelta) {
211            let key = (target_to_u32(target), layer);
212            if !self.modules.contains(&target) {
213                self.modules.push(target);
214            }
215            self.deltas.insert(key, delta);
216        }
217    }
218
219    fn target_to_u32(t: TargetModule) -> u32 {
220        match t {
221            TargetModule::QueryProj => 0,
222            TargetModule::KeyProj => 1,
223            TargetModule::ValueProj => 2,
224            TargetModule::OutputProj => 3,
225            TargetModule::GateProj => 4,
226            TargetModule::UpProj => 5,
227            TargetModule::DownProj => 6,
228            TargetModule::Custom(id) => 100 + id,
229        }
230    }
231
232    impl LoraAdapterTrait for TestAdapter {
233        fn rank(&self) -> usize {
234            self.rank
235        }
236        fn alpha(&self) -> f32 {
237            self.alpha
238        }
239        fn target_modules(&self) -> &[TargetModule] {
240            &self.modules
241        }
242        fn delta(&self, target: TargetModule, layer: usize) -> Option<&LoraDelta> {
243            let key = (target_to_u32(target), layer);
244            self.deltas.get(&key)
245        }
246    }
247
248    // ─── Tests ────────────────────────────────────────────────────────────────
249
250    /// Empty adapter list → applied_delta returns None.
251    #[test]
252    fn empty_stack_applied_delta_returns_none() {
253        let stack = LoraStack::new();
254        let result = stack.applied_delta(TargetModule::QueryProj, 0, &[1.0f32, 2.0, 3.0]);
255        assert!(result.is_none(), "empty stack must return None");
256    }
257
258    /// Single adapter with an identity delta: result matches input.
259    #[test]
260    fn single_lora_identity_matches_input() {
261        let rank = 4;
262        let in_dim = 4;
263        let out_dim = 4;
264
265        // Build identity A and B.
266        let mut a = vec![0.0f32; rank * in_dim];
267        let mut b = vec![0.0f32; out_dim * rank];
268        for i in 0..rank {
269            a[i * in_dim + i] = 1.0;
270            b[i * rank + i] = 1.0;
271        }
272        let delta = LoraDelta::new(a, b, rank, in_dim, out_dim);
273        let alpha = rank as f32;
274
275        let mut adapter = TestAdapter::new(rank, alpha);
276        adapter.add_delta(TargetModule::QueryProj, 0, delta);
277
278        let mut stack = LoraStack::new();
279        stack.push_adapter(Arc::new(adapter));
280
281        let x = vec![1.0f32, 2.0, 3.0, 4.0];
282        let result = stack
283            .applied_delta(TargetModule::QueryProj, 0, &x)
284            .expect("single adapter must produce a result");
285
286        // scale = alpha/rank = 1.0, identity delta passes x through.
287        for (r, xi) in result.iter().zip(x.iter()) {
288            assert!((r - xi).abs() < 1e-5, "expected {xi} got {r}");
289        }
290    }
291
292    /// Two adapters with the same delta compose additively.
293    #[test]
294    fn two_loras_compose_additively() {
295        let rank = 2;
296        let in_dim = 4;
297        let out_dim = 4;
298        let alpha = 2.0f32; // scale = alpha/rank = 1.0
299
300        // A=[1,0,0,0; 0,1,0,0], B=[1,0; 0,1; 0,0; 0,0]
301        let a = vec![
302            1.0f32, 0.0, 0.0, 0.0, // row 0
303            0.0, 1.0, 0.0, 0.0, // row 1
304        ];
305        let b = vec![
306            1.0f32, 0.0, // row 0
307            0.0, 1.0, // row 1
308            0.0, 0.0, // row 2
309            0.0, 0.0, // row 3
310        ];
311
312        let make_delta = || LoraDelta::new(a.clone(), b.clone(), rank, in_dim, out_dim);
313
314        let mut adapter1 = TestAdapter::new(rank, alpha);
315        adapter1.add_delta(TargetModule::QueryProj, 0, make_delta());
316
317        let mut adapter2 = TestAdapter::new(rank, alpha);
318        adapter2.add_delta(TargetModule::QueryProj, 0, make_delta());
319
320        let mut stack = LoraStack::new();
321        stack.push_adapter(Arc::new(adapter1));
322        stack.push_adapter(Arc::new(adapter2));
323
324        let x = vec![1.0f32, 2.0, 3.0, 4.0];
325        let combined = stack
326            .applied_delta(TargetModule::QueryProj, 0, &x)
327            .expect("two adapters must produce a result");
328
329        // Compute single-adapter expected, multiply by 2.
330        let single = LoraDelta::new(a.clone(), b.clone(), rank, in_dim, out_dim)
331            .apply(&x, alpha / rank as f32);
332        for (c, s) in combined.iter().zip(single.iter()) {
333            let expected = s * 2.0;
334            assert!(
335                (c - expected).abs() < 1e-5,
336                "combined={c} expected={expected}"
337            );
338        }
339    }
340
341    /// Adapter that doesn't cover a target/layer returns None for that slot.
342    #[test]
343    fn adapter_not_covering_target_is_skipped() {
344        let mut adapter = TestAdapter::new(2, 2.0);
345        // only covers KeyProj layer 0
346        adapter.add_delta(
347            TargetModule::KeyProj,
348            0,
349            LoraDelta::new(vec![1.0; 4], vec![1.0; 4], 2, 2, 2),
350        );
351
352        let mut stack = LoraStack::new();
353        stack.push_adapter(Arc::new(adapter));
354
355        // Ask for QueryProj — not covered.
356        let result = stack.applied_delta(TargetModule::QueryProj, 0, &[1.0f32, 1.0]);
357        assert!(result.is_none(), "uncovered target must return None");
358    }
359
360    /// `with_lora_stack` persistence test: set stack, retrieve entries count.
361    #[test]
362    fn lora_stack_stores_adapters() {
363        let mut stack = LoraStack::new();
364        let mut a1 = TestAdapter::new(4, 4.0);
365        a1.add_delta(
366            TargetModule::ValueProj,
367            0,
368            LoraDelta::new(vec![0.0; 16], vec![0.0; 16], 4, 4, 4),
369        );
370        stack.push_adapter(Arc::new(a1));
371        assert_eq!(stack.adapter_list.len(), 1, "one adapter pushed");
372
373        let mut a2 = TestAdapter::new(8, 8.0);
374        a2.add_delta(
375            TargetModule::ValueProj,
376            1,
377            LoraDelta::new(vec![0.0; 64], vec![0.0; 64], 8, 8, 8),
378        );
379        stack.push_adapter(Arc::new(a2));
380        assert_eq!(stack.adapter_list.len(), 2, "two adapters pushed");
381    }
382
383    // ─── Legacy apply() tests (preserved from mod.rs) ────────────────────────
384
385    fn make_loaded_lora(
386        name: &str,
387        in_f: usize,
388        out_f: usize,
389        rank: usize,
390        fill: f32,
391    ) -> Arc<LoadedLora> {
392        use oxillama_quant::LoraAdapter;
393        let scale = 1.0_f32;
394        let adapter = Arc::new(
395            LoraAdapter::new(
396                vec![fill; rank * in_f],
397                vec![fill; out_f * rank],
398                rank,
399                scale,
400                in_f,
401                out_f,
402            )
403            .expect("valid adapter"),
404        );
405        let mut adapters = std::collections::HashMap::new();
406        adapters.insert(name.to_string(), adapter);
407        Arc::new(LoadedLora {
408            adapters,
409            rank,
410            alpha: rank as f32,
411        })
412    }
413
414    #[test]
415    fn empty_legacy_stack_returns_zeros() {
416        let stack = LoraStack::new();
417        let result = stack
418            .apply("blk.0.attn_q.weight", &[1.0, 2.0, 3.0, 4.0], 4)
419            .expect("apply ok");
420        assert_eq!(result, vec![0.0f32; 4]);
421    }
422
423    #[test]
424    fn legacy_stacked_adapters_add_linearly() {
425        let in_f = 4;
426        let out_f = 4;
427        let rank = 2;
428        let lora = make_loaded_lora("blk.0.attn_q.weight", in_f, out_f, rank, 0.5);
429
430        let mut stack_double = LoraStack::new();
431        stack_double.push(Arc::clone(&lora), 0.5);
432        stack_double.push(Arc::clone(&lora), 0.5);
433
434        let mut stack_single = LoraStack::new();
435        stack_single.push(Arc::clone(&lora), 1.0);
436
437        let input = vec![1.0f32; in_f];
438        let double = stack_double
439            .apply("blk.0.attn_q.weight", &input, out_f)
440            .expect("apply ok");
441        let single = stack_single
442            .apply("blk.0.attn_q.weight", &input, out_f)
443            .expect("apply ok");
444
445        for (a, b) in double.iter().zip(single.iter()) {
446            assert!((a - b).abs() < 1e-5, "double={a} single={b}");
447        }
448    }
449}