1use crate::{
8 pq::{PQConfig, PQIndex},
9 Vector,
10};
11use anyhow::{anyhow, Result};
12use nalgebra::{DMatrix, DVector, SVD};
13
14#[derive(Debug, Clone)]
16pub struct OPQConfig {
17 pub pq_config: PQConfig,
19 pub n_iterations: usize,
21 pub center_data: bool,
23 pub regularization: f32,
25}
26
27impl Default for OPQConfig {
28 fn default() -> Self {
29 Self {
30 pq_config: PQConfig::default(),
31 n_iterations: 10,
32 center_data: true,
33 regularization: 0.0,
34 }
35 }
36}
37
38pub struct OPQIndex {
40 config: OPQConfig,
42 rotation_matrix: Option<DMatrix<f32>>,
44 data_mean: Option<DVector<f32>>,
46 pq_index: PQIndex,
48 is_trained: bool,
50}
51
52impl OPQIndex {
53 pub fn new(config: OPQConfig) -> Self {
55 Self {
56 pq_index: PQIndex::new(config.pq_config.clone()),
57 config,
58 rotation_matrix: None,
59 data_mean: None,
60 is_trained: false,
61 }
62 }
63
64 pub fn train(&mut self, vectors: &[Vector]) -> Result<()> {
66 if vectors.is_empty() {
67 return Err(anyhow!("Cannot train OPQ with empty data"));
68 }
69
70 let n_samples = vectors.len();
71 let dimensions = vectors[0].dimensions;
72
73 let mut data_matrix = DMatrix::zeros(n_samples, dimensions);
75 for (i, vector) in vectors.iter().enumerate() {
76 let vec_f32 = vector.as_f32();
77 for (j, &val) in vec_f32.iter().enumerate() {
78 data_matrix[(i, j)] = val;
79 }
80 }
81
82 if self.config.center_data {
84 let mean = self.compute_mean(&data_matrix);
85 self.center_data_matrix(&mut data_matrix, &mean);
86 self.data_mean = Some(mean);
87 }
88
89 let mut rotation = DMatrix::identity(dimensions, dimensions);
91
92 for iteration in 0..self.config.n_iterations {
94 println!(
95 "OPQ iteration {}/{}",
96 iteration + 1,
97 self.config.n_iterations
98 );
99
100 let rotated_data = self.apply_rotation(&data_matrix, &rotation);
102 let rotated_vectors = self.matrix_to_vectors(&rotated_data);
103
104 self.pq_index.train(&rotated_vectors)?;
106
107 rotation = self.optimize_rotation(&data_matrix, &rotated_vectors)?;
109
110 let error = self.compute_reconstruction_error(&data_matrix, &rotation)?;
112 println!("Reconstruction error: {error}");
113 }
114
115 self.rotation_matrix = Some(rotation);
116 self.is_trained = true;
117
118 Ok(())
119 }
120
121 fn compute_mean(&self, data: &DMatrix<f32>) -> DVector<f32> {
123 let n_samples = data.nrows() as f32;
124 let mut mean = DVector::zeros(data.ncols());
125
126 for i in 0..data.ncols() {
127 mean[i] = data.column(i).sum() / n_samples;
128 }
129
130 mean
131 }
132
133 fn center_data_matrix(&self, data: &mut DMatrix<f32>, mean: &DVector<f32>) {
135 for i in 0..data.nrows() {
136 for j in 0..data.ncols() {
137 data[(i, j)] -= mean[j];
138 }
139 }
140 }
141
142 fn apply_rotation(&self, data: &DMatrix<f32>, rotation: &DMatrix<f32>) -> DMatrix<f32> {
144 data * rotation.transpose()
145 }
146
147 fn matrix_to_vectors(&self, matrix: &DMatrix<f32>) -> Vec<Vector> {
149 let mut vectors = Vec::with_capacity(matrix.nrows());
150
151 for i in 0..matrix.nrows() {
152 let row: Vec<f32> = matrix.row(i).iter().cloned().collect();
153 vectors.push(Vector::new(row));
154 }
155
156 vectors
157 }
158
159 fn optimize_rotation(
161 &self,
162 data: &DMatrix<f32>,
163 rotated_vectors: &[Vector],
164 ) -> Result<DMatrix<f32>> {
165 let mut reconstructed = DMatrix::zeros(data.nrows(), data.ncols());
167
168 for (i, vector) in rotated_vectors.iter().enumerate() {
169 if let Ok(reconstructed_vec) = self.pq_index.reconstruct(vector) {
171 let rec_f32 = reconstructed_vec.as_f32();
172 for (j, &val) in rec_f32.iter().enumerate() {
173 reconstructed[(i, j)] = val;
174 }
175 }
176 }
177
178 let correlation = data.transpose() * &reconstructed;
181
182 let mut reg_correlation = correlation.clone();
184 if self.config.regularization > 0.0 {
185 for i in 0..reg_correlation.ncols().min(reg_correlation.nrows()) {
186 reg_correlation[(i, i)] += self.config.regularization;
187 }
188 }
189
190 let svd = SVD::new(reg_correlation, true, true);
192 let u = svd.u.ok_or_else(|| anyhow!("SVD failed to compute U"))?;
193 let v_t = svd
194 .v_t
195 .ok_or_else(|| anyhow!("SVD failed to compute V^T"))?;
196
197 Ok(u * v_t)
199 }
200
201 fn compute_reconstruction_error(
203 &self,
204 data: &DMatrix<f32>,
205 rotation: &DMatrix<f32>,
206 ) -> Result<f32> {
207 let rotated = self.apply_rotation(data, rotation);
208 let rotated_vecs = self.matrix_to_vectors(&rotated);
209
210 let mut total_error = 0.0;
211 for (i, vec) in rotated_vecs.iter().enumerate() {
212 if let Ok(reconstructed) = self.pq_index.reconstruct(vec) {
213 let rec_f32 = reconstructed.as_f32();
214 for (j, &val) in rec_f32.iter().enumerate() {
215 let diff = rotated[(i, j)] - val;
216 total_error += diff * diff;
217 }
218 }
219 }
220
221 Ok((total_error / (data.nrows() * data.ncols()) as f32).sqrt())
222 }
223
224 pub fn encode(&self, vector: &Vector) -> Result<Vec<u8>> {
226 if !self.is_trained {
227 return Err(anyhow!("OPQ index must be trained before encoding"));
228 }
229
230 let transformed = self.transform_vector(vector)?;
232
233 self.pq_index.encode(&transformed)
235 }
236
237 pub fn decode(&self, codes: &[u8]) -> Result<Vector> {
239 if !self.is_trained {
240 return Err(anyhow!("OPQ index must be trained before decoding"));
241 }
242
243 let rotated = self.pq_index.decode(codes)?;
245
246 self.inverse_transform_vector(&rotated)
248 }
249
250 fn transform_vector(&self, vector: &Vector) -> Result<Vector> {
252 let rotation = self
253 .rotation_matrix
254 .as_ref()
255 .ok_or_else(|| anyhow!("Rotation matrix not initialized"))?;
256
257 let vec_f32 = vector.as_f32();
258 let mut vec_dv = DVector::from_vec(vec_f32.to_vec());
259
260 if let Some(ref mean) = self.data_mean {
262 vec_dv -= mean;
263 }
264
265 let rotated = rotation.transpose() * vec_dv;
267
268 Ok(Vector::new(rotated.iter().cloned().collect()))
269 }
270
271 fn inverse_transform_vector(&self, vector: &Vector) -> Result<Vector> {
273 let rotation = self
274 .rotation_matrix
275 .as_ref()
276 .ok_or_else(|| anyhow!("Rotation matrix not initialized"))?;
277
278 let vec_f32 = vector.as_f32();
279 let vec_dv = DVector::from_vec(vec_f32.to_vec());
280
281 let unrotated = rotation * vec_dv;
283
284 let mut result = unrotated;
286 if let Some(ref mean) = self.data_mean {
287 result += mean;
288 }
289
290 Ok(Vector::new(result.iter().cloned().collect()))
291 }
292
293 pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
295 if !self.is_trained {
296 return Err(anyhow!("OPQ index must be trained before searching"));
297 }
298
299 let transformed_query = self.transform_vector(query)?;
301
302 self.pq_index.search(&transformed_query, k)
304 }
305
306 pub fn stats(&self) -> OPQStats {
308 let pq_stats = self.pq_index.stats();
309
310 OPQStats {
311 pq_stats,
312 is_trained: self.is_trained,
313 has_rotation: self.rotation_matrix.is_some(),
314 rotation_rank: self
315 .rotation_matrix
316 .as_ref()
317 .map(|r| r.rank(1e-6))
318 .unwrap_or(0),
319 }
320 }
321}
322
323#[derive(Debug, Clone)]
325pub struct OPQStats {
326 pub pq_stats: crate::pq::PQStats,
327 pub is_trained: bool,
328 pub has_rotation: bool,
329 pub rotation_rank: usize,
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use crate::VectorIndex;
336
337 #[test]
338 fn test_opq_basic() -> Result<()> {
339 let config = OPQConfig {
340 pq_config: PQConfig {
341 n_subquantizers: 4,
342 n_centroids: 16,
343 ..Default::default()
344 },
345 n_iterations: 3,
346 ..Default::default()
347 };
348
349 let mut opq = OPQIndex::new(config);
350
351 let vectors: Vec<Vector> = (0..100)
353 .map(|i| {
354 let values: Vec<f32> = (0..16)
355 .map(|j| (i as f32 * 0.1 + j as f32) % 10.0)
356 .collect();
357 Vector::new(values)
358 })
359 .collect();
360
361 opq.train(&vectors)?;
363
364 let test_vec = Vector::new(vec![1.0; 16]);
366 let codes = opq.encode(&test_vec)?;
367 let reconstructed = opq.decode(&codes)?;
368
369 assert_eq!(reconstructed.dimensions, 16);
370
371 Ok(())
372 }
373
374 #[test]
375 fn test_opq_search() -> Result<()> {
376 let config = OPQConfig::default();
377 let mut opq = OPQIndex::new(config);
378
379 let vectors: Vec<Vector> = (0..50)
381 .map(|i| {
382 let values: Vec<f32> = (0..8).map(|j| ((i * j) as f32).sin()).collect();
383 Vector::new(values)
384 })
385 .collect();
386
387 opq.train(&vectors)?;
388
389 for (i, vec) in vectors.iter().enumerate() {
391 opq.pq_index.insert(format!("vec_{i}"), vec.clone())?;
392 }
393
394 let query = Vector::new(vec![0.5; 8]);
396 let results = opq.search(&query, 5)?;
397
398 assert_eq!(results.len(), 5);
399
400 Ok(())
401 }
402}