use std::collections::HashMap;
use std::sync::Arc;
use crate::candidate_gate::AllowedSet;
use crate::filtered_vector_search::ScoredResult;
#[derive(Debug, Clone)]
pub struct Bm25Params {
pub k1: f32,
pub b: f32,
pub avgdl: f32,
pub total_docs: u64,
}
impl Default for Bm25Params {
fn default() -> Self {
Self {
k1: 1.2,
b: 0.75,
avgdl: 100.0,
total_docs: 1_000_000,
}
}
}
impl Bm25Params {
pub fn idf(&self, doc_freq: u64) -> f32 {
let n = self.total_docs as f32;
let df = doc_freq as f32;
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
}
pub fn term_score(&self, tf: f32, doc_len: f32, idf: f32) -> f32 {
let numerator = tf * (self.k1 + 1.0);
let denominator = tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avgdl);
idf * numerator / denominator
}
}
#[derive(Debug, Clone)]
pub struct PostingList {
pub term: String,
pub doc_ids: Vec<u64>,
pub term_freqs: Vec<u32>,
pub doc_freq: u64,
}
impl PostingList {
pub fn new(term: impl Into<String>, entries: Vec<(u64, u32)>) -> Self {
let term = term.into();
let doc_freq = entries.len() as u64;
let mut doc_ids = Vec::with_capacity(entries.len());
let mut term_freqs = Vec::with_capacity(entries.len());
for (doc_id, tf) in entries {
doc_ids.push(doc_id);
term_freqs.push(tf);
}
Self {
term,
doc_ids,
term_freqs,
doc_freq,
}
}
pub fn intersect_with_allowed(&self, allowed: &AllowedSet) -> Vec<(u64, u32)> {
match allowed {
AllowedSet::All => {
self.doc_ids.iter()
.zip(self.term_freqs.iter())
.map(|(&id, &tf)| (id, tf))
.collect()
}
AllowedSet::None => vec![],
_ => {
self.doc_ids.iter()
.zip(self.term_freqs.iter())
.filter(|&(&id, _)| allowed.contains(id))
.map(|(&id, &tf)| (id, tf))
.collect()
}
}
}
}
pub trait InvertedIndex: Send + Sync {
fn get_posting_list(&self, term: &str) -> Option<PostingList>;
fn get_doc_length(&self, doc_id: u64) -> Option<u32>;
fn get_params(&self) -> &Bm25Params;
}
pub struct FilteredBm25Executor<I: InvertedIndex> {
index: Arc<I>,
}
impl<I: InvertedIndex> FilteredBm25Executor<I> {
pub fn new(index: Arc<I>) -> Self {
Self { index }
}
pub fn search(
&self,
query: &str,
k: usize,
allowed: &AllowedSet,
) -> Vec<ScoredResult> {
if allowed.is_empty() {
return vec![];
}
let terms: Vec<&str> = query
.split_whitespace()
.filter(|t| t.len() >= 2) .collect();
if terms.is_empty() {
return vec![];
}
let mut posting_lists: Vec<PostingList> = terms
.iter()
.filter_map(|t| self.index.get_posting_list(t))
.collect();
posting_lists.sort_by_key(|pl| pl.doc_freq);
let candidates = self.progressive_intersection(&posting_lists, allowed);
if candidates.is_empty() {
return vec![];
}
let params = self.index.get_params();
let scores = self.score_candidates(&candidates, &posting_lists, params);
self.top_k(scores, k)
}
fn progressive_intersection(
&self,
posting_lists: &[PostingList],
allowed: &AllowedSet,
) -> HashMap<u64, Vec<u32>> {
if posting_lists.is_empty() {
return HashMap::new();
}
let first = &posting_lists[0];
let mut candidates: HashMap<u64, Vec<u32>> = first
.intersect_with_allowed(allowed)
.into_iter()
.map(|(id, tf)| (id, vec![tf]))
.collect();
for (_term_idx, posting_list) in posting_lists.iter().enumerate().skip(1) {
let term_postings: HashMap<u64, u32> = posting_list
.doc_ids.iter()
.zip(posting_list.term_freqs.iter())
.map(|(&id, &tf)| (id, tf))
.collect();
candidates.retain(|doc_id, tfs| {
if let Some(&tf) = term_postings.get(doc_id) {
tfs.push(tf);
true
} else {
false
}
});
if candidates.is_empty() {
break;
}
}
candidates
}
fn score_candidates(
&self,
candidates: &HashMap<u64, Vec<u32>>,
posting_lists: &[PostingList],
params: &Bm25Params,
) -> Vec<ScoredResult> {
let idfs: Vec<f32> = posting_lists
.iter()
.map(|pl| params.idf(pl.doc_freq))
.collect();
candidates
.iter()
.filter_map(|(&doc_id, tfs)| {
let doc_len = self.index.get_doc_length(doc_id)? as f32;
let score: f32 = tfs.iter()
.zip(idfs.iter())
.map(|(&tf, &idf)| params.term_score(tf as f32, doc_len, idf))
.sum();
Some(ScoredResult::new(doc_id, score))
})
.collect()
}
fn top_k(&self, mut scores: Vec<ScoredResult>, k: usize) -> Vec<ScoredResult> {
scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
scores
}
}
pub struct DisjunctiveBm25Executor<I: InvertedIndex> {
index: Arc<I>,
}
impl<I: InvertedIndex> DisjunctiveBm25Executor<I> {
pub fn new(index: Arc<I>) -> Self {
Self { index }
}
pub fn search(
&self,
query: &str,
k: usize,
allowed: &AllowedSet,
) -> Vec<ScoredResult> {
if allowed.is_empty() {
return vec![];
}
let terms: Vec<&str> = query.split_whitespace().collect();
if terms.is_empty() {
return vec![];
}
let posting_lists: Vec<PostingList> = terms
.iter()
.filter_map(|t| self.index.get_posting_list(t))
.collect();
let params = self.index.get_params();
let mut scores: HashMap<u64, f32> = HashMap::new();
for posting_list in &posting_lists {
let idf = params.idf(posting_list.doc_freq);
for (&doc_id, &tf) in posting_list.doc_ids.iter().zip(posting_list.term_freqs.iter()) {
if !allowed.contains(doc_id) {
continue;
}
if let Some(doc_len) = self.index.get_doc_length(doc_id) {
let term_score = params.term_score(tf as f32, doc_len as f32, idf);
*scores.entry(doc_id).or_insert(0.0) += term_score;
}
}
}
let mut results: Vec<ScoredResult> = scores
.into_iter()
.map(|(id, score)| ScoredResult::new(id, score))
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
results
}
}
#[derive(Debug, Clone)]
pub struct PositionalPosting {
pub doc_id: u64,
pub positions: Vec<u32>,
}
pub trait PositionalIndex: InvertedIndex {
fn get_positional_posting(&self, term: &str) -> Option<Vec<PositionalPosting>>;
}
pub struct FilteredPhraseExecutor<I: PositionalIndex> {
index: Arc<I>,
}
impl<I: PositionalIndex> FilteredPhraseExecutor<I> {
pub fn new(index: Arc<I>) -> Self {
Self { index }
}
pub fn search(
&self,
phrase: &[&str],
k: usize,
allowed: &AllowedSet,
) -> Vec<ScoredResult> {
if phrase.is_empty() || allowed.is_empty() {
return vec![];
}
let mut positional_postings: Vec<Vec<PositionalPosting>> = vec![];
for term in phrase {
match self.index.get_positional_posting(term) {
Some(postings) => positional_postings.push(postings),
None => return vec![], }
}
let candidates = self.find_phrase_matches(&positional_postings, allowed);
let params = self.index.get_params();
let results: Vec<ScoredResult> = candidates
.into_iter()
.filter_map(|(doc_id, phrase_freq)| {
let doc_len = self.index.get_doc_length(doc_id)? as f32;
let min_df = positional_postings.iter()
.map(|pp| pp.len() as u64)
.min()
.unwrap_or(1);
let idf = params.idf(min_df);
let score = params.term_score(phrase_freq as f32, doc_len, idf);
Some(ScoredResult::new(doc_id, score))
})
.collect();
let mut results = results;
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
results
}
fn find_phrase_matches(
&self,
positional_postings: &[Vec<PositionalPosting>],
allowed: &AllowedSet,
) -> Vec<(u64, u32)> {
if positional_postings.is_empty() {
return vec![];
}
let indexed: Vec<HashMap<u64, &Vec<u32>>> = positional_postings
.iter()
.map(|postings| {
postings.iter()
.filter(|p| allowed.contains(p.doc_id))
.map(|p| (p.doc_id, &p.positions))
.collect()
})
.collect();
let first_docs: std::collections::HashSet<u64> = indexed[0].keys().copied().collect();
let candidate_docs: Vec<u64> = first_docs
.into_iter()
.filter(|doc_id| indexed.iter().all(|idx| idx.contains_key(doc_id)))
.collect();
let mut matches = vec![];
for doc_id in candidate_docs {
let mut phrase_count = 0u32;
let first_positions = indexed[0].get(&doc_id).unwrap();
'outer: for &start_pos in first_positions.iter() {
for (term_idx, term_positions) in indexed.iter().enumerate().skip(1) {
let expected_pos = start_pos + term_idx as u32;
let positions = term_positions.get(&doc_id).unwrap();
if positions.binary_search(&expected_pos).is_err() {
continue 'outer;
}
}
phrase_count += 1;
}
if phrase_count > 0 {
matches.push((doc_id, phrase_count));
}
}
matches
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::candidate_gate::AllowedSet;
struct MockIndex {
postings: HashMap<String, PostingList>,
doc_lengths: HashMap<u64, u32>,
params: Bm25Params,
}
impl MockIndex {
fn new() -> Self {
let mut postings = HashMap::new();
let mut doc_lengths = HashMap::new();
postings.insert("rust".to_string(), PostingList::new("rust", vec![
(1, 3), (2, 1), (3, 2), (5, 1),
]));
postings.insert("database".to_string(), PostingList::new("database", vec![
(1, 1), (3, 4), (4, 1),
]));
postings.insert("vector".to_string(), PostingList::new("vector", vec![
(1, 2), (2, 3), (4, 1), (5, 2),
]));
for i in 1..=5 {
doc_lengths.insert(i, 100);
}
Self {
postings,
doc_lengths,
params: Bm25Params {
k1: 1.2,
b: 0.75,
avgdl: 100.0,
total_docs: 1000,
},
}
}
}
impl InvertedIndex for MockIndex {
fn get_posting_list(&self, term: &str) -> Option<PostingList> {
self.postings.get(term).cloned()
}
fn get_doc_length(&self, doc_id: u64) -> Option<u32> {
self.doc_lengths.get(&doc_id).copied()
}
fn get_params(&self) -> &Bm25Params {
&self.params
}
}
#[test]
fn test_conjunctive_search() {
let index = Arc::new(MockIndex::new());
let executor = FilteredBm25Executor::new(index);
let results = executor.search("rust database", 10, &AllowedSet::All);
assert_eq!(results.len(), 2);
let doc_ids: Vec<u64> = results.iter().map(|r| r.doc_id).collect();
assert!(doc_ids.contains(&1));
assert!(doc_ids.contains(&3));
}
#[test]
fn test_filter_pushdown() {
let index = Arc::new(MockIndex::new());
let executor = FilteredBm25Executor::new(index);
let allowed = AllowedSet::SortedVec(Arc::new(vec![1]));
let results = executor.search("rust database", 10, &allowed);
assert_eq!(results.len(), 1);
assert_eq!(results[0].doc_id, 1);
}
#[test]
fn test_empty_allowed_set() {
let index = Arc::new(MockIndex::new());
let executor = FilteredBm25Executor::new(index);
let results = executor.search("rust", 10, &AllowedSet::None);
assert!(results.is_empty());
}
#[test]
fn test_disjunctive_search() {
let index = Arc::new(MockIndex::new());
let executor = DisjunctiveBm25Executor::new(index);
let results = executor.search("rust database", 10, &AllowedSet::All);
assert!(results.len() >= 4);
}
#[test]
fn test_term_ordering_by_df() {
let mut pl1 = PostingList::new("rare", vec![(1, 1), (2, 1)]); let mut pl2 = PostingList::new("common", vec![(1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]);
let mut lists = vec![pl2.clone(), pl1.clone()];
lists.sort_by_key(|pl| pl.doc_freq);
assert_eq!(lists[0].term, "rare");
assert_eq!(lists[1].term, "common");
}
#[test]
fn test_bm25_scoring() {
let params = Bm25Params::default();
let idf_rare = params.idf(10);
let idf_common = params.idf(100_000);
assert!(idf_rare > idf_common);
let score_tf_1 = params.term_score(1.0, 100.0, idf_rare);
let score_tf_5 = params.term_score(5.0, 100.0, idf_rare);
assert!(score_tf_5 > score_tf_1);
}
}