#[derive(Debug, Clone)]
pub struct QLoraLayer {
pub in_features: usize,
pub out_features: usize,
pub rank: usize,
pub alpha: f32,
pub quant_bits: u8,
pub group_size: usize,
pub double_quant: bool,
}
impl QLoraLayer {
pub fn new(in_features: usize, out_features: usize, rank: usize, alpha: f32, bits: u8) -> Self {
Self {
in_features,
out_features,
rank,
alpha,
quant_bits: bits,
group_size: 64,
double_quant: true,
}
}
pub fn int4(in_features: usize, out_features: usize, rank: usize, alpha: f32) -> Self {
Self::new(in_features, out_features, rank, alpha, 4)
}
pub fn int8(in_features: usize, out_features: usize, rank: usize, alpha: f32) -> Self {
Self::new(in_features, out_features, rank, alpha, 8)
}
pub fn scaling(&self) -> f32 {
self.alpha / self.rank as f32
}
pub fn quantized_base_bytes(&self) -> usize {
let total_elements = self.in_features * self.out_features;
let base_bytes = (total_elements * self.quant_bits as usize) / 8;
let num_groups = total_elements / self.group_size;
let scale_bytes = num_groups * 2; base_bytes + scale_bytes
}
pub fn adapter_bytes(&self) -> usize {
let trainable_params = self.rank * self.in_features + self.out_features * self.rank;
trainable_params * 2 }
pub fn vram_savings_ratio(&self) -> f64 {
let full_fp16 = self.in_features * self.out_features * 2; let quantized = self.quantized_base_bytes() + self.adapter_bytes();
1.0 - (quantized as f64 / full_fp16 as f64)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qlora_int4() {
let layer = QLoraLayer::int4(4096, 4096, 16, 32.0);
assert_eq!(layer.quant_bits, 4);
assert!(layer.vram_savings_ratio() > 0.5); }
#[test]
fn test_qlora_int8() {
let layer = QLoraLayer::int8(4096, 4096, 16, 32.0);
assert_eq!(layer.quant_bits, 8);
}
}