1use std::sync::Arc;
4
5use super::adapter::{LoraAdapterTrait, TargetModule};
6use super::loader::LoadedLora;
7use crate::error::ArchResult;
8
9#[derive(Default)]
22pub struct LoraStack {
23 entries: Vec<(Arc<LoadedLora>, f32)>,
25 adapter_list: Vec<Arc<dyn LoraAdapterTrait>>,
27}
28
29impl std::fmt::Debug for LoraStack {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("LoraStack")
32 .field("entries_len", &self.entries.len())
33 .field("adapter_list_len", &self.adapter_list.len())
34 .finish()
35 }
36}
37
38impl Clone for LoraStack {
39 fn clone(&self) -> Self {
40 Self {
41 entries: self.entries.clone(),
42 adapter_list: self.adapter_list.clone(),
44 }
45 }
46}
47
48impl LoraStack {
49 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn push(&mut self, adapter: Arc<LoadedLora>, scale: f32) {
60 self.entries.push((adapter, scale));
61 }
62
63 pub fn pop(&mut self) -> Option<(Arc<LoadedLora>, f32)> {
67 self.entries.pop()
68 }
69
70 pub fn clear(&mut self) {
72 self.entries.clear();
73 self.adapter_list.clear();
74 }
75
76 pub fn len(&self) -> usize {
78 self.entries.len()
79 }
80
81 pub fn is_empty(&self) -> bool {
83 self.entries.is_empty() && self.adapter_list.is_empty()
84 }
85
86 pub fn entries(&self) -> &[(Arc<LoadedLora>, f32)] {
88 &self.entries
89 }
90
91 pub fn adapters(&self) -> &[Arc<dyn LoraAdapterTrait>] {
93 &self.adapter_list
94 }
95
96 pub fn apply(
107 &self,
108 tensor_name: &str,
109 input: &[f32],
110 out_features: usize,
111 ) -> ArchResult<Vec<f32>> {
112 let mut delta = vec![0.0f32; out_features];
113 for (lora, stack_scale) in &self.entries {
114 let Some(adapter) = lora.get(tensor_name) else {
115 continue;
116 };
117 let rank = adapter.rank;
118 let in_f = adapter.in_features;
119 let out_f = adapter.out_features.min(out_features);
120
121 let mut r_vec = vec![0.0f32; rank];
123 for (i, r) in r_vec.iter_mut().enumerate() {
124 let row = &adapter.a[i * in_f..(i + 1) * in_f];
125 *r = row
126 .iter()
127 .zip(input.iter().take(in_f))
128 .map(|(&a, &x)| a * x)
129 .sum();
130 }
131
132 let combined = adapter.scale * stack_scale;
134 for (i, d) in delta.iter_mut().enumerate().take(out_f) {
135 let row = &adapter.b[i * rank..(i + 1) * rank];
136 let v: f32 = row.iter().zip(r_vec.iter()).map(|(&b, &r)| b * r).sum();
137 *d += v * combined;
138 }
139 }
140 Ok(delta)
141 }
142
143 pub fn push_adapter(&mut self, adapter: Arc<dyn LoraAdapterTrait>) {
147 self.adapter_list.push(adapter);
148 }
149
150 pub fn applied_delta(
157 &self,
158 target: TargetModule,
159 layer: usize,
160 input: &[f32],
161 ) -> Option<Vec<f32>> {
162 let mut result: Option<Vec<f32>> = None;
163 for adapter in &self.adapter_list {
164 let scale = adapter.alpha() / adapter.rank().max(1) as f32;
165 if let Some(delta) = adapter.delta(target, layer) {
166 let contribution = delta.apply(input, scale);
167 match result {
168 None => result = Some(contribution),
169 Some(ref mut acc) => {
170 for (a, c) in acc.iter_mut().zip(contribution.iter()) {
171 *a += c;
172 }
173 }
174 }
175 }
176 }
177 result
178 }
179}
180
181#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::lora::adapter::{LoraAdapterTrait, LoraDelta, TargetModule};
187 use std::collections::HashMap;
188 use std::sync::Arc;
189
190 struct TestAdapter {
194 rank: usize,
195 alpha: f32,
196 deltas: HashMap<(u32, usize), LoraDelta>,
197 modules: Vec<TargetModule>,
198 }
199
200 impl TestAdapter {
201 fn new(rank: usize, alpha: f32) -> Self {
202 Self {
203 rank,
204 alpha,
205 deltas: HashMap::new(),
206 modules: Vec::new(),
207 }
208 }
209
210 fn add_delta(&mut self, target: TargetModule, layer: usize, delta: LoraDelta) {
211 let key = (target_to_u32(target), layer);
212 if !self.modules.contains(&target) {
213 self.modules.push(target);
214 }
215 self.deltas.insert(key, delta);
216 }
217 }
218
219 fn target_to_u32(t: TargetModule) -> u32 {
220 match t {
221 TargetModule::QueryProj => 0,
222 TargetModule::KeyProj => 1,
223 TargetModule::ValueProj => 2,
224 TargetModule::OutputProj => 3,
225 TargetModule::GateProj => 4,
226 TargetModule::UpProj => 5,
227 TargetModule::DownProj => 6,
228 TargetModule::Custom(id) => 100 + id,
229 }
230 }
231
232 impl LoraAdapterTrait for TestAdapter {
233 fn rank(&self) -> usize {
234 self.rank
235 }
236 fn alpha(&self) -> f32 {
237 self.alpha
238 }
239 fn target_modules(&self) -> &[TargetModule] {
240 &self.modules
241 }
242 fn delta(&self, target: TargetModule, layer: usize) -> Option<&LoraDelta> {
243 let key = (target_to_u32(target), layer);
244 self.deltas.get(&key)
245 }
246 }
247
248 #[test]
252 fn empty_stack_applied_delta_returns_none() {
253 let stack = LoraStack::new();
254 let result = stack.applied_delta(TargetModule::QueryProj, 0, &[1.0f32, 2.0, 3.0]);
255 assert!(result.is_none(), "empty stack must return None");
256 }
257
258 #[test]
260 fn single_lora_identity_matches_input() {
261 let rank = 4;
262 let in_dim = 4;
263 let out_dim = 4;
264
265 let mut a = vec![0.0f32; rank * in_dim];
267 let mut b = vec![0.0f32; out_dim * rank];
268 for i in 0..rank {
269 a[i * in_dim + i] = 1.0;
270 b[i * rank + i] = 1.0;
271 }
272 let delta = LoraDelta::new(a, b, rank, in_dim, out_dim);
273 let alpha = rank as f32;
274
275 let mut adapter = TestAdapter::new(rank, alpha);
276 adapter.add_delta(TargetModule::QueryProj, 0, delta);
277
278 let mut stack = LoraStack::new();
279 stack.push_adapter(Arc::new(adapter));
280
281 let x = vec![1.0f32, 2.0, 3.0, 4.0];
282 let result = stack
283 .applied_delta(TargetModule::QueryProj, 0, &x)
284 .expect("single adapter must produce a result");
285
286 for (r, xi) in result.iter().zip(x.iter()) {
288 assert!((r - xi).abs() < 1e-5, "expected {xi} got {r}");
289 }
290 }
291
292 #[test]
294 fn two_loras_compose_additively() {
295 let rank = 2;
296 let in_dim = 4;
297 let out_dim = 4;
298 let alpha = 2.0f32; let a = vec![
302 1.0f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ];
305 let b = vec![
306 1.0f32, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ];
311
312 let make_delta = || LoraDelta::new(a.clone(), b.clone(), rank, in_dim, out_dim);
313
314 let mut adapter1 = TestAdapter::new(rank, alpha);
315 adapter1.add_delta(TargetModule::QueryProj, 0, make_delta());
316
317 let mut adapter2 = TestAdapter::new(rank, alpha);
318 adapter2.add_delta(TargetModule::QueryProj, 0, make_delta());
319
320 let mut stack = LoraStack::new();
321 stack.push_adapter(Arc::new(adapter1));
322 stack.push_adapter(Arc::new(adapter2));
323
324 let x = vec![1.0f32, 2.0, 3.0, 4.0];
325 let combined = stack
326 .applied_delta(TargetModule::QueryProj, 0, &x)
327 .expect("two adapters must produce a result");
328
329 let single = LoraDelta::new(a.clone(), b.clone(), rank, in_dim, out_dim)
331 .apply(&x, alpha / rank as f32);
332 for (c, s) in combined.iter().zip(single.iter()) {
333 let expected = s * 2.0;
334 assert!(
335 (c - expected).abs() < 1e-5,
336 "combined={c} expected={expected}"
337 );
338 }
339 }
340
341 #[test]
343 fn adapter_not_covering_target_is_skipped() {
344 let mut adapter = TestAdapter::new(2, 2.0);
345 adapter.add_delta(
347 TargetModule::KeyProj,
348 0,
349 LoraDelta::new(vec![1.0; 4], vec![1.0; 4], 2, 2, 2),
350 );
351
352 let mut stack = LoraStack::new();
353 stack.push_adapter(Arc::new(adapter));
354
355 let result = stack.applied_delta(TargetModule::QueryProj, 0, &[1.0f32, 1.0]);
357 assert!(result.is_none(), "uncovered target must return None");
358 }
359
360 #[test]
362 fn lora_stack_stores_adapters() {
363 let mut stack = LoraStack::new();
364 let mut a1 = TestAdapter::new(4, 4.0);
365 a1.add_delta(
366 TargetModule::ValueProj,
367 0,
368 LoraDelta::new(vec![0.0; 16], vec![0.0; 16], 4, 4, 4),
369 );
370 stack.push_adapter(Arc::new(a1));
371 assert_eq!(stack.adapter_list.len(), 1, "one adapter pushed");
372
373 let mut a2 = TestAdapter::new(8, 8.0);
374 a2.add_delta(
375 TargetModule::ValueProj,
376 1,
377 LoraDelta::new(vec![0.0; 64], vec![0.0; 64], 8, 8, 8),
378 );
379 stack.push_adapter(Arc::new(a2));
380 assert_eq!(stack.adapter_list.len(), 2, "two adapters pushed");
381 }
382
383 fn make_loaded_lora(
386 name: &str,
387 in_f: usize,
388 out_f: usize,
389 rank: usize,
390 fill: f32,
391 ) -> Arc<LoadedLora> {
392 use oxillama_quant::LoraAdapter;
393 let scale = 1.0_f32;
394 let adapter = Arc::new(
395 LoraAdapter::new(
396 vec![fill; rank * in_f],
397 vec![fill; out_f * rank],
398 rank,
399 scale,
400 in_f,
401 out_f,
402 )
403 .expect("valid adapter"),
404 );
405 let mut adapters = std::collections::HashMap::new();
406 adapters.insert(name.to_string(), adapter);
407 Arc::new(LoadedLora {
408 adapters,
409 rank,
410 alpha: rank as f32,
411 })
412 }
413
414 #[test]
415 fn empty_legacy_stack_returns_zeros() {
416 let stack = LoraStack::new();
417 let result = stack
418 .apply("blk.0.attn_q.weight", &[1.0, 2.0, 3.0, 4.0], 4)
419 .expect("apply ok");
420 assert_eq!(result, vec![0.0f32; 4]);
421 }
422
423 #[test]
424 fn legacy_stacked_adapters_add_linearly() {
425 let in_f = 4;
426 let out_f = 4;
427 let rank = 2;
428 let lora = make_loaded_lora("blk.0.attn_q.weight", in_f, out_f, rank, 0.5);
429
430 let mut stack_double = LoraStack::new();
431 stack_double.push(Arc::clone(&lora), 0.5);
432 stack_double.push(Arc::clone(&lora), 0.5);
433
434 let mut stack_single = LoraStack::new();
435 stack_single.push(Arc::clone(&lora), 1.0);
436
437 let input = vec![1.0f32; in_f];
438 let double = stack_double
439 .apply("blk.0.attn_q.weight", &input, out_f)
440 .expect("apply ok");
441 let single = stack_single
442 .apply("blk.0.attn_q.weight", &input, out_f)
443 .expect("apply ok");
444
445 for (a, b) in double.iter().zip(single.iter()) {
446 assert!((a - b).abs() < 1e-5, "double={a} single={b}");
447 }
448 }
449}