use crate::quantization::{fake_quantize, fake_quantize_vector, QuantizationMethod};
use scirs2_core::ndarray::array;
#[test]
fn test_fake_quantize() {
let a = array![[1.0_f32, 2.5, 3.7], [4.2, 5.0, 6.1]];
let a_fake_q = fake_quantize(&a.view(), 8, QuantizationMethod::Uniform);
let max_diff = (&a - &a_fake_q)
.mapv(|x| x.abs())
.fold(0.0_f32, |acc, &b| acc.max(b));
println!("Max error (Fake Quantize): {}", max_diff);
assert!(max_diff < 6.0, "Max error too large: {}", max_diff);
assert!(a != a_fake_q);
}
#[test]
fn test_fake_quantize_int4() {
let a = array![[1.0_f32, 2.5, 3.7], [4.2, 5.0, 6.1]];
let a_fake_q = fake_quantize(&a.view(), 4, QuantizationMethod::Int4);
let max_diff = (&a - &a_fake_q)
.mapv(|x| x.abs())
.fold(0.0_f32, |acc, &b| acc.max(b));
println!("Max error (Fake Quantize Int4): {}", max_diff);
assert!(max_diff < 10.0, "Max error too large: {}", max_diff);
assert!(a != a_fake_q);
}
#[test]
fn test_fake_quantize_vector() {
let a = array![1.0_f32, 2.5, 3.7, 4.2, 5.0, 6.1];
let a_fake_q = fake_quantize_vector(&a.view(), 8, QuantizationMethod::Uniform);
let max_diff = (&a - &a_fake_q)
.mapv(|x| x.abs())
.fold(0.0_f32, |acc, &b| acc.max(b));
println!("Max error (Fake Quantize Vector): {}", max_diff);
assert!(max_diff < 6.0, "Max error too large: {}", max_diff);
assert!(a != a_fake_q);
}
#[test]
fn test_fake_quantize_vector_uint4() {
let a = array![1.0_f32, 2.5, 3.7, 4.2, 5.0, 6.1];
let a_fake_q = fake_quantize_vector(&a.view(), 4, QuantizationMethod::UInt4);
let max_diff = (&a - &a_fake_q)
.mapv(|x| x.abs())
.fold(0.0_f32, |acc, &b| acc.max(b));
println!("Max error (Fake Quantize Vector UInt4): {}", max_diff);
assert!(max_diff < 10.0, "Max error too large: {}", max_diff);
assert!(a != a_fake_q);
}