rabitq_rs/
python_bindings.rs

1//! Python bindings for MSTG index using PyO3
2#![allow(non_local_definitions)]
3
4#[cfg(feature = "python")]
5use numpy::{PyArray1, PyReadonlyArray1, PyReadonlyArray2};
6#[cfg(feature = "python")]
7use pyo3::prelude::*;
8
9use crate::mstg::{MstgConfig, MstgIndex, ScalarPrecision, SearchParams};
10use crate::rotation::RotatorType;
11use crate::Metric;
12
13#[cfg(feature = "python")]
14#[pyclass(name = "MstgIndex")]
15pub struct PyMstgIndex {
16    index: Option<MstgIndex>,
17    config: MstgConfig,
18    dimension: usize,
19}
20
21#[cfg(feature = "python")]
22#[pymethods]
23impl PyMstgIndex {
24    /// Create a new MSTG index with configuration
25    #[new]
26    #[allow(clippy::too_many_arguments)]
27    #[pyo3(signature = (
28        dimension,
29        metric="euclidean",
30        max_posting_size=16,
31        branching_factor=10,
32        balance_weight=1.0,
33        closure_epsilon=0.15,
34        max_replicas=8,
35        rabitq_bits=7,
36        faster_config=true,
37        hnsw_m=32,
38        hnsw_ef_construction=400,
39        centroid_precision="bf16",
40        default_ef_search=150,
41        pruning_epsilon=0.6
42    ))]
43    fn new(
44        dimension: usize,
45        metric: &str,
46        max_posting_size: usize,
47        branching_factor: usize,
48        balance_weight: f32,
49        closure_epsilon: f32,
50        max_replicas: usize,
51        rabitq_bits: usize,
52        faster_config: bool,
53        hnsw_m: usize,
54        hnsw_ef_construction: usize,
55        centroid_precision: &str,
56        default_ef_search: usize,
57        pruning_epsilon: f32,
58    ) -> PyResult<Self> {
59        let metric = match metric {
60            "euclidean" | "l2" => Metric::L2,
61            "angular" | "ip" | "inner_product" => Metric::InnerProduct,
62            _ => {
63                return Err(pyo3::exceptions::PyValueError::new_err(format!(
64                    "Invalid metric: {}. Use 'euclidean' or 'angular'",
65                    metric
66                )))
67            }
68        };
69
70        let centroid_precision = match centroid_precision {
71            "fp32" => ScalarPrecision::FP32,
72            "bf16" => ScalarPrecision::BF16,
73            "fp16" => ScalarPrecision::FP16,
74            "int8" => ScalarPrecision::INT8,
75            _ => {
76                return Err(pyo3::exceptions::PyValueError::new_err(format!(
77                    "Invalid precision: {}. Use 'fp32', 'bf16', 'fp16', or 'int8'",
78                    centroid_precision
79                )))
80            }
81        };
82
83        let config = MstgConfig {
84            max_posting_size,
85            branching_factor,
86            balance_weight,
87            closure_epsilon,
88            max_replicas,
89            rabitq_bits,
90            faster_config,
91            metric,
92            hnsw_m,
93            hnsw_ef_construction,
94            centroid_precision,
95            default_ef_search,
96            pruning_epsilon,
97        };
98
99        Ok(Self {
100            index: None,
101            config,
102            dimension,
103        })
104    }
105
106    /// Build index from numpy array (N x D)
107    fn fit(&mut self, data: PyReadonlyArray2<f32>) -> PyResult<()> {
108        let data = data.as_array();
109        let shape = data.shape();
110
111        if shape.len() != 2 {
112            return Err(pyo3::exceptions::PyValueError::new_err(
113                "Data must be 2D array (N x D)",
114            ));
115        }
116
117        if shape[1] != self.dimension {
118            return Err(pyo3::exceptions::PyValueError::new_err(format!(
119                "Data dimension {} does not match expected {}",
120                shape[1], self.dimension
121            )));
122        }
123
124        // Convert to Vec<Vec<f32>>
125        let n = shape[0];
126        let mut vectors = Vec::with_capacity(n);
127
128        for i in 0..n {
129            let row = data.row(i);
130            let vec: Vec<f32> = row.iter().copied().collect();
131            vectors.push(vec);
132        }
133
134        // Build index
135        match MstgIndex::build(&vectors, self.config.clone()) {
136            Ok(index) => {
137                self.index = Some(index);
138                Ok(())
139            }
140            Err(e) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
141                "Failed to build index: {}",
142                e
143            ))),
144        }
145    }
146
147    /// Set query-time search parameters
148    fn set_query_arguments(&mut self, ef_search: Option<usize>, pruning_epsilon: Option<f32>) {
149        if let Some(ef) = ef_search {
150            self.config.default_ef_search = ef;
151        }
152        if let Some(eps) = pruning_epsilon {
153            self.config.pruning_epsilon = eps;
154        }
155    }
156
157    /// Query for k nearest neighbors
158    /// Returns numpy array of shape (k, 2) with [id, distance]
159    fn query(&self, py: Python, query: PyReadonlyArray1<f32>, k: usize) -> PyResult<PyObject> {
160        let index = self.index.as_ref().ok_or_else(|| {
161            pyo3::exceptions::PyRuntimeError::new_err("Index not built yet. Call fit() first.")
162        })?;
163
164        let query = query.as_slice()?;
165
166        if query.len() != self.dimension {
167            return Err(pyo3::exceptions::PyValueError::new_err(format!(
168                "Query dimension {} does not match expected {}",
169                query.len(),
170                self.dimension
171            )));
172        }
173
174        let params = SearchParams::new(
175            self.config.default_ef_search,
176            self.config.pruning_epsilon,
177            k,
178        );
179
180        let results = index.search(query, &params);
181
182        // Convert to numpy array
183        let n = results.len();
184        let mut data = Vec::with_capacity(n * 2);
185        for result in &results {
186            data.push(result.vector_id as f32);
187            data.push(result.distance);
188        }
189
190        // Create 1D array then reshape to 2D
191        let array_1d = PyArray1::<f32>::from_vec(py, data);
192        let result_array = array_1d.reshape([n, 2]).unwrap();
193
194        Ok(result_array.to_owned().into_py(py))
195    }
196
197    /// Batch query for multiple queries
198    /// data: N x D array of queries
199    /// k: number of neighbors per query
200    /// Returns: list of N numpy arrays, each of shape (k, 2)
201    fn batch_query(
202        &self,
203        py: Python,
204        queries: PyReadonlyArray2<f32>,
205        k: usize,
206    ) -> PyResult<Vec<PyObject>> {
207        let index = self.index.as_ref().ok_or_else(|| {
208            pyo3::exceptions::PyRuntimeError::new_err("Index not built yet. Call fit() first.")
209        })?;
210
211        let queries = queries.as_array();
212        let shape = queries.shape();
213
214        if shape.len() != 2 {
215            return Err(pyo3::exceptions::PyValueError::new_err(
216                "Queries must be 2D array (N x D)",
217            ));
218        }
219
220        if shape[1] != self.dimension {
221            return Err(pyo3::exceptions::PyValueError::new_err(format!(
222                "Query dimension {} does not match expected {}",
223                shape[1], self.dimension
224            )));
225        }
226
227        let n_queries = shape[0];
228        let params = SearchParams::new(
229            self.config.default_ef_search,
230            self.config.pruning_epsilon,
231            k,
232        );
233
234        // Convert to Vec<Vec<f32>> for parallel batch_search
235        let query_vecs: Vec<Vec<f32>> = (0..n_queries)
236            .map(|i| {
237                let row = queries.row(i);
238                row.iter().copied().collect()
239            })
240            .collect();
241
242        // Parallel batch search (4-8x speedup)
243        let all_results = index.batch_search(&query_vecs, &params);
244
245        // Convert results to Python numpy arrays
246        all_results
247            .into_iter()
248            .map(|results| {
249                let n = results.len();
250                let mut data = Vec::with_capacity(n * 2);
251                for result in &results {
252                    data.push(result.vector_id as f32);
253                    data.push(result.distance);
254                }
255
256                let array_1d = PyArray1::<f32>::from_vec(py, data);
257                let result_array = array_1d.reshape([n, 2]).unwrap();
258                Ok(result_array.to_owned().into_py(py))
259            })
260            .collect()
261    }
262
263    /// Get memory usage in bytes
264    fn get_memory_usage(&self) -> PyResult<usize> {
265        let index = self
266            .index
267            .as_ref()
268            .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("Index not built yet."))?;
269
270        // Estimate memory usage
271        let centroid_mem = index.centroid_index.memory_usage();
272        let posting_mem: usize = index.posting_lists.iter().map(|p| p.memory_size()).sum();
273
274        Ok(centroid_mem + posting_mem)
275    }
276
277    /// Get number of vectors in index
278    fn __len__(&self) -> PyResult<usize> {
279        let index = self
280            .index
281            .as_ref()
282            .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("Index not built yet."))?;
283
284        let total: usize = index.posting_lists.iter().map(|p| p.len()).sum();
285        Ok(total)
286    }
287
288    fn __repr__(&self) -> String {
289        format!(
290            "MstgIndex(dimension={}, metric={:?}, built={})",
291            self.dimension,
292            self.config.metric,
293            self.index.is_some()
294        )
295    }
296
297    /// Save index to file
298    fn save(&self, path: &str) -> PyResult<()> {
299        let index = self
300            .index
301            .as_ref()
302            .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("Index not built yet."))?;
303
304        index.save_to_path(path).map_err(|e| {
305            pyo3::exceptions::PyRuntimeError::new_err(format!("Save failed: {}", e))
306        })?;
307
308        Ok(())
309    }
310
311    /// Load index from file
312    #[staticmethod]
313    fn load(path: &str) -> PyResult<Self> {
314        let index = MstgIndex::load_from_path(path).map_err(|e| {
315            pyo3::exceptions::PyRuntimeError::new_err(format!("Load failed: {}", e))
316        })?;
317
318        let dimension = if !index.posting_lists.is_empty() {
319            index.posting_lists[0].centroid.len()
320        } else {
321            0
322        };
323
324        let config = index.config.clone();
325
326        Ok(Self {
327            index: Some(index),
328            config,
329            dimension,
330        })
331    }
332}
333
334// ============================================================================
335// IVF RaBitQ Index Python Bindings
336// ============================================================================
337
338#[cfg(feature = "python")]
339#[pyclass(name = "IvfRabitqIndex")]
340pub struct PyIvfRabitqIndex {
341    index: Option<crate::ivf::IvfRabitqIndex>,
342    dimension: usize,
343    metric: Metric,
344}
345
346#[cfg(feature = "python")]
347#[pymethods]
348impl PyIvfRabitqIndex {
349    /// Create a new IVF RaBitQ index
350    #[new]
351    #[pyo3(signature = (dimension, metric="euclidean"))]
352    fn new(dimension: usize, metric: &str) -> PyResult<Self> {
353        let metric = match metric {
354            "euclidean" | "l2" => Metric::L2,
355            "angular" | "ip" | "inner_product" => Metric::InnerProduct,
356            _ => {
357                return Err(pyo3::exceptions::PyValueError::new_err(format!(
358                    "Invalid metric: {}. Use 'euclidean' or 'angular'",
359                    metric
360                )))
361            }
362        };
363
364        Ok(Self {
365            index: None,
366            dimension,
367            metric,
368        })
369    }
370
371    /// Build index from numpy array (N x D) with automatic k-means clustering
372    #[pyo3(signature = (data, nlist, total_bits=7, rotator_type="random", seed=42, faster_config=true))]
373    fn fit(
374        &mut self,
375        data: PyReadonlyArray2<f32>,
376        nlist: usize,
377        total_bits: usize,
378        rotator_type: &str,
379        seed: u64,
380        faster_config: bool,
381    ) -> PyResult<()> {
382        let data = data.as_array();
383        let shape = data.shape();
384
385        if shape.len() != 2 {
386            return Err(pyo3::exceptions::PyValueError::new_err(
387                "Data must be 2D array (N x D)",
388            ));
389        }
390
391        if shape[1] != self.dimension {
392            return Err(pyo3::exceptions::PyValueError::new_err(format!(
393                "Data dimension {} does not match expected {}",
394                shape[1], self.dimension
395            )));
396        }
397
398        let rotator_type = match rotator_type {
399            "fht" | "random" => RotatorType::FhtKacRotator,
400            "matrix" | "identity" => RotatorType::MatrixRotator,
401            _ => {
402                return Err(pyo3::exceptions::PyValueError::new_err(format!(
403                    "Invalid rotator_type: {}. Use 'fht', 'random', 'matrix', or 'identity'",
404                    rotator_type
405                )))
406            }
407        };
408
409        // Convert to Vec<Vec<f32>>
410        let n = shape[0];
411        let mut vectors = Vec::with_capacity(n);
412
413        for i in 0..n {
414            let row = data.row(i);
415            let vec: Vec<f32> = row.iter().copied().collect();
416            vectors.push(vec);
417        }
418
419        // Build index
420        match crate::ivf::IvfRabitqIndex::train(
421            &vectors,
422            nlist,
423            total_bits,
424            self.metric,
425            rotator_type,
426            seed,
427            faster_config,
428        ) {
429            Ok(index) => {
430                self.index = Some(index);
431                Ok(())
432            }
433            Err(e) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
434                "Failed to build index: {}",
435                e
436            ))),
437        }
438    }
439
440    /// Build index from numpy array with pre-computed centroids and assignments
441    #[allow(clippy::too_many_arguments)]
442    #[pyo3(signature = (data, centroids, assignments, total_bits=7, rotator_type="random", seed=42, faster_config=true))]
443    fn fit_with_clusters(
444        &mut self,
445        data: PyReadonlyArray2<f32>,
446        centroids: PyReadonlyArray2<f32>,
447        assignments: PyReadonlyArray1<i32>,
448        total_bits: usize,
449        rotator_type: &str,
450        seed: u64,
451        faster_config: bool,
452    ) -> PyResult<()> {
453        let data = data.as_array();
454        let centroids = centroids.as_array();
455        let assignments = assignments.as_slice()?;
456
457        let data_shape = data.shape();
458        let centroids_shape = centroids.shape();
459
460        if data_shape.len() != 2 || centroids_shape.len() != 2 {
461            return Err(pyo3::exceptions::PyValueError::new_err(
462                "Data and centroids must be 2D arrays",
463            ));
464        }
465
466        if data_shape[1] != self.dimension || centroids_shape[1] != self.dimension {
467            return Err(pyo3::exceptions::PyValueError::new_err(format!(
468                "Data/centroids dimension must match expected {}",
469                self.dimension
470            )));
471        }
472
473        if data_shape[0] != assignments.len() {
474            return Err(pyo3::exceptions::PyValueError::new_err(
475                "Data and assignments must have same length",
476            ));
477        }
478
479        let rotator_type = match rotator_type {
480            "fht" | "random" => RotatorType::FhtKacRotator,
481            "matrix" | "identity" => RotatorType::MatrixRotator,
482            _ => {
483                return Err(pyo3::exceptions::PyValueError::new_err(format!(
484                    "Invalid rotator_type: {}. Use 'fht', 'random', 'matrix', or 'identity'",
485                    rotator_type
486                )))
487            }
488        };
489
490        // Convert data to Vec<Vec<f32>>
491        let n_data = data_shape[0];
492        let mut data_vecs = Vec::with_capacity(n_data);
493        for i in 0..n_data {
494            let row = data.row(i);
495            let vec: Vec<f32> = row.iter().copied().collect();
496            data_vecs.push(vec);
497        }
498
499        // Convert centroids to Vec<Vec<f32>>
500        let n_centroids = centroids_shape[0];
501        let mut centroid_vecs = Vec::with_capacity(n_centroids);
502        for i in 0..n_centroids {
503            let row = centroids.row(i);
504            let vec: Vec<f32> = row.iter().copied().collect();
505            centroid_vecs.push(vec);
506        }
507
508        // Convert assignments to Vec<usize>
509        let assignments_usize: Vec<usize> = assignments.iter().map(|&x| x as usize).collect();
510
511        // Build index
512        match crate::ivf::IvfRabitqIndex::train_with_clusters(
513            &data_vecs,
514            &centroid_vecs,
515            &assignments_usize,
516            total_bits,
517            self.metric,
518            rotator_type,
519            seed,
520            faster_config,
521        ) {
522            Ok(index) => {
523                self.index = Some(index);
524                Ok(())
525            }
526            Err(e) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
527                "Failed to build index with clusters: {}",
528                e
529            ))),
530        }
531    }
532
533    /// Query for k nearest neighbors
534    /// Returns numpy array of shape (k, 2) with [id, distance]
535    #[pyo3(signature = (query, k, nprobe=1))]
536    fn query(
537        &self,
538        py: Python,
539        query: PyReadonlyArray1<f32>,
540        k: usize,
541        nprobe: usize,
542    ) -> PyResult<PyObject> {
543        let index = self.index.as_ref().ok_or_else(|| {
544            pyo3::exceptions::PyRuntimeError::new_err("Index not built yet. Call fit() first.")
545        })?;
546
547        let query_slice = query.as_slice()?;
548
549        if query_slice.len() != self.dimension {
550            return Err(pyo3::exceptions::PyValueError::new_err(format!(
551                "Query dimension {} does not match expected {}",
552                query_slice.len(),
553                self.dimension
554            )));
555        }
556
557        let params = crate::ivf::SearchParams::new(k, nprobe);
558        let results = match index.search(query_slice, params) {
559            Ok(r) => r,
560            Err(e) => {
561                return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
562                    "Search failed: {}",
563                    e
564                )))
565            }
566        };
567
568        // Pre-allocate exact size needed
569        let n = results.len();
570        let mut data = Vec::with_capacity(n * 2);
571
572        // Avoid iterator overhead for small vectors
573        for result in &results {
574            data.push(result.id as f32);
575            data.push(result.score);
576        }
577
578        // Create 1D array then reshape to 2D
579        let array_1d = PyArray1::<f32>::from_vec(py, data);
580        let result_array = array_1d.reshape([n, 2]).unwrap();
581
582        Ok(result_array.to_owned().into_py(py))
583    }
584
585    /// Batch query for multiple queries using parallel processing
586    /// data: N x D array of queries
587    /// k: number of neighbors per query
588    /// nprobe: number of clusters to probe
589    /// Returns: list of N numpy arrays, each of shape (k, 2)
590    ///
591    /// Performance: 2-8x faster than sequential queries (depends on CPU cores)
592    #[pyo3(signature = (queries, k, nprobe=1))]
593    fn batch_query(
594        &self,
595        py: Python,
596        queries: PyReadonlyArray2<f32>,
597        k: usize,
598        nprobe: usize,
599    ) -> PyResult<Vec<PyObject>> {
600        let index = self.index.as_ref().ok_or_else(|| {
601            pyo3::exceptions::PyRuntimeError::new_err("Index not built yet. Call fit() first.")
602        })?;
603
604        let queries_arr = queries.as_array();
605        let shape = queries_arr.shape();
606
607        if shape.len() != 2 {
608            return Err(pyo3::exceptions::PyValueError::new_err(
609                "Queries must be 2D array (N x D)",
610            ));
611        }
612
613        if shape[1] != self.dimension {
614            return Err(pyo3::exceptions::PyValueError::new_err(format!(
615                "Query dimension {} does not match expected {}",
616                shape[1], self.dimension
617            )));
618        }
619
620        let n_queries = shape[0];
621        let params = crate::ivf::SearchParams::new(k, nprobe);
622
623        // Convert queries to Vec of slices for batch_search
624        // This minimizes allocations - we only convert the numpy array to vectors once
625        let query_vecs: Vec<Vec<f32>> = (0..n_queries)
626            .map(|i| queries_arr.row(i).iter().copied().collect())
627            .collect();
628
629        let query_refs: Vec<&[f32]> = query_vecs.iter().map(|v| v.as_slice()).collect();
630
631        // Use parallel batch_search (2-8x speedup)
632        let all_results = index.batch_search(&query_refs, params);
633
634        // Pre-allocate result vector
635        let mut py_results = Vec::with_capacity(n_queries);
636
637        // Convert results to Python numpy arrays
638        for (i, result) in all_results.into_iter().enumerate() {
639            let results = match result {
640                Ok(r) => r,
641                Err(e) => {
642                    return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
643                        "Batch search failed at query {}: {}",
644                        i, e
645                    )))
646                }
647            };
648
649            // Pre-allocate exact size needed
650            let n = results.len();
651            let mut data = Vec::with_capacity(n * 2);
652            for result in &results {
653                data.push(result.id as f32);
654                data.push(result.score);
655            }
656
657            // Create 1D array then reshape to 2D
658            let array_1d = PyArray1::<f32>::from_vec(py, data);
659            let result_array = array_1d.reshape([n, 2]).unwrap();
660
661            py_results.push(result_array.to_owned().into_py(py));
662        }
663
664        Ok(py_results)
665    }
666
667    /// Save index to file
668    fn save(&self, path: &str) -> PyResult<()> {
669        let index = self.index.as_ref().ok_or_else(|| {
670            pyo3::exceptions::PyRuntimeError::new_err("Index not built yet. Call fit() first.")
671        })?;
672
673        index.save_to_path(path).map_err(|e| {
674            pyo3::exceptions::PyIOError::new_err(format!("Failed to save index: {}", e))
675        })
676    }
677
678    /// Load index from file
679    fn load(&mut self, path: &str) -> PyResult<()> {
680        let index = crate::ivf::IvfRabitqIndex::load_from_path(path).map_err(|e| {
681            pyo3::exceptions::PyIOError::new_err(format!("Failed to load index: {}", e))
682        })?;
683
684        self.index = Some(index);
685        Ok(())
686    }
687
688    /// Get number of vectors in index
689    fn __len__(&self) -> PyResult<usize> {
690        let index = self
691            .index
692            .as_ref()
693            .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("Index not built yet."))?;
694
695        Ok(index.len())
696    }
697
698    /// Get number of clusters
699    fn cluster_count(&self) -> PyResult<usize> {
700        let index = self
701            .index
702            .as_ref()
703            .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("Index not built yet."))?;
704
705        Ok(index.cluster_count())
706    }
707
708    fn __repr__(&self) -> String {
709        format!(
710            "IvfRabitqIndex(dimension={}, metric={:?}, built={}, clusters={})",
711            self.dimension,
712            self.metric,
713            self.index.is_some(),
714            self.index
715                .as_ref()
716                .map(|idx| idx.cluster_count())
717                .unwrap_or(0)
718        )
719    }
720}
721
722/// Python module initialization
723#[cfg(feature = "python")]
724#[pymodule]
725fn rabitq_rs(_py: Python, m: &PyModule) -> PyResult<()> {
726    m.add_class::<PyMstgIndex>()?;
727    m.add_class::<PyIvfRabitqIndex>()?;
728    Ok(())
729}