use candle_core::Tensor;
use crate::error::Result;
pub fn ste_forward(input: &Tensor, scale: f32, min_val: f32, max_val: f32) -> Result<Tensor> {
let scaled = input.affine(1.0 / f64::from(scale), 0.0)?;
let rounded = scaled.round()?;
let clamped = rounded.clamp(min_val, max_val)?;
let dequantized = clamped.affine(f64::from(scale), 0.0)?;
Ok(dequantized)
}
#[must_use]
pub fn ste_backward(grad_output: &Tensor) -> Tensor {
grad_output.clone()
}
pub fn ternary_ste(input: &Tensor, scale: f32) -> Result<Tensor> {
ste_forward(input, scale, -1.0, 1.0)
}
pub fn int8_ste(input: &Tensor, scale: f32) -> Result<Tensor> {
ste_forward(input, scale, -127.0, 127.0)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_ternary_ste() {
let device = Device::Cpu;
let values: Vec<f32> = vec![-2.0, -0.5, 0.0, 0.5, 2.0];
let input = Tensor::from_vec(values, (5,), &device).unwrap();
let output = ternary_ste(&input, 1.0).unwrap();
let result: Vec<f32> = output.to_vec1().unwrap();
assert_eq!(result[0], -1.0); assert_eq!(result[1], -1.0); assert_eq!(result[2], 0.0); assert_eq!(result[3], 1.0); assert_eq!(result[4], 1.0); }
#[test]
fn test_int8_ste() {
let device = Device::Cpu;
let values: Vec<f32> = vec![-200.0, -50.0, 0.0, 50.0, 200.0];
let input = Tensor::from_vec(values, (5,), &device).unwrap();
let output = int8_ste(&input, 1.0).unwrap();
let result: Vec<f32> = output.to_vec1().unwrap();
assert_eq!(result[0], -127.0);
assert_eq!(result[1], -50.0);
assert_eq!(result[2], 0.0);
assert_eq!(result[3], 50.0);
assert_eq!(result[4], 127.0);
}
#[test]
fn test_ste_with_scale() {
let device = Device::Cpu;
let values: Vec<f32> = vec![0.5, 1.0, 1.5, 2.0];
let input = Tensor::from_vec(values, (4,), &device).unwrap();
let output = ternary_ste(&input, 2.0).unwrap();
let result: Vec<f32> = output.to_vec1().unwrap();
assert!((result[0] - 0.0).abs() < 0.01);
assert!((result[1] - 2.0).abs() < 0.01);
assert!((result[2] - 2.0).abs() < 0.01);
assert!((result[3] - 2.0).abs() < 0.01);
}
#[test]
fn test_ste_backward_identity() {
let device = Device::Cpu;
let grad = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], (3,), &device).unwrap();
let result = ste_backward(&grad);
let grad_vec: Vec<f32> = grad.to_vec1().unwrap();
let result_vec: Vec<f32> = result.to_vec1().unwrap();
assert_eq!(grad_vec, result_vec);
}
}