bitnet_quantize/layer/
ste.rs1use candle_core::Tensor;
7
8use crate::error::Result;
9
10pub fn ste_forward(input: &Tensor, scale: f32, min_val: f32, max_val: f32) -> Result<Tensor> {
27 let scaled = input.affine(1.0 / f64::from(scale), 0.0)?;
29
30 let rounded = scaled.round()?;
32 let clamped = rounded.clamp(min_val, max_val)?;
33
34 let dequantized = clamped.affine(f64::from(scale), 0.0)?;
36
37 Ok(dequantized)
38}
39
40#[must_use]
54pub fn ste_backward(grad_output: &Tensor) -> Tensor {
55 grad_output.clone()
56}
57
58pub fn ternary_ste(input: &Tensor, scale: f32) -> Result<Tensor> {
69 ste_forward(input, scale, -1.0, 1.0)
70}
71
72pub fn int8_ste(input: &Tensor, scale: f32) -> Result<Tensor> {
83 ste_forward(input, scale, -127.0, 127.0)
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use candle_core::Device;
90
91 #[test]
92 fn test_ternary_ste() {
93 let device = Device::Cpu;
94
95 let values: Vec<f32> = vec![-2.0, -0.5, 0.0, 0.5, 2.0];
97 let input = Tensor::from_vec(values, (5,), &device).unwrap();
98
99 let output = ternary_ste(&input, 1.0).unwrap();
101 let result: Vec<f32> = output.to_vec1().unwrap();
102
103 assert_eq!(result[0], -1.0); assert_eq!(result[1], -1.0); assert_eq!(result[2], 0.0); assert_eq!(result[3], 1.0); assert_eq!(result[4], 1.0); }
110
111 #[test]
112 fn test_int8_ste() {
113 let device = Device::Cpu;
114
115 let values: Vec<f32> = vec![-200.0, -50.0, 0.0, 50.0, 200.0];
116 let input = Tensor::from_vec(values, (5,), &device).unwrap();
117
118 let output = int8_ste(&input, 1.0).unwrap();
119 let result: Vec<f32> = output.to_vec1().unwrap();
120
121 assert_eq!(result[0], -127.0);
123 assert_eq!(result[1], -50.0);
124 assert_eq!(result[2], 0.0);
125 assert_eq!(result[3], 50.0);
126 assert_eq!(result[4], 127.0);
127 }
128
129 #[test]
130 fn test_ste_with_scale() {
131 let device = Device::Cpu;
132
133 let values: Vec<f32> = vec![0.5, 1.0, 1.5, 2.0];
134 let input = Tensor::from_vec(values, (4,), &device).unwrap();
135
136 let output = ternary_ste(&input, 2.0).unwrap();
138 let result: Vec<f32> = output.to_vec1().unwrap();
139
140 assert!((result[0] - 0.0).abs() < 0.01);
145 assert!((result[1] - 2.0).abs() < 0.01);
146 assert!((result[2] - 2.0).abs() < 0.01);
147 assert!((result[3] - 2.0).abs() < 0.01);
148 }
149
150 #[test]
151 fn test_ste_backward_identity() {
152 let device = Device::Cpu;
153
154 let grad = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], (3,), &device).unwrap();
155 let result = ste_backward(&grad);
156
157 let grad_vec: Vec<f32> = grad.to_vec1().unwrap();
158 let result_vec: Vec<f32> = result.to_vec1().unwrap();
159
160 assert_eq!(grad_vec, result_vec);
161 }
162}