manifoldb_vector/quantization/
config.rs

1//! Product Quantization configuration.
2
3use crate::distance::DistanceMetric;
4use crate::error::VectorError;
5
6/// Configuration for Product Quantization.
7///
8/// # Parameters
9///
10/// - `dimension`: The dimension of vectors to quantize
11/// - `num_segments`: Number of subspaces (M). Must divide dimension evenly.
12/// - `num_centroids`: Number of centroids per subspace (K). Typically 256 (8 bits per code).
13/// - `distance_metric`: Distance metric for codebook training and distance computation
14///
15/// # Memory Usage
16///
17/// - Codebooks: `M × K × (D/M) × 4` bytes
18/// - Per-vector codes: `M × ceil(log2(K)/8)` bytes
19///
20/// For typical settings (D=128, M=8, K=256):
21/// - Codebooks: 8 × 256 × 16 × 4 = 128KB
22/// - Per-vector: 8 bytes (compression ratio: 64x)
23#[derive(Debug, Clone)]
24pub struct PQConfig {
25    /// Dimension of input vectors.
26    pub dimension: usize,
27    /// Number of subspaces (segments).
28    pub num_segments: usize,
29    /// Number of centroids per subspace.
30    pub num_centroids: usize,
31    /// Distance metric for training and search.
32    pub distance_metric: DistanceMetric,
33    /// Number of training iterations for k-means.
34    pub training_iterations: usize,
35    /// Random seed for reproducible training.
36    pub seed: Option<u64>,
37}
38
39impl PQConfig {
40    /// Create a new PQ configuration.
41    ///
42    /// # Arguments
43    ///
44    /// - `dimension`: The dimension of vectors to quantize
45    /// - `num_segments`: Number of subspaces. Must divide `dimension` evenly.
46    ///
47    /// # Defaults
48    ///
49    /// - `num_centroids`: 256 (8-bit codes)
50    /// - `distance_metric`: Euclidean
51    /// - `training_iterations`: 25
52    /// - `seed`: None (non-deterministic)
53    ///
54    /// # Panics
55    ///
56    /// Panics if `num_segments` is 0 or doesn't divide `dimension` evenly.
57    #[must_use]
58    pub fn new(dimension: usize, num_segments: usize) -> Self {
59        assert!(num_segments > 0, "num_segments must be > 0");
60        assert!(
61            dimension % num_segments == 0,
62            "dimension ({}) must be divisible by num_segments ({})",
63            dimension,
64            num_segments
65        );
66
67        Self {
68            dimension,
69            num_segments,
70            num_centroids: 256,
71            distance_metric: DistanceMetric::Euclidean,
72            training_iterations: 25,
73            seed: None,
74        }
75    }
76
77    /// Set the number of centroids per subspace.
78    ///
79    /// Common values:
80    /// - 256 (8-bit codes, default)
81    /// - 65536 (16-bit codes, more accurate but larger codebooks)
82    #[must_use]
83    pub const fn with_num_centroids(mut self, k: usize) -> Self {
84        self.num_centroids = k;
85        self
86    }
87
88    /// Set the distance metric.
89    #[must_use]
90    pub const fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
91        self.distance_metric = metric;
92        self
93    }
94
95    /// Set the number of training iterations.
96    #[must_use]
97    pub const fn with_training_iterations(mut self, iterations: usize) -> Self {
98        self.training_iterations = iterations;
99        self
100    }
101
102    /// Set the random seed for reproducible training.
103    #[must_use]
104    pub const fn with_seed(mut self, seed: u64) -> Self {
105        self.seed = Some(seed);
106        self
107    }
108
109    /// Get the dimension of each subspace.
110    #[must_use]
111    pub const fn subspace_dimension(&self) -> usize {
112        self.dimension / self.num_segments
113    }
114
115    /// Calculate the number of bits per code.
116    ///
117    /// Returns the number of bits needed to represent a centroid index.
118    #[must_use]
119    #[allow(clippy::cast_possible_truncation)]
120    #[allow(clippy::cast_sign_loss)]
121    pub fn bits_per_code(&self) -> usize {
122        // ceil(log2(num_centroids))
123        if self.num_centroids <= 1 {
124            1
125        } else {
126            (self.num_centroids as f64).log2().ceil() as usize
127        }
128    }
129
130    /// Calculate bytes per encoded vector.
131    #[must_use]
132    pub fn bytes_per_code(&self) -> usize {
133        // For 256 centroids, each code is 1 byte
134        // For 65536 centroids, each code is 2 bytes
135        let bits = self.bits_per_code();
136        let total_bits = bits * self.num_segments;
137        total_bits.div_ceil(8)
138    }
139
140    /// Validate the configuration.
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if:
145    /// - `dimension` is 0
146    /// - `num_segments` is 0 or doesn't divide `dimension`
147    /// - `num_centroids` is 0
148    pub fn validate(&self) -> Result<(), VectorError> {
149        if self.dimension == 0 {
150            return Err(VectorError::InvalidDimension { expected: 1, actual: 0 });
151        }
152
153        if self.num_segments == 0 {
154            return Err(VectorError::Encoding("num_segments must be > 0".to_string()));
155        }
156
157        if self.dimension % self.num_segments != 0 {
158            return Err(VectorError::Encoding(format!(
159                "dimension ({}) must be divisible by num_segments ({})",
160                self.dimension, self.num_segments
161            )));
162        }
163
164        if self.num_centroids == 0 {
165            return Err(VectorError::Encoding("num_centroids must be > 0".to_string()));
166        }
167
168        Ok(())
169    }
170
171    /// Calculate compression ratio compared to full f32 vectors.
172    #[must_use]
173    pub fn compression_ratio(&self) -> f32 {
174        let original_bytes = self.dimension * 4; // 4 bytes per f32
175        let compressed_bytes = self.bytes_per_code();
176        original_bytes as f32 / compressed_bytes as f32
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_basic_config() {
186        let config = PQConfig::new(128, 8);
187        assert_eq!(config.dimension, 128);
188        assert_eq!(config.num_segments, 8);
189        assert_eq!(config.num_centroids, 256);
190        assert_eq!(config.subspace_dimension(), 16);
191    }
192
193    #[test]
194    fn test_bits_per_code() {
195        let config = PQConfig::new(128, 8).with_num_centroids(256);
196        assert_eq!(config.bits_per_code(), 8);
197
198        let config = PQConfig::new(128, 8).with_num_centroids(65536);
199        assert_eq!(config.bits_per_code(), 16);
200
201        let config = PQConfig::new(128, 8).with_num_centroids(16);
202        assert_eq!(config.bits_per_code(), 4);
203    }
204
205    #[test]
206    fn test_bytes_per_code() {
207        // 8 segments × 8 bits = 64 bits = 8 bytes
208        let config = PQConfig::new(128, 8).with_num_centroids(256);
209        assert_eq!(config.bytes_per_code(), 8);
210
211        // 8 segments × 16 bits = 128 bits = 16 bytes
212        let config = PQConfig::new(128, 8).with_num_centroids(65536);
213        assert_eq!(config.bytes_per_code(), 16);
214    }
215
216    #[test]
217    fn test_compression_ratio() {
218        // 128 × 4 = 512 bytes original, 8 bytes compressed = 64x
219        let config = PQConfig::new(128, 8).with_num_centroids(256);
220        assert!((config.compression_ratio() - 64.0).abs() < 0.01);
221    }
222
223    #[test]
224    fn test_validation() {
225        let config = PQConfig::new(128, 8);
226        assert!(config.validate().is_ok());
227
228        // Invalid: num_centroids = 0
229        let mut config = PQConfig::new(128, 8);
230        config.num_centroids = 0;
231        assert!(config.validate().is_err());
232    }
233
234    #[test]
235    #[should_panic(expected = "num_segments must be > 0")]
236    fn test_zero_segments_panics() {
237        let _ = PQConfig::new(128, 0);
238    }
239
240    #[test]
241    #[should_panic(expected = "must be divisible by")]
242    fn test_indivisible_dimension_panics() {
243        let _ = PQConfig::new(128, 7);
244    }
245}