use manifoldb_storage::StorageEngine;
use super::error::{ApiError, ApiResult};
use super::filter::Filter;
use super::handle::CollectionHandle;
use super::point::{ScoredPoint, Vector};
pub struct SearchBuilder<'a, E: StorageEngine> {
pub(crate) handle: &'a CollectionHandle<E>,
pub(crate) vector_name: String,
pub(crate) query: Option<Vector>,
pub(crate) limit: usize,
pub(crate) offset: usize,
pub(crate) filter: Option<Filter>,
pub(crate) with_payload: bool,
pub(crate) with_vectors: bool,
pub(crate) score_threshold: Option<f32>,
pub(crate) ef: Option<usize>,
}
impl<'a, E: StorageEngine> SearchBuilder<'a, E> {
pub(crate) fn new(handle: &'a CollectionHandle<E>, vector_name: impl Into<String>) -> Self {
Self {
handle,
vector_name: vector_name.into(),
query: None,
limit: 10,
offset: 0,
filter: None,
with_payload: false,
with_vectors: false,
score_threshold: None,
ef: None,
}
}
#[must_use]
pub fn query(mut self, vector: impl Into<Vector>) -> Self {
self.query = Some(vector.into());
self
}
#[must_use]
pub const fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
#[must_use]
pub const fn offset(mut self, offset: usize) -> Self {
self.offset = offset;
self
}
#[must_use]
pub fn filter(mut self, filter: Filter) -> Self {
self.filter = Some(filter);
self
}
#[must_use]
pub const fn with_payload(mut self, include: bool) -> Self {
self.with_payload = include;
self
}
#[must_use]
pub const fn with_vectors(mut self, include: bool) -> Self {
self.with_vectors = include;
self
}
#[must_use]
pub fn score_threshold(mut self, threshold: f32) -> Self {
self.score_threshold = Some(threshold);
self
}
#[must_use]
pub const fn ef(mut self, ef: usize) -> Self {
self.ef = Some(ef);
self
}
pub fn execute(self) -> ApiResult<Vec<ScoredPoint>> {
let query = self.query.ok_or(ApiError::EmptyQueryVector)?;
if self.limit == 0 {
return Err(ApiError::InvalidSearchLimit);
}
self.handle.execute_search(
&self.vector_name,
query,
self.limit,
self.offset,
self.filter,
self.with_payload,
self.with_vectors,
self.score_threshold,
self.ef,
)
}
}
pub struct HybridSearchBuilder<'a, E: StorageEngine> {
pub(crate) handle: &'a CollectionHandle<E>,
pub(crate) queries: Vec<(String, Vector, f32)>, pub(crate) limit: usize,
pub(crate) offset: usize,
pub(crate) filter: Option<Filter>,
pub(crate) with_payload: bool,
pub(crate) with_vectors: bool,
pub(crate) fusion: FusionStrategy,
}
impl<'a, E: StorageEngine> HybridSearchBuilder<'a, E> {
pub(crate) fn new(handle: &'a CollectionHandle<E>) -> Self {
Self {
handle,
queries: Vec::new(),
limit: 10,
offset: 0,
filter: None,
with_payload: false,
with_vectors: false,
fusion: FusionStrategy::Rrf { k: 60.0 },
}
}
#[must_use]
pub fn query(
mut self,
vector_name: impl Into<String>,
vector: impl Into<Vector>,
weight: f32,
) -> Self {
self.queries.push((vector_name.into(), vector.into(), weight));
self
}
#[must_use]
pub const fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
#[must_use]
pub const fn offset(mut self, offset: usize) -> Self {
self.offset = offset;
self
}
#[must_use]
pub fn filter(mut self, filter: Filter) -> Self {
self.filter = Some(filter);
self
}
#[must_use]
pub const fn with_payload(mut self, include: bool) -> Self {
self.with_payload = include;
self
}
#[must_use]
pub const fn with_vectors(mut self, include: bool) -> Self {
self.with_vectors = include;
self
}
#[must_use]
pub const fn fusion(mut self, strategy: FusionStrategy) -> Self {
self.fusion = strategy;
self
}
pub fn execute(self) -> ApiResult<Vec<ScoredPoint>> {
if self.queries.len() < 2 {
return Err(ApiError::InsufficientVectorsForHybrid);
}
if self.limit == 0 {
return Err(ApiError::InvalidSearchLimit);
}
self.handle.execute_hybrid_search(
self.queries,
self.limit,
self.offset,
self.filter,
self.with_payload,
self.with_vectors,
self.fusion,
)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FusionStrategy {
Rrf {
k: f32,
},
WeightedAverage,
WeightedSum,
}
impl Default for FusionStrategy {
fn default() -> Self {
Self::Rrf { k: 60.0 }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fusion_strategy_default() {
let strategy = FusionStrategy::default();
assert!(matches!(strategy, FusionStrategy::Rrf { k } if (k - 60.0).abs() < 0.001));
}
}