Skip to main content

batuta/oracle/rag/quantization/
params.rs

1//! Quantization parameters for int8 scalar quantization
2//!
3//! Implements symmetric quantization: Q(x) = round(x / scale)
4//! Following Jacob et al. (2018) and Wu et al. (2020)
5
6use super::error::QuantizationError;
7
8/// Quantization parameters for int8 scalar quantization
9///
10/// Implements symmetric quantization: Q(x) = round(x / scale)
11/// Following Jacob et al. (2018) and Wu et al. (2020)
12#[derive(Debug, Clone, PartialEq)]
13pub struct QuantizationParams {
14    /// Scale factor: absmax / 127.0 for symmetric quantization
15    pub scale: f32,
16    /// Zero point (0 for symmetric quantization)
17    pub zero_point: i8,
18    /// Original embedding dimensions
19    pub dims: usize,
20}
21
22impl QuantizationParams {
23    /// Create new quantization parameters
24    ///
25    /// # Errors
26    /// Returns error if scale is invalid (zero, negative, or non-finite)
27    pub fn new(scale: f32, dims: usize) -> Result<Self, QuantizationError> {
28        if !scale.is_finite() || scale <= 0.0 {
29            return Err(QuantizationError::InvalidScale { scale });
30        }
31        Ok(Self {
32            scale,
33            zero_point: 0, // Symmetric quantization
34            dims,
35        })
36    }
37
38    /// Create from absmax value (symmetric quantization)
39    pub fn from_absmax(absmax: f32, dims: usize) -> Result<Self, QuantizationError> {
40        let scale = absmax / 127.0;
41        Self::new(scale, dims)
42    }
43
44    /// Quantize a single f32 value to i8
45    #[inline]
46    pub fn quantize_value(&self, value: f32) -> i8 {
47        let q = (value / self.scale).round() as i32;
48        q.clamp(-128, 127) as i8
49    }
50
51    /// Dequantize a single i8 value to f32
52    #[inline]
53    pub fn dequantize_value(&self, value: i8) -> f32 {
54        (value as f32 - self.zero_point as f32) * self.scale
55    }
56
57    /// Maximum quantization error bound: scale / 2
58    pub fn max_error_bound(&self) -> f32 {
59        self.scale / 2.0
60    }
61}