use crate::dsl::Field;
use crate::segment::SegmentReader;
use super::VectorResultScorer;
use super::combiner::MultiValueCombiner;
use crate::query::traits::{CountFuture, Query, Scorer, ScorerFuture};
#[derive(Debug, Clone)]
pub struct DenseVectorQuery {
pub field: Field,
pub vector: Vec<f32>,
pub nprobe: usize,
pub rerank_factor: f32,
pub combiner: MultiValueCombiner,
}
impl std::fmt::Display for DenseVectorQuery {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Dense({}, dim={}, nprobe={}, rerank={})",
self.field.0,
self.vector.len(),
self.nprobe,
self.rerank_factor
)
}
}
impl DenseVectorQuery {
pub fn new(field: Field, vector: Vec<f32>) -> Self {
Self {
field,
vector,
nprobe: 32,
rerank_factor: 3.0,
combiner: MultiValueCombiner::Max,
}
}
pub fn with_nprobe(mut self, nprobe: usize) -> Self {
self.nprobe = nprobe;
self
}
pub fn with_rerank_factor(mut self, factor: f32) -> Self {
self.rerank_factor = factor;
self
}
pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
self.combiner = combiner;
self
}
}
impl Query for DenseVectorQuery {
fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
let field = self.field;
let vector = self.vector.clone();
let nprobe = self.nprobe;
let rerank_factor = self.rerank_factor;
let combiner = self.combiner;
Box::pin(async move {
let results = reader
.search_dense_vector(field, &vector, limit, nprobe, rerank_factor, combiner)
.await?;
Ok(Box::new(VectorResultScorer::new(results, field.0)) as Box<dyn Scorer>)
})
}
#[cfg(feature = "sync")]
fn scorer_sync<'a>(
&self,
reader: &'a SegmentReader,
limit: usize,
) -> crate::Result<Box<dyn Scorer + 'a>> {
let results = reader.search_dense_vector_sync(
self.field,
&self.vector,
limit,
self.nprobe,
self.rerank_factor,
self.combiner,
)?;
Ok(Box::new(VectorResultScorer::new(results, self.field.0)) as Box<dyn Scorer>)
}
fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
Box::pin(async move { Ok(u32::MAX) })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dense_vector_query_builder() {
let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
.with_nprobe(64)
.with_rerank_factor(5.0);
assert_eq!(query.field, Field(0));
assert_eq!(query.vector.len(), 3);
assert_eq!(query.nprobe, 64);
assert_eq!(query.rerank_factor, 5.0);
}
}