lnmp_quant/
binary.rs

1//! Binary: 1-bit quantization for maximum compression
2//!
3//! This module implements binary (1-bit) quantization where each value is represented
4//! by a single bit based on its sign relative to the mean. This provides 32x compression
5//! and is particularly effective for similarity search tasks.
6
7use crate::error::QuantError;
8use crate::scheme::QuantScheme;
9use crate::vector::QuantizedVector;
10use lnmp_embedding::Vector;
11
12/// Quantizes an embedding vector to binary (1-bit per value)
13///
14/// Binary quantization uses a simple threshold-based approach:
15/// - Calculate the mean of all values
16/// - Each value above mean → bit 1
17/// - Each value below mean → bit 0
18///
19/// This is extremely lossy but preserves relative ordering and is fast for similarity.
20///
21/// # Arguments
22/// * `emb` - The embedding vector to quantize
23///
24/// # Returns
25/// * `Ok(QuantizedVector)` - Quantized vector with 32x compression
26/// * `Err(QuantError)` - If quantization fails
27///
28/// # Example
29/// ```
30/// use lnmp_quant::binary::quantize_binary;
31/// use lnmp_embedding::Vector;
32///
33/// let emb = Vector::from_f32(vec![0.1, -0.2, 0.3, -0.4]);
34/// let quantized = quantize_binary(&emb).unwrap();
35/// assert_eq!(quantized.scheme, lnmp_quant::QuantScheme::Binary);
36/// ```
37pub fn quantize_binary(emb: &Vector) -> Result<QuantizedVector, QuantError> {
38    // Validate input
39    if emb.dtype != lnmp_embedding::EmbeddingType::F32 {
40        return Err(QuantError::EncodingFailed(
41            "Only F32 embeddings are supported for binary quantization".to_string(),
42        ));
43    }
44
45    let values = emb
46        .as_f32()
47        .map_err(|e| QuantError::EncodingFailed(format!("Failed to convert to F32: {}", e)))?;
48
49    if values.is_empty() {
50        return Err(QuantError::InvalidDimension(
51            "Cannot quantize empty vector".to_string(),
52        ));
53    }
54
55    // Calculate mean for thresholding
56    let mean: f32 = values.iter().sum::<f32>() / values.len() as f32;
57
58    // Pack bits: 8 values per byte
59    let num_bytes = values.len().div_ceil(8); // Round up
60    let mut data = Vec::with_capacity(num_bytes);
61
62    for chunk in values.chunks(8) {
63        let mut byte = 0u8;
64        for (i, &val) in chunk.iter().enumerate() {
65            if val >= mean {
66                byte |= 1 << i; // Set bit if value >= mean
67            }
68            // Otherwise bit remains 0
69        }
70        data.push(byte);
71    }
72
73    // Store mean in min_val field for use during dequantization
74    // Scale and zero_point are not really used for binary quantization
75    Ok(QuantizedVector::new(
76        emb.dim as u32,
77        QuantScheme::Binary,
78        1.0,  // Not used for binary
79        0,    // Not used for binary
80        mean, // Store mean as min_val for reconstruction
81        data,
82    ))
83}
84
85/// Dequantizes a binary quantized vector back to f32
86///
87/// Converts each bit back to either +1.0 or -1.0:
88/// - Bit = 1 → +1.0 (above mean)
89/// - Bit = 0 → -1.0 (below mean)
90///
91/// This creates a normalized vector suitable for similarity comparison.
92///
93/// # Arguments
94/// * `qv` - The quantized vector to dequantize
95///
96/// # Returns
97/// * `Ok(Vector)` - Restored f32 vector with normalized values
98/// * `Err(QuantError)` - If dequantization fails
99pub fn dequantize_binary(qv: &QuantizedVector) -> Result<Vector, QuantError> {
100    if qv.scheme != QuantScheme::Binary {
101        return Err(QuantError::InvalidScheme(format!(
102            "Expected Binary, got {:?}",
103            qv.scheme
104        )));
105    }
106
107    let dim = qv.dim as usize;
108    let mut values = Vec::with_capacity(dim);
109
110    for &byte in &qv.data {
111        for i in 0..8 {
112            if values.len() >= dim {
113                break; // Stop if we've reached the target dimension
114            }
115
116            let bit = (byte >> i) & 1;
117            // Convert to +1 or -1 for normalized representation
118            let value = if bit == 1 { 1.0 } else { -1.0 };
119            values.push(value);
120        }
121    }
122
123    // Truncate to exact dimension
124    values.truncate(dim);
125
126    Ok(Vector::from_f32(values))
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use lnmp_embedding::SimilarityMetric;
133
134    #[test]
135    fn test_binary_basic() {
136        // Simple test with known values
137        let vec = Vector::from_f32(vec![0.5, -0.5, 0.3, -0.3]);
138        let quantized = quantize_binary(&vec).unwrap();
139
140        // Should pack 4 values into 1 byte
141        assert_eq!(quantized.data.len(), 1);
142        assert_eq!(quantized.scheme, QuantScheme::Binary);
143
144        let restored = dequantize_binary(&quantized).unwrap();
145        assert_eq!(restored.dim, 4);
146    }
147
148    #[test]
149    fn test_binary_compression() {
150        // Test with larger vector to verify compression
151        let values: Vec<f32> = (0..512).map(|i| (i as f32) / 512.0 - 0.5).collect();
152        let vec = Vector::from_f32(values.clone());
153
154        let quantized = quantize_binary(&vec).unwrap();
155
156        // Original: 512 * 4 = 2048 bytes
157        // Binary: 512 / 8 = 64 bytes
158        // Compression: 32x
159        let original_bytes = 512 * 4;
160        let quantized_bytes = quantized.data.len();
161        let compression_ratio = original_bytes as f32 / quantized_bytes as f32;
162
163        assert_eq!(compression_ratio, 32.0);
164    }
165
166    #[test]
167    fn test_binary_bit_packing() {
168        // Verify bit packing is correct
169        // Mean of [1.0, -1.0] = 0.0
170        // 1.0 >= 0.0 → bit 1
171        // -1.0 < 0.0 → bit 0
172        let vec = Vector::from_f32(vec![1.0, -1.0]);
173        let quantized = quantize_binary(&vec).unwrap();
174
175        assert_eq!(quantized.data.len(), 1);
176        // First bit should be set (1.0), second bit should be clear (-1.0)
177        let byte = quantized.data[0];
178        assert_eq!(byte & 0x01, 1); // First bit set
179        assert_eq!(byte & 0x02, 0); // Second bit clear
180    }
181
182    #[test]
183    fn test_binary_odd_dimension() {
184        // Test with dimension not divisible by 8
185        let vec = Vector::from_f32(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
186        let quantized = quantize_binary(&vec).unwrap();
187
188        // 5 values need 1 byte (with 3 unused bits)
189        assert_eq!(quantized.data.len(), 1);
190
191        let restored = dequantize_binary(&quantized).unwrap();
192        assert_eq!(restored.dim, 5);
193    }
194
195    #[test]
196    fn test_binary_roundtrip_similarity() {
197        // Binary quantization is very lossy, but should preserve relative similarity
198        let vec1 = Vector::from_f32(vec![0.5, 0.3, -0.2, 0.8, -0.4, 0.1]);
199        let vec2 = Vector::from_f32(vec![0.4, 0.35, -0.15, 0.75, -0.35, 0.15]);
200
201        // Original similarity
202        let original_similarity = vec1.similarity(&vec2, SimilarityMetric::Cosine).unwrap();
203
204        // Quantize both
205        let q1 = quantize_binary(&vec1).unwrap();
206        let q2 = quantize_binary(&vec2).unwrap();
207
208        // Dequantize
209        let r1 = dequantize_binary(&q1).unwrap();
210        let r2 = dequantize_binary(&q2).unwrap();
211
212        // Binary similarity
213        let binary_similarity = r1.similarity(&r2, SimilarityMetric::Cosine).unwrap();
214
215        // Binary quantization preserves general similarity patterns
216        // Both should be positive (similar vectors)
217        assert!(original_similarity > 0.5);
218        assert!(binary_similarity > 0.0);
219
220        // Print for information
221        println!("Original similarity: {}", original_similarity);
222        println!("Binary similarity: {}", binary_similarity);
223    }
224
225    #[test]
226    fn test_binary_uniform_values() {
227        // All same values → mean equals that value
228        // All values will be exactly at mean, implementation defines behavior
229        let vec = Vector::from_f32(vec![0.5, 0.5, 0.5, 0.5]);
230        let quantized = quantize_binary(&vec).unwrap();
231
232        let restored = dequantize_binary(&quantized).unwrap();
233        assert_eq!(restored.dim, 4);
234        // All values should be +1 or -1
235        let restored_values = restored.as_f32().unwrap();
236        for val in restored_values {
237            assert!(val == 1.0 || val == -1.0);
238        }
239    }
240
241    #[test]
242    fn test_binary_normalized_output() {
243        // Binary dequantization always outputs +1 or -1
244        let vec = Vector::from_f32(vec![0.1, -0.5, 0.8, -0.2, 0.4]);
245        let quantized = quantize_binary(&vec).unwrap();
246        let restored = dequantize_binary(&quantized).unwrap();
247
248        let restored_values = restored.as_f32().unwrap();
249        for val in restored_values {
250            assert!(val == 1.0 || val == -1.0, "Value: {}", val);
251        }
252    }
253
254    #[test]
255    fn test_binary_empty_fails() {
256        let vec = Vector::from_f32(vec![]);
257        let result = quantize_binary(&vec);
258        assert!(result.is_err());
259    }
260
261    #[test]
262    fn test_binary_large_vector() {
263        // Test with typical embedding dimension
264        let values: Vec<f32> = (0..1536)
265            .map(|i| ((i % 100) as f32) / 100.0 - 0.5)
266            .collect();
267        let vec = Vector::from_f32(values);
268
269        let quantized = quantize_binary(&vec).unwrap();
270
271        // 1536 values / 8 = 192 bytes
272        assert_eq!(quantized.data.len(), 192);
273
274        let restored = dequantize_binary(&quantized).unwrap();
275        assert_eq!(restored.dim, 1536);
276
277        // Verify all values are normalized
278        let restored_values = restored.as_f32().unwrap();
279        for val in restored_values {
280            assert!(val == 1.0 || val == -1.0);
281        }
282    }
283}