1use candle_core::{DType, Device, Tensor};
15use candle_nn::VarBuilder;
16use peft_rs::{Adapter, LoraConfig, LoraLayer};
17use serde::{Deserialize, Serialize};
18
19use crate::error::{QLoraError, Result};
20use crate::quantization::{
21 dequantize_nf4, quantize_nf4_with_config, ComputeDType, QuantizationConfig, QuantizedTensor,
22};
23
24fn warn_cpu_fallback(device: &Device) {
25 static WARN_ONCE: std::sync::Once = std::sync::Once::new();
26 if matches!(device, Device::Cpu) {
27 WARN_ONCE.call_once(|| {
28 eprintln!(
29 "qlora-rs: CPU device in use. CUDA is the intended default; enable the 'cuda' feature and use Device::cuda_if_available(0) when possible."
30 );
31 });
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct QLoraConfig {
67 pub lora: LoraConfig,
69 pub quantization: QuantizationConfig,
71 #[serde(default = "default_target_modules")]
74 pub target_modules: Vec<String>,
75 #[serde(default)]
78 pub cache_dequantized: bool,
79}
80
81fn default_target_modules() -> Vec<String> {
82 vec![
83 "q_proj".into(),
84 "k_proj".into(),
85 "v_proj".into(),
86 "o_proj".into(),
87 "gate_proj".into(),
88 "up_proj".into(),
89 "down_proj".into(),
90 ]
91}
92
93impl Default for QLoraConfig {
94 fn default() -> Self {
98 Self {
99 lora: LoraConfig {
100 r: 64,
101 alpha: 16,
102 dropout: 0.05,
103 ..Default::default()
104 },
105 quantization: QuantizationConfig {
106 block_size: 64,
107 double_quant: true,
108 compute_dtype: ComputeDType::BF16, ..Default::default()
110 },
111 target_modules: default_target_modules(),
112 cache_dequantized: false, }
114 }
115}
116
117impl QLoraConfig {
118 #[must_use]
126 pub fn preset_all_bf16(r: usize, alpha: usize) -> Self {
127 Self {
128 lora: LoraConfig {
129 r,
130 alpha,
131 dropout: 0.05,
132 ..Default::default()
133 },
134 quantization: QuantizationConfig {
135 block_size: 64,
136 double_quant: true,
137 compute_dtype: ComputeDType::BF16,
138 ..Default::default()
139 },
140 target_modules: default_target_modules(),
141 cache_dequantized: false,
142 }
143 }
144
145 #[must_use]
150 pub fn preset_qv_bf16(r: usize, alpha: usize) -> Self {
151 Self {
152 lora: LoraConfig {
153 r,
154 alpha,
155 dropout: 0.05,
156 ..Default::default()
157 },
158 quantization: QuantizationConfig {
159 block_size: 64,
160 double_quant: true,
161 compute_dtype: ComputeDType::BF16,
162 ..Default::default()
163 },
164 target_modules: vec!["q_proj".into(), "v_proj".into()],
165 cache_dequantized: false,
166 }
167 }
168
169 #[must_use]
173 pub fn preset_inference(r: usize, alpha: usize) -> Self {
174 Self {
175 cache_dequantized: true, ..Self::preset_all_bf16(r, alpha)
177 }
178 }
179
180 #[must_use]
182 pub fn is_target(&self, module_name: &str) -> bool {
183 self.target_modules.iter().any(|t| module_name.contains(t))
184 }
185
186 #[must_use]
188 #[allow(clippy::cast_precision_loss)]
189 pub fn scale(&self) -> f64 {
190 self.lora.alpha as f64 / self.lora.r as f64
191 }
192
193 pub fn validate_for_training(&self) -> Result<()> {
198 if self.lora.r == 0 {
199 return Err(QLoraError::InvalidConfig("LoRA rank must be > 0".into()));
200 }
201 if self.target_modules.is_empty() {
202 return Err(QLoraError::InvalidConfig(
203 "At least one target module required".into(),
204 ));
205 }
206 if matches!(self.quantization.compute_dtype, ComputeDType::F16) {
208 tracing::warn!(
209 "FP16 compute dtype may cause training instability (20% failure rate). \
210 Consider using BF16 instead."
211 );
212 }
213 Ok(())
214 }
215}
216
217pub struct QuantizedLinear {
227 quantized_weight: QuantizedTensor,
229 cached_weight: Option<Tensor>,
231 bias: Option<Tensor>,
233 lora: LoraLayer,
235 device: Device,
237 config: QLoraConfig,
239}
240
241impl QuantizedLinear {
242 pub fn from_weight(
256 weight: &Tensor,
257 bias: Option<Tensor>,
258 config: &QLoraConfig,
259 device: &Device,
260 ) -> Result<Self> {
261 warn_cpu_fallback(device);
262 let shape = weight.shape().dims();
263 if shape.len() != 2 {
264 return Err(QLoraError::InvalidConfig("weight must be 2D".into()));
265 }
266 let (out_features, in_features) = (shape[0], shape[1]);
267
268 let quantized_weight = quantize_nf4_with_config(weight, &config.quantization)?;
270
271 let cached_weight = if config.cache_dequantized {
273 Some(dequantize_nf4(&quantized_weight, device)?)
274 } else {
275 None
276 };
277
278 let lora =
280 LoraLayer::new_with_zeros(in_features, out_features, config.lora.clone(), device)?;
281
282 Ok(Self {
283 quantized_weight,
284 cached_weight,
285 bias,
286 lora,
287 device: device.clone(),
288 config: config.clone(),
289 })
290 }
291
292 pub fn from_weight_with_varbuilder(
306 weight: &Tensor,
307 bias: Option<Tensor>,
308 config: &QLoraConfig,
309 vb: VarBuilder,
310 ) -> Result<Self> {
311 let shape = weight.shape().dims();
312 if shape.len() != 2 {
313 return Err(QLoraError::InvalidConfig("weight must be 2D".into()));
314 }
315 let (out_features, in_features) = (shape[0], shape[1]);
316 let device = weight.device();
317 warn_cpu_fallback(device);
318
319 let quantized_weight = quantize_nf4_with_config(weight, &config.quantization)?;
321
322 let cached_weight = if config.cache_dequantized {
324 Some(dequantize_nf4(&quantized_weight, device)?)
325 } else {
326 None
327 };
328
329 let lora = LoraLayer::new(in_features, out_features, config.lora.clone(), vb)?;
331
332 Ok(Self {
333 quantized_weight,
334 cached_weight,
335 bias,
336 lora,
337 device: device.clone(),
338 config: config.clone(),
339 })
340 }
341
342 pub fn new(
349 in_features: usize,
350 out_features: usize,
351 config: &QLoraConfig,
352 device: &Device,
353 ) -> Result<Self> {
354 let weight = Tensor::zeros(&[out_features, in_features], DType::F32, device)?;
355 Self::from_weight(&weight, None, config, device)
356 }
357
358 pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
367 let weight = if let Some(cached) = &self.cached_weight {
369 cached.clone()
370 } else {
371 dequantize_nf4(&self.quantized_weight, &self.device)?
373 };
374 let weight_t = weight.t()?;
375
376 let base_output = if input.dims().len() == 3 {
378 let (batch, seq, in_features) = input.dims3()?;
380 let reshaped = input.reshape(&[batch * seq, in_features])?;
381 let out = reshaped.matmul(&weight_t)?;
382 let out_features = weight_t.dim(1)?;
384 out.reshape(&[batch, seq, out_features])?
385 } else {
386 input.matmul(&weight_t)?
388 };
389
390 let output = self.lora.forward(input, Some(&base_output))?;
392
393 match &self.bias {
395 Some(bias) => Ok(output.broadcast_add(bias)?),
396 None => Ok(output),
397 }
398 }
399
400 pub fn enable_weight_caching(&mut self) -> Result<()> {
408 if self.cached_weight.is_none() {
409 self.cached_weight = Some(dequantize_nf4(&self.quantized_weight, &self.device)?);
410 }
411 Ok(())
412 }
413
414 pub fn disable_weight_caching(&mut self) {
416 self.cached_weight = None;
417 }
418
419 #[must_use]
421 pub fn is_weight_cached(&self) -> bool {
422 self.cached_weight.is_some()
423 }
424
425 #[must_use]
427 pub fn config(&self) -> &QLoraConfig {
428 &self.config
429 }
430
431 #[must_use]
433 pub fn lora(&self) -> &LoraLayer {
434 &self.lora
435 }
436
437 pub fn lora_mut(&mut self) -> &mut LoraLayer {
439 &mut self.lora
440 }
441
442 #[must_use]
448 pub fn lora_weights(&self) -> (&Tensor, &Tensor) {
449 self.lora.weights()
450 }
451
452 #[must_use]
454 pub fn num_trainable_parameters(&self) -> usize {
455 self.lora.num_parameters()
456 }
457
458 #[must_use]
460 pub fn memory_bytes(&self) -> usize {
461 let quantized_size = self.quantized_weight.size_bytes();
462 let lora_size = self.lora.num_parameters() * 4; let bias_size = self.bias.as_ref().map_or(0, |b| b.elem_count() * 4);
464 quantized_size + lora_size + bias_size
465 }
466}
467
468pub struct QLoraLayer {
470 linear: QuantizedLinear,
472}
473
474impl QLoraLayer {
475 #[must_use]
477 pub fn new(linear: QuantizedLinear) -> Self {
478 Self { linear }
479 }
480
481 pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
486 self.linear.forward(input)
487 }
488
489 #[must_use]
491 pub fn quantized_weight(&self) -> &QuantizedTensor {
492 &self.linear.quantized_weight
493 }
494
495 #[must_use]
501 pub fn lora_weights(&self) -> (&Tensor, &Tensor) {
502 self.linear.lora_weights()
503 }
504
505 #[must_use]
507 pub fn lora_scale(&self) -> f64 {
508 self.linear.config.scale()
509 }
510
511 #[must_use]
513 pub fn device(&self) -> &Device {
514 &self.linear.device
515 }
516
517 #[must_use]
519 pub fn config(&self) -> &QLoraConfig {
520 &self.linear.config
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 #[test]
529 fn test_qlora_creation() {
530 let config = QLoraConfig::default();
531 let device = Device::Cpu;
532 let layer = QuantizedLinear::new(768, 768, &config, &device);
533 assert!(layer.is_ok());
534 }
535
536 #[test]
537 fn test_qlora_forward_shape() {
538 let config = QLoraConfig::default();
539 let device = Device::Cpu;
540 let layer = QuantizedLinear::new(768, 768, &config, &device).unwrap();
541
542 let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
543 let output = layer.forward(&input).unwrap();
544
545 assert_eq!(output.shape().dims(), &[1, 10, 768]);
546 }
547
548 #[test]
549 fn test_qlora_memory_reduction() {
550 let config = QLoraConfig::default();
551 let device = Device::Cpu;
552 let layer = QuantizedLinear::new(4096, 4096, &config, &device).unwrap();
553
554 let full_size = 4096 * 4096 * 4;
556 let actual_size = layer.memory_bytes();
557
558 #[allow(clippy::cast_precision_loss)]
560 let ratio = f64::from(full_size) / actual_size as f64;
561 assert!(ratio > 2.0, "Expected >2x reduction, got {ratio:.2}x");
562 }
563
564 #[test]
568 fn test_preset_all_bf16() {
569 let config = QLoraConfig::preset_all_bf16(64, 16);
570
571 assert_eq!(config.lora.r, 64);
573 assert_eq!(config.lora.alpha, 16);
574 assert!((config.lora.dropout - 0.05).abs() < 1e-10);
575
576 assert!(matches!(
578 config.quantization.compute_dtype,
579 ComputeDType::BF16
580 ));
581 assert!(config.quantization.double_quant);
582
583 assert!(config.target_modules.contains(&"q_proj".to_string()));
585 assert!(config.target_modules.contains(&"k_proj".to_string()));
586 assert!(config.target_modules.contains(&"v_proj".to_string()));
587 assert!(config.target_modules.contains(&"o_proj".to_string()));
588 assert!(config.target_modules.contains(&"gate_proj".to_string()));
589
590 assert!(!config.cache_dequantized);
592 }
593
594 #[test]
595 fn test_preset_qv_bf16() {
596 let config = QLoraConfig::preset_qv_bf16(32, 8);
597
598 assert_eq!(config.lora.r, 32);
600 assert_eq!(config.lora.alpha, 8);
601
602 assert_eq!(config.target_modules.len(), 2);
604 assert!(config.target_modules.contains(&"q_proj".to_string()));
605 assert!(config.target_modules.contains(&"v_proj".to_string()));
606
607 assert!(!config.target_modules.contains(&"k_proj".to_string()));
609 assert!(!config.target_modules.contains(&"o_proj".to_string()));
610 }
611
612 #[test]
613 fn test_preset_inference() {
614 let config = QLoraConfig::preset_inference(16, 32);
615
616 assert_eq!(config.lora.r, 16);
618 assert_eq!(config.lora.alpha, 32);
619
620 assert!(config.cache_dequantized);
622
623 assert!(matches!(
625 config.quantization.compute_dtype,
626 ComputeDType::BF16
627 ));
628 }
629
630 #[test]
631 fn test_is_target() {
632 let config = QLoraConfig::preset_all_bf16(8, 16);
633
634 assert!(config.is_target("model.layer.q_proj"));
636 assert!(config.is_target("transformer.blocks.0.attn.v_proj"));
637 assert!(config.is_target("gate_proj"));
638
639 assert!(!config.is_target("embed_tokens"));
641 assert!(!config.is_target("lm_head"));
642 assert!(!config.is_target("layer_norm"));
643 }
644
645 #[test]
646 fn test_scale() {
647 let config = QLoraConfig::preset_all_bf16(64, 16);
648 let scale = config.scale();
649
650 assert!((scale - 0.25).abs() < 1e-10);
652
653 let config2 = QLoraConfig::preset_all_bf16(8, 32);
654 let scale2 = config2.scale();
655
656 assert!((scale2 - 4.0).abs() < 1e-10);
658 }
659
660 #[test]
661 fn test_validate_for_training_success() {
662 let config = QLoraConfig::preset_all_bf16(8, 16);
663 assert!(config.validate_for_training().is_ok());
664 }
665
666 #[test]
667 fn test_validate_for_training_zero_rank() {
668 let mut config = QLoraConfig::preset_all_bf16(0, 16);
669 config.lora.r = 0;
670
671 let result = config.validate_for_training();
672 assert!(result.is_err());
673 if let Err(e) = result {
674 assert!(e.to_string().contains("rank"));
675 }
676 }
677
678 #[test]
679 fn test_validate_for_training_empty_targets() {
680 let mut config = QLoraConfig::preset_all_bf16(8, 16);
681 config.target_modules.clear();
682
683 let result = config.validate_for_training();
684 assert!(result.is_err());
685 if let Err(e) = result {
686 assert!(e.to_string().contains("target module"));
687 }
688 }
689
690 #[test]
691 fn test_default_config() {
692 let config = QLoraConfig::default();
693
694 assert!(matches!(
696 config.quantization.compute_dtype,
697 ComputeDType::BF16
698 ));
699
700 assert_eq!(config.lora.r, 64);
702 assert_eq!(config.lora.alpha, 16);
703
704 assert!(!config.target_modules.is_empty());
706
707 assert!(!config.cache_dequantized);
709 }
710
711 #[test]
712 fn test_lora_weights() {
713 let config = QLoraConfig::preset_all_bf16(8, 16);
714 let device = Device::Cpu;
715 let layer = QuantizedLinear::new(64, 128, &config, &device).unwrap();
716
717 let (a_weight, b_weight) = layer.lora_weights();
718
719 assert_eq!(a_weight.dims(), &[8, 64]);
721
722 assert_eq!(b_weight.dims(), &[128, 8]);
724 }
725}