use std::collections::HashMap;
use manifoldb_core::PointId;
use manifoldb_storage::StorageEngine;
use manifoldb_vector::store::PointStore;
use manifoldb_vector::types::{
CollectionName as VectorCollectionName, CollectionSchema, NamedVector,
Payload as VectorPayload, VectorConfig as VectorStoreConfig,
VectorType as VectorStoreVectorType,
};
use serde_json::Value as JsonValue;
use super::config::VectorConfig;
use super::error::{ApiError, ApiResult};
use super::filter::Filter;
use super::metadata::CollectionName;
use super::point::{PointStruct, ScoredPoint, Vector};
use super::search::{FusionStrategy, HybridSearchBuilder, SearchBuilder};
pub struct CollectionHandle<E: StorageEngine> {
point_store: PointStore<E>,
name: CollectionName,
vector_name: VectorCollectionName,
vectors: HashMap<String, VectorConfig>,
}
impl<E: StorageEngine> CollectionHandle<E> {
pub(crate) fn create(
engine: E,
name: CollectionName,
vectors: Vec<(String, VectorConfig)>,
) -> ApiResult<Self> {
let vector_name = VectorCollectionName::new(name.as_str())?;
let mut schema = CollectionSchema::new();
for (vec_name, config) in &vectors {
let store_config = vector_config_to_store_config(config);
schema = schema.with_vector(vec_name.clone(), store_config);
}
let point_store = PointStore::new(engine);
point_store.create_collection(&vector_name, schema)?;
Ok(Self { point_store, name, vector_name, vectors: vectors.into_iter().collect() })
}
pub(crate) fn open(engine: E, name: CollectionName) -> ApiResult<Self> {
let vector_name = VectorCollectionName::new(name.as_str())?;
let point_store = PointStore::new(engine);
let collection = point_store.get_collection(&vector_name)?;
let schema = collection.schema();
let mut vectors = HashMap::new();
for (vec_name, store_config) in schema.vectors() {
let config = store_config_to_vector_config(store_config);
vectors.insert(vec_name.clone(), config);
}
Ok(Self { point_store, name, vector_name, vectors })
}
#[must_use]
pub fn name(&self) -> &str {
self.name.as_str()
}
#[must_use]
pub fn vectors(&self) -> &HashMap<String, VectorConfig> {
&self.vectors
}
#[must_use]
pub fn has_vector(&self, name: &str) -> bool {
self.vectors.contains_key(name)
}
pub fn upsert_point(&self, point: PointStruct) -> ApiResult<()> {
let payload = point.payload.map(|v| VectorPayload::from_value(v)).unwrap_or_default();
let vectors = self.convert_vectors_to_store(point.vectors)?;
self.point_store.upsert_point(&self.vector_name, point.id, payload, vectors)?;
Ok(())
}
pub fn insert_point(&self, point: PointStruct) -> ApiResult<()> {
if self.point_exists(point.id)? {
return Err(ApiError::PointAlreadyExists {
point_id: point.id,
collection: self.name.as_str().to_string(),
});
}
self.upsert_point(point)
}
pub fn upsert_points(&self, points: impl IntoIterator<Item = PointStruct>) -> ApiResult<()> {
for point in points {
self.upsert_point(point)?;
}
Ok(())
}
pub fn get_payload(&self, id: PointId) -> ApiResult<Option<JsonValue>> {
match self.point_store.get_payload(&self.vector_name, id) {
Ok(payload) => Ok(Some(payload.into_value())),
Err(manifoldb_vector::error::VectorError::EmbeddingNotFound { .. }) => Ok(None),
Err(e) => Err(ApiError::from(e)),
}
}
pub fn get_vector(&self, id: PointId, vector_name: &str) -> ApiResult<Option<Vector>> {
match self.point_store.get_vector(&self.vector_name, id, vector_name) {
Ok(named_vec) => Ok(Some(named_vector_to_vector(named_vec))),
Err(manifoldb_vector::error::VectorError::EmbeddingNotFound { .. }) => Ok(None),
Err(e) => Err(ApiError::from(e)),
}
}
pub fn get_all_vectors(&self, id: PointId) -> ApiResult<HashMap<String, Vector>> {
let store_vectors = self.point_store.get_all_vectors(&self.vector_name, id)?;
Ok(store_vectors.into_iter().map(|(k, v)| (k, named_vector_to_vector(v))).collect())
}
pub fn update_payload(&self, id: PointId, payload: JsonValue) -> ApiResult<()> {
if !self.point_exists(id)? {
return Err(ApiError::PointNotFound {
point_id: id,
collection: self.name.as_str().to_string(),
});
}
self.point_store.update_payload(
&self.vector_name,
id,
VectorPayload::from_value(payload),
)?;
Ok(())
}
pub fn update_vector(
&self,
id: PointId,
vector_name: &str,
vector: impl Into<Vector>,
) -> ApiResult<()> {
let named_vec = vector_to_named_vector(vector.into());
self.point_store.update_vector(&self.vector_name, id, vector_name, named_vec)?;
Ok(())
}
pub fn delete_point(&self, id: PointId) -> ApiResult<bool> {
Ok(self.point_store.delete_point(&self.vector_name, id)?)
}
pub fn delete_points(&self, ids: impl IntoIterator<Item = PointId>) -> ApiResult<usize> {
let mut deleted = 0;
for id in ids {
if self.delete_point(id)? {
deleted += 1;
}
}
Ok(deleted)
}
pub fn delete_vector(&self, id: PointId, vector_name: &str) -> ApiResult<bool> {
Ok(self.point_store.delete_vector(&self.vector_name, id, vector_name)?)
}
pub fn point_exists(&self, id: PointId) -> ApiResult<bool> {
Ok(self.point_store.point_exists(&self.vector_name, id)?)
}
pub fn list_points(&self) -> ApiResult<Vec<PointId>> {
Ok(self.point_store.list_points(&self.vector_name)?)
}
pub fn count_points(&self) -> ApiResult<usize> {
Ok(self.point_store.count_points(&self.vector_name)?)
}
#[must_use]
pub fn search(&self, vector_name: impl Into<String>) -> SearchBuilder<'_, E> {
SearchBuilder::new(self, vector_name)
}
#[must_use]
pub fn hybrid_search(&self) -> HybridSearchBuilder<'_, E> {
HybridSearchBuilder::new(self)
}
pub(crate) fn execute_search(
&self,
vector_name: &str,
query: Vector,
limit: usize,
offset: usize,
filter: Option<Filter>,
with_payload: bool,
with_vectors: bool,
score_threshold: Option<f32>,
_ef: Option<usize>,
) -> ApiResult<Vec<ScoredPoint>> {
let points = self.point_store.list_points(&self.vector_name)?;
let mut results = Vec::new();
for point_id in points {
let stored_vector =
match self.point_store.get_vector(&self.vector_name, point_id, vector_name) {
Ok(v) => v,
Err(_) => continue, };
let payload = self.point_store.get_payload(&self.vector_name, point_id)?;
let payload_value = payload.value().clone();
if let Some(ref f) = filter {
if !f.matches(&payload_value) {
continue;
}
}
let score = compute_similarity(&query, &named_vector_to_vector(stored_vector));
if let Some(threshold) = score_threshold {
if score < threshold {
continue;
}
}
let mut scored = ScoredPoint::new(point_id, score);
if with_payload {
scored = scored.with_payload(payload_value);
}
if with_vectors {
let vectors = self.get_all_vectors(point_id)?;
scored = scored.with_vectors(vectors);
}
results.push(scored);
}
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
let results: Vec<_> = results.into_iter().skip(offset).take(limit).collect();
Ok(results)
}
pub(crate) fn execute_hybrid_search(
&self,
queries: Vec<(String, Vector, f32)>,
limit: usize,
offset: usize,
filter: Option<Filter>,
with_payload: bool,
with_vectors: bool,
fusion: FusionStrategy,
) -> ApiResult<Vec<ScoredPoint>> {
let mut all_results: Vec<Vec<(PointId, f32)>> = Vec::new();
let mut weights: Vec<f32> = Vec::new();
for (vector_name, query, weight) in queries {
let results = self.execute_search(
&vector_name,
query,
limit * 3, 0,
filter.clone(),
false,
false,
None,
None,
)?;
all_results.push(results.into_iter().map(|r| (r.id, r.score)).collect());
weights.push(weight);
}
let fused = fuse_results(all_results, weights, fusion);
let mut sorted: Vec<_> = fused.into_iter().collect();
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut final_results = Vec::new();
for (point_id, score) in sorted.into_iter().skip(offset).take(limit) {
let mut scored = ScoredPoint::new(point_id, score);
if with_payload {
if let Ok(payload) = self.point_store.get_payload(&self.vector_name, point_id) {
scored = scored.with_payload(payload.into_value());
}
}
if with_vectors {
let vectors = self.get_all_vectors(point_id)?;
scored = scored.with_vectors(vectors);
}
final_results.push(scored);
}
Ok(final_results)
}
fn convert_vectors_to_store(
&self,
vectors: HashMap<String, Vector>,
) -> ApiResult<HashMap<String, NamedVector>> {
let mut result = HashMap::new();
for (name, vector) in vectors {
result.insert(name, vector_to_named_vector(vector));
}
Ok(result)
}
}
fn vector_to_named_vector(vector: Vector) -> NamedVector {
match vector {
Vector::Dense(v) => NamedVector::Dense(v),
Vector::Sparse(v) => NamedVector::Sparse(v),
Vector::Multi(v) => NamedVector::Multi(v),
}
}
fn named_vector_to_vector(vector: NamedVector) -> Vector {
match vector {
NamedVector::Dense(v) => Vector::Dense(v),
NamedVector::Sparse(v) => Vector::Sparse(v),
NamedVector::Multi(v) => Vector::Multi(v),
}
}
fn vector_config_to_store_config(config: &VectorConfig) -> VectorStoreConfig {
match &config.vector_type {
super::config::VectorType::Dense { dimension } => {
VectorStoreConfig::dense(*dimension as u32)
}
super::config::VectorType::Sparse { max_dimension } => {
VectorStoreConfig::sparse(*max_dimension)
}
super::config::VectorType::Multi { token_dim } => {
VectorStoreConfig::multi(*token_dim as u32)
}
super::config::VectorType::Binary { bits } => {
VectorStoreConfig::dense((*bits / 8) as u32)
}
}
}
fn store_config_to_vector_config(config: &VectorStoreConfig) -> VectorConfig {
use manifoldb_vector::distance::DistanceMetric;
match config.vector_type {
VectorStoreVectorType::Dense => {
VectorConfig::dense(config.dimension as usize, DistanceMetric::Cosine)
}
VectorStoreVectorType::Sparse => VectorConfig::sparse(config.dimension),
VectorStoreVectorType::Multi => VectorConfig::multi_vector(config.dimension as usize),
}
}
fn compute_similarity(a: &Vector, b: &Vector) -> f32 {
match (a, b) {
(Vector::Dense(a), Vector::Dense(b)) => cosine_similarity(a, b),
(Vector::Sparse(a), Vector::Sparse(b)) => sparse_dot_product(a, b),
(Vector::Multi(a), Vector::Multi(b)) => max_sim(a, b),
_ => 0.0, }
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
fn sparse_dot_product(a: &[(u32, f32)], b: &[(u32, f32)]) -> f32 {
let mut score = 0.0;
let mut i = 0;
let mut j = 0;
while i < a.len() && j < b.len() {
match a[i].0.cmp(&b[j].0) {
std::cmp::Ordering::Equal => {
score += a[i].1 * b[j].1;
i += 1;
j += 1;
}
std::cmp::Ordering::Less => {
i += 1;
}
std::cmp::Ordering::Greater => {
j += 1;
}
}
}
score
}
fn max_sim(query: &[Vec<f32>], doc: &[Vec<f32>]) -> f32 {
if query.is_empty() || doc.is_empty() {
return 0.0;
}
let mut total = 0.0;
for q_token in query {
let mut max_sim = f32::NEG_INFINITY;
for d_token in doc {
let sim = cosine_similarity(q_token, d_token);
if sim > max_sim {
max_sim = sim;
}
}
if max_sim.is_finite() {
total += max_sim;
}
}
total
}
fn fuse_results(
results: Vec<Vec<(PointId, f32)>>,
weights: Vec<f32>,
strategy: FusionStrategy,
) -> HashMap<PointId, f32> {
let mut fused: HashMap<PointId, f32> = HashMap::new();
match strategy {
FusionStrategy::Rrf { k } => {
for (result_set, weight) in results.iter().zip(weights.iter()) {
for (rank, (point_id, _score)) in result_set.iter().enumerate() {
let rrf_score = weight / (k + (rank as f32) + 1.0);
*fused.entry(*point_id).or_insert(0.0) += rrf_score;
}
}
}
FusionStrategy::WeightedAverage => {
for (result_set, weight) in results.iter().zip(weights.iter()) {
let max_score =
result_set.iter().map(|(_, s)| *s).fold(f32::NEG_INFINITY, f32::max);
let min_score = result_set.iter().map(|(_, s)| *s).fold(f32::INFINITY, f32::min);
let range = max_score - min_score;
for (point_id, score) in result_set {
let normalized = if range > 0.0 { (score - min_score) / range } else { 1.0 };
*fused.entry(*point_id).or_insert(0.0) += normalized * weight;
}
}
}
FusionStrategy::WeightedSum => {
for (result_set, weight) in results.iter().zip(weights.iter()) {
for (point_id, score) in result_set {
*fused.entry(*point_id).or_insert(0.0) += score * weight;
}
}
}
}
fused
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_create_with_arc_engine_works() {
use manifoldb_storage::backends::RedbEngine;
let engine = Arc::new(RedbEngine::in_memory().unwrap());
let engine_clone = Arc::clone(&engine);
let name = CollectionName::new("test_collection").unwrap();
let vectors = vec![(
"embedding".to_string(),
VectorConfig::dense(128, manifoldb_vector::distance::DistanceMetric::Cosine),
)];
let result = CollectionHandle::create(engine, name, vectors);
assert!(result.is_ok(), "Creating handle with Arc should work");
drop(engine_clone);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &b).abs() < 0.001);
let a = vec![1.0, 1.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let expected = 1.0 / 2.0_f32.sqrt();
assert!((cosine_similarity(&a, &b) - expected).abs() < 0.001);
}
#[test]
fn test_sparse_dot_product() {
let a = vec![(0, 1.0), (2, 2.0), (5, 3.0)];
let b = vec![(0, 0.5), (2, 1.0), (10, 1.0)];
assert!((sparse_dot_product(&a, &b) - 2.5).abs() < 0.001);
}
#[test]
fn test_rrf_fusion() {
let results = vec![
vec![(PointId::new(1), 0.9), (PointId::new(2), 0.8)],
vec![(PointId::new(2), 0.95), (PointId::new(1), 0.85)],
];
let weights = vec![0.5, 0.5];
let fused = fuse_results(results, weights, FusionStrategy::Rrf { k: 60.0 });
assert!(fused.contains_key(&PointId::new(1)));
assert!(fused.contains_key(&PointId::new(2)));
}
}