use std::collections::HashMap;
use crate::lexical::query::Query;
use crate::lexical::search::searcher::{LexicalSearchQuery, SortField};
use crate::vector::VectorScoreMode;
pub use crate::vector::search::searcher::VectorSearchQuery;
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum SearchQuery {
Dsl(String),
Lexical(LexicalSearchQuery),
Vector(VectorSearchQuery),
Hybrid {
lexical: LexicalSearchQuery,
vector: VectorSearchQuery,
mode: HybridMode,
},
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum HybridMode {
#[default]
Union,
Intersection,
}
#[derive(Debug, Clone)]
pub struct LexicalSearchOptions {
pub field_boosts: HashMap<String, f32>,
pub min_score: f32,
pub timeout_ms: Option<u64>,
pub parallel: bool,
pub sort_by: SortField,
}
impl Default for LexicalSearchOptions {
fn default() -> Self {
Self {
field_boosts: HashMap::new(),
min_score: 0.0,
timeout_ms: None,
parallel: false,
sort_by: SortField::Score,
}
}
}
#[derive(Debug, Clone)]
pub struct VectorSearchOptions {
pub score_mode: VectorScoreMode,
pub min_score: f32,
}
impl Default for VectorSearchOptions {
fn default() -> Self {
Self {
score_mode: VectorScoreMode::WeightedSum,
min_score: 0.0,
}
}
}
pub struct SearchRequest {
pub query: SearchQuery,
pub limit: usize,
pub offset: usize,
pub fusion_algorithm: Option<FusionAlgorithm>,
pub filter_query: Option<Box<dyn Query>>,
pub lexical_options: LexicalSearchOptions,
pub vector_options: VectorSearchOptions,
}
#[derive(Debug, Clone, Copy)]
pub enum FusionAlgorithm {
RRF {
k: f64,
},
WeightedSum {
lexical_weight: f32,
vector_weight: f32,
},
}
impl Default for SearchRequest {
fn default() -> Self {
Self {
query: SearchQuery::Dsl(String::new()),
limit: 10,
offset: 0,
fusion_algorithm: None,
filter_query: None,
lexical_options: LexicalSearchOptions::default(),
vector_options: VectorSearchOptions::default(),
}
}
}
pub struct SearchRequestBuilder {
dsl: Option<String>,
lexical_query: Option<LexicalSearchQuery>,
vector_query: Option<VectorSearchQuery>,
limit: usize,
offset: usize,
fusion_algorithm: Option<FusionAlgorithm>,
filter_query: Option<Box<dyn Query>>,
lexical_options: LexicalSearchOptions,
vector_options: VectorSearchOptions,
}
impl Default for SearchRequestBuilder {
fn default() -> Self {
Self::new()
}
}
impl SearchRequestBuilder {
pub fn new() -> Self {
Self {
dsl: None,
lexical_query: None,
vector_query: None,
limit: 10,
offset: 0,
fusion_algorithm: None,
filter_query: None,
lexical_options: LexicalSearchOptions::default(),
vector_options: VectorSearchOptions::default(),
}
}
pub fn query_dsl(mut self, dsl: impl Into<String>) -> Self {
self.dsl = Some(dsl.into());
self
}
pub fn lexical_query(mut self, query: LexicalSearchQuery) -> Self {
self.lexical_query = Some(query);
self
}
pub fn vector_query(mut self, query: VectorSearchQuery) -> Self {
self.vector_query = Some(query);
self
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn offset(mut self, offset: usize) -> Self {
self.offset = offset;
self
}
pub fn fusion_algorithm(mut self, fusion: FusionAlgorithm) -> Self {
let fusion = match fusion {
FusionAlgorithm::WeightedSum {
lexical_weight,
vector_weight,
} => FusionAlgorithm::WeightedSum {
lexical_weight: lexical_weight.clamp(0.0, 1.0),
vector_weight: vector_weight.clamp(0.0, 1.0),
},
other => other,
};
self.fusion_algorithm = Some(fusion);
self
}
pub fn filter_query(mut self, query: Box<dyn Query>) -> Self {
self.filter_query = Some(query);
self
}
pub fn add_field_boost(mut self, field: impl Into<String>, boost: f32) -> Self {
self.lexical_options
.field_boosts
.insert(field.into(), boost);
self
}
pub fn lexical_min_score(mut self, min_score: f32) -> Self {
self.lexical_options.min_score = min_score;
self
}
pub fn lexical_timeout_ms(mut self, timeout_ms: u64) -> Self {
self.lexical_options.timeout_ms = Some(timeout_ms);
self
}
pub fn lexical_parallel(mut self, parallel: bool) -> Self {
self.lexical_options.parallel = parallel;
self
}
pub fn sort_by(mut self, sort_by: SortField) -> Self {
self.lexical_options.sort_by = sort_by;
self
}
pub fn vector_score_mode(mut self, score_mode: VectorScoreMode) -> Self {
self.vector_options.score_mode = score_mode;
self
}
pub fn vector_min_score(mut self, min_score: f32) -> Self {
self.vector_options.min_score = min_score;
self
}
pub fn build(self) -> SearchRequest {
let query = if let Some(dsl) = self.dsl {
SearchQuery::Dsl(dsl)
} else {
match (self.lexical_query, self.vector_query) {
(Some(lexical), Some(vector)) => SearchQuery::Hybrid {
lexical,
vector,
mode: HybridMode::default(),
},
(Some(lexical), None) => SearchQuery::Lexical(lexical),
(None, Some(vector)) => SearchQuery::Vector(vector),
(None, None) => SearchQuery::Dsl(String::new()),
}
};
SearchRequest {
query,
limit: self.limit,
offset: self.offset,
fusion_algorithm: self.fusion_algorithm,
filter_query: self.filter_query,
lexical_options: self.lexical_options,
vector_options: self.vector_options,
}
}
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: String,
pub score: f32,
pub document: Option<crate::data::Document>,
}