use crate::{
device_info::{MobileDeviceInfo, PerformanceTier},
optimization::memory_pool::MobileMemoryPool,
Result,
};
use half::f16;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use trustformers_core::errors::{invalid_config, invalid_input};
use trustformers_core::Tensor;
pub struct MobileQuantizationEngine {
config: QuantizationConfig,
device_info: MobileDeviceInfo,
calibration_data: Option<CalibrationDataset>,
quantization_cache: Arc<Mutex<HashMap<String, QuantizedModel>>>,
memory_pool: Option<Arc<MobileMemoryPool>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationConfig {
pub target_precision: MobilePrecision,
pub enable_mixed_precision: bool,
pub dynamic_strategy: DynamicQuantizationStrategy,
pub hardware_aware: bool,
pub granularity: QuantizationGranularity,
pub quality_threshold: f32,
pub memory_constraint_mb: usize,
pub enable_gradient_quantization: bool,
pub kl_threshold: f32,
pub enable_ptq: bool,
pub enable_qat: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum MobilePrecision {
INT4,
INT8,
FP16,
Mixed4_8,
Mixed8_16,
DYNAMIC,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum DynamicQuantizationStrategy {
BatteryAware,
ThermalAware,
MemoryAware,
PerformanceAware,
Adaptive,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum QuantizationGranularity {
PerTensor,
PerChannel,
PerGroup { group_size: usize },
PerLayer,
}
#[derive(Debug, Clone)]
pub struct CalibrationDataset {
pub samples: Vec<Tensor>,
pub weights: Option<Vec<f32>>,
pub statistics: DatasetStatistics,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct DatasetStatistics {
pub activation_ranges: HashMap<String, (f32, f32)>,
pub layer_means: HashMap<String, f32>,
pub layer_variances: HashMap<String, f32>,
pub kl_scores: HashMap<String, f32>,
}
#[derive(Debug, Clone)]
pub struct QuantizedModel {
pub weights: HashMap<String, QuantizedTensor>,
pub parameters: QuantizationParameters,
pub metadata: ModelMetadata,
pub benchmarks: QuantizationBenchmarks,
}
#[derive(Debug, Clone)]
pub struct QuantizedTensor {
pub data: Vec<i8>,
pub scales: Vec<f32>,
pub zero_points: Vec<i32>,
pub shape: Vec<usize>,
pub scheme: QuantizationScheme,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationScheme {
pub bits: u8,
pub symmetric: bool,
pub signed: bool,
pub method: QuantizationMethod,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum QuantizationMethod {
Linear,
Logarithmic,
PowerOfTwo,
KMeans,
Learned,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ModelFormat {
SafeTensors,
PyTorchPickle,
TensorFlow,
ONNX,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationParameters {
pub global_scale: f32,
pub layer_scales: HashMap<String, f32>,
pub layer_zero_points: HashMap<String, i32>,
pub dequant_overhead_ms: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub original_size_bytes: usize,
pub quantized_size_bytes: usize,
pub compression_ratio: f32,
pub quality_score: f32,
pub timestamp: std::time::SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationBenchmarks {
pub original_inference_ms: f32,
pub quantized_inference_ms: f32,
pub speedup_factor: f32,
pub memory_reduction_mb: f32,
pub power_reduction_mw: f32,
}
impl MobileQuantizationEngine {
pub fn new(
config: QuantizationConfig,
device_info: MobileDeviceInfo,
memory_pool: Option<Arc<MobileMemoryPool>>,
) -> Result<Self> {
Ok(Self {
config,
device_info,
calibration_data: None,
quantization_cache: Arc::new(Mutex::new(HashMap::new())),
memory_pool,
})
}
pub fn set_calibration_data(&mut self, dataset: CalibrationDataset) -> Result<()> {
if dataset.samples.is_empty() {
return Err(invalid_config(
"set_calibration_data",
"Calibration dataset cannot be empty",
));
}
self.calibration_data = Some(dataset);
Ok(())
}
pub fn quantize_model(&self, model_id: &str, model_data: &[u8]) -> Result<QuantizedModel> {
{
let cache = self.quantization_cache.lock().expect("Operation failed");
if let Some(cached_model) = cache.get(model_id) {
return Ok(cached_model.clone());
}
}
let strategy = self.determine_quantization_strategy()?;
let hardware_config = self.get_hardware_quantization_config()?;
let quantized_model = self.perform_quantization(model_data, &strategy, &hardware_config)?;
let benchmarks = self.benchmark_quantized_model(&quantized_model)?;
let final_model = QuantizedModel {
weights: quantized_model.weights,
parameters: quantized_model.parameters,
metadata: quantized_model.metadata,
benchmarks,
};
{
let mut cache = self.quantization_cache.lock().expect("Operation failed");
cache.insert(model_id.to_string(), final_model.clone());
}
Ok(final_model)
}
fn determine_quantization_strategy(&self) -> Result<MobilePrecision> {
match (
&self.device_info.performance_scores.overall_tier,
&self.config.target_precision,
) {
(PerformanceTier::High, MobilePrecision::DYNAMIC) => {
Ok(MobilePrecision::Mixed8_16)
},
(PerformanceTier::Mid, MobilePrecision::DYNAMIC) => {
Ok(MobilePrecision::INT8)
},
(PerformanceTier::Budget, MobilePrecision::DYNAMIC) => {
Ok(MobilePrecision::Mixed4_8)
},
(_, precision) => Ok(*precision),
}
}
fn get_hardware_quantization_config(&self) -> Result<HardwareQuantizationConfig> {
let mut config = HardwareQuantizationConfig::default();
if self.device_info.npu_info.is_some() {
config.use_npu_kernels = true;
config.preferred_precision = MobilePrecision::INT8;
}
if self.device_info.gpu_info.is_some() {
config.use_gpu_kernels = true;
config.gpu_memory_optimization = true;
}
if self.device_info.cpu_info.architecture.contains("arm")
|| self.device_info.cpu_info.architecture.contains("aarch64")
{
config.use_neon_instructions = true;
config.arm_specific_kernels = true;
}
Ok(config)
}
fn perform_quantization(
&self,
model_data: &[u8],
strategy: &MobilePrecision,
hardware_config: &HardwareQuantizationConfig,
) -> Result<QuantizedModel> {
let weights = self.parse_model_weights(model_data)?;
let quantized_weights = match strategy {
MobilePrecision::INT4 => self.quantize_to_int4(&weights)?,
MobilePrecision::INT8 => self.quantize_to_int8(&weights)?,
MobilePrecision::FP16 => self.quantize_to_fp16(&weights)?,
MobilePrecision::Mixed4_8 => self.quantize_mixed_4_8(&weights)?,
MobilePrecision::Mixed8_16 => self.quantize_mixed_8_16(&weights)?,
MobilePrecision::DYNAMIC => self.quantize_dynamic(&weights)?,
};
let parameters = self.calculate_quantization_parameters(&quantized_weights)?;
let metadata = ModelMetadata {
original_size_bytes: model_data.len(),
quantized_size_bytes: self.calculate_quantized_size(&quantized_weights),
compression_ratio: model_data.len() as f32
/ self.calculate_quantized_size(&quantized_weights) as f32,
quality_score: self.estimate_quality_score(&quantized_weights)?,
timestamp: std::time::SystemTime::now(),
};
Ok(QuantizedModel {
weights: quantized_weights,
parameters,
metadata,
benchmarks: QuantizationBenchmarks::default(), })
}
fn quantize_to_int4(
&self,
weights: &HashMap<String, Tensor>,
) -> Result<HashMap<String, QuantizedTensor>> {
let mut quantized = HashMap::new();
for (layer_name, tensor) in weights {
let tensor_data = tensor.data()?.to_vec();
let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let scale = (max_val - min_val) / 15.0; let zero_point = (-min_val / scale).round() as i32;
let quantized_data: Vec<i8> = tensor_data
.iter()
.map(|&x| {
let quantized = ((x / scale) + zero_point as f32).round();
quantized.max(0.0).min(15.0) as i8
})
.collect();
let quantized_tensor = QuantizedTensor {
data: quantized_data,
scales: vec![scale],
zero_points: vec![zero_point],
shape: tensor.shape().to_vec(),
scheme: QuantizationScheme {
bits: 4,
symmetric: false,
signed: false,
method: QuantizationMethod::Linear,
},
};
quantized.insert(layer_name.clone(), quantized_tensor);
}
Ok(quantized)
}
fn quantize_to_int8(
&self,
weights: &HashMap<String, Tensor>,
) -> Result<HashMap<String, QuantizedTensor>> {
let mut quantized = HashMap::new();
for (layer_name, tensor) in weights {
let tensor_data = tensor.data()?.to_vec();
let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let scale = (max_val - min_val) / 255.0; let zero_point = (-min_val / scale).round() as i32;
let quantized_data: Vec<i8> = tensor_data
.iter()
.map(|&x| {
let quantized = ((x / scale) + zero_point as f32).round();
(quantized.max(0.0).min(255.0) as i32 - 128) as i8 })
.collect();
let quantized_tensor = QuantizedTensor {
data: quantized_data,
scales: vec![scale],
zero_points: vec![zero_point],
shape: tensor.shape().to_vec(),
scheme: QuantizationScheme {
bits: 8,
symmetric: false,
signed: true,
method: QuantizationMethod::Linear,
},
};
quantized.insert(layer_name.clone(), quantized_tensor);
}
Ok(quantized)
}
fn quantize_to_fp16(
&self,
weights: &HashMap<String, Tensor>,
) -> Result<HashMap<String, QuantizedTensor>> {
let mut quantized = HashMap::new();
for (layer_name, tensor) in weights {
let tensor_data = tensor.data()?.to_vec();
let quantized_data: Vec<i8> = tensor_data
.iter()
.flat_map(|&x| {
let fp16_bits = f16::from_f32(x).to_bits();
[(fp16_bits & 0xFF) as i8, ((fp16_bits >> 8) & 0xFF) as i8]
})
.collect();
let quantized_tensor = QuantizedTensor {
data: quantized_data,
scales: vec![1.0], zero_points: vec![0],
shape: tensor.shape().to_vec(),
scheme: QuantizationScheme {
bits: 16,
symmetric: true,
signed: true,
method: QuantizationMethod::Linear,
},
};
quantized.insert(layer_name.clone(), quantized_tensor);
}
Ok(quantized)
}
fn quantize_mixed_4_8(
&self,
weights: &HashMap<String, Tensor>,
) -> Result<HashMap<String, QuantizedTensor>> {
let mut quantized = HashMap::new();
for (layer_name, tensor) in weights {
let use_4bit = self.should_use_4bit_for_layer(layer_name, tensor)?;
if use_4bit {
let quantized_4bit = self.quantize_to_int4(
&[(layer_name.clone(), tensor.clone())].iter().cloned().collect(),
)?;
quantized.extend(quantized_4bit);
} else {
let quantized_8bit = self.quantize_to_int8(
&[(layer_name.clone(), tensor.clone())].iter().cloned().collect(),
)?;
quantized.extend(quantized_8bit);
}
}
Ok(quantized)
}
fn quantize_mixed_8_16(
&self,
weights: &HashMap<String, Tensor>,
) -> Result<HashMap<String, QuantizedTensor>> {
let mut quantized = HashMap::new();
for (layer_name, tensor) in weights {
let use_8bit = self.should_use_8bit_for_layer(layer_name, tensor)?;
if use_8bit {
let quantized_8bit = self.quantize_to_int8(
&[(layer_name.clone(), tensor.clone())].iter().cloned().collect(),
)?;
quantized.extend(quantized_8bit);
} else {
let quantized_16bit = self.quantize_to_fp16(
&[(layer_name.clone(), tensor.clone())].iter().cloned().collect(),
)?;
quantized.extend(quantized_16bit);
}
}
Ok(quantized)
}
fn quantize_dynamic(
&self,
weights: &HashMap<String, Tensor>,
) -> Result<HashMap<String, QuantizedTensor>> {
self.quantize_to_int8(weights)
}
fn should_use_4bit_for_layer(&self, layer_name: &str, tensor: &Tensor) -> Result<bool> {
let is_embedding = layer_name.contains("embed") || layer_name.contains("token");
let is_output = layer_name.contains("output") || layer_name.contains("head");
let is_large = tensor.shape().iter().product::<usize>() > 1000000;
Ok(is_embedding || (is_large && !is_output))
}
fn should_use_8bit_for_layer(&self, layer_name: &str, tensor: &Tensor) -> Result<bool> {
let is_attention = layer_name.contains("attn") || layer_name.contains("attention");
let is_output = layer_name.contains("output") || layer_name.contains("head");
let is_norm = layer_name.contains("norm") || layer_name.contains("ln");
Ok(!(is_attention || is_output || is_norm))
}
fn parse_model_weights(&self, model_data: &[u8]) -> Result<HashMap<String, Tensor>> {
#[allow(dead_code)]
let mut weights = HashMap::new();
let format = self.detect_model_format(model_data)?;
match format {
ModelFormat::SafeTensors => {
weights = self.parse_safetensors(model_data)?;
},
ModelFormat::PyTorchPickle => {
weights = self.parse_pytorch_pickle(model_data)?;
},
ModelFormat::TensorFlow => {
weights = self.parse_tensorflow(model_data)?;
},
ModelFormat::ONNX => {
weights = self.parse_onnx(model_data)?;
},
ModelFormat::Custom => {
weights = self.parse_custom_format(model_data)?;
},
}
self.validate_parsed_weights(&weights)?;
Ok(weights)
}
fn calculate_quantization_parameters(
&self,
weights: &HashMap<String, QuantizedTensor>,
) -> Result<QuantizationParameters> {
let mut layer_scales = HashMap::new();
let mut layer_zero_points = HashMap::new();
let mut total_dequant_overhead = 0.0;
for (layer_name, quantized_tensor) in weights {
let scale = quantized_tensor.scales.first().copied().unwrap_or(1.0);
let zero_point = quantized_tensor.zero_points.first().copied().unwrap_or(0);
layer_scales.insert(layer_name.clone(), scale);
layer_zero_points.insert(layer_name.clone(), zero_point);
let tensor_size = quantized_tensor.data.len();
let overhead_factor = match quantized_tensor.scheme.bits {
4 => 0.05, 8 => 0.03, 16 => 0.01, _ => 0.04, };
total_dequant_overhead += (tensor_size as f32 * overhead_factor) / 1000.0;
}
let total_elements: f32 = weights.values().map(|t| t.data.len() as f32).sum();
let global_scale = if total_elements > 0.0 {
layer_scales.values().sum::<f32>() / layer_scales.len() as f32
} else {
1.0
};
Ok(QuantizationParameters {
global_scale,
layer_scales,
layer_zero_points,
dequant_overhead_ms: total_dequant_overhead,
})
}
fn calculate_quantized_size(&self, weights: &HashMap<String, QuantizedTensor>) -> usize {
let mut total_size = 0;
for (layer_name, quantized_tensor) in weights {
let data_size = quantized_tensor.data.len();
let metadata_size = quantized_tensor.scales.len() * 4 + quantized_tensor.zero_points.len() * 4 + quantized_tensor.shape.len() * 8 + layer_name.len() + 32;
total_size += data_size + metadata_size;
}
total_size
}
fn estimate_quality_score(&self, weights: &HashMap<String, QuantizedTensor>) -> Result<f32> {
if weights.is_empty() {
return Ok(1.0);
}
let mut total_quality = 0.0;
let mut total_weight = 0.0;
for (layer_name, quantized_tensor) in weights {
let layer_weight = quantized_tensor.data.len() as f32;
let base_quality = match quantized_tensor.scheme.bits {
4 => 0.85, 8 => 0.93, 16 => 0.98, _ => 0.90, };
let layer_quality_factor = self.get_layer_quality_factor(layer_name);
let adjusted_quality = base_quality * layer_quality_factor;
total_quality += adjusted_quality * layer_weight;
total_weight += layer_weight;
}
let overall_quality = if total_weight > 0.0 { total_quality / total_weight } else { 1.0 };
let calibration_factor = if let Some(ref cal_data) = self.calibration_data {
self.estimate_calibration_quality_impact(cal_data)?
} else {
0.95 };
Ok((overall_quality * calibration_factor).min(1.0))
}
fn benchmark_quantized_model(&self, model: &QuantizedModel) -> Result<QuantizationBenchmarks> {
let mut benchmarks = QuantizationBenchmarks::default();
let original_params = model.metadata.original_size_bytes / 4; benchmarks.original_inference_ms =
self.estimate_inference_time(original_params, MobilePrecision::FP16)?;
let quantized_params = model.metadata.quantized_size_bytes;
let avg_precision = self.estimate_average_precision(&model.weights);
benchmarks.quantized_inference_ms =
self.estimate_inference_time(quantized_params, avg_precision)?;
benchmarks.speedup_factor = if benchmarks.quantized_inference_ms > 0.0 {
benchmarks.original_inference_ms / benchmarks.quantized_inference_ms
} else {
1.0
};
benchmarks.memory_reduction_mb = (model.metadata.original_size_bytes
- model.metadata.quantized_size_bytes) as f32
/ (1024.0 * 1024.0);
benchmarks.power_reduction_mw = self.estimate_power_reduction(&model.weights)?;
Ok(benchmarks)
}
fn detect_model_format(&self, data: &[u8]) -> Result<ModelFormat> {
if data.len() < 8 {
return Err(invalid_input("Model data too small to detect format"));
}
if data.starts_with(b"STFR") || data.starts_with(&[0x53, 0x54, 0x46, 0x52]) {
return Ok(ModelFormat::SafeTensors);
}
if data.starts_with(&[0x80, 0x02]) || data.starts_with(&[0x80, 0x03]) {
return Ok(ModelFormat::PyTorchPickle);
}
if data.starts_with(b"TF") {
return Ok(ModelFormat::TensorFlow);
}
if data.starts_with(&[0x08, 0x01]) {
return Ok(ModelFormat::ONNX);
}
Ok(ModelFormat::Custom)
}
fn parse_safetensors(&self, _data: &[u8]) -> Result<HashMap<String, Tensor>> {
Ok(HashMap::new())
}
fn parse_pytorch_pickle(&self, _data: &[u8]) -> Result<HashMap<String, Tensor>> {
Ok(HashMap::new())
}
fn parse_tensorflow(&self, _data: &[u8]) -> Result<HashMap<String, Tensor>> {
Ok(HashMap::new())
}
fn parse_onnx(&self, _data: &[u8]) -> Result<HashMap<String, Tensor>> {
Ok(HashMap::new())
}
fn parse_custom_format(&self, _data: &[u8]) -> Result<HashMap<String, Tensor>> {
Ok(HashMap::new())
}
fn validate_parsed_weights(&self, weights: &HashMap<String, Tensor>) -> Result<()> {
if weights.is_empty() {
return Err(invalid_input("No weights found in model"));
}
for (layer_name, tensor) in weights {
if tensor.shape().is_empty() {
return Err(invalid_input(format!(
"Invalid tensor shape for layer: {}",
layer_name
)));
}
let total_elements: usize = tensor.shape().iter().product();
if total_elements == 0 {
return Err(invalid_input(format!(
"Empty tensor for layer: {}",
layer_name
)));
}
if total_elements > 100_000_000 {
tracing::warn!(
"Large tensor detected in layer {}: {} elements",
layer_name,
total_elements
);
}
}
Ok(())
}
fn get_layer_quality_factor(&self, layer_name: &str) -> f32 {
if layer_name.contains("output") || layer_name.contains("head") {
0.95 } else if layer_name.contains("attention") || layer_name.contains("attn") {
0.92 } else if layer_name.contains("norm") || layer_name.contains("ln") {
0.98 } else if layer_name.contains("embed") || layer_name.contains("token") {
0.90 } else {
1.0 }
}
fn estimate_calibration_quality_impact(&self, cal_data: &CalibrationDataset) -> Result<f32> {
let sample_factor = (cal_data.samples.len() as f32 / 100.0).min(1.0);
let stats_quality =
if !cal_data.statistics.activation_ranges.is_empty() { 1.0 } else { 0.9 };
Ok(0.95 + 0.05 * sample_factor * stats_quality)
}
fn estimate_inference_time(&self, params: usize, precision: MobilePrecision) -> Result<f32> {
let base_time_per_param = match self.device_info.performance_scores.overall_tier {
PerformanceTier::VeryLow => 0.01, PerformanceTier::Low => 0.008, PerformanceTier::Budget => 0.005, PerformanceTier::Medium => 0.003, PerformanceTier::Mid => 0.002, PerformanceTier::High => 0.001, PerformanceTier::VeryHigh => 0.0007, PerformanceTier::Flagship => 0.0005, };
let precision_factor = match precision {
MobilePrecision::INT4 => 0.5,
MobilePrecision::INT8 => 0.7,
MobilePrecision::FP16 => 1.0,
MobilePrecision::Mixed4_8 => 0.6,
MobilePrecision::Mixed8_16 => 0.85,
MobilePrecision::DYNAMIC => 0.8,
};
let hw_factor = if self.device_info.npu_info.is_some() {
0.6
} else if self.device_info.gpu_info.is_some() {
0.8
} else {
1.0
};
let total_time = params as f32 * base_time_per_param * precision_factor * hw_factor;
Ok(total_time)
}
fn estimate_average_precision(
&self,
weights: &HashMap<String, QuantizedTensor>,
) -> MobilePrecision {
if weights.is_empty() {
return MobilePrecision::FP16;
}
let mut total_bits = 0;
let mut total_tensors = 0;
for tensor in weights.values() {
total_bits += tensor.scheme.bits as u32;
total_tensors += 1;
}
let avg_bits = total_bits as f32 / total_tensors as f32;
match avg_bits.round() as u8 {
4 => MobilePrecision::INT4,
8 => MobilePrecision::INT8,
16 => MobilePrecision::FP16,
_ => MobilePrecision::INT8, }
}
fn estimate_power_reduction(&self, weights: &HashMap<String, QuantizedTensor>) -> Result<f32> {
let mut total_power_reduction = 0.0;
for tensor in weights.values() {
let tensor_size = tensor.data.len() as f32;
let reduction_per_op = match tensor.scheme.bits {
4 => 0.08, 8 => 0.05, 16 => 0.02, _ => 0.04, };
total_power_reduction += tensor_size * reduction_per_op / 1000.0;
}
Ok(total_power_reduction)
}
}
#[derive(Debug, Clone)]
struct HardwareQuantizationConfig {
use_npu_kernels: bool,
use_gpu_kernels: bool,
use_neon_instructions: bool,
arm_specific_kernels: bool,
gpu_memory_optimization: bool,
preferred_precision: MobilePrecision,
}
impl Default for HardwareQuantizationConfig {
fn default() -> Self {
Self {
use_npu_kernels: false,
use_gpu_kernels: false,
use_neon_instructions: false,
arm_specific_kernels: false,
gpu_memory_optimization: false,
preferred_precision: MobilePrecision::INT8,
}
}
}
impl Default for QuantizationBenchmarks {
fn default() -> Self {
Self {
original_inference_ms: 0.0,
quantized_inference_ms: 0.0,
speedup_factor: 1.0,
memory_reduction_mb: 0.0,
power_reduction_mw: 0.0,
}
}
}
impl Default for QuantizationConfig {
fn default() -> Self {
Self {
target_precision: MobilePrecision::INT8,
enable_mixed_precision: true,
dynamic_strategy: DynamicQuantizationStrategy::Adaptive,
hardware_aware: true,
granularity: QuantizationGranularity::PerChannel,
quality_threshold: 0.9,
memory_constraint_mb: 512,
enable_gradient_quantization: false,
kl_threshold: 0.01,
enable_ptq: true,
enable_qat: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::Tensor;
#[test]
fn test_model_format_detection() {
let engine = create_test_engine();
let safetensors_data = b"STFR\x00\x00\x00\x00test data";
let format = engine.detect_model_format(safetensors_data).expect("Operation failed");
assert_eq!(format, ModelFormat::SafeTensors);
let pytorch_data = b"\x80\x02test data";
let format = engine.detect_model_format(pytorch_data).expect("Operation failed");
assert_eq!(format, ModelFormat::PyTorchPickle);
let tf_data = b"TFtest data";
let format = engine.detect_model_format(tf_data).expect("Operation failed");
assert_eq!(format, ModelFormat::TensorFlow);
let onnx_data = b"\x08\x01test data";
let format = engine.detect_model_format(onnx_data).expect("Operation failed");
assert_eq!(format, ModelFormat::ONNX);
let custom_data = b"custom test data";
let format = engine.detect_model_format(custom_data).expect("Operation failed");
assert_eq!(format, ModelFormat::Custom);
}
#[test]
fn test_quantization_parameters_calculation() {
let engine = create_test_engine();
let weights = create_test_quantized_weights();
let params = engine.calculate_quantization_parameters(&weights).expect("Operation failed");
assert!(params.global_scale > 0.0);
assert!(!params.layer_scales.is_empty());
assert!(!params.layer_zero_points.is_empty());
assert!(params.dequant_overhead_ms >= 0.0);
}
#[test]
fn test_quality_score_estimation() {
let engine = create_test_engine();
let weights = create_test_quantized_weights();
let quality = engine.estimate_quality_score(&weights).expect("Operation failed");
assert!((0.0..=1.0).contains(&quality));
}
#[test]
fn test_layer_quality_factors() {
let engine = create_test_engine();
assert_eq!(engine.get_layer_quality_factor("model.output.weight"), 0.95);
assert_eq!(
engine.get_layer_quality_factor("model.attention.weight"),
0.92
);
assert_eq!(
engine.get_layer_quality_factor("model.layer_norm.weight"),
0.98
);
assert_eq!(
engine.get_layer_quality_factor("model.embedding.weight"),
0.90
);
assert_eq!(engine.get_layer_quality_factor("model.hidden.weight"), 1.0);
}
#[test]
fn test_inference_time_estimation() {
let engine = create_test_engine();
let time = engine
.estimate_inference_time(1000, MobilePrecision::INT8)
.expect("Operation failed");
assert!(time > 0.0);
let time_fp16 = engine
.estimate_inference_time(1000, MobilePrecision::FP16)
.expect("Operation failed");
let time_int4 = engine
.estimate_inference_time(1000, MobilePrecision::INT4)
.expect("Operation failed");
assert!(time_int4 < time_fp16);
}
#[test]
fn test_power_reduction_estimation() {
let engine = create_test_engine();
let weights = create_test_quantized_weights();
let power_reduction = engine.estimate_power_reduction(&weights).expect("Operation failed");
assert!(power_reduction >= 0.0);
}
#[test]
fn test_quantized_size_calculation() {
let engine = create_test_engine();
let weights = create_test_quantized_weights();
let size = engine.calculate_quantized_size(&weights);
assert!(size > 0);
}
#[test]
fn test_weight_validation() {
let engine = create_test_engine();
let empty_weights = HashMap::new();
assert!(engine.validate_parsed_weights(&empty_weights).is_err());
let valid_weights = create_test_weights();
assert!(engine.validate_parsed_weights(&valid_weights).is_ok());
}
#[test]
fn test_calibration_data_validation() {
let mut engine = create_test_engine();
let empty_dataset = CalibrationDataset {
samples: vec![],
weights: None,
statistics: DatasetStatistics::default(),
};
assert!(engine.set_calibration_data(empty_dataset).is_err());
let valid_dataset = CalibrationDataset {
samples: vec![Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("Operation failed")],
weights: None,
statistics: DatasetStatistics::default(),
};
assert!(engine.set_calibration_data(valid_dataset).is_ok());
}
fn create_test_engine() -> MobileQuantizationEngine {
let config = QuantizationConfig::default();
let device_info = crate::device_info::MobileDeviceDetector::detect().unwrap_or_else(|_| {
use crate::device_info::*;
MobileDeviceInfo {
platform: crate::MobilePlatform::Generic,
basic_info: BasicDeviceInfo {
platform: crate::MobilePlatform::Generic,
manufacturer: "Test".to_string(),
model: "Test Device".to_string(),
os_version: "1.0".to_string(),
hardware_id: "test".to_string(),
device_generation: Some(2023),
},
cpu_info: CpuInfo {
architecture: "aarch64".to_string(),
total_cores: 8,
core_count: 8,
performance_cores: 4,
efficiency_cores: 4,
max_frequency_mhz: Some(3000),
l1_cache_kb: Some(64),
l2_cache_kb: Some(512),
l3_cache_kb: Some(4096),
features: vec!["neon".to_string(), "fp16".to_string()],
simd_support: SimdSupport::Advanced,
},
memory_info: MemoryInfo {
total_mb: 8192,
available_mb: 6144,
total_memory: 8192,
available_memory: 6144,
bandwidth_mbps: Some(51200),
memory_type: "LPDDR5".to_string(),
frequency_mhz: Some(6400),
is_low_memory_device: false,
},
gpu_info: Some(GpuInfo {
vendor: "ARM".to_string(),
model: "Mali-G78".to_string(),
driver_version: "1.0".to_string(),
memory_mb: Some(2048),
compute_units: Some(14),
supported_apis: vec![GpuApi::OpenGLES3, GpuApi::Vulkan11],
performance_tier: GpuPerformanceTier::High,
}),
npu_info: None,
thermal_info: ThermalInfo {
current_state: ThermalState::Nominal,
state: ThermalState::Nominal,
throttling_supported: true,
temperature_sensors: vec![],
thermal_zones: vec![],
},
power_info: PowerInfo {
battery_capacity_mah: Some(4000),
battery_level_percent: Some(80),
battery_level: Some(80),
battery_health_percent: Some(100),
charging_status: ChargingStatus::Discharging,
is_charging: false,
power_save_mode: false,
low_power_mode_available: true,
},
available_backends: vec![crate::MobileBackend::CPU, crate::MobileBackend::GPU],
performance_scores: PerformanceScores {
cpu_single_core: Some(1200),
cpu_multi_core: Some(4800),
gpu_score: Some(2500),
memory_score: Some(1800),
overall_tier: PerformanceTier::Mid,
tier: PerformanceTier::Mid,
},
}
});
MobileQuantizationEngine::new(config, device_info, None).expect("Operation failed")
}
fn create_test_quantized_weights() -> HashMap<String, QuantizedTensor> {
let mut weights = HashMap::new();
weights.insert(
"layer1.weight".to_string(),
QuantizedTensor {
data: vec![1, 2, 3, 4, 5],
scales: vec![0.1],
zero_points: vec![0],
shape: vec![5],
scheme: QuantizationScheme {
bits: 8,
symmetric: false,
signed: true,
method: QuantizationMethod::Linear,
},
},
);
weights.insert(
"layer2.weight".to_string(),
QuantizedTensor {
data: vec![6, 7, 8, 9, 10],
scales: vec![0.2],
zero_points: vec![1],
shape: vec![5],
scheme: QuantizationScheme {
bits: 4,
symmetric: false,
signed: false,
method: QuantizationMethod::Linear,
},
},
);
weights
}
fn create_test_weights() -> HashMap<String, Tensor> {
let mut weights = HashMap::new();
weights.insert(
"layer1.weight".to_string(),
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).expect("Operation failed"),
);
weights.insert(
"layer2.weight".to_string(),
Tensor::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0], &[5]).expect("Operation failed"),
);
weights
}
}