use crate::{config::QuantConfig, observers::Observer};
use torsh_core::{error::Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
fn quantize_per_tensor(
tensor: &Tensor,
scale: f32,
zero_point: i32,
_dtype: torsh_core::DType,
) -> TorshResult<Tensor> {
let (quantized, _, _) =
crate::algorithms::quantize_per_tensor_affine(tensor, scale, zero_point)?;
Ok(quantized)
}
#[allow(dead_code)]
fn dequantize(tensor: &Tensor, scale: f32, zero_point: i32) -> TorshResult<Tensor> {
crate::algorithms::dequantize_per_tensor_affine(tensor, scale, zero_point)
}
pub fn validate_config_with_suggestions(config: &QuantConfig) -> TorshResult<Vec<String>> {
use crate::config::{ObserverType, QScheme, QuantBackend};
let mut suggestions = Vec::new();
config.validate()?;
match config.scheme {
QScheme::PerChannelAffine | QScheme::PerChannelSymmetric => {
if config.observer_type == ObserverType::MinMax {
suggestions.push("Consider using Histogram observer for per-channel quantization for better accuracy".to_string());
}
}
QScheme::GroupWise => {
if let Some(group_size) = config.group_size {
if group_size < 8 {
suggestions.push("Very small group sizes may not provide significant benefits over per-channel quantization".to_string());
} else if group_size > 128 {
suggestions.push(
"Large group sizes may reduce the benefits of group-wise quantization"
.to_string(),
);
}
}
}
QScheme::Int4PerTensor | QScheme::Int4PerChannel => {
if config.observer_type == ObserverType::MinMax {
suggestions.push("Consider using Histogram observer for INT4 quantization to handle outliers better".to_string());
}
}
QScheme::Binary | QScheme::Ternary => {
if config.observer_type != ObserverType::MinMax {
suggestions.push(
"MinMax observer is typically sufficient for binary/ternary quantization"
.to_string(),
);
}
}
_ => {}
}
if config.backend == QuantBackend::Native {
suggestions.push(
"Consider using FBGEMM or QNNPACK backends for better performance in production"
.to_string(),
);
}
if config.enable_fake_quant && config.observer_type != ObserverType::MovingAverage {
suggestions
.push("MovingAverage observer is recommended for QAT (fake quantization)".to_string());
}
Ok(suggestions)
}
pub fn create_optimized_config(use_case: &str, target_platform: &str) -> TorshResult<QuantConfig> {
use crate::config::{ObserverType, QuantBackend, ReduceRange};
let base_config = match use_case.to_lowercase().as_str() {
"inference_cpu" => QuantConfig::int8()
.with_backend(QuantBackend::Fbgemm)
.with_observer(ObserverType::Histogram),
"inference_mobile" => QuantConfig::int8()
.with_backend(QuantBackend::Qnnpack)
.with_observer(ObserverType::MinMax)
.with_reduce_range(ReduceRange::Reduce),
"training" => QuantConfig::qat().with_observer(ObserverType::MovingAverage),
"extreme_compression" => QuantConfig::int4().with_observer(ObserverType::Histogram),
"transformers" => QuantConfig::group_wise(0, 64).with_observer(ObserverType::Histogram),
"edge_device" => QuantConfig::binary().with_observer(ObserverType::MinMax),
_ => {
return Err(TorshError::InvalidArgument(format!(
"Unknown use case: {use_case}"
)))
}
};
let optimized_config = match target_platform.to_lowercase().as_str() {
"x86" | "x64" => base_config.with_backend(QuantBackend::Fbgemm),
"arm" | "mobile" => base_config.with_backend(QuantBackend::Qnnpack),
"gpu" => base_config.with_backend(QuantBackend::Native),
_ => base_config,
};
Ok(optimized_config)
}
pub fn quantize_batch_consistent(
tensors: &[&Tensor],
config: &QuantConfig,
) -> TorshResult<Vec<(Tensor, f32, i32)>> {
if tensors.is_empty() {
return Ok(Vec::new());
}
let mut global_observer = Observer::new(config.observer_type);
for tensor in tensors {
global_observer.update(tensor)?;
}
let (global_scale, global_zero_point) = global_observer.calculate_qparams(config.dtype)?;
let mut results = Vec::new();
for tensor in tensors {
let quantized = quantize_per_tensor(tensor, global_scale, global_zero_point, config.dtype)?;
results.push((quantized, global_scale, global_zero_point));
}
Ok(results)
}
pub fn diagnose_quantization_failure(
tensor: &Tensor,
config: &QuantConfig,
error: &TorshError,
) -> String {
let mut diagnosis = format!("Quantization failed with error: {error}\n\n");
let shape = tensor.shape();
let data_result = tensor.data();
diagnosis.push_str("Tensor Analysis:\n");
diagnosis.push_str(&format!(" Shape: {:?}\n", shape.dims()));
diagnosis.push_str(&format!(" Total elements: {}\n", shape.numel()));
if let Ok(data) = data_result {
let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let has_nan = data.iter().any(|&x| x.is_nan());
let has_inf = data.iter().any(|&x| x.is_infinite());
diagnosis.push_str(&format!(" Data range: [{min_val:.6}, {max_val:.6}]\n"));
diagnosis.push_str(&format!(" Contains NaN: {has_nan}\n"));
diagnosis.push_str(&format!(" Contains Inf: {has_inf}\n"));
if has_nan || has_inf {
diagnosis.push_str(
"\nSuggestion: Clean tensor data to remove NaN/Inf values before quantization.\n",
);
}
if max_val - min_val < 1e-6 {
diagnosis.push_str("\nSuggestion: Tensor has very small dynamic range. Consider using a different tensor or adjusting the quantization scheme.\n");
}
}
diagnosis.push_str("\nConfiguration Analysis:\n");
diagnosis.push_str(&format!(" Scheme: {:?}\n", config.scheme));
diagnosis.push_str(&format!(" Observer: {:?}\n", config.observer_type));
diagnosis.push_str(&format!(" Backend: {:?}\n", config.backend));
match config.validate() {
Ok(_) => diagnosis.push_str(" Configuration is valid\n"),
Err(e) => diagnosis.push_str(&format!(" Configuration error: {e}\n")),
}
diagnosis.push_str("\nRecovery Suggestions:\n");
diagnosis.push_str(
"1. Try a simpler quantization scheme (e.g., PerTensorAffine with MinMax observer)\n",
);
diagnosis.push_str("2. Use quantize_with_fallback() for automatic error recovery\n");
diagnosis.push_str("3. Check tensor data for NaN/Inf values\n");
diagnosis.push_str("4. Ensure tensor has sufficient dynamic range\n");
diagnosis
.push_str("5. Try a different observer type (Histogram for outlier-robust quantization)\n");
diagnosis
}
pub fn get_optimization_hints(tensor: &Tensor, config: &QuantConfig) -> Vec<String> {
use crate::config::{ObserverType, QScheme};
let mut hints = Vec::new();
let shape = tensor.shape();
let numel = shape.numel();
if numel > 1_000_000 {
hints.push("Large tensor detected. Consider using parallel processing with Rayon for better performance.".to_string());
if config.observer_type == ObserverType::Percentile {
hints.push("For large tensors, Histogram observer may be more memory-efficient than Percentile observer.".to_string());
}
}
if shape.dims().len() >= 2 && shape.dims().iter().any(|&dim| dim > 16) {
hints.push("Multi-channel tensor detected. Per-channel or group-wise quantization may provide better accuracy.".to_string());
}
match config.scheme {
QScheme::PerTensorAffine | QScheme::PerTensorSymmetric => {
if shape.dims().len() > 2 {
hints.push("Consider per-channel quantization for better accuracy with multi-dimensional tensors.".to_string());
}
}
QScheme::GroupWise => {
if let Some(group_size) = config.group_size {
let total_elements = shape.dims().iter().product::<usize>();
if total_elements / group_size < 4 {
hints.push("Too few groups for group-wise quantization. Consider per-tensor quantization instead.".to_string());
}
}
}
QScheme::Int4PerTensor | QScheme::Int4PerChannel => {
hints.push("INT4 quantization detected. Ensure your inference backend supports INT4 operations.".to_string());
}
QScheme::Binary | QScheme::Ternary => {
hints.push(
"Extreme quantization scheme detected. Verify accuracy requirements are met."
.to_string(),
);
}
_ => {}
}
hints
}
pub fn export_config_to_json(config: &QuantConfig) -> TorshResult<String> {
match serde_json::to_string_pretty(config) {
Ok(json) => Ok(json),
Err(e) => Err(TorshError::InvalidArgument(format!(
"Failed to serialize config: {e}"
))),
}
}
pub fn import_config_from_json(json: &str) -> TorshResult<QuantConfig> {
match serde_json::from_str(json) {
Ok(config) => Ok(config),
Err(e) => Err(TorshError::InvalidArgument(format!(
"Failed to deserialize config: {e}"
))),
}
}