zvec-bindings 0.4.1

Idiomatic Rust bindings for zvec vector database
use std::cmp::Ordering;
use std::collections::HashMap;

use crate::types::MetricType;

pub struct RrfReRanker {
    topn: usize,
    rank_constant: i32,
}

impl RrfReRanker {
    pub fn new(topn: usize) -> Self {
        Self {
            topn,
            rank_constant: 60,
        }
    }

    pub fn with_rank_constant(mut self, rank_constant: i32) -> Self {
        self.rank_constant = rank_constant;
        self
    }

    pub fn topn(&self) -> usize {
        self.topn
    }

    pub fn rank_constant(&self) -> i32 {
        self.rank_constant
    }

    fn rrf_score(&self, rank: usize) -> f64 {
        1.0 / (self.rank_constant as f64 + rank as f64 + 1.0)
    }

    pub fn rerank<K: AsRef<str>>(
        &self,
        query_results: &HashMap<K, Vec<(String, f32)>>,
    ) -> Vec<(String, f32)> {
        let mut rrf_scores: HashMap<String, f64> = HashMap::new();

        for docs in query_results.values() {
            for (rank, (doc_id, _)) in docs.iter().enumerate() {
                let rrf_score = self.rrf_score(rank);
                *rrf_scores.entry(doc_id.clone()).or_insert(0.0) += rrf_score;
            }
        }

        let mut scored: Vec<_> = rrf_scores.into_iter().collect();
        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
        scored.truncate(self.topn);

        scored
            .into_iter()
            .map(|(id, score)| (id, score as f32))
            .collect()
    }
}

pub struct WeightedReRanker {
    topn: usize,
    metric: MetricType,
    weights: HashMap<String, f64>,
}

impl WeightedReRanker {
    pub fn new(topn: usize, metric: MetricType) -> Self {
        Self {
            topn,
            metric,
            weights: HashMap::new(),
        }
    }

    pub fn with_weight(mut self, field: impl Into<String>, weight: f64) -> Self {
        self.weights.insert(field.into(), weight);
        self
    }

    pub fn with_weights(mut self, weights: HashMap<String, f64>) -> Self {
        self.weights = weights;
        self
    }

    pub fn topn(&self) -> usize {
        self.topn
    }

    pub fn metric(&self) -> MetricType {
        self.metric
    }

    fn normalize_score(&self, score: f32) -> f64 {
        match self.metric {
            MetricType::L2 => 1.0 - 2.0 * (score as f64).atan() / std::f64::consts::PI,
            MetricType::Ip => 0.5 + (score as f64).atan() / std::f64::consts::PI,
            MetricType::Cosine => 1.0 - score as f64 / 2.0,
            _ => score as f64,
        }
    }

    pub fn rerank<K: AsRef<str>>(
        &self,
        query_results: &HashMap<K, Vec<(String, f32)>>,
    ) -> Vec<(String, f32)> {
        let mut weighted_scores: HashMap<String, f64> = HashMap::new();

        for (vector_name, docs) in query_results.iter() {
            let weight = self
                .weights
                .get(vector_name.as_ref())
                .copied()
                .unwrap_or(1.0);
            for (doc_id, score) in docs.iter() {
                let normalized = self.normalize_score(*score);
                let weighted = normalized * weight;
                *weighted_scores.entry(doc_id.clone()).or_insert(0.0) += weighted;
            }
        }

        let mut scored: Vec<_> = weighted_scores.into_iter().collect();
        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
        scored.truncate(self.topn);

        scored
            .into_iter()
            .map(|(id, score)| (id, score as f32))
            .collect()
    }
}