use std::collections::HashMap;
use sketch_oxide::membership::BinaryFuseFilter;
use xxhash_rust::xxh3::xxh3_64;
#[derive(Debug, Clone)]
pub struct NgramEntry {
pub doc_id: usize,
pub positions: Vec<usize>,
}
pub struct DocumentFilter {
pub doc_id: usize,
filter: BinaryFuseFilter,
pub ngram_count: usize,
}
impl DocumentFilter {
pub fn build(doc_id: usize, ngrams: &[u64]) -> Option<Self> {
if ngrams.is_empty() {
return None;
}
let unique: std::collections::HashSet<u64> = ngrams.iter().copied().collect();
let ngram_count = unique.len();
let filter = BinaryFuseFilter::from_items(unique, 8).ok()?;
Some(Self {
doc_id,
filter,
ngram_count,
})
}
pub fn contains(&self, hash: &u64) -> bool {
self.filter.contains(hash)
}
pub fn estimate_overlap(&self, other_ngrams: &[u64]) -> usize {
let unique: std::collections::HashSet<u64> = other_ngrams.iter().copied().collect();
unique.iter().filter(|h| self.filter.contains(h)).count()
}
}
impl std::fmt::Debug for DocumentFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DocumentFilter")
.field("doc_id", &self.doc_id)
.field("ngram_count", &self.ngram_count)
.finish()
}
}
#[derive(Debug)]
pub struct NgramIndex {
index: HashMap<u64, Vec<NgramEntry>>,
num_docs: usize,
filters: HashMap<usize, DocumentFilter>,
}
impl NgramIndex {
pub fn new() -> Self {
Self {
index: HashMap::new(),
num_docs: 0,
filters: HashMap::new(),
}
}
pub fn num_docs(&self) -> usize {
self.num_docs
}
pub fn num_ngrams(&self) -> usize {
self.index.len()
}
pub fn add_document(&mut self, doc_id: usize, ngrams: &[u64]) {
let mut positions_by_hash: HashMap<u64, Vec<usize>> = HashMap::new();
for (pos, &hash) in ngrams.iter().enumerate() {
positions_by_hash.entry(hash).or_default().push(pos);
}
for (hash, positions) in positions_by_hash {
self.index
.entry(hash)
.or_default()
.push(NgramEntry { doc_id, positions });
}
if let Some(filter) = DocumentFilter::build(doc_id, ngrams) {
self.filters.insert(doc_id, filter);
}
self.num_docs += 1;
}
pub fn get_filter(&self, doc_id: usize) -> Option<&DocumentFilter> {
self.filters.get(&doc_id)
}
pub fn find_candidates(&self, min_overlap: usize) -> Vec<(usize, usize, usize)> {
let mut pair_counts: HashMap<(usize, usize), usize> = HashMap::new();
for entries in self.index.values() {
if entries.len() < 2 {
continue;
}
for i in 0..entries.len() {
for j in (i + 1)..entries.len() {
let a = entries[i].doc_id.min(entries[j].doc_id);
let b = entries[i].doc_id.max(entries[j].doc_id);
*pair_counts.entry((a, b)).or_insert(0) += 1;
}
}
}
pair_counts
.into_iter()
.filter(|&(_, count)| count >= min_overlap)
.map(|((a, b), count)| (a, b, count))
.collect()
}
pub fn find_candidates_prescreened(
&self,
doc_ngrams: &HashMap<usize, Vec<u64>>,
min_overlap: usize,
) -> Vec<(usize, usize, usize)> {
let doc_ids: Vec<usize> = doc_ngrams.keys().copied().collect();
let mut pair_results: Vec<(usize, usize, usize)> = Vec::new();
for i in 0..doc_ids.len() {
for j in (i + 1)..doc_ids.len() {
let id_a = doc_ids[i].min(doc_ids[j]);
let id_b = doc_ids[i].max(doc_ids[j]);
let pass_prescreen = if let Some(filter_b) = self.filters.get(&id_b) {
if let Some(ngrams_a) = doc_ngrams.get(&id_a) {
filter_b.estimate_overlap(ngrams_a) >= min_overlap
} else {
false
}
} else {
true
};
if !pass_prescreen {
continue;
}
if let (Some(ngrams_a), Some(ngrams_b)) =
(doc_ngrams.get(&id_a), doc_ngrams.get(&id_b))
{
let set_a: std::collections::HashSet<u64> = ngrams_a.iter().copied().collect();
let overlap = ngrams_b
.iter()
.collect::<std::collections::HashSet<_>>()
.iter()
.filter(|h| set_a.contains(h))
.count();
if overlap >= min_overlap {
pair_results.push((id_a, id_b, overlap));
}
}
}
}
pair_results
}
pub fn ngrams_from_tokens(tokens: &[&str], n: usize) -> Vec<u64> {
if n == 0 || tokens.len() < n {
return Vec::new();
}
tokens
.windows(n)
.map(|window| {
let combined = window.join(" ");
xxh3_64(combined.as_bytes())
})
.collect()
}
}
impl Default for NgramIndex {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ngrams_from_tokens_basic() {
let tokens = vec!["fn", "foo", "(", "x", ")", "{", "return", "x", "}"];
let ngrams = NgramIndex::ngrams_from_tokens(&tokens, 3);
assert_eq!(ngrams.len(), 7);
}
#[test]
fn test_ngrams_from_tokens_too_few() {
let tokens = vec!["fn", "foo"];
let ngrams = NgramIndex::ngrams_from_tokens(&tokens, 3);
assert!(ngrams.is_empty());
}
#[test]
fn test_ngrams_from_tokens_zero_n() {
let tokens = vec!["fn", "foo", "bar"];
let ngrams = NgramIndex::ngrams_from_tokens(&tokens, 0);
assert!(ngrams.is_empty());
}
#[test]
fn test_ngrams_deterministic() {
let tokens = vec!["if", "x", ">", "0", "return", "x"];
let a = NgramIndex::ngrams_from_tokens(&tokens, 3);
let b = NgramIndex::ngrams_from_tokens(&tokens, 3);
assert_eq!(a, b);
}
#[test]
fn test_add_document_and_num_docs() {
let mut index = NgramIndex::new();
let ngrams = NgramIndex::ngrams_from_tokens(&["a", "b", "c", "d"], 2);
index.add_document(0, &ngrams);
assert_eq!(index.num_docs(), 1);
index.add_document(1, &ngrams);
assert_eq!(index.num_docs(), 2);
}
#[test]
fn test_similar_documents_found_as_candidates() {
let mut index = NgramIndex::new();
let tokens_a = vec![
"fn", "compute", "(", "x", ")", "{", "return", "x", "+", "1", "}",
];
let tokens_b = vec![
"fn",
"calculate",
"(",
"x",
")",
"{",
"return",
"x",
"+",
"2",
"}",
];
let ngrams_a = NgramIndex::ngrams_from_tokens(&tokens_a, 3);
let ngrams_b = NgramIndex::ngrams_from_tokens(&tokens_b, 3);
index.add_document(0, &ngrams_a);
index.add_document(1, &ngrams_b);
let candidates = index.find_candidates(2);
assert!(
!candidates.is_empty(),
"Similar documents should be candidate pairs"
);
let (a, b, overlap) = &candidates[0];
assert_eq!(*a, 0);
assert_eq!(*b, 1);
assert!(
*overlap >= 2,
"Expected at least 2 shared N-grams, got {overlap}"
);
}
#[test]
fn test_dissimilar_documents_not_candidates() {
let mut index = NgramIndex::new();
let tokens_a = vec![
"fn", "compute", "(", "x", ")", "{", "return", "x", "+", "1", "}",
];
let tokens_b = vec![
"class",
"Widget",
"extends",
"Base",
"implements",
"Drawable",
];
let ngrams_a = NgramIndex::ngrams_from_tokens(&tokens_a, 3);
let ngrams_b = NgramIndex::ngrams_from_tokens(&tokens_b, 3);
index.add_document(0, &ngrams_a);
index.add_document(1, &ngrams_b);
let candidates = index.find_candidates(2);
assert!(
candidates.is_empty(),
"Dissimilar documents should not be candidates, got {candidates:?}"
);
}
#[test]
fn test_empty_index_no_candidates() {
let index = NgramIndex::new();
let candidates = index.find_candidates(1);
assert!(candidates.is_empty());
}
#[test]
fn test_single_document_no_candidates() {
let mut index = NgramIndex::new();
let ngrams = NgramIndex::ngrams_from_tokens(&["a", "b", "c", "d"], 2);
index.add_document(0, &ngrams);
let candidates = index.find_candidates(1);
assert!(candidates.is_empty());
}
#[test]
fn test_multiple_documents_partial_overlap() {
let mut index = NgramIndex::new();
let tokens_0 = vec!["let", "x", "=", "foo", "(", "a", ")"];
let tokens_1 = vec!["let", "y", "=", "foo", "(", "b", ")"];
let tokens_2 = vec!["import", "os", "import", "sys", "print", "hello"];
let ng0 = NgramIndex::ngrams_from_tokens(&tokens_0, 3);
let ng1 = NgramIndex::ngrams_from_tokens(&tokens_1, 3);
let ng2 = NgramIndex::ngrams_from_tokens(&tokens_2, 3);
index.add_document(0, &ng0);
index.add_document(1, &ng1);
index.add_document(2, &ng2);
let candidates = index.find_candidates(1);
let has_01 = candidates.iter().any(|&(a, b, _)| a == 0 && b == 1);
let has_02 = candidates.iter().any(|&(a, b, _)| a == 0 && b == 2);
let has_12 = candidates.iter().any(|&(a, b, _)| a == 1 && b == 2);
assert!(has_01, "Docs 0 and 1 should be candidates");
assert!(!has_02, "Docs 0 and 2 should not be candidates");
assert!(!has_12, "Docs 1 and 2 should not be candidates");
}
#[test]
fn test_min_overlap_filtering() {
let mut index = NgramIndex::new();
let tokens = vec!["a", "b", "c", "d", "e", "f"];
let ngrams = NgramIndex::ngrams_from_tokens(&tokens, 3);
index.add_document(0, &ngrams);
index.add_document(1, &ngrams);
let candidates_low = index.find_candidates(1);
assert!(!candidates_low.is_empty());
let candidates_high = index.find_candidates(100);
assert!(
candidates_high.is_empty(),
"Threshold exceeds actual N-gram count"
);
}
#[test]
fn test_duplicate_ngrams_within_document() {
let mut index = NgramIndex::new();
let tokens = vec!["a", "b", "a", "b", "a", "b"];
let ngrams = NgramIndex::ngrams_from_tokens(&tokens, 2);
index.add_document(0, &ngrams);
assert_eq!(index.num_docs(), 1);
}
#[test]
fn test_document_filter_build() {
let ngrams = NgramIndex::ngrams_from_tokens(&["fn", "foo", "(", "x", ")", "{", "}"], 3);
let filter = DocumentFilter::build(0, &ngrams);
assert!(
filter.is_some(),
"Should build filter from non-empty N-grams"
);
let f = filter.unwrap();
assert_eq!(f.doc_id, 0);
assert!(f.ngram_count > 0);
}
#[test]
fn test_document_filter_contains_own_ngrams() {
let ngrams = NgramIndex::ngrams_from_tokens(&["fn", "foo", "(", "x", ")", "{", "}"], 3);
let filter = DocumentFilter::build(0, &ngrams).unwrap();
for &hash in &ngrams {
assert!(
filter.contains(&hash),
"Filter should contain its own N-gram hash"
);
}
}
#[test]
fn test_document_filter_empty_ngrams() {
let filter = DocumentFilter::build(0, &[]);
assert!(filter.is_none(), "Empty N-grams should return None");
}
#[test]
fn test_document_filter_estimate_overlap() {
let tokens_a = vec![
"fn", "compute", "(", "x", ")", "{", "return", "x", "+", "1", "}",
];
let tokens_b = vec![
"fn",
"calculate",
"(",
"x",
")",
"{",
"return",
"x",
"+",
"2",
"}",
];
let ngrams_a = NgramIndex::ngrams_from_tokens(&tokens_a, 3);
let ngrams_b = NgramIndex::ngrams_from_tokens(&tokens_b, 3);
let filter_a = DocumentFilter::build(0, &ngrams_a).unwrap();
let estimated = filter_a.estimate_overlap(&ngrams_b);
assert!(
estimated >= 2,
"Similar documents should have estimated overlap >= 2, got {estimated}"
);
}
#[test]
fn test_document_filter_no_overlap_dissimilar() {
let tokens_a = vec![
"fn", "compute", "(", "x", ")", "{", "return", "x", "+", "1", "}",
];
let tokens_b = vec![
"class",
"Widget",
"extends",
"Base",
"implements",
"Drawable",
];
let ngrams_a = NgramIndex::ngrams_from_tokens(&tokens_a, 3);
let ngrams_b = NgramIndex::ngrams_from_tokens(&tokens_b, 3);
let filter_a = DocumentFilter::build(0, &ngrams_a).unwrap();
let estimated = filter_a.estimate_overlap(&ngrams_b);
assert!(
estimated <= 1,
"Dissimilar documents should have near-zero estimated overlap, got {estimated}"
);
}
#[test]
fn test_find_candidates_prescreened_similar() {
let mut index = NgramIndex::new();
let tokens_a = vec![
"fn", "compute", "(", "x", ")", "{", "return", "x", "+", "1", "}",
];
let tokens_b = vec![
"fn",
"calculate",
"(",
"x",
")",
"{",
"return",
"x",
"+",
"2",
"}",
];
let ngrams_a = NgramIndex::ngrams_from_tokens(&tokens_a, 3);
let ngrams_b = NgramIndex::ngrams_from_tokens(&tokens_b, 3);
index.add_document(0, &ngrams_a);
index.add_document(1, &ngrams_b);
let mut doc_ngrams = HashMap::new();
doc_ngrams.insert(0, ngrams_a);
doc_ngrams.insert(1, ngrams_b);
let candidates = index.find_candidates_prescreened(&doc_ngrams, 2);
assert!(
!candidates.is_empty(),
"Pre-screened search should find similar documents"
);
let (a, b, overlap) = &candidates[0];
assert_eq!(*a, 0);
assert_eq!(*b, 1);
assert!(*overlap >= 2);
}
#[test]
fn test_find_candidates_prescreened_dissimilar() {
let mut index = NgramIndex::new();
let tokens_a = vec![
"fn", "compute", "(", "x", ")", "{", "return", "x", "+", "1", "}",
];
let tokens_b = vec![
"class",
"Widget",
"extends",
"Base",
"implements",
"Drawable",
];
let ngrams_a = NgramIndex::ngrams_from_tokens(&tokens_a, 3);
let ngrams_b = NgramIndex::ngrams_from_tokens(&tokens_b, 3);
index.add_document(0, &ngrams_a);
index.add_document(1, &ngrams_b);
let mut doc_ngrams = HashMap::new();
doc_ngrams.insert(0, ngrams_a);
doc_ngrams.insert(1, ngrams_b);
let candidates = index.find_candidates_prescreened(&doc_ngrams, 2);
assert!(
candidates.is_empty(),
"Pre-screened search should reject dissimilar documents"
);
}
#[test]
fn test_prescreened_matches_inverted_index_results() {
let mut index = NgramIndex::new();
let tokens_0 = vec!["let", "x", "=", "foo", "(", "a", ")"];
let tokens_1 = vec!["let", "y", "=", "foo", "(", "b", ")"];
let tokens_2 = vec!["import", "os", "import", "sys", "print", "hello"];
let ng0 = NgramIndex::ngrams_from_tokens(&tokens_0, 3);
let ng1 = NgramIndex::ngrams_from_tokens(&tokens_1, 3);
let ng2 = NgramIndex::ngrams_from_tokens(&tokens_2, 3);
index.add_document(0, &ng0);
index.add_document(1, &ng1);
index.add_document(2, &ng2);
let inverted_candidates = index.find_candidates(1);
let mut doc_ngrams = HashMap::new();
doc_ngrams.insert(0, ng0);
doc_ngrams.insert(1, ng1);
doc_ngrams.insert(2, ng2);
let prescreened_candidates = index.find_candidates_prescreened(&doc_ngrams, 1);
for &(a, b, _) in &inverted_candidates {
let found = prescreened_candidates
.iter()
.any(|&(pa, pb, _)| pa == a && pb == b);
assert!(
found,
"Inverted-index candidate ({a}, {b}) missing from prescreened results"
);
}
}
#[test]
fn test_add_document_builds_filter() {
let mut index = NgramIndex::new();
let ngrams = NgramIndex::ngrams_from_tokens(&["a", "b", "c", "d", "e"], 3);
index.add_document(42, &ngrams);
let filter = index.get_filter(42);
assert!(filter.is_some(), "add_document should build a filter");
assert_eq!(filter.unwrap().doc_id, 42);
}
}