Skip to main content

bitnet_quantize/layer/
ste.rs

1//! Straight-Through Estimator (STE) for gradient estimation.
2//!
3//! STE allows gradients to flow through non-differentiable quantization
4//! by using the identity function in the backward pass.
5
6use candle_core::Tensor;
7
8use crate::error::Result;
9
10/// Apply STE forward pass (quantize then dequantize).
11///
12/// This creates a "fake quantized" tensor that:
13/// - Has values as if they were quantized and dequantized
14/// - Can receive gradients normally during backprop
15///
16/// # Arguments
17///
18/// * `input` - Input tensor to quantize
19/// * `scale` - Scale factor for quantization
20/// * `min_val` - Minimum quantized value (e.g., -1 for ternary)
21/// * `max_val` - Maximum quantized value (e.g., +1 for ternary)
22///
23/// # Errors
24///
25/// Returns error if tensor operations fail.
26pub fn ste_forward(input: &Tensor, scale: f32, min_val: f32, max_val: f32) -> Result<Tensor> {
27    // Quantize: round(x / scale) clamped to [min_val, max_val]
28    let scaled = input.affine(1.0 / f64::from(scale), 0.0)?;
29
30    // Round and clamp (fake quantization)
31    let rounded = scaled.round()?;
32    let clamped = rounded.clamp(min_val, max_val)?;
33
34    // Dequantize: multiply by scale
35    let dequantized = clamped.affine(f64::from(scale), 0.0)?;
36
37    Ok(dequantized)
38}
39
40/// Compute STE backward pass (identity gradient).
41///
42/// In the backward pass, gradients flow through unchanged.
43/// This is handled automatically by the tensor operations,
44/// but this function is provided for clarity.
45///
46/// # Arguments
47///
48/// * `grad_output` - Gradient from the next layer
49///
50/// # Returns
51///
52/// The same gradient (identity function)
53#[must_use]
54pub fn ste_backward(grad_output: &Tensor) -> Tensor {
55    grad_output.clone()
56}
57
58/// Apply ternary STE (quantize to {-1, 0, +1}).
59///
60/// # Arguments
61///
62/// * `input` - Input tensor
63/// * `scale` - Scale factor (typically AbsMean of the input)
64///
65/// # Errors
66///
67/// Returns error if tensor operations fail.
68pub fn ternary_ste(input: &Tensor, scale: f32) -> Result<Tensor> {
69    ste_forward(input, scale, -1.0, 1.0)
70}
71
72/// Apply INT8 STE (quantize to [-127, 127]).
73///
74/// # Arguments
75///
76/// * `input` - Input tensor
77/// * `scale` - Scale factor (typically AbsMax / 127)
78///
79/// # Errors
80///
81/// Returns error if tensor operations fail.
82pub fn int8_ste(input: &Tensor, scale: f32) -> Result<Tensor> {
83    ste_forward(input, scale, -127.0, 127.0)
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use candle_core::Device;
90
91    #[test]
92    fn test_ternary_ste() {
93        let device = Device::Cpu;
94
95        // Create input with various magnitudes
96        let values: Vec<f32> = vec![-2.0, -0.5, 0.0, 0.5, 2.0];
97        let input = Tensor::from_vec(values, (5,), &device).unwrap();
98
99        // Scale = 1.0 for simplicity
100        let output = ternary_ste(&input, 1.0).unwrap();
101        let result: Vec<f32> = output.to_vec1().unwrap();
102
103        // Should be clamped to {-1, 0, 1}
104        assert_eq!(result[0], -1.0); // -2.0 -> -1
105        assert_eq!(result[1], -1.0); // -0.5 rounds to -1
106        assert_eq!(result[2], 0.0);  // 0 stays 0
107        assert_eq!(result[3], 1.0);  // 0.5 rounds to 1
108        assert_eq!(result[4], 1.0);  // 2.0 -> 1
109    }
110
111    #[test]
112    fn test_int8_ste() {
113        let device = Device::Cpu;
114
115        let values: Vec<f32> = vec![-200.0, -50.0, 0.0, 50.0, 200.0];
116        let input = Tensor::from_vec(values, (5,), &device).unwrap();
117
118        let output = int8_ste(&input, 1.0).unwrap();
119        let result: Vec<f32> = output.to_vec1().unwrap();
120
121        // Should be clamped to [-127, 127]
122        assert_eq!(result[0], -127.0);
123        assert_eq!(result[1], -50.0);
124        assert_eq!(result[2], 0.0);
125        assert_eq!(result[3], 50.0);
126        assert_eq!(result[4], 127.0);
127    }
128
129    #[test]
130    fn test_ste_with_scale() {
131        let device = Device::Cpu;
132
133        let values: Vec<f32> = vec![0.5, 1.0, 1.5, 2.0];
134        let input = Tensor::from_vec(values, (4,), &device).unwrap();
135
136        // Scale = 2.0 means values are divided by 2 before rounding
137        let output = ternary_ste(&input, 2.0).unwrap();
138        let result: Vec<f32> = output.to_vec1().unwrap();
139
140        // 0.5/2 = 0.25 -> 0 -> 0
141        // 1.0/2 = 0.5 -> 1 -> 2
142        // 1.5/2 = 0.75 -> 1 -> 2
143        // 2.0/2 = 1.0 -> 1 -> 2
144        assert!((result[0] - 0.0).abs() < 0.01);
145        assert!((result[1] - 2.0).abs() < 0.01);
146        assert!((result[2] - 2.0).abs() < 0.01);
147        assert!((result[3] - 2.0).abs() < 0.01);
148    }
149
150    #[test]
151    fn test_ste_backward_identity() {
152        let device = Device::Cpu;
153
154        let grad = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], (3,), &device).unwrap();
155        let result = ste_backward(&grad);
156
157        let grad_vec: Vec<f32> = grad.to_vec1().unwrap();
158        let result_vec: Vec<f32> = result.to_vec1().unwrap();
159
160        assert_eq!(grad_vec, result_vec);
161    }
162}