use std::sync::Arc;
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
pub fn fake_quantize_differentiable<T: Float>(
input: &Tensor<T>,
scale: f64,
zero_point: i32,
qmin: i32,
qmax: i32,
) -> FerrotorchResult<Tensor<T>> {
use crate::error::FerrotorchError;
if !(scale > 0.0) {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"fake_quantize_differentiable: scale must be > 0, got {scale}"
),
});
}
if qmin >= qmax {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"fake_quantize_differentiable: qmin ({qmin}) must be < qmax ({qmax})"
),
});
}
let data = input.data_vec()?;
let scale_f = T::from(scale).unwrap();
let zp_f = T::from(zero_point as f64).unwrap();
let qmin_f = T::from(qmin as f64).unwrap();
let qmax_f = T::from(qmax as f64).unwrap();
let range_min: T = (qmin_f - zp_f) * scale_f;
let range_max: T = (qmax_f - zp_f) * scale_f;
let mut out = Vec::with_capacity(data.len());
for &x in &data {
let scaled = x / scale_f + zp_f;
let rounded = scaled.round();
let clamped = if rounded < qmin_f {
qmin_f
} else if rounded > qmax_f {
qmax_f
} else {
rounded
};
let dq = (clamped - zp_f) * scale_f;
out.push(dq);
}
let storage = TensorStorage::cpu(out);
let shape = input.shape().to_vec();
if input.requires_grad() && crate::autograd::no_grad::is_grad_enabled() {
let grad_fn = Arc::new(FakeQuantizeBackward::<T> {
input: input.clone(),
range_min,
range_max,
});
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
}
#[derive(Debug)]
struct FakeQuantizeBackward<T: Float> {
input: Tensor<T>,
range_min: T,
range_max: T,
}
impl<T: Float> GradFn<T> for FakeQuantizeBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let grad_data = grad_output.data_vec()?;
let input_data = self.input.data_vec()?;
let zero = <T as num_traits::Zero>::zero();
let grad: Vec<T> = input_data
.iter()
.zip(grad_data.iter())
.map(|(&x, &g)| {
if x >= self.range_min && x <= self.range_max {
g
} else {
zero
}
})
.collect();
let storage = TensorStorage::cpu(grad);
let shape = self.input.shape().to_vec();
Ok(vec![Some(Tensor::from_storage(storage, shape, false)?)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"FakeQuantizeBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::backward;
fn t(data: Vec<f32>, shape: Vec<usize>, req_grad: bool) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape, req_grad).unwrap()
}
#[test]
fn fake_quantize_round_trips_representable_values() {
let scale = 0.1;
let zp = 0;
let qmin = -128;
let qmax = 127;
let input = t(vec![0.0, 0.1, 0.2, -0.1, -0.2], vec![5], false);
let out =
fake_quantize_differentiable(&input, scale, zp, qmin, qmax).unwrap();
let data = out.data().unwrap();
for (got, expected) in data.iter().zip([0.0, 0.1, 0.2, -0.1, -0.2].iter()) {
assert!(
(got - expected).abs() < 1e-5,
"expected {expected}, got {got}"
);
}
}
#[test]
fn fake_quantize_clamps_out_of_range_values() {
let input = t(vec![-200.0, -100.0, 0.0, 100.0, 200.0], vec![5], false);
let out = fake_quantize_differentiable(&input, 1.0, 0, -128, 127).unwrap();
let data = out.data().unwrap();
assert_eq!(data[0], -128.0); assert_eq!(data[1], -100.0);
assert_eq!(data[2], 0.0);
assert_eq!(data[3], 100.0);
assert_eq!(data[4], 127.0); }
#[test]
fn fake_quantize_rejects_zero_scale() {
let input = t(vec![1.0], vec![1], false);
let result = fake_quantize_differentiable(&input, 0.0, 0, -128, 127);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("scale must be > 0"));
}
#[test]
fn fake_quantize_rejects_negative_scale() {
let input = t(vec![1.0], vec![1], false);
let result = fake_quantize_differentiable(&input, -0.1, 0, -128, 127);
assert!(result.is_err());
}
#[test]
fn fake_quantize_rejects_inverted_range() {
let input = t(vec![1.0], vec![1], false);
let result = fake_quantize_differentiable(&input, 1.0, 0, 128, -128);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("qmin"));
}
#[test]
fn fake_quantize_asymmetric_with_zero_point() {
let input = t(vec![-128.0, 0.0, 127.0], vec![3], false);
let out = fake_quantize_differentiable(&input, 1.0, 128, 0, 255).unwrap();
let data = out.data().unwrap();
assert_eq!(data, &[-128.0, 0.0, 127.0]);
}
#[test]
fn fake_quantize_ste_passes_grad_for_in_range_values() {
let input = t(vec![-10.0, 0.0, 10.0, 50.0], vec![4], true);
let out = fake_quantize_differentiable(&input, 1.0, 0, -128, 127).unwrap();
let loss = out
.data_vec()
.unwrap()
.into_iter()
.fold(0.0f32, |a, b| a + b);
let sum = crate::grad_fns::reduction::sum(&out).unwrap();
backward(&sum).unwrap();
let grad = input.grad().unwrap().unwrap();
let grad_data = grad.data().unwrap();
for &g in grad_data {
assert_eq!(g, 1.0);
}
let _ = loss;
}
#[test]
fn fake_quantize_ste_zeros_grad_for_out_of_range_values() {
let input = t(
vec![-5.0, -1.0, 0.0, 1.0, 5.0, 100.0],
vec![6],
true,
);
let out =
fake_quantize_differentiable(&input, 0.01, 0, -128, 127).unwrap();
let sum = crate::grad_fns::reduction::sum(&out).unwrap();
backward(&sum).unwrap();
let grad = input.grad().unwrap().unwrap();
let grad_data = grad.data().unwrap();
assert_eq!(grad_data[0], 0.0);
assert_eq!(grad_data[1], 1.0);
assert_eq!(grad_data[2], 1.0);
assert_eq!(grad_data[3], 1.0);
assert_eq!(grad_data[4], 0.0);
assert_eq!(grad_data[5], 0.0);
}
#[test]
fn fake_quantize_no_grad_when_input_doesnt_require_grad() {
let input = t(vec![1.0, 2.0], vec![2], false);
let out = fake_quantize_differentiable(&input, 1.0, 0, -128, 127).unwrap();
assert!(!out.requires_grad());
assert!(out.grad_fn().is_none());
}
#[test]
fn fake_quantize_preserves_grad_fn_when_input_requires_grad() {
let input = t(vec![1.0, 2.0], vec![2], true);
let out = fake_quantize_differentiable(&input, 1.0, 0, -128, 127).unwrap();
assert!(out.requires_grad());
assert!(out.grad_fn().is_some());
}
#[test]
fn fake_quantize_no_grad_context_skips_grad_fn() {
use crate::autograd::no_grad::no_grad;
let input = t(vec![1.0, 2.0], vec![2], true);
let out = no_grad(|| {
fake_quantize_differentiable(&input, 1.0, 0, -128, 127)
})
.unwrap();
assert!(out.grad_fn().is_none());
}
#[test]
fn fake_quantize_chains_through_autograd_with_relu() {
let input = t(vec![-2.0, -0.5, 0.5, 2.0], vec![4], true);
let fq = fake_quantize_differentiable(&input, 0.01, 0, -128, 127).unwrap();
let relu_out = crate::grad_fns::activation::relu(&fq).unwrap();
let sum = crate::grad_fns::reduction::sum(&relu_out).unwrap();
backward(&sum).unwrap();
let grad = input.grad().unwrap().unwrap();
let grad_data = grad.data().unwrap();
assert_eq!(grad_data[0], 0.0);
assert_eq!(grad_data[1], 0.0);
assert_eq!(grad_data[2], 1.0);
assert_eq!(grad_data[3], 0.0);
}
}