use std::fs::File;
use std::num::NonZeroUsize;
use std::path::{Path, PathBuf};
use std::simd::{f32x8, num::SimdFloat, Simd};
type F32x16 = Simd<f32, 16>;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use anyhow::{bail, Context, Result};
use lru::LruCache;
use once_cell::sync::Lazy;
use rayon::prelude::*;
use tempfile::NamedTempFile;
use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
const PARALLEL_BATCH_THRESHOLD: usize = 100;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Metric {
#[default]
InnerProduct,
Cosine,
L2Squared,
}
impl Metric {
#[must_use]
fn to_usearch(self) -> MetricKind {
match self {
Self::InnerProduct => MetricKind::IP,
Self::Cosine => MetricKind::Cos,
Self::L2Squared => MetricKind::L2sq,
}
}
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::InnerProduct => "inner_product",
Self::Cosine => "cosine",
Self::L2Squared => "l2_squared",
}
}
}
impl std::fmt::Display for Metric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl std::str::FromStr for Metric {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"ip" | "inner_product" | "innerproduct" => Ok(Self::InnerProduct),
"cos" | "cosine" => Ok(Self::Cosine),
"l2" | "l2sq" | "l2_squared" | "euclidean" => Ok(Self::L2Squared),
_ => bail!(
"Unknown metric: {}. Valid options: inner_product, cosine, l2_squared",
s
),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Quantization {
#[default]
F32,
F16,
BF16,
I8,
}
impl Quantization {
#[must_use]
fn to_usearch(self) -> ScalarKind {
match self {
Self::F32 => ScalarKind::F32,
Self::F16 => ScalarKind::F16,
Self::BF16 => ScalarKind::BF16,
Self::I8 => ScalarKind::I8,
}
}
#[must_use]
pub fn bytes_per_element(self) -> usize {
match self {
Self::F32 => 4,
Self::F16 | Self::BF16 => 2,
Self::I8 => 1,
}
}
}
#[derive(Debug, Clone)]
pub struct IndexConfig {
pub dimensions: usize,
pub metric: Metric,
pub quantization: Quantization,
pub connectivity: usize,
pub expansion_add: usize,
pub expansion_search: usize,
pub multi: bool,
}
impl IndexConfig {
#[must_use]
pub fn new(dimensions: usize) -> Self {
Self {
dimensions,
metric: Metric::default(),
quantization: Quantization::default(),
connectivity: 0,
expansion_add: 0,
expansion_search: 0,
multi: false,
}
}
#[must_use]
pub fn with_metric(mut self, metric: Metric) -> Self {
self.metric = metric;
self
}
#[must_use]
pub fn with_quantization(mut self, quantization: Quantization) -> Self {
self.quantization = quantization;
self
}
#[must_use]
pub fn with_connectivity(mut self, connectivity: usize) -> Self {
self.connectivity = connectivity;
self
}
#[must_use]
pub fn with_expansion_add(mut self, expansion: usize) -> Self {
self.expansion_add = expansion;
self
}
#[must_use]
pub fn with_expansion_search(mut self, expansion: usize) -> Self {
self.expansion_search = expansion;
self
}
#[must_use]
pub fn with_multi(mut self, multi: bool) -> Self {
self.multi = multi;
self
}
fn to_usearch_options(&self) -> IndexOptions {
IndexOptions {
dimensions: self.dimensions,
metric: self.metric.to_usearch(),
quantization: self.quantization.to_usearch(),
connectivity: self.connectivity,
expansion_add: self.expansion_add,
expansion_search: self.expansion_search,
multi: self.multi,
}
}
#[must_use]
pub fn estimate_memory(&self, vector_count: usize) -> usize {
let bytes_per_vector = self.dimensions * self.quantization.bytes_per_element();
let overhead_factor = 2;
vector_count * bytes_per_vector * overhead_factor
}
}
impl Default for IndexConfig {
fn default() -> Self {
Self::new(768) }
}
#[derive(Debug, Clone)]
pub struct SaveInfo {
pub path: PathBuf,
pub size_bytes: usize,
pub elapsed: Duration,
pub available_before: u64,
pub space_remaining: u64,
}
impl SaveInfo {
#[must_use]
pub fn bytes_per_second(&self) -> f64 {
if self.elapsed.as_secs_f64() > 0.0 {
self.size_bytes as f64 / self.elapsed.as_secs_f64()
} else {
f64::INFINITY
}
}
#[must_use]
pub fn mb_per_second(&self) -> f64 {
self.bytes_per_second() / (1024.0 * 1024.0)
}
#[must_use]
pub fn human_size(&self) -> String {
const KB: usize = 1024;
const MB: usize = KB * 1024;
const GB: usize = MB * 1024;
if self.size_bytes >= GB {
format!("{:.2} GB", self.size_bytes as f64 / GB as f64)
} else if self.size_bytes >= MB {
format!("{:.2} MB", self.size_bytes as f64 / MB as f64)
} else if self.size_bytes >= KB {
format!("{:.2} KB", self.size_bytes as f64 / KB as f64)
} else {
format!("{} bytes", self.size_bytes)
}
}
}
impl std::fmt::Display for SaveInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Saved {} to '{}' in {:?} ({:.1} MB/s)",
self.human_size(),
self.path.display(),
self.elapsed,
self.mb_per_second()
)
}
}
pub struct VectorIndex {
inner: Index,
config: IndexConfig,
}
pub struct IndexView {
inner: VectorIndex,
_file_handle: Arc<File>,
path: PathBuf,
}
impl IndexView {
#[must_use]
pub fn is_valid(&self) -> bool {
self.path.exists()
}
#[must_use]
pub fn path(&self) -> &Path {
&self.path
}
#[must_use]
pub fn config(&self) -> &IndexConfig {
&self.inner.config
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn dimensions(&self) -> usize {
self.inner.dimensions()
}
#[must_use]
pub fn contains(&self, key: u64) -> bool {
self.inner.contains(key)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u64, f32)>> {
VectorIndex::search(&self.inner, query, k)
}
pub fn search_threshold(
&self,
query: &[f32],
k: usize,
threshold: f32,
) -> Result<Vec<(u64, f32)>> {
VectorIndex::search_threshold(&self.inner, query, k, threshold)
}
pub fn get(&self, key: u64) -> Result<Vec<f32>> {
VectorIndex::get(&self.inner, key)
.ok_or_else(|| anyhow::anyhow!("Key {} not found in index", key))
}
#[must_use]
pub fn into_inner(self) -> VectorIndex {
self.inner
}
}
impl std::ops::Deref for IndexView {
type Target = VectorIndex;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl std::fmt::Debug for IndexView {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IndexView")
.field("inner", &self.inner)
.field("path", &self.path)
.field("is_valid", &self.is_valid())
.finish()
}
}
impl VectorIndex {
pub fn new(dimensions: usize, metric: Metric) -> Result<Self> {
let config = IndexConfig::new(dimensions).with_metric(metric);
Self::with_config(config)
}
pub fn with_config(config: IndexConfig) -> Result<Self> {
let options = config.to_usearch_options();
let inner = Index::new(&options)
.map_err(|e| anyhow::anyhow!("Failed to create usearch index: {}", e))?;
Ok(Self { inner, config })
}
#[must_use]
pub fn config(&self) -> &IndexConfig {
&self.config
}
#[must_use]
pub fn dimensions(&self) -> usize {
self.inner.dimensions()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.size()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
#[must_use]
pub fn memory_usage(&self) -> usize {
self.inner.memory_usage()
}
#[must_use]
pub fn metric(&self) -> Metric {
self.config.metric
}
#[must_use]
pub fn to_similarity_scores(&self, distances: &[f32]) -> Vec<f32> {
distances_to_scores_for_metric(distances, self.config.metric)
}
pub fn reserve(&self, capacity: usize) -> Result<()> {
self.inner
.reserve(capacity)
.map_err(|e| anyhow::anyhow!("Failed to reserve capacity: {}", e))
}
#[must_use]
pub fn contains(&self, key: u64) -> bool {
self.inner.contains(key)
}
#[must_use]
pub fn count_key(&self, key: u64) -> usize {
self.inner.count(key)
}
#[must_use]
pub fn get(&self, key: u64) -> Option<Vec<f32>> {
let dimensions = self.inner.dimensions();
let mut buffer = vec![0.0f32; dimensions];
match self.inner.get(key, &mut buffer) {
Ok(count) if count > 0 => Some(buffer),
_ => None,
}
}
#[must_use]
pub fn get_batch(&self, keys: &[u64]) -> Vec<Option<Vec<f32>>> {
keys.iter().map(|&key| self.get(key)).collect()
}
pub fn add(&self, key: u64, vector: &[f32]) -> Result<()> {
self.validate_vector_dimensions(vector)?;
self.inner
.add(key, vector)
.map_err(|e| anyhow::anyhow!("Failed to add vector {}: {}", key, e))
}
pub fn add_batch(&self, keys: &[u64], vectors: &[Vec<f32>]) -> Result<()> {
if keys.len() != vectors.len() {
bail!(
"Keys and vectors length mismatch: {} keys, {} vectors",
keys.len(),
vectors.len()
);
}
if keys.is_empty() {
return Ok(());
}
for (i, vector) in vectors.iter().enumerate() {
self.validate_vector_dimensions(vector)
.with_context(|| format!("Vector at index {}", i))?;
}
if keys.len() < PARALLEL_BATCH_THRESHOLD {
for (key, vector) in keys.iter().zip(vectors.iter()) {
self.inner
.add(*key, vector)
.map_err(|e| anyhow::anyhow!("Failed to add vector {}: {}", key, e))?;
}
return Ok(());
}
let error_count = AtomicUsize::new(0);
let first_error_key = AtomicUsize::new(0);
keys.par_iter()
.zip(vectors.par_iter())
.for_each(|(key, vector)| {
if let Err(_e) = self.inner.add(*key, vector) {
if error_count.fetch_add(1, Ordering::Relaxed) == 0 {
first_error_key.store(*key as usize, Ordering::Relaxed);
}
}
});
let errors = error_count.load(Ordering::Relaxed);
if errors > 0 {
let first_key = first_error_key.load(Ordering::Relaxed);
bail!(
"Failed to add {} vector(s), first failure at key {}",
errors,
first_key
);
}
Ok(())
}
pub fn add_batch_sequential(&self, keys: &[u64], vectors: &[Vec<f32>]) -> Result<()> {
if keys.len() != vectors.len() {
bail!(
"Keys and vectors length mismatch: {} keys, {} vectors",
keys.len(),
vectors.len()
);
}
for (i, (key, vector)) in keys.iter().zip(vectors.iter()).enumerate() {
self.validate_vector_dimensions(vector)
.with_context(|| format!("Vector at index {}", i))?;
self.inner
.add(*key, vector)
.map_err(|e| anyhow::anyhow!("Failed to add vector {}: {}", key, e))?;
}
Ok(())
}
pub fn add_batch_flat(&self, keys: &[u64], vectors_flat: &[f32]) -> Result<()> {
let expected_len = keys.len() * self.dimensions();
if vectors_flat.len() != expected_len {
bail!(
"Flat vector array size mismatch: expected {} ({} keys * {} dims), got {}",
expected_len,
keys.len(),
self.dimensions(),
vectors_flat.len()
);
}
if keys.is_empty() {
return Ok(());
}
let dims = self.dimensions();
if keys.len() < PARALLEL_BATCH_THRESHOLD {
for (i, key) in keys.iter().enumerate() {
let start = i * dims;
let end = start + dims;
let vector = &vectors_flat[start..end];
self.inner
.add(*key, vector)
.map_err(|e| anyhow::anyhow!("Failed to add vector {}: {}", key, e))?;
}
return Ok(());
}
let error_count = AtomicUsize::new(0);
let first_error_key = AtomicUsize::new(0);
keys.par_iter()
.enumerate()
.for_each(|(i, key)| {
let start = i * dims;
let end = start + dims;
let vector = &vectors_flat[start..end];
if let Err(_e) = self.inner.add(*key, vector) {
if error_count.fetch_add(1, Ordering::Relaxed) == 0 {
first_error_key.store(*key as usize, Ordering::Relaxed);
}
}
});
let errors = error_count.load(Ordering::Relaxed);
if errors > 0 {
let first_key = first_error_key.load(Ordering::Relaxed);
bail!(
"Failed to add {} vector(s), first failure at key {}",
errors,
first_key
);
}
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u64, f32)>> {
self.validate_vector_dimensions(query)?;
if self.is_empty() {
return Ok(Vec::new());
}
let k = k.min(self.len());
let matches = self
.inner
.search(query, k)
.map_err(|e| anyhow::anyhow!("Search failed: {}", e))?;
Ok(matches.keys.into_iter().zip(matches.distances).collect())
}
pub fn search_exact(&self, query: &[f32], k: usize) -> Result<Vec<(u64, f32)>> {
self.validate_vector_dimensions(query)?;
if self.is_empty() {
return Ok(Vec::new());
}
let k = k.min(self.len());
let matches = self
.inner
.exact_search(query, k)
.map_err(|e| anyhow::anyhow!("Exact search failed: {}", e))?;
Ok(matches.keys.into_iter().zip(matches.distances).collect())
}
pub fn search_filtered<F>(&self, query: &[f32], k: usize, filter: F) -> Result<Vec<(u64, f32)>>
where
F: Fn(u64) -> bool,
{
self.validate_vector_dimensions(query)?;
if self.is_empty() {
return Ok(Vec::new());
}
let k = k.min(self.len());
let matches = self
.inner
.filtered_search(query, k, filter)
.map_err(|e| anyhow::anyhow!("Filtered search failed: {}", e))?;
Ok(matches.keys.into_iter().zip(matches.distances).collect())
}
pub fn search_threshold(
&self,
query: &[f32],
k: usize,
threshold: f32,
) -> Result<Vec<(u64, f32)>> {
self.validate_vector_dimensions(query)?;
if self.is_empty() {
return Ok(Vec::new());
}
let k = k.min(self.len());
let matches = self
.inner
.search(query, k)
.map_err(|e| anyhow::anyhow!("Search failed: {}", e))?;
Ok(matches
.keys
.into_iter()
.zip(matches.distances)
.filter(|(_, dist)| *dist <= threshold)
.collect())
}
pub fn remove(&self, key: u64) -> Result<usize> {
self.inner
.remove(key)
.map_err(|e| anyhow::anyhow!("Failed to remove vector {}: {}", key, e))
}
pub fn rename(&self, from_key: u64, to_key: u64) -> Result<usize> {
if self.contains(to_key) {
bail!(
"Cannot rename key {} to {}: target key already exists",
from_key,
to_key
);
}
self.inner
.rename(from_key, to_key)
.map_err(|e| anyhow::anyhow!("Failed to rename key {} to {}: {}", from_key, to_key, e))
}
pub fn rename_overwrite(&self, from_key: u64, to_key: u64) -> Result<usize> {
if self.contains(to_key) {
self.remove(to_key)?;
}
self.inner
.rename(from_key, to_key)
.map_err(|e| anyhow::anyhow!("Failed to rename key {} to {}: {}", from_key, to_key, e))
}
pub fn clear(&self) -> Result<()> {
self.inner
.reset()
.map_err(|e| anyhow::anyhow!("Failed to clear index: {}", e))
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
let path = path.as_ref();
let parent = path.parent().unwrap_or_else(|| Path::new("."));
if !parent.exists() {
bail!("Parent directory does not exist: {}", parent.display());
}
let temp_file = NamedTempFile::new_in(parent)
.with_context(|| format!("Failed to create temp file in '{}'", parent.display()))?;
let temp_path = temp_file.path();
let temp_path_str = temp_path.to_string_lossy();
self.inner.save(&temp_path_str).map_err(|e| {
anyhow::anyhow!(
"Failed to save index to temp file '{}': {}",
temp_path_str,
e
)
})?;
temp_file.persist(path).with_context(|| {
format!(
"Failed to atomically rename temp file to '{}'",
path.display()
)
})?;
Ok(())
}
pub fn save_unsafe(&self, path: impl AsRef<Path>) -> Result<()> {
let path_str = path.as_ref().to_string_lossy();
self.inner
.save(&path_str)
.map_err(|e| anyhow::anyhow!("Failed to save index to '{}': {}", path_str, e))
}
pub fn restore(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
if !path.exists() {
bail!("Index file not found: {}", path.display());
}
let path_str = path.to_string_lossy();
let default_opts = IndexOptions::default();
let inner = Index::new(&default_opts)
.map_err(|e| anyhow::anyhow!("Failed to create usearch index: {}", e))?;
inner
.load(&path_str)
.map_err(|e| anyhow::anyhow!("Failed to load index from '{}': {}", path_str, e))?;
let dimensions = inner.dimensions();
if dimensions == 0 {
bail!(
"Invalid index file '{}': loaded index has 0 dimensions",
path_str
);
}
let config = IndexConfig {
dimensions,
metric: Metric::default(),
quantization: Quantization::default(),
connectivity: 0,
expansion_add: 0,
expansion_search: 0,
multi: false,
};
Ok(Self { inner, config })
}
pub fn load_validated(path: impl AsRef<Path>, config: IndexConfig) -> Result<Self> {
let path = path.as_ref();
let mut index = Self::restore(path)?;
if index.dimensions() != config.dimensions {
bail!(
"Dimension mismatch: expected {} but loaded index has {}. \
Use VectorIndex::restore() to load without dimension constraints.",
config.dimensions,
index.dimensions()
);
}
index.config = config;
Ok(index)
}
#[deprecated(
since = "1.2.0",
note = "Use `restore()` for safe loading or `load_validated()` for validated loading"
)]
#[allow(clippy::missing_errors_doc)]
pub fn load(path: impl AsRef<Path>, config: IndexConfig) -> Result<Self> {
Self::load_validated(path, config)
}
pub fn view_restore(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
if !path.exists() {
bail!("Index file not found: {}", path.display());
}
let path_str = path.to_string_lossy();
let default_opts = IndexOptions::default();
let inner = Index::new(&default_opts)
.map_err(|e| anyhow::anyhow!("Failed to create usearch index: {}", e))?;
inner
.view(&path_str)
.map_err(|e| anyhow::anyhow!("Failed to view index from '{}': {}", path_str, e))?;
let dimensions = inner.dimensions();
if dimensions == 0 {
bail!(
"Invalid index file '{}': viewed index has 0 dimensions",
path_str
);
}
let config = IndexConfig {
dimensions,
metric: Metric::default(),
quantization: Quantization::default(),
connectivity: 0,
expansion_add: 0,
expansion_search: 0,
multi: false,
};
Ok(Self { inner, config })
}
pub fn view_validated(path: impl AsRef<Path>, config: IndexConfig) -> Result<Self> {
let path = path.as_ref();
let mut index = Self::view_restore(path)?;
if index.dimensions() != config.dimensions {
bail!(
"Dimension mismatch: expected {} but viewed index has {}. \
Use VectorIndex::view_restore() to view without dimension constraints.",
config.dimensions,
index.dimensions()
);
}
index.config = config;
Ok(index)
}
#[deprecated(
since = "1.2.0",
note = "Use `view_safe()` for proper file lifecycle management"
)]
#[allow(clippy::missing_errors_doc)]
pub fn view(path: impl AsRef<Path>, config: IndexConfig) -> Result<Self> {
Self::view_validated(path, config)
}
pub fn view_safe(path: impl AsRef<Path>) -> Result<IndexView> {
let path = path.as_ref().to_path_buf();
let file = File::open(&path)
.with_context(|| format!("Failed to open index file: {}", path.display()))?;
let index = Self::view_restore(&path)?;
Ok(IndexView {
inner: index,
_file_handle: Arc::new(file),
path,
})
}
pub fn view_validated_safe(
path: impl AsRef<Path>,
config: IndexConfig,
) -> Result<IndexView> {
let view = Self::view_safe(&path)?;
if config.dimensions != 0 && view.inner.dimensions() != config.dimensions {
bail!(
"Dimension mismatch: expected {} but viewed index has {}. \
Use VectorIndex::view_safe() to view without dimension constraints.",
config.dimensions,
view.inner.dimensions()
);
}
Ok(view)
}
#[must_use]
pub fn serialized_size(&self) -> usize {
self.inner.serialized_length()
}
pub fn check_disk_space(&self, path: impl AsRef<Path>) -> Result<bool> {
let path = path.as_ref();
let required = self.serialized_size() as u64;
let safety_margin = std::cmp::max(required / 10, 1024 * 1024);
let required_with_margin = required.saturating_add(safety_margin);
let parent = path.parent().unwrap_or_else(|| Path::new("."));
if !parent.exists() {
bail!(
"Cannot check disk space: parent directory does not exist: {}",
parent.display()
);
}
let available = fs2::available_space(parent)
.with_context(|| format!("Failed to check disk space for '{}'", parent.display()))?;
Ok(available >= required_with_margin)
}
pub fn disk_space_info(&self, path: impl AsRef<Path>) -> Result<(u64, u64, bool)> {
let path = path.as_ref();
let required = self.serialized_size() as u64;
let safety_margin = std::cmp::max(required / 10, 1024 * 1024);
let required_with_margin = required.saturating_add(safety_margin);
let parent = path.parent().unwrap_or_else(|| Path::new("."));
if !parent.exists() {
bail!(
"Cannot check disk space: parent directory does not exist: {}",
parent.display()
);
}
let available = fs2::available_space(parent)
.with_context(|| format!("Failed to check disk space for '{}'", parent.display()))?;
Ok((
available,
required_with_margin,
available >= required_with_margin,
))
}
pub fn save_checked(&self, path: impl AsRef<Path>) -> Result<SaveInfo> {
let path = path.as_ref();
let required = self.serialized_size() as u64;
let safety_margin = std::cmp::max(required / 10, 1024 * 1024);
let required_with_margin = required.saturating_add(safety_margin);
let parent = path.parent().unwrap_or_else(|| Path::new("."));
if !parent.exists() {
bail!("Parent directory does not exist: {}", parent.display());
}
let available = fs2::available_space(parent)
.with_context(|| format!("Failed to check disk space for '{}'", parent.display()))?;
if available < required_with_margin {
bail!(
"Insufficient disk space to save index: \
need {} bytes (including {}% safety margin), \
but only {} bytes available on '{}'",
required_with_margin,
10,
available,
parent.display()
);
}
let start = Instant::now();
self.save(path)?;
let elapsed = start.elapsed();
Ok(SaveInfo {
path: path.to_path_buf(),
size_bytes: required as usize,
elapsed,
available_before: available,
space_remaining: available.saturating_sub(required),
})
}
pub fn to_bytes(&self) -> Result<Vec<u8>> {
let size = self.serialized_size();
let mut buffer = vec![0u8; size];
self.inner
.save_to_buffer(&mut buffer)
.map_err(|e| anyhow::anyhow!("Failed to serialize index to buffer: {}", e))?;
Ok(buffer)
}
pub fn from_bytes(data: &[u8], config: IndexConfig) -> Result<Self> {
let options = config.to_usearch_options();
let index = Index::new(&options)
.map_err(|e| anyhow::anyhow!("Failed to create usearch index: {}", e))?;
index
.load_from_buffer(data)
.map_err(|e| anyhow::anyhow!("Failed to deserialize index from buffer: {}", e))?;
let loaded_dims = index.dimensions();
if config.dimensions != 0 && loaded_dims != config.dimensions {
bail!(
"Dimension mismatch: config specifies {} but loaded index has {}",
config.dimensions,
loaded_dims
);
}
let final_config = if config.dimensions == 0 {
IndexConfig {
dimensions: loaded_dims,
..config
}
} else {
config
};
Ok(Self {
inner: index,
config: final_config,
})
}
pub fn from_bytes_unchecked(data: &[u8]) -> Result<Self> {
let config = IndexConfig {
dimensions: 0,
..Default::default()
};
Self::from_bytes(data, config)
}
#[must_use]
pub fn hardware_acceleration(&self) -> String {
self.inner.hardware_acceleration()
}
pub fn set_expansion_search(&self, expansion: usize) {
self.inner.change_expansion_search(expansion);
}
#[must_use]
pub fn expansion_search(&self) -> usize {
self.inner.expansion_search()
}
pub fn set_expansion_add(&self, expansion: usize) {
self.inner.change_expansion_add(expansion);
}
#[must_use]
pub fn expansion_add(&self) -> usize {
self.inner.expansion_add()
}
fn validate_vector_dimensions(&self, vector: &[f32]) -> Result<()> {
let expected = self.dimensions();
let actual = vector.len();
if actual != expected {
bail!(
"Vector dimension mismatch: expected {}, got {}",
expected,
actual
);
}
Ok(())
}
}
impl std::fmt::Debug for VectorIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VectorIndex")
.field("dimensions", &self.dimensions())
.field("metric", &self.config.metric)
.field("quantization", &self.config.quantization)
.field("size", &self.len())
.field("capacity", &self.capacity())
.field("memory_usage", &self.memory_usage())
.finish()
}
}
#[must_use]
pub fn distances_to_scores_for_metric(distances: &[f32], metric: Metric) -> Vec<f32> {
const LANES: usize = 8;
let n = distances.len();
let mut scores = vec![0.0_f32; n];
let simd_chunks = n / LANES;
let remainder_start = simd_chunks * LANES;
match metric {
Metric::InnerProduct | Metric::Cosine => {
let one = f32x8::splat(1.0);
let zero = f32x8::splat(0.0);
for i in 0..simd_chunks {
let idx = i * LANES;
let d = f32x8::from_slice(&distances[idx..]);
let score = (one - d).simd_clamp(zero, one);
score.copy_to_slice(&mut scores[idx..idx + LANES]);
}
for i in remainder_start..n {
scores[i] = (1.0 - distances[i]).clamp(0.0, 1.0);
}
}
Metric::L2Squared => {
let one = f32x8::splat(1.0);
let zero = f32x8::splat(0.0);
for i in 0..simd_chunks {
let idx = i * LANES;
let d = f32x8::from_slice(&distances[idx..]);
let score = one / (one + d.simd_max(zero));
score.copy_to_slice(&mut scores[idx..idx + LANES]);
}
for i in remainder_start..n {
scores[i] = 1.0 / (1.0 + distances[i].max(0.0));
}
}
}
scores
}
#[must_use]
pub fn distances_to_scores(distances: &[f32]) -> Vec<f32> {
distances_to_scores_for_metric(distances, Metric::InnerProduct)
}
#[inline]
fn squared_norm_simd16(vector: &[f32]) -> f32 {
const LANES: usize = 16;
let chunks = vector.len() / LANES;
let mut acc = F32x16::splat(0.0);
for i in 0..chunks {
let v = F32x16::from_slice(&vector[i * LANES..]);
acc += v * v;
}
let mut sum = acc.reduce_sum();
for i in (chunks * LANES)..vector.len() {
sum += vector[i] * vector[i];
}
sum
}
#[inline]
fn squared_norm_simd8(vector: &[f32]) -> f32 {
const LANES: usize = 8;
let chunks = vector.len() / LANES;
let mut acc = f32x8::splat(0.0);
for i in 0..chunks {
let v = f32x8::from_slice(&vector[i * LANES..]);
acc += v * v;
}
let mut sum = acc.reduce_sum();
for i in (chunks * LANES)..vector.len() {
sum += vector[i] * vector[i];
}
sum
}
#[inline]
fn squared_norm_simd(vector: &[f32]) -> f32 {
if vector.len() < 8 {
return vector.iter().map(|x| x * x).sum();
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if vector.len() >= 16 && is_x86_feature_detected!("avx512f") {
return squared_norm_simd16(vector);
}
}
squared_norm_simd8(vector)
}
#[inline]
fn normalize_inplace_simd16(vector: &mut [f32], inv_norm: f32) {
const LANES: usize = 16;
let chunks = vector.len() / LANES;
let inv = F32x16::splat(inv_norm);
for i in 0..chunks {
let start = i * LANES;
let v = F32x16::from_slice(&vector[start..]);
let normalized = v * inv;
normalized.copy_to_slice(&mut vector[start..start + LANES]);
}
for i in (chunks * LANES)..vector.len() {
vector[i] *= inv_norm;
}
}
#[inline]
fn normalize_inplace_simd8(vector: &mut [f32], inv_norm: f32) {
const LANES: usize = 8;
let chunks = vector.len() / LANES;
let inv = f32x8::splat(inv_norm);
for i in 0..chunks {
let start = i * LANES;
let v = f32x8::from_slice(&vector[start..]);
let normalized = v * inv;
normalized.copy_to_slice(&mut vector[start..start + LANES]);
}
for i in (chunks * LANES)..vector.len() {
vector[i] *= inv_norm;
}
}
#[inline]
fn normalize_inplace_simd(vector: &mut [f32], inv_norm: f32) {
if vector.len() < 8 {
for x in vector.iter_mut() {
*x *= inv_norm;
}
return;
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if vector.len() >= 16 && is_x86_feature_detected!("avx512f") {
normalize_inplace_simd16(vector, inv_norm);
return;
}
}
normalize_inplace_simd8(vector, inv_norm);
}
#[must_use]
pub fn normalize_vector(vector: &[f32]) -> Option<Vec<f32>> {
let norm: f32 = squared_norm_simd(vector).sqrt();
if norm < f32::EPSILON {
return None;
}
let mut result = vector.to_vec();
normalize_inplace_simd(&mut result, 1.0 / norm);
Some(result)
}
#[must_use]
pub fn normalize_vector_or_original(mut vector: Vec<f32>) -> Vec<f32> {
let norm: f32 = squared_norm_simd(&vector).sqrt();
if norm < f32::EPSILON {
return vector;
}
normalize_inplace_simd(&mut vector, 1.0 / norm);
vector
}
#[must_use]
pub fn is_valid_vector(vector: &[f32]) -> bool {
const LANES: usize = 8;
let chunks = vector.chunks_exact(LANES);
let remainder = chunks.remainder();
for chunk in chunks {
let v = f32x8::from_slice(chunk);
if !v.is_finite().all() {
return false;
}
}
remainder.iter().all(|x| x.is_finite())
}
#[must_use]
pub fn is_normalized(vector: &[f32], epsilon: f32) -> bool {
const LANES: usize = 8;
let chunks = vector.chunks_exact(LANES);
let remainder = chunks.remainder();
let mut acc = f32x8::splat(0.0);
for chunk in chunks {
let v = f32x8::from_slice(chunk);
acc += v * v;
}
let mut sum: f32 = acc.reduce_sum();
for &x in remainder {
sum += x * x;
}
let norm = sum.sqrt();
(norm - 1.0).abs() < epsilon
}
const QUERY_CACHE_DEFAULT_CAPACITY: usize = 100;
static QUERY_EMBEDDING_CACHE: Lazy<Mutex<LruCache<String, Vec<f32>>>> = Lazy::new(|| {
let capacity = NonZeroUsize::new(QUERY_CACHE_DEFAULT_CAPACITY)
.expect("QUERY_CACHE_DEFAULT_CAPACITY must be non-zero");
Mutex::new(LruCache::new(capacity))
});
pub fn get_cached_query_embedding<F>(query: &str, compute_fn: F) -> Result<Vec<f32>>
where
F: FnOnce(&str) -> Result<Vec<f32>>,
{
let cache_key = query.to_string();
{
let mut cache = QUERY_EMBEDDING_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire cache lock: {}", e))?;
if let Some(embedding) = cache.get(&cache_key) {
tracing::debug!(query = %query, "Query embedding cache hit");
return Ok(embedding.clone());
}
}
tracing::debug!(query = %query, "Query embedding cache miss, computing...");
let embedding = compute_fn(query)?;
{
let mut cache = QUERY_EMBEDDING_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire cache lock: {}", e))?;
cache.put(cache_key, embedding.clone());
}
Ok(embedding)
}
pub async fn get_cached_query_embedding_async<F, Fut>(
query: &str,
compute_fn: F,
) -> Result<Vec<f32>>
where
F: FnOnce(String) -> Fut,
Fut: std::future::Future<Output = Result<Vec<f32>>>,
{
let cache_key = query.to_string();
{
let mut cache = QUERY_EMBEDDING_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire cache lock: {}", e))?;
if let Some(embedding) = cache.get(&cache_key) {
tracing::debug!(query = %query, "Query embedding cache hit (async)");
return Ok(embedding.clone());
}
}
tracing::debug!(query = %query, "Query embedding cache miss (async), computing...");
let embedding = compute_fn(cache_key.clone()).await?;
{
let mut cache = QUERY_EMBEDDING_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire cache lock: {}", e))?;
cache.put(cache_key, embedding.clone());
}
Ok(embedding)
}
pub fn clear_query_cache() -> Result<()> {
let mut cache = QUERY_EMBEDDING_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire cache lock: {}", e))?;
cache.clear();
tracing::debug!("Query embedding cache cleared");
Ok(())
}
pub fn query_cache_stats() -> Result<(usize, usize)> {
let cache = QUERY_EMBEDDING_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire cache lock: {}", e))?;
Ok((cache.len(), cache.cap().get()))
}
pub fn query_in_cache(query: &str) -> Result<bool> {
let cache = QUERY_EMBEDDING_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to acquire cache lock: {}", e))?;
Ok(cache.contains(query))
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_index() -> VectorIndex {
VectorIndex::new(4, Metric::InnerProduct).expect("Failed to create test index")
}
#[test]
fn test_create_index() {
let index = create_test_index();
assert_eq!(index.dimensions(), 4);
assert_eq!(index.len(), 0);
assert!(index.is_empty());
}
#[test]
fn test_add_and_search() {
let index = create_test_index();
index.reserve(10).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(1, &[0.0, 1.0, 0.0, 0.0]).unwrap();
index.add(2, &[0.0, 0.0, 1.0, 0.0]).unwrap();
assert_eq!(index.len(), 3);
assert!(!index.is_empty());
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 3).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, 0); }
#[test]
fn test_add_batch() {
let index = create_test_index();
index.reserve(3).unwrap();
let keys = vec![0, 1, 2];
let vectors = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
];
index.add_batch(&keys, &vectors).unwrap();
assert_eq!(index.len(), 3);
}
#[test]
fn test_add_batch_flat() {
let index = create_test_index();
index.reserve(2).unwrap();
let keys = vec![0, 1];
let vectors_flat = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ];
index.add_batch_flat(&keys, &vectors_flat).unwrap();
assert_eq!(index.len(), 2);
}
#[test]
fn test_dimension_validation() {
let index = create_test_index();
index.reserve(10).unwrap();
let result = index.add(0, &[1.0, 0.0, 0.0]); assert!(result.is_err());
let result = index.add(0, &[1.0, 0.0, 0.0, 0.0, 0.0]); assert!(result.is_err());
let result = index.add(0, &[1.0, 0.0, 0.0, 0.0]);
assert!(result.is_ok());
}
#[test]
fn test_contains() {
let index = create_test_index();
index.reserve(10).unwrap();
assert!(!index.contains(42));
index.add(42, &[1.0, 0.0, 0.0, 0.0]).unwrap();
assert!(index.contains(42));
assert!(!index.contains(43));
}
#[test]
fn test_remove() {
let index = create_test_index();
index.reserve(10).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
assert!(index.contains(0));
let removed = index.remove(0).unwrap();
assert_eq!(removed, 1);
}
#[test]
fn test_filtered_search() {
let index = create_test_index();
index.reserve(20).unwrap();
for i in 0..10 {
index.add(i, &[i as f32 * 0.1, 0.0, 0.0, 0.0]).unwrap();
}
let results = index
.search_filtered(&[0.5, 0.0, 0.0, 0.0], 10, |key| key % 2 == 0)
.unwrap();
for (key, _) in results {
assert_eq!(key % 2, 0);
}
}
#[test]
fn test_metric_parsing() {
assert_eq!("ip".parse::<Metric>().unwrap(), Metric::InnerProduct);
assert_eq!("cosine".parse::<Metric>().unwrap(), Metric::Cosine);
assert_eq!("l2".parse::<Metric>().unwrap(), Metric::L2Squared);
assert!("invalid".parse::<Metric>().is_err());
}
#[test]
fn test_normalize_vector() {
let v = vec![3.0, 4.0];
let normalized = normalize_vector(&v).expect("non-zero vector should normalize");
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_normalize_vector_zero() {
let zero_vec = vec![0.0, 0.0, 0.0];
assert!(normalize_vector(&zero_vec).is_none());
let near_zero = vec![f32::EPSILON / 2.0, 0.0];
assert!(normalize_vector(&near_zero).is_none());
}
#[test]
fn test_normalize_vector_or_original() {
let v = vec![3.0, 4.0];
let normalized = normalize_vector_or_original(v);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
let zero_vec = vec![0.0, 0.0, 0.0];
let result = normalize_vector_or_original(zero_vec.clone());
assert_eq!(result, zero_vec);
}
#[test]
fn test_is_valid_vector() {
assert!(is_valid_vector(&[1.0, 2.0, 3.0]));
assert!(is_valid_vector(&[0.0, 0.0, 0.0]));
assert!(is_valid_vector(&[-1.0, f32::MAX, f32::MIN]));
assert!(is_valid_vector(&[]));
assert!(!is_valid_vector(&[1.0, f32::NAN, 3.0]));
assert!(!is_valid_vector(&[f32::NAN]));
assert!(!is_valid_vector(&[f32::INFINITY, 2.0, 3.0]));
assert!(!is_valid_vector(&[1.0, f32::NEG_INFINITY]));
}
#[test]
fn test_is_normalized() {
let normalized = vec![0.6, 0.8];
let not_normalized = vec![3.0, 4.0];
assert!(is_normalized(&normalized, 1e-6));
assert!(!is_normalized(¬_normalized, 1e-6));
}
#[test]
fn test_distances_to_scores_inner_product() {
let distances = vec![0.0, 0.5, 1.0, 1.5];
let scores = distances_to_scores_for_metric(&distances, Metric::InnerProduct);
assert!((scores[0] - 1.0).abs() < 1e-6); assert!((scores[1] - 0.5).abs() < 1e-6); assert!((scores[2] - 0.0).abs() < 1e-6); assert!((scores[3] - 0.0).abs() < 1e-6); }
#[test]
fn test_distances_to_scores_cosine() {
let distances = vec![0.0, 0.25, 0.75, 1.0];
let scores = distances_to_scores_for_metric(&distances, Metric::Cosine);
assert!((scores[0] - 1.0).abs() < 1e-6);
assert!((scores[1] - 0.75).abs() < 1e-6);
assert!((scores[2] - 0.25).abs() < 1e-6);
assert!((scores[3] - 0.0).abs() < 1e-6);
}
#[test]
fn test_distances_to_scores_l2_squared() {
let distances = vec![0.0, 1.0, 4.0, 9.0];
let scores = distances_to_scores_for_metric(&distances, Metric::L2Squared);
assert!((scores[0] - 1.0).abs() < 1e-6); assert!((scores[1] - 0.5).abs() < 1e-6); assert!((scores[2] - 0.2).abs() < 1e-6); assert!((scores[3] - 0.1).abs() < 1e-6); }
#[test]
fn test_distances_to_scores_l2_large_distances() {
let distances = vec![99.0, 999.0, 9999.0];
let scores = distances_to_scores_for_metric(&distances, Metric::L2Squared);
assert!(scores[0] > 0.0 && scores[0] < 0.02); assert!(scores[1] > 0.0 && scores[1] < 0.002); assert!(scores[2] > 0.0 && scores[2] < 0.0002); }
#[test]
fn test_distances_to_scores_backward_compat() {
let distances = vec![0.0, 0.5, 1.0, 1.5];
let scores = distances_to_scores(&distances);
assert!((scores[0] - 1.0).abs() < 1e-6);
assert!((scores[1] - 0.5).abs() < 1e-6);
assert!((scores[2] - 0.0).abs() < 1e-6);
assert!((scores[3] - 0.0).abs() < 1e-6);
}
#[test]
fn test_distances_to_scores_simd_path() {
let distances_ip: Vec<f32> = (0..16).map(|i| i as f32 * 0.1).collect();
let scores_ip = distances_to_scores_for_metric(&distances_ip, Metric::InnerProduct);
assert_eq!(scores_ip.len(), 16);
for (i, &score) in scores_ip.iter().enumerate() {
let expected = (1.0 - distances_ip[i]).clamp(0.0, 1.0);
assert!(
(score - expected).abs() < 1e-6,
"IP mismatch at {i}: got {score}, expected {expected}"
);
}
let distances_l2: Vec<f32> = (0..16).map(|i| i as f32).collect();
let scores_l2 = distances_to_scores_for_metric(&distances_l2, Metric::L2Squared);
assert_eq!(scores_l2.len(), 16);
for (i, &score) in scores_l2.iter().enumerate() {
let expected = 1.0 / (1.0 + distances_l2[i].max(0.0));
assert!(
(score - expected).abs() < 1e-6,
"L2 mismatch at {i}: got {score}, expected {expected}"
);
}
let distances_odd: Vec<f32> = (0..11).map(|i| i as f32 * 0.05).collect();
let scores_odd = distances_to_scores_for_metric(&distances_odd, Metric::Cosine);
assert_eq!(scores_odd.len(), 11);
for (i, &score) in scores_odd.iter().enumerate() {
let expected = (1.0 - distances_odd[i]).clamp(0.0, 1.0);
assert!(
(score - expected).abs() < 1e-6,
"Odd mismatch at {i}: got {score}, expected {expected}"
);
}
}
#[test]
fn test_vector_index_to_similarity_scores() {
let index_ip = VectorIndex::new(4, Metric::InnerProduct).unwrap();
let index_l2 = VectorIndex::new(4, Metric::L2Squared).unwrap();
let distances = vec![0.5, 4.0];
let scores_ip = index_ip.to_similarity_scores(&distances);
assert!((scores_ip[0] - 0.5).abs() < 1e-6);
assert!((scores_ip[1] - 0.0).abs() < 1e-6);
let scores_l2 = index_l2.to_similarity_scores(&distances);
assert!((scores_l2[0] - (1.0 / 1.5)).abs() < 1e-6);
assert!((scores_l2[1] - 0.2).abs() < 1e-6);
}
#[test]
fn test_metric_getter() {
let index = VectorIndex::new(4, Metric::Cosine).unwrap();
assert_eq!(index.metric(), Metric::Cosine);
let index2 = VectorIndex::new(4, Metric::L2Squared).unwrap();
assert_eq!(index2.metric(), Metric::L2Squared);
}
#[test]
fn test_config_builder() {
let config = IndexConfig::new(768)
.with_metric(Metric::Cosine)
.with_quantization(Quantization::F16)
.with_connectivity(32)
.with_expansion_add(128)
.with_expansion_search(64);
assert_eq!(config.dimensions, 768);
assert_eq!(config.metric, Metric::Cosine);
assert_eq!(config.quantization, Quantization::F16);
assert_eq!(config.connectivity, 32);
assert_eq!(config.expansion_add, 128);
assert_eq!(config.expansion_search, 64);
}
#[test]
fn test_expansion_add_getter_setter() {
let config = IndexConfig::new(4);
let index = VectorIndex::with_config(config).unwrap();
let default = index.expansion_add();
assert!(default > 0, "Default expansion_add should be positive");
index.set_expansion_add(256);
assert_eq!(index.expansion_add(), 256);
}
#[test]
fn test_expansion_config_applied() {
let config = IndexConfig::new(4)
.with_expansion_add(512)
.with_expansion_search(128);
let index = VectorIndex::with_config(config).unwrap();
assert!(
index.expansion_add() >= 512 || index.expansion_add() > 0,
"expansion_add should be set from config"
);
assert!(
index.expansion_search() >= 128 || index.expansion_search() > 0,
"expansion_search should be set from config"
);
}
#[test]
fn test_expansion_search_getter_setter() {
let index = create_test_index();
let default = index.expansion_search();
assert!(default > 0, "Default expansion_search should be positive");
index.set_expansion_search(128);
assert_eq!(index.expansion_search(), 128);
}
#[test]
fn test_estimate_memory() {
let config = IndexConfig::new(768).with_quantization(Quantization::F32);
let estimate = config.estimate_memory(10000);
assert!(estimate > 50_000_000);
assert!(estimate < 100_000_000);
}
#[test]
fn test_empty_search() {
let index = create_test_index();
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_k_larger_than_index() {
let index = create_test_index();
index.reserve(10).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(1, &[0.0, 1.0, 0.0, 0.0]).unwrap();
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 100).unwrap();
assert_eq!(results.len(), 2); }
#[test]
fn test_debug_format() {
let index = create_test_index();
let debug_str = format!("{:?}", index);
assert!(debug_str.contains("VectorIndex"));
assert!(debug_str.contains("dimensions: 4"));
}
#[test]
fn test_get_vector() {
let config = IndexConfig {
dimensions: 4,
..Default::default()
};
let index = VectorIndex::with_config(config).unwrap();
index.reserve(10).unwrap();
let vector = vec![1.0, 0.0, 0.0, 0.0];
index.add(42, &vector).unwrap();
let retrieved = index.get(42);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap(), vector);
assert!(index.get(999).is_none());
}
#[test]
fn test_get_batch_vectors() {
let index = create_test_index();
index.reserve(10).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(3, &[0.0, 1.0, 0.0, 0.0]).unwrap();
index.add(5, &[0.0, 0.0, 1.0, 0.0]).unwrap();
let results = index.get_batch(&[1, 2, 3, 4, 5]);
assert!(results[0].is_some()); assert!(results[1].is_none()); assert!(results[2].is_some()); assert!(results[3].is_none()); assert!(results[4].is_some());
assert_eq!(results[0].as_ref().unwrap(), &vec![1.0, 0.0, 0.0, 0.0]);
assert_eq!(results[2].as_ref().unwrap(), &vec![0.0, 1.0, 0.0, 0.0]);
assert_eq!(results[4].as_ref().unwrap(), &vec![0.0, 0.0, 1.0, 0.0]);
}
#[test]
fn test_count_key_standard_mode() {
let index = create_test_index();
index.reserve(10).unwrap();
assert_eq!(index.count_key(42), 0);
index.add(42, &[1.0, 0.0, 0.0, 0.0]).unwrap();
assert_eq!(index.count_key(42), 1);
assert_eq!(index.count_key(999), 0);
}
#[test]
fn test_count_key_multi_index() {
let config = IndexConfig::new(4)
.with_metric(Metric::InnerProduct)
.with_multi(true);
let index = VectorIndex::with_config(config).unwrap();
index.reserve(10).unwrap();
let key = 42u64;
index.add(key, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(key, &[0.0, 1.0, 0.0, 0.0]).unwrap();
index.add(key, &[0.0, 0.0, 1.0, 0.0]).unwrap();
assert_eq!(index.count_key(key), 3);
assert_eq!(index.len(), 3);
assert_eq!(index.count_key(999), 0);
}
#[test]
fn test_with_multi_builder() {
let config = IndexConfig::new(768)
.with_metric(Metric::InnerProduct)
.with_multi(true);
assert!(config.multi);
let config_no_multi = IndexConfig::new(768).with_multi(false);
assert!(!config_no_multi.multi);
}
#[test]
fn test_rename() {
let config = IndexConfig {
dimensions: 4,
..Default::default()
};
let index = VectorIndex::with_config(config).unwrap();
index.reserve(10).unwrap();
let vector = vec![1.0, 0.0, 0.0, 0.0];
index.add(42, &vector).unwrap();
assert!(index.contains(42));
assert!(!index.contains(100));
let renamed = index.rename(42, 100).unwrap();
assert_eq!(renamed, 1);
assert!(!index.contains(42));
assert!(index.contains(100));
let retrieved = index.get(100).unwrap();
assert_eq!(retrieved, vector);
}
#[test]
fn test_rename_fails_if_target_exists() {
let config = IndexConfig {
dimensions: 4,
..Default::default()
};
let index = VectorIndex::with_config(config).unwrap();
index.reserve(10).unwrap();
index.add(42, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(100, &[0.0, 1.0, 0.0, 0.0]).unwrap();
let result = index.rename(42, 100);
assert!(result.is_err());
assert!(index.contains(42));
assert!(index.contains(100));
}
#[test]
fn test_rename_overwrite() {
let config = IndexConfig {
dimensions: 4,
..Default::default()
};
let index = VectorIndex::with_config(config).unwrap();
index.reserve(10).unwrap();
let vector1 = vec![1.0, 0.0, 0.0, 0.0];
let vector2 = vec![0.0, 1.0, 0.0, 0.0];
index.add(42, &vector1).unwrap();
index.add(100, &vector2).unwrap();
let renamed = index.rename_overwrite(42, 100).unwrap();
assert_eq!(renamed, 1);
assert!(!index.contains(42));
assert!(index.contains(100));
let retrieved = index.get(100).unwrap();
assert_eq!(retrieved, vector1);
}
#[test]
fn test_rename_nonexistent_key() {
let config = IndexConfig {
dimensions: 4,
..Default::default()
};
let index = VectorIndex::with_config(config).unwrap();
index.reserve(10).unwrap();
let renamed = index.rename(42, 100).unwrap();
assert_eq!(renamed, 0);
}
#[test]
fn test_restore_reads_dimensions_from_file() {
let index = create_test_index(); index.reserve(10).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(1, &[0.0, 1.0, 0.0, 0.0]).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_restore_index.usearch");
index.save(&path).unwrap();
let restored = VectorIndex::restore(&path).unwrap();
assert_eq!(restored.dimensions(), 4);
assert_eq!(restored.len(), 2);
assert!(restored.contains(0));
assert!(restored.contains(1));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_load_validated_success() {
let index = create_test_index(); index.reserve(5).unwrap();
index.add(42, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_load_validated_success.usearch");
index.save(&path).unwrap();
let config = IndexConfig::new(4).with_metric(Metric::Cosine);
let loaded = VectorIndex::load_validated(&path, config).unwrap();
assert_eq!(loaded.dimensions(), 4);
assert_eq!(loaded.config().metric, Metric::Cosine);
assert!(loaded.contains(42));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_load_validated_dimension_mismatch() {
let index = create_test_index();
index.reserve(5).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_load_validated_mismatch.usearch");
index.save(&path).unwrap();
let wrong_config = IndexConfig::new(768); let result = VectorIndex::load_validated(&path, wrong_config);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Dimension mismatch"));
assert!(err_msg.contains("expected 768"));
assert!(err_msg.contains("loaded index has 4"));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_view_restore_reads_dimensions() {
let index = create_test_index(); index.reserve(5).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_view_restore.usearch");
index.save(&path).unwrap();
let viewed = VectorIndex::view_restore(&path).unwrap();
assert_eq!(viewed.dimensions(), 4);
assert!(viewed.contains(0));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_view_validated_dimension_mismatch() {
let index = create_test_index();
index.reserve(5).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_view_validated_mismatch.usearch");
index.save(&path).unwrap();
let wrong_config = IndexConfig::new(1024); let result = VectorIndex::view_validated(&path, wrong_config);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Dimension mismatch"));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_restore_nonexistent_file() {
let result = VectorIndex::restore("/nonexistent/path/index.usearch");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[test]
fn test_view_restore_nonexistent_file() {
let result = VectorIndex::view_restore("/nonexistent/path/index.usearch");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[test]
fn test_atomic_save_creates_valid_file() {
let index = create_test_index();
index.reserve(10).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(1, &[0.0, 1.0, 0.0, 0.0]).unwrap();
index.add(2, &[0.0, 0.0, 1.0, 0.0]).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_atomic_save.usearch");
index.save(&path).unwrap();
assert!(path.exists());
let restored = VectorIndex::restore(&path).unwrap();
assert_eq!(restored.len(), 3);
assert!(restored.contains(0));
assert!(restored.contains(1));
assert!(restored.contains(2));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_atomic_save_overwrites_existing_file() {
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_atomic_overwrite.usearch");
let index1 = create_test_index();
index1.reserve(5).unwrap();
index1.add(100, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index1.save(&path).unwrap();
let index2 = create_test_index();
index2.reserve(5).unwrap();
index2.add(200, &[0.0, 1.0, 0.0, 0.0]).unwrap();
index2.add(201, &[0.0, 0.0, 1.0, 0.0]).unwrap();
index2.save(&path).unwrap();
let restored = VectorIndex::restore(&path).unwrap();
assert_eq!(restored.len(), 2);
assert!(!restored.contains(100)); assert!(restored.contains(200));
assert!(restored.contains(201));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_save_fails_on_nonexistent_parent() {
let index = create_test_index();
index.reserve(5).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let path = std::path::Path::new("/nonexistent_dir_abc123/index.usearch");
let result = index.save(path);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Parent directory does not exist"));
}
#[test]
fn test_save_unsafe_works() {
let index = create_test_index();
index.reserve(5).unwrap();
index.add(42, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_save_unsafe.usearch");
index.save_unsafe(&path).unwrap();
let restored = VectorIndex::restore(&path).unwrap();
assert!(restored.contains(42));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_save_to_current_directory() {
let index = create_test_index();
index.reserve(5).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_save_cwd.usearch");
index.save(&path).unwrap();
let restored = VectorIndex::restore(&path).unwrap();
assert!(restored.contains(0));
std::fs::remove_file(&path).ok();
}
#[test]
fn test_serialization_roundtrip() {
let config = IndexConfig {
dimensions: 4,
..Default::default()
};
let index = VectorIndex::with_config(config.clone()).unwrap();
index.reserve(10).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(2, &[0.0, 1.0, 0.0, 0.0]).unwrap();
let bytes = index.to_bytes().unwrap();
let loaded = VectorIndex::from_bytes(&bytes, config).unwrap();
assert_eq!(loaded.len(), 2);
assert!(loaded.contains(1));
assert!(loaded.contains(2));
}
#[test]
fn test_serialization_preserves_vectors() {
let config = IndexConfig::new(4).with_metric(Metric::InnerProduct);
let index = VectorIndex::with_config(config.clone()).unwrap();
index.reserve(10).unwrap();
let vec1 = [1.0, 0.0, 0.0, 0.0];
let vec2 = [0.0, 1.0, 0.0, 0.0];
let vec3 = [0.0, 0.0, 1.0, 0.0];
index.add(10, &vec1).unwrap();
index.add(20, &vec2).unwrap();
index.add(30, &vec3).unwrap();
let bytes = index.to_bytes().unwrap();
let loaded = VectorIndex::from_bytes(&bytes, config).unwrap();
let results = loaded.search(&vec1, 3).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, 10);
let retrieved = loaded.get(10).unwrap();
assert_eq!(retrieved, vec1.to_vec());
}
#[test]
fn test_serialization_dimension_mismatch() {
let config = IndexConfig::new(4);
let index = VectorIndex::with_config(config).unwrap();
index.reserve(5).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let bytes = index.to_bytes().unwrap();
let wrong_config = IndexConfig::new(768);
let result = VectorIndex::from_bytes(&bytes, wrong_config);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Dimension mismatch"));
assert!(err_msg.contains("768"));
}
#[test]
fn test_from_bytes_unchecked() {
let config = IndexConfig::new(4);
let index = VectorIndex::with_config(config).unwrap();
index.reserve(5).unwrap();
index.add(42, &[1.0, 0.5, 0.0, 0.0]).unwrap();
let bytes = index.to_bytes().unwrap();
let loaded = VectorIndex::from_bytes_unchecked(&bytes).unwrap();
assert_eq!(loaded.dimensions(), 4);
assert!(loaded.contains(42));
assert_eq!(loaded.len(), 1);
}
#[test]
fn test_serialized_size_consistency() {
let config = IndexConfig::new(4);
let index = VectorIndex::with_config(config).unwrap();
index.reserve(10).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(2, &[0.0, 1.0, 0.0, 0.0]).unwrap();
let expected_size = index.serialized_size();
let bytes = index.to_bytes().unwrap();
assert_eq!(bytes.len(), expected_size);
}
#[test]
fn test_serialization_empty_index() {
let config = IndexConfig::new(768);
let index = VectorIndex::with_config(config.clone()).unwrap();
let bytes = index.to_bytes().unwrap();
let loaded = VectorIndex::from_bytes(&bytes, config).unwrap();
assert_eq!(loaded.len(), 0);
assert!(loaded.is_empty());
assert_eq!(loaded.dimensions(), 768);
}
#[test]
fn test_serialization_with_quantization() {
let config = IndexConfig::new(4)
.with_metric(Metric::Cosine)
.with_quantization(Quantization::F16);
let index = VectorIndex::with_config(config.clone()).unwrap();
index.reserve(5).unwrap();
index.add(1, &[0.5, 0.5, 0.5, 0.5]).unwrap();
let bytes = index.to_bytes().unwrap();
let loaded = VectorIndex::from_bytes(&bytes, config.clone()).unwrap();
assert_eq!(loaded.dimensions(), 4);
assert!(loaded.contains(1));
assert_eq!(loaded.config().metric, Metric::Cosine);
assert_eq!(loaded.config().quantization, Quantization::F16);
}
#[test]
fn test_serialization_large_index() {
let config = IndexConfig::new(128);
let index = VectorIndex::with_config(config.clone()).unwrap();
index.reserve(1000).unwrap();
for i in 0..100u64 {
let mut vec = vec![0.0f32; 128];
vec[(i as usize) % 128] = 1.0;
index.add(i, &vec).unwrap();
}
let bytes = index.to_bytes().unwrap();
let loaded = VectorIndex::from_bytes(&bytes, config).unwrap();
assert_eq!(loaded.len(), 100);
for i in 0..100u64 {
assert!(loaded.contains(i));
}
}
#[test]
fn test_check_disk_space_valid_path() {
let index = create_test_index();
index.reserve(10).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_disk_space_check.usearch");
let has_space = index.check_disk_space(&path).unwrap();
assert!(has_space);
}
#[test]
fn test_check_disk_space_nonexistent_parent() {
let index = create_test_index();
index.reserve(5).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let path = std::path::Path::new("/nonexistent_dir_xyz123/index.usearch");
let result = index.check_disk_space(path);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("parent directory does not exist"));
}
#[test]
fn test_disk_space_info() {
let index = create_test_index();
index.reserve(10).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(1, &[0.0, 1.0, 0.0, 0.0]).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_disk_space_info.usearch");
let (available, required, has_enough) = index.disk_space_info(&path).unwrap();
assert!(available > 0);
assert!(required >= 1024 * 1024);
assert!(has_enough);
}
#[test]
fn test_save_checked_success() {
let index = create_test_index();
index.reserve(10).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(1, &[0.0, 1.0, 0.0, 0.0]).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_save_checked.usearch");
let info = index.save_checked(&path).unwrap();
assert_eq!(info.path, path);
assert!(info.size_bytes > 0);
assert!(info.available_before > 0);
assert!(info.space_remaining > 0);
assert!(path.exists());
let restored = VectorIndex::restore(&path).unwrap();
assert_eq!(restored.len(), 2);
std::fs::remove_file(&path).ok();
}
#[test]
fn test_save_checked_nonexistent_parent() {
let index = create_test_index();
index.reserve(5).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let path = std::path::Path::new("/nonexistent_dir_abc999/index.usearch");
let result = index.save_checked(path);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Parent directory does not exist"));
}
#[test]
fn test_save_info_display() {
let info = SaveInfo {
path: PathBuf::from("/test/index.usearch"),
size_bytes: 5 * 1024 * 1024, elapsed: std::time::Duration::from_secs(1),
available_before: 100 * 1024 * 1024 * 1024, space_remaining: 99 * 1024 * 1024 * 1024, };
let display = format!("{}", info);
assert!(display.contains("5.00 MB"));
assert!(display.contains("/test/index.usearch"));
assert!(display.contains("MB/s"));
}
#[test]
fn test_save_info_human_size() {
let info = SaveInfo {
path: PathBuf::from("test"),
size_bytes: 500,
elapsed: std::time::Duration::from_secs(1),
available_before: 0,
space_remaining: 0,
};
assert_eq!(info.human_size(), "500 bytes");
let info = SaveInfo {
path: PathBuf::from("test"),
size_bytes: 2048,
elapsed: std::time::Duration::from_secs(1),
available_before: 0,
space_remaining: 0,
};
assert_eq!(info.human_size(), "2.00 KB");
let info = SaveInfo {
path: PathBuf::from("test"),
size_bytes: 5 * 1024 * 1024,
elapsed: std::time::Duration::from_secs(1),
available_before: 0,
space_remaining: 0,
};
assert_eq!(info.human_size(), "5.00 MB");
let info = SaveInfo {
path: PathBuf::from("test"),
size_bytes: 2 * 1024 * 1024 * 1024,
elapsed: std::time::Duration::from_secs(1),
available_before: 0,
space_remaining: 0,
};
assert_eq!(info.human_size(), "2.00 GB");
}
#[test]
fn test_save_info_bytes_per_second() {
let info = SaveInfo {
path: PathBuf::from("test"),
size_bytes: 10 * 1024 * 1024, elapsed: std::time::Duration::from_secs(2),
available_before: 0,
space_remaining: 0,
};
let bps = info.bytes_per_second();
let expected = (10.0 * 1024.0 * 1024.0) / 2.0;
assert!((bps - expected).abs() < 1.0);
let mbps = info.mb_per_second();
assert!((mbps - 5.0).abs() < 0.01);
}
#[test]
fn test_save_info_zero_elapsed() {
let info = SaveInfo {
path: PathBuf::from("test"),
size_bytes: 1000,
elapsed: std::time::Duration::ZERO,
available_before: 0,
space_remaining: 0,
};
assert!(info.bytes_per_second().is_infinite());
assert!(info.mb_per_second().is_infinite());
}
#[test]
fn test_check_disk_space_includes_safety_margin() {
let index = create_test_index();
index.reserve(5).unwrap();
index.add(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let serialized_size = index.serialized_size();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_safety_margin.usearch");
let (_, required, _) = index.disk_space_info(&path).unwrap();
let min_margin = 1024 * 1024u64;
assert!(
required >= serialized_size as u64 + min_margin,
"Required {} should be >= serialized_size {} + margin {}",
required,
serialized_size,
min_margin
);
}
#[test]
fn test_add_batch_parallel_large() {
let dims = 128;
let count = 500; let index = VectorIndex::new(dims, Metric::InnerProduct).unwrap();
index.reserve(count).unwrap();
let keys: Vec<u64> = (0..count as u64).collect();
let vectors: Vec<Vec<f32>> = (0..count)
.map(|i| {
let mut v = vec![0.0f32; dims];
v[i % dims] = 1.0; v
})
.collect();
index.add_batch(&keys, &vectors).unwrap();
assert_eq!(index.len(), count);
for key in 0..count as u64 {
assert!(index.contains(key), "Missing key {}", key);
}
}
#[test]
fn test_add_batch_flat_parallel_large() {
let dims = 64;
let count = 200;
let index = VectorIndex::new(dims, Metric::InnerProduct).unwrap();
index.reserve(count).unwrap();
let keys: Vec<u64> = (0..count as u64).collect();
let mut vectors_flat = vec![0.0f32; count * dims];
for i in 0..count {
vectors_flat[i * dims + (i % dims)] = 1.0;
}
index.add_batch_flat(&keys, &vectors_flat).unwrap();
assert_eq!(index.len(), count);
for key in 0..count as u64 {
assert!(index.contains(key), "Missing key {}", key);
}
}
#[test]
fn test_add_batch_sequential_method() {
let index = create_test_index();
index.reserve(5).unwrap();
let keys = vec![0, 1, 2, 3, 4];
let vectors = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
vec![0.5, 0.5, 0.0, 0.0],
];
index.add_batch_sequential(&keys, &vectors).unwrap();
assert_eq!(index.len(), 5);
}
#[test]
fn test_add_batch_empty() {
let index = create_test_index();
index.reserve(10).unwrap();
let keys: Vec<u64> = vec![];
let vectors: Vec<Vec<f32>> = vec![];
index.add_batch(&keys, &vectors).unwrap();
assert_eq!(index.len(), 0);
}
#[test]
fn test_add_batch_flat_empty() {
let index = create_test_index();
index.reserve(10).unwrap();
let keys: Vec<u64> = vec![];
let vectors_flat: Vec<f32> = vec![];
index.add_batch_flat(&keys, &vectors_flat).unwrap();
assert_eq!(index.len(), 0);
}
#[test]
fn test_add_batch_below_parallel_threshold() {
let index = create_test_index();
let count = 50; index.reserve(count).unwrap();
let keys: Vec<u64> = (0..count as u64).collect();
let vectors: Vec<Vec<f32>> = (0..count)
.map(|_| vec![0.25, 0.25, 0.25, 0.25])
.collect();
index.add_batch(&keys, &vectors).unwrap();
assert_eq!(index.len(), count);
}
#[test]
fn test_view_safe_basic() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test_view_safe.usearch");
let index = create_test_index();
index.reserve(5).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(2, &[0.0, 1.0, 0.0, 0.0]).unwrap();
index.save(&path).unwrap();
let view = VectorIndex::view_safe(&path).unwrap();
assert!(view.is_valid());
assert_eq!(view.path(), path);
assert_eq!(view.dimensions(), 4);
assert_eq!(view.len(), 2);
assert!(!view.is_empty());
assert!(view.contains(1));
assert!(view.contains(2));
assert!(!view.contains(3));
}
#[test]
fn test_view_safe_search() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test_view_safe_search.usearch");
let index = create_test_index();
index.reserve(5).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.add(2, &[0.0, 1.0, 0.0, 0.0]).unwrap();
index.add(3, &[0.0, 0.0, 1.0, 0.0]).unwrap();
index.save(&path).unwrap();
let view = VectorIndex::view_safe(&path).unwrap();
let results = view.search(&[1.0, 0.0, 0.0, 0.0], 3).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, 1); }
#[test]
fn test_view_validated_safe_matching_dimensions() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test_view_validated_safe.usearch");
let index = create_test_index();
index.reserve(5).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.save(&path).unwrap();
let config = IndexConfig::new(4);
let view = VectorIndex::view_validated_safe(&path, config).unwrap();
assert!(view.is_valid());
assert_eq!(view.dimensions(), 4);
assert!(view.contains(1));
}
#[test]
fn test_view_validated_safe_dimension_mismatch() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test_view_validated_safe_mismatch.usearch");
let index = create_test_index();
index.reserve(5).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.save(&path).unwrap();
let config = IndexConfig::new(768);
let result = VectorIndex::view_validated_safe(&path, config);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Dimension mismatch"));
}
#[test]
fn test_view_validated_safe_zero_dimension_skips_validation() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test_view_validated_safe_zero.usearch");
let index = create_test_index();
index.reserve(5).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.save(&path).unwrap();
let config = IndexConfig::new(0);
let view = VectorIndex::view_validated_safe(&path, config).unwrap();
assert!(view.is_valid());
assert_eq!(view.dimensions(), 4); }
#[test]
fn test_view_safe_nonexistent_file() {
let result = VectorIndex::view_safe("/nonexistent/path/index.usearch");
assert!(result.is_err());
}
#[test]
fn test_view_safe_into_inner() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test_view_safe_into_inner.usearch");
let index = create_test_index();
index.reserve(5).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.save(&path).unwrap();
let view = VectorIndex::view_safe(&path).unwrap();
let inner = view.into_inner();
assert_eq!(inner.dimensions(), 4);
assert!(inner.contains(1));
}
#[test]
fn test_view_safe_deref() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test_view_safe_deref.usearch");
let index = create_test_index();
index.reserve(5).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.save(&path).unwrap();
let view = VectorIndex::view_safe(&path).unwrap();
fn takes_vector_index(index: &VectorIndex) -> bool {
index.contains(1)
}
assert!(takes_vector_index(&view));
}
#[cfg(unix)]
#[test]
fn test_view_safe_keeps_file_open_on_delete() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test_view_safe_delete.usearch");
let index = create_test_index();
index.reserve(5).unwrap();
index.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.save(&path).unwrap();
let view = VectorIndex::view_safe(&path).unwrap();
assert!(view.is_valid());
assert!(view.contains(1));
std::fs::remove_file(&path).unwrap();
assert!(!view.is_valid());
assert!(view.contains(1));
let results = view.search(&[1.0, 0.0, 0.0, 0.0], 1).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, 1);
}
#[test]
fn test_query_cache_basic_operations() {
clear_query_cache().unwrap();
let (size, capacity) = query_cache_stats().unwrap();
assert_eq!(size, 0);
assert_eq!(capacity, 100);
assert!(!query_in_cache("test query").unwrap());
let call_count = std::sync::atomic::AtomicUsize::new(0);
let embedding = get_cached_query_embedding("test query", |_q| {
call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(vec![0.1, 0.2, 0.3, 0.4])
})
.unwrap();
assert_eq!(embedding, vec![0.1, 0.2, 0.3, 0.4]);
assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
assert!(query_in_cache("test query").unwrap());
let (size, _) = query_cache_stats().unwrap();
assert_eq!(size, 1);
let embedding2 = get_cached_query_embedding("test query", |_q| {
call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(vec![0.5, 0.6, 0.7, 0.8]) })
.unwrap();
assert_eq!(embedding2, vec![0.1, 0.2, 0.3, 0.4]); assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
clear_query_cache().unwrap();
let (size, _) = query_cache_stats().unwrap();
assert_eq!(size, 0);
assert!(!query_in_cache("test query").unwrap());
}
#[test]
fn test_query_cache_different_queries() {
clear_query_cache().unwrap();
for i in 0..5 {
let query = format!("query {}", i);
let embedding = get_cached_query_embedding(&query, |_q| Ok(vec![i as f32; 4])).unwrap();
assert_eq!(embedding, vec![i as f32; 4]);
}
let (size, _) = query_cache_stats().unwrap();
assert_eq!(size, 5);
for i in 0..5 {
let query = format!("query {}", i);
assert!(query_in_cache(&query).unwrap());
}
clear_query_cache().unwrap();
}
#[test]
fn test_query_cache_compute_error_not_cached() {
clear_query_cache().unwrap();
let result = get_cached_query_embedding("error query", |_q| {
Err(anyhow::anyhow!("Simulated TEI error"))
});
assert!(result.is_err());
assert!(!query_in_cache("error query").unwrap());
let (size, _) = query_cache_stats().unwrap();
assert_eq!(size, 0);
clear_query_cache().unwrap();
}
#[tokio::test]
async fn test_query_cache_async_basic() {
clear_query_cache().unwrap();
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let embedding = get_cached_query_embedding_async("async test", |_q| {
let cc = call_count_clone.clone();
async move {
cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(vec![1.0, 2.0, 3.0])
}
})
.await
.unwrap();
assert_eq!(embedding, vec![1.0, 2.0, 3.0]);
assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
let call_count_clone2 = call_count.clone();
let embedding2 = get_cached_query_embedding_async("async test", |_q| {
let cc = call_count_clone2.clone();
async move {
cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(vec![9.0, 9.0, 9.0])
}
})
.await
.unwrap();
assert_eq!(embedding2, vec![1.0, 2.0, 3.0]); assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
clear_query_cache().unwrap();
}
}