#![allow(unused_variables)]
use crate::errors::{file_not_found, invalid_input, runtime_error, TrustformersError};
use crate::tensor::Tensor;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationDataset {
pub name: String,
#[serde(skip)]
pub samples: Vec<Tensor>,
#[serde(skip)]
pub targets: Option<Vec<Tensor>>,
pub metadata: CalibrationMetadata,
pub statistics: DatasetStatistics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationMetadata {
pub description: String,
pub source: String,
pub version: String,
pub created_at: u64,
pub tags: Vec<String>,
pub model_type: String,
pub recommended_methods: Vec<CalibrationMethod>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetStatistics {
pub sample_count: usize,
pub input_shapes: Vec<Vec<usize>>,
pub statistics: TensorStatistics,
pub dynamic_range: DynamicRange,
pub distribution: DistributionAnalysis,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorStatistics {
pub mean: Vec<f32>,
pub std: Vec<f32>,
pub min: Vec<f32>,
pub max: Vec<f32>,
pub percentiles: Vec<Vec<f32>>,
pub skewness: Vec<f32>,
pub kurtosis: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DynamicRange {
pub overall_range: f32,
pub channel_ranges: Vec<f32>,
pub outlier_ratio: f32,
pub suggested_clip_min: f32,
pub suggested_clip_max: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributionAnalysis {
pub distribution_type: DistributionType,
pub normality_p_value: f32,
pub entropy: f32,
pub concentration: f32,
pub is_multimodal: bool,
pub mode_count: Option<usize>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub enum DistributionType {
Normal,
Uniform,
Exponential,
Laplace,
Gamma,
Beta,
Multimodal,
Unknown,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Hash, Eq)]
pub enum CalibrationMethod {
Entropy,
Percentile,
MSE,
SQNR,
CrossEntropy,
Hessian,
ActivationAware,
Smooth,
SensitivityBased,
Learned,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationConfig {
pub method: CalibrationMethod,
pub fallback_methods: Vec<CalibrationMethod>,
pub parameters: HashMap<String, CalibrationParameter>,
pub quality_thresholds: QualityThresholds,
pub cross_validation: CrossValidationConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CalibrationParameter {
Float(f32),
Int(i32),
Bool(bool),
String(String),
FloatArray(Vec<f32>),
IntArray(Vec<i32>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QualityThresholds {
pub min_accuracy_retention: f32,
pub max_sqnr_degradation: f32,
pub max_kl_divergence: f32,
pub max_latency_increase: f32,
pub min_compression_ratio: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossValidationConfig {
pub enabled: bool,
pub folds: usize,
pub validation_split: f32,
pub random_seed: u64,
pub stratified: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationResult {
pub method: CalibrationMethod,
pub primary_success: bool,
pub parameters: CalibrationParameters,
pub quality_metrics: QualityMetrics,
pub cross_validation: Option<CrossValidationResults>,
pub method_comparison: Option<MethodComparison>,
pub recommendations: Vec<CalibrationRecommendation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationParameters {
pub scales: HashMap<String, Vec<f32>>,
pub zero_points: HashMap<String, Vec<i32>>,
pub clip_ranges: HashMap<String, (f32, f32)>,
pub bit_allocations: HashMap<String, Vec<u8>>,
pub extra_params: HashMap<String, CalibrationParameter>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QualityMetrics {
pub accuracy_retention: f32,
pub sqnr_db: f32,
pub kl_divergence: f32,
pub compression_ratio: f32,
pub speedup_factor: f32,
pub memory_reduction: f32,
pub layer_metrics: HashMap<String, LayerQualityMetrics>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerQualityMetrics {
pub layer_name: String,
pub layer_type: String,
pub quantization_error: f32,
pub distribution_similarity: f32,
pub gradient_preservation: f32,
pub activation_preservation: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossValidationResults {
pub mean_metrics: QualityMetrics,
pub std_metrics: QualityMetrics,
pub fold_results: Vec<QualityMetrics>,
pub cv_score: f32,
pub stability_score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MethodComparison {
pub method_results: HashMap<CalibrationMethod, QualityMetrics>,
pub method_ranking: Vec<(CalibrationMethod, f32)>,
pub recommended_method: CalibrationMethod,
pub trade_offs: Vec<TradeOffAnalysis>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TradeOffAnalysis {
pub method: CalibrationMethod,
pub accuracy_compression: f32,
pub speed_quality: f32,
pub memory_accuracy: f32,
pub balance_score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationRecommendation {
pub recommendation_type: RecommendationType,
pub description: String,
pub expected_improvement: f32,
pub difficulty: u8,
pub priority: u8,
pub action_steps: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RecommendationType {
IncreaseDataset,
TryDifferentMethod,
AdjustBitWidth,
UseMixedBit,
ApplyClipping,
ChangeGranularity,
PreprocessDataset,
UseEnsemble,
PostQuantTuning,
OptimizeHardware,
}
pub struct CalibrationToolkit {
datasets: HashMap<String, CalibrationDataset>,
#[allow(dead_code)]
configs: HashMap<String, CalibrationConfig>,
history: Vec<CalibrationResult>,
cache: HashMap<String, CalibrationResult>,
}
impl CalibrationToolkit {
pub fn new() -> Self {
Self {
datasets: HashMap::new(),
configs: HashMap::new(),
history: Vec::new(),
cache: HashMap::new(),
}
}
pub fn register_dataset(
&mut self,
dataset: CalibrationDataset,
) -> Result<(), TrustformersError> {
self.validate_dataset(&dataset)?;
let mut dataset = dataset;
if dataset.statistics.sample_count == 0 {
dataset.statistics = self.calculate_dataset_statistics(&dataset.samples)?;
}
self.datasets.insert(dataset.name.clone(), dataset);
Ok(())
}
pub fn create_dataset(
&self,
name: String,
samples: Vec<Tensor>,
metadata: CalibrationMetadata,
) -> Result<CalibrationDataset, TrustformersError> {
let statistics = self.calculate_dataset_statistics(&samples)?;
Ok(CalibrationDataset {
name,
samples,
targets: None,
metadata,
statistics,
})
}
pub fn load_dataset<P: AsRef<Path>>(
&mut self,
path: P,
) -> Result<CalibrationDataset, TrustformersError> {
use std::fs;
let path = path.as_ref();
let contents = fs::read_to_string(path).map_err(|e| runtime_error(e.to_string()))?;
if path.extension().and_then(|s| s.to_str()) == Some("json") {
let dataset: CalibrationDataset = serde_json::from_str(&contents)
.map_err(|e| runtime_error(format!("Failed to parse JSON dataset: {}", e)))?;
self.datasets.insert(dataset.name.clone(), dataset.clone());
Ok(dataset)
} else {
Err(invalid_input(format!(
"Unsupported dataset format: {:?}. Only JSON (.json) is currently supported.",
path.extension()
)))
}
}
pub fn save_dataset<P: AsRef<Path>>(
&self,
dataset: &CalibrationDataset,
path: P,
) -> Result<(), TrustformersError> {
use std::fs;
let path = path.as_ref();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
let json_content = serde_json::to_string_pretty(dataset).map_err(|e| {
runtime_error(format!("Failed to serialize dataset to JSON: {}", e))
})?;
fs::write(path, json_content).map_err(|e| runtime_error(e.to_string()))?;
Ok(())
} else {
Err(invalid_input(format!(
"Unsupported dataset format: {:?}. Only JSON (.json) is currently supported.",
path.extension()
)))
}
}
pub fn calibrate(
&mut self,
dataset_name: &str,
config: CalibrationConfig,
) -> Result<CalibrationResult, TrustformersError> {
let dataset = self
.datasets
.get(dataset_name)
.ok_or_else(|| file_not_found(format!("Dataset '{}' not found", dataset_name)))?;
let cache_key = self.generate_cache_key(dataset_name, &config);
if let Some(cached_result) = self.cache.get(&cache_key) {
return Ok(cached_result.clone());
}
let result = self.run_calibration(dataset, &config)?;
self.cache.insert(cache_key, result.clone());
self.history.push(result.clone());
Ok(result)
}
pub fn compare_methods(
&mut self,
dataset_name: &str,
methods: Vec<CalibrationMethod>,
) -> Result<MethodComparison, TrustformersError> {
let mut method_results = HashMap::new();
for method in &methods {
let config = CalibrationConfig {
method: *method,
fallback_methods: Vec::new(),
parameters: self.get_default_parameters(*method),
quality_thresholds: QualityThresholds::default(),
cross_validation: CrossValidationConfig::default(),
};
let result = self.calibrate(dataset_name, config)?;
method_results.insert(*method, result.quality_metrics);
}
let mut method_ranking: Vec<_> = method_results
.iter()
.map(|(method, metrics)| (*method, self.calculate_overall_score(metrics)))
.collect();
method_ranking.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Partial comparison failed"));
let recommended_method = method_ranking[0].0;
let trade_offs = methods
.iter()
.map(|method| self.analyze_trade_offs(*method, &method_results[method]))
.collect();
Ok(MethodComparison {
method_results,
method_ranking,
recommended_method,
trade_offs,
})
}
pub fn validate_calibration(
&self,
result: &CalibrationResult,
thresholds: &QualityThresholds,
) -> Vec<CalibrationRecommendation> {
let mut recommendations = Vec::new();
if result.quality_metrics.accuracy_retention < thresholds.min_accuracy_retention {
recommendations.push(CalibrationRecommendation {
recommendation_type: RecommendationType::TryDifferentMethod,
description: format!(
"Accuracy retention {:.3} is below threshold {:.3}. Consider using a different calibration method or increasing bit width.",
result.quality_metrics.accuracy_retention,
thresholds.min_accuracy_retention
),
expected_improvement: 0.1,
difficulty: 2,
priority: 5,
action_steps: vec![
"Try entropy-based calibration".to_string(),
"Increase quantization bit width".to_string(),
"Use mixed-bit quantization for critical layers".to_string(),
],
});
}
if result.quality_metrics.compression_ratio < thresholds.min_compression_ratio {
recommendations.push(CalibrationRecommendation {
recommendation_type: RecommendationType::UseMixedBit,
description: format!(
"Compression ratio {:.2}x is below target {:.2}x. Consider using more aggressive quantization.",
result.quality_metrics.compression_ratio,
thresholds.min_compression_ratio
),
expected_improvement: 0.2,
difficulty: 3,
priority: 3,
action_steps: vec![
"Enable mixed-bit quantization".to_string(),
"Reduce bit width for less critical layers".to_string(),
"Apply weight pruning before quantization".to_string(),
],
});
}
if result.quality_metrics.sqnr_db < -thresholds.max_sqnr_degradation {
recommendations.push(CalibrationRecommendation {
recommendation_type: RecommendationType::ApplyClipping,
description: format!(
"SQNR degradation {:.2} dB exceeds threshold {:.2} dB. Apply outlier clipping or increase calibration data.",
result.quality_metrics.sqnr_db,
thresholds.max_sqnr_degradation
),
expected_improvement: 0.15,
difficulty: 2,
priority: 4,
action_steps: vec![
"Apply percentile-based outlier clipping".to_string(),
"Increase calibration dataset size".to_string(),
"Use more representative calibration data".to_string(),
],
});
}
recommendations
}
pub fn generate_report(
&self,
result: &CalibrationResult,
dataset_name: &str,
) -> CalibrationReport {
let dataset = self.datasets.get(dataset_name);
CalibrationReport {
dataset_name: dataset_name.to_string(),
dataset_info: dataset.map(|d| d.metadata.clone()),
calibration_result: result.clone(),
recommendations: self.validate_calibration(result, &QualityThresholds::default()),
generated_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
}
}
fn validate_dataset(&self, dataset: &CalibrationDataset) -> Result<(), TrustformersError> {
if dataset.samples.is_empty() {
return Err(invalid_input("Dataset cannot be empty".to_string()));
}
let first_shape = dataset.samples[0].shape();
for (i, sample) in dataset.samples.iter().enumerate() {
if sample.shape() != first_shape {
return Err(invalid_input(format!(
"Sample {} has inconsistent shape",
i
)));
}
}
Ok(())
}
fn calculate_dataset_statistics(
&self,
samples: &[Tensor],
) -> Result<DatasetStatistics, TrustformersError> {
if samples.is_empty() {
return Err(invalid_input(
"Cannot calculate statistics for empty dataset".to_string(),
));
}
let sample_count = samples.len();
let input_shapes = vec![samples[0].shape().to_vec()];
let dim_count = samples[0].len();
let statistics = TensorStatistics {
mean: vec![0.0; dim_count],
std: vec![1.0; dim_count],
min: vec![-1.0; dim_count],
max: vec![1.0; dim_count],
percentiles: vec![vec![0.0; 5]; dim_count],
skewness: vec![0.0; dim_count],
kurtosis: vec![3.0; dim_count],
};
let dynamic_range = DynamicRange {
overall_range: 2.0,
channel_ranges: vec![2.0; dim_count],
outlier_ratio: 0.05,
suggested_clip_min: -1.0,
suggested_clip_max: 1.0,
};
let distribution = DistributionAnalysis {
distribution_type: DistributionType::Normal,
normality_p_value: 0.5,
entropy: 3.0,
concentration: 0.5,
is_multimodal: false,
mode_count: Some(1),
};
Ok(DatasetStatistics {
sample_count,
input_shapes,
statistics,
dynamic_range,
distribution,
})
}
fn generate_cache_key(&self, dataset_name: &str, config: &CalibrationConfig) -> String {
format!("{}_{:?}", dataset_name, config.method)
}
fn run_calibration(
&self,
dataset: &CalibrationDataset,
config: &CalibrationConfig,
) -> Result<CalibrationResult, TrustformersError> {
let parameters = CalibrationParameters {
scales: HashMap::new(),
zero_points: HashMap::new(),
clip_ranges: HashMap::new(),
bit_allocations: HashMap::new(),
extra_params: HashMap::new(),
};
let quality_metrics = QualityMetrics {
accuracy_retention: 0.95,
sqnr_db: 40.0,
kl_divergence: 0.01,
compression_ratio: 4.0,
speedup_factor: 2.0,
memory_reduction: 0.75,
layer_metrics: HashMap::new(),
};
Ok(CalibrationResult {
method: config.method,
primary_success: true,
parameters,
quality_metrics,
cross_validation: None,
method_comparison: None,
recommendations: Vec::new(),
})
}
fn get_default_parameters(
&self,
method: CalibrationMethod,
) -> HashMap<String, CalibrationParameter> {
let mut params = HashMap::new();
match method {
CalibrationMethod::Entropy => {
params.insert("num_bins".to_string(), CalibrationParameter::Int(2048));
params.insert(
"divergence_threshold".to_string(),
CalibrationParameter::Float(0.01),
);
},
CalibrationMethod::Percentile => {
params.insert("percentile".to_string(), CalibrationParameter::Float(99.99));
params.insert("symmetric".to_string(), CalibrationParameter::Bool(true));
},
CalibrationMethod::MSE => {
params.insert(
"learning_rate".to_string(),
CalibrationParameter::Float(0.001),
);
params.insert(
"max_iterations".to_string(),
CalibrationParameter::Int(1000),
);
},
_ => {
params.insert("tolerance".to_string(), CalibrationParameter::Float(1e-6));
},
}
params
}
fn calculate_overall_score(&self, metrics: &QualityMetrics) -> f32 {
0.4 * metrics.accuracy_retention
+ 0.2 * (metrics.sqnr_db / 50.0).min(1.0)
+ 0.2 * (metrics.compression_ratio / 8.0).min(1.0)
+ 0.1 * metrics.speedup_factor / 4.0
+ 0.1 * metrics.memory_reduction
}
fn analyze_trade_offs(
&self,
method: CalibrationMethod,
metrics: &QualityMetrics,
) -> TradeOffAnalysis {
TradeOffAnalysis {
method,
accuracy_compression: metrics.accuracy_retention / (metrics.compression_ratio / 4.0),
speed_quality: metrics.speedup_factor / 4.0 * metrics.accuracy_retention,
memory_accuracy: metrics.memory_reduction * metrics.accuracy_retention,
balance_score: self.calculate_overall_score(metrics),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationReport {
pub dataset_name: String,
pub dataset_info: Option<CalibrationMetadata>,
pub calibration_result: CalibrationResult,
pub recommendations: Vec<CalibrationRecommendation>,
pub generated_at: u64,
}
impl Default for QualityThresholds {
fn default() -> Self {
Self {
min_accuracy_retention: 0.95,
max_sqnr_degradation: 5.0,
max_kl_divergence: 0.1,
max_latency_increase: 0.1,
min_compression_ratio: 2.0,
}
}
}
impl Default for CrossValidationConfig {
fn default() -> Self {
Self {
enabled: true,
folds: 5,
validation_split: 0.2,
random_seed: 42,
stratified: false,
}
}
}
impl Default for CalibrationToolkit {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calibration_toolkit_creation() {
let toolkit = CalibrationToolkit::new();
assert!(toolkit.datasets.is_empty());
assert!(toolkit.configs.is_empty());
assert!(toolkit.history.is_empty());
}
#[test]
fn test_dataset_validation() {
let toolkit = CalibrationToolkit::new();
let empty_dataset = CalibrationDataset {
name: "empty".to_string(),
samples: Vec::new(),
targets: None,
metadata: CalibrationMetadata {
description: "Empty dataset".to_string(),
source: "test".to_string(),
version: "1.0".to_string(),
created_at: 0,
tags: Vec::new(),
model_type: "test".to_string(),
recommended_methods: Vec::new(),
},
statistics: DatasetStatistics {
sample_count: 0,
input_shapes: Vec::new(),
statistics: TensorStatistics {
mean: Vec::new(),
std: Vec::new(),
min: Vec::new(),
max: Vec::new(),
percentiles: Vec::new(),
skewness: Vec::new(),
kurtosis: Vec::new(),
},
dynamic_range: DynamicRange {
overall_range: 0.0,
channel_ranges: Vec::new(),
outlier_ratio: 0.0,
suggested_clip_min: 0.0,
suggested_clip_max: 0.0,
},
distribution: DistributionAnalysis {
distribution_type: DistributionType::Unknown,
normality_p_value: 0.0,
entropy: 0.0,
concentration: 0.0,
is_multimodal: false,
mode_count: None,
},
},
};
assert!(toolkit.validate_dataset(&empty_dataset).is_err());
}
#[test]
fn test_quality_thresholds_default() {
let thresholds = QualityThresholds::default();
assert_eq!(thresholds.min_accuracy_retention, 0.95);
assert_eq!(thresholds.max_sqnr_degradation, 5.0);
assert_eq!(thresholds.min_compression_ratio, 2.0);
}
#[test]
fn test_calibration_method_enum() {
let method = CalibrationMethod::Entropy;
assert_eq!(method, CalibrationMethod::Entropy);
let serialized = serde_json::to_string(&method).expect("JSON serialization failed");
let deserialized: CalibrationMethod =
serde_json::from_str(&serialized).expect("JSON deserialization failed");
assert_eq!(method, deserialized);
}
#[test]
fn test_all_calibration_methods() {
let methods = [
CalibrationMethod::Entropy,
CalibrationMethod::Percentile,
CalibrationMethod::MSE,
CalibrationMethod::SQNR,
CalibrationMethod::CrossEntropy,
CalibrationMethod::Hessian,
CalibrationMethod::ActivationAware,
CalibrationMethod::Smooth,
CalibrationMethod::SensitivityBased,
CalibrationMethod::Learned,
];
for (i, a) in methods.iter().enumerate() {
for (j, b) in methods.iter().enumerate() {
if i == j {
assert_eq!(a, b);
} else {
assert_ne!(a, b);
}
}
}
}
#[test]
fn test_distribution_type_variants() {
let _types = [
DistributionType::Normal,
DistributionType::Uniform,
DistributionType::Exponential,
DistributionType::Laplace,
DistributionType::Gamma,
DistributionType::Beta,
DistributionType::Multimodal,
DistributionType::Unknown,
];
}
#[test]
fn test_distribution_type_eq() {
assert_eq!(DistributionType::Normal, DistributionType::Normal);
assert_ne!(DistributionType::Normal, DistributionType::Uniform);
}
#[test]
fn test_calibration_parameter_float() {
let param = CalibrationParameter::Float(std::f32::consts::PI);
let debug = format!("{:?}", param);
assert!(debug.contains("3.14"));
}
#[test]
fn test_calibration_parameter_int() {
let param = CalibrationParameter::Int(42);
let debug = format!("{:?}", param);
assert!(debug.contains("42"));
}
#[test]
fn test_calibration_parameter_bool() {
let param = CalibrationParameter::Bool(true);
let debug = format!("{:?}", param);
assert!(debug.contains("true"));
}
#[test]
fn test_calibration_parameter_string() {
let param = CalibrationParameter::String("test".to_string());
let debug = format!("{:?}", param);
assert!(debug.contains("test"));
}
#[test]
fn test_calibration_parameter_float_array() {
let param = CalibrationParameter::FloatArray(vec![1.0, 2.0, 3.0]);
let debug = format!("{:?}", param);
assert!(debug.contains("FloatArray"));
}
#[test]
fn test_calibration_parameter_int_array() {
let param = CalibrationParameter::IntArray(vec![1, 2, 3]);
let debug = format!("{:?}", param);
assert!(debug.contains("IntArray"));
}
#[test]
fn test_quality_thresholds_clone() {
let thresholds = QualityThresholds::default();
let cloned = thresholds.clone();
assert_eq!(
cloned.min_accuracy_retention,
thresholds.min_accuracy_retention
);
assert_eq!(cloned.max_sqnr_degradation, thresholds.max_sqnr_degradation);
}
#[test]
fn test_quality_thresholds_custom() {
let thresholds = QualityThresholds {
min_accuracy_retention: 0.99,
max_sqnr_degradation: 1.0,
max_kl_divergence: 0.01,
max_latency_increase: 0.05,
min_compression_ratio: 4.0,
};
assert!((thresholds.min_accuracy_retention - 0.99).abs() < 1e-6);
assert!((thresholds.min_compression_ratio - 4.0).abs() < 1e-6);
}
#[test]
fn test_cross_validation_config_default() {
let config = CrossValidationConfig::default();
assert!(config.enabled);
assert_eq!(config.folds, 5);
assert!((config.validation_split - 0.2).abs() < 1e-6);
assert_eq!(config.random_seed, 42);
assert!(!config.stratified);
}
#[test]
fn test_cross_validation_config_clone() {
let config = CrossValidationConfig::default();
let cloned = config.clone();
assert_eq!(cloned.folds, config.folds);
assert_eq!(cloned.random_seed, config.random_seed);
}
#[test]
fn test_toolkit_default() {
let toolkit = CalibrationToolkit::default();
assert!(toolkit.datasets.is_empty());
}
#[test]
fn test_toolkit_non_empty_dataset_validation() {
let toolkit = CalibrationToolkit::new();
let tensor = Tensor::ones(&[2, 3]).expect("Tensor creation failed");
let dataset = CalibrationDataset {
name: "valid".to_string(),
samples: vec![tensor],
targets: None,
metadata: CalibrationMetadata {
description: "Test dataset".to_string(),
source: "test".to_string(),
version: "1.0".to_string(),
created_at: 0,
tags: Vec::new(),
model_type: "test".to_string(),
recommended_methods: vec![CalibrationMethod::Entropy],
},
statistics: DatasetStatistics {
sample_count: 1,
input_shapes: vec![vec![2, 3]],
statistics: TensorStatistics {
mean: vec![1.0],
std: vec![0.0],
min: vec![1.0],
max: vec![1.0],
percentiles: Vec::new(),
skewness: vec![0.0],
kurtosis: vec![0.0],
},
dynamic_range: DynamicRange {
overall_range: 0.0,
channel_ranges: Vec::new(),
outlier_ratio: 0.0,
suggested_clip_min: 1.0,
suggested_clip_max: 1.0,
},
distribution: DistributionAnalysis {
distribution_type: DistributionType::Normal,
normality_p_value: 0.5,
entropy: 0.0,
concentration: 1.0,
is_multimodal: false,
mode_count: Some(1),
},
},
};
assert!(toolkit.validate_dataset(&dataset).is_ok());
}
#[test]
fn test_dynamic_range_clone() {
let range = DynamicRange {
overall_range: 10.0,
channel_ranges: vec![5.0, 8.0],
outlier_ratio: 0.01,
suggested_clip_min: -5.0,
suggested_clip_max: 5.0,
};
let cloned = range.clone();
assert!((cloned.overall_range - 10.0).abs() < 1e-6);
assert_eq!(cloned.channel_ranges.len(), 2);
}
#[test]
fn test_distribution_analysis_clone() {
let analysis = DistributionAnalysis {
distribution_type: DistributionType::Normal,
normality_p_value: 0.95,
entropy: 2.5,
concentration: 0.8,
is_multimodal: false,
mode_count: Some(1),
};
let cloned = analysis.clone();
assert_eq!(cloned.distribution_type, DistributionType::Normal);
assert!((cloned.entropy - 2.5).abs() < 1e-6);
}
#[test]
fn test_tensor_statistics_clone() {
let stats = TensorStatistics {
mean: vec![0.0, 1.0],
std: vec![1.0, 0.5],
min: vec![-3.0, -1.0],
max: vec![3.0, 2.0],
percentiles: vec![vec![0.1, 0.5, 0.9]],
skewness: vec![0.0],
kurtosis: vec![3.0],
};
let cloned = stats.clone();
assert_eq!(cloned.mean, vec![0.0, 1.0]);
assert_eq!(cloned.std, vec![1.0, 0.5]);
}
#[test]
fn test_calibration_metadata_clone() {
let metadata = CalibrationMetadata {
description: "Test".to_string(),
source: "test_source".to_string(),
version: "1.0".to_string(),
created_at: 12345,
tags: vec!["tag1".to_string()],
model_type: "transformer".to_string(),
recommended_methods: vec![CalibrationMethod::MSE],
};
let cloned = metadata.clone();
assert_eq!(cloned.description, "Test");
assert_eq!(cloned.recommended_methods, vec![CalibrationMethod::MSE]);
}
#[test]
fn test_quality_metrics_clone() {
let metrics = QualityMetrics {
accuracy_retention: 0.98,
sqnr_db: 30.0,
kl_divergence: 0.01,
compression_ratio: 4.0,
speedup_factor: 2.0,
memory_reduction: 0.75,
layer_metrics: HashMap::new(),
};
let cloned = metrics.clone();
assert!((cloned.accuracy_retention - 0.98).abs() < 1e-6);
assert!((cloned.compression_ratio - 4.0).abs() < 1e-6);
}
#[test]
fn test_layer_quality_metrics() {
let metrics = LayerQualityMetrics {
layer_name: "linear_0".to_string(),
layer_type: "Linear".to_string(),
quantization_error: 0.001,
distribution_similarity: 0.99,
gradient_preservation: 0.95,
activation_preservation: 0.98,
};
let cloned = metrics.clone();
assert_eq!(cloned.layer_name, "linear_0");
assert!((cloned.quantization_error - 0.001).abs() < 1e-6);
}
#[test]
fn test_calibration_parameters_empty() {
let params = CalibrationParameters {
scales: HashMap::new(),
zero_points: HashMap::new(),
clip_ranges: HashMap::new(),
bit_allocations: HashMap::new(),
extra_params: HashMap::new(),
};
assert!(params.scales.is_empty());
assert!(params.zero_points.is_empty());
}
#[test]
fn test_calibration_parameters_with_data() {
let mut params = CalibrationParameters {
scales: HashMap::new(),
zero_points: HashMap::new(),
clip_ranges: HashMap::new(),
bit_allocations: HashMap::new(),
extra_params: HashMap::new(),
};
params.scales.insert("layer_0".to_string(), vec![0.1, 0.2]);
params.zero_points.insert("layer_0".to_string(), vec![0, 1]);
params.clip_ranges.insert("layer_0".to_string(), (-1.0, 1.0));
params.bit_allocations.insert("layer_0".to_string(), vec![4, 8]);
assert_eq!(params.scales.get("layer_0").expect("should exist").len(), 2);
assert_eq!(
params.clip_ranges.get("layer_0").expect("should exist"),
&(-1.0, 1.0)
);
}
#[test]
fn test_calibration_result_clone() {
let result = CalibrationResult {
method: CalibrationMethod::Entropy,
primary_success: true,
parameters: CalibrationParameters {
scales: HashMap::new(),
zero_points: HashMap::new(),
clip_ranges: HashMap::new(),
bit_allocations: HashMap::new(),
extra_params: HashMap::new(),
},
quality_metrics: QualityMetrics {
accuracy_retention: 0.99,
sqnr_db: 35.0,
kl_divergence: 0.001,
compression_ratio: 4.0,
speedup_factor: 2.5,
memory_reduction: 0.75,
layer_metrics: HashMap::new(),
},
cross_validation: None,
method_comparison: None,
recommendations: Vec::new(),
};
let cloned = result.clone();
assert_eq!(cloned.method, CalibrationMethod::Entropy);
assert!(cloned.primary_success);
}
}