use qdrant_edge::{
AnyVariants, Condition, FieldCondition, Filter, Fusion, JsonPath, Match, NamedQuery, Payload,
Prefetch, QueryEnum, QueryRequest, ScoredPoint, ScoringQuery, SparseVector,
WithPayloadInterface, WithVector, external::ordered_float::OrderedFloat,
};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use crate::init_env::{EmbeddingPayload, SPARSE_VECTORS_NAME, VECTORS_NAME, load_qdrant_edge};
pub const RERANK_TOP_N: usize = 5;
pub const DENSE_WEIGHT: f32 = 0.5;
pub const SPARSE_WEIGHT: f32 = 0.5;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultWithScore {
pub content: String,
pub document_path: String,
pub score: f32,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct SearchResults {
pub raw: Vec<ScoredPoint>,
pub processed: Vec<ResultWithScore>,
}
fn payload_to_struct(payload: &Payload) -> Result<EmbeddingPayload, Box<dyn std::error::Error>> {
let ep: EmbeddingPayload = serde_json::from_value(serde_json::to_value(payload)?)?;
Ok(ep)
}
pub fn search(
query_embedding: Vec<f32>,
sparse_query_embedding: SparseVector,
document_paths: Option<Vec<String>>,
limit: Option<usize>,
score_threshold: Option<f32>,
) -> Result<SearchResults, Box<dyn std::error::Error>> {
let edge_shard = load_qdrant_edge()?;
let mut all_results: Vec<ResultWithScore> = vec![];
let threshold: Option<OrderedFloat<f32>> = score_threshold.map(OrderedFloat);
let top_k = limit.unwrap_or(10);
let stmt_filter = match document_paths {
Some(d) => Some(Filter::new_must(Condition::Field(
FieldCondition::new_match(
JsonPath::from_str("document_path").map_err(|_| "invalid json path")?,
Match::from(AnyVariants::Strings(d.iter().cloned().collect())),
),
))),
None => None,
};
let shard_query = QueryRequest {
prefetches: vec![
Prefetch {
prefetches: vec![],
query: Some(ScoringQuery::Vector(QueryEnum::Nearest(NamedQuery {
query: sparse_query_embedding.into(),
using: Some(SPARSE_VECTORS_NAME.to_string()),
}))),
limit: top_k,
params: None,
filter: None,
score_threshold: threshold,
},
Prefetch {
prefetches: vec![],
query: Some(ScoringQuery::Vector(QueryEnum::Nearest(NamedQuery {
query: query_embedding.into(),
using: Some(VECTORS_NAME.to_string()),
}))),
limit: top_k,
params: None,
filter: None,
score_threshold: None,
},
],
query: Some(ScoringQuery::Fusion(Fusion::Rrf {
k: RERANK_TOP_N,
weights: Some(vec![
OrderedFloat(SPARSE_WEIGHT),
OrderedFloat(DENSE_WEIGHT),
]),
})),
filter: stmt_filter,
score_threshold: None,
limit: top_k,
offset: 0,
params: None,
with_vector: WithVector::Bool(false),
with_payload: WithPayloadInterface::Bool(true),
};
let results = edge_shard.query(shard_query)?;
for r in results.clone() {
let payload = match r.payload {
Some(p) => p,
None => return Err("Found a None payload when searching".into()),
};
let embd_payload = payload_to_struct(&payload)?;
let scored_result = ResultWithScore {
content: embd_payload.content,
document_path: embd_payload.document_path,
score: r.score,
};
all_results.push(scored_result);
}
Ok(SearchResults {
raw: results,
processed: all_results,
})
}