pub mod config;
pub use config::*;
pub mod algorithms;
pub use algorithms::*;
pub mod observers;
pub use observers::*;
pub mod specialized;
pub use specialized::*;
pub mod metrics;
pub use metrics::*;
pub mod analysis;
pub use analysis::*;
pub mod memory_pool;
pub use memory_pool::*;
pub mod simd_ops;
pub use simd_ops::{
calculate_tensor_stats_simd, dequantize_per_tensor_affine_simd, find_min_max_simd,
get_mobile_optimization_hints, get_simd_width, is_simd_available,
quantize_batch_consistent_simd, quantize_mobile_optimized, quantize_per_channel_simd,
quantize_per_tensor_affine_simd, quantize_to_int8_simd, MobileOptimizationHints,
TensorStats as SimdTensorStats,
};
#[cfg(target_arch = "aarch64")]
pub use simd_ops::{find_min_max_neon, quantize_neon_optimized};
pub mod quantum;
pub use quantum::*;
pub mod quantum_enhanced;
pub use quantum_enhanced::*;
pub mod benchmarks;
pub use benchmarks::{
BaselineMetrics, BenchmarkConfig as SuiteBenchmarkConfig,
BenchmarkResult as SuiteBenchmarkResult, HardwareInfo, QuantizationBenchmarkSuite,
};
pub mod utils;
pub use utils::*;
pub mod auto_config;
pub use auto_config::*;
#[cfg(feature = "experimental")]
pub mod quantize;
#[cfg(feature = "experimental")]
pub mod dequantize;
#[cfg(feature = "experimental")]
pub mod advanced;
#[cfg(feature = "experimental")]
pub mod compression;
#[cfg(feature = "experimental")]
pub mod fake_quantize;
#[cfg(feature = "experimental")]
pub mod qat;
#[cfg(feature = "experimental")]
pub mod post_training;
#[cfg(feature = "experimental")]
pub mod optimizer;
#[cfg(feature = "experimental")]
pub mod realtime_adaptive;
#[cfg(feature = "experimental")]
pub mod hardware;
#[cfg(feature = "experimental")]
pub mod fusion;
#[cfg(feature = "experimental")]
pub mod profiler;
#[cfg(feature = "experimental")]
pub mod debugging;
#[cfg(feature = "experimental")]
pub mod neural_codecs;
#[cfg(feature = "experimental")]
pub mod research;
#[cfg(feature = "experimental")]
pub mod export;
pub use torsh_core::{error::Result as TorshResult, DType, TorshError};
pub use torsh_tensor::Tensor;
pub mod prelude {
pub use crate::algorithms::*;
pub use crate::analysis::*;
pub use crate::auto_config::*;
pub use crate::config::*;
pub use crate::memory_pool::*;
pub use crate::metrics::*;
pub use crate::observers::*;
pub use crate::specialized::*;
pub use crate::utils::*;
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::tensor_1d;
#[test]
fn test_basic_quantization_workflow() {
let data = vec![0.0, 1.0, 2.0, 3.0];
let tensor = tensor_1d(&data).unwrap();
let config = QuantConfig::int8();
let result = quantize_with_config(&tensor, &config);
assert!(result.is_ok());
let (quantized, scale, zero_point) = result.unwrap();
let quantized_data = quantized.data().unwrap();
let all_in_range = quantized_data.iter().all(|&x| x >= -128.0 && x <= 127.0);
assert!(
all_in_range,
"Quantized values should be in I8 range [-128, 127]"
);
assert!(scale > 0.0);
let dequantized = dequantize(&quantized, scale, zero_point).unwrap();
assert_eq!(dequantized.dtype(), DType::F32);
}
#[test]
fn test_configuration_validation() {
let valid_config = QuantConfig::int8();
assert!(valid_config.validate().is_ok());
let per_channel_config = QuantConfig::per_channel(0);
assert!(per_channel_config.validate().is_ok());
}
#[test]
fn test_specialized_quantization() {
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
let _tensor = tensor_1d(&data).unwrap();
let int4_config = QuantConfig::int4();
assert!(int4_config.validate().is_ok());
let binary_config = QuantConfig::binary();
assert!(binary_config.validate().is_ok());
let ternary_config = QuantConfig::ternary();
assert!(ternary_config.validate().is_ok());
}
#[test]
fn test_utils_functionality() {
let data = vec![0.0, 1.0, 2.0, 3.0];
let tensor = tensor_1d(&data).unwrap();
let config = QuantConfig::int8();
let suggestions = validate_config_with_suggestions(&config).unwrap();
assert!(suggestions.len() > 0);
let hints = get_optimization_hints(&tensor, &config);
assert!(hints.is_empty() || !hints.is_empty());
let json = export_config_to_json(&config).unwrap();
let imported_config = import_config_from_json(&json).unwrap();
assert_eq!(config.dtype, imported_config.dtype);
assert_eq!(config.scheme, imported_config.scheme);
}
#[test]
fn test_batch_processing() {
let data1 = vec![0.0, 1.0, 2.0, 3.0];
let data2 = vec![4.0, 5.0, 6.0, 7.0];
let tensor1 = tensor_1d(&data1).unwrap();
let tensor2 = tensor_1d(&data2).unwrap();
let tensors = vec![&tensor1, &tensor2];
let config = QuantConfig::int8();
let results = quantize_batch_consistent(&tensors, &config).unwrap();
assert_eq!(results.len(), 2);
let (_, scale1, zp1) = &results[0];
let (_, scale2, zp2) = &results[1];
assert_eq!(scale1, scale2);
assert_eq!(zp1, zp2);
}
#[test]
fn test_metrics_calculation() {
let data = vec![0.0, 1.0, 2.0, 3.0];
let tensor = tensor_1d(&data).unwrap();
let config = QuantConfig::int8();
let (quantized, scale, zero_point) = quantize_with_config(&tensor, &config).unwrap();
let dequantized = dequantize(&quantized, scale, zero_point).unwrap();
let metrics = calculate_quantization_metrics(&tensor, &dequantized, 32, 8).unwrap();
assert!(metrics.psnr > 0.0);
assert!(metrics.snr > 0.0);
assert!(metrics.compression_ratio > 1.0);
assert!(metrics.cosine_similarity >= 0.0 && metrics.cosine_similarity <= 1.0);
}
#[test]
fn test_configuration_comparison() {
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
let tensor = tensor_1d(&data).unwrap();
let configs = vec![
QuantConfig::int8(),
QuantConfig::per_channel(0),
QuantConfig::int4(),
];
let comparison = compare_quantization_configs(&tensor, &configs).unwrap();
assert_eq!(comparison.len(), 3);
for i in 1..comparison.len() {
assert!(comparison[i - 1].1.psnr >= comparison[i].1.psnr);
}
}
#[test]
fn test_auto_calibration() {
let data1 = vec![0.0, 1.0, 2.0, 3.0];
let data2 = vec![4.0, 5.0, 6.0, 7.0];
let tensor1 = tensor_1d(&data1).unwrap();
let tensor2 = tensor_1d(&data2).unwrap();
let calibration_tensors = vec![&tensor1, &tensor2];
let target_psnr = 30.0;
let max_compression = 8.0;
let optimal_config =
auto_calibrate_quantization(&calibration_tensors, target_psnr, max_compression)
.unwrap();
assert!(optimal_config.validate().is_ok());
}
#[test]
fn test_report_generation() {
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
let tensor = tensor_1d(&data).unwrap();
let configs = vec![QuantConfig::int8(), QuantConfig::int4()];
let report = generate_quantization_report(&tensor, &configs).unwrap();
assert!(report.contains("# Quantization Analysis Report"));
assert!(report.contains("## Quantization Configuration Comparison"));
assert!(report.contains("## Detailed Metrics"));
assert!(report.contains("## Recommendations"));
}
#[test]
fn test_error_diagnostics() {
let data = vec![0.0, 1.0, 2.0, 3.0];
let tensor = tensor_1d(&data).unwrap();
let config = QuantConfig::int8();
let error = TorshError::InvalidArgument("Test error".to_string());
let diagnosis = diagnose_quantization_failure(&tensor, &config, &error);
assert!(diagnosis.contains("Quantization failed with error"));
assert!(diagnosis.contains("Tensor Analysis"));
assert!(diagnosis.contains("Configuration Analysis"));
assert!(diagnosis.contains("Recovery Suggestions"));
}
#[test]
fn test_optimized_config_creation() {
let inference_config = create_optimized_config("inference_cpu", "x86").unwrap();
assert!(inference_config.validate().is_ok());
let mobile_config = create_optimized_config("inference_mobile", "arm").unwrap();
assert!(mobile_config.validate().is_ok());
let training_config = create_optimized_config("training", "gpu").unwrap();
assert!(training_config.validate().is_ok());
let invalid_result = create_optimized_config("invalid_use_case", "x86");
assert!(invalid_result.is_err());
}
}