ruve-db 0.1.1

A hybrid vector and full-text search database with HNSW approximate nearest-neighbour indexing and BM25
Documentation
use std::collections::HashMap;
use serde::{Deserialize, Serialize};

/// association doc_id -> term frequency
type DocFrequency = Vec<(String, u32)>;

#[derive(Debug, Serialize, Deserialize)]
pub struct Bm25Index {
    term_to_counts: HashMap<String, DocFrequency>, // maps every known term to the docs that contain it — e.g. {"rust": [("doc1", 3), ("doc2", 1)]}
    doc_lengths: HashMap<String, u32>,      // total token count per document, used to normalise tf — e.g. {"doc1": 120, "doc2": 45}
    avg_dl: f32,                            // average document length across all docs, recomputed on insert/remove — e.g. 82.5
    n_docs: u32,                            // total number of indexed documents, used in IDF — e.g. 1000
}

impl Bm25Index {
    pub fn new() -> Self {
        Bm25Index {
            term_to_counts: HashMap::new(),
            doc_lengths: HashMap::new(),
            avg_dl: 0.0,
            n_docs: 0,
        }
    }

    pub fn load(path: &str) -> Self {
        match std::fs::read_to_string(path) {
            Ok(contents) => serde_json::from_str(&contents).unwrap(),
            Err(_) => Self::new(),
        }
    }

    pub fn save(&self, path: &str) {
        let contents = serde_json::to_string(self).unwrap();
        std::fs::write(path, contents).unwrap();
    }

    // index a record by calculating the term frequency of each token and updating the term_to_counts and doc_lengths mappings accordingly
    pub fn index_record(&mut self, doc_id: &str, tokens: &[String]) {
        // count term frequency for each token in this document
        let mut term_frequencies: HashMap<&str, u32> = HashMap::new();
        for token in tokens {
            if term_frequencies.contains_key(token.as_str()) {
                term_frequencies.insert(token.as_str(), term_frequencies.get(token.as_str()).copied().unwrap() + 1);
            } else {
                term_frequencies.insert(token.as_str(), 1);
            }
        }

        // add the doc to the term to counts mapping
        for (term, term_frequency) in &term_frequencies {
            self.term_to_counts
                .entry(term.to_string())
                .or_default()
                .push((doc_id.to_string(), *term_frequency));
        }

        // add the doc length to the mapping
        self.doc_lengths.insert(doc_id.to_string(), tokens.len() as u32);

        // update number of docs and average length
        self.n_docs += 1;
        self.avg_dl = self.doc_lengths.values().sum::<u32>() as f32 / self.n_docs as f32;
    }

    // remove a record by removing its doc_id from the term_to_counts mapping for each token, and removing its length from the doc_lengths mapping
    pub fn remove_record(&mut self, doc_id: &str, tokens: &[String]) {
        // remove the doc from the term to counts mapping, starting from each token
        for token in tokens {
            if let Some(doc_freqs) = self.term_to_counts.get_mut(token) {
                doc_freqs.retain(|(id, _)| id != doc_id);
                if doc_freqs.is_empty() {
                    self.term_to_counts.remove(token);
                }
            }
        }

        // remove the doc length from the mapping
        if let Some(length) = self.doc_lengths.remove(doc_id) {
            self.n_docs -= 1;
            if self.n_docs > 0 {
                // update average document length using the formula: new_avg = (old_avg * old_n - removed_length) / new_n
                self.avg_dl = (self.avg_dl * (self.n_docs as f32 + 1.0) - length as f32) / self.n_docs as f32;
            } else {
                self.avg_dl = 0.0;
            }
        }
    }

    /// Returns (doc_id, bm25_score) for all documents containing at least one query term.
    pub fn score(&self, query_tokens: &[String]) -> Vec<(String, f32)> {
        const K: f32 = 1.5; // How much term-frequency impacts result
        const B: f32 = 0.75; // How much long documents are penalized?

        let mut scores: HashMap<String, f32> = HashMap::new();

        for token in query_tokens {
            // get the documents that contains this token and their term frequency
            let Some(postings) = self.term_to_counts.get(token) else { continue };

            let documents_count = postings.len() as f32;
            
            // IDF describes how rare a term is across all documents, and is used to give more weight to rare terms.
            // high IDF = rare term, low IDF = common term. The 0.5 is a smoothing factor to prevent division by zero.
            let idf_numerator = self.n_docs as f32 - documents_count + 0.5; // docs that do NOT contain this term
            let idf_denominator = documents_count + 0.5;                    // docs that DO contain this term
            let idf = (idf_numerator / idf_denominator + 1.0).ln();         // +1 ensures idf > 0 even for very common terms, the ln helps to compress the range of idf values (very rare terms will have very high idf, and we don't want them to dominate the scores)

            // given each doc and its term frequency for this token
            for (doc_id, term_frequency) in postings {
                // get total document length, to penalize long documents
                let doc_len = *self.doc_lengths.get(doc_id).unwrap_or(&0) as f32;

                let term_frequency = *term_frequency as f32;
                let tf_numerator = term_frequency * (K + 1.0);                     // raw frequency scaled up — would grow unbounded without the denominator
                
                let length_norm = 1.0 - B + B * doc_len / self.avg_dl;             // start from a base of 1 - B, then multiply the long document penalizer for the document lenght, then get the normalized lenght > 1 if doc is longer than average, < 1 if shorter
                let tf_denominator = term_frequency + K * length_norm;             // an high term frequency will be scaled down by the denominator, and long documents will have a higher denominator, thus penalizing them
                
                let tf_norm = tf_numerator / tf_denominator;

                *scores.entry(doc_id.clone()).or_insert(0.0) += idf * tf_norm;
            }
        }

        let mut results: Vec<(String, f32)> = scores.into_iter().collect();
        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
        results
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn toks(words: &[&str]) -> Vec<String> {
        words.iter().map(|s| s.to_string()).collect()
    }

    #[test]
    fn matching_doc_is_returned() {
        let mut idx = Bm25Index::new();
        idx.index_record("doc1", &toks(&["rust", "fast"]));
        idx.index_record("doc2", &toks(&["python", "slow"]));

        let scores = idx.score(&toks(&["rust"]));
        assert_eq!(scores.len(), 1);
        assert_eq!(scores[0].0, "doc1");
    }

    #[test]
    fn results_are_sorted_by_score_descending() {
        let mut idx = Bm25Index::new();
        // doc2 contains the query term three times — should rank higher
        idx.index_record("doc1", &toks(&["rust"]));
        idx.index_record("doc2", &toks(&["rust", "rust", "rust"]));

        let scores = idx.score(&toks(&["rust"]));
        assert_eq!(scores.len(), 2);
        assert_eq!(scores[0].0, "doc2");
        assert!(scores[0].1 > scores[1].1);
    }

    #[test]
    fn no_match_returns_empty() {
        let mut idx = Bm25Index::new();
        idx.index_record("doc1", &toks(&["rust"]));

        assert!(idx.score(&toks(&["python"])).is_empty());
    }

    #[test]
    fn query_on_empty_index_returns_empty() {
        let idx = Bm25Index::new();
        assert!(idx.score(&toks(&["rust"])).is_empty());
    }

    #[test]
    fn remove_record_excludes_it_from_results() {
        let mut idx = Bm25Index::new();
        idx.index_record("doc1", &toks(&["rust"]));
        idx.remove_record("doc1", &toks(&["rust"]));

        assert!(idx.score(&toks(&["rust"])).is_empty());
    }

    #[test]
    fn remove_one_of_two_docs_leaves_the_other() {
        let mut idx = Bm25Index::new();
        idx.index_record("doc1", &toks(&["rust"]));
        idx.index_record("doc2", &toks(&["rust"]));
        idx.remove_record("doc1", &toks(&["rust"]));

        let scores = idx.score(&toks(&["rust"]));
        assert_eq!(scores.len(), 1);
        assert_eq!(scores[0].0, "doc2");
    }

    #[test]
    fn rare_term_scores_higher_than_common_term() {
        let mut idx = Bm25Index::new();
        idx.index_record("doc1", &toks(&["common", "rare"]));
        idx.index_record("doc2", &toks(&["common"]));
        idx.index_record("doc3", &toks(&["common"]));

        let scores = idx.score(&toks(&["rare"]));
        assert_eq!(scores.len(), 1);
        assert!(scores[0].1 > 0.0);
    }
}