1#![allow(non_local_definitions)]
3
4#[cfg(feature = "python")]
5use numpy::{PyArray1, PyReadonlyArray1, PyReadonlyArray2};
6#[cfg(feature = "python")]
7use pyo3::prelude::*;
8
9use crate::mstg::{MstgConfig, MstgIndex, ScalarPrecision, SearchParams};
10use crate::rotation::RotatorType;
11use crate::Metric;
12
13#[cfg(feature = "python")]
14#[pyclass(name = "MstgIndex")]
15pub struct PyMstgIndex {
16 index: Option<MstgIndex>,
17 config: MstgConfig,
18 dimension: usize,
19}
20
21#[cfg(feature = "python")]
22#[pymethods]
23impl PyMstgIndex {
24 #[new]
26 #[allow(clippy::too_many_arguments)]
27 #[pyo3(signature = (
28 dimension,
29 metric="euclidean",
30 max_posting_size=16,
31 branching_factor=10,
32 balance_weight=1.0,
33 closure_epsilon=0.15,
34 max_replicas=8,
35 rabitq_bits=7,
36 faster_config=true,
37 hnsw_m=32,
38 hnsw_ef_construction=400,
39 centroid_precision="bf16",
40 default_ef_search=150,
41 pruning_epsilon=0.6
42 ))]
43 fn new(
44 dimension: usize,
45 metric: &str,
46 max_posting_size: usize,
47 branching_factor: usize,
48 balance_weight: f32,
49 closure_epsilon: f32,
50 max_replicas: usize,
51 rabitq_bits: usize,
52 faster_config: bool,
53 hnsw_m: usize,
54 hnsw_ef_construction: usize,
55 centroid_precision: &str,
56 default_ef_search: usize,
57 pruning_epsilon: f32,
58 ) -> PyResult<Self> {
59 let metric = match metric {
60 "euclidean" | "l2" => Metric::L2,
61 "angular" | "ip" | "inner_product" => Metric::InnerProduct,
62 _ => {
63 return Err(pyo3::exceptions::PyValueError::new_err(format!(
64 "Invalid metric: {}. Use 'euclidean' or 'angular'",
65 metric
66 )))
67 }
68 };
69
70 let centroid_precision = match centroid_precision {
71 "fp32" => ScalarPrecision::FP32,
72 "bf16" => ScalarPrecision::BF16,
73 "fp16" => ScalarPrecision::FP16,
74 "int8" => ScalarPrecision::INT8,
75 _ => {
76 return Err(pyo3::exceptions::PyValueError::new_err(format!(
77 "Invalid precision: {}. Use 'fp32', 'bf16', 'fp16', or 'int8'",
78 centroid_precision
79 )))
80 }
81 };
82
83 let config = MstgConfig {
84 max_posting_size,
85 branching_factor,
86 balance_weight,
87 closure_epsilon,
88 max_replicas,
89 rabitq_bits,
90 faster_config,
91 metric,
92 hnsw_m,
93 hnsw_ef_construction,
94 centroid_precision,
95 default_ef_search,
96 pruning_epsilon,
97 };
98
99 Ok(Self {
100 index: None,
101 config,
102 dimension,
103 })
104 }
105
106 fn fit(&mut self, data: PyReadonlyArray2<f32>) -> PyResult<()> {
108 let data = data.as_array();
109 let shape = data.shape();
110
111 if shape.len() != 2 {
112 return Err(pyo3::exceptions::PyValueError::new_err(
113 "Data must be 2D array (N x D)",
114 ));
115 }
116
117 if shape[1] != self.dimension {
118 return Err(pyo3::exceptions::PyValueError::new_err(format!(
119 "Data dimension {} does not match expected {}",
120 shape[1], self.dimension
121 )));
122 }
123
124 let n = shape[0];
126 let mut vectors = Vec::with_capacity(n);
127
128 for i in 0..n {
129 let row = data.row(i);
130 let vec: Vec<f32> = row.iter().copied().collect();
131 vectors.push(vec);
132 }
133
134 match MstgIndex::build(&vectors, self.config.clone()) {
136 Ok(index) => {
137 self.index = Some(index);
138 Ok(())
139 }
140 Err(e) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
141 "Failed to build index: {}",
142 e
143 ))),
144 }
145 }
146
147 fn set_query_arguments(&mut self, ef_search: Option<usize>, pruning_epsilon: Option<f32>) {
149 if let Some(ef) = ef_search {
150 self.config.default_ef_search = ef;
151 }
152 if let Some(eps) = pruning_epsilon {
153 self.config.pruning_epsilon = eps;
154 }
155 }
156
157 fn query(&self, py: Python, query: PyReadonlyArray1<f32>, k: usize) -> PyResult<PyObject> {
160 let index = self.index.as_ref().ok_or_else(|| {
161 pyo3::exceptions::PyRuntimeError::new_err("Index not built yet. Call fit() first.")
162 })?;
163
164 let query = query.as_slice()?;
165
166 if query.len() != self.dimension {
167 return Err(pyo3::exceptions::PyValueError::new_err(format!(
168 "Query dimension {} does not match expected {}",
169 query.len(),
170 self.dimension
171 )));
172 }
173
174 let params = SearchParams::new(
175 self.config.default_ef_search,
176 self.config.pruning_epsilon,
177 k,
178 );
179
180 let results = index.search(query, ¶ms);
181
182 let n = results.len();
184 let mut data = Vec::with_capacity(n * 2);
185 for result in &results {
186 data.push(result.vector_id as f32);
187 data.push(result.distance);
188 }
189
190 let array_1d = PyArray1::<f32>::from_vec(py, data);
192 let result_array = array_1d.reshape([n, 2]).unwrap();
193
194 Ok(result_array.to_owned().into_py(py))
195 }
196
197 fn batch_query(
202 &self,
203 py: Python,
204 queries: PyReadonlyArray2<f32>,
205 k: usize,
206 ) -> PyResult<Vec<PyObject>> {
207 let index = self.index.as_ref().ok_or_else(|| {
208 pyo3::exceptions::PyRuntimeError::new_err("Index not built yet. Call fit() first.")
209 })?;
210
211 let queries = queries.as_array();
212 let shape = queries.shape();
213
214 if shape.len() != 2 {
215 return Err(pyo3::exceptions::PyValueError::new_err(
216 "Queries must be 2D array (N x D)",
217 ));
218 }
219
220 if shape[1] != self.dimension {
221 return Err(pyo3::exceptions::PyValueError::new_err(format!(
222 "Query dimension {} does not match expected {}",
223 shape[1], self.dimension
224 )));
225 }
226
227 let n_queries = shape[0];
228 let params = SearchParams::new(
229 self.config.default_ef_search,
230 self.config.pruning_epsilon,
231 k,
232 );
233
234 let query_vecs: Vec<Vec<f32>> = (0..n_queries)
236 .map(|i| {
237 let row = queries.row(i);
238 row.iter().copied().collect()
239 })
240 .collect();
241
242 let all_results = index.batch_search(&query_vecs, ¶ms);
244
245 all_results
247 .into_iter()
248 .map(|results| {
249 let n = results.len();
250 let mut data = Vec::with_capacity(n * 2);
251 for result in &results {
252 data.push(result.vector_id as f32);
253 data.push(result.distance);
254 }
255
256 let array_1d = PyArray1::<f32>::from_vec(py, data);
257 let result_array = array_1d.reshape([n, 2]).unwrap();
258 Ok(result_array.to_owned().into_py(py))
259 })
260 .collect()
261 }
262
263 fn get_memory_usage(&self) -> PyResult<usize> {
265 let index = self
266 .index
267 .as_ref()
268 .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("Index not built yet."))?;
269
270 let centroid_mem = index.centroid_index.memory_usage();
272 let posting_mem: usize = index.posting_lists.iter().map(|p| p.memory_size()).sum();
273
274 Ok(centroid_mem + posting_mem)
275 }
276
277 fn __len__(&self) -> PyResult<usize> {
279 let index = self
280 .index
281 .as_ref()
282 .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("Index not built yet."))?;
283
284 let total: usize = index.posting_lists.iter().map(|p| p.len()).sum();
285 Ok(total)
286 }
287
288 fn __repr__(&self) -> String {
289 format!(
290 "MstgIndex(dimension={}, metric={:?}, built={})",
291 self.dimension,
292 self.config.metric,
293 self.index.is_some()
294 )
295 }
296
297 fn save(&self, path: &str) -> PyResult<()> {
299 let index = self
300 .index
301 .as_ref()
302 .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("Index not built yet."))?;
303
304 index.save_to_path(path).map_err(|e| {
305 pyo3::exceptions::PyRuntimeError::new_err(format!("Save failed: {}", e))
306 })?;
307
308 Ok(())
309 }
310
311 #[staticmethod]
313 fn load(path: &str) -> PyResult<Self> {
314 let index = MstgIndex::load_from_path(path).map_err(|e| {
315 pyo3::exceptions::PyRuntimeError::new_err(format!("Load failed: {}", e))
316 })?;
317
318 let dimension = if !index.posting_lists.is_empty() {
319 index.posting_lists[0].centroid.len()
320 } else {
321 0
322 };
323
324 let config = index.config.clone();
325
326 Ok(Self {
327 index: Some(index),
328 config,
329 dimension,
330 })
331 }
332}
333
334#[cfg(feature = "python")]
339#[pyclass(name = "IvfRabitqIndex")]
340pub struct PyIvfRabitqIndex {
341 index: Option<crate::ivf::IvfRabitqIndex>,
342 dimension: usize,
343 metric: Metric,
344}
345
346#[cfg(feature = "python")]
347#[pymethods]
348impl PyIvfRabitqIndex {
349 #[new]
351 #[pyo3(signature = (dimension, metric="euclidean"))]
352 fn new(dimension: usize, metric: &str) -> PyResult<Self> {
353 let metric = match metric {
354 "euclidean" | "l2" => Metric::L2,
355 "angular" | "ip" | "inner_product" => Metric::InnerProduct,
356 _ => {
357 return Err(pyo3::exceptions::PyValueError::new_err(format!(
358 "Invalid metric: {}. Use 'euclidean' or 'angular'",
359 metric
360 )))
361 }
362 };
363
364 Ok(Self {
365 index: None,
366 dimension,
367 metric,
368 })
369 }
370
371 #[pyo3(signature = (data, nlist, total_bits=7, rotator_type="random", seed=42, faster_config=true))]
373 fn fit(
374 &mut self,
375 data: PyReadonlyArray2<f32>,
376 nlist: usize,
377 total_bits: usize,
378 rotator_type: &str,
379 seed: u64,
380 faster_config: bool,
381 ) -> PyResult<()> {
382 let data = data.as_array();
383 let shape = data.shape();
384
385 if shape.len() != 2 {
386 return Err(pyo3::exceptions::PyValueError::new_err(
387 "Data must be 2D array (N x D)",
388 ));
389 }
390
391 if shape[1] != self.dimension {
392 return Err(pyo3::exceptions::PyValueError::new_err(format!(
393 "Data dimension {} does not match expected {}",
394 shape[1], self.dimension
395 )));
396 }
397
398 let rotator_type = match rotator_type {
399 "fht" | "random" => RotatorType::FhtKacRotator,
400 "matrix" | "identity" => RotatorType::MatrixRotator,
401 _ => {
402 return Err(pyo3::exceptions::PyValueError::new_err(format!(
403 "Invalid rotator_type: {}. Use 'fht', 'random', 'matrix', or 'identity'",
404 rotator_type
405 )))
406 }
407 };
408
409 let n = shape[0];
411 let mut vectors = Vec::with_capacity(n);
412
413 for i in 0..n {
414 let row = data.row(i);
415 let vec: Vec<f32> = row.iter().copied().collect();
416 vectors.push(vec);
417 }
418
419 match crate::ivf::IvfRabitqIndex::train(
421 &vectors,
422 nlist,
423 total_bits,
424 self.metric,
425 rotator_type,
426 seed,
427 faster_config,
428 ) {
429 Ok(index) => {
430 self.index = Some(index);
431 Ok(())
432 }
433 Err(e) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
434 "Failed to build index: {}",
435 e
436 ))),
437 }
438 }
439
440 #[allow(clippy::too_many_arguments)]
442 #[pyo3(signature = (data, centroids, assignments, total_bits=7, rotator_type="random", seed=42, faster_config=true))]
443 fn fit_with_clusters(
444 &mut self,
445 data: PyReadonlyArray2<f32>,
446 centroids: PyReadonlyArray2<f32>,
447 assignments: PyReadonlyArray1<i32>,
448 total_bits: usize,
449 rotator_type: &str,
450 seed: u64,
451 faster_config: bool,
452 ) -> PyResult<()> {
453 let data = data.as_array();
454 let centroids = centroids.as_array();
455 let assignments = assignments.as_slice()?;
456
457 let data_shape = data.shape();
458 let centroids_shape = centroids.shape();
459
460 if data_shape.len() != 2 || centroids_shape.len() != 2 {
461 return Err(pyo3::exceptions::PyValueError::new_err(
462 "Data and centroids must be 2D arrays",
463 ));
464 }
465
466 if data_shape[1] != self.dimension || centroids_shape[1] != self.dimension {
467 return Err(pyo3::exceptions::PyValueError::new_err(format!(
468 "Data/centroids dimension must match expected {}",
469 self.dimension
470 )));
471 }
472
473 if data_shape[0] != assignments.len() {
474 return Err(pyo3::exceptions::PyValueError::new_err(
475 "Data and assignments must have same length",
476 ));
477 }
478
479 let rotator_type = match rotator_type {
480 "fht" | "random" => RotatorType::FhtKacRotator,
481 "matrix" | "identity" => RotatorType::MatrixRotator,
482 _ => {
483 return Err(pyo3::exceptions::PyValueError::new_err(format!(
484 "Invalid rotator_type: {}. Use 'fht', 'random', 'matrix', or 'identity'",
485 rotator_type
486 )))
487 }
488 };
489
490 let n_data = data_shape[0];
492 let mut data_vecs = Vec::with_capacity(n_data);
493 for i in 0..n_data {
494 let row = data.row(i);
495 let vec: Vec<f32> = row.iter().copied().collect();
496 data_vecs.push(vec);
497 }
498
499 let n_centroids = centroids_shape[0];
501 let mut centroid_vecs = Vec::with_capacity(n_centroids);
502 for i in 0..n_centroids {
503 let row = centroids.row(i);
504 let vec: Vec<f32> = row.iter().copied().collect();
505 centroid_vecs.push(vec);
506 }
507
508 let assignments_usize: Vec<usize> = assignments.iter().map(|&x| x as usize).collect();
510
511 match crate::ivf::IvfRabitqIndex::train_with_clusters(
513 &data_vecs,
514 ¢roid_vecs,
515 &assignments_usize,
516 total_bits,
517 self.metric,
518 rotator_type,
519 seed,
520 faster_config,
521 ) {
522 Ok(index) => {
523 self.index = Some(index);
524 Ok(())
525 }
526 Err(e) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
527 "Failed to build index with clusters: {}",
528 e
529 ))),
530 }
531 }
532
533 #[pyo3(signature = (query, k, nprobe=1))]
536 fn query(
537 &self,
538 py: Python,
539 query: PyReadonlyArray1<f32>,
540 k: usize,
541 nprobe: usize,
542 ) -> PyResult<PyObject> {
543 let index = self.index.as_ref().ok_or_else(|| {
544 pyo3::exceptions::PyRuntimeError::new_err("Index not built yet. Call fit() first.")
545 })?;
546
547 let query_slice = query.as_slice()?;
548
549 if query_slice.len() != self.dimension {
550 return Err(pyo3::exceptions::PyValueError::new_err(format!(
551 "Query dimension {} does not match expected {}",
552 query_slice.len(),
553 self.dimension
554 )));
555 }
556
557 let params = crate::ivf::SearchParams::new(k, nprobe);
558 let results = match index.search(query_slice, params) {
559 Ok(r) => r,
560 Err(e) => {
561 return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
562 "Search failed: {}",
563 e
564 )))
565 }
566 };
567
568 let n = results.len();
570 let mut data = Vec::with_capacity(n * 2);
571
572 for result in &results {
574 data.push(result.id as f32);
575 data.push(result.score);
576 }
577
578 let array_1d = PyArray1::<f32>::from_vec(py, data);
580 let result_array = array_1d.reshape([n, 2]).unwrap();
581
582 Ok(result_array.to_owned().into_py(py))
583 }
584
585 #[pyo3(signature = (queries, k, nprobe=1))]
593 fn batch_query(
594 &self,
595 py: Python,
596 queries: PyReadonlyArray2<f32>,
597 k: usize,
598 nprobe: usize,
599 ) -> PyResult<Vec<PyObject>> {
600 let index = self.index.as_ref().ok_or_else(|| {
601 pyo3::exceptions::PyRuntimeError::new_err("Index not built yet. Call fit() first.")
602 })?;
603
604 let queries_arr = queries.as_array();
605 let shape = queries_arr.shape();
606
607 if shape.len() != 2 {
608 return Err(pyo3::exceptions::PyValueError::new_err(
609 "Queries must be 2D array (N x D)",
610 ));
611 }
612
613 if shape[1] != self.dimension {
614 return Err(pyo3::exceptions::PyValueError::new_err(format!(
615 "Query dimension {} does not match expected {}",
616 shape[1], self.dimension
617 )));
618 }
619
620 let n_queries = shape[0];
621 let params = crate::ivf::SearchParams::new(k, nprobe);
622
623 let query_vecs: Vec<Vec<f32>> = (0..n_queries)
626 .map(|i| queries_arr.row(i).iter().copied().collect())
627 .collect();
628
629 let query_refs: Vec<&[f32]> = query_vecs.iter().map(|v| v.as_slice()).collect();
630
631 let all_results = index.batch_search(&query_refs, params);
633
634 let mut py_results = Vec::with_capacity(n_queries);
636
637 for (i, result) in all_results.into_iter().enumerate() {
639 let results = match result {
640 Ok(r) => r,
641 Err(e) => {
642 return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
643 "Batch search failed at query {}: {}",
644 i, e
645 )))
646 }
647 };
648
649 let n = results.len();
651 let mut data = Vec::with_capacity(n * 2);
652 for result in &results {
653 data.push(result.id as f32);
654 data.push(result.score);
655 }
656
657 let array_1d = PyArray1::<f32>::from_vec(py, data);
659 let result_array = array_1d.reshape([n, 2]).unwrap();
660
661 py_results.push(result_array.to_owned().into_py(py));
662 }
663
664 Ok(py_results)
665 }
666
667 fn save(&self, path: &str) -> PyResult<()> {
669 let index = self.index.as_ref().ok_or_else(|| {
670 pyo3::exceptions::PyRuntimeError::new_err("Index not built yet. Call fit() first.")
671 })?;
672
673 index.save_to_path(path).map_err(|e| {
674 pyo3::exceptions::PyIOError::new_err(format!("Failed to save index: {}", e))
675 })
676 }
677
678 fn load(&mut self, path: &str) -> PyResult<()> {
680 let index = crate::ivf::IvfRabitqIndex::load_from_path(path).map_err(|e| {
681 pyo3::exceptions::PyIOError::new_err(format!("Failed to load index: {}", e))
682 })?;
683
684 self.index = Some(index);
685 Ok(())
686 }
687
688 fn __len__(&self) -> PyResult<usize> {
690 let index = self
691 .index
692 .as_ref()
693 .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("Index not built yet."))?;
694
695 Ok(index.len())
696 }
697
698 fn cluster_count(&self) -> PyResult<usize> {
700 let index = self
701 .index
702 .as_ref()
703 .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("Index not built yet."))?;
704
705 Ok(index.cluster_count())
706 }
707
708 fn __repr__(&self) -> String {
709 format!(
710 "IvfRabitqIndex(dimension={}, metric={:?}, built={}, clusters={})",
711 self.dimension,
712 self.metric,
713 self.index.is_some(),
714 self.index
715 .as_ref()
716 .map(|idx| idx.cluster_count())
717 .unwrap_or(0)
718 )
719 }
720}
721
722#[cfg(feature = "python")]
724#[pymodule]
725fn rabitq_rs(_py: Python, m: &PyModule) -> PyResult<()> {
726 m.add_class::<PyMstgIndex>()?;
727 m.add_class::<PyIvfRabitqIndex>()?;
728 Ok(())
729}