use alloc::string::String;
use alloc::vec::Vec;
#[derive(Debug, Clone)]
pub struct QLoraMetadata {
pub original_size: u64,
pub bits: u8,
pub rank: u16,
pub dimensions: Vec<u32>,
pub dtype: TensorDtype,
pub scale_factors: Vec<f32>,
pub zero_points: Option<Vec<i8>>,
pub original_checksum: u64,
pub timestamp: u64,
pub extra: Option<String>,
}
impl QLoraMetadata {
pub fn new(original_size: u64, bits: u8, rank: u16) -> Self {
Self {
original_size,
bits,
rank,
dimensions: Vec::new(),
dtype: TensorDtype::Float32,
scale_factors: Vec::new(),
zero_points: None,
original_checksum: 0,
timestamp: 0,
extra: None,
}
}
pub fn compression_ratio(&self, compressed_size: usize) -> f64 {
if compressed_size == 0 {
return 0.0;
}
self.original_size as f64 / compressed_size as f64
}
pub fn expected_ratio(&self) -> f64 {
let original_bits = self.dtype.bits() as f64;
let quantized_bits = self.bits as f64;
let bit_ratio = original_bits / quantized_bits;
bit_ratio * 0.85
}
pub fn is_valid(&self) -> bool {
if self.original_size == 0 {
return false;
}
if self.bits != 4 && self.bits != 8 {
return false;
}
if self.rank == 0 || self.rank > 1024 {
return false;
}
if self.scale_factors.is_empty() {
return false;
}
let total_elements: u64 = self.dimensions.iter().map(|&d| d as u64).product();
let expected_size = total_elements * (self.dtype.bits() as u64 / 8);
if expected_size != self.original_size && !self.dimensions.is_empty() {
return false;
}
true
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(64 + self.scale_factors.len() * 4);
bytes.extend_from_slice(b"QLOR");
bytes.push(1);
bytes.extend_from_slice(&self.original_size.to_le_bytes());
bytes.push(self.bits);
bytes.extend_from_slice(&self.rank.to_le_bytes());
bytes.push(self.dtype as u8);
bytes.extend_from_slice(&(self.dimensions.len() as u16).to_le_bytes());
for &dim in &self.dimensions {
bytes.extend_from_slice(&dim.to_le_bytes());
}
bytes.extend_from_slice(&(self.scale_factors.len() as u32).to_le_bytes());
for &scale in &self.scale_factors {
bytes.extend_from_slice(&scale.to_le_bytes());
}
if let Some(ref zp) = self.zero_points {
bytes.push(1);
bytes.extend_from_slice(&(zp.len() as u32).to_le_bytes());
for &z in zp {
bytes.push(z as u8);
}
} else {
bytes.push(0);
}
bytes.extend_from_slice(&self.original_checksum.to_le_bytes());
bytes.extend_from_slice(&self.timestamp.to_le_bytes());
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 20 {
return None;
}
if &bytes[0..4] != b"QLOR" {
return None;
}
let version = bytes[4];
if version != 1 {
return None;
}
let mut offset = 5;
let original_size = u64::from_le_bytes(bytes[offset..offset + 8].try_into().ok()?);
offset += 8;
let bits = bytes[offset];
offset += 1;
let rank = u16::from_le_bytes(bytes[offset..offset + 2].try_into().ok()?);
offset += 2;
let dtype = TensorDtype::from_u8(bytes[offset])?;
offset += 1;
let dim_count = u16::from_le_bytes(bytes[offset..offset + 2].try_into().ok()?) as usize;
offset += 2;
let mut dimensions = Vec::with_capacity(dim_count);
for _ in 0..dim_count {
dimensions.push(u32::from_le_bytes(
bytes[offset..offset + 4].try_into().ok()?,
));
offset += 4;
}
let scale_count = u32::from_le_bytes(bytes[offset..offset + 4].try_into().ok()?) as usize;
offset += 4;
let mut scale_factors = Vec::with_capacity(scale_count);
for _ in 0..scale_count {
scale_factors.push(f32::from_le_bytes(
bytes[offset..offset + 4].try_into().ok()?,
));
offset += 4;
}
let has_zp = bytes[offset] != 0;
offset += 1;
let zero_points = if has_zp {
let zp_count = u32::from_le_bytes(bytes[offset..offset + 4].try_into().ok()?) as usize;
offset += 4;
let mut zp = Vec::with_capacity(zp_count);
for i in 0..zp_count {
zp.push(bytes[offset + i] as i8);
}
offset += zp_count;
Some(zp)
} else {
None
};
let original_checksum = u64::from_le_bytes(bytes[offset..offset + 8].try_into().ok()?);
offset += 8;
let timestamp = u64::from_le_bytes(bytes[offset..offset + 8].try_into().ok()?);
Some(Self {
original_size,
bits,
rank,
dimensions,
dtype,
scale_factors,
zero_points,
original_checksum,
timestamp,
extra: None,
})
}
}
impl Default for QLoraMetadata {
fn default() -> Self {
Self::new(0, 4, 16)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum TensorDtype {
Float32 = 0,
Float16 = 1,
BFloat16 = 2,
Float64 = 3,
Float8E4M3 = 4,
Float8E5M2 = 5,
}
impl TensorDtype {
pub const fn bits(&self) -> u8 {
match self {
TensorDtype::Float32 => 32,
TensorDtype::Float16 => 16,
TensorDtype::BFloat16 => 16,
TensorDtype::Float64 => 64,
TensorDtype::Float8E4M3 => 8,
TensorDtype::Float8E5M2 => 8,
}
}
pub const fn bytes(&self) -> u8 {
self.bits() / 8
}
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0 => Some(TensorDtype::Float32),
1 => Some(TensorDtype::Float16),
2 => Some(TensorDtype::BFloat16),
3 => Some(TensorDtype::Float64),
4 => Some(TensorDtype::Float8E4M3),
5 => Some(TensorDtype::Float8E5M2),
_ => None,
}
}
pub const fn name(&self) -> &'static str {
match self {
TensorDtype::Float32 => "float32",
TensorDtype::Float16 => "float16",
TensorDtype::BFloat16 => "bfloat16",
TensorDtype::Float64 => "float64",
TensorDtype::Float8E4M3 => "float8_e4m3",
TensorDtype::Float8E5M2 => "float8_e5m2",
}
}
}
#[derive(Debug, Clone)]
pub struct QLoraConfig {
pub bits: u8,
pub rank: u16,
pub group_size: u32,
pub symmetric: bool,
pub min_size: usize,
pub max_size: usize,
pub quality_threshold: f32,
}
impl QLoraConfig {
pub fn aggressive() -> Self {
Self {
bits: 4,
rank: 8,
group_size: 128,
symmetric: true,
min_size: 4096,
max_size: 1024 * 1024 * 1024, quality_threshold: 0.95,
}
}
pub fn quality() -> Self {
Self {
bits: 8,
rank: 64,
group_size: 32,
symmetric: false,
min_size: 4096,
max_size: 1024 * 1024 * 1024,
quality_threshold: 0.99,
}
}
pub fn balanced() -> Self {
Self {
bits: 4,
rank: 32,
group_size: 64,
symmetric: true,
min_size: 4096,
max_size: 1024 * 1024 * 1024,
quality_threshold: 0.97,
}
}
pub fn is_valid(&self) -> bool {
(self.bits == 4 || self.bits == 8)
&& self.rank > 0
&& self.rank <= 1024
&& self.group_size >= 8
&& self.group_size <= 256
&& self.min_size > 0
&& self.max_size >= self.min_size
&& self.quality_threshold > 0.0
&& self.quality_threshold <= 1.0
}
}
impl Default for QLoraConfig {
fn default() -> Self {
Self::balanced()
}
}
pub trait QLoraProvider: Send + Sync {
fn is_qlora_candidate(&self, data: &[u8]) -> bool;
fn compress(&self, data: &[u8], config: &QLoraConfig) -> Option<(Vec<u8>, QLoraMetadata)>;
fn decompress(&self, compressed: &[u8], metadata: &QLoraMetadata) -> Option<Vec<u8>>;
fn config(&self) -> &QLoraConfig;
fn name(&self) -> &str;
fn supports_gpu(&self) -> bool;
fn statistics(&self) -> QLoraStatistics {
QLoraStatistics::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct QLoraStatistics {
pub bytes_in: u64,
pub bytes_out: u64,
pub compress_count: u64,
pub decompress_count: u64,
pub candidates_detected: u64,
pub candidates_rejected: u64,
pub compress_time_us: u64,
pub decompress_time_us: u64,
}
impl QLoraStatistics {
pub fn compression_ratio(&self) -> f64 {
if self.bytes_out == 0 {
return 0.0;
}
self.bytes_in as f64 / self.bytes_out as f64
}
pub fn compress_throughput_mbps(&self) -> f64 {
if self.compress_time_us == 0 {
return 0.0;
}
(self.bytes_in as f64 / 1_000_000.0) / (self.compress_time_us as f64 / 1_000_000.0)
}
pub fn decompress_throughput_mbps(&self) -> f64 {
if self.decompress_time_us == 0 {
return 0.0;
}
(self.bytes_in as f64 / 1_000_000.0) / (self.decompress_time_us as f64 / 1_000_000.0)
}
}
pub struct NoOpQLoraProvider {
config: QLoraConfig,
}
impl NoOpQLoraProvider {
pub const fn new() -> Self {
Self {
config: QLoraConfig {
bits: 4,
rank: 16,
group_size: 64,
symmetric: true,
min_size: 4096,
max_size: 1024 * 1024 * 1024,
quality_threshold: 0.97,
},
}
}
}
impl Default for NoOpQLoraProvider {
fn default() -> Self {
Self::new()
}
}
impl QLoraProvider for NoOpQLoraProvider {
fn is_qlora_candidate(&self, _data: &[u8]) -> bool {
false
}
fn compress(&self, _data: &[u8], _config: &QLoraConfig) -> Option<(Vec<u8>, QLoraMetadata)> {
None
}
fn decompress(&self, _compressed: &[u8], _metadata: &QLoraMetadata) -> Option<Vec<u8>> {
None
}
fn config(&self) -> &QLoraConfig {
&self.config
}
fn name(&self) -> &str {
"no-op"
}
fn supports_gpu(&self) -> bool {
false
}
}
static NOOP_PROVIDER: NoOpQLoraProvider = NoOpQLoraProvider::new();
static QLORA_PROVIDER: spin::Once<&'static dyn QLoraProvider> = spin::Once::new();
pub fn register_qlora_provider(provider: &'static dyn QLoraProvider) {
QLORA_PROVIDER.call_once(|| provider);
}
pub fn get_qlora_provider() -> &'static dyn QLoraProvider {
QLORA_PROVIDER.get().copied().unwrap_or(&NOOP_PROVIDER)
}
pub fn is_qlora_available() -> bool {
QLORA_PROVIDER.get().is_some()
}
pub fn is_qlora_candidate(data: &[u8]) -> bool {
get_qlora_provider().is_qlora_candidate(data)
}
pub fn compress_qlora(data: &[u8]) -> Option<(Vec<u8>, QLoraMetadata)> {
let provider = get_qlora_provider();
provider.compress(data, provider.config())
}
pub fn compress_qlora_with_config(
data: &[u8],
config: &QLoraConfig,
) -> Option<(Vec<u8>, QLoraMetadata)> {
get_qlora_provider().compress(data, config)
}
pub fn decompress_qlora(compressed: &[u8], metadata: &QLoraMetadata) -> Option<Vec<u8>> {
get_qlora_provider().decompress(compressed, metadata)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorFormat {
SafeTensors,
Gguf,
Ggml,
PyTorch,
NumPy,
Raw,
Unknown,
}
impl TensorFormat {
pub fn detect(data: &[u8]) -> Self {
if data.len() < 8 {
return TensorFormat::Unknown;
}
if data.len() >= 16 {
let header_len = u64::from_le_bytes(data[0..8].try_into().unwrap_or([0; 8]));
if header_len > 0 && header_len < 1_000_000 && data.get(8) == Some(&b'{') {
return TensorFormat::SafeTensors;
}
}
if &data[0..4] == b"GGUF" {
return TensorFormat::Gguf;
}
if &data[0..4] == b"GGML" || &data[0..4] == b"lmgg" {
return TensorFormat::Ggml;
}
if data.len() >= 6 && data[0] == 0x93 && &data[1..6] == b"NUMPY" {
return TensorFormat::NumPy;
}
if &data[0..2] == b"PK" {
return TensorFormat::PyTorch;
}
TensorFormat::Unknown
}
pub const fn extension(&self) -> &'static str {
match self {
TensorFormat::SafeTensors => ".safetensors",
TensorFormat::Gguf => ".gguf",
TensorFormat::Ggml => ".ggml",
TensorFormat::PyTorch => ".pt",
TensorFormat::NumPy => ".npy",
TensorFormat::Raw => ".bin",
TensorFormat::Unknown => "",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn test_tensor_dtype_bits() {
assert_eq!(TensorDtype::Float32.bits(), 32);
assert_eq!(TensorDtype::Float16.bits(), 16);
assert_eq!(TensorDtype::BFloat16.bits(), 16);
assert_eq!(TensorDtype::Float64.bits(), 64);
assert_eq!(TensorDtype::Float8E4M3.bits(), 8);
}
#[test]
fn test_tensor_dtype_from_u8() {
assert_eq!(TensorDtype::from_u8(0), Some(TensorDtype::Float32));
assert_eq!(TensorDtype::from_u8(1), Some(TensorDtype::Float16));
assert_eq!(TensorDtype::from_u8(255), None);
}
#[test]
fn test_qlora_metadata_new() {
let meta = QLoraMetadata::new(1024, 4, 16);
assert_eq!(meta.original_size, 1024);
assert_eq!(meta.bits, 4);
assert_eq!(meta.rank, 16);
}
#[test]
fn test_qlora_metadata_expected_ratio() {
let mut meta = QLoraMetadata::new(1024, 4, 16);
meta.dtype = TensorDtype::Float32;
let ratio = meta.expected_ratio();
assert!(ratio > 6.0 && ratio < 8.0);
}
#[test]
fn test_qlora_metadata_serialization() {
let mut meta = QLoraMetadata::new(1024, 4, 16);
meta.dimensions = vec![32, 32];
meta.scale_factors = vec![1.0, 0.5, 0.25];
meta.original_checksum = 0xDEADBEEF;
meta.timestamp = 1234567890;
let bytes = meta.to_bytes();
let recovered = QLoraMetadata::from_bytes(&bytes).unwrap();
assert_eq!(recovered.original_size, meta.original_size);
assert_eq!(recovered.bits, meta.bits);
assert_eq!(recovered.rank, meta.rank);
assert_eq!(recovered.dimensions, meta.dimensions);
assert_eq!(recovered.scale_factors.len(), meta.scale_factors.len());
assert_eq!(recovered.original_checksum, meta.original_checksum);
assert_eq!(recovered.timestamp, meta.timestamp);
}
#[test]
fn test_qlora_config_presets() {
let aggressive = QLoraConfig::aggressive();
assert_eq!(aggressive.bits, 4);
assert_eq!(aggressive.rank, 8);
assert!(aggressive.is_valid());
let quality = QLoraConfig::quality();
assert_eq!(quality.bits, 8);
assert_eq!(quality.rank, 64);
assert!(quality.is_valid());
let balanced = QLoraConfig::balanced();
assert!(balanced.is_valid());
}
#[test]
fn test_qlora_config_validation() {
let mut config = QLoraConfig::default();
assert!(config.is_valid());
config.bits = 3; assert!(!config.is_valid());
config.bits = 4;
config.rank = 0; assert!(!config.is_valid());
}
#[test]
fn test_noop_provider() {
let provider = NoOpQLoraProvider::new();
assert!(!provider.is_qlora_candidate(&[1, 2, 3, 4]));
assert!(
provider
.compress(&[1, 2, 3, 4], &QLoraConfig::default())
.is_none()
);
assert!(
provider
.decompress(&[], &QLoraMetadata::default())
.is_none()
);
assert_eq!(provider.name(), "no-op");
assert!(!provider.supports_gpu());
}
#[test]
fn test_global_provider_fallback() {
let provider = get_qlora_provider();
assert_eq!(provider.name(), "no-op");
assert!(!is_qlora_candidate(&[1, 2, 3, 4]));
}
#[test]
fn test_tensor_format_detection() {
let gguf_data = b"GGUF\x00\x00\x00\x00";
assert_eq!(TensorFormat::detect(gguf_data), TensorFormat::Gguf);
let ggml_data = b"GGML\x00\x00\x00\x00";
assert_eq!(TensorFormat::detect(ggml_data), TensorFormat::Ggml);
let numpy_data = [0x93, b'N', b'U', b'M', b'P', b'Y', 0, 0];
assert_eq!(TensorFormat::detect(&numpy_data), TensorFormat::NumPy);
let pytorch_data = b"PK\x03\x04\x00\x00\x00\x00";
assert_eq!(TensorFormat::detect(pytorch_data), TensorFormat::PyTorch);
let unknown_data = b"UNKNOWN_";
assert_eq!(TensorFormat::detect(unknown_data), TensorFormat::Unknown);
assert_eq!(TensorFormat::detect(b"short"), TensorFormat::Unknown);
}
#[test]
fn test_tensor_format_extension() {
assert_eq!(TensorFormat::SafeTensors.extension(), ".safetensors");
assert_eq!(TensorFormat::Gguf.extension(), ".gguf");
assert_eq!(TensorFormat::PyTorch.extension(), ".pt");
}
#[test]
fn test_statistics_calculations() {
let stats = QLoraStatistics {
bytes_in: 1_000_000,
bytes_out: 100_000,
compress_count: 10,
decompress_count: 5,
candidates_detected: 15,
candidates_rejected: 5,
compress_time_us: 1_000_000, decompress_time_us: 500_000, };
assert!((stats.compression_ratio() - 10.0).abs() < 0.01);
assert!((stats.compress_throughput_mbps() - 1.0).abs() < 0.01);
assert!((stats.decompress_throughput_mbps() - 2.0).abs() < 0.01);
}
#[test]
fn test_metadata_validity() {
let mut meta = QLoraMetadata::new(1024, 4, 16);
meta.scale_factors = vec![1.0];
assert!(meta.is_valid());
meta.bits = 3;
assert!(!meta.is_valid());
meta.bits = 4;
meta.rank = 0;
assert!(!meta.is_valid());
meta.rank = 16;
meta.scale_factors.clear();
assert!(!meta.is_valid());
}
#[test]
fn test_convenience_functions() {
assert!(!is_qlora_candidate(&[1, 2, 3, 4]));
assert!(compress_qlora(&[1, 2, 3, 4]).is_none());
assert!(compress_qlora_with_config(&[1, 2, 3, 4], &QLoraConfig::aggressive()).is_none());
assert!(decompress_qlora(&[], &QLoraMetadata::default()).is_none());
}
}