use crate::error::MullamaError;
use crate::sys;
use crate::Model;
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, PartialEq)]
pub struct QuantizationParams {
pub quantization_type: QuantizationType,
pub n_threads: i32,
pub use_importance_matrix: bool,
pub importance_matrix: Option<Vec<f32>>,
pub output_path: Option<String>,
pub quality_threshold: f32,
pub preserve_layers: Vec<usize>,
pub layer_settings: HashMap<String, LayerQuantizationSettings>,
pub enable_calibration: bool,
pub calibration_data: Option<Vec<Vec<i32>>>,
}
#[derive(Debug, Clone, PartialEq)]
#[allow(non_camel_case_types)]
pub enum QuantizationType {
F32,
F16,
Q8_0,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q2_K,
Q3_K_S,
Q3_K_M,
Q3_K_L,
Q4_K_S,
Q4_K_M,
Q5_K_S,
Q5_K_M,
Q6_K,
Q8_K,
IQ2_XXS,
IQ2_XS,
IQ3_XXS,
IQ3_XS,
IQ4_NL,
IQ4_XS,
MXFP4_MOE,
Custom(Box<CustomQuantizationScheme>),
}
#[derive(Debug, Clone, PartialEq)]
pub struct CustomQuantizationScheme {
pub weight_bits: u8,
pub activation_bits: u8,
pub method: QuantizationMethod,
pub block_size: usize,
pub symmetric: bool,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum QuantizationMethod {
Linear,
KMeans,
VectorQuantization,
Learned,
}
#[derive(Debug, Clone, PartialEq)]
pub struct LayerQuantizationSettings {
pub quantization_type: QuantizationType,
pub skip_quantization: bool,
pub scale_factor: f32,
}
#[derive(Debug, Clone, PartialEq)]
pub struct QuantizationMetrics {
pub original_perplexity: f32,
pub quantized_perplexity: f32,
pub compression_ratio: f32,
pub size_reduction: u64,
pub accuracy_loss: f32,
pub layer_metrics: HashMap<String, LayerMetrics>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct LayerMetrics {
pub snr: f32,
pub mse: f32,
pub cosine_similarity: f32,
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct QuantizationEngine {
model: Model,
params: QuantizationParams,
importance_cache: HashMap<String, Vec<f32>>,
last_metrics: Option<QuantizationMetrics>,
}
impl QuantizationEngine {
pub fn new(model: Model, params: QuantizationParams) -> Self {
Self {
model,
params,
importance_cache: HashMap::new(),
last_metrics: None,
}
}
pub fn quantize(&mut self) -> Result<Model, MullamaError> {
self.validate_params()?;
if self.params.use_importance_matrix && self.params.importance_matrix.is_none() {
self.calculate_importance_matrix()?;
}
let qtype = self.params.quantization_type.clone();
match qtype {
QuantizationType::Custom(scheme) => self.quantize_custom(&scheme),
_ => self.quantize_standard(),
}
}
fn quantize_standard(&mut self) -> Result<Model, MullamaError> {
Err(MullamaError::NotImplemented(
"Standard quantization requires source model path. Use quantize_file() instead."
.to_string(),
))
}
pub fn quantize_file<P: AsRef<Path>>(
input_path: P,
output_path: P,
params: &QuantizationParams,
) -> Result<Model, MullamaError> {
let input_str = input_path.as_ref().to_string_lossy().to_string();
let output_str = output_path.as_ref().to_string_lossy().to_string();
let c_input = std::ffi::CString::new(input_str.clone())
.map_err(|_| MullamaError::InvalidInput("Invalid input path".to_string()))?;
let c_output = std::ffi::CString::new(output_str.clone())
.map_err(|_| MullamaError::InvalidInput("Invalid output path".to_string()))?;
let mut llama_params = unsafe { sys::llama_model_quantize_default_params() };
llama_params.ftype = match ¶ms.quantization_type {
QuantizationType::F32 => sys::llama_ftype::LLAMA_FTYPE_ALL_F32,
QuantizationType::F16 => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_F16,
QuantizationType::Q8_0 => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q8_0,
QuantizationType::Q4_0 => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q4_0,
QuantizationType::Q4_1 => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q4_1,
QuantizationType::Q5_0 => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q5_0,
QuantizationType::Q5_1 => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q5_1,
QuantizationType::Q2_K => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q2_K,
QuantizationType::Q3_K_S => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q3_K_S,
QuantizationType::Q3_K_M => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q3_K_M,
QuantizationType::Q3_K_L => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q3_K_L,
QuantizationType::Q4_K_S => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q4_K_S,
QuantizationType::Q4_K_M => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q4_K_M,
QuantizationType::Q5_K_S => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q5_K_S,
QuantizationType::Q5_K_M => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q5_K_M,
QuantizationType::Q6_K => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q6_K,
QuantizationType::Q8_K => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q8_0, QuantizationType::IQ2_XXS => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_IQ2_XXS,
QuantizationType::IQ2_XS => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_IQ2_XS,
QuantizationType::IQ3_XXS => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_IQ3_XXS,
QuantizationType::IQ3_XS => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_IQ3_XS,
QuantizationType::IQ4_NL => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_IQ4_NL,
QuantizationType::IQ4_XS => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_IQ4_XS,
QuantizationType::MXFP4_MOE => sys::llama_ftype::LLAMA_FTYPE_MOSTLY_Q4_K_M, QuantizationType::Custom(_) => {
return Err(MullamaError::NotImplemented(
"Custom quantization schemes not yet implemented".to_string(),
));
}
};
llama_params.nthread = params.n_threads;
let result = unsafe {
sys::llama_model_quantize(c_input.as_ptr(), c_output.as_ptr(), &llama_params)
};
if result != 0 {
return Err(MullamaError::QuantizationError(format!(
"Quantization failed with error code: {}",
result
)));
}
Model::load(&output_str)
}
fn quantize_custom(
&mut self,
scheme: &CustomQuantizationScheme,
) -> Result<Model, MullamaError> {
match scheme.method {
QuantizationMethod::Linear => self.quantize_linear(scheme),
QuantizationMethod::KMeans => self.quantize_kmeans(scheme),
QuantizationMethod::VectorQuantization => self.quantize_vector(scheme),
QuantizationMethod::Learned => self.quantize_learned(scheme),
}
}
fn quantize_linear(
&mut self,
_scheme: &CustomQuantizationScheme,
) -> Result<Model, MullamaError> {
Err(MullamaError::NotImplemented(
"Custom linear quantization not yet implemented".to_string(),
))
}
fn quantize_kmeans(
&mut self,
_scheme: &CustomQuantizationScheme,
) -> Result<Model, MullamaError> {
Err(MullamaError::NotImplemented(
"Custom K-means quantization not yet implemented".to_string(),
))
}
fn quantize_vector(
&mut self,
_scheme: &CustomQuantizationScheme,
) -> Result<Model, MullamaError> {
Err(MullamaError::NotImplemented(
"Custom vector quantization not yet implemented".to_string(),
))
}
fn quantize_learned(
&mut self,
_scheme: &CustomQuantizationScheme,
) -> Result<Model, MullamaError> {
Err(MullamaError::NotImplemented(
"Custom learned quantization not yet implemented".to_string(),
))
}
fn calculate_importance_matrix(&mut self) -> Result<(), MullamaError> {
if let Some(calibration_data) = self.params.calibration_data.clone() {
self.calculate_importance_from_data(&calibration_data)?;
} else {
self.calculate_default_importance()?;
}
Ok(())
}
fn calculate_importance_from_data(&mut self, _data: &[Vec<i32>]) -> Result<(), MullamaError> {
let embedding_dim = 768; let importance_matrix = vec![1.0; embedding_dim];
self.importance_cache
.insert("default".to_string(), importance_matrix);
Ok(())
}
fn calculate_default_importance(&mut self) -> Result<(), MullamaError> {
let embedding_dim = 768; let mut importance_matrix = vec![1.0; embedding_dim];
let num_layers = 32; for (i, importance) in importance_matrix.iter_mut().enumerate() {
let layer_ratio = i as f32 / num_layers as f32;
*importance = 1.0 + layer_ratio; }
self.importance_cache
.insert("default".to_string(), importance_matrix);
Ok(())
}
#[allow(dead_code)]
fn calculate_metrics(&mut self, quantized_model: &Model) -> Result<(), MullamaError> {
let original_size = 1000000000u64; let quantized_size = 500000000u64;
let compression_ratio = original_size as f32 / quantized_size as f32;
let size_reduction = original_size - quantized_size;
let original_perplexity = self.estimate_perplexity(&self.model)?;
let quantized_perplexity = self.estimate_perplexity(quantized_model)?;
let accuracy_loss = (quantized_perplexity - original_perplexity) / original_perplexity;
let metrics = QuantizationMetrics {
original_perplexity,
quantized_perplexity,
compression_ratio,
size_reduction,
accuracy_loss,
layer_metrics: HashMap::new(), };
self.last_metrics = Some(metrics);
Ok(())
}
#[allow(dead_code)]
fn estimate_perplexity(&self, _model: &Model) -> Result<f32, MullamaError> {
Ok(10.0) }
fn validate_params(&self) -> Result<(), MullamaError> {
if self.params.quality_threshold < 0.0 || self.params.quality_threshold > 1.0 {
return Err(MullamaError::InvalidInput(
"Quality threshold must be between 0.0 and 1.0".to_string(),
));
}
if self.params.n_threads <= 0 {
return Err(MullamaError::InvalidInput(
"Number of threads must be positive".to_string(),
));
}
Ok(())
}
#[allow(dead_code)]
fn create_llama_quantize_params(
&self,
) -> Result<sys::llama_model_quantize_params, MullamaError> {
let mut params = unsafe { sys::llama_model_quantize_default_params() };
params.nthread = self.params.n_threads;
Ok(params)
}
pub fn last_metrics(&self) -> Option<&QuantizationMetrics> {
self.last_metrics.as_ref()
}
pub fn set_params(&mut self, params: QuantizationParams) {
self.params = params;
}
}
impl Default for QuantizationParams {
fn default() -> Self {
Self {
quantization_type: QuantizationType::Q4_K_M,
n_threads: num_cpus::get() as i32,
use_importance_matrix: false,
importance_matrix: None,
output_path: None,
quality_threshold: 0.95,
preserve_layers: Vec::new(),
layer_settings: HashMap::new(),
enable_calibration: false,
calibration_data: None,
}
}
}
impl QuantizationParams {
pub fn with_type(mut self, qtype: QuantizationType) -> Self {
self.quantization_type = qtype;
self
}
pub fn with_threads(mut self, threads: i32) -> Self {
self.n_threads = threads;
self
}
pub fn with_importance_matrix(mut self, matrix: Vec<f32>) -> Self {
self.use_importance_matrix = true;
self.importance_matrix = Some(matrix);
self
}
pub fn with_output_path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.output_path = Some(path.as_ref().to_string_lossy().to_string());
self
}
pub fn with_quality_threshold(mut self, threshold: f32) -> Self {
self.quality_threshold = threshold;
self
}
pub fn with_preserved_layers(mut self, layers: Vec<usize>) -> Self {
self.preserve_layers = layers;
self
}
pub fn with_calibration_data(mut self, data: Vec<Vec<i32>>) -> Self {
self.enable_calibration = true;
self.calibration_data = Some(data);
self
}
}
pub mod utils {
use super::*;
pub fn recommend_quantization(_model: &Model) -> QuantizationType {
let size = 1000000000u64;
if size > 20_000_000_000 {
QuantizationType::Q4_K_M } else if size > 7_000_000_000 {
QuantizationType::Q5_K_M } else if size > 3_000_000_000 {
QuantizationType::Q8_0 } else {
QuantizationType::F16 }
}
pub fn compression_ratio(qtype: QuantizationType) -> f32 {
match qtype {
QuantizationType::F32 => 1.0,
QuantizationType::F16 => 2.0,
QuantizationType::Q8_0 => 4.0,
QuantizationType::Q5_0 | QuantizationType::Q5_1 => 6.4,
QuantizationType::Q4_0 | QuantizationType::Q4_1 => 8.0,
QuantizationType::Q4_K_M => 8.5,
QuantizationType::Q3_K_M => 10.7,
QuantizationType::Q2_K => 16.0,
QuantizationType::IQ2_XXS => 20.0,
_ => 8.0, }
}
pub fn speed_optimized_params() -> QuantizationParams {
QuantizationParams::default()
.with_type(QuantizationType::Q4_0)
.with_threads(num_cpus::get() as i32)
}
pub fn quality_optimized_params() -> QuantizationParams {
QuantizationParams::default()
.with_type(QuantizationType::Q8_0)
.with_quality_threshold(0.98)
}
pub fn size_optimized_params() -> QuantizationParams {
QuantizationParams::default()
.with_type(QuantizationType::Q2_K)
.with_quality_threshold(0.90)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantization_params() {
let params = QuantizationParams::default()
.with_type(QuantizationType::Q4_K_M)
.with_threads(8)
.with_quality_threshold(0.95);
assert!(matches!(params.quantization_type, QuantizationType::Q4_K_M));
assert_eq!(params.n_threads, 8);
assert_eq!(params.quality_threshold, 0.95);
}
#[test]
fn test_quantization_types() {
assert_eq!(utils::compression_ratio(QuantizationType::F32), 1.0);
assert_eq!(utils::compression_ratio(QuantizationType::F16), 2.0);
assert_eq!(utils::compression_ratio(QuantizationType::Q4_0), 8.0);
}
#[test]
fn test_custom_quantization_scheme() {
let scheme = CustomQuantizationScheme {
weight_bits: 4,
activation_bits: 8,
method: QuantizationMethod::Linear,
block_size: 128,
symmetric: true,
};
assert_eq!(scheme.weight_bits, 4);
assert_eq!(scheme.activation_bits, 8);
assert!(scheme.symmetric);
}
#[test]
fn test_layer_settings() {
let mut settings = HashMap::new();
settings.insert(
"attention".to_string(),
LayerQuantizationSettings {
quantization_type: QuantizationType::Q8_0,
skip_quantization: false,
scale_factor: 1.0,
},
);
assert_eq!(settings.len(), 1);
assert!(matches!(
settings["attention"].quantization_type,
QuantizationType::Q8_0
));
}
#[test]
fn test_optimization_presets() {
let speed_params = utils::speed_optimized_params();
assert!(matches!(
speed_params.quantization_type,
QuantizationType::Q4_0
));
let quality_params = utils::quality_optimized_params();
assert!(matches!(
quality_params.quantization_type,
QuantizationType::Q8_0
));
assert_eq!(quality_params.quality_threshold, 0.98);
let size_params = utils::size_optimized_params();
assert!(matches!(
size_params.quantization_type,
QuantizationType::Q2_K
));
assert_eq!(size_params.quality_threshold, 0.90);
}
}