Skip to main content

oxillama_arch/lora/
loader.rs

1//! LoRA adapter loading from GGUF files.
2//!
3//! LoRA adapter GGUF files follow the same binary format as model GGUF files,
4//! but contain only the adapter weight matrices (A and B per adapted layer),
5//! plus metadata keys for rank and alpha:
6//!
7//! | Key                  | Fallback key              | Type  | Description                   |
8//! |----------------------|---------------------------|-------|-------------------------------|
9//! | `lora.r`             | `adapter.lora.r`          | u32   | Low-rank dimension `r`        |
10//! | `lora.alpha`         | `adapter.lora.alpha`      | f32   | LoRA alpha (scale numerator)  |
11//!
12//! ## Tensor naming convention (llama.cpp-compatible)
13//!
14//! Each adapted linear layer provides two tensors:
15//!
16//! ```text
17//! blk.{i}.attn_q.weight.lora_a      blk.{i}.attn_q.weight.lora_b
18//! blk.{i}.attn_k.weight.lora_a      blk.{i}.attn_k.weight.lora_b
19//! blk.{i}.attn_v.weight.lora_a      blk.{i}.attn_v.weight.lora_b
20//! blk.{i}.attn_output.weight.lora_a blk.{i}.attn_output.weight.lora_b
21//! blk.{i}.ffn_gate.weight.lora_a    blk.{i}.ffn_gate.weight.lora_b
22//! blk.{i}.ffn_up.weight.lora_a      blk.{i}.ffn_up.weight.lora_b
23//! blk.{i}.ffn_down.weight.lora_a    blk.{i}.ffn_down.weight.lora_b
24//! ```
25//!
26//! A tensors have shape `[rank, in_features]`; B tensors have shape
27//! `[out_features, rank]`.
28
29use std::collections::HashMap;
30use std::sync::Arc;
31
32use oxillama_gguf::{GgufModel, GgufTensorType, TensorInfo};
33use oxillama_quant::{KernelDispatcher, LoraAdapter, QuantError};
34
35use crate::error::{ArchError, ArchResult};
36
37pub(super) const LORA_A_SUFFIX: &str = ".lora_a";
38pub(super) const LORA_B_SUFFIX: &str = ".lora_b";
39
40/// A fully loaded LoRA adapter, mapping tensor base names to their adapters.
41///
42/// The key in [`adapters`](Self::adapters) is the tensor base name without the
43/// `.lora_a` / `.lora_b` suffix (e.g. `"blk.0.attn_q.weight"`).
44///
45/// After loading, attach adapters to `QuantLinear` layers via
46/// [`LoadedLora::get`] + `QuantLinear::set_lora`.
47#[derive(Debug)]
48pub struct LoadedLora {
49    /// Map from tensor base name → LoRA adapter.
50    pub adapters: HashMap<String, Arc<LoraAdapter>>,
51    /// The LoRA rank used throughout this adapter file.
52    pub rank: usize,
53    /// The alpha value (scale numerator).
54    pub alpha: f32,
55}
56
57impl LoadedLora {
58    /// Load a LoRA adapter from a GGUF file on disk.
59    ///
60    /// # Errors
61    ///
62    /// Returns [`ArchError::Gguf`] if the file cannot be parsed, or
63    /// [`ArchError::Quant`] if dequantization of any adapter tensor fails.
64    pub fn load(path: &str) -> ArchResult<Self> {
65        let model = GgufModel::load(path)?;
66        Self::from_gguf(&model)
67    }
68
69    /// Load a LoRA adapter from an already-loaded [`GgufModel`].
70    ///
71    /// This is the primary construction path, separated from file I/O for
72    /// testability.
73    pub fn from_gguf(model: &GgufModel) -> ArchResult<Self> {
74        // --- Extract rank --------------------------------------------------
75        let rank = model
76            .file
77            .metadata
78            .get("lora.r")
79            .and_then(|v| v.as_u32())
80            .or_else(|| {
81                model
82                    .file
83                    .metadata
84                    .get("adapter.lora.r")
85                    .and_then(|v| v.as_u32())
86            })
87            .map(|v| v as usize)
88            .unwrap_or(8);
89
90        // --- Extract alpha -------------------------------------------------
91        let alpha = model
92            .file
93            .metadata
94            .get("lora.alpha")
95            .and_then(|v| v.as_f32())
96            .or_else(|| {
97                model
98                    .file
99                    .metadata
100                    .get("adapter.lora.alpha")
101                    .and_then(|v| v.as_f32())
102            })
103            .unwrap_or(rank as f32);
104
105        let scale = alpha / rank.max(1) as f32;
106        let dispatcher = KernelDispatcher::new();
107
108        // --- Find and pair .lora_a / .lora_b tensors -----------------------
109        let tensor_names: Vec<String> = model.file.tensors.names().cloned().collect();
110        let mut adapters: HashMap<String, Arc<LoraAdapter>> = HashMap::new();
111
112        for name in &tensor_names {
113            if !name.ends_with(LORA_A_SUFFIX) {
114                continue;
115            }
116            let base = &name[..name.len() - LORA_A_SUFFIX.len()];
117            let b_name = format!("{base}{LORA_B_SUFFIX}");
118
119            if !model.file.tensors.contains(&b_name) {
120                tracing::warn!(
121                    tensor = %name,
122                    "LoRA tensor has no matching .lora_b partner; skipping"
123                );
124                continue;
125            }
126
127            let a_info = model
128                .file
129                .tensors
130                .get(name)
131                .map_err(|_| ArchError::MissingTensor { name: name.clone() })?;
132            let a_data = model.tensor_data(name)?;
133            let a_f32 = dequant_tensor_to_f32(a_info, a_data, &dispatcher)?;
134
135            let (rank_actual, in_features) = shape_to_rank_in(a_info, rank, a_f32.len());
136
137            let b_info = model
138                .file
139                .tensors
140                .get(&b_name)
141                .map_err(|_| ArchError::MissingTensor {
142                    name: b_name.clone(),
143                })?;
144            let b_data = model.tensor_data(&b_name)?;
145            let b_f32 = dequant_tensor_to_f32(b_info, b_data, &dispatcher)?;
146
147            let out_features = b_f32.len().checked_div(rank_actual).unwrap_or(0);
148
149            let adapter =
150                LoraAdapter::new(a_f32, b_f32, rank_actual, scale, in_features, out_features)
151                    .map_err(ArchError::Quant)?;
152
153            adapters.insert(base.to_string(), Arc::new(adapter));
154        }
155
156        tracing::debug!(
157            rank = rank,
158            alpha = alpha,
159            adapters = adapters.len(),
160            "LoRA adapter loaded from GGUF"
161        );
162
163        Ok(Self {
164            adapters,
165            rank,
166            alpha,
167        })
168    }
169
170    /// Look up the adapter for a named linear weight tensor.
171    ///
172    /// Returns `None` if this LoRA file does not patch the named tensor.
173    pub fn get(&self, tensor_name: &str) -> Option<Arc<LoraAdapter>> {
174        self.adapters.get(tensor_name).cloned()
175    }
176
177    /// Number of adapted layers in this adapter file.
178    pub fn num_adapters(&self) -> usize {
179        self.adapters.len()
180    }
181}
182
183// ─── Internal helpers ─────────────────────────────────────────────────────────
184
185/// Infer `(rank_actual, in_features)` from a LoRA-A tensor's shape metadata.
186fn shape_to_rank_in(info: &TensorInfo, hint_rank: usize, n_elements: usize) -> (usize, usize) {
187    match info.dimensions.as_slice() {
188        [in_f, r] => (*r as usize, *in_f as usize),
189        [total] => {
190            let r = hint_rank.max(1);
191            let in_f = (*total as usize) / r;
192            (r, in_f)
193        }
194        _ => {
195            let r = hint_rank.max(1);
196            (r, n_elements / r)
197        }
198    }
199}
200
201/// Dequantize arbitrary-typed tensor data to a Vec<f32>.
202pub(crate) fn dequant_tensor_to_f32(
203    info: &TensorInfo,
204    data: &[u8],
205    dispatcher: &KernelDispatcher,
206) -> ArchResult<Vec<f32>> {
207    let n_elements = info.n_elements() as usize;
208
209    if info.tensor_type == GgufTensorType::F32 {
210        let mut out = vec![0.0f32; n_elements];
211        for (i, chunk) in data.chunks_exact(4).enumerate().take(n_elements) {
212            out[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
213        }
214        return Ok(out);
215    }
216
217    if info.tensor_type == GgufTensorType::F16 {
218        let mut out = vec![0.0f32; n_elements];
219        for (i, chunk) in data.chunks_exact(2).enumerate().take(n_elements) {
220            let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
221            out[i] = half::f16::from_bits(bits).to_f32();
222        }
223        return Ok(out);
224    }
225
226    let kernel = dispatcher
227        .get_kernel(info.tensor_type)
228        .map_err(ArchError::Quant)?;
229    let block_size = kernel.block_size();
230    let block_bytes = kernel.block_bytes();
231
232    if block_size == 0 || block_bytes == 0 {
233        return Err(ArchError::Quant(QuantError::UnsupportedType {
234            quant_type: format!("{:?}", info.tensor_type),
235        }));
236    }
237
238    let n_blocks = n_elements.div_ceil(block_size);
239    let mut out = vec![0.0f32; n_elements];
240
241    for b in 0..n_blocks {
242        let block_start = b * block_bytes;
243        let out_start = b * block_size;
244        let block_end = (block_start + block_bytes).min(data.len());
245        let out_end = (out_start + block_size).min(n_elements);
246
247        if block_end <= block_start {
248            break;
249        }
250
251        kernel
252            .dequant_block(&data[block_start..block_end], &mut out[out_start..out_end])
253            .map_err(ArchError::Quant)?;
254    }
255
256    Ok(out)
257}
258
259// ─── Tests ────────────────────────────────────────────────────────────────────
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_loaded_lora_empty_construction() {
267        let lora = LoadedLora {
268            adapters: HashMap::new(),
269            rank: 8,
270            alpha: 8.0,
271        };
272        assert_eq!(lora.num_adapters(), 0);
273        assert_eq!(lora.rank, 8);
274        assert!(lora.get("blk.0.attn_q.weight").is_none());
275    }
276
277    #[test]
278    fn test_shape_to_rank_in_2d() {
279        let info = TensorInfo {
280            name: "test.lora_a".into(),
281            n_dims: 2,
282            dimensions: vec![64, 8],
283            tensor_type: GgufTensorType::F32,
284            offset: 0,
285        };
286        let (r, in_f) = shape_to_rank_in(&info, 8, 64 * 8);
287        assert_eq!(r, 8, "rank should be 8 (dims[1])");
288        assert_eq!(in_f, 64, "in_features should be 64 (dims[0])");
289    }
290
291    #[test]
292    fn test_shape_to_rank_in_1d() {
293        let info = TensorInfo {
294            name: "test.lora_a".into(),
295            n_dims: 1,
296            dimensions: vec![128],
297            tensor_type: GgufTensorType::F32,
298            offset: 0,
299        };
300        let (r, in_f) = shape_to_rank_in(&info, 8, 128);
301        assert_eq!(r, 8);
302        assert_eq!(in_f, 16);
303    }
304
305    #[test]
306    fn test_get_missing() {
307        let lora = LoadedLora {
308            adapters: HashMap::new(),
309            rank: 4,
310            alpha: 4.0,
311        };
312        assert!(lora.get("blk.99.ffn_gate.weight").is_none());
313    }
314
315    #[test]
316    fn test_get_present() {
317        let adapter =
318            Arc::new(LoraAdapter::new(vec![1.0], vec![1.0], 1, 1.0, 1, 1).expect("valid"));
319        let mut adapters = HashMap::new();
320        adapters.insert("blk.0.attn_q.weight".to_string(), adapter);
321
322        let lora = LoadedLora {
323            adapters,
324            rank: 1,
325            alpha: 1.0,
326        };
327        assert!(lora.get("blk.0.attn_q.weight").is_some());
328    }
329
330    #[test]
331    fn test_loaded_lora_from_gguf_succeeds() {
332        use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
333        let bytes = build_minimal_lora_gguf();
334        let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
335        let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
336        assert!(lora.rank > 0, "rank must be positive");
337        assert!(lora.alpha > 0.0, "alpha must be positive");
338        assert!(!lora.adapters.is_empty(), "adapters map must not be empty");
339    }
340
341    #[test]
342    fn test_loaded_lora_rank_matches_metadata() {
343        use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
344        let bytes = build_minimal_lora_gguf();
345        let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
346        let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
347        assert_eq!(lora.rank, 4, "rank should match lora.r=4 in synthetic GGUF");
348    }
349
350    #[test]
351    fn test_loaded_lora_alpha_matches_metadata() {
352        use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
353        let bytes = build_minimal_lora_gguf();
354        let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
355        let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
356        assert!(
357            (lora.alpha - 8.0).abs() < 1e-5,
358            "alpha should match lora.alpha=8.0, got {}",
359            lora.alpha
360        );
361    }
362
363    #[test]
364    fn test_loaded_lora_contains_expected_adapters() {
365        use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
366        let bytes = build_minimal_lora_gguf();
367        let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
368        let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
369        assert_eq!(
370            lora.adapters.len(),
371            3,
372            "expected 3 lora adapters (attn_q, attn_v, ffn_gate), got {}",
373            lora.adapters.len()
374        );
375    }
376
377    #[test]
378    fn test_loaded_lora_get_returns_adapter() {
379        use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
380        let bytes = build_minimal_lora_gguf();
381        let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
382        let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
383        let adapter = lora.get("blk.0.attn_q.weight");
384        assert!(
385            adapter.is_some(),
386            "expected to find adapter for blk.0.attn_q.weight"
387        );
388    }
389
390    #[test]
391    fn test_loaded_lora_get_missing_returns_none() {
392        use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
393        let bytes = build_minimal_lora_gguf();
394        let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
395        let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
396        let adapter = lora.get("nonexistent_layer");
397        assert!(adapter.is_none(), "nonexistent layer should return None");
398    }
399
400    #[test]
401    fn test_loaded_lora_from_base_model_has_empty_adapters() {
402        use oxillama_gguf::{test_utils::build_minimal_llama_gguf, GgufModel};
403        let bytes = build_minimal_llama_gguf();
404        let model = GgufModel::from_bytes(bytes).expect("test: parse base model gguf");
405        let lora = LoadedLora::from_gguf(&model).expect("test: load from base model");
406        assert!(
407            lora.adapters.is_empty(),
408            "base model (no .lora_a tensors) should yield zero adapters"
409        );
410    }
411
412    #[test]
413    fn test_loaded_lora_default_rank_when_missing() {
414        use oxillama_gguf::{test_utils::build_minimal_llama_gguf, GgufModel};
415        let bytes = build_minimal_llama_gguf();
416        let model = GgufModel::from_bytes(bytes).expect("test: parse base model gguf");
417        let lora = LoadedLora::from_gguf(&model).expect("test: load from base model");
418        assert_eq!(
419            lora.rank, 8,
420            "without lora.r key the default rank should be 8"
421        );
422    }
423
424    #[test]
425    fn test_loaded_lora_all_three_adapters_reachable() {
426        use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
427        let bytes = build_minimal_lora_gguf();
428        let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
429        let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
430        for layer_name in [
431            "blk.0.attn_q.weight",
432            "blk.0.attn_v.weight",
433            "blk.0.ffn_gate.weight",
434        ] {
435            assert!(
436                lora.get(layer_name).is_some(),
437                "adapter for '{layer_name}' must be reachable via get()"
438            );
439        }
440    }
441
442    #[test]
443    fn test_loaded_lora_num_adapters_matches_len() {
444        use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
445        let bytes = build_minimal_lora_gguf();
446        let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
447        let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
448        assert_eq!(
449            lora.num_adapters(),
450            lora.adapters.len(),
451            "num_adapters() must equal adapters.len()"
452        );
453    }
454}