use crate::config::{ObserverType, QScheme, QuantBackend, QuantConfig};
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConfigObjective {
MaximumCompression,
MaximumAccuracy,
BalancedQuality,
MaximumSpeed,
MinimumMemory,
EdgeOptimized,
}
#[derive(Debug, Clone)]
pub struct TensorProfile {
pub shape: Vec<usize>,
pub numel: usize,
pub stats: TensorStats,
pub sparsity: f32,
pub distribution: DistributionProfile,
}
#[derive(Debug, Clone)]
pub struct TensorStats {
pub min: f32,
pub max: f32,
pub mean: f32,
pub std_dev: f32,
pub range: f32,
pub has_outliers: bool,
pub near_zero_ratio: f32,
}
#[derive(Debug, Clone, PartialEq)]
pub enum DistributionProfile {
Normal,
Uniform,
HeavyTailed,
Bimodal,
Skewed,
Sparse,
}
pub struct AutoConfigurator {
objective: ConfigObjective,
history: Vec<ConfigPerformance>,
feature_weights: FeatureWeights,
}
#[derive(Debug, Clone)]
struct ConfigPerformance {
#[allow(dead_code)]
config: QuantConfig,
profile: TensorProfile,
error: f32,
#[allow(dead_code)]
compression: f32,
#[allow(dead_code)]
speedup: Option<f32>,
}
#[derive(Debug, Clone)]
struct FeatureWeights {
range_weight: f32,
sparsity_weight: f32,
distribution_weight: f32,
size_weight: f32,
}
impl Default for FeatureWeights {
fn default() -> Self {
Self {
range_weight: 1.0,
sparsity_weight: 0.8,
distribution_weight: 0.9,
size_weight: 0.7,
}
}
}
impl AutoConfigurator {
pub fn new(objective: ConfigObjective) -> Self {
Self {
objective,
history: Vec::new(),
feature_weights: FeatureWeights::default(),
}
}
pub fn recommend(
&self,
tensor: &Tensor,
constraints: Option<ConfigConstraints>,
) -> TorshResult<QuantConfig> {
let profile = self.analyze_tensor(tensor)?;
let config = self.select_configuration(&profile, constraints)?;
Ok(config)
}
pub fn recommend_ranked(
&self,
tensor: &Tensor,
top_k: usize,
constraints: Option<ConfigConstraints>,
) -> TorshResult<Vec<(QuantConfig, f32)>> {
let profile = self.analyze_tensor(tensor)?;
let mut candidates = self.generate_candidates(&profile, constraints)?;
for (config, score) in &mut candidates {
*score = self.score_configuration(config, &profile);
}
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
candidates.truncate(top_k);
Ok(candidates)
}
pub fn update_performance(
&mut self,
config: &QuantConfig,
tensor: &Tensor,
observed_error: f32,
observed_compression: f32,
speedup: Option<f32>,
) -> TorshResult<()> {
let profile = self.analyze_tensor(tensor)?;
let performance = ConfigPerformance {
config: config.clone(),
profile,
error: observed_error,
compression: observed_compression,
speedup,
};
self.history.push(performance);
if self.history.len() >= 10 {
self.update_feature_weights();
}
Ok(())
}
fn analyze_tensor(&self, tensor: &Tensor) -> TorshResult<TensorProfile> {
let data = tensor.data()?;
let shape = tensor.shape().dims().to_vec();
let numel = tensor.shape().numel();
let stats = self.calculate_stats(&data)?;
let sparsity = self.calculate_sparsity(&data);
let distribution = self.classify_distribution(&data, &stats);
Ok(TensorProfile {
shape,
numel,
stats,
sparsity,
distribution,
})
}
fn calculate_stats(&self, data: &[f32]) -> TorshResult<TensorStats> {
if data.is_empty() {
return Err(TorshError::InvalidArgument(
"Cannot calculate stats for empty tensor".to_string(),
));
}
let min = data.iter().copied().fold(f32::INFINITY, f32::min);
let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let range = max - min;
let mean = data.iter().sum::<f32>() / data.len() as f32;
let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
let std_dev = variance.sqrt();
let mut sorted = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let q1_idx = sorted.len() / 4;
let q3_idx = 3 * sorted.len() / 4;
let q1 = sorted[q1_idx];
let q3 = sorted[q3_idx];
let iqr = q3 - q1;
let outlier_threshold_low = q1 - 1.5 * iqr;
let outlier_threshold_high = q3 + 1.5 * iqr;
let has_outliers = data
.iter()
.any(|&x| x < outlier_threshold_low || x > outlier_threshold_high);
let zero_threshold = range.abs() * 0.01; let near_zero_count = data.iter().filter(|&&x| x.abs() < zero_threshold).count();
let near_zero_ratio = near_zero_count as f32 / data.len() as f32;
Ok(TensorStats {
min,
max,
mean,
std_dev,
range,
has_outliers,
near_zero_ratio,
})
}
fn calculate_sparsity(&self, data: &[f32]) -> f32 {
let zero_count = data.iter().filter(|&&x| x.abs() < 1e-8).count();
zero_count as f32 / data.len() as f32
}
fn classify_distribution(&self, data: &[f32], stats: &TensorStats) -> DistributionProfile {
if stats.near_zero_ratio > 0.6 {
return DistributionProfile::Sparse;
}
let skewness = data
.iter()
.map(|&x| ((x - stats.mean) / stats.std_dev).powi(3))
.sum::<f32>()
/ data.len() as f32;
let kurtosis = data
.iter()
.map(|&x| ((x - stats.mean) / stats.std_dev).powi(4))
.sum::<f32>()
/ data.len() as f32;
if skewness.abs() > 1.0 {
DistributionProfile::Skewed
} else if kurtosis > 4.0 {
DistributionProfile::HeavyTailed
} else if (kurtosis - 3.0).abs() < 0.5 && skewness.abs() < 0.5 {
DistributionProfile::Normal
} else if kurtosis < 2.0 {
DistributionProfile::Uniform
} else {
DistributionProfile::Bimodal
}
}
fn select_configuration(
&self,
profile: &TensorProfile,
constraints: Option<ConfigConstraints>,
) -> TorshResult<QuantConfig> {
let mut config = match self.objective {
ConfigObjective::MaximumCompression => self.select_for_compression(profile),
ConfigObjective::MaximumAccuracy => self.select_for_accuracy(profile),
ConfigObjective::BalancedQuality => self.select_balanced(profile),
ConfigObjective::MaximumSpeed => self.select_for_speed(profile),
ConfigObjective::MinimumMemory => self.select_for_memory(profile),
ConfigObjective::EdgeOptimized => self.select_for_edge(profile),
}?;
if let Some(constraints) = constraints {
config = self.apply_constraints(config, constraints)?;
}
Ok(config)
}
fn select_for_compression(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
if profile.sparsity > 0.5 {
if profile.distribution == DistributionProfile::Sparse {
Ok(QuantConfig::binary())
} else {
Ok(QuantConfig::ternary())
}
} else if profile.numel < 1000 {
Ok(QuantConfig::int4())
} else {
let group_size = (profile.numel / 100).min(128).max(16);
Ok(QuantConfig::group_wise(0, group_size))
}
}
fn select_for_accuracy(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
let mut config = if profile.stats.has_outliers
|| profile.distribution == DistributionProfile::HeavyTailed
{
QuantConfig::int8().with_observer(ObserverType::Histogram)
} else if profile.stats.range > 1000.0 {
QuantConfig::per_channel(0).with_observer(ObserverType::Percentile)
} else {
QuantConfig::int8().with_observer(ObserverType::Percentile)
};
if profile.stats.range > 10000.0 {
config = config.with_reduce_range(crate::config::ReduceRange::Reduce);
}
Ok(config)
}
fn select_balanced(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
if profile.numel > 100000 && profile.sparsity < 0.1 {
let group_size = if profile.stats.has_outliers { 32 } else { 64 };
Ok(QuantConfig::group_wise(0, group_size).with_observer(ObserverType::Histogram))
} else if profile.sparsity > 0.3 {
Ok(QuantConfig::int4().with_observer(ObserverType::MinMax))
} else {
Ok(QuantConfig::int8().with_observer(ObserverType::Histogram))
}
}
fn select_for_speed(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
let mut config = if profile.numel < 10000 {
QuantConfig::int8()
} else {
QuantConfig::int8().with_observer(ObserverType::MinMax) };
config = config.with_backend(QuantBackend::Fbgemm);
Ok(config)
}
fn select_for_memory(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
if profile.sparsity > 0.4 {
Ok(QuantConfig::binary())
} else if profile.numel > 50000 {
Ok(QuantConfig::int4())
} else {
Ok(QuantConfig::int8())
}
}
fn select_for_edge(&self, _profile: &TensorProfile) -> TorshResult<QuantConfig> {
Ok(QuantConfig::int8()
.with_backend(QuantBackend::Qnnpack)
.with_observer(ObserverType::MinMax))
}
fn generate_candidates(
&self,
profile: &TensorProfile,
constraints: Option<ConfigConstraints>,
) -> TorshResult<Vec<(QuantConfig, f32)>> {
let mut candidates = vec![
(QuantConfig::int8(), 0.0),
(QuantConfig::int4(), 0.0),
(QuantConfig::per_channel(0), 0.0),
];
if profile.sparsity > 0.3 {
candidates.push((QuantConfig::binary(), 0.0));
candidates.push((QuantConfig::ternary(), 0.0));
}
if profile.numel > 10000 {
candidates.push((QuantConfig::group_wise(0, 64), 0.0));
candidates.push((QuantConfig::group_wise(0, 32), 0.0));
}
if let Some(constraints) = constraints {
candidates.retain(|(config, _)| self.satisfies_constraints(config, &constraints));
}
Ok(candidates)
}
fn score_configuration(&self, config: &QuantConfig, profile: &TensorProfile) -> f32 {
let mut score = 0.0;
let scheme_score = self.score_scheme(config.scheme, profile);
score += scheme_score * self.feature_weights.distribution_weight;
let observer_score = self.score_observer(config.observer_type, profile);
score += observer_score * self.feature_weights.range_weight;
let backend_score = self.score_backend(config.backend, profile);
score += backend_score * 0.5;
let size_score = self.score_size(config.scheme, profile.numel);
score += size_score * self.feature_weights.size_weight;
score
}
fn score_scheme(&self, scheme: QScheme, _profile: &TensorProfile) -> f32 {
match (self.objective, scheme) {
(ConfigObjective::MaximumCompression, QScheme::Binary) => 10.0,
(ConfigObjective::MaximumCompression, QScheme::Ternary) => 9.0,
(ConfigObjective::MaximumCompression, QScheme::Int4PerTensor) => 8.0,
(ConfigObjective::MaximumAccuracy, QScheme::PerChannelAffine) => 10.0,
(ConfigObjective::MaximumAccuracy, QScheme::PerTensorAffine) => 8.5,
(ConfigObjective::MaximumSpeed, QScheme::PerTensorAffine) => 10.0,
(ConfigObjective::MaximumSpeed, QScheme::PerTensorSymmetric) => 9.5,
(ConfigObjective::BalancedQuality, QScheme::GroupWise) => 9.0,
(ConfigObjective::BalancedQuality, QScheme::PerTensorAffine) => 8.0,
_ => 5.0,
}
}
fn score_observer(&self, observer: ObserverType, profile: &TensorProfile) -> f32 {
match observer {
ObserverType::Histogram if profile.stats.has_outliers => 10.0,
ObserverType::Percentile
if profile.distribution == DistributionProfile::HeavyTailed =>
{
9.5
}
ObserverType::MinMax => 7.0, _ => 6.0,
}
}
fn score_backend(&self, backend: QuantBackend, _profile: &TensorProfile) -> f32 {
match (self.objective, backend) {
(ConfigObjective::MaximumSpeed, QuantBackend::Fbgemm) => 10.0,
(ConfigObjective::EdgeOptimized, QuantBackend::Qnnpack) => 10.0,
_ => 5.0,
}
}
fn score_size(&self, scheme: QScheme, numel: usize) -> f32 {
match scheme {
QScheme::GroupWise if numel > 100000 => 10.0,
QScheme::PerChannelAffine if numel > 10000 => 8.0,
QScheme::Binary if numel < 1000 => 3.0, _ => 5.0,
}
}
fn apply_constraints(
&self,
mut config: QuantConfig,
constraints: ConfigConstraints,
) -> TorshResult<QuantConfig> {
if let Some(backend) = constraints.required_backend {
config = config.with_backend(backend);
}
if let Some(min_bits) = constraints.min_bits {
if min_bits >= 8
&& matches!(
config.scheme,
QScheme::Int4PerTensor | QScheme::Binary | QScheme::Ternary
)
{
config = QuantConfig::int8();
}
}
Ok(config)
}
fn satisfies_constraints(&self, config: &QuantConfig, constraints: &ConfigConstraints) -> bool {
if let Some(backend) = constraints.required_backend {
if config.backend != backend {
return false;
}
}
if let Some(min_bits) = constraints.min_bits {
let scheme_bits = match config.scheme {
QScheme::Binary => 1,
QScheme::Ternary => 2,
QScheme::Int4PerTensor | QScheme::Int4PerChannel => 4,
_ => 8,
};
if scheme_bits < min_bits {
return false;
}
}
true
}
fn update_feature_weights(&mut self) {
if self.history.len() < 10 {
return;
}
let sparse_configs: Vec<&ConfigPerformance> = self
.history
.iter()
.filter(|p| p.profile.sparsity > 0.3)
.collect();
let dense_configs: Vec<&ConfigPerformance> = self
.history
.iter()
.filter(|p| p.profile.sparsity <= 0.3)
.collect();
if !sparse_configs.is_empty() {
let avg_sparse_error =
sparse_configs.iter().map(|p| p.error).sum::<f32>() / sparse_configs.len() as f32;
let avg_dense_error =
dense_configs.iter().map(|p| p.error).sum::<f32>() / dense_configs.len() as f32;
if avg_sparse_error < avg_dense_error {
self.feature_weights.sparsity_weight *= 1.1;
} else {
self.feature_weights.sparsity_weight *= 0.95;
}
self.feature_weights.sparsity_weight =
self.feature_weights.sparsity_weight.clamp(0.5, 2.0);
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ConfigConstraints {
pub required_backend: Option<QuantBackend>,
pub min_bits: Option<u32>,
pub max_memory: Option<usize>,
pub target_compression: Option<f32>,
}
impl ConfigConstraints {
pub fn new() -> Self {
Self::default()
}
pub fn with_backend(mut self, backend: QuantBackend) -> Self {
self.required_backend = Some(backend);
self
}
pub fn with_min_bits(mut self, bits: u32) -> Self {
self.min_bits = Some(bits);
self
}
pub fn with_max_memory(mut self, bytes: usize) -> Self {
self.max_memory = Some(bytes);
self
}
pub fn with_target_compression(mut self, ratio: f32) -> Self {
self.target_compression = Some(ratio);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::tensor_1d;
#[test]
fn test_auto_configurator_basic() {
let configurator = AutoConfigurator::new(ConfigObjective::BalancedQuality);
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let tensor = tensor_1d(&data).unwrap();
let config = configurator.recommend(&tensor, None).unwrap();
assert!(config.validate().is_ok());
}
#[test]
fn test_tensor_profile_analysis() {
let configurator = AutoConfigurator::new(ConfigObjective::MaximumAccuracy);
let data = vec![1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0, 100.0]; let tensor = tensor_1d(&data).unwrap();
let profile = configurator.analyze_tensor(&tensor).unwrap();
assert!(
profile.stats.has_outliers,
"Expected outliers to be detected"
);
assert_eq!(profile.numel, 10);
assert!(profile.stats.max > 90.0, "Max value should be around 100");
}
#[test]
fn test_sparse_tensor_recommendation() {
let configurator = AutoConfigurator::new(ConfigObjective::MaximumCompression);
let data = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 2.0];
let tensor = tensor_1d(&data).unwrap();
let config = configurator.recommend(&tensor, None).unwrap();
assert!(matches!(config.scheme, QScheme::Binary | QScheme::Ternary));
}
#[test]
fn test_constraints_application() {
let configurator = AutoConfigurator::new(ConfigObjective::MaximumSpeed);
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = tensor_1d(&data).unwrap();
let constraints = ConfigConstraints::new()
.with_backend(QuantBackend::Qnnpack)
.with_min_bits(8);
let config = configurator.recommend(&tensor, Some(constraints)).unwrap();
assert_eq!(config.backend, QuantBackend::Qnnpack);
}
#[test]
fn test_ranked_recommendations() {
let configurator = AutoConfigurator::new(ConfigObjective::BalancedQuality);
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let tensor = tensor_1d(&data).unwrap();
let ranked = configurator.recommend_ranked(&tensor, 3, None).unwrap();
assert_eq!(ranked.len(), 3);
assert!(ranked[0].1 >= ranked[1].1);
assert!(ranked[1].1 >= ranked[2].1);
}
#[test]
fn test_performance_update() {
let mut configurator = AutoConfigurator::new(ConfigObjective::MaximumAccuracy);
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = tensor_1d(&data).unwrap();
let config = QuantConfig::int8();
configurator
.update_performance(&config, &tensor, 0.1, 4.0, Some(1.5))
.unwrap();
assert_eq!(configurator.history.len(), 1);
}
#[test]
fn test_distribution_classification() {
let configurator = AutoConfigurator::new(ConfigObjective::BalancedQuality);
let normal_data = vec![1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0, 2.0];
let tensor = tensor_1d(&normal_data).unwrap();
let _profile = configurator.analyze_tensor(&tensor).unwrap();
let sparse_data = vec![0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0];
let tensor = tensor_1d(&sparse_data).unwrap();
let _profile = configurator.analyze_tensor(&tensor).unwrap();
assert_eq!(_profile.distribution, DistributionProfile::Sparse);
}
#[test]
fn test_objective_specific_selection() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let tensor = tensor_1d(&data).unwrap();
let objectives = vec![
ConfigObjective::MaximumCompression,
ConfigObjective::MaximumAccuracy,
ConfigObjective::BalancedQuality,
ConfigObjective::MaximumSpeed,
ConfigObjective::MinimumMemory,
ConfigObjective::EdgeOptimized,
];
for objective in objectives {
let configurator = AutoConfigurator::new(objective);
let config = configurator.recommend(&tensor, None).unwrap();
assert!(
config.validate().is_ok(),
"Failed for objective {:?}",
objective
);
}
}
}