lnmp_quant/
encode.rs

1use crate::error::QuantError;
2use crate::metrics::QuantMetrics;
3use crate::scheme::QuantScheme;
4use crate::vector::QuantizedVector;
5use lnmp_embedding::Vector;
6
7/// Quantizes an embedding vector using the specified quantization scheme
8///
9/// # Arguments
10/// * `emb` - The embedding vector to quantize (must be F32 type)
11/// * `scheme` - The quantization scheme to use
12///
13/// # Returns
14/// * `Ok(QuantizedVector)` - Successfully quantized vector
15/// * `Err(QuantError)` - If quantization fails
16///
17/// # Example
18/// ```
19/// use lnmp_quant::{quantize_embedding, QuantScheme};
20/// use lnmp_embedding::Vector;
21///
22/// let emb = Vector::from_f32(vec![0.12, -0.45, 0.33]);
23/// let quantized = quantize_embedding(&emb, QuantScheme::QInt8).unwrap();
24/// ```
25pub fn quantize_embedding(
26    emb: &Vector,
27    scheme: QuantScheme,
28) -> Result<QuantizedVector, QuantError> {
29    // Validate input
30    if emb.dtype != lnmp_embedding::EmbeddingType::F32 {
31        return Err(QuantError::EncodingFailed(
32            "Only F32 embeddings are currently supported for quantization".to_string(),
33        ));
34    }
35
36    if emb.dim == 0 {
37        return Err(QuantError::InvalidDimension(
38            "Cannot quantize zero-dimensional vector".to_string(),
39        ));
40    }
41
42    match scheme {
43        QuantScheme::QInt8 => quantize_qint8(emb),
44        QuantScheme::QInt4 => crate::qint4::quantize_qint4(emb),
45        QuantScheme::Binary => crate::binary::quantize_binary(emb),
46        QuantScheme::FP16Passthrough => crate::fp16::quantize_fp16(emb),
47    }
48}
49
50/// Quantizes an embedding to QInt8 format
51fn quantize_qint8(emb: &Vector) -> Result<QuantizedVector, QuantError> {
52    // Convert to f32 values
53    let values = emb
54        .as_f32()
55        .map_err(|e| QuantError::EncodingFailed(format!("Failed to convert to F32: {}", e)))?;
56
57    if values.is_empty() {
58        return Err(QuantError::InvalidDimension(
59            "Empty embedding vector".to_string(),
60        ));
61    }
62
63    // Find min and max values
64    let min_val = values
65        .iter()
66        .copied()
67        .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
68        .unwrap();
69    let max_val = values
70        .iter()
71        .copied()
72        .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
73        .unwrap();
74
75    // Handle edge case: all values are the same
76    let (scale, zero_point) = if (max_val - min_val).abs() < 1e-10 {
77        (1.0, 0)
78    } else {
79        // Calculate scale to map [min_val, max_val] to [-128, 127]
80        let range = max_val - min_val;
81        let scale = range / 255.0; // QInt8 range: -128 to 127 = 256 values, but we use 255 for symmetry
82
83        // Calculate zero point (typically 0 for symmetric quantization)
84        let zero_point = 0i8;
85
86        (scale, zero_point)
87    };
88
89    // Pre-allocate output vector with exact capacity (memory optimization)
90    let mut quantized_data = Vec::with_capacity(values.len());
91
92    // Cache inverse scale to avoid repeated division (performance optimization)
93    let inv_scale = if scale.abs() > 1e-10 {
94        1.0 / scale
95    } else {
96        1.0
97    };
98
99    // Quantize each value
100    for &value in &values {
101        // Optimized: multiply by inv_scale instead of dividing by scale
102        let normalized = (value - min_val) * inv_scale;
103        let quantized = (normalized as i32 - 128).clamp(-128, 127) as i8;
104        quantized_data.push(quantized as u8);
105    }
106
107    // Calculate metrics (approximate loss ratio based on quantization error)
108    let loss_ratio = calculate_loss_ratio(&values, &quantized_data, scale, min_val);
109
110    let _metrics = QuantMetrics::new(min_val, max_val, loss_ratio);
111
112    Ok(QuantizedVector::new(
113        emb.dim as u32,
114        QuantScheme::QInt8,
115        scale,
116        zero_point,
117        min_val,
118        quantized_data,
119    ))
120}
121
122/// Calculates approximate information loss ratio
123fn calculate_loss_ratio(original: &[f32], quantized: &[u8], scale: f32, min_val: f32) -> f32 {
124    if original.is_empty() {
125        return 0.0;
126    }
127
128    let mut total_error = 0.0f32;
129    let mut total_magnitude = 0.0f32;
130
131    for (i, &orig_val) in original.iter().enumerate() {
132        // Dequantize
133        let q_val = quantized[i] as i8;
134        let reconstructed = ((q_val as i32 + 128) as f32 * scale) + min_val;
135
136        let error = (orig_val - reconstructed).abs();
137        total_error += error;
138        total_magnitude += orig_val.abs();
139    }
140
141    if total_magnitude < 1e-10 {
142        0.0
143    } else {
144        total_error / total_magnitude
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use lnmp_embedding::Vector;
152
153    #[test]
154    fn test_quantize_simple() {
155        let emb = Vector::from_f32(vec![0.12, -0.45, 0.33]);
156        let result = quantize_embedding(&emb, QuantScheme::QInt8);
157        assert!(result.is_ok());
158
159        let quantized = result.unwrap();
160        assert_eq!(quantized.dim, 3);
161        assert_eq!(quantized.scheme, QuantScheme::QInt8);
162        assert_eq!(quantized.data.len(), 3);
163    }
164
165    #[test]
166    fn test_quantize_large_vector() {
167        let values: Vec<f32> = (0..512).map(|i| (i as f32 / 512.0) - 0.5).collect();
168        let emb = Vector::from_f32(values);
169        let result = quantize_embedding(&emb, QuantScheme::QInt8);
170        assert!(result.is_ok());
171
172        let quantized = result.unwrap();
173        assert_eq!(quantized.dim, 512);
174        assert_eq!(quantized.data.len(), 512);
175    }
176
177    #[test]
178    fn test_quantize_uniform_values() {
179        let emb = Vector::from_f32(vec![0.5, 0.5, 0.5, 0.5]);
180        let result = quantize_embedding(&emb, QuantScheme::QInt8);
181        assert!(result.is_ok());
182
183        let quantized = result.unwrap();
184        assert_eq!(quantized.dim, 4);
185        // All values should be quantized to the same value
186    }
187
188    #[test]
189    fn test_quantize_empty_fails() {
190        let emb = Vector::from_f32(vec![]);
191        let result = quantize_embedding(&emb, QuantScheme::QInt8);
192        assert!(result.is_err());
193    }
194
195    // Note: All schemes are now implemented, so there's no unsupported scheme to test
196}