langchainrust 0.2.15

A LangChain-inspired framework for building LLM applications in Rust. Supports OpenAI, Agents, Tools, Memory, Chains, RAG, BM25, Hybrid Retrieval, LangGraph, HyDE, Reranking, MultiQuery, and native Function Calling.
use super::algorithm::{bm25_score, BM25Params};
use super::index::BM25Index;
use super::tokenizer::Tokenizer;
use crate::vector_stores::{Document, SearchResult};

pub struct BM25Retriever {
    index: BM25Index,
    tokenizer: Tokenizer,
}

impl BM25Retriever {
    pub fn new() -> Self {
        Self {
            index: BM25Index::new(),
            tokenizer: Tokenizer::new(),
        }
    }

    pub fn with_params(k1: f64, b: f64) -> Self {
        Self {
            index: BM25Index::with_params(BM25Params::with_values(k1, b)),
            tokenizer: Tokenizer::new(),
        }
    }

    pub fn with_tokenizer(tokenizer: Tokenizer) -> Self {
        Self {
            index: BM25Index::new(),
            tokenizer,
        }
    }

    pub fn add_document(&mut self, document: Document) {
        let terms = self.tokenizer.tokenize(&document.content);
        self.index.add_document(document, terms);
    }

    pub fn add_documents_sync(&mut self, documents: Vec<Document>) {
        for doc in documents {
            self.add_document(doc);
        }
    }

    pub fn search(&mut self, query: &str, k: usize) -> Vec<SearchResult> {
        if self.index.n_docs() == 0 {
            return Vec::new();
        }

        let query_terms = self.tokenizer.tokenize(query);

        if query_terms.is_empty() {
            return Vec::new();
        }

        let idf_values = self.index.compute_idf_for_terms(&query_terms);

        let mut scored_docs: Vec<(usize, f64)> = Vec::new();

        for doc_id in 0..self.index.n_docs() {
            let doc_term_freqs = self.index.get_doc_term_freq(doc_id);
            let doc_length = self.index.get_doc_length(doc_id);
            let avgdl = self.index.avgdl();
            let params = self.index.params();

            if let Some(term_freqs) = doc_term_freqs {
                let score = bm25_score(
                    &query_terms,
                    term_freqs,
                    doc_length,
                    avgdl,
                    &idf_values,
                    params,
                );

                if score > 0.0 {
                    scored_docs.push((doc_id, score));
                }
            }
        }

        scored_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

        scored_docs
            .into_iter()
            .take(k)
            .map(|(doc_id, score)| {
                let doc = self
                    .index
                    .get_document(doc_id)
                    .cloned()
                    .unwrap_or(Document::new(""));
                SearchResult {
                    document: doc,
                    score: score as f32,
                }
            })
            .collect()
    }

    pub fn len(&self) -> usize {
        self.index.n_docs()
    }

    pub fn is_empty(&self) -> bool {
        self.index.n_docs() == 0
    }

    pub fn clear(&mut self) {
        self.index.clear();
    }

    pub fn index(&self) -> &BM25Index {
        &self.index
    }
}

impl Default for BM25Retriever {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_bm25_retriever_basic() {
        let mut retriever = BM25Retriever::new();

        retriever.add_documents_sync(vec![
            Document::new("Rust is a systems programming language"),
            Document::new("Python is a scripting language"),
            Document::new("JavaScript is used for web development"),
        ]);

        assert_eq!(retriever.len(), 3);

        let results = retriever.search("programming language", 2);
        assert_eq!(results.len(), 2);

        assert!(results[0].document.content.contains("programming"));
    }

    #[test]
    fn test_bm25_retriever_chinese() {
        let mut retriever = BM25Retriever::new();

        retriever.add_documents_sync(vec![
            Document::new("Rust 是一门系统编程语言"),
            Document::new("Python 是脚本语言"),
            Document::new("JavaScript 用于网页开发"),
        ]);

        let results = retriever.search("编程语言", 2);
        assert!(results.len() > 0);

        assert!(results[0].document.content.contains("编程"));
    }

    #[test]
    fn test_bm25_retriever_empty() {
        let mut retriever = BM25Retriever::new();

        let results = retriever.search("test", 5);
        assert!(results.is_empty());
    }

    #[test]
    fn test_bm25_retriever_params() {
        let mut retriever = BM25Retriever::with_params(2.0, 0.5);

        retriever.add_documents_sync(vec![
            Document::new("Rust programming"),
            Document::new("Python scripting"),
        ]);

        let results = retriever.search("programming", 1);
        assert_eq!(results.len(), 1);
    }

    #[test]
    fn test_bm25_retriever_no_match() {
        let mut retriever = BM25Retriever::new();

        retriever.add_documents_sync(vec![
            Document::new("Rust programming language"),
            Document::new("Python scripting language"),
        ]);

        let results = retriever.search("javascript typescript", 5);
        assert!(results.is_empty());
    }
}