oxirs_vec/
opq.rs

1//! Optimized Product Quantization (OPQ) implementation
2//!
3//! OPQ learns an optimal rotation matrix that minimizes quantization error
4//! by aligning the data before applying product quantization. This typically
5//! provides better compression quality than standard PQ.
6
7use crate::{
8    pq::{PQConfig, PQIndex},
9    Vector,
10};
11use anyhow::{anyhow, Result};
12use nalgebra::{DMatrix, DVector, SVD};
13
14/// Configuration for Optimized Product Quantization
15#[derive(Debug, Clone)]
16pub struct OPQConfig {
17    /// Base PQ configuration
18    pub pq_config: PQConfig,
19    /// Number of iterations for alternating optimization
20    pub n_iterations: usize,
21    /// Whether to center data before rotation
22    pub center_data: bool,
23    /// Regularization parameter for rotation matrix
24    pub regularization: f32,
25}
26
27impl Default for OPQConfig {
28    fn default() -> Self {
29        Self {
30            pq_config: PQConfig::default(),
31            n_iterations: 10,
32            center_data: true,
33            regularization: 0.0,
34        }
35    }
36}
37
38/// Optimized Product Quantization index
39pub struct OPQIndex {
40    /// Configuration
41    config: OPQConfig,
42    /// Rotation matrix R (d x d)
43    rotation_matrix: Option<DMatrix<f32>>,
44    /// Data mean for centering
45    data_mean: Option<DVector<f32>>,
46    /// Underlying PQ index
47    pq_index: PQIndex,
48    /// Whether the model is trained
49    is_trained: bool,
50}
51
52impl OPQIndex {
53    /// Create a new OPQ index
54    pub fn new(config: OPQConfig) -> Self {
55        Self {
56            pq_index: PQIndex::new(config.pq_config.clone()),
57            config,
58            rotation_matrix: None,
59            data_mean: None,
60            is_trained: false,
61        }
62    }
63
64    /// Train the OPQ model using alternating optimization
65    pub fn train(&mut self, vectors: &[Vector]) -> Result<()> {
66        if vectors.is_empty() {
67            return Err(anyhow!("Cannot train OPQ with empty data"));
68        }
69
70        let n_samples = vectors.len();
71        let dimensions = vectors[0].dimensions;
72
73        // Convert vectors to matrix (samples x dimensions)
74        let mut data_matrix = DMatrix::zeros(n_samples, dimensions);
75        for (i, vector) in vectors.iter().enumerate() {
76            let vec_f32 = vector.as_f32();
77            for (j, &val) in vec_f32.iter().enumerate() {
78                data_matrix[(i, j)] = val;
79            }
80        }
81
82        // Center data if requested
83        if self.config.center_data {
84            let mean = self.compute_mean(&data_matrix);
85            self.center_data_matrix(&mut data_matrix, &mean);
86            self.data_mean = Some(mean);
87        }
88
89        // Initialize rotation matrix as identity
90        let mut rotation = DMatrix::identity(dimensions, dimensions);
91
92        // Alternating optimization
93        for iteration in 0..self.config.n_iterations {
94            println!(
95                "OPQ iteration {}/{}",
96                iteration + 1,
97                self.config.n_iterations
98            );
99
100            // Step 1: Fix R, optimize C (codebooks)
101            let rotated_data = self.apply_rotation(&data_matrix, &rotation);
102            let rotated_vectors = self.matrix_to_vectors(&rotated_data);
103
104            // Train PQ on rotated data
105            self.pq_index.train(&rotated_vectors)?;
106
107            // Step 2: Fix C, optimize R
108            rotation = self.optimize_rotation(&data_matrix, &rotated_vectors)?;
109
110            // Compute reconstruction error for monitoring
111            let error = self.compute_reconstruction_error(&data_matrix, &rotation)?;
112            println!("Reconstruction error: {error}");
113        }
114
115        self.rotation_matrix = Some(rotation);
116        self.is_trained = true;
117
118        Ok(())
119    }
120
121    /// Compute mean of data matrix
122    fn compute_mean(&self, data: &DMatrix<f32>) -> DVector<f32> {
123        let n_samples = data.nrows() as f32;
124        let mut mean = DVector::zeros(data.ncols());
125
126        for i in 0..data.ncols() {
127            mean[i] = data.column(i).sum() / n_samples;
128        }
129
130        mean
131    }
132
133    /// Center data matrix by subtracting mean
134    fn center_data_matrix(&self, data: &mut DMatrix<f32>, mean: &DVector<f32>) {
135        for i in 0..data.nrows() {
136            for j in 0..data.ncols() {
137                data[(i, j)] -= mean[j];
138            }
139        }
140    }
141
142    /// Apply rotation matrix to data
143    fn apply_rotation(&self, data: &DMatrix<f32>, rotation: &DMatrix<f32>) -> DMatrix<f32> {
144        data * rotation.transpose()
145    }
146
147    /// Convert matrix back to vectors
148    fn matrix_to_vectors(&self, matrix: &DMatrix<f32>) -> Vec<Vector> {
149        let mut vectors = Vec::with_capacity(matrix.nrows());
150
151        for i in 0..matrix.nrows() {
152            let row: Vec<f32> = matrix.row(i).iter().cloned().collect();
153            vectors.push(Vector::new(row));
154        }
155
156        vectors
157    }
158
159    /// Optimize rotation matrix using SVD
160    fn optimize_rotation(
161        &self,
162        data: &DMatrix<f32>,
163        rotated_vectors: &[Vector],
164    ) -> Result<DMatrix<f32>> {
165        // Reconstruct data using current codebooks
166        let mut reconstructed = DMatrix::zeros(data.nrows(), data.ncols());
167
168        for (i, vector) in rotated_vectors.iter().enumerate() {
169            // Encode and decode to get reconstruction
170            if let Ok(reconstructed_vec) = self.pq_index.reconstruct(vector) {
171                let rec_f32 = reconstructed_vec.as_f32();
172                for (j, &val) in rec_f32.iter().enumerate() {
173                    reconstructed[(i, j)] = val;
174                }
175            }
176        }
177
178        // Solve orthogonal Procrustes problem: min ||X - Y*R||_F
179        // Solution: R = U*V^T where X^T*Y = U*S*V^T
180        let correlation = data.transpose() * &reconstructed;
181
182        // Add regularization if needed
183        let mut reg_correlation = correlation.clone();
184        if self.config.regularization > 0.0 {
185            for i in 0..reg_correlation.ncols().min(reg_correlation.nrows()) {
186                reg_correlation[(i, i)] += self.config.regularization;
187            }
188        }
189
190        // Compute SVD
191        let svd = SVD::new(reg_correlation, true, true);
192        let u = svd.u.ok_or_else(|| anyhow!("SVD failed to compute U"))?;
193        let v_t = svd
194            .v_t
195            .ok_or_else(|| anyhow!("SVD failed to compute V^T"))?;
196
197        // Optimal rotation is U * V^T
198        Ok(u * v_t)
199    }
200
201    /// Compute reconstruction error
202    fn compute_reconstruction_error(
203        &self,
204        data: &DMatrix<f32>,
205        rotation: &DMatrix<f32>,
206    ) -> Result<f32> {
207        let rotated = self.apply_rotation(data, rotation);
208        let rotated_vecs = self.matrix_to_vectors(&rotated);
209
210        let mut total_error = 0.0;
211        for (i, vec) in rotated_vecs.iter().enumerate() {
212            if let Ok(reconstructed) = self.pq_index.reconstruct(vec) {
213                let rec_f32 = reconstructed.as_f32();
214                for (j, &val) in rec_f32.iter().enumerate() {
215                    let diff = rotated[(i, j)] - val;
216                    total_error += diff * diff;
217                }
218            }
219        }
220
221        Ok((total_error / (data.nrows() * data.ncols()) as f32).sqrt())
222    }
223
224    /// Encode a vector using OPQ
225    pub fn encode(&self, vector: &Vector) -> Result<Vec<u8>> {
226        if !self.is_trained {
227            return Err(anyhow!("OPQ index must be trained before encoding"));
228        }
229
230        // Apply centering and rotation
231        let transformed = self.transform_vector(vector)?;
232
233        // Use PQ to encode
234        self.pq_index.encode(&transformed)
235    }
236
237    /// Decode PQ codes to approximate vector
238    pub fn decode(&self, codes: &[u8]) -> Result<Vector> {
239        if !self.is_trained {
240            return Err(anyhow!("OPQ index must be trained before decoding"));
241        }
242
243        // Decode using PQ
244        let rotated = self.pq_index.decode(codes)?;
245
246        // Apply inverse transformation
247        self.inverse_transform_vector(&rotated)
248    }
249
250    /// Transform vector: center and rotate
251    fn transform_vector(&self, vector: &Vector) -> Result<Vector> {
252        let rotation = self
253            .rotation_matrix
254            .as_ref()
255            .ok_or_else(|| anyhow!("Rotation matrix not initialized"))?;
256
257        let vec_f32 = vector.as_f32();
258        let mut vec_dv = DVector::from_vec(vec_f32.to_vec());
259
260        // Center if needed
261        if let Some(ref mean) = self.data_mean {
262            vec_dv -= mean;
263        }
264
265        // Apply rotation
266        let rotated = rotation.transpose() * vec_dv;
267
268        Ok(Vector::new(rotated.iter().cloned().collect()))
269    }
270
271    /// Inverse transform: rotate back and uncenter
272    fn inverse_transform_vector(&self, vector: &Vector) -> Result<Vector> {
273        let rotation = self
274            .rotation_matrix
275            .as_ref()
276            .ok_or_else(|| anyhow!("Rotation matrix not initialized"))?;
277
278        let vec_f32 = vector.as_f32();
279        let vec_dv = DVector::from_vec(vec_f32.to_vec());
280
281        // Apply inverse rotation
282        let unrotated = rotation * vec_dv;
283
284        // Uncenter if needed
285        let mut result = unrotated;
286        if let Some(ref mean) = self.data_mean {
287            result += mean;
288        }
289
290        Ok(Vector::new(result.iter().cloned().collect()))
291    }
292
293    /// Search for nearest neighbors using asymmetric distance computation
294    pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
295        if !self.is_trained {
296            return Err(anyhow!("OPQ index must be trained before searching"));
297        }
298
299        // Transform query
300        let transformed_query = self.transform_vector(query)?;
301
302        // Use PQ search
303        self.pq_index.search(&transformed_query, k)
304    }
305
306    /// Get compression statistics
307    pub fn stats(&self) -> OPQStats {
308        let pq_stats = self.pq_index.stats();
309
310        OPQStats {
311            pq_stats,
312            is_trained: self.is_trained,
313            has_rotation: self.rotation_matrix.is_some(),
314            rotation_rank: self
315                .rotation_matrix
316                .as_ref()
317                .map(|r| r.rank(1e-6))
318                .unwrap_or(0),
319        }
320    }
321}
322
323/// Statistics for OPQ index
324#[derive(Debug, Clone)]
325pub struct OPQStats {
326    pub pq_stats: crate::pq::PQStats,
327    pub is_trained: bool,
328    pub has_rotation: bool,
329    pub rotation_rank: usize,
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use crate::VectorIndex;
336
337    #[test]
338    fn test_opq_basic() -> Result<()> {
339        let config = OPQConfig {
340            pq_config: PQConfig {
341                n_subquantizers: 4,
342                n_centroids: 16,
343                ..Default::default()
344            },
345            n_iterations: 3,
346            ..Default::default()
347        };
348
349        let mut opq = OPQIndex::new(config);
350
351        // Create test data
352        let vectors: Vec<Vector> = (0..100)
353            .map(|i| {
354                let values: Vec<f32> = (0..16)
355                    .map(|j| (i as f32 * 0.1 + j as f32) % 10.0)
356                    .collect();
357                Vector::new(values)
358            })
359            .collect();
360
361        // Train OPQ
362        opq.train(&vectors)?;
363
364        // Test encoding/decoding
365        let test_vec = Vector::new(vec![1.0; 16]);
366        let codes = opq.encode(&test_vec)?;
367        let reconstructed = opq.decode(&codes)?;
368
369        assert_eq!(reconstructed.dimensions, 16);
370
371        Ok(())
372    }
373
374    #[test]
375    fn test_opq_search() -> Result<()> {
376        let config = OPQConfig::default();
377        let mut opq = OPQIndex::new(config);
378
379        // Create and train on random vectors
380        let vectors: Vec<Vector> = (0..50)
381            .map(|i| {
382                let values: Vec<f32> = (0..8).map(|j| ((i * j) as f32).sin()).collect();
383                Vector::new(values)
384            })
385            .collect();
386
387        opq.train(&vectors)?;
388
389        // Add vectors to index
390        for (i, vec) in vectors.iter().enumerate() {
391            opq.pq_index.insert(format!("vec_{i}"), vec.clone())?;
392        }
393
394        // Search
395        let query = Vector::new(vec![0.5; 8]);
396        let results = opq.search(&query, 5)?;
397
398        assert_eq!(results.len(), 5);
399
400        Ok(())
401    }
402}