Skip to main content

zer_blocking/
index.rs

1use std::collections::{HashMap, HashSet};
2use zer_core::{record::RecordId, traits::BlockIndex};
3
4/// Inverted index mapping blocking keys to record IDs.
5pub struct InvertedIndex {
6    buckets:     HashMap<String, Vec<RecordId>>,
7    record_keys: HashMap<RecordId, Vec<String>>,
8}
9
10impl InvertedIndex {
11    pub fn new() -> Self {
12        Self {
13            buckets:     HashMap::new(),
14            record_keys: HashMap::new(),
15        }
16    }
17
18    pub fn insert(&mut self, record_id: RecordId, keys: Vec<String>) {
19        for key in &keys {
20            self.buckets.entry(key.clone()).or_default().push(record_id);
21        }
22        self.record_keys.insert(record_id, keys);
23    }
24
25    /// Returns all record IDs sharing at least one key with the query keys,
26    /// excluding `exclude` (the querying record itself). Result is deduplicated.
27    pub fn lookup_union(&self, keys: &[String], exclude: RecordId) -> Vec<RecordId> {
28        let mut seen: HashSet<RecordId> = HashSet::new();
29        for key in keys {
30            if let Some(ids) = self.buckets.get(key) {
31                for &id in ids {
32                    if id != exclude {
33                        seen.insert(id);
34                    }
35                }
36            }
37        }
38        seen.into_iter().collect()
39    }
40
41    /// Like [`Self::lookup_union`] but skips any bucket whose size exceeds
42    /// `max_bucket_size`.  Overfull buckets have low selectivity and produce
43    /// O(n²) spurious pairs; capping them prevents unbounded memory growth.
44    ///
45    /// Pass `max_bucket_size = 0` to disable the cap (same as `lookup_union`).
46    pub fn lookup_union_capped(
47        &self,
48        keys: &[String],
49        exclude: RecordId,
50        max_bucket_size: usize,
51    ) -> Vec<RecordId> {
52        let mut seen: HashSet<RecordId> = HashSet::new();
53        for key in keys {
54            if let Some(ids) = self.buckets.get(key) {
55                if max_bucket_size > 0 && ids.len() > max_bucket_size {
56                    continue;
57                }
58                for &id in ids {
59                    if id != exclude {
60                        seen.insert(id);
61                    }
62                }
63            }
64        }
65        seen.into_iter().collect()
66    }
67
68    /// Returns the size of a specific bucket, or 0 if not present.
69    pub fn bucket_size(&self, key: &str) -> usize {
70        self.buckets.get(key).map_or(0, |v| v.len())
71    }
72
73    /// Returns the number of buckets exceeding `max_size`.
74    pub fn oversized_buckets(&self, max_size: usize) -> usize {
75        self.buckets.values().filter(|v| v.len() > max_size).count()
76    }
77
78    /// Enumerate all canonical `(i < j)` candidate pairs directly from bucket contents.
79    ///
80    /// Pass `max_bucket_size = 0` to disable the cap.
81    pub fn all_pairs(
82        &self,
83        id_to_idx:       &HashMap<RecordId, usize>,
84        max_bucket_size: usize,
85    ) -> Vec<(usize, usize)> {
86        let mut pairs: Vec<(usize, usize)> = Vec::new();
87        for bucket in self.buckets.values() {
88            if max_bucket_size > 0 && bucket.len() > max_bucket_size { continue; }
89            let indices: Vec<usize> = bucket.iter()
90                .filter_map(|id| id_to_idx.get(id).copied())
91                .collect();
92            for a in 0..indices.len() {
93                for b in (a + 1)..indices.len() {
94                    let (i, j) = (indices[a], indices[b]);
95                    pairs.push(if i < j { (i, j) } else { (j, i) });
96                }
97            }
98        }
99        pairs.sort_unstable();
100        pairs.dedup();
101        pairs
102    }
103
104    pub fn remove(&mut self, record_id: RecordId) {
105        if let Some(keys) = self.record_keys.remove(&record_id) {
106            for key in keys {
107                if let Some(bucket) = self.buckets.get_mut(&key) {
108                    bucket.retain(|&id| id != record_id);
109                }
110            }
111        }
112    }
113
114    pub fn len(&self) -> usize {
115        self.buckets.len()
116    }
117
118    pub fn is_empty(&self) -> bool {
119        self.buckets.is_empty()
120    }
121
122    pub fn record_count(&self) -> usize {
123        self.record_keys.len()
124    }
125}
126
127impl Default for InvertedIndex {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133impl BlockIndex for InvertedIndex {
134    fn insert(&mut self, record_id: RecordId, keys: Vec<String>) {
135        self.insert(record_id, keys);
136    }
137
138    fn lookup_union(&self, keys: &[String], exclude: RecordId) -> Vec<RecordId> {
139        self.lookup_union(keys, exclude)
140    }
141
142    fn remove(&mut self, record_id: RecordId) {
143        self.remove(record_id);
144    }
145
146    fn as_any(&self) -> &dyn std::any::Any {
147        self
148    }
149
150    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
151        self
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    fn make_index() -> InvertedIndex {
160        let mut idx = InvertedIndex::new();
161        idx.insert(1, vec!["key_a".into(), "key_b".into()]);
162        idx.insert(2, vec!["key_b".into(), "key_c".into()]);
163        idx.insert(3, vec!["key_c".into(), "key_d".into()]);
164        idx
165    }
166
167    #[test]
168    fn lookup_union_returns_all_matching() {
169        let idx = make_index();
170        let mut result = idx.lookup_union(&["key_b".into()], 99);
171        result.sort();
172        assert_eq!(result, vec![1, 2]);
173    }
174
175    #[test]
176    fn lookup_union_deduplicates() {
177        let mut idx = InvertedIndex::new();
178        idx.insert(1, vec!["k1".into(), "k2".into()]);
179        idx.insert(2, vec!["k1".into(), "k2".into()]);
180
181        let result = idx.lookup_union(&["k1".into(), "k2".into()], 99);
182        assert_eq!(result.len(), 2);
183    }
184
185    #[test]
186    fn no_self_candidates() {
187        let idx = make_index();
188        let result = idx.lookup_union(&["key_a".into(), "key_b".into()], 1);
189        assert!(!result.contains(&1));
190    }
191
192    #[test]
193    fn remove_cleans_up() {
194        let mut idx = make_index();
195        idx.remove(1);
196        let result = idx.lookup_union(&["key_a".into(), "key_b".into()], 99);
197        assert!(!result.contains(&1));
198    }
199
200    #[test]
201    fn block_index_trait_insert_and_lookup() {
202        let mut idx: Box<dyn BlockIndex> = Box::new(InvertedIndex::new());
203        idx.insert(10, vec!["k".into()]);
204        idx.insert(20, vec!["k".into()]);
205        let mut result = idx.lookup_union(&["k".into()], 99);
206        result.sort();
207        assert_eq!(result, vec![10, 20]);
208    }
209
210    #[test]
211    fn block_index_trait_remove() {
212        let mut idx: Box<dyn BlockIndex> = Box::new(InvertedIndex::new());
213        idx.insert(1, vec!["x".into()]);
214        idx.remove(1);
215        let result = idx.lookup_union(&["x".into()], 99);
216        assert!(result.is_empty());
217    }
218
219    #[test]
220    fn lookup_union_capped_skips_oversized_bucket() {
221        let mut idx = InvertedIndex::new();
222        // "big_key" bucket has 5 records; "small_key" has 2.
223        for id in 1u64..=5 {
224            idx.insert(id, vec!["big_key".into()]);
225        }
226        idx.insert(10u64, vec!["small_key".into()]);
227        idx.insert(11u64, vec!["small_key".into()]);
228
229        // cap=3: big_key (5 entries) is skipped; small_key (2 entries) is used.
230        let result = idx.lookup_union_capped(&["big_key".into(), "small_key".into()], 99, 3);
231        assert!(!result.contains(&1), "big_key bucket must be skipped");
232        assert!(result.contains(&10), "small_key bucket must be included");
233        assert!(result.contains(&11), "small_key bucket must be included");
234    }
235
236    #[test]
237    fn lookup_union_capped_zero_cap_disables_limit() {
238        let mut idx = InvertedIndex::new();
239        for id in 1u64..=10 {
240            idx.insert(id, vec!["k".into()]);
241        }
242        // cap=0 means no limit; all 9 non-excluded records returned.
243        let result = idx.lookup_union_capped(&["k".into()], 1, 0);
244        assert_eq!(result.len(), 9);
245    }
246
247    #[test]
248    fn oversized_buckets_count_is_correct() {
249        let mut idx = InvertedIndex::new();
250        for id in 1u64..=5 {
251            idx.insert(id, vec!["big".into()]);
252        }
253        idx.insert(10u64, vec!["small".into()]);
254        assert_eq!(idx.oversized_buckets(4), 1);
255        assert_eq!(idx.oversized_buckets(5), 0);
256        assert_eq!(idx.oversized_buckets(0), 2);
257    }
258}