use std::collections::HashMap;
use nodedb_types::Surrogate;
use crate::posting::Posting;
const FULL_PHRASE_BOOST: f32 = 3.0;
pub fn phrase_boost(query_tokens: &[String], doc_postings: &HashMap<usize, &Posting>) -> f32 {
if query_tokens.len() < 2 {
return 1.0;
}
let mut consecutive_pairs = 0u32;
let total_pairs = (query_tokens.len() - 1) as u32;
for i in 0..query_tokens.len() - 1 {
let (Some(posting_a), Some(posting_b)) = (doc_postings.get(&i), doc_postings.get(&(i + 1)))
else {
continue;
};
if has_consecutive_positions(&posting_a.positions, &posting_b.positions) {
consecutive_pairs += 1;
}
}
if consecutive_pairs == 0 {
return 1.0;
}
let ratio = consecutive_pairs as f32 / total_pairs as f32;
1.0 + ratio * (FULL_PHRASE_BOOST - 1.0)
}
fn has_consecutive_positions(a: &[u32], b: &[u32]) -> bool {
let mut i = 0;
let mut j = 0;
while i < a.len() && j < b.len() {
let target = a[i] + 1;
match b[j].cmp(&target) {
std::cmp::Ordering::Equal => return true,
std::cmp::Ordering::Less => j += 1,
std::cmp::Ordering::Greater => i += 1,
}
}
false
}
pub(crate) fn collect_doc_postings<'a>(
query_tokens: &[String],
term_postings: &'a [(Vec<Posting>, bool)],
) -> HashMap<Surrogate, HashMap<usize, &'a Posting>> {
let mut doc_map: HashMap<Surrogate, HashMap<usize, &Posting>> = HashMap::new();
for (token_idx, (postings, _is_fuzzy)) in term_postings.iter().enumerate() {
for posting in postings {
doc_map
.entry(posting.doc_id)
.or_default()
.insert(token_idx, posting);
}
}
if query_tokens.len() >= 2 {
doc_map.retain(|_, postings| postings.len() >= 2);
}
doc_map
}
#[cfg(test)]
mod tests {
use super::*;
fn make_posting(doc_id: u32, positions: Vec<u32>) -> Posting {
use nodedb_types::Surrogate;
Posting {
doc_id: Surrogate(doc_id),
term_freq: positions.len() as u32,
positions,
}
}
#[test]
fn full_phrase_match() {
let tokens = vec!["hello".into(), "world".into()];
let p0 = make_posting(1, vec![0, 5]);
let p1 = make_posting(1, vec![1, 8]);
let mut doc = HashMap::new();
doc.insert(0, &p0);
doc.insert(1, &p1);
let boost = phrase_boost(&tokens, &doc);
assert!((boost - FULL_PHRASE_BOOST).abs() < f32::EPSILON);
}
#[test]
fn no_phrase_match() {
let tokens = vec!["hello".into(), "world".into()];
let p0 = make_posting(1, vec![0]);
let p1 = make_posting(1, vec![5]); let mut doc = HashMap::new();
doc.insert(0, &p0);
doc.insert(1, &p1);
let boost = phrase_boost(&tokens, &doc);
assert!((boost - 1.0).abs() < f32::EPSILON);
}
#[test]
fn partial_phrase_match() {
let tokens = vec!["the".into(), "quick".into(), "brown".into()];
let p0 = make_posting(1, vec![0]);
let p1 = make_posting(1, vec![1]); let p2 = make_posting(1, vec![5]); let mut doc = HashMap::new();
doc.insert(0, &p0);
doc.insert(1, &p1);
doc.insert(2, &p2);
let boost = phrase_boost(&tokens, &doc);
assert!((boost - 2.0).abs() < f32::EPSILON);
}
#[test]
fn single_token_no_boost() {
let tokens = vec!["hello".into()];
let doc: HashMap<usize, &Posting> = HashMap::new();
assert!((phrase_boost(&tokens, &doc) - 1.0).abs() < f32::EPSILON);
}
#[test]
fn consecutive_positions_merge_scan() {
assert!(has_consecutive_positions(&[0, 3, 7], &[1, 5, 9])); assert!(has_consecutive_positions(&[2, 5, 10], &[3, 8, 11])); assert!(!has_consecutive_positions(&[0, 5, 10], &[3, 8, 15]));
assert!(!has_consecutive_positions(&[], &[1, 2]));
assert!(!has_consecutive_positions(&[0], &[]));
}
}