use std::f32;
use crate::datasets::{dense_dataset::DenseDataset, sparse_dataset::SparseDataset};
use crate::hnsw::graph_index::GraphIndex;
use crate::hnsw_utils::config_hnsw::ConfigHnsw;
use crate::index_serializer::IndexSerializer;
use crate::plain_quantizer::PlainQuantizer;
use crate::pq::ProductQuantizer;
use crate::sparse_plain_quantizer::SparsePlainQuantizer;
use crate::{read_numpy_f32_flatten_2d, DArray1, Dataset, DenseDArray1, DistanceType};
use half::f16;
use numpy::{PyArray1, PyArrayMethods, PyReadonlyArray1};
use pyo3::prelude::*;
use rand::{rngs::StdRng, seq::IteratorRandom, SeedableRng};
use rayon;
#[pyclass]
pub struct DensePlainHNSW {
index: GraphIndex<'static, DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>,
}
#[pymethods]
impl DensePlainHNSW {
#[staticmethod]
#[pyo3(signature = (data_path, m=32, ef_construction=200, metric="ip".to_string()))]
pub fn build_from_file(
data_path: &str,
m: usize,
ef_construction: usize,
metric: String,
) -> PyResult<Self> {
let (data_vec, dim) = read_numpy_f32_flatten_2d(data_path.to_string());
let distance = match metric.as_str() {
"l2" => DistanceType::Euclidean,
"ip" => DistanceType::DotProduct,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid metric; choose 'l2' or 'ip'",
))
}
};
let quantizer = PlainQuantizer::<f32>::new(dim, distance);
let boxed_dataset = Box::new(DenseDataset::from_vec(data_vec, dim, quantizer.clone()));
let static_dataset: &'static DenseDataset<PlainQuantizer<f32>> = Box::leak(boxed_dataset);
let config = ConfigHnsw::new()
.num_neighbors(m)
.ef_construction(ef_construction)
.build();
let num_threads = rayon::current_num_threads();
let start = std::time::Instant::now();
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
let elapsed = start.elapsed();
println!("Time to build index: {:?}", elapsed);
Ok(DensePlainHNSW { index })
}
#[staticmethod]
#[pyo3(signature = (data_vec, dim, m=32, ef_construction=200, metric="ip".to_string()))]
pub fn build_from_array(
data_vec: PyReadonlyArray1<f32>,
dim: usize,
m: usize,
ef_construction: usize,
metric: String,
) -> PyResult<Self> {
let data_vec = data_vec.as_slice()?.to_vec();
let distance = match metric.as_str() {
"l2" => DistanceType::Euclidean,
"ip" => DistanceType::DotProduct,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid metric; choose 'l2' or 'ip'",
))
}
};
let quantizer = PlainQuantizer::<f32>::new(dim, distance);
let boxed_dataset = Box::new(DenseDataset::from_vec(data_vec, dim, quantizer.clone()));
let static_dataset: &'static DenseDataset<PlainQuantizer<f32>> = Box::leak(boxed_dataset);
let config = ConfigHnsw::new()
.num_neighbors(m)
.ef_construction(ef_construction)
.build();
let num_threads = rayon::current_num_threads();
let start = std::time::Instant::now();
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
let elapsed = start.elapsed();
println!("Time to build index: {:?}", elapsed);
Ok(DensePlainHNSW { index })
}
pub fn save(&self, path: &str) -> PyResult<()> {
IndexSerializer::save_index(path, &self.index).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Error saving index: {:?}", e))
})
}
#[staticmethod]
pub fn load(path: &str) -> PyResult<Self> {
let index = IndexSerializer::load_index::<
GraphIndex<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>,
>(path);
Ok(DensePlainHNSW { index })
}
pub fn search(
&self,
query: PyReadonlyArray1<f32>,
k: usize,
ef_search: usize,
) -> PyResult<(Py<PyArray1<f32>>, Py<PyArray1<i64>>)> {
let query_slice = query.as_slice()?;
let _dim = query_slice.len();
let mut search_config = ConfigHnsw::new().build();
search_config.set_ef_search(ef_search);
let query_darray = DenseDArray1::new(query_slice);
let search_results = self
.index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
);
let mut distances_vec: Vec<f32> = Vec::with_capacity(k);
let mut ids_vec: Vec<i64> = Vec::with_capacity(k);
for (score, doc_id) in search_results {
distances_vec.push(score);
ids_vec.push(doc_id as i64);
}
Python::with_gil(|py| {
let distances_array = PyArray1::from_vec(py, distances_vec).to_owned();
let ids_array = PyArray1::from_vec(py, ids_vec).to_owned();
Ok((distances_array.into(), ids_array.into()))
})
}
pub fn search_batch(
&self,
queries_path: &str,
k: usize,
ef_search: usize,
) -> PyResult<(Py<PyArray1<f32>>, Py<PyArray1<i64>>)> {
let (queries_slice, dim) = read_numpy_f32_flatten_2d(queries_path.to_string());
let num_queries = queries_slice.len() / dim;
let mut search_config = ConfigHnsw::new().build();
search_config.set_ef_search(ef_search);
let mut ids_vec: Vec<i64> = Vec::with_capacity(num_queries * k);
let mut distances_vec: Vec<f32> = Vec::with_capacity(num_queries * k);
for query in queries_slice.chunks_exact(dim) {
let query_darray = DenseDArray1::new(query);
let search_results = self
.index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
);
for (score, doc_id) in search_results {
distances_vec.push(score);
ids_vec.push(doc_id as i64);
}
}
Python::with_gil(|py| {
let ids_array = PyArray1::from_vec(py, ids_vec).to_owned();
let distances_array = PyArray1::from_vec(py, distances_vec).to_owned();
Ok((distances_array.into(), ids_array.into()))
})
}
}
#[pyclass]
pub struct DensePlainHNSWf16 {
index: GraphIndex<'static, DenseDataset<PlainQuantizer<f16>>, PlainQuantizer<f16>>,
}
#[pymethods]
impl DensePlainHNSWf16 {
#[staticmethod]
#[pyo3(signature = (data_path, m=32, ef_construction=200, metric="ip".to_string()))]
pub fn build_from_file(
data_path: &str,
m: usize,
ef_construction: usize,
metric: String,
) -> PyResult<Self> {
let (data_vec, dim) = read_numpy_f32_flatten_2d(data_path.to_string());
let data_vec: Vec<f16> = data_vec.iter().map(|&v| f16::from_f32(v)).collect();
let distance = match metric.as_str() {
"l2" => DistanceType::Euclidean,
"ip" => DistanceType::DotProduct,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid metric; choose 'l2' or 'ip'",
))
}
};
let quantizer = PlainQuantizer::<f16>::new(dim, distance);
let boxed_dataset = Box::new(DenseDataset::from_vec(data_vec, dim, quantizer.clone()));
let static_dataset: &'static DenseDataset<PlainQuantizer<f16>> = Box::leak(boxed_dataset);
let config = ConfigHnsw::new()
.num_neighbors(m)
.ef_construction(ef_construction)
.build();
let num_threads = rayon::current_num_threads();
let start = std::time::Instant::now();
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
let elapsed = start.elapsed();
println!("Time to build index: {:?}", elapsed);
Ok(DensePlainHNSWf16 { index })
}
#[staticmethod]
#[pyo3(signature = (data_vec, dim, m=32, ef_construction=200, metric="ip".to_string()))]
pub fn build_from_array(
data_vec: PyReadonlyArray1<f32>,
dim: usize,
m: usize,
ef_construction: usize,
metric: String,
) -> PyResult<Self> {
let data_vec: Vec<f16> = data_vec
.as_slice()?
.iter()
.map(|&v| f16::from_f32(v))
.collect();
let distance = match metric.as_str() {
"l2" => DistanceType::Euclidean,
"ip" => DistanceType::DotProduct,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid metric; choose 'l2' or 'ip'",
))
}
};
let quantizer = PlainQuantizer::<f16>::new(dim, distance);
let boxed_dataset = Box::new(DenseDataset::from_vec(data_vec, dim, quantizer.clone()));
let static_dataset: &'static DenseDataset<PlainQuantizer<f16>> = Box::leak(boxed_dataset);
let config = ConfigHnsw::new()
.num_neighbors(m)
.ef_construction(ef_construction)
.build();
let num_threads = rayon::current_num_threads();
let start = std::time::Instant::now();
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
let elapsed = start.elapsed();
println!("Time to build index: {:?}", elapsed);
Ok(DensePlainHNSWf16 { index })
}
pub fn save(&self, path: &str) -> PyResult<()> {
IndexSerializer::save_index(path, &self.index).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Error saving index: {:?}", e))
})
}
#[staticmethod]
pub fn load(path: &str) -> PyResult<Self> {
let index = IndexSerializer::load_index::<
GraphIndex<DenseDataset<PlainQuantizer<f16>>, PlainQuantizer<f16>>,
>(path);
Ok(DensePlainHNSWf16 { index })
}
pub fn search(
&self,
query: PyReadonlyArray1<f32>,
k: usize,
ef_search: usize,
) -> PyResult<(Py<PyArray1<f32>>, Py<PyArray1<i64>>)> {
let query_slice = query.as_slice()?;
let _dim = query_slice.len();
let mut search_config = ConfigHnsw::new().build();
search_config.set_ef_search(ef_search);
let query_slice: Vec<f16> = query_slice.iter().map(|&v| f16::from_f32(v)).collect();
let query_slice = query_slice.as_slice();
let query_darray = DenseDArray1::new(query_slice);
let search_results = self
.index
.search::<DenseDataset<PlainQuantizer<f16>>, PlainQuantizer<f16>>(
query_darray,
k,
&search_config,
);
let mut distances_vec: Vec<f32> = Vec::with_capacity(k);
let mut ids_vec: Vec<i64> = Vec::with_capacity(k);
for (score, doc_id) in search_results {
distances_vec.push(score);
ids_vec.push(doc_id as i64);
}
Python::with_gil(|py| {
let distances_array = PyArray1::from_vec(py, distances_vec).to_owned();
let ids_array = PyArray1::from_vec(py, ids_vec).to_owned();
Ok((distances_array.into(), ids_array.into()))
})
}
pub fn search_batch(
&self,
queries_path: &str,
k: usize,
ef_search: usize,
) -> PyResult<(Py<PyArray1<f32>>, Py<PyArray1<i64>>)> {
let (queries_slice, dim) = read_numpy_f32_flatten_2d(queries_path.to_string());
let num_queries = queries_slice.len() / dim;
let mut search_config = ConfigHnsw::new().build();
search_config.set_ef_search(ef_search);
let mut ids_vec: Vec<i64> = Vec::with_capacity(num_queries * k);
let mut distances_vec: Vec<f32> = Vec::with_capacity(num_queries * k);
for query in queries_slice.chunks_exact(dim) {
let query: Vec<f16> = query.iter().map(|&v| f16::from_f32(v)).collect();
let query_slice = query.as_slice();
let query_darray = DenseDArray1::new(query_slice);
let search_results = self
.index
.search::<DenseDataset<PlainQuantizer<f16>>, PlainQuantizer<f16>>(
query_darray,
k,
&search_config,
);
for (score, doc_id) in search_results {
distances_vec.push(score);
ids_vec.push(doc_id as i64);
}
}
Python::with_gil(|py| {
let ids_array = PyArray1::from_vec(py, ids_vec).to_owned();
let distances_array = PyArray1::from_vec(py, distances_vec).to_owned();
Ok((distances_array.into(), ids_array.into()))
})
}
}
#[pyclass]
pub struct SparsePlainHNSW {
index: GraphIndex<'static, SparseDataset<SparsePlainQuantizer<f32>>, SparsePlainQuantizer<f32>>,
}
#[pymethods]
impl SparsePlainHNSW {
#[staticmethod]
#[pyo3(signature = (data_file, m=32, ef_construction=200, metric="ip".to_string()))]
pub fn build_from_file(
data_file: &str,
m: usize,
ef_construction: usize,
metric: String,
) -> PyResult<Self> {
let distance = match metric.as_str() {
"l2" => DistanceType::Euclidean,
"ip" => DistanceType::DotProduct,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid metric; choose 'l2' or 'ip'",
))
}
};
let boxed_dataset = Box::new(
SparseDataset::<SparsePlainQuantizer<f32>>::read_bin_file(data_file).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Error reading dataset: {:?}",
e
))
})?,
);
let static_dataset: &'static SparseDataset<SparsePlainQuantizer<f32>> =
Box::leak(boxed_dataset);
let quantizer = SparsePlainQuantizer::<f32>::new(static_dataset.dim(), distance);
let config = ConfigHnsw::new()
.num_neighbors(m)
.ef_construction(ef_construction)
.build();
let num_threads = rayon::current_num_threads();
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
Ok(SparsePlainHNSW { index })
}
#[staticmethod]
#[pyo3(signature = (components, values, offsets, m=32, ef_construction=200, metric="ip".to_string()))]
pub fn build_from_arrays(
components: PyReadonlyArray1<i32>,
values: PyReadonlyArray1<f32>,
offsets: PyReadonlyArray1<i32>,
m: usize,
ef_construction: usize,
metric: String,
) -> PyResult<Self> {
let components_vec = components
.to_vec()
.unwrap()
.iter()
.map(|x| *x as u16)
.collect::<Vec<_>>();
let values_slice = values.as_slice()?;
let offsets_vec = offsets
.to_vec()
.unwrap()
.iter()
.map(|x| *x as usize)
.collect::<Vec<_>>();
let distance = match metric.as_str() {
"l2" => DistanceType::Euclidean,
"ip" => DistanceType::DotProduct,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid metric; choose 'l2' or 'ip'",
))
}
};
let boxed_dataset = Box::new(
SparseDataset::<SparsePlainQuantizer<f32>>::from_vecs_f32(
components_vec.as_slice(),
values_slice,
offsets_vec.as_slice(),
)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Error reading dataset: {:?}",
e
))
})?,
);
let static_dataset: &'static SparseDataset<SparsePlainQuantizer<f32>> =
Box::leak(boxed_dataset);
let quantizer = SparsePlainQuantizer::<f32>::new(static_dataset.dim(), distance);
let config = ConfigHnsw::new()
.num_neighbors(m)
.ef_construction(ef_construction)
.build();
let num_threads = rayon::current_num_threads();
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
Ok(SparsePlainHNSW { index })
}
pub fn save(&self, path: &str) -> PyResult<()> {
IndexSerializer::save_index(path, &self.index).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Error saving index: {:?}", e))
})
}
#[staticmethod]
pub fn load(path: &str) -> PyResult<Self> {
let index = IndexSerializer::load_index::<
GraphIndex<SparseDataset<SparsePlainQuantizer<f32>>, SparsePlainQuantizer<f32>>,
>(path);
Ok(SparsePlainHNSW { index })
}
pub fn search(
&self,
query_components: numpy::PyReadonlyArray1<i32>,
query_values: numpy::PyReadonlyArray1<f32>,
k: usize,
ef_search: usize,
) -> PyResult<(Py<PyArray1<f32>>, Py<PyArray1<i64>>)> {
let comp_vec = query_components
.to_vec()
.unwrap()
.iter()
.map(|x| *x as u16)
.collect::<Vec<_>>();
let values_slice = query_values.as_slice()?;
let offsets_vec = vec![0, values_slice.len()];
let query_dataset = SparseDataset::<SparsePlainQuantizer<f32>>::from_vecs_f32(
&comp_vec,
&values_slice,
&offsets_vec,
)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Error constructing query dataset: {:?}",
e
))
})?;
if query_dataset.len() != 1 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Expected a single query dataset.",
));
}
let mut search_config = ConfigHnsw::new().build();
search_config.set_ef_search(ef_search);
let query = query_dataset.iter().next().unwrap();
let search_results = self
.index
.search::<SparseDataset<SparsePlainQuantizer<f32>>, SparsePlainQuantizer<f32>>(
query,
k,
&search_config,
);
let mut distances_vec: Vec<f32> = Vec::with_capacity(k);
let mut ids_vec: Vec<i64> = Vec::with_capacity(k);
for (score, doc_id) in search_results {
distances_vec.push(score);
ids_vec.push(doc_id as i64);
}
Python::with_gil(|py| {
let distances_array = PyArray1::from_vec(py, distances_vec).to_owned();
let ids_array = PyArray1::from_vec(py, ids_vec).to_owned();
Ok((distances_array.into(), ids_array.into()))
})
}
pub fn search_batch(
&self,
query_file: &str,
k: usize,
ef_search: usize,
) -> PyResult<(Py<PyArray1<f32>>, Py<PyArray1<i64>>)> {
let queries = SparseDataset::<SparsePlainQuantizer<f16>>::read_bin_file(query_file)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Error reading query file: {:?}",
e
))
})?;
let num_queries = queries.len();
let mut search_config = ConfigHnsw::new().build();
search_config.set_ef_search(ef_search);
let mut ids_vec: Vec<i64> = Vec::with_capacity(num_queries * k);
let mut distances_vec: Vec<f32> = Vec::with_capacity(num_queries * k);
for query in queries.iter() {
let search_results = self
.index
.search::<SparseDataset<SparsePlainQuantizer<f32>>, SparsePlainQuantizer<f32>>(
query,
k,
&search_config,
);
for (score, doc_id) in search_results {
distances_vec.push(score);
ids_vec.push(doc_id as i64);
}
}
Python::with_gil(|py| {
let distances_array = PyArray1::from_vec(py, distances_vec).to_owned();
let ids_array = PyArray1::from_vec(py, ids_vec).to_owned();
Ok((distances_array.into(), ids_array.into()))
})
}
}
#[pyclass]
pub struct SparsePlainHNSWf16 {
index: GraphIndex<'static, SparseDataset<SparsePlainQuantizer<f16>>, SparsePlainQuantizer<f16>>,
}
#[pymethods]
impl SparsePlainHNSWf16 {
#[staticmethod]
#[pyo3(signature = (data_file, m=32, ef_construction=200, metric="ip".to_string()))]
pub fn build_from_file(
data_file: &str,
m: usize,
ef_construction: usize,
metric: String,
) -> PyResult<Self> {
let distance = match metric.as_str() {
"l2" => DistanceType::Euclidean,
"ip" => DistanceType::DotProduct,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid metric; choose 'l2' or 'ip'",
))
}
};
let boxed_dataset = Box::new(
SparseDataset::<SparsePlainQuantizer<f16>>::read_bin_file_f16(data_file, None)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Error reading dataset: {:?}",
e
))
})?,
);
let static_dataset: &'static SparseDataset<SparsePlainQuantizer<f16>> =
Box::leak(boxed_dataset);
let quantizer = SparsePlainQuantizer::<f16>::new(static_dataset.dim(), distance);
let config = ConfigHnsw::new()
.num_neighbors(m)
.ef_construction(ef_construction)
.build();
let num_threads = rayon::current_num_threads();
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
Ok(SparsePlainHNSWf16 { index })
}
#[staticmethod]
#[pyo3(signature = (components, values, offsets, m=32, ef_construction=200, metric="ip".to_string()))]
pub fn build_from_arrays(
components: PyReadonlyArray1<i32>,
values: PyReadonlyArray1<f32>,
offsets: PyReadonlyArray1<i32>,
m: usize,
ef_construction: usize,
metric: String,
) -> PyResult<Self> {
let components_vec = components
.to_vec()
.unwrap()
.iter()
.map(|x| *x as u16)
.collect::<Vec<_>>();
let values_vec = values
.to_vec()
.unwrap()
.iter()
.map(|&x| half::f16::from_f32(x))
.collect::<Vec<_>>();
let offsets_vec = offsets
.to_vec()
.unwrap()
.iter()
.map(|x| *x as usize)
.collect::<Vec<_>>();
let distance = match metric.as_str() {
"l2" => DistanceType::Euclidean,
"ip" => DistanceType::DotProduct,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid metric; choose 'l2' or 'ip'",
))
}
};
let boxed_dataset = Box::new(
SparseDataset::<SparsePlainQuantizer<f16>>::from_vecs_f16(
components_vec.as_slice(),
values_vec.as_slice(),
offsets_vec.as_slice(),
)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Error reading dataset: {:?}",
e
))
})?,
);
let static_dataset: &'static SparseDataset<SparsePlainQuantizer<f16>> =
Box::leak(boxed_dataset);
let quantizer = SparsePlainQuantizer::<f16>::new(static_dataset.dim(), distance);
let config = ConfigHnsw::new()
.num_neighbors(m)
.ef_construction(ef_construction)
.build();
let num_threads = rayon::current_num_threads();
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
Ok(SparsePlainHNSWf16 { index })
}
pub fn save(&self, path: &str) -> PyResult<()> {
IndexSerializer::save_index(path, &self.index).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Error saving index: {:?}", e))
})
}
#[staticmethod]
pub fn load(path: &str) -> PyResult<Self> {
let index = IndexSerializer::load_index::<
GraphIndex<SparseDataset<SparsePlainQuantizer<f16>>, SparsePlainQuantizer<f16>>,
>(path);
Ok(SparsePlainHNSWf16 { index })
}
pub fn search(
&self,
query_components: numpy::PyReadonlyArray1<i32>,
query_values: numpy::PyReadonlyArray1<f32>,
k: usize,
ef_search: usize,
) -> PyResult<(Py<PyArray1<f32>>, Py<PyArray1<i64>>)> {
let comp_vec = query_components
.to_vec()
.unwrap()
.iter()
.map(|x| *x as u16)
.collect::<Vec<_>>();
let values_slice = query_values.as_slice()?;
let values_f16: Vec<half::f16> = values_slice
.iter()
.map(|&v| half::f16::from_f32(v))
.collect();
let offsets_vec = vec![0, values_f16.len()];
let query_dataset = SparseDataset::<SparsePlainQuantizer<half::f16>>::from_vecs_f16(
&comp_vec,
&values_f16,
&offsets_vec,
)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Error constructing query dataset: {:?}",
e
))
})?;
if query_dataset.len() != 1 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Expected a single query dataset.",
));
}
let mut search_config = ConfigHnsw::new().build();
search_config.set_ef_search(ef_search);
let query = query_dataset.iter().next().unwrap();
let search_results = self.index.search::<SparseDataset<SparsePlainQuantizer<half::f16>>, SparsePlainQuantizer<half::f16>>(query, k, &search_config);
let mut distances_vec: Vec<f32> = Vec::with_capacity(k);
let mut ids_vec: Vec<i64> = Vec::with_capacity(k);
for (score, doc_id) in search_results {
distances_vec.push(score);
ids_vec.push(doc_id as i64);
}
Python::with_gil(|py| {
let distances_array = PyArray1::from_vec(py, distances_vec).to_owned();
let ids_array = PyArray1::from_vec(py, ids_vec).to_owned();
Ok((distances_array.into(), ids_array.into()))
})
}
pub fn search_batch(
&self,
query_file: &str,
k: usize,
ef_search: usize,
) -> PyResult<(Py<PyArray1<f32>>, Py<PyArray1<i64>>)> {
let queries =
SparseDataset::<SparsePlainQuantizer<f16>>::read_bin_file_f16(query_file, None)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Error reading query file: {:?}",
e
))
})?;
let num_queries = queries.len();
let mut search_config = ConfigHnsw::new().build();
search_config.set_ef_search(ef_search);
let mut ids_vec: Vec<i64> = Vec::with_capacity(num_queries * k);
let mut distances_vec: Vec<f32> = Vec::with_capacity(num_queries * k);
for query in queries.iter() {
let search_results = self
.index
.search::<SparseDataset<SparsePlainQuantizer<f16>>, SparsePlainQuantizer<f16>>(
query,
k,
&search_config,
);
for (score, doc_id) in search_results {
distances_vec.push(score);
ids_vec.push(doc_id as i64);
}
}
Python::with_gil(|py| {
let distances_array = PyArray1::from_vec(py, distances_vec).to_owned();
let ids_array = PyArray1::from_vec(py, ids_vec).to_owned();
Ok((distances_array.into(), ids_array.into()))
})
}
}
enum DensePQHNSWEnum {
PQ8(GraphIndex<'static, DenseDataset<ProductQuantizer<8>>, ProductQuantizer<8>>),
PQ16(GraphIndex<'static, DenseDataset<ProductQuantizer<16>>, ProductQuantizer<16>>),
PQ32(GraphIndex<'static, DenseDataset<ProductQuantizer<32>>, ProductQuantizer<32>>),
PQ48(GraphIndex<'static, DenseDataset<ProductQuantizer<48>>, ProductQuantizer<48>>),
PQ64(GraphIndex<'static, DenseDataset<ProductQuantizer<64>>, ProductQuantizer<64>>),
PQ96(GraphIndex<'static, DenseDataset<ProductQuantizer<96>>, ProductQuantizer<96>>),
PQ128(GraphIndex<'static, DenseDataset<ProductQuantizer<128>>, ProductQuantizer<128>>),
PQ192(GraphIndex<'static, DenseDataset<ProductQuantizer<192>>, ProductQuantizer<192>>),
PQ256(GraphIndex<'static, DenseDataset<ProductQuantizer<256>>, ProductQuantizer<256>>),
PQ384(GraphIndex<'static, DenseDataset<ProductQuantizer<384>>, ProductQuantizer<384>>),
}
#[pyclass]
pub struct DensePQHNSW {
inner: DensePQHNSWEnum,
}
#[pymethods]
impl DensePQHNSW {
#[staticmethod]
#[pyo3(signature = (data_path, m_pq, nbits=8, m=32, ef_construction=200, metric="ip".to_string(), sample_size=100_000))]
pub fn build_from_file(
data_path: &str,
m_pq: usize,
nbits: usize,
m: usize,
ef_construction: usize,
metric: String,
sample_size: usize,
) -> PyResult<Self> {
let (data_vec, dim) = read_numpy_f32_flatten_2d(data_path.to_string());
let distance = match metric.as_str() {
"l2" => DistanceType::Euclidean,
"ip" => DistanceType::DotProduct,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid metric; choose 'l2' or 'ip'",
))
}
};
let config = ConfigHnsw::new()
.num_neighbors(m)
.ef_construction(ef_construction)
.build();
let base_dataset = Box::new(DenseDataset::from_vec(
data_vec,
dim,
PlainQuantizer::<f32>::new(dim, distance),
));
let static_dataset: &'static DenseDataset<PlainQuantizer<f32>> = Box::leak(base_dataset);
let mut rng = StdRng::seed_from_u64(523);
let mut training_vec: Vec<f32> = Vec::new();
for vec in static_dataset.iter().choose_multiple(&mut rng, sample_size) {
training_vec.extend(vec.values_as_slice());
}
let training_dataset =
DenseDataset::from_vec(training_vec, dim, PlainQuantizer::<f32>::new(dim, distance));
let num_threads = rayon::current_num_threads();
let inner = match m_pq {
8 => {
let quantizer = ProductQuantizer::<8>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ8(index)
}
16 => {
let quantizer = ProductQuantizer::<16>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ16(index)
}
32 => {
let quantizer = ProductQuantizer::<32>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ32(index)
}
48 => {
let quantizer = ProductQuantizer::<48>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ48(index)
}
64 => {
let quantizer = ProductQuantizer::<64>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ64(index)
}
96 => {
let quantizer = ProductQuantizer::<96>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ96(index)
}
128 => {
let quantizer = ProductQuantizer::<128>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ128(index)
}
192 => {
let quantizer = ProductQuantizer::<192>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ192(index)
}
256 => {
let quantizer = ProductQuantizer::<256>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ256(index)
}
384 => {
let quantizer = ProductQuantizer::<384>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ384(index)
}
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Unsupported m_pq value. Supported values: 8, 16, 32, 48, 64, 96, 128, 192, 256, 384.",
))
}
};
Ok(DensePQHNSW { inner })
}
#[staticmethod]
#[pyo3(signature = (data_vec, dim, m_pq, nbits=8, m=32, ef_construction=200, metric="ip".to_string(), sample_size=100_000))]
pub fn build_from_array(
data_vec: PyReadonlyArray1<f32>,
dim: usize,
m_pq: usize,
nbits: usize,
m: usize,
ef_construction: usize,
metric: String,
sample_size: usize,
) -> PyResult<Self> {
let data_vec = data_vec.as_slice()?.to_vec();
let distance = match metric.as_str() {
"l2" => DistanceType::Euclidean,
"ip" => DistanceType::DotProduct,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid metric; choose 'l2' or 'ip'",
))
}
};
let config = ConfigHnsw::new()
.num_neighbors(m)
.ef_construction(ef_construction)
.build();
let base_dataset = Box::new(DenseDataset::from_vec(
data_vec,
dim,
PlainQuantizer::<f32>::new(dim, distance),
));
let static_dataset: &'static DenseDataset<PlainQuantizer<f32>> = Box::leak(base_dataset);
let mut rng = StdRng::seed_from_u64(523);
let mut training_vec: Vec<f32> = Vec::new();
for vec in static_dataset.iter().choose_multiple(&mut rng, sample_size) {
training_vec.extend(vec.values_as_slice());
}
let training_dataset =
DenseDataset::from_vec(training_vec, dim, PlainQuantizer::<f32>::new(dim, distance));
let num_threads = rayon::current_num_threads();
let inner = match m_pq {
8 => {
let quantizer = ProductQuantizer::<8>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ8(index)
}
16 => {
let quantizer = ProductQuantizer::<16>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ16(index)
}
32 => {
let quantizer = ProductQuantizer::<32>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ32(index)
}
48 => {
let quantizer = ProductQuantizer::<48>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ48(index)
}
64 => {
let quantizer = ProductQuantizer::<64>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ64(index)
}
96 => {
let quantizer = ProductQuantizer::<96>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ96(index)
}
128 => {
let quantizer = ProductQuantizer::<128>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ128(index)
}
192 => {
let quantizer = ProductQuantizer::<192>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ192(index)
}
256 => {
let quantizer = ProductQuantizer::<256>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ256(index)
}
384 => {
let quantizer = ProductQuantizer::<384>::train(&training_dataset, nbits, distance);
let index = GraphIndex::from_dataset(static_dataset, &config, quantizer, num_threads);
DensePQHNSWEnum::PQ384(index)
}
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Unsupported m_pq value. Supported values: 8, 16, 32, 48, 64, 96, 128, 192, 256, 384.",
))
}
};
Ok(DensePQHNSW { inner })
}
#[staticmethod]
pub fn load(path: &str, m_pq: usize) -> PyResult<Self> {
let inner = match m_pq {
8 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<8>>, ProductQuantizer<8>> =
IndexSerializer::load_index(path);
DensePQHNSWEnum::PQ8(index)
}
16 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<16>>, ProductQuantizer<16>> =
IndexSerializer::load_index(path);
DensePQHNSWEnum::PQ16(index)
}
32 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<32>>, ProductQuantizer<32>> =
IndexSerializer::load_index(path);
DensePQHNSWEnum::PQ32(index)
}
48 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<48>>, ProductQuantizer<48>> =
IndexSerializer::load_index(path);
DensePQHNSWEnum::PQ48(index)
}
64 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<64>>, ProductQuantizer<64>> =
IndexSerializer::load_index(path);
DensePQHNSWEnum::PQ64(index)
}
96 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<96>>, ProductQuantizer<96>> =
IndexSerializer::load_index(path);
DensePQHNSWEnum::PQ96(index)
}
128 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<128>>, ProductQuantizer<128>> =
IndexSerializer::load_index(path);
DensePQHNSWEnum::PQ128(index)
}
192 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<192>>, ProductQuantizer<192>> =
IndexSerializer::load_index(path);
DensePQHNSWEnum::PQ192(index)
}
256 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<256>>, ProductQuantizer<256>> =
IndexSerializer::load_index(path);
DensePQHNSWEnum::PQ256(index)
}
384 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<384>>, ProductQuantizer<384>> =
IndexSerializer::load_index(path);
DensePQHNSWEnum::PQ384(index)
}
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Unsupported m_pq value for load. Supported values: 8, 16, 32, 48, 64, 96, 128, 192, 256, 384.",
))
}
};
Ok(DensePQHNSW { inner })
}
pub fn search(
&self,
query: PyReadonlyArray1<f32>,
k: usize,
ef_search: usize,
) -> PyResult<(Py<PyArray1<f32>>, Py<PyArray1<i64>>)> {
let query_slice = query.as_slice()?;
let mut search_config = ConfigHnsw::new().build();
search_config.set_ef_search(ef_search);
let query_darray = DenseDArray1::new(query_slice);
let results = match &self.inner {
DensePQHNSWEnum::PQ8(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ16(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ32(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ48(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ64(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ96(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ128(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ192(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ256(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ384(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
};
let mut distances_vec: Vec<f32> = Vec::with_capacity(k);
let mut ids_vec: Vec<i64> = Vec::with_capacity(k);
for (score, doc_id) in results {
distances_vec.push(score);
ids_vec.push(doc_id as i64);
}
Python::with_gil(|py| {
let distances_array = PyArray1::from_vec(py, distances_vec).to_owned();
let ids_array = PyArray1::from_vec(py, ids_vec).to_owned();
Ok((distances_array.into(), ids_array.into()))
})
}
pub fn search_batch(
&self,
queries_path: &str,
k: usize,
ef_search: usize,
) -> PyResult<(Py<PyArray1<f32>>, Py<PyArray1<i64>>)> {
let (queries_slice, dim) = read_numpy_f32_flatten_2d(queries_path.to_string());
let num_queries = queries_slice.len() / dim;
let mut search_config = ConfigHnsw::new().build();
search_config.set_ef_search(ef_search);
let mut ids_vec: Vec<i64> = Vec::with_capacity(num_queries * k);
let mut distances_vec: Vec<f32> = Vec::with_capacity(num_queries * k);
for query in queries_slice.chunks_exact(dim) {
let query_darray = DenseDArray1::new(query);
let results = match &self.inner {
DensePQHNSWEnum::PQ8(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ16(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ32(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ48(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ64(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ96(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ128(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ192(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ256(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
DensePQHNSWEnum::PQ384(index) => index
.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query_darray,
k,
&search_config,
),
};
for (score, doc_id) in results {
distances_vec.push(score);
ids_vec.push(doc_id as i64);
}
}
Python::with_gil(|py| {
let ids_array = PyArray1::from_vec(py, ids_vec).to_owned();
let distances_array = PyArray1::from_vec(py, distances_vec).to_owned();
Ok((distances_array.into(), ids_array.into()))
})
}
}