use super::super::QuantizationParams;
use super::utils::*;
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::ArrayView1;
use std::fmt::Debug;
#[allow(dead_code)]
pub(super) fn calibrate_vector_minmax<F>(
vector: &ArrayView1<F>,
bits: u8,
symmetric: bool,
) -> LinalgResult<QuantizationParams>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
let mut min_val = f32::MAX;
let mut max_val = f32::MIN;
for &val in vector.iter() {
let val_f32 = val.as_();
if val_f32.is_finite() {
min_val = min_val.min(val_f32);
max_val = max_val.max(val_f32);
}
}
if !min_val.is_finite() || !max_val.is_finite() {
return Err(LinalgError::ValueError(
"Vector contains non-finite values".to_string(),
));
}
if min_val == max_val {
min_val -= 1.0;
max_val += 1.0;
}
create_params_from_range(bits, min_val, max_val, symmetric)
}
#[allow(dead_code)]
pub(super) fn calibrate_vector_moving_average<F>(
vector: &ArrayView1<F>,
bits: u8,
windowsize: usize,
symmetric: bool,
) -> LinalgResult<QuantizationParams>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
let mut values: Vec<f32> = vector
.iter()
.filter_map(|&x| {
let val = x.as_();
if val.is_finite() {
Some(val)
} else {
None
}
})
.collect();
if values.is_empty() {
return Err(LinalgError::ValueError(
"Vector contains no finite values".to_string(),
));
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
if values.len() <= windowsize {
let min_val = *values.first().expect("Operation failed");
let max_val = *values.last().expect("Operation failed");
create_params_from_range(bits, min_val, max_val, symmetric)
} else {
let min_val = values.iter().take(windowsize).sum::<f32>() / windowsize as f32;
let max_val = values.iter().rev().take(windowsize).sum::<f32>() / windowsize as f32;
create_params_from_range(bits, min_val, max_val, symmetric)
}
}
#[allow(dead_code)]
pub(super) fn calibrate_vector_percentile<F>(
vector: &ArrayView1<F>,
bits: u8,
percentile: f32,
symmetric: bool,
) -> LinalgResult<QuantizationParams>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
if !(0.0..=1.0).contains(&percentile) {
return Err(LinalgError::ValueError(
"Percentile must be between 0.0 and 1.0".to_string(),
));
}
let mut values: Vec<f32> = vector
.iter()
.filter_map(|&x| {
let val = x.as_();
if val.is_finite() {
Some(val)
} else {
None
}
})
.collect();
if values.is_empty() {
return Err(LinalgError::ValueError(
"Vector contains no finite values".to_string(),
));
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let low_idx = ((1.0 - percentile) * (values.len() as f32)).round() as usize;
let high_idx = ((percentile) * (values.len() as f32)).round() as usize;
let min_val = values[low_idx.min(values.len() - 1)];
let max_val = values[high_idx.min(values.len() - 1)];
create_params_from_range(bits, min_val, max_val, symmetric)
}
#[allow(dead_code)]
pub(super) fn calibrate_vector_entropy<F>(
vector: &ArrayView1<F>,
bits: u8,
num_bins: usize,
symmetric: bool,
) -> LinalgResult<QuantizationParams>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
let (min_val, max_val) = find_min_max_vec(vector);
let histogram = create_histogram_vec(vector, min_val, max_val, num_bins);
let (opt_min, opt_max) =
optimize_thresholds_kl_divergence(&histogram, min_val, max_val, bits, symmetric);
create_params_from_range(bits, opt_min, opt_max, symmetric)
}
#[allow(dead_code)]
pub(super) fn calibrate_vector_mse<F>(
vector: &ArrayView1<F>,
bits: u8,
symmetric: bool,
) -> LinalgResult<QuantizationParams>
where
F: scirs2_core::numeric::Float
+ Debug
+ scirs2_core::numeric::AsPrimitive<f32>
+ scirs2_core::numeric::FromPrimitive,
f32: scirs2_core::numeric::AsPrimitive<F>,
{
let mut base_params = calibrate_vector_minmax(vector, bits, symmetric)?;
let scales = if symmetric {
optimize_symmetric_scale_vec(vector, bits, base_params.scale)
} else {
let (scale, zero_point) =
optimize_affine_params_vec(vector, bits, base_params.scale, base_params.zero_point);
base_params.scale = scale;
base_params.zero_point = zero_point;
base_params.scale
};
let mut opt_params = base_params.clone();
opt_params.scale = scales;
Ok(opt_params)
}