vq/
opq.rs

1//! # Optimized Product Quantizer Implementation
2//!
3//! This module implements an Optimized Product Quantizer (OPQ) that first learns an optimal
4//! rotation of the input data before performing product quantization. The OPQ algorithm reduces
5//! quantization error by rotating the data, partitioning the rotated data into `m` subspaces,
6//! and learning a separate codebook for each subspace via the LBG algorithm. During quantization,
7//! the input vector is rotated and each sub-vector is quantized by selecting the nearest centroid
8//! (using a specified distance metric). The final quantized representation is obtained by concatenating
9//! the selected codewords and converting them to half-precision (`f16`).
10//!
11//! # Errors
12//! The `fit` and `quantize` methods panic with custom errors from the exceptions module when:
13//! - The training data is empty.
14//! - The dimension of the training vectors is less than `m` or not divisible by `m`.
15//! - The input vector's dimension in `quantize` does not match the expected dimension.
16//!
17//! # Example
18//! ```
19//! use vq::vector::Vector;
20//! use vq::distances::Distance;
21//! use vq::opq::OptimizedProductQuantizer;
22//! use nalgebra::DMatrix;
23//!
24//! // Create a small training dataset. Each vector has dimension 4.
25//! let training_data = vec![
26//!     Vector::new(vec![0.0, 0.0, 0.0, 0.0]),
27//!     Vector::new(vec![1.0, 1.0, 1.0, 1.0]),
28//!     Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
29//! ];
30//!
31//! // Partition the 4-dimensional vectors into m = 2 subspaces (each of dimension 2).
32//! let m = 2;
33//! // Use k = 2 centroids per subspace (training data length 3 is sufficient for k = 2).
34//! let k = 2;
35//! let max_iters = 10;
36//! let opq_iters = 5;
37//! let seed = 42;
38//! let distance = Distance::Euclidean;
39//!
40//! // Fit the optimized product quantizer with the training data.
41//! let opq = OptimizedProductQuantizer::fit(&training_data, m, k, max_iters, opq_iters, distance, seed);
42//!
43//! // Quantize an input vector (dimension must equal 4).
44//! let input = Vector::new(vec![0.2, 0.8, 0.3, 0.7]);
45//! let quantized = opq.quantize(&input);
46//! println!("Quantized vector: {:?}", quantized);
47//! ```
48
49use crate::distances::Distance;
50use crate::exceptions::VqError;
51use crate::utils::lbg_quantize;
52use crate::vector::Vector;
53use half::f16;
54use nalgebra::DMatrix;
55use rayon::prelude::*;
56
57pub struct OptimizedProductQuantizer {
58    /// The learned rotation matrix (of size `dim x dim`).
59    rotation: DMatrix<f32>,
60    /// A vector of codebooks (one for each subspace). Each codebook is a vector of centroids.
61    codebooks: Vec<Vec<Vector<f32>>>,
62    /// The dimensionality of each subspace (i.e. `dim / m`).
63    sub_dim: usize,
64    /// The number of subspaces into which the rotated vector is partitioned.
65    m: usize,
66    /// The overall dimensionality of the input vectors.
67    dim: usize,
68    /// The distance metric used for selecting codewords during quantization.
69    distance: Distance,
70}
71
72impl OptimizedProductQuantizer {
73    /// Constructs a new `OptimizedProductQuantizer` from training data.
74    ///
75    /// # Parameters
76    /// - `training_data`: A slice of training vectors (`Vector<f32>`) used for learning the quantizer.
77    /// - `m`: The number of subspaces into which the rotated data will be partitioned.
78    /// - `k`: The number of centroids (codewords) per subspace.
79    /// - `max_iters`: The maximum number of iterations for the LBG quantization algorithm.
80    /// - `opq_iters`: The number of OPQ iterations (i.e. the number of times the algorithm alternates
81    ///    between codebook learning, reconstruction, rotation update, and re-rotation).
82    /// - `distance`: The distance metric to use for comparing subvectors during codeword selection.
83    /// - `seed`: A random seed for initializing LBG quantization (each subspace uses `seed + i`).
84    ///
85    /// # Panics
86    /// Panics with a custom error if:
87    /// - `training_data` is empty.
88    /// - The dimension of the training vectors is less than `m`.
89    /// - The dimension of the training vectors is not divisible by `m`.
90    pub fn fit(
91        training_data: &[Vector<f32>],
92        m: usize,
93        k: usize,
94        max_iters: usize,
95        opq_iters: usize,
96        distance: Distance,
97        seed: u64,
98    ) -> Self {
99        if training_data.is_empty() {
100            panic!("{}", VqError::EmptyInput);
101        }
102        let dim = training_data[0].len();
103        if dim < m {
104            panic!(
105                "{}",
106                VqError::InvalidParameter("Dimension must be at least m".to_string())
107            );
108        }
109        if dim % m != 0 {
110            panic!(
111                "{}",
112                VqError::InvalidParameter("Dimension must be divisible by m".to_string())
113            );
114        }
115        let sub_dim = dim / m;
116        let n = training_data.len();
117
118        // Start with an identity rotation.
119        let mut rotation = DMatrix::<f32>::identity(dim, dim);
120        // Initially, no rotation is applied.
121        let mut rotated_data: Vec<Vector<f32>> = training_data.to_vec();
122        let mut codebooks = Vec::with_capacity(m);
123
124        for _ in 0..opq_iters {
125            // --- Codebook Learning ---
126            // Learn a codebook for each subspace in parallel.
127            codebooks = (0..m)
128                .into_par_iter()
129                .map(|i| {
130                    // Extract the sub-training data for subspace `i`.
131                    let sub_training: Vec<Vector<f32>> = rotated_data
132                        .iter()
133                        .map(|v| {
134                            let start = i * sub_dim;
135                            let end = start + sub_dim;
136                            Vector::new(v.data[start..end].to_vec())
137                        })
138                        .collect();
139                    // Learn a codebook for the subspace using LBG quantization.
140                    lbg_quantize(&sub_training, k, max_iters, seed + i as u64)
141                })
142                .collect();
143
144            // --- Reconstruction ---
145            // For each rotated vector, compute its reconstruction using the current codebooks.
146            let reconstructions: Vec<Vector<f32>> = rotated_data
147                .par_iter()
148                .map(|v| {
149                    let mut rec = Vec::with_capacity(dim);
150                    // Use enumerate to iterate over codebooks.
151                    for (i, codebook) in codebooks.iter().enumerate() {
152                        let start = i * sub_dim;
153                        let end = start + sub_dim;
154                        let sub_vector = &v.data[start..end];
155                        let mut best_index = 0;
156                        let mut best_dist = distance.compute(sub_vector, &codebook[0].data);
157                        for (j, centroid) in codebook.iter().enumerate().skip(1) {
158                            let dist = distance.compute(sub_vector, &centroid.data);
159                            if dist < best_dist {
160                                best_dist = dist;
161                                best_index = j;
162                            }
163                        }
164                        rec.extend_from_slice(&codebook[best_index].data);
165                    }
166                    Vector::new(rec)
167                })
168                .collect();
169
170            // --- Rotation Update ---
171            // Prepare data matrices: x_mat for rotated_data, y_mat for reconstructions.
172            let mut x_data: Vec<f32> = Vec::with_capacity(dim * n);
173            let mut y_data: Vec<f32> = Vec::with_capacity(dim * n);
174            // Flatten rotated_data and reconstructions.
175            rotated_data.iter().for_each(|v| x_data.extend(&v.data));
176            reconstructions.iter().for_each(|v| y_data.extend(&v.data));
177            let x_mat = DMatrix::from_column_slice(dim, n, &x_data);
178            let y_mat = DMatrix::from_column_slice(dim, n, &y_data);
179            let a: DMatrix<f32> = &y_mat * x_mat.transpose();
180            let svd = a.svd(true, true);
181            let u = svd.u.expect("SVD failed to produce U");
182            let v_t = svd.v_t.expect("SVD failed to produce Váµ€");
183            rotation = v_t.transpose() * u.transpose();
184
185            // --- Re-rotate the Original Data ---
186            rotated_data = training_data
187                .par_iter()
188                .map(|v| {
189                    let x = DMatrix::from_column_slice(dim, 1, &v.data);
190                    let y = &rotation * x;
191                    let y_vec: Vec<f32> = y.column(0).iter().cloned().collect();
192                    Vector::new(y_vec)
193                })
194                .collect();
195        }
196
197        Self {
198            rotation,
199            codebooks,
200            sub_dim,
201            m,
202            dim,
203            distance,
204        }
205    }
206
207    /// Quantizes an input vector using the learned rotation and codebooks.
208    ///
209    /// The input vector is first rotated using the learned rotation matrix. It is then partitioned into `m`
210    /// sub-vectors, each of dimension `sub_dim`. For each subspace, the nearest codeword is selected using the
211    /// stored distance metric. The selected codewords (one from each subspace) are concatenated and converted
212    /// to half-precision (`f16`), resulting in the final quantized representation.
213    ///
214    /// # Parameters
215    /// - `vector`: The input vector (`Vector<f32>`) to be quantized.
216    ///
217    /// # Returns
218    /// A quantized vector (`Vector<f16>`) representing the input vector.
219    ///
220    /// # Panics
221    /// Panics with a custom error if the input vector's dimension does not match the expected dimension.
222    pub fn quantize(&self, vector: &Vector<f32>) -> Vector<f16> {
223        if vector.len() != self.dim {
224            panic!(
225                "{}",
226                VqError::DimensionMismatch {
227                    expected: self.dim,
228                    found: vector.len()
229                }
230            );
231        }
232        let x = DMatrix::from_column_slice(self.dim, 1, &vector.data);
233        let y = &self.rotation * x;
234        let y_vec: Vec<f32> = y.column(0).iter().cloned().collect();
235        if y_vec.len() != self.sub_dim * self.m {
236            panic!(
237                "{}",
238                VqError::DimensionMismatch {
239                    expected: self.sub_dim * self.m,
240                    found: y_vec.len()
241                }
242            );
243        }
244        let mut quantized_data = Vec::with_capacity(y_vec.len());
245        // Use enumerate to iterate over the codebooks.
246        for (i, codebook) in self.codebooks.iter().enumerate() {
247            let start = i * self.sub_dim;
248            let end = start + self.sub_dim;
249            let sub_vector = &y_vec[start..end];
250            let mut best_index = 0;
251            let mut best_dist = self.distance.compute(sub_vector, &codebook[0].data);
252            for (j, centroid) in codebook.iter().enumerate().skip(1) {
253                let dist = self.distance.compute(sub_vector, &centroid.data);
254                if dist < best_dist {
255                    best_dist = dist;
256                    best_index = j;
257                }
258            }
259            for &val in &codebook[best_index].data {
260                quantized_data.push(f16::from_f32(val));
261            }
262        }
263        Vector::new(quantized_data)
264    }
265}