usearch/
lib.rs

1//! # USearch Crate for Rust
2//!
3//! `usearch` is a high-performance library for Approximate Nearest Neighbor (ANN) search in high-dimensional spaces.
4//! It offers efficient and scalable solutions for indexing and querying dense vector spaces with support for multiple distance metrics and vector types.
5//!
6//! This crate wraps the native functionalities of USearch, providing Rust-friendly interfaces and integration capabilities.
7//! It is designed to facilitate rapid development and deployment of applications requiring fast and accurate vector search functionalities, such as recommendation systems, image retrieval systems, and natural language processing tasks.
8//!
9//! ## Features
10//!
11//! - SIMD-accelerated distance calculations for various metrics.
12//! - Support for `f32`, `f64`, `i8`, custom `f16`, and binary (`b1x8`) vector types.
13//! - Extensible with custom distance metrics and filtering predicates.
14//! - Efficient serialization and deserialization for persistence and network transfers.
15//!
16//! ## Quick Start
17//!
18//! Refer to the `Index` struct for detailed usage examples.
19
20/// Returns the version of the USearch crate.
21pub fn version() -> &'static str {
22    env!("CARGO_PKG_VERSION")
23}
24
25/// The key type used to identify vectors in the index.
26/// It is a 64-bit unsigned integer.
27pub type Key = u64;
28
29/// The distance type used to represent the similarity between vectors.
30/// It is a 32-bit floating-point number.
31pub type Distance = f32;
32
33/// Callback signature for custom metric functions, defined in the Rust layer and used in the C++ layer.
34pub type StatefulMetric = unsafe extern "C" fn(
35    *const std::ffi::c_void,
36    *const std::ffi::c_void,
37    *mut std::ffi::c_void,
38) -> Distance;
39
40/// Callback signature for custom predicate functions, defined in the Rust layer and used in the C++ layer.
41pub type StatefulPredicate = unsafe extern "C" fn(Key, *mut std::ffi::c_void) -> bool;
42
43/// Represents errors that can occur when addressing bits.
44#[derive(Debug)]
45pub enum BitAddressableError {
46    /// Error indicating the specified index is out of the allowable range.
47    IndexOutOfRange,
48}
49
50impl std::fmt::Display for BitAddressableError {
51    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
52        match *self {
53            BitAddressableError::IndexOutOfRange => write!(f, "Index out of range"),
54        }
55    }
56}
57
58impl std::error::Error for BitAddressableError {}
59
60/// Trait for types that can be addressed at the bit level.
61/// Provides methods to set and get individual bits within the implementing type.
62pub trait BitAddressable {
63    /// Sets a bit at the specified index.
64    /// Returns an error if the index is out of range.
65    ///
66    /// # Arguments
67    ///
68    /// * `index` - The index of the bit to set.
69    /// * `value` - The value to set the bit to (`true` for 1, `false` for 0).
70    fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError>;
71
72    /// Gets the value of a bit at the specified index.
73    /// Returns an error if the index is out of range.
74    ///
75    /// # Arguments
76    ///
77    /// * `index` - The index of the bit to retrieve.
78    fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError>;
79}
80
81/// A byte-wide bit vector type that provides low-level control over individual bits.
82///
83/// This struct represents a single byte (8 bits) and enables manipulation and
84/// interpretation of individual bits via various utility functions.
85#[repr(transparent)]
86#[allow(non_camel_case_types)]
87#[derive(Clone, Copy, Eq, PartialEq)]
88pub struct b1x8(pub u8);
89
90impl b1x8 {
91    /// Casts a slice of `u8` bytes to a slice of `b1x8`, allowing bit-level operations on byte slices.
92    pub fn from_u8s(slice: &[u8]) -> &[Self] {
93        unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
94    }
95
96    /// Casts a mutable slice of `u8` bytes to a mutable slice of `b1x8`, enabling mutable
97    /// bit-level operations on byte slices.
98    pub fn from_mut_u8s(slice: &mut [u8]) -> &mut [Self] {
99        unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) }
100    }
101
102    /// Converts a slice of `b1x8` back to a slice of `u8`, useful for reading bit-level manipulations
103    /// in byte-oriented contexts.
104    pub fn to_u8s(slice: &[Self]) -> &[u8] {
105        unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len()) }
106    }
107
108    /// Converts a mutable slice of `b1x8` back to a mutable slice of `u8`, enabling further
109    /// modifications on the original byte data after bit-level manipulations.
110    pub fn to_mut_u8s(slice: &mut [Self]) -> &mut [u8] {
111        unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, slice.len()) }
112    }
113}
114
115/// A struct representing a half-precision floating-point number based on the IEEE 754 standard.
116///
117/// This struct uses an `i16` to store the half-precision floating-point data, which includes
118/// 1 sign bit, 5 exponent bits, and 10 mantissa bits.
119#[repr(transparent)]
120#[allow(non_camel_case_types)]
121#[derive(Clone, Copy)]
122pub struct f16(i16);
123
124impl f16 {
125    /// Casts a slice of `i16` integers to a slice of `f16`, allowing operations on half-precision
126    /// floating-point data stored in standard 16-bit integer arrays.
127    pub fn from_i16s(slice: &[i16]) -> &[Self] {
128        unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
129    }
130
131    /// Casts a mutable slice of `i16` integers to a mutable slice of `f16`, enabling mutable operations
132    /// on half-precision floating-point data.
133    pub fn from_mut_i16s(slice: &mut [i16]) -> &mut [Self] {
134        unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) }
135    }
136
137    /// Converts a slice of `f16` back to a slice of `i16`, useful for storage or manipulation in formats
138    /// that require standard integer types.
139    pub fn to_i16s(slice: &[Self]) -> &[i16] {
140        unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const i16, slice.len()) }
141    }
142
143    /// Converts a mutable slice of `f16` back to a mutable slice of `i16`, enabling further
144    /// modifications on the original integer data after operations involving half-precision
145    /// floating-point numbers.
146    pub fn to_mut_i16s(slice: &mut [Self]) -> &mut [i16] {
147        unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut i16, slice.len()) }
148    }
149}
150
151impl BitAddressable for b1x8 {
152    /// Sets a bit at a specific index within the byte.
153    ///
154    /// # Arguments
155    ///
156    /// * `index` - The 0-based index of the bit to set, ranging from 0 to 7.
157    /// * `value` - The boolean value to assign to the bit (`true` for 1, `false` for 0).
158    ///
159    /// # Returns
160    ///
161    /// This method returns `Ok(())` if the bit was successfully set, or an `Err(BitAddressableError::IndexOutOfRange)`
162    /// if the provided index is outside the valid range.
163    fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError> {
164        if index >= 8 {
165            Err(BitAddressableError::IndexOutOfRange)
166        } else {
167            if value {
168                self.0 |= 1 << index;
169            } else {
170                self.0 &= !(1 << index);
171            }
172            Ok(())
173        }
174    }
175
176    /// Retrieves the value of a bit at a specific index within the byte.
177    ///
178    /// # Arguments
179    ///
180    /// * `index` - The 0-based index of the bit to retrieve, ranging from 0 to 7.
181    ///
182    /// # Returns
183    ///
184    /// Returns `Ok(true)` if the bit is set (1), `Ok(false)` if the bit is not set (0),
185    /// or an `Err(BitAddressableError::IndexOutOfRange)` if the provided index is outside
186    /// the valid range.
187    fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError> {
188        if index >= 8 {
189            Err(BitAddressableError::IndexOutOfRange)
190        } else {
191            Ok(((self.0 >> index) & 1) == 1)
192        }
193    }
194}
195
196impl BitAddressable for [b1x8] {
197    /// Sets a bit at a specific index across the slice of `b1x8`.
198    fn set_bit(&mut self, index: usize, value: bool) -> Result<(), BitAddressableError> {
199        let byte_index = index / 8;
200        let bit_index = index % 8;
201        if byte_index >= self.len() {
202            Err(BitAddressableError::IndexOutOfRange)
203        } else {
204            self[byte_index].set_bit(bit_index, value)
205        }
206    }
207
208    /// Gets a bit at a specific index across the slice of `b1x8`.
209    fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError> {
210        let byte_index = index / 8;
211        let bit_index = index % 8;
212        if byte_index >= self.len() {
213            Err(BitAddressableError::IndexOutOfRange)
214        } else {
215            self[byte_index].get_bit(bit_index)
216        }
217    }
218}
219
220impl PartialEq for f16 {
221    fn eq(&self, other: &Self) -> bool {
222        // Check for NaN values first (exponent all ones and non-zero mantissa)
223        let nan_self = (self.0 & 0x7C00) == 0x7C00 && (self.0 & 0x03FF) != 0;
224        let nan_other = (other.0 & 0x7C00) == 0x7C00 && (other.0 & 0x03FF) != 0;
225        if nan_self || nan_other {
226            return false;
227        }
228
229        self.0 == other.0
230    }
231}
232
233impl std::fmt::Debug for b1x8 {
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        write!(f, "{:08b}", self.0)
236    }
237}
238
239impl std::fmt::Debug for f16 {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        let bits = self.0;
242        let sign = (bits >> 15) & 1;
243        let exponent = (bits >> 10) & 0x1F;
244        let mantissa = bits & 0x3FF;
245        write!(f, "{}|{:05b}|{:010b}", sign, exponent, mantissa)
246    }
247}
248
249#[cxx::bridge]
250pub mod ffi {
251
252    /// The metric kind used to differentiate built-in distance functions.
253    #[derive(Debug)]
254    #[repr(i32)]
255    enum MetricKind {
256        Unknown,
257        /// The Inner Product metric, defined as `IP = 1 - sum(a[i] * b[i])`.
258        IP,
259        /// The squared Euclidean Distance metric, defined as `L2 = sum((a[i] - b[i])^2)`.
260        L2sq,
261        /// The Cosine Similarity metric, defined as `Cos = 1 - sum(a[i] * b[i]) / (sqrt(sum(a[i]^2) * sqrt(sum(b[i]^2)))`.
262        Cos,
263        /// The Pearson Correlation metric.
264        Pearson,
265        /// The Haversine (Great Circle) Distance metric.
266        Haversine,
267        /// The Jensen Shannon Divergence metric.
268        Divergence,
269        /// The bit-level Hamming Distance metric, defined as the number of differing bits.
270        Hamming,
271        /// The bit-level Tanimoto (Jaccard) metric, defined as the number of intersecting bits divided by the number of union bits.
272        Tanimoto,
273        /// The bit-level Sorensen metric.
274        Sorensen,
275    }
276
277    /// The scalar kind used to differentiate built-in vector element types.
278    #[derive(Debug)]
279    #[repr(i32)]
280    enum ScalarKind {
281        Unknown,
282        /// 64-bit double-precision IEEE 754 floating-point number.
283        F64,
284        /// 32-bit single-precision IEEE 754 floating-point number.
285        F32,
286        /// 16-bit half-precision IEEE 754 floating-point number (different from `bf16`).
287        F16,
288        /// 16-bit brain floating-point number.
289        BF16,
290        /// 8-bit signed integer.
291        I8,
292        /// 1-bit binary value, packed 8 per byte.
293        B1,
294    }
295
296    /// The resulting matches from a search operation.
297    /// It contains the keys and distances of the closest vectors.
298    #[derive(Debug)]
299    struct Matches {
300        keys: Vec<u64>,
301        distances: Vec<f32>,
302    }
303
304    /// The index options used to configure the dense index during creation.
305    /// It contains the number of dimensions, the metric kind, the scalar kind,
306    /// the connectivity, the expansion values, and the multi-flag.
307    #[derive(Debug, PartialEq)]
308    struct IndexOptions {
309        dimensions: usize,
310        metric: MetricKind,
311        quantization: ScalarKind,
312        connectivity: usize,
313        expansion_add: usize,
314        expansion_search: usize,
315        multi: bool,
316    }
317
318    // C++ types and signatures exposed to Rust.
319    unsafe extern "C++" {
320        include!("lib.hpp");
321
322        /// Low-level C++ interface that is further wrapped into the high-level `Index`
323        type NativeIndex;
324
325        pub fn expansion_add(self: &NativeIndex) -> usize;
326        pub fn expansion_search(self: &NativeIndex) -> usize;
327        pub fn change_expansion_add(self: &NativeIndex, n: usize);
328        pub fn change_expansion_search(self: &NativeIndex, n: usize);
329        pub fn change_metric_kind(self: &NativeIndex, metric: MetricKind);
330
331        /// Changes the metric function used to calculate the distance between vectors.
332        /// Avoids the `std::ffi::c_void` type and the `StatefulMetric` type, that the FFI
333        /// does not support, replacing them with basic pointer-sized integer types.
334        /// The first two arguments are the pointers to the vectors to compare, and the third
335        /// argument is the `metric_state` propagated from the Rust layer.
336        pub fn change_metric(self: &NativeIndex, metric: usize, metric_state: usize);
337
338        pub fn new_native_index(options: &IndexOptions) -> Result<UniquePtr<NativeIndex>>;
339        pub fn reserve(self: &NativeIndex, capacity: usize) -> Result<()>;
340        pub fn reserve_capacity_and_threads(
341            self: &NativeIndex,
342            capacity: usize,
343            threads: usize,
344        ) -> Result<()>;
345
346        pub fn dimensions(self: &NativeIndex) -> usize;
347        pub fn connectivity(self: &NativeIndex) -> usize;
348        pub fn size(self: &NativeIndex) -> usize;
349        pub fn capacity(self: &NativeIndex) -> usize;
350        pub fn serialized_length(self: &NativeIndex) -> usize;
351
352        pub fn add_b1x8(self: &NativeIndex, key: u64, vector: &[u8]) -> Result<()>;
353        pub fn add_i8(self: &NativeIndex, key: u64, vector: &[i8]) -> Result<()>;
354        pub fn add_f16(self: &NativeIndex, key: u64, vector: &[i16]) -> Result<()>;
355        pub fn add_f32(self: &NativeIndex, key: u64, vector: &[f32]) -> Result<()>;
356        pub fn add_f64(self: &NativeIndex, key: u64, vector: &[f64]) -> Result<()>;
357
358        pub fn search_b1x8(self: &NativeIndex, query: &[u8], count: usize) -> Result<Matches>;
359        pub fn search_i8(self: &NativeIndex, query: &[i8], count: usize) -> Result<Matches>;
360        pub fn search_f16(self: &NativeIndex, query: &[i16], count: usize) -> Result<Matches>;
361        pub fn search_f32(self: &NativeIndex, query: &[f32], count: usize) -> Result<Matches>;
362        pub fn search_f64(self: &NativeIndex, query: &[f64], count: usize) -> Result<Matches>;
363
364        pub fn exact_search_b1x8(self: &NativeIndex, query: &[u8], count: usize)
365            -> Result<Matches>;
366        pub fn exact_search_i8(self: &NativeIndex, query: &[i8], count: usize) -> Result<Matches>;
367        pub fn exact_search_f16(self: &NativeIndex, query: &[i16], count: usize)
368            -> Result<Matches>;
369        pub fn exact_search_f32(self: &NativeIndex, query: &[f32], count: usize)
370            -> Result<Matches>;
371        pub fn exact_search_f64(self: &NativeIndex, query: &[f64], count: usize)
372            -> Result<Matches>;
373
374        pub fn filtered_search_b1x8(
375            self: &NativeIndex,
376            query: &[u8],
377            count: usize,
378            filter: usize,
379            filter_state: usize,
380        ) -> Result<Matches>;
381        pub fn filtered_search_i8(
382            self: &NativeIndex,
383            query: &[i8],
384            count: usize,
385            filter: usize,
386            filter_state: usize,
387        ) -> Result<Matches>;
388        pub fn filtered_search_f16(
389            self: &NativeIndex,
390            query: &[i16],
391            count: usize,
392            filter: usize,
393            filter_state: usize,
394        ) -> Result<Matches>;
395        pub fn filtered_search_f32(
396            self: &NativeIndex,
397            query: &[f32],
398            count: usize,
399            filter: usize,
400            filter_state: usize,
401        ) -> Result<Matches>;
402        pub fn filtered_search_f64(
403            self: &NativeIndex,
404            query: &[f64],
405            count: usize,
406            filter: usize,
407            filter_state: usize,
408        ) -> Result<Matches>;
409
410        pub fn get_b1x8(self: &NativeIndex, key: u64, buffer: &mut [u8]) -> Result<usize>;
411        pub fn get_i8(self: &NativeIndex, key: u64, buffer: &mut [i8]) -> Result<usize>;
412        pub fn get_f16(self: &NativeIndex, key: u64, buffer: &mut [i16]) -> Result<usize>;
413        pub fn get_f32(self: &NativeIndex, key: u64, buffer: &mut [f32]) -> Result<usize>;
414        pub fn get_f64(self: &NativeIndex, key: u64, buffer: &mut [f64]) -> Result<usize>;
415
416        pub fn remove(self: &NativeIndex, key: u64) -> Result<usize>;
417        pub fn rename(self: &NativeIndex, from: u64, to: u64) -> Result<usize>;
418        pub fn contains(self: &NativeIndex, key: u64) -> bool;
419        pub fn count(self: &NativeIndex, key: u64) -> usize;
420
421        pub fn save(self: &NativeIndex, path: &str) -> Result<()>;
422        pub fn load(self: &NativeIndex, path: &str) -> Result<()>;
423        pub fn view(self: &NativeIndex, path: &str) -> Result<()>;
424        pub fn reset(self: &NativeIndex) -> Result<()>;
425        pub fn memory_usage(self: &NativeIndex) -> usize;
426        pub fn hardware_acceleration(self: &NativeIndex) -> *const c_char;
427
428        pub fn save_to_buffer(self: &NativeIndex, buffer: &mut [u8]) -> Result<()>;
429        pub fn load_from_buffer(self: &NativeIndex, buffer: &[u8]) -> Result<()>;
430        pub fn view_from_buffer(self: &NativeIndex, buffer: &[u8]) -> Result<()>;
431    }
432}
433
434// Re-export the FFI structs and enums at the crate root for easy access
435pub use ffi::{IndexOptions, MetricKind, ScalarKind};
436
437/// Represents custom metric functions for calculating distances between vectors in various formats.
438///
439/// This enum allows the encapsulation of custom distance calculation logic for vectors of different
440/// data types, facilitating the use of custom metrics in vector space operations. Each variant of this
441/// enum holds a boxed function pointer (`std::boxed::Box<dyn Fn(...) -> Distance + Send + Sync>`) that defines
442/// the distance calculation between two vectors of a specific type. The function returns a `Distance`, which
443/// is typically a floating-point value representing the calculated distance between the two vectors.
444///
445/// # Variants
446///
447/// - `B1X8Metric`: A metric function for binary vectors packed in `u8` containers, represented here by `b1x8`.
448/// - `I8Metric`: A metric function for vectors of 8-bit signed integers (`i8`).
449/// - `F16Metric`: A metric function for vectors of 16-bit half-precision floating-point numbers (`f16`).
450/// - `F32Metric`: A metric function for vectors of 32-bit floating-point numbers (`f32`).
451/// - `F64Metric`: A metric function for vectors of 64-bit floating-point numbers (`f64`).
452///
453/// Each metric function takes two pointers to the vectors of the respective type and returns a `Distance`.
454///
455/// # Usage
456///
457/// Custom metric functions can be used to define how distances are calculated between vectors, enabling
458/// the implementation of various distance metrics such as Euclidean distance, Manhattan distance, or
459/// Cosine similarity, depending on the specific requirements of the application.
460///
461/// # Safety
462///
463/// Since these functions operate on raw pointers, care must be taken to ensure that the pointers are valid
464/// and that the lifetime of the referenced data extends at least as long as the lifetime of the metric
465/// function's use. Improper use of these functions can lead to undefined behavior.
466///
467/// # Examples
468///
469/// ```
470/// use usearch::{Distance, f16, b1x8};
471///
472/// let euclidean_fn = Box::new(|a: *const f32, b: *const f32| -> f32 {
473///     let dimensions = 256;
474///     let a = unsafe { std::slice::from_raw_parts(a, dimensions) };
475///     let b = unsafe { std::slice::from_raw_parts(b, dimensions) };
476///     a.iter().zip(b.iter())
477///         .map(|(a, b)| (*a - *b).powi(2))
478///         .sum::<f32>()
479///         .sqrt()
480/// });
481/// ```
482///
483/// In this example, `dimensions` should be defined and valid for the vectors `a` and `b`.
484pub enum MetricFunction {
485    B1X8Metric(*mut std::boxed::Box<dyn Fn(*const b1x8, *const b1x8) -> Distance + Send + Sync>),
486    I8Metric(*mut std::boxed::Box<dyn Fn(*const i8, *const i8) -> Distance + Send + Sync>),
487    F16Metric(*mut std::boxed::Box<dyn Fn(*const f16, *const f16) -> Distance + Send + Sync>),
488    F32Metric(*mut std::boxed::Box<dyn Fn(*const f32, *const f32) -> Distance + Send + Sync>),
489    F64Metric(*mut std::boxed::Box<dyn Fn(*const f64, *const f64) -> Distance + Send + Sync>),
490}
491
492/// Approximate Nearest Neighbors search index for dense vectors.
493///
494/// The `Index` struct provides an abstraction over a dense vector space, allowing
495/// for efficient addition, search, and management of high-dimensional vectors.
496/// It supports various distance metrics and vector types through generic interfaces.
497///
498/// # Examples
499///
500/// Basic usage:
501///
502/// ```rust
503/// use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
504///
505/// let mut options = IndexOptions::default();
506/// options.dimensions = 4; // Set the number of dimensions for vectors
507/// options.metric = MetricKind::Cos; // Use cosine similarity for distance measurement
508/// options.quantization = ScalarKind::F32; // Use 32-bit floating point numbers
509///
510/// let index = Index::new(&options).expect("Failed to create index.");
511/// index.reserve(1000).expect("Failed to reserve capacity.");
512///
513/// // Add vectors to the index
514/// let vector1: Vec<f32> = vec![0.0, 1.0, 0.0, 1.0];
515/// let vector2: Vec<f32> = vec![1.0, 0.0, 1.0, 0.0];
516/// index.add(1, &vector1).expect("Failed to add vector1.");
517/// index.add(2, &vector2).expect("Failed to add vector2.");
518///
519/// // Search for the nearest neighbors to a query vector
520/// let query: Vec<f32> = vec![0.5, 0.5, 0.5, 0.5];
521/// let results = index.search(&query, 5).expect("Search failed.");
522/// for (key, distance) in results.keys.iter().zip(results.distances.iter()) {
523///     println!("Key: {}, Distance: {}", key, distance);
524/// }
525/// ```
526/// For more examples, including how to add vectors to the index and perform searches,
527/// refer to the individual method documentation.
528pub struct Index {
529    inner: cxx::UniquePtr<ffi::NativeIndex>,
530    metric_fn: Option<MetricFunction>,
531}
532
533unsafe impl Send for Index {}
534unsafe impl Sync for Index {}
535
536impl Drop for Index {
537    fn drop(&mut self) {
538        if let Some(metric) = &self.metric_fn {
539            match metric {
540                MetricFunction::B1X8Metric(pointer) => unsafe {
541                    drop(Box::from_raw(*pointer));
542                },
543                MetricFunction::I8Metric(pointer) => unsafe {
544                    drop(Box::from_raw(*pointer));
545                },
546                MetricFunction::F16Metric(pointer) => unsafe {
547                    drop(Box::from_raw(*pointer));
548                },
549                MetricFunction::F32Metric(pointer) => unsafe {
550                    drop(Box::from_raw(*pointer));
551                },
552                MetricFunction::F64Metric(pointer) => unsafe {
553                    drop(Box::from_raw(*pointer));
554                },
555            }
556        }
557    }
558}
559
560impl Default for ffi::IndexOptions {
561    fn default() -> Self {
562        Self {
563            dimensions: 256,
564            metric: MetricKind::Cos,
565            quantization: ScalarKind::BF16,
566            connectivity: 0,
567            expansion_add: 0,
568            expansion_search: 0,
569            multi: false,
570        }
571    }
572}
573
574impl Clone for ffi::IndexOptions {
575    fn clone(&self) -> Self {
576        ffi::IndexOptions {
577            dimensions: (self.dimensions),
578            metric: (self.metric),
579            quantization: (self.quantization),
580            connectivity: (self.connectivity),
581            expansion_add: (self.expansion_add),
582            expansion_search: (self.expansion_search),
583            multi: (self.multi),
584        }
585    }
586}
587
588/// The `VectorType` trait defines operations for managing and querying vectors
589/// in an index. It supports generic operations on vectors of different types,
590/// allowing for the addition, retrieval, and search of vectors within an index.
591pub trait VectorType {
592    /// Adds a vector to the index under the specified key.
593    ///
594    /// # Parameters
595    /// - `index`: A reference to the `Index` where the vector is to be added.
596    /// - `key`: The key under which the vector should be stored.
597    /// - `vector`: A slice representing the vector to be added.
598    ///
599    /// # Returns
600    /// - `Ok(())` if the vector was successfully added to the index.
601    /// - `Err(cxx::Exception)` if an error occurred during the operation.
602    fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception>
603    where
604        Self: Sized;
605
606    /// Retrieves a vector from the index by its key.
607    ///
608    /// # Parameters
609    /// - `index`: A reference to the `Index` from which the vector is to be retrieved.
610    /// - `key`: The key of the vector to retrieve.
611    /// - `buffer`: A mutable slice where the retrieved vector will be stored. The size of the
612    ///   buffer determines the maximum number of elements that can be retrieved.
613    ///
614    /// # Returns
615    /// - `Ok(usize)` indicating the number of elements actually written into the `buffer`.
616    /// - `Err(cxx::Exception)` if an error occurred during the operation.
617    fn get(index: &Index, key: Key, buffer: &mut [Self]) -> Result<usize, cxx::Exception>
618    where
619        Self: Sized;
620
621    /// Performs a search in the index using the given query vector, returning
622    /// up to `count` closest matches.
623    ///
624    /// # Parameters
625    /// - `index`: A reference to the `Index` where the search is to be performed.
626    /// - `query`: A slice representing the query vector.
627    /// - `count`: The maximum number of matches to return.
628    ///
629    /// # Returns
630    /// - `Ok(ffi::Matches)` containing the matches found.
631    /// - `Err(cxx::Exception)` if an error occurred during the search operation.
632    fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception>
633    where
634        Self: Sized;
635
636    /// Performs an exact (brute force) search in the index using the given query vector, returning
637    /// up to `count` closest matches. This search checks all vectors in the index, guaranteeing to find
638    /// the true nearest neighbors, but will be slower especially for large indices.
639    ///
640    /// # Parameters
641    /// - `index`: A reference to the `Index` where the search is to be performed.
642    /// - `query`: A slice representing the query vector.
643    /// - `count`: The maximum number of matches to return.
644    ///
645    /// # Returns
646    /// - `Ok(ffi::Matches)` containing the matches found.
647    /// - `Err(cxx::Exception)` if an error occurred during the search operation.
648    fn exact_search(
649        index: &Index,
650        query: &[Self],
651        count: usize,
652    ) -> Result<ffi::Matches, cxx::Exception>
653    where
654        Self: Sized;
655
656    /// Performs a filtered search in the index using a query vector and a custom
657    /// filter function, returning up to `count` matches that satisfy the filter.
658    ///
659    /// # Parameters
660    /// - `index`: A reference to the `Index` where the search is to be performed.
661    /// - `query`: A slice representing the query vector.
662    /// - `count`: The maximum number of matches to return.
663    /// - `filter`: A closure that takes a `Key` and returns `true` if the corresponding
664    ///   vector should be included in the search results, or `false` otherwise.
665    ///
666    /// # Returns
667    /// - `Ok(ffi::Matches)` containing the matches that satisfy the filter.
668    /// - `Err(cxx::Exception)` if an error occurred during the filtered search operation.
669    fn filtered_search<F>(
670        index: &Index,
671        query: &[Self],
672        count: usize,
673        filter: F,
674    ) -> Result<ffi::Matches, cxx::Exception>
675    where
676        Self: Sized,
677        F: Fn(Key) -> bool;
678
679    /// Changes the metric used for distance calculations within the index.
680    ///
681    /// # Parameters
682    /// - `index`: A mutable reference to the `Index` for which the metric is to be changed.
683    /// - `metric`: A boxed closure that defines the new metric for distance calculation. The
684    ///   closure must take two pointers to elements of type `Self` and return a `Distance`.
685    ///
686    /// # Returns
687    /// - `Ok(())` if the metric was successfully changed.
688    /// - `Err(cxx::Exception)` if an error occurred during the operation.
689    fn change_metric(
690        index: &mut Index,
691        metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
692    ) -> Result<(), cxx::Exception>
693    where
694        Self: Sized;
695}
696
697impl VectorType for f32 {
698    fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
699        index.inner.search_f32(query, count)
700    }
701
702    fn exact_search(
703        index: &Index,
704        query: &[Self],
705        count: usize,
706    ) -> Result<ffi::Matches, cxx::Exception> {
707        index.inner.exact_search_f32(query, count)
708    }
709
710    fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
711        index.inner.get_f32(key, vector)
712    }
713
714    fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
715        index.inner.add_f32(key, vector)
716    }
717
718    fn filtered_search<F>(
719        index: &Index,
720        query: &[Self],
721        count: usize,
722        filter: F,
723    ) -> Result<ffi::Matches, cxx::Exception>
724    where
725        Self: Sized,
726        F: Fn(Key) -> bool,
727    {
728        // Trampoline is the function that knows how to call the Rust closure.
729        extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
730            let closure = closure_address as *const F;
731            unsafe { (*closure)(key) }
732        }
733
734        // Temporarily cast the closure to a raw pointer for passing.
735        let trampoline_fn: usize = trampoline::<F> as *const () as usize;
736        let closure_address: usize = &filter as *const F as usize;
737        index
738            .inner
739            .filtered_search_f32(query, count, trampoline_fn, closure_address)
740    }
741
742    fn change_metric(
743        index: &mut Index,
744        metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
745    ) -> Result<(), cxx::Exception> {
746        // Store the metric function in the Index.
747        type MetricFn = Box<dyn Fn(*const f32, *const f32) -> Distance>;
748        index.metric_fn = Some(MetricFunction::F32Metric(Box::into_raw(Box::new(metric))));
749
750        // Trampoline is the function that knows how to call the Rust closure.
751        // The `first` is a pointer to the first vector, `second` is a pointer to the second vector,
752        // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function
753        // and the number of dimensions.
754        extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
755            let first_ptr = first as *const f32;
756            let second_ptr = second as *const f32;
757            let closure: *mut MetricFn = closure_address as *mut MetricFn;
758            unsafe { (*closure)(first_ptr, second_ptr) }
759        }
760
761        let trampoline_fn: usize = trampoline as *const () as usize;
762        let closure_address = match index.metric_fn {
763            Some(MetricFunction::F32Metric(metric)) => metric as *mut () as usize,
764            _ => panic!("Expected F32Metric"),
765        };
766        index.inner.change_metric(trampoline_fn, closure_address);
767
768        Ok(())
769    }
770}
771
772impl VectorType for i8 {
773    fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
774        index.inner.search_i8(query, count)
775    }
776
777    fn exact_search(
778        index: &Index,
779        query: &[Self],
780        count: usize,
781    ) -> Result<ffi::Matches, cxx::Exception> {
782        index.inner.exact_search_i8(query, count)
783    }
784
785    fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
786        index.inner.get_i8(key, vector)
787    }
788
789    fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
790        index.inner.add_i8(key, vector)
791    }
792
793    fn filtered_search<F>(
794        index: &Index,
795        query: &[Self],
796        count: usize,
797        filter: F,
798    ) -> Result<ffi::Matches, cxx::Exception>
799    where
800        Self: Sized,
801        F: Fn(Key) -> bool,
802    {
803        // Trampoline is the function that knows how to call the Rust closure.
804        extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
805            let closure = closure_address as *const F;
806            unsafe { (*closure)(key) }
807        }
808
809        // Temporarily cast the closure to a raw pointer for passing.
810        let trampoline_fn: usize = trampoline::<F> as *const () as usize;
811        let closure_address: usize = &filter as *const F as usize;
812        index
813            .inner
814            .filtered_search_i8(query, count, trampoline_fn, closure_address)
815    }
816    fn change_metric(
817        index: &mut Index,
818        metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
819    ) -> Result<(), cxx::Exception> {
820        // Store the metric function in the Index.
821        type MetricFn = Box<dyn Fn(*const i8, *const i8) -> Distance>;
822        index.metric_fn = Some(MetricFunction::I8Metric(Box::into_raw(Box::new(metric))));
823
824        // Trampoline is the function that knows how to call the Rust closure.
825        // The `first` is a pointer to the first vector, `second` is a pointer to the second vector,
826        // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function
827        // and the number of dimensions.
828        extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
829            let first_ptr = first as *const i8;
830            let second_ptr = second as *const i8;
831            let closure: *mut MetricFn = closure_address as *mut MetricFn;
832            unsafe { (*closure)(first_ptr, second_ptr) }
833        }
834
835        let trampoline_fn: usize = trampoline as *const () as usize;
836        let closure_address = match index.metric_fn {
837            Some(MetricFunction::I8Metric(metric)) => metric as *mut () as usize,
838            _ => panic!("Expected I8Metric"),
839        };
840        index.inner.change_metric(trampoline_fn, closure_address);
841
842        Ok(())
843    }
844}
845
846impl VectorType for f64 {
847    fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
848        index.inner.search_f64(query, count)
849    }
850
851    fn exact_search(
852        index: &Index,
853        query: &[Self],
854        count: usize,
855    ) -> Result<ffi::Matches, cxx::Exception> {
856        index.inner.exact_search_f64(query, count)
857    }
858
859    fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
860        index.inner.get_f64(key, vector)
861    }
862
863    fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
864        index.inner.add_f64(key, vector)
865    }
866
867    fn filtered_search<F>(
868        index: &Index,
869        query: &[Self],
870        count: usize,
871        filter: F,
872    ) -> Result<ffi::Matches, cxx::Exception>
873    where
874        Self: Sized,
875        F: Fn(Key) -> bool,
876    {
877        // Trampoline is the function that knows how to call the Rust closure.
878        extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
879            let closure = closure_address as *const F;
880            unsafe { (*closure)(key) }
881        }
882
883        // Temporarily cast the closure to a raw pointer for passing.
884        let trampoline_fn: usize = trampoline::<F> as *const () as usize;
885        let closure_address: usize = &filter as *const F as usize;
886        index
887            .inner
888            .filtered_search_f64(query, count, trampoline_fn, closure_address)
889    }
890    fn change_metric(
891        index: &mut Index,
892        metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
893    ) -> Result<(), cxx::Exception> {
894        // Store the metric function in the Index.
895        type MetricFn = Box<dyn Fn(*const f64, *const f64) -> Distance>;
896        index.metric_fn = Some(MetricFunction::F64Metric(Box::into_raw(Box::new(metric))));
897
898        // Trampoline is the function that knows how to call the Rust closure.
899        // The `first` is a pointer to the first vector, `second` is a pointer to the second vector,
900        // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function
901        // and the number of dimensions.
902        extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
903            let first_ptr = first as *const f64;
904            let second_ptr = second as *const f64;
905            let closure: *mut MetricFn = closure_address as *mut MetricFn;
906            unsafe { (*closure)(first_ptr, second_ptr) }
907        }
908
909        let trampoline_fn: usize = trampoline as *const () as usize;
910        let closure_address = match index.metric_fn {
911            Some(MetricFunction::F64Metric(metric)) => metric as *mut () as usize,
912            _ => panic!("Expected F64Metric"),
913        };
914        index.inner.change_metric(trampoline_fn, closure_address);
915
916        Ok(())
917    }
918}
919
920impl VectorType for f16 {
921    fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
922        index.inner.search_f16(f16::to_i16s(query), count)
923    }
924
925    fn exact_search(
926        index: &Index,
927        query: &[Self],
928        count: usize,
929    ) -> Result<ffi::Matches, cxx::Exception> {
930        index.inner.exact_search_f16(f16::to_i16s(query), count)
931    }
932
933    fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
934        index.inner.get_f16(key, f16::to_mut_i16s(vector))
935    }
936
937    fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
938        index.inner.add_f16(key, f16::to_i16s(vector))
939    }
940
941    fn filtered_search<F>(
942        index: &Index,
943        query: &[Self],
944        count: usize,
945        filter: F,
946    ) -> Result<ffi::Matches, cxx::Exception>
947    where
948        Self: Sized,
949        F: Fn(Key) -> bool,
950    {
951        // Trampoline is the function that knows how to call the Rust closure.
952        extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
953            let closure = closure_address as *const F;
954            unsafe { (*closure)(key) }
955        }
956
957        // Temporarily cast the closure to a raw pointer for passing.
958        let trampoline_fn: usize = trampoline::<F> as *const () as usize;
959        let closure_address: usize = &filter as *const F as usize;
960        index.inner.filtered_search_f16(
961            f16::to_i16s(query),
962            count,
963            trampoline_fn,
964            closure_address,
965        )
966    }
967
968    fn change_metric(
969        index: &mut Index,
970        metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
971    ) -> Result<(), cxx::Exception> {
972        // Store the metric function in the Index.
973        type MetricFn = Box<dyn Fn(*const f16, *const f16) -> Distance>;
974        index.metric_fn = Some(MetricFunction::F16Metric(Box::into_raw(Box::new(metric))));
975
976        // Trampoline is the function that knows how to call the Rust closure.
977        // The `first` is a pointer to the first vector, `second` is a pointer to the second vector,
978        // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function
979        // and the number of dimensions.
980        extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
981            let first_ptr = first as *const f16;
982            let second_ptr = second as *const f16;
983            let closure: *mut MetricFn = closure_address as *mut MetricFn;
984            unsafe { (*closure)(first_ptr, second_ptr) }
985        }
986
987        let trampoline_fn: usize = trampoline as *const () as usize;
988        let closure_address = match index.metric_fn {
989            Some(MetricFunction::F16Metric(metric)) => metric as *mut () as usize,
990            _ => panic!("Expected F16Metric"),
991        };
992        index.inner.change_metric(trampoline_fn, closure_address);
993
994        Ok(())
995    }
996}
997
998impl VectorType for b1x8 {
999    fn search(index: &Index, query: &[Self], count: usize) -> Result<ffi::Matches, cxx::Exception> {
1000        index.inner.search_b1x8(b1x8::to_u8s(query), count)
1001    }
1002
1003    fn exact_search(
1004        index: &Index,
1005        query: &[Self],
1006        count: usize,
1007    ) -> Result<ffi::Matches, cxx::Exception> {
1008        index.inner.exact_search_b1x8(b1x8::to_u8s(query), count)
1009    }
1010
1011    fn get(index: &Index, key: Key, vector: &mut [Self]) -> Result<usize, cxx::Exception> {
1012        index.inner.get_b1x8(key, b1x8::to_mut_u8s(vector))
1013    }
1014
1015    fn add(index: &Index, key: Key, vector: &[Self]) -> Result<(), cxx::Exception> {
1016        index.inner.add_b1x8(key, b1x8::to_u8s(vector))
1017    }
1018
1019    fn filtered_search<F>(
1020        index: &Index,
1021        query: &[Self],
1022        count: usize,
1023        filter: F,
1024    ) -> Result<ffi::Matches, cxx::Exception>
1025    where
1026        Self: Sized,
1027        F: Fn(Key) -> bool,
1028    {
1029        // Trampoline is the function that knows how to call the Rust closure.
1030        extern "C" fn trampoline<F: Fn(u64) -> bool>(key: u64, closure_address: usize) -> bool {
1031            let closure = closure_address as *const F;
1032            unsafe { (*closure)(key) }
1033        }
1034
1035        // Temporarily cast the closure to a raw pointer for passing.
1036        let trampoline_fn: usize = trampoline::<F> as *const () as usize;
1037        let closure_address: usize = &filter as *const F as usize;
1038        index.inner.filtered_search_b1x8(
1039            b1x8::to_u8s(query),
1040            count,
1041            trampoline_fn,
1042            closure_address,
1043        )
1044    }
1045
1046    fn change_metric(
1047        index: &mut Index,
1048        metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
1049    ) -> Result<(), cxx::Exception> {
1050        // Store the metric function in the Index.
1051        type MetricFn = Box<dyn Fn(*const b1x8, *const b1x8) -> Distance>;
1052        index.metric_fn = Some(MetricFunction::B1X8Metric(Box::into_raw(Box::new(metric))));
1053
1054        // Trampoline is the function that knows how to call the Rust closure.
1055        // The `first` is a pointer to the first vector, `second` is a pointer to the second vector,
1056        // and `index_wrapper` is a pointer to the `index` itself, from which we can infer the metric function
1057        // and the number of dimensions.
1058        extern "C" fn trampoline(first: usize, second: usize, closure_address: usize) -> Distance {
1059            let first_ptr = first as *const b1x8;
1060            let second_ptr = second as *const b1x8;
1061            let closure: *mut MetricFn = closure_address as *mut MetricFn;
1062            unsafe { (*closure)(first_ptr, second_ptr) }
1063        }
1064
1065        let trampoline_fn: usize = trampoline as *const () as usize;
1066        let closure_address = match index.metric_fn {
1067            Some(MetricFunction::B1X8Metric(metric)) => metric as *mut () as usize,
1068            _ => panic!("Expected F1X8Metric"),
1069        };
1070        index.inner.change_metric(trampoline_fn, closure_address);
1071
1072        Ok(())
1073    }
1074}
1075
1076impl Index {
1077    pub fn new(options: &ffi::IndexOptions) -> Result<Self, cxx::Exception> {
1078        match ffi::new_native_index(options) {
1079            Ok(inner) => Result::Ok(Self {
1080                inner,
1081                metric_fn: None,
1082            }),
1083            Err(err) => Err(err),
1084        }
1085    }
1086
1087    /// Retrieves the expansion value used during index creation.
1088    pub fn expansion_add(self: &Index) -> usize {
1089        self.inner.expansion_add()
1090    }
1091
1092    /// Retrieves the expansion value used during search.
1093    pub fn expansion_search(self: &Index) -> usize {
1094        self.inner.expansion_search()
1095    }
1096
1097    /// Updates the expansion value used during index creation. Rarely used.
1098    pub fn change_expansion_add(self: &Index, n: usize) {
1099        self.inner.change_expansion_add(n)
1100    }
1101
1102    /// Updates the expansion value used during search operations.
1103    pub fn change_expansion_search(self: &Index, n: usize) {
1104        self.inner.change_expansion_search(n)
1105    }
1106
1107    /// Changes the metric kind used to calculate the distance between vectors.
1108    pub fn change_metric_kind(self: &Index, metric: ffi::MetricKind) {
1109        self.inner.change_metric_kind(metric)
1110    }
1111
1112    /// Overrides the metric function used to calculate the distance between vectors.
1113    pub fn change_metric<T: VectorType>(
1114        self: &mut Index,
1115        metric: std::boxed::Box<dyn Fn(*const T, *const T) -> Distance + Send + Sync>,
1116    ) {
1117        T::change_metric(self, metric).unwrap();
1118    }
1119
1120    /// Retrieves the hardware acceleration information.
1121    pub fn hardware_acceleration(&self) -> String {
1122        use core::ffi::CStr;
1123        unsafe {
1124            let c_str = CStr::from_ptr(self.inner.hardware_acceleration());
1125            c_str.to_string_lossy().into_owned()
1126        }
1127    }
1128
1129    /// Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to the provided query.
1130    ///
1131    /// # Arguments
1132    ///
1133    /// * `query` - A slice containing the query vector data.
1134    /// * `count` - The maximum number of neighbors to search for.
1135    ///
1136    /// # Returns
1137    ///
1138    /// A `Result` containing the matches found.
1139    pub fn search<T: VectorType>(
1140        self: &Index,
1141        query: &[T],
1142        count: usize,
1143    ) -> Result<ffi::Matches, cxx::Exception> {
1144        T::search(self, query, count)
1145    }
1146
1147    /// Performs exact (brute force) Nearest Neighbors Search for closest vectors to the provided query.
1148    /// This search checks all vectors in the index, guaranteeing to find the true nearest neighbors,
1149    /// but may be slower for large indices.
1150    ///
1151    /// # Arguments
1152    ///
1153    /// * `query` - A slice containing the query vector data.
1154    /// * `count` - The maximum number of neighbors to search for.
1155    ///
1156    /// # Returns
1157    ///
1158    /// A `Result` containing the matches found.
1159    pub fn exact_search<T: VectorType>(
1160        self: &Index,
1161        query: &[T],
1162        count: usize,
1163    ) -> Result<ffi::Matches, cxx::Exception> {
1164        T::exact_search(self, query, count)
1165    }
1166
1167    /// Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to the provided query
1168    /// satisfying a custom filter function.
1169    ///
1170    /// # Arguments
1171    ///
1172    /// * `query` - A slice containing the query vector data.
1173    /// * `count` - The maximum number of neighbors to search for.
1174    /// * `filter` - A closure that takes a `Key` and returns `true` if the corresponding vector should be included in the search results, or `false` otherwise.
1175    ///
1176    /// # Returns
1177    ///
1178    /// A `Result` containing the matches found.
1179    pub fn filtered_search<T: VectorType, F>(
1180        self: &Index,
1181        query: &[T],
1182        count: usize,
1183        filter: F,
1184    ) -> Result<ffi::Matches, cxx::Exception>
1185    where
1186        F: Fn(Key) -> bool,
1187    {
1188        T::filtered_search(self, query, count, filter)
1189    }
1190
1191    /// Adds a vector with a specified key to the index.
1192    ///
1193    /// # Arguments
1194    ///
1195    /// * `key` - The key associated with the vector.
1196    /// * `vector` - A slice containing the vector data.
1197    pub fn add<T: VectorType>(self: &Index, key: Key, vector: &[T]) -> Result<(), cxx::Exception> {
1198        T::add(self, key, vector)
1199    }
1200
1201    /// Extracts one or more vectors matching the specified key.
1202    /// The `vector` slice must be a multiple of the number of dimensions in the index.
1203    /// After the execution, return the number `X` of vectors found.
1204    /// The vector slice's first `X * dimensions` elements will be filled.
1205    ///
1206    /// If you are a novice user, consider `export`.
1207    ///
1208    /// # Arguments
1209    ///
1210    /// * `key` - The key associated with the vector.
1211    /// * `vector` - A slice containing the vector data.
1212    pub fn get<T: VectorType>(
1213        self: &Index,
1214        key: Key,
1215        vector: &mut [T],
1216    ) -> Result<usize, cxx::Exception> {
1217        T::get(self, key, vector)
1218    }
1219
1220    /// Extracts one or more vectors matching specified key into supplied resizable vector.
1221    /// The `vector` is resized to a multiple of the number of dimensions in the index.
1222    ///
1223    /// # Arguments
1224    ///
1225    /// * `key` - The key associated with the vector.
1226    /// * `vector` - A mutable vector containing the vector data.
1227    pub fn export<T: VectorType + Default + Clone>(
1228        self: &Index,
1229        key: Key,
1230        vector: &mut Vec<T>,
1231    ) -> Result<usize, cxx::Exception> {
1232        let dim = self.dimensions();
1233        let max_matches = self.count(key);
1234        vector.resize(dim * max_matches, T::default());
1235        let matches = T::get(self, key, &mut vector[..])?;
1236        vector.resize(dim * matches, T::default());
1237        Ok(matches)
1238    }
1239
1240    /// Reserves memory for a specified number of incoming vectors.
1241    ///
1242    /// # Arguments
1243    ///
1244    /// * `capacity` - The desired total capacity, including the current size.
1245    pub fn reserve(self: &Index, capacity: usize) -> Result<(), cxx::Exception> {
1246        self.inner.reserve(capacity)
1247    }
1248
1249    /// Reserves memory for a specified number of incoming vectors & active threads.
1250    ///
1251    /// # Arguments
1252    ///
1253    /// * `capacity` - The desired total capacity, including the current size.
1254    /// * `threads` - The number of threads to use for the operation.
1255    pub fn reserve_capacity_and_threads(
1256        self: &Index,
1257        capacity: usize,
1258        threads: usize,
1259    ) -> Result<(), cxx::Exception> {
1260        self.inner.reserve_capacity_and_threads(capacity, threads)
1261    }
1262
1263    /// Retrieves the number of dimensions in the vectors indexed.
1264    pub fn dimensions(self: &Index) -> usize {
1265        self.inner.dimensions()
1266    }
1267
1268    /// Retrieves the connectivity parameter that limits connections-per-node in the graph.
1269    pub fn connectivity(self: &Index) -> usize {
1270        self.inner.connectivity()
1271    }
1272
1273    /// Retrieves the current number of vectors in the index.
1274    pub fn size(self: &Index) -> usize {
1275        self.inner.size()
1276    }
1277
1278    /// Retrieves the total capacity of the index, including reserved space.
1279    pub fn capacity(self: &Index) -> usize {
1280        self.inner.capacity()
1281    }
1282
1283    /// Reports expected file size after serialization.
1284    pub fn serialized_length(self: &Index) -> usize {
1285        self.inner.serialized_length()
1286    }
1287
1288    /// Removes the vector associated with the given key from the index.
1289    ///
1290    /// # Arguments
1291    ///
1292    /// * `key` - The key of the vector to be removed.
1293    ///
1294    /// # Returns
1295    ///
1296    /// `true` if the vector is successfully removed, `false` otherwise.
1297    pub fn remove(self: &Index, key: Key) -> Result<usize, cxx::Exception> {
1298        self.inner.remove(key)
1299    }
1300
1301    /// Renames the vector under a specific key.
1302    ///
1303    /// # Arguments
1304    ///
1305    /// * `from` - The key of the vector to be renamed.
1306    /// * `to` - The new name.
1307    ///
1308    /// # Returns
1309    ///
1310    /// `true` if the vector is renamed, `false` otherwise.
1311    pub fn rename(self: &Index, from: Key, to: Key) -> Result<usize, cxx::Exception> {
1312        self.inner.rename(from, to)
1313    }
1314
1315    /// Checks if the index contains a vector with a specified key.
1316    ///
1317    /// # Arguments
1318    ///
1319    /// * `key` - The key to be checked.
1320    ///
1321    /// # Returns
1322    ///
1323    /// `true` if the index contains the vector with the given key, `false` otherwise.
1324    pub fn contains(self: &Index, key: Key) -> bool {
1325        self.inner.contains(key)
1326    }
1327
1328    /// Count the count of vectors with the same specified key.
1329    ///
1330    /// # Arguments
1331    ///
1332    /// * `key` - The key to be checked.
1333    ///
1334    /// # Returns
1335    ///
1336    /// Number of vectors found.
1337    pub fn count(self: &Index, key: Key) -> usize {
1338        self.inner.count(key)
1339    }
1340
1341    /// Saves the index to a specified file.
1342    ///
1343    /// # Arguments
1344    ///
1345    /// * `path` - The file path where the index will be saved.
1346    pub fn save(self: &Index, path: &str) -> Result<(), cxx::Exception> {
1347        self.inner.save(path)
1348    }
1349
1350    /// Loads the index from a specified file.
1351    ///
1352    /// # Arguments
1353    ///
1354    /// * `path` - The file path from where the index will be loaded.
1355    pub fn load(self: &Index, path: &str) -> Result<(), cxx::Exception> {
1356        self.inner.load(path)
1357    }
1358
1359    /// Creates a view of the index from a file without loading it into memory.
1360    ///
1361    /// # Arguments
1362    ///
1363    /// * `path` - The file path from where the view will be created.
1364    pub fn view(self: &Index, path: &str) -> Result<(), cxx::Exception> {
1365        self.inner.view(path)
1366    }
1367
1368    /// Erases all members from the index, closes files, and returns RAM to OS.
1369    pub fn reset(self: &Index) -> Result<(), cxx::Exception> {
1370        self.inner.reset()
1371    }
1372
1373    /// A relatively accurate lower bound on the amount of memory consumed by the system.
1374    /// In practice, its error will be below 10%.
1375    pub fn memory_usage(self: &Index) -> usize {
1376        self.inner.memory_usage()
1377    }
1378
1379    /// Saves the index to a specified file.
1380    ///
1381    /// # Arguments
1382    ///
1383    /// * `buffer` - The buffer where the index will be saved.
1384    pub fn save_to_buffer(self: &Index, buffer: &mut [u8]) -> Result<(), cxx::Exception> {
1385        self.inner.save_to_buffer(buffer)
1386    }
1387
1388    /// Loads the index from a specified file.
1389    ///
1390    /// # Arguments
1391    ///
1392    /// * `buffer` - The buffer from where the index will be loaded.
1393    pub fn load_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> {
1394        self.inner.load_from_buffer(buffer)
1395    }
1396
1397    /// Creates a view of the index from a file without loading it into memory.
1398    ///
1399    /// # Arguments
1400    ///
1401    /// * `buffer` - The buffer from where the view will be created.
1402    ///
1403    /// # Safety
1404    ///
1405    /// This function is marked as `unsafe` because it stores a pointer to the input buffer.
1406    /// The caller must ensure that the buffer outlives the index and is not dropped
1407    /// or modified for the duration of the index's use. Dereferencing a pointer to a
1408    /// temporary buffer after it has been dropped can lead to undefined behavior,
1409    /// which violates Rust's memory safety guarantees.
1410    ///
1411    /// Example of misuse:
1412    ///
1413    /// ```rust,ignore
1414    /// let index: usearch::Index = usearch::new_index(&usearch::IndexOptions::default()).unwrap();
1415    ///
1416    /// let temporary = vec![0u8; 100];
1417    /// index.view_from_buffer(&temporary);
1418    /// std::mem::drop(temporary);
1419    ///
1420    /// let query = vec![0.0; 256];
1421    /// let results = index.search(&query, 5).unwrap();
1422    /// ```
1423    ///
1424    /// The above example would result in use-after-free and undefined behavior.
1425    pub unsafe fn view_from_buffer(self: &Index, buffer: &[u8]) -> Result<(), cxx::Exception> {
1426        self.inner.view_from_buffer(buffer)
1427    }
1428}
1429
1430pub fn new_index(options: &ffi::IndexOptions) -> Result<Index, cxx::Exception> {
1431    Index::new(options)
1432}
1433
1434#[cfg(test)]
1435mod tests {
1436    use crate::ffi::IndexOptions;
1437    use crate::ffi::MetricKind;
1438    use crate::ffi::ScalarKind;
1439
1440    use crate::b1x8;
1441    use crate::new_index;
1442    use crate::Index;
1443    use crate::Key;
1444
1445    use std::env;
1446
1447    #[test]
1448    fn print_specs() {
1449        println!("--------------------------------------------------");
1450        println!("OS: {}", env::consts::OS);
1451        println!(
1452            "Rust version: {}",
1453            env::var("RUST_VERSION").unwrap_or_else(|_| "unknown".into())
1454        );
1455
1456        // Create indexes with different configurations
1457        let f64_index = Index::new(&IndexOptions {
1458            dimensions: 256,
1459            metric: MetricKind::Cos,
1460            quantization: ScalarKind::F64,
1461            ..Default::default()
1462        })
1463        .unwrap();
1464
1465        let f32_index = Index::new(&IndexOptions {
1466            dimensions: 256,
1467            metric: MetricKind::Cos,
1468            quantization: ScalarKind::F32,
1469            ..Default::default()
1470        })
1471        .unwrap();
1472
1473        let f16_index = Index::new(&IndexOptions {
1474            dimensions: 256,
1475            metric: MetricKind::Cos,
1476            quantization: ScalarKind::F16,
1477            ..Default::default()
1478        })
1479        .unwrap();
1480
1481        let i8_index = Index::new(&IndexOptions {
1482            dimensions: 256,
1483            metric: MetricKind::Cos,
1484            quantization: ScalarKind::I8,
1485            ..Default::default()
1486        })
1487        .unwrap();
1488
1489        let b1_index = Index::new(&IndexOptions {
1490            dimensions: 256,
1491            metric: MetricKind::Hamming,
1492            quantization: ScalarKind::B1,
1493            ..Default::default()
1494        })
1495        .unwrap();
1496
1497        println!(
1498            "f64 hardware acceleration: {}",
1499            f64_index.hardware_acceleration()
1500        );
1501        println!(
1502            "f32 hardware acceleration: {}",
1503            f32_index.hardware_acceleration()
1504        );
1505        println!(
1506            "f16 hardware acceleration: {}",
1507            f16_index.hardware_acceleration()
1508        );
1509        println!(
1510            "i8 hardware acceleration: {}",
1511            i8_index.hardware_acceleration()
1512        );
1513        println!(
1514            "b1 hardware acceleration: {}",
1515            b1_index.hardware_acceleration()
1516        );
1517        println!("--------------------------------------------------");
1518    }
1519
1520    #[test]
1521    fn test_add_get_vector() {
1522        let options = IndexOptions {
1523            dimensions: 5,
1524            quantization: ScalarKind::F32,
1525            ..Default::default()
1526        };
1527        let index = Index::new(&options).unwrap();
1528        assert!(index.reserve(10).is_ok());
1529
1530        let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
1531        let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
1532        let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1];
1533        let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
1534        assert!(index.add(1, &first).is_ok());
1535        assert!(index.add(2, &second).is_ok());
1536        assert!(index.add(3, &too_long).is_err());
1537        assert!(index.add(4, &too_short).is_err());
1538        assert_eq!(index.size(), 2);
1539
1540        // Test using Vec<T>
1541        let mut found_vec: Vec<f32> = Vec::new();
1542        assert_eq!(index.export(1, &mut found_vec).unwrap(), 1);
1543        assert_eq!(found_vec.len(), 5);
1544        assert_eq!(found_vec, first.to_vec());
1545
1546        // Test using slice
1547        let mut found_slice = [0.0f32; 5];
1548        assert_eq!(index.get(1, &mut found_slice).unwrap(), 1);
1549        assert_eq!(found_slice, first);
1550
1551        // Create a slice with incorrect size
1552        let mut found = [0.0f32; 6]; // This isn't a multiple of the index's dimensions.
1553        let result = index.get(1, &mut found);
1554        assert!(result.is_err());
1555    }
1556    #[test]
1557    fn test_search_vector() {
1558        let options = IndexOptions {
1559            dimensions: 5,
1560            quantization: ScalarKind::F32,
1561            ..Default::default()
1562        };
1563        let index = Index::new(&options).unwrap();
1564        assert!(index.reserve(10).is_ok());
1565
1566        let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
1567        let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
1568        let too_long: [f32; 6] = [0.3, 0.2, 0.4, 0.0, 0.1, 0.1];
1569        let too_short: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
1570        assert!(index.add(1, &first).is_ok());
1571        assert!(index.add(2, &second).is_ok());
1572        assert_eq!(index.size(), 2);
1573        //assert!(index.add(3, &too_long).is_err());
1574        //assert!(index.add(4, &too_short).is_err());
1575
1576        assert!(index.search(&too_long, 1).is_err());
1577        assert!(index.search(&too_short, 1).is_err());
1578    }
1579
1580    #[test]
1581    fn test_add_remove_vector() {
1582        let options = IndexOptions {
1583            dimensions: 4,
1584            metric: MetricKind::IP,
1585            quantization: ScalarKind::F64,
1586            connectivity: 10,
1587            expansion_add: 128,
1588            expansion_search: 3,
1589            ..Default::default()
1590        };
1591        let index = Index::new(&options).unwrap();
1592        assert!(index.reserve(10).is_ok());
1593        assert!(index.capacity() >= 10);
1594
1595        let first: [f32; 4] = [0.2, 0.1, 0.2, 0.1];
1596        let second: [f32; 4] = [0.3, 0.2, 0.4, 0.0];
1597
1598        // IDs until 18446744073709551615 should be fine:
1599        let id1 = 483367403120493160;
1600        let id2 = 483367403120558696;
1601        let id3 = 483367403120624232;
1602        let id4 = 483367403120624233;
1603
1604        assert!(index.add(id1, &first).is_ok());
1605        let mut found_slice = [0.0f32; 4];
1606        assert_eq!(index.get(id1, &mut found_slice).unwrap(), 1);
1607        assert!(index.remove(id1).is_ok());
1608
1609        assert!(index.add(id2, &second).is_ok());
1610        let mut found_slice = [0.0f32; 4];
1611        assert_eq!(index.get(id2, &mut found_slice).unwrap(), 1);
1612        assert!(index.remove(id2).is_ok());
1613
1614        assert!(index.add(id3, &second).is_ok());
1615        let mut found_slice = [0.0f32; 4];
1616        assert_eq!(index.get(id3, &mut found_slice).unwrap(), 1);
1617        assert!(index.remove(id3).is_ok());
1618
1619        assert!(index.add(id4, &second).is_ok());
1620        let mut found_slice = [0.0f32; 4];
1621        assert_eq!(index.get(id4, &mut found_slice).unwrap(), 1);
1622        assert!(index.remove(id4).is_ok());
1623
1624        assert_eq!(index.size(), 0);
1625    }
1626
1627    #[test]
1628    fn integration() {
1629        let mut options = IndexOptions {
1630            dimensions: 5,
1631            ..Default::default()
1632        };
1633
1634        let index = Index::new(&options).unwrap();
1635
1636        assert!(index.expansion_add() > 0);
1637        assert!(index.expansion_search() > 0);
1638
1639        assert!(index.reserve(10).is_ok());
1640        assert!(index.capacity() >= 10);
1641        assert!(index.connectivity() != 0);
1642        assert_eq!(index.dimensions(), 5);
1643        assert_eq!(index.size(), 0);
1644
1645        let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
1646        let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
1647
1648        println!("--------------------------------------------------");
1649        println!(
1650            "before add, memory_usage: {} \
1651            cap: {} \
1652            ",
1653            index.memory_usage(),
1654            index.capacity(),
1655        );
1656        index.change_expansion_add(10);
1657        assert_eq!(index.expansion_add(), 10);
1658        assert!(index.add(42, &first).is_ok());
1659        index.change_expansion_add(12);
1660        assert_eq!(index.expansion_add(), 12);
1661        assert!(index.add(43, &second).is_ok());
1662        assert_eq!(index.size(), 2);
1663        println!(
1664            "after add, memory_usage: {} \
1665            cap: {} \
1666            ",
1667            index.memory_usage(),
1668            index.capacity(),
1669        );
1670
1671        index.change_expansion_search(10);
1672        assert_eq!(index.expansion_search(), 10);
1673        // Read back the tags
1674        let results = index.search(&first, 10).unwrap();
1675        println!("{:?}", results);
1676        assert_eq!(results.keys.len(), 2);
1677
1678        index.change_expansion_search(12);
1679        assert_eq!(index.expansion_search(), 12);
1680        let results = index.search(&first, 10).unwrap();
1681        println!("{:?}", results);
1682        assert_eq!(results.keys.len(), 2);
1683        println!("--------------------------------------------------");
1684
1685        // Validate serialization
1686        assert!(index.save("index.rust.usearch").is_ok());
1687        assert!(index.load("index.rust.usearch").is_ok());
1688        assert!(index.view("index.rust.usearch").is_ok());
1689
1690        // Make sure every function is called at least once
1691        assert!(new_index(&options).is_ok());
1692        options.metric = MetricKind::L2sq;
1693        assert!(new_index(&options).is_ok());
1694        options.metric = MetricKind::Cos;
1695        assert!(new_index(&options).is_ok());
1696        options.metric = MetricKind::Haversine;
1697        options.quantization = ScalarKind::F32;
1698        options.dimensions = 2;
1699        assert!(new_index(&options).is_ok());
1700
1701        let mut serialization_buffer = vec![0; index.serialized_length()];
1702        assert!(index.save_to_buffer(&mut serialization_buffer).is_ok());
1703
1704        let deserialized_index = new_index(&options).unwrap();
1705        assert!(deserialized_index
1706            .load_from_buffer(&serialization_buffer)
1707            .is_ok());
1708        assert_eq!(index.size(), deserialized_index.size());
1709
1710        // reset
1711        assert_ne!(index.memory_usage(), 0);
1712        assert!(index.reset().is_ok());
1713        assert_eq!(index.size(), 0);
1714        assert_eq!(index.memory_usage(), 0);
1715
1716        // clone
1717        options.metric = MetricKind::Haversine;
1718        let mut opts = options.clone();
1719        assert_eq!(opts.metric, options.metric);
1720        assert_eq!(opts.quantization, options.quantization);
1721        assert_eq!(opts, options);
1722        opts.metric = MetricKind::Cos;
1723        assert_ne!(opts.metric, options.metric);
1724        assert!(new_index(&opts).is_ok());
1725    }
1726
1727    #[test]
1728    fn test_search_with_stateless_filter() {
1729        let options = IndexOptions {
1730            dimensions: 5,
1731            ..Default::default()
1732        };
1733        let index = Index::new(&options).unwrap();
1734        index.reserve(10).unwrap();
1735
1736        // Adding sample vectors to the index
1737        let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
1738        let second: [f32; 5] = [0.3, 0.2, 0.4, 0.0, 0.1];
1739        index.add(1, &first).unwrap();
1740        index.add(2, &second).unwrap();
1741
1742        // Stateless filter: checks if the key is odd
1743        let is_odd = |key: Key| key % 2 == 1;
1744        let query = vec![0.2, 0.1, 0.2, 0.1, 0.3]; // Example query vector
1745        let results = index.filtered_search(&query, 10, is_odd).unwrap();
1746        assert!(
1747            results.keys.iter().all(|&key| key % 2 == 1),
1748            "All keys must be odd"
1749        );
1750    }
1751
1752    #[test]
1753    fn test_search_with_stateful_filter() {
1754        use std::collections::HashSet;
1755
1756        let options = IndexOptions {
1757            dimensions: 5,
1758            ..Default::default()
1759        };
1760        let index = Index::new(&options).unwrap();
1761        index.reserve(10).unwrap();
1762
1763        // Adding sample vectors to the index
1764        let first: [f32; 5] = [0.2, 0.1, 0.2, 0.1, 0.3];
1765        index.add(1, &first).unwrap();
1766        index.add(2, &first).unwrap();
1767
1768        let allowed_keys = vec![1, 2, 3].into_iter().collect::<HashSet<Key>>();
1769        // Clone `allowed_keys` for use in the closure
1770        let filter_keys = allowed_keys.clone();
1771        let stateful_filter = move |key: Key| filter_keys.contains(&key);
1772
1773        let query = vec![0.2, 0.1, 0.2, 0.1, 0.3]; // Example query vector
1774        let results = index.filtered_search(&query, 10, stateful_filter).unwrap();
1775
1776        // Use the original `allowed_keys` for assertion
1777        assert!(
1778            results.keys.iter().all(|&key| allowed_keys.contains(&key)),
1779            "All keys must be in the allowed set"
1780        );
1781    }
1782
1783    #[test]
1784    fn test_zero_distances() {
1785        let options = IndexOptions {
1786            dimensions: 8,
1787            metric: MetricKind::L2sq,
1788            quantization: ScalarKind::F16,
1789            ..Default::default()
1790        };
1791
1792        let index = new_index(&options).unwrap();
1793        index.reserve(10).unwrap();
1794        index
1795            .add(0, &[0.4, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
1796            .unwrap();
1797        index
1798            .add(1, &[0.5, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
1799            .unwrap();
1800        index
1801            .add(2, &[0.6, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0])
1802            .unwrap();
1803
1804        // Make sure non of the distances are zeros
1805        let matches = index
1806            .search(&[0.05, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0], 2)
1807            .unwrap();
1808        for distance in matches.distances.iter() {
1809            assert_ne!(*distance, 0.0);
1810        }
1811    }
1812
1813    #[test]
1814    fn test_exact_search() {
1815        use std::collections::HashSet;
1816
1817        // Create an index with many vectors
1818        let options = IndexOptions {
1819            dimensions: 4,
1820            metric: MetricKind::L2sq,
1821            quantization: ScalarKind::F32,
1822            ..Default::default()
1823        };
1824        let index = new_index(&options).unwrap();
1825        index.reserve(100).unwrap();
1826        // Add 100 vectors to the index
1827        for i in 0..100 {
1828            let vec = vec![
1829                i as f32 * 0.1,
1830                (i as f32 * 0.05).sin(),
1831                (i as f32 * 0.05).cos(),
1832                0.0,
1833            ];
1834            index.add(i, &vec).unwrap();
1835        }
1836        // Query vector
1837        let query = vec![4.5, 0.0, 1.0, 0.0];
1838        // Compare approximate and exact search results
1839        let approx_matches = index.search(&query, 10).unwrap();
1840        let exact_matches = index.exact_search(&query, 10).unwrap();
1841        // Collect the keys from both result sets
1842        let approx_keys: HashSet<Key> = approx_matches.keys.iter().cloned().collect();
1843        let exact_keys: HashSet<Key> = exact_matches.keys.iter().cloned().collect();
1844        // Check that both methods return 10 results
1845        assert_eq!(approx_matches.keys.len(), 10);
1846        assert_eq!(exact_matches.keys.len(), 10);
1847
1848        // The exact search should find the true nearest neighbors
1849        // Verify that the minimum distance in exact results is <= minimum distance in approximate results
1850        assert!(exact_matches.distances[0] <= approx_matches.distances[0]);
1851        // The nearest neighbor according to exact search might be different from approximate search
1852        println!(
1853            "Approximate search first match: key={}, distance={}",
1854            approx_matches.keys[0], approx_matches.distances[0]
1855        );
1856        println!(
1857            "Exact search first match: key={}, distance={}",
1858            exact_matches.keys[0], exact_matches.distances[0]
1859        );
1860        // Results from both should be mostly similar, but may differ due to approximation
1861        let intersection: HashSet<_> = approx_keys.intersection(&exact_keys).collect();
1862        println!(
1863            "Number of common results between approximate and exact search: {}",
1864            intersection.len()
1865        );
1866    }
1867
1868    #[test]
1869    fn test_change_distance_function() {
1870        let options = IndexOptions {
1871            dimensions: 2, // Adjusted for simplicity in creating test vectors
1872            ..Default::default()
1873        };
1874        let mut index = Index::new(&options).unwrap();
1875        index.reserve(10).unwrap();
1876
1877        // Adding a simple vector to test the distance function changes
1878        let vector: [f32; 2] = [1.0, 0.0];
1879        index.add(1, &vector).unwrap();
1880
1881        // Stateful distance function with adjustments for pointer to slice conversion
1882        let first_factor: f32 = 2.0;
1883        let second_factor: f32 = 0.7;
1884        let stateful_distance = Box::new(move |a: *const f32, b: *const f32| unsafe {
1885            let a_slice = std::slice::from_raw_parts(a, 2);
1886            let b_slice = std::slice::from_raw_parts(b, 2);
1887            (a_slice[0] - b_slice[0]).abs() * first_factor
1888                + (a_slice[1] - b_slice[1]).abs() * second_factor
1889        });
1890        index.change_metric(stateful_distance);
1891
1892        let another_vector: [f32; 2] = [0.0, 1.0];
1893        index.add(2, &another_vector).unwrap();
1894    }
1895
1896    #[test]
1897    fn test_binary_vectors_and_hamming_distance() {
1898        let index = Index::new(&IndexOptions {
1899            dimensions: 8,
1900            metric: MetricKind::Hamming,
1901            quantization: ScalarKind::B1,
1902            ..Default::default()
1903        })
1904        .unwrap();
1905
1906        // Binary vectors represented as `b1x8` slices
1907        let vector42: Vec<b1x8> = vec![b1x8(0b00001111)];
1908        let vector43: Vec<b1x8> = vec![b1x8(0b11110000)];
1909        let query: Vec<b1x8> = vec![b1x8(0b01111000)];
1910
1911        // Adding binary vectors to the index
1912        index.reserve(10).unwrap();
1913        index.add(42, &vector42).unwrap();
1914        index.add(43, &vector43).unwrap();
1915
1916        let results = index.search(&query, 5).unwrap();
1917
1918        // Validate the search results based on Hamming distance
1919        assert_eq!(results.keys.len(), 2);
1920        assert_eq!(results.keys[0], 43);
1921        assert_eq!(results.distances[0], 2.0);
1922        assert_eq!(results.keys[1], 42);
1923        assert_eq!(results.distances[1], 6.0);
1924    }
1925
1926    #[test]
1927    fn test_concurrency() {
1928        use fork_union as fu;
1929        use rand::{Rng, SeedableRng};
1930        use rand_chacha::ChaCha8Rng;
1931        use rand_distr::Uniform;
1932        use std::sync::Arc;
1933
1934        const DIMENSIONS: usize = 128;
1935        const VECTOR_COUNT: usize = 1000;
1936        const THREAD_COUNT: usize = 4;
1937
1938        let options = IndexOptions {
1939            dimensions: DIMENSIONS,
1940            metric: MetricKind::Cos,
1941            quantization: ScalarKind::F32,
1942            ..Default::default()
1943        };
1944
1945        let index = Arc::new(Index::new(&options).unwrap());
1946        index
1947            .reserve_capacity_and_threads(VECTOR_COUNT, THREAD_COUNT)
1948            .unwrap();
1949
1950        // Generate deterministic vectors using rand crate for reproducible testing
1951        let seed = 42; // Fixed seed for reproducibility
1952        let mut rng = ChaCha8Rng::seed_from_u64(seed);
1953        let uniform = Uniform::new(-1.0f32, 1.0f32).unwrap();
1954
1955        // Store reference vectors for validation
1956        let mut reference_vectors: Vec<[f32; DIMENSIONS]> = Vec::with_capacity(VECTOR_COUNT);
1957        for _ in 0..VECTOR_COUNT {
1958            let mut vector = [0.0f32; DIMENSIONS];
1959            // Fill with random values in [-1, 1]
1960            for item in vector.iter_mut().take(DIMENSIONS) {
1961                *item = rng.sample(uniform);
1962            }
1963            reference_vectors.push(vector);
1964        }
1965
1966        let mut pool = fu::spawn(THREAD_COUNT);
1967
1968        // Concurrent indexing
1969        pool.for_n(VECTOR_COUNT, |prong| {
1970            let index_clone = Arc::clone(&index);
1971            let i = prong.task_index;
1972            let vector = reference_vectors[i];
1973            index_clone.add(i as u64, &vector).unwrap();
1974        });
1975
1976        assert_eq!(index.size(), VECTOR_COUNT);
1977
1978        // Concurrent retrieval and validation
1979        let mut pool = fu::spawn(THREAD_COUNT);
1980        let validation_results = Arc::new(std::sync::Mutex::new(Vec::new()));
1981
1982        pool.for_n(VECTOR_COUNT, |prong| {
1983            let index_clone = Arc::clone(&index);
1984            let results_clone = Arc::clone(&validation_results);
1985            let i = prong.task_index;
1986            let expected_vector = &reference_vectors[i];
1987
1988            let mut retrieved_vector = [0.0f32; DIMENSIONS];
1989            let count = index_clone.get(i as u64, &mut retrieved_vector).unwrap();
1990            assert_eq!(count, 1);
1991
1992            // Validate retrieved vector matches expected
1993            let matches = retrieved_vector
1994                .iter()
1995                .zip(expected_vector.iter())
1996                .all(|(a, b)| (a - b).abs() < 1e-6);
1997
1998            let mut results = results_clone.lock().unwrap();
1999            results.push(matches);
2000        });
2001
2002        let validation_results = validation_results.lock().unwrap();
2003        assert_eq!(validation_results.len(), VECTOR_COUNT);
2004        assert!(
2005            validation_results.iter().all(|&x| x),
2006            "All retrieved vectors should match the original ones"
2007        );
2008
2009        // Concurrent search testing
2010        let mut pool = fu::spawn(THREAD_COUNT);
2011        let search_results = Arc::new(std::sync::Mutex::new(Vec::new()));
2012
2013        pool.for_n(100, |prong| {
2014            // Test 100 searches
2015            let index_clone = Arc::clone(&index);
2016            let results_clone = Arc::clone(&search_results);
2017            let query_idx = prong.task_index % VECTOR_COUNT;
2018            let query_vector = &reference_vectors[query_idx];
2019
2020            let matches = index_clone.exact_search(query_vector, 10).unwrap();
2021
2022            // The first result should be the exact match with distance ~0
2023            let exact_match_found = !matches.keys.is_empty()
2024                && matches.keys[0] == query_idx as u64
2025                && matches.distances[0] < 1e-6;
2026
2027            let mut results = results_clone.lock().unwrap();
2028            results.push(exact_match_found);
2029        });
2030
2031        let search_results = search_results.lock().unwrap();
2032        assert_eq!(search_results.len(), 100);
2033        assert!(
2034            search_results.iter().all(|&x| x),
2035            "All searches should find exact matches"
2036        );
2037    }
2038}