1use std::collections::HashMap;
30use std::sync::Arc;
31
32use oxillama_gguf::{GgufModel, GgufTensorType, TensorInfo};
33use oxillama_quant::{KernelDispatcher, LoraAdapter, QuantError};
34
35use crate::error::{ArchError, ArchResult};
36
37pub(super) const LORA_A_SUFFIX: &str = ".lora_a";
38pub(super) const LORA_B_SUFFIX: &str = ".lora_b";
39
40#[derive(Debug)]
48pub struct LoadedLora {
49 pub adapters: HashMap<String, Arc<LoraAdapter>>,
51 pub rank: usize,
53 pub alpha: f32,
55}
56
57impl LoadedLora {
58 pub fn load(path: &str) -> ArchResult<Self> {
65 let model = GgufModel::load(path)?;
66 Self::from_gguf(&model)
67 }
68
69 pub fn from_gguf(model: &GgufModel) -> ArchResult<Self> {
74 let rank = model
76 .file
77 .metadata
78 .get("lora.r")
79 .and_then(|v| v.as_u32())
80 .or_else(|| {
81 model
82 .file
83 .metadata
84 .get("adapter.lora.r")
85 .and_then(|v| v.as_u32())
86 })
87 .map(|v| v as usize)
88 .unwrap_or(8);
89
90 let alpha = model
92 .file
93 .metadata
94 .get("lora.alpha")
95 .and_then(|v| v.as_f32())
96 .or_else(|| {
97 model
98 .file
99 .metadata
100 .get("adapter.lora.alpha")
101 .and_then(|v| v.as_f32())
102 })
103 .unwrap_or(rank as f32);
104
105 let scale = alpha / rank.max(1) as f32;
106 let dispatcher = KernelDispatcher::new();
107
108 let tensor_names: Vec<String> = model.file.tensors.names().cloned().collect();
110 let mut adapters: HashMap<String, Arc<LoraAdapter>> = HashMap::new();
111
112 for name in &tensor_names {
113 if !name.ends_with(LORA_A_SUFFIX) {
114 continue;
115 }
116 let base = &name[..name.len() - LORA_A_SUFFIX.len()];
117 let b_name = format!("{base}{LORA_B_SUFFIX}");
118
119 if !model.file.tensors.contains(&b_name) {
120 tracing::warn!(
121 tensor = %name,
122 "LoRA tensor has no matching .lora_b partner; skipping"
123 );
124 continue;
125 }
126
127 let a_info = model
128 .file
129 .tensors
130 .get(name)
131 .map_err(|_| ArchError::MissingTensor { name: name.clone() })?;
132 let a_data = model.tensor_data(name)?;
133 let a_f32 = dequant_tensor_to_f32(a_info, a_data, &dispatcher)?;
134
135 let (rank_actual, in_features) = shape_to_rank_in(a_info, rank, a_f32.len());
136
137 let b_info = model
138 .file
139 .tensors
140 .get(&b_name)
141 .map_err(|_| ArchError::MissingTensor {
142 name: b_name.clone(),
143 })?;
144 let b_data = model.tensor_data(&b_name)?;
145 let b_f32 = dequant_tensor_to_f32(b_info, b_data, &dispatcher)?;
146
147 let out_features = b_f32.len().checked_div(rank_actual).unwrap_or(0);
148
149 let adapter =
150 LoraAdapter::new(a_f32, b_f32, rank_actual, scale, in_features, out_features)
151 .map_err(ArchError::Quant)?;
152
153 adapters.insert(base.to_string(), Arc::new(adapter));
154 }
155
156 tracing::debug!(
157 rank = rank,
158 alpha = alpha,
159 adapters = adapters.len(),
160 "LoRA adapter loaded from GGUF"
161 );
162
163 Ok(Self {
164 adapters,
165 rank,
166 alpha,
167 })
168 }
169
170 pub fn get(&self, tensor_name: &str) -> Option<Arc<LoraAdapter>> {
174 self.adapters.get(tensor_name).cloned()
175 }
176
177 pub fn num_adapters(&self) -> usize {
179 self.adapters.len()
180 }
181}
182
183fn shape_to_rank_in(info: &TensorInfo, hint_rank: usize, n_elements: usize) -> (usize, usize) {
187 match info.dimensions.as_slice() {
188 [in_f, r] => (*r as usize, *in_f as usize),
189 [total] => {
190 let r = hint_rank.max(1);
191 let in_f = (*total as usize) / r;
192 (r, in_f)
193 }
194 _ => {
195 let r = hint_rank.max(1);
196 (r, n_elements / r)
197 }
198 }
199}
200
201pub(crate) fn dequant_tensor_to_f32(
203 info: &TensorInfo,
204 data: &[u8],
205 dispatcher: &KernelDispatcher,
206) -> ArchResult<Vec<f32>> {
207 let n_elements = info.n_elements() as usize;
208
209 if info.tensor_type == GgufTensorType::F32 {
210 let mut out = vec![0.0f32; n_elements];
211 for (i, chunk) in data.chunks_exact(4).enumerate().take(n_elements) {
212 out[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
213 }
214 return Ok(out);
215 }
216
217 if info.tensor_type == GgufTensorType::F16 {
218 let mut out = vec![0.0f32; n_elements];
219 for (i, chunk) in data.chunks_exact(2).enumerate().take(n_elements) {
220 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
221 out[i] = half::f16::from_bits(bits).to_f32();
222 }
223 return Ok(out);
224 }
225
226 let kernel = dispatcher
227 .get_kernel(info.tensor_type)
228 .map_err(ArchError::Quant)?;
229 let block_size = kernel.block_size();
230 let block_bytes = kernel.block_bytes();
231
232 if block_size == 0 || block_bytes == 0 {
233 return Err(ArchError::Quant(QuantError::UnsupportedType {
234 quant_type: format!("{:?}", info.tensor_type),
235 }));
236 }
237
238 let n_blocks = n_elements.div_ceil(block_size);
239 let mut out = vec![0.0f32; n_elements];
240
241 for b in 0..n_blocks {
242 let block_start = b * block_bytes;
243 let out_start = b * block_size;
244 let block_end = (block_start + block_bytes).min(data.len());
245 let out_end = (out_start + block_size).min(n_elements);
246
247 if block_end <= block_start {
248 break;
249 }
250
251 kernel
252 .dequant_block(&data[block_start..block_end], &mut out[out_start..out_end])
253 .map_err(ArchError::Quant)?;
254 }
255
256 Ok(out)
257}
258
259#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_loaded_lora_empty_construction() {
267 let lora = LoadedLora {
268 adapters: HashMap::new(),
269 rank: 8,
270 alpha: 8.0,
271 };
272 assert_eq!(lora.num_adapters(), 0);
273 assert_eq!(lora.rank, 8);
274 assert!(lora.get("blk.0.attn_q.weight").is_none());
275 }
276
277 #[test]
278 fn test_shape_to_rank_in_2d() {
279 let info = TensorInfo {
280 name: "test.lora_a".into(),
281 n_dims: 2,
282 dimensions: vec![64, 8],
283 tensor_type: GgufTensorType::F32,
284 offset: 0,
285 };
286 let (r, in_f) = shape_to_rank_in(&info, 8, 64 * 8);
287 assert_eq!(r, 8, "rank should be 8 (dims[1])");
288 assert_eq!(in_f, 64, "in_features should be 64 (dims[0])");
289 }
290
291 #[test]
292 fn test_shape_to_rank_in_1d() {
293 let info = TensorInfo {
294 name: "test.lora_a".into(),
295 n_dims: 1,
296 dimensions: vec![128],
297 tensor_type: GgufTensorType::F32,
298 offset: 0,
299 };
300 let (r, in_f) = shape_to_rank_in(&info, 8, 128);
301 assert_eq!(r, 8);
302 assert_eq!(in_f, 16);
303 }
304
305 #[test]
306 fn test_get_missing() {
307 let lora = LoadedLora {
308 adapters: HashMap::new(),
309 rank: 4,
310 alpha: 4.0,
311 };
312 assert!(lora.get("blk.99.ffn_gate.weight").is_none());
313 }
314
315 #[test]
316 fn test_get_present() {
317 let adapter =
318 Arc::new(LoraAdapter::new(vec![1.0], vec![1.0], 1, 1.0, 1, 1).expect("valid"));
319 let mut adapters = HashMap::new();
320 adapters.insert("blk.0.attn_q.weight".to_string(), adapter);
321
322 let lora = LoadedLora {
323 adapters,
324 rank: 1,
325 alpha: 1.0,
326 };
327 assert!(lora.get("blk.0.attn_q.weight").is_some());
328 }
329
330 #[test]
331 fn test_loaded_lora_from_gguf_succeeds() {
332 use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
333 let bytes = build_minimal_lora_gguf();
334 let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
335 let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
336 assert!(lora.rank > 0, "rank must be positive");
337 assert!(lora.alpha > 0.0, "alpha must be positive");
338 assert!(!lora.adapters.is_empty(), "adapters map must not be empty");
339 }
340
341 #[test]
342 fn test_loaded_lora_rank_matches_metadata() {
343 use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
344 let bytes = build_minimal_lora_gguf();
345 let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
346 let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
347 assert_eq!(lora.rank, 4, "rank should match lora.r=4 in synthetic GGUF");
348 }
349
350 #[test]
351 fn test_loaded_lora_alpha_matches_metadata() {
352 use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
353 let bytes = build_minimal_lora_gguf();
354 let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
355 let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
356 assert!(
357 (lora.alpha - 8.0).abs() < 1e-5,
358 "alpha should match lora.alpha=8.0, got {}",
359 lora.alpha
360 );
361 }
362
363 #[test]
364 fn test_loaded_lora_contains_expected_adapters() {
365 use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
366 let bytes = build_minimal_lora_gguf();
367 let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
368 let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
369 assert_eq!(
370 lora.adapters.len(),
371 3,
372 "expected 3 lora adapters (attn_q, attn_v, ffn_gate), got {}",
373 lora.adapters.len()
374 );
375 }
376
377 #[test]
378 fn test_loaded_lora_get_returns_adapter() {
379 use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
380 let bytes = build_minimal_lora_gguf();
381 let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
382 let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
383 let adapter = lora.get("blk.0.attn_q.weight");
384 assert!(
385 adapter.is_some(),
386 "expected to find adapter for blk.0.attn_q.weight"
387 );
388 }
389
390 #[test]
391 fn test_loaded_lora_get_missing_returns_none() {
392 use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
393 let bytes = build_minimal_lora_gguf();
394 let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
395 let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
396 let adapter = lora.get("nonexistent_layer");
397 assert!(adapter.is_none(), "nonexistent layer should return None");
398 }
399
400 #[test]
401 fn test_loaded_lora_from_base_model_has_empty_adapters() {
402 use oxillama_gguf::{test_utils::build_minimal_llama_gguf, GgufModel};
403 let bytes = build_minimal_llama_gguf();
404 let model = GgufModel::from_bytes(bytes).expect("test: parse base model gguf");
405 let lora = LoadedLora::from_gguf(&model).expect("test: load from base model");
406 assert!(
407 lora.adapters.is_empty(),
408 "base model (no .lora_a tensors) should yield zero adapters"
409 );
410 }
411
412 #[test]
413 fn test_loaded_lora_default_rank_when_missing() {
414 use oxillama_gguf::{test_utils::build_minimal_llama_gguf, GgufModel};
415 let bytes = build_minimal_llama_gguf();
416 let model = GgufModel::from_bytes(bytes).expect("test: parse base model gguf");
417 let lora = LoadedLora::from_gguf(&model).expect("test: load from base model");
418 assert_eq!(
419 lora.rank, 8,
420 "without lora.r key the default rank should be 8"
421 );
422 }
423
424 #[test]
425 fn test_loaded_lora_all_three_adapters_reachable() {
426 use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
427 let bytes = build_minimal_lora_gguf();
428 let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
429 let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
430 for layer_name in [
431 "blk.0.attn_q.weight",
432 "blk.0.attn_v.weight",
433 "blk.0.ffn_gate.weight",
434 ] {
435 assert!(
436 lora.get(layer_name).is_some(),
437 "adapter for '{layer_name}' must be reachable via get()"
438 );
439 }
440 }
441
442 #[test]
443 fn test_loaded_lora_num_adapters_matches_len() {
444 use oxillama_gguf::{test_utils::build_minimal_lora_gguf, GgufModel};
445 let bytes = build_minimal_lora_gguf();
446 let model = GgufModel::from_bytes(bytes).expect("test: parse lora gguf");
447 let lora = LoadedLora::from_gguf(&model).expect("test: load lora from gguf");
448 assert_eq!(
449 lora.num_adapters(),
450 lora.adapters.len(),
451 "num_adapters() must equal adapters.len()"
452 );
453 }
454}