use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::{
creation::{ones, randn, zeros},
stats::StatMode,
Tensor,
};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum QuantizationScheme {
Uniform,
NonUniform,
Dynamic,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum QuantizationType {
Int8,
UInt8,
Int16,
Int4,
}
pub fn uniform_quantize(
input: &Tensor,
scale: f32,
zero_point: i32,
qtype: QuantizationType,
) -> TorshResult<(Tensor, f32, i32)> {
let (qmin, qmax) = match qtype {
QuantizationType::Int8 => (-128i32, 127i32),
QuantizationType::UInt8 => (0i32, 255i32),
QuantizationType::Int16 => (-32768i32, 32767i32),
QuantizationType::Int4 => (-8i32, 7i32),
};
let scaled = input.div_scalar(scale)?;
let shifted = scaled.add_scalar(zero_point as f32)?;
let rounded = shifted.round()?;
let clamped = crate::math::clamp(&rounded, qmin as f32, qmax as f32)?;
Ok((clamped, scale, zero_point))
}
pub fn uniform_dequantize(quantized: &Tensor, scale: f32, zero_point: i32) -> TorshResult<Tensor> {
let mut shifted = quantized.clone();
shifted.sub_scalar_(zero_point as f32)?;
let shifted = shifted;
let dequantized = shifted.mul_scalar(scale)?;
Ok(dequantized)
}
pub fn dynamic_quantize(
input: &Tensor,
qtype: QuantizationType,
reduce_range: bool,
) -> TorshResult<(Tensor, f32, i32)> {
let (qmin, qmax) = match qtype {
QuantizationType::Int8 => {
if reduce_range {
(-64i32, 63i32)
} else {
(-128i32, 127i32)
}
}
QuantizationType::UInt8 => {
if reduce_range {
(0i32, 127i32)
} else {
(0i32, 255i32)
}
}
QuantizationType::Int16 => {
if reduce_range {
(-16384i32, 16383i32)
} else {
(-32768i32, 32767i32)
}
}
QuantizationType::Int4 => {
if reduce_range {
(-4i32, 3i32)
} else {
(-8i32, 7i32)
}
}
};
let input_min = input.min()?.data()?[0];
let input_max = input.max(None, false)?.data()?[0];
let scale = (input_max - input_min) / (qmax - qmin) as f32;
let zero_point_float = qmin as f32 - input_min / scale;
let zero_point = zero_point_float.round() as i32;
let safe_scale = if scale == 0.0 { 1.0 } else { scale };
uniform_quantize(input, safe_scale, zero_point, qtype)
}
pub fn fake_quantize(
input: &Tensor,
scale: f32,
zero_point: i32,
qtype: QuantizationType,
) -> TorshResult<Tensor> {
let (quantized, scale, zero_point) = uniform_quantize(input, scale, zero_point, qtype)?;
uniform_dequantize(&quantized, scale, zero_point)
}
pub fn magnitude_prune(
weights: &Tensor,
sparsity: f32,
structured: bool,
) -> TorshResult<(Tensor, Tensor)> {
if sparsity < 0.0 || sparsity >= 1.0 {
return Err(TorshError::invalid_argument_with_context(
"Sparsity must be in range [0.0, 1.0)",
"magnitude_prune",
));
}
if structured {
let weight_shape_ref = weights.shape();
let weight_shape = weight_shape_ref.dims();
if weight_shape.len() < 2 {
return Err(TorshError::invalid_argument_with_context(
"Structured pruning requires at least 2D weights",
"magnitude_prune",
));
}
let num_filters = weight_shape[0];
let num_to_prune = (num_filters as f32 * sparsity) as usize;
let dims_to_reduce: Vec<i32> = (1..weight_shape.len()).map(|i| i as i32).collect();
let _filter_norms = weights
.pow_scalar(2.0)?
.sum_dim(&dims_to_reduce, false)?
.sqrt()?;
let mask = ones(&weight_shape)?;
if num_to_prune > 0 {
for _i in 0..num_to_prune.min(num_filters) {
}
}
let pruned_weights = weights.mul_op(&mask)?;
Ok((pruned_weights, mask))
} else {
let abs_weights = weights.abs()?;
let threshold = calculate_pruning_threshold(&abs_weights, sparsity)?;
let bool_mask = abs_weights.gt_scalar(threshold)?;
let mask_data: Vec<f32> = bool_mask
.data()?
.iter()
.map(|&b| if b { 1.0 } else { 0.0 })
.collect();
let mask = Tensor::from_data(mask_data, weights.shape().dims().to_vec(), weights.device())?;
let pruned_weights = weights.mul_op(&mask)?;
Ok((pruned_weights, mask))
}
}
fn calculate_pruning_threshold(abs_weights: &Tensor, sparsity: f32) -> TorshResult<f32> {
let mean_data = abs_weights.mean(None, false)?.data()?;
let mean_val = mean_data.get(0).unwrap_or(&0.1).clone();
let std_data = abs_weights.std(None, false, StatMode::Sample)?.data()?;
let std_val = std_data.get(0).unwrap_or(&0.01).clone();
let threshold = mean_val - sparsity * std_val;
Ok(threshold.max(0.0))
}
pub fn gradual_magnitude_prune(
weights: &Tensor,
current_step: usize,
start_step: usize,
end_step: usize,
initial_sparsity: f32,
final_sparsity: f32,
) -> TorshResult<(Tensor, f32, Tensor)> {
if current_step < start_step {
let mask = ones(&weights.shape().dims())?;
return Ok((weights.clone(), initial_sparsity, mask));
}
if current_step >= end_step {
let (pruned, mask) = magnitude_prune(weights, final_sparsity, false)?;
return Ok((pruned, final_sparsity, mask));
}
let progress = (current_step - start_step) as f32 / (end_step - start_step) as f32;
let current_sparsity = initial_sparsity
+ (final_sparsity - initial_sparsity) * (3.0 * progress.powi(2) - 2.0 * progress.powi(3));
let (pruned, mask) = magnitude_prune(weights, current_sparsity, false)?;
Ok((pruned, current_sparsity, mask))
}
pub fn weight_clustering(
weights: &Tensor,
num_clusters: usize,
) -> TorshResult<(Tensor, Tensor, Tensor)> {
if num_clusters == 0 {
return Err(TorshError::invalid_argument_with_context(
"Number of clusters must be positive",
"weight_clustering",
));
}
let weight_shape_ref = weights.shape();
let weight_shape = weight_shape_ref.dims();
let _num_weights = weights.numel();
let centroids = randn(&[num_clusters])?;
let min_data = weights.min()?.data()?;
let min_weight = min_data.get(0).unwrap_or(&-1.0).clone();
let max_data = weights.max(None, false)?.data()?;
let max_weight = max_data.get(0).unwrap_or(&1.0).clone();
let _weight_range = max_weight - min_weight;
let cluster_assignments = zeros(&weight_shape)?;
let clustered_weights = weights.clone();
Ok((clustered_weights, centroids, cluster_assignments))
}
pub fn lottery_ticket_prune(
weights: &Tensor,
initial_weights: &Tensor,
sparsity: f32,
) -> TorshResult<(Tensor, Tensor)> {
if weights.shape().dims() != initial_weights.shape().dims() {
return Err(TorshError::invalid_argument_with_context(
"Weight tensors must have same shape",
"lottery_ticket_prune",
));
}
let (_, mask) = magnitude_prune(weights, sparsity, false)?;
let winning_subnetwork = initial_weights.mul_op(&mask)?;
Ok((mask, winning_subnetwork))
}
pub fn quantization_error_analysis(
original: &Tensor,
quantized: &Tensor,
) -> TorshResult<(f32, f32, f32)> {
if original.shape().dims() != quantized.shape().dims() {
return Err(TorshError::invalid_argument_with_context(
"Tensors must have same shape",
"quantization_error_analysis",
));
}
let error = original.sub(quantized)?;
let mse_tensor = error.pow_scalar(2.0)?.mean(None, false)?;
let mse = mse_tensor.data()?[0];
let abs_error = error.abs()?;
let max_error_tensor = abs_error.max(None, false)?;
let max_error = max_error_tensor.data()?[0];
let signal_power_tensor = original.pow_scalar(2.0)?.mean(None, false)?;
let signal_power = signal_power_tensor.data()?[0];
let snr_db = if mse > 0.0 {
10.0 * (signal_power / mse).log10()
} else {
f32::INFINITY
};
Ok((mse, max_error, snr_db))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::random_ops::randn;
#[test]
fn test_uniform_quantization() {
let input = randn(&[4, 4], None, None, None).unwrap();
let (quantized, scale, zero_point) =
uniform_quantize(&input, 0.1, 128, QuantizationType::UInt8).unwrap();
assert_eq!(quantized.shape().dims(), input.shape().dims());
let dequantized = uniform_dequantize(&quantized, scale, zero_point).unwrap();
assert_eq!(dequantized.shape().dims(), input.shape().dims());
}
#[test]
fn test_dynamic_quantization() {
let input = randn(&[3, 3], None, None, None).unwrap();
let (quantized, scale, _zero_point) =
dynamic_quantize(&input, QuantizationType::Int8, false).unwrap();
assert_eq!(quantized.shape().dims(), input.shape().dims());
assert!(scale > 0.0);
}
#[test]
fn test_fake_quantization() {
let input = randn(&[2, 2], None, None, None).unwrap();
let fake_quantized = fake_quantize(&input, 0.1, 0, QuantizationType::Int8).unwrap();
assert_eq!(fake_quantized.shape().dims(), input.shape().dims());
}
#[test]
fn test_magnitude_pruning() {
let weights = randn(&[10, 10], None, None, None).unwrap();
let (pruned, mask) = magnitude_prune(&weights, 0.5, false).unwrap();
assert_eq!(pruned.shape().dims(), weights.shape().dims());
assert_eq!(mask.shape().dims(), weights.shape().dims());
}
#[test]
fn test_gradual_pruning() {
let weights = randn(&[5, 5], None, None, None).unwrap();
let (pruned, sparsity, mask) =
gradual_magnitude_prune(&weights, 50, 10, 100, 0.0, 0.8).unwrap();
assert_eq!(pruned.shape().dims(), weights.shape().dims());
assert!(sparsity >= 0.0 && sparsity <= 0.8);
assert_eq!(mask.shape().dims(), weights.shape().dims());
}
#[test]
fn test_lottery_ticket() {
let trained_weights = randn(&[4, 4], None, None, None).unwrap();
let initial_weights = randn(&[4, 4], None, None, None).unwrap();
let (mask, winning_subnetwork) =
lottery_ticket_prune(&trained_weights, &initial_weights, 0.6).unwrap();
assert_eq!(mask.shape().dims(), trained_weights.shape().dims());
assert_eq!(
winning_subnetwork.shape().dims(),
initial_weights.shape().dims()
);
}
#[test]
fn test_quantization_error_analysis() {
let original = randn(&[3, 3], None, None, None).unwrap();
let quantized = original.clone();
let (mse, max_error, snr_db) = quantization_error_analysis(&original, &quantized).unwrap();
assert!(mse <= 1e-6);
assert!(max_error <= 1e-6);
assert!(snr_db > 60.0 || snr_db.is_infinite());
}
}