agent-source-repository 0.1.0

Agent Source Repository local context registry for coding agents
Documentation
use std::collections::HashMap;

#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub struct Bm25Index {
    postings: HashMap<String, Vec<(usize, f32)>>,
    idf: HashMap<String, f64>,
    doc_lengths: Vec<f32>,
    avg_dl: f32,
    num_docs: usize,
    k1: f32,
    b: f32,
}

impl Bm25Index {
    pub fn new(documents: &[Vec<String>]) -> Self {
        let num_docs = documents.len();
        let mut postings: HashMap<String, Vec<(usize, f32)>> = HashMap::new();
        let mut doc_lengths = Vec::with_capacity(num_docs);
        let mut df: HashMap<String, usize> = HashMap::new();

        for (doc_id, tokens) in documents.iter().enumerate() {
            doc_lengths.push(tokens.len() as f32);

            let mut tf: HashMap<&str, f32> = HashMap::new();
            for token in tokens {
                *tf.entry(token.as_str()).or_default() += 1.0;
            }

            for (term, freq) in tf {
                postings
                    .entry(term.to_string())
                    .or_default()
                    .push((doc_id, freq));
                *df.entry(term.to_string()).or_default() += 1;
            }
        }

        let avg_dl = if num_docs > 0 {
            doc_lengths.iter().sum::<f32>() / num_docs as f32
        } else {
            0.0
        };

        let idf: HashMap<String, f64> = df
            .iter()
            .map(|(term, &doc_freq)| {
                let n = num_docs as f64;
                let df = doc_freq as f64;
                let idf_val = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
                (term.clone(), idf_val)
            })
            .collect();

        Self {
            postings,
            idf,
            doc_lengths,
            avg_dl,
            num_docs,
            k1: 1.5,
            b: 0.75,
        }
    }

    pub fn get_scores(&self, query_tokens: &[String], weight_mask: Option<&[bool]>) -> Vec<f32> {
        let mut scores = vec![0.0f32; self.num_docs];

        for token in query_tokens {
            let idf = match self.idf.get(token.as_str()) {
                Some(&v) => v as f32,
                None => continue,
            };

            if let Some(posting_list) = self.postings.get(token.as_str()) {
                for &(doc_id, tf) in posting_list {
                    if let Some(mask) = weight_mask {
                        if doc_id >= mask.len() || !mask[doc_id] {
                            continue;
                        }
                    }

                    let dl = self.doc_lengths[doc_id];
                    let tf_component = (tf * (self.k1 + 1.0))
                        / (tf + self.k1 * (1.0 - self.b + self.b * dl / self.avg_dl));
                    scores[doc_id] += idf * tf_component;
                }
            }
        }

        scores
    }
}