use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::vector::search::searcher::VectorIndexQueryResults;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RankingConfig {
pub method: RankingMethod,
pub normalize_scores: bool,
pub boost_factors: std::collections::HashMap<String, f32>,
}
impl Default for RankingConfig {
fn default() -> Self {
Self {
method: RankingMethod::Similarity,
normalize_scores: true,
boost_factors: std::collections::HashMap::new(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RankingMethod {
Similarity,
Distance,
Weighted,
Custom,
}
pub struct VectorRanker {
config: RankingConfig,
}
impl VectorRanker {
pub fn new(config: RankingConfig) -> Self {
Self { config }
}
pub fn rank_results(&self, results: &mut VectorIndexQueryResults) -> Result<()> {
match self.config.method {
RankingMethod::Similarity => {
results.sort_by_similarity();
}
RankingMethod::Distance => {
results.sort_by_distance();
}
RankingMethod::Weighted => {
self.apply_weighted_ranking(results)?;
}
RankingMethod::Custom => {
self.apply_custom_ranking(results)?;
}
}
if self.config.normalize_scores {
self.normalize_scores(results)?;
}
Ok(())
}
fn apply_weighted_ranking(&self, results: &mut VectorIndexQueryResults) -> Result<()> {
results.sort_by_similarity();
Ok(())
}
fn apply_custom_ranking(&self, _results: &mut VectorIndexQueryResults) -> Result<()> {
Ok(())
}
fn normalize_scores(&self, results: &mut VectorIndexQueryResults) -> Result<()> {
if results.results.is_empty() {
return Ok(());
}
let max_score = results
.results
.iter()
.map(|r| r.similarity)
.fold(f32::NEG_INFINITY, f32::max);
let min_score = results
.results
.iter()
.map(|r| r.similarity)
.fold(f32::INFINITY, f32::min);
let range = max_score - min_score;
if range > 0.0 {
for result in &mut results.results {
result.similarity = (result.similarity - min_score) / range;
}
}
Ok(())
}
pub fn config(&self) -> &RankingConfig {
&self.config
}
}