batuta/oracle/rag/quantization/params.rs
1//! Quantization parameters for int8 scalar quantization
2//!
3//! Implements symmetric quantization: Q(x) = round(x / scale)
4//! Following Jacob et al. (2018) and Wu et al. (2020)
5
6use super::error::QuantizationError;
7
8/// Quantization parameters for int8 scalar quantization
9///
10/// Implements symmetric quantization: Q(x) = round(x / scale)
11/// Following Jacob et al. (2018) and Wu et al. (2020)
12#[derive(Debug, Clone, PartialEq)]
13pub struct QuantizationParams {
14 /// Scale factor: absmax / 127.0 for symmetric quantization
15 pub scale: f32,
16 /// Zero point (0 for symmetric quantization)
17 pub zero_point: i8,
18 /// Original embedding dimensions
19 pub dims: usize,
20}
21
22impl QuantizationParams {
23 /// Create new quantization parameters
24 ///
25 /// # Errors
26 /// Returns error if scale is invalid (zero, negative, or non-finite)
27 pub fn new(scale: f32, dims: usize) -> Result<Self, QuantizationError> {
28 if !scale.is_finite() || scale <= 0.0 {
29 return Err(QuantizationError::InvalidScale { scale });
30 }
31 Ok(Self {
32 scale,
33 zero_point: 0, // Symmetric quantization
34 dims,
35 })
36 }
37
38 /// Create from absmax value (symmetric quantization)
39 pub fn from_absmax(absmax: f32, dims: usize) -> Result<Self, QuantizationError> {
40 let scale = absmax / 127.0;
41 Self::new(scale, dims)
42 }
43
44 /// Quantize a single f32 value to i8
45 #[inline]
46 pub fn quantize_value(&self, value: f32) -> i8 {
47 let q = (value / self.scale).round() as i32;
48 q.clamp(-128, 127) as i8
49 }
50
51 /// Dequantize a single i8 value to f32
52 #[inline]
53 pub fn dequantize_value(&self, value: i8) -> f32 {
54 (value as f32 - self.zero_point as f32) * self.scale
55 }
56
57 /// Maximum quantization error bound: scale / 2
58 pub fn max_error_bound(&self) -> f32 {
59 self.scale / 2.0
60 }
61}