use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::vector::core::vector::Vector;
#[derive(Debug, Clone)]
pub struct VectorIndexQuery {
pub query: Vector,
pub params: VectorIndexQueryParams,
pub field_name: Option<String>,
}
impl VectorIndexQuery {
pub fn new(query: Vector) -> Self {
VectorIndexQuery {
query,
params: VectorIndexQueryParams::default(),
field_name: None,
}
}
pub fn top_k(mut self, top_k: usize) -> Self {
self.params.top_k = top_k;
self
}
pub fn min_similarity(mut self, threshold: f32) -> Self {
self.params.min_similarity = threshold;
self
}
pub fn include_scores(mut self, include: bool) -> Self {
self.params.include_scores = include;
self
}
pub fn include_vectors(mut self, include: bool) -> Self {
self.params.include_vectors = include;
self
}
pub fn timeout_ms(mut self, timeout: u64) -> Self {
self.params.timeout_ms = Some(timeout);
self
}
pub fn field_name(mut self, field_name: String) -> Self {
self.field_name = Some(field_name);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorIndexQueryParams {
pub top_k: usize,
pub min_similarity: f32,
pub include_scores: bool,
pub include_vectors: bool,
pub timeout_ms: Option<u64>,
pub reranking: Option<crate::vector::search::scoring::ranking::RankingConfig>,
}
impl Default for VectorIndexQueryParams {
fn default() -> Self {
Self {
top_k: 10,
min_similarity: 0.0,
include_scores: true,
include_vectors: false,
timeout_ms: None,
reranking: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorIndexQueryResult {
pub doc_id: u64,
pub field_name: String,
pub similarity: f32,
pub distance: f32,
pub vector: Option<Vector>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorIndexQueryResults {
pub results: Vec<VectorIndexQueryResult>,
pub candidates_examined: usize,
pub search_time_ms: f64,
pub query_metadata: std::collections::HashMap<String, String>,
}
impl VectorIndexQueryResults {
pub fn new() -> Self {
Self {
results: Vec::new(),
candidates_examined: 0,
search_time_ms: 0.0,
query_metadata: std::collections::HashMap::new(),
}
}
pub fn is_empty(&self) -> bool {
self.results.is_empty()
}
pub fn len(&self) -> usize {
self.results.len()
}
pub fn sort_by_similarity(&mut self) {
self.results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
pub fn sort_by_distance(&mut self) {
self.results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
pub fn take_top_k(&mut self, k: usize) {
if self.results.len() > k {
self.results.truncate(k);
}
}
pub fn filter_by_similarity(&mut self, min_similarity: f32) {
self.results
.retain(|result| result.similarity >= min_similarity);
}
pub fn best_result(&self) -> Option<&VectorIndexQueryResult> {
self.results.iter().max_by(|a, b| {
a.similarity
.partial_cmp(&b.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
})
}
}
impl Default for VectorIndexQueryResults {
fn default() -> Self {
Self::new()
}
}
pub trait VectorIndexSearcher: Send + Sync + std::fmt::Debug {
fn search(&self, request: &VectorIndexQuery) -> Result<VectorIndexQueryResults>;
fn count(&self, request: VectorIndexQuery) -> Result<u64>;
fn warmup(&mut self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum VectorSearchQuery {
Payloads(Vec<crate::vector::store::request::QueryPayload>),
Vectors(Vec<crate::vector::store::request::QueryVector>),
}
fn default_query_limit() -> usize {
10
}
fn default_overfetch() -> f32 {
1.0
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorSearchParams {
#[serde(default)]
pub fields: Option<Vec<crate::vector::store::request::FieldSelector>>,
#[serde(default = "default_query_limit")]
pub limit: usize,
#[serde(default)]
pub score_mode: crate::vector::store::request::VectorScoreMode,
#[serde(default = "default_overfetch")]
pub overfetch: f32,
#[serde(default)]
pub min_score: f32,
#[serde(skip)]
pub allowed_ids: Option<Vec<u64>>,
}
impl Default for VectorSearchParams {
fn default() -> Self {
Self {
fields: None,
limit: default_query_limit(),
score_mode: crate::vector::store::request::VectorScoreMode::default(),
overfetch: default_overfetch(),
min_score: 0.0,
allowed_ids: None,
}
}
}
#[derive(Debug, Clone)]
pub struct VectorSearchRequest {
pub query: VectorSearchQuery,
pub params: VectorSearchParams,
}
impl Default for VectorSearchRequest {
fn default() -> Self {
Self {
query: VectorSearchQuery::Vectors(Vec::new()),
params: VectorSearchParams::default(),
}
}
}
pub trait VectorSearcher: Send + Sync + std::fmt::Debug {
fn search(
&self,
request: &VectorSearchRequest,
) -> crate::error::Result<crate::vector::store::response::VectorSearchResults>;
fn count(&self, request: &VectorSearchRequest) -> crate::error::Result<u64>;
}