use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use crate::distance::Distance;
use crate::filter::parse_simple_filter;
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
use crate::persistence::BackendConfig;
use crate::quantization::Quantization;
use crate::{EmbedVec, Metadata};
#[pyclass(name = "EmbedVec")]
pub struct PyEmbedVec {
inner: EmbedVec,
}
#[pymethods]
impl PyEmbedVec {
#[new]
#[pyo3(signature = (dim, metric="cosine", m=32, ef_construction=200, persist_path=None, quantization=None, random_seed=None))]
fn new(
dim: usize,
metric: &str,
m: usize,
ef_construction: usize,
persist_path: Option<String>,
quantization: Option<&str>,
random_seed: Option<u64>,
) -> PyResult<Self> {
let distance = match metric.to_lowercase().as_str() {
"cosine" => Distance::Cosine,
"euclidean" | "l2" => Distance::Euclidean,
"dot" | "dotproduct" | "inner" => Distance::DotProduct,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
format!("Unknown metric: {}. Use 'cosine', 'euclidean', or 'dot'", metric),
))
}
};
let seed = random_seed.unwrap_or(0xcafef00d);
let quant = match quantization {
None | Some("none") => Quantization::None,
Some("e8") | Some("e8-10bit") => Quantization::e8(10, true, seed),
Some("e8-8bit") => Quantization::e8(8, true, seed),
Some("e8-12bit") => Quantization::e8(12, true, seed),
Some("e8p") | Some("e8p-10bit") => Quantization::e8(10, true, seed),
Some("e8p-8bit") => Quantization::e8(8, true, seed),
Some("e8p-12bit") => Quantization::e8(12, true, seed),
Some(q) => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
format!("Unknown quantization: {}. Use 'none', 'e8-8bit', 'e8-10bit', or 'e8-12bit'", q),
))
}
};
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
let persistence_config = persist_path.map(|p| BackendConfig::new(p));
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
let inner = EmbedVec::new_internal(dim, distance, m, ef_construction, quant, persistence_config)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
#[cfg(not(any(feature = "persistence-sled", feature = "persistence-rocksdb")))]
let inner = EmbedVec::new_internal(dim, distance, m, ef_construction, quant)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
Ok(Self { inner })
}
fn add(&mut self, vector: Vec<f32>, payload: &Bound<'_, PyDict>) -> PyResult<usize> {
let metadata = pydict_to_metadata(payload)?;
self.inner
.add_internal(&vector, metadata)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
}
fn add_many(&mut self, vectors: Vec<Vec<f32>>, payloads: Bound<'_, PyList>) -> PyResult<()> {
if vectors.len() != payloads.len() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
format!(
"Mismatched lengths: {} vectors, {} payloads",
vectors.len(),
payloads.len()
),
));
}
for (vector, payload_obj) in vectors.iter().zip(payloads.iter()) {
let payload = payload_obj.downcast::<PyDict>()
.map_err(|_| PyErr::new::<pyo3::exceptions::PyTypeError, _>("payloads must be a list of dicts"))?;
let metadata = pydict_to_metadata(payload)?;
self.inner
.add_internal(vector, metadata)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
}
Ok(())
}
#[pyo3(signature = (query_vector, k=10, ef_search=128, filter=None))]
fn search(
&self,
py: Python<'_>,
query_vector: Vec<f32>,
k: usize,
ef_search: usize,
filter: Option<&Bound<'_, PyDict>>,
) -> PyResult<Py<PyList>> {
let filter_expr = if let Some(f) = filter {
let filter_value = pydict_to_metadata(f)?;
parse_simple_filter(&filter_value)
} else {
None
};
let results = self
.inner
.search_internal(&query_vector, k, ef_search, filter_expr)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
let py_results = PyList::empty_bound(py);
for hit in results {
let hit_dict = PyDict::new_bound(py);
hit_dict.set_item("id", hit.id)?;
hit_dict.set_item("score", hit.score)?;
hit_dict.set_item("payload", metadata_to_pyobject(py, &hit.payload)?)?;
py_results.append(hit_dict)?;
}
Ok(py_results.unbind())
}
fn __len__(&self) -> usize {
self.inner.storage.read().len()
}
fn len(&self) -> usize {
self.inner.storage.read().len()
}
fn is_empty(&self) -> bool {
self.inner.storage.read().is_empty()
}
fn clear(&mut self) -> PyResult<()> {
self.inner.storage.write().clear();
self.inner.metadata.write().clear();
self.inner.index.write().clear();
Ok(())
}
#[getter]
fn dimension(&self) -> usize {
self.inner.dimension()
}
#[getter]
fn metric(&self) -> &'static str {
match self.inner.distance() {
Distance::Cosine => "cosine",
Distance::Euclidean => "euclidean",
Distance::DotProduct => "dot",
}
}
fn memory_bytes(&self) -> usize {
self.inner.storage.read().memory_bytes()
}
fn compression_ratio(&self) -> f32 {
self.inner.quantization().compression_ratio(self.inner.dimension())
}
fn __repr__(&self) -> String {
format!(
"EmbedVec(dim={}, metric='{}', len={}, memory={}KB)",
self.inner.dimension(),
self.metric(),
self.len(),
self.memory_bytes() / 1024
)
}
}
fn pydict_to_metadata(dict: &Bound<'_, PyDict>) -> PyResult<Metadata> {
let mut map = serde_json::Map::new();
for (key, value) in dict.iter() {
let key_str: String = key.extract()?;
let json_value = pyobject_to_json(&value)?;
map.insert(key_str, json_value);
}
Ok(serde_json::Value::Object(map))
}
fn pyobject_to_json(obj: &Bound<'_, PyAny>) -> PyResult<serde_json::Value> {
if obj.is_none() {
Ok(serde_json::Value::Null)
} else if let Ok(b) = obj.extract::<bool>() {
Ok(serde_json::Value::Bool(b))
} else if let Ok(i) = obj.extract::<i64>() {
Ok(serde_json::Value::Number(i.into()))
} else if let Ok(f) = obj.extract::<f64>() {
Ok(serde_json::json!(f))
} else if let Ok(s) = obj.extract::<String>() {
Ok(serde_json::Value::String(s))
} else if let Ok(list) = obj.downcast::<PyList>() {
let arr: Result<Vec<serde_json::Value>, _> = list
.iter()
.map(|item| pyobject_to_json(&item))
.collect();
Ok(serde_json::Value::Array(arr?))
} else if let Ok(dict) = obj.downcast::<PyDict>() {
let metadata = pydict_to_metadata(dict)?;
Ok(metadata)
} else {
let s = obj.str()?.to_string();
Ok(serde_json::Value::String(s))
}
}
fn metadata_to_pyobject(py: Python<'_>, value: &Metadata) -> PyResult<PyObject> {
use pyo3::conversion::ToPyObject;
match value {
serde_json::Value::Null => Ok(py.None()),
serde_json::Value::Bool(b) => Ok(b.to_object(py)),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Ok(i.to_object(py))
} else if let Some(f) = n.as_f64() {
Ok(f.to_object(py))
} else {
Ok(py.None())
}
}
serde_json::Value::String(s) => Ok(s.to_object(py)),
serde_json::Value::Array(arr) => {
let list = PyList::empty_bound(py);
for item in arr {
list.append(metadata_to_pyobject(py, item)?)?;
}
Ok(list.unbind().into())
}
serde_json::Value::Object(map) => {
let dict = PyDict::new_bound(py);
for (k, v) in map {
dict.set_item(k, metadata_to_pyobject(py, v)?)?;
}
Ok(dict.unbind().into())
}
}
}
#[pymodule]
fn embedvec_py(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyEmbedVec>()?;
m.add("__version__", "0.5.0")?;
m.add("__doc__", "Fast, lightweight, in-process vector database with HNSW indexing and E8 quantization")?;
Ok(())
}
#[cfg(test)]
mod tests {
#[test]
fn test_distance_parsing() {
assert!(matches!(
"cosine".to_lowercase().as_str(),
"cosine"
));
}
}