1use crate::error::QuantError;
2use crate::metrics::QuantMetrics;
3use crate::scheme::QuantScheme;
4use crate::vector::QuantizedVector;
5use lnmp_embedding::Vector;
6
7pub fn quantize_embedding(
26 emb: &Vector,
27 scheme: QuantScheme,
28) -> Result<QuantizedVector, QuantError> {
29 if emb.dtype != lnmp_embedding::EmbeddingType::F32 {
31 return Err(QuantError::EncodingFailed(
32 "Only F32 embeddings are currently supported for quantization".to_string(),
33 ));
34 }
35
36 if emb.dim == 0 {
37 return Err(QuantError::InvalidDimension(
38 "Cannot quantize zero-dimensional vector".to_string(),
39 ));
40 }
41
42 match scheme {
43 QuantScheme::QInt8 => quantize_qint8(emb),
44 QuantScheme::QInt4 => crate::qint4::quantize_qint4(emb),
45 QuantScheme::Binary => crate::binary::quantize_binary(emb),
46 QuantScheme::FP16Passthrough => crate::fp16::quantize_fp16(emb),
47 }
48}
49
50fn quantize_qint8(emb: &Vector) -> Result<QuantizedVector, QuantError> {
52 let values = emb
54 .as_f32()
55 .map_err(|e| QuantError::EncodingFailed(format!("Failed to convert to F32: {}", e)))?;
56
57 if values.is_empty() {
58 return Err(QuantError::InvalidDimension(
59 "Empty embedding vector".to_string(),
60 ));
61 }
62
63 let min_val = values
65 .iter()
66 .copied()
67 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
68 .unwrap();
69 let max_val = values
70 .iter()
71 .copied()
72 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
73 .unwrap();
74
75 let (scale, zero_point) = if (max_val - min_val).abs() < 1e-10 {
77 (1.0, 0)
78 } else {
79 let range = max_val - min_val;
81 let scale = range / 255.0; let zero_point = 0i8;
85
86 (scale, zero_point)
87 };
88
89 let mut quantized_data = Vec::with_capacity(values.len());
91
92 let inv_scale = if scale.abs() > 1e-10 {
94 1.0 / scale
95 } else {
96 1.0
97 };
98
99 for &value in &values {
101 let normalized = (value - min_val) * inv_scale;
103 let quantized = (normalized as i32 - 128).clamp(-128, 127) as i8;
104 quantized_data.push(quantized as u8);
105 }
106
107 let loss_ratio = calculate_loss_ratio(&values, &quantized_data, scale, min_val);
109
110 let _metrics = QuantMetrics::new(min_val, max_val, loss_ratio);
111
112 Ok(QuantizedVector::new(
113 emb.dim as u32,
114 QuantScheme::QInt8,
115 scale,
116 zero_point,
117 min_val,
118 quantized_data,
119 ))
120}
121
122fn calculate_loss_ratio(original: &[f32], quantized: &[u8], scale: f32, min_val: f32) -> f32 {
124 if original.is_empty() {
125 return 0.0;
126 }
127
128 let mut total_error = 0.0f32;
129 let mut total_magnitude = 0.0f32;
130
131 for (i, &orig_val) in original.iter().enumerate() {
132 let q_val = quantized[i] as i8;
134 let reconstructed = ((q_val as i32 + 128) as f32 * scale) + min_val;
135
136 let error = (orig_val - reconstructed).abs();
137 total_error += error;
138 total_magnitude += orig_val.abs();
139 }
140
141 if total_magnitude < 1e-10 {
142 0.0
143 } else {
144 total_error / total_magnitude
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use lnmp_embedding::Vector;
152
153 #[test]
154 fn test_quantize_simple() {
155 let emb = Vector::from_f32(vec![0.12, -0.45, 0.33]);
156 let result = quantize_embedding(&emb, QuantScheme::QInt8);
157 assert!(result.is_ok());
158
159 let quantized = result.unwrap();
160 assert_eq!(quantized.dim, 3);
161 assert_eq!(quantized.scheme, QuantScheme::QInt8);
162 assert_eq!(quantized.data.len(), 3);
163 }
164
165 #[test]
166 fn test_quantize_large_vector() {
167 let values: Vec<f32> = (0..512).map(|i| (i as f32 / 512.0) - 0.5).collect();
168 let emb = Vector::from_f32(values);
169 let result = quantize_embedding(&emb, QuantScheme::QInt8);
170 assert!(result.is_ok());
171
172 let quantized = result.unwrap();
173 assert_eq!(quantized.dim, 512);
174 assert_eq!(quantized.data.len(), 512);
175 }
176
177 #[test]
178 fn test_quantize_uniform_values() {
179 let emb = Vector::from_f32(vec![0.5, 0.5, 0.5, 0.5]);
180 let result = quantize_embedding(&emb, QuantScheme::QInt8);
181 assert!(result.is_ok());
182
183 let quantized = result.unwrap();
184 assert_eq!(quantized.dim, 4);
185 }
187
188 #[test]
189 fn test_quantize_empty_fails() {
190 let emb = Vector::from_f32(vec![]);
191 let result = quantize_embedding(&emb, QuantScheme::QInt8);
192 assert!(result.is_err());
193 }
194
195 }