1use std::collections::{HashMap, HashSet};
2use zer_core::{record::RecordId, traits::BlockIndex};
3
4pub struct InvertedIndex {
6 buckets: HashMap<String, Vec<RecordId>>,
7 record_keys: HashMap<RecordId, Vec<String>>,
8}
9
10impl InvertedIndex {
11 pub fn new() -> Self {
12 Self {
13 buckets: HashMap::new(),
14 record_keys: HashMap::new(),
15 }
16 }
17
18 pub fn insert(&mut self, record_id: RecordId, keys: Vec<String>) {
19 for key in &keys {
20 self.buckets.entry(key.clone()).or_default().push(record_id);
21 }
22 self.record_keys.insert(record_id, keys);
23 }
24
25 pub fn lookup_union(&self, keys: &[String], exclude: RecordId) -> Vec<RecordId> {
28 let mut seen: HashSet<RecordId> = HashSet::new();
29 for key in keys {
30 if let Some(ids) = self.buckets.get(key) {
31 for &id in ids {
32 if id != exclude {
33 seen.insert(id);
34 }
35 }
36 }
37 }
38 seen.into_iter().collect()
39 }
40
41 pub fn lookup_union_capped(
47 &self,
48 keys: &[String],
49 exclude: RecordId,
50 max_bucket_size: usize,
51 ) -> Vec<RecordId> {
52 let mut seen: HashSet<RecordId> = HashSet::new();
53 for key in keys {
54 if let Some(ids) = self.buckets.get(key) {
55 if max_bucket_size > 0 && ids.len() > max_bucket_size {
56 continue;
57 }
58 for &id in ids {
59 if id != exclude {
60 seen.insert(id);
61 }
62 }
63 }
64 }
65 seen.into_iter().collect()
66 }
67
68 pub fn bucket_size(&self, key: &str) -> usize {
70 self.buckets.get(key).map_or(0, |v| v.len())
71 }
72
73 pub fn oversized_buckets(&self, max_size: usize) -> usize {
75 self.buckets.values().filter(|v| v.len() > max_size).count()
76 }
77
78 pub fn all_pairs(
82 &self,
83 id_to_idx: &HashMap<RecordId, usize>,
84 max_bucket_size: usize,
85 ) -> Vec<(usize, usize)> {
86 let mut pairs: Vec<(usize, usize)> = Vec::new();
87 for bucket in self.buckets.values() {
88 if max_bucket_size > 0 && bucket.len() > max_bucket_size { continue; }
89 let indices: Vec<usize> = bucket.iter()
90 .filter_map(|id| id_to_idx.get(id).copied())
91 .collect();
92 for a in 0..indices.len() {
93 for b in (a + 1)..indices.len() {
94 let (i, j) = (indices[a], indices[b]);
95 pairs.push(if i < j { (i, j) } else { (j, i) });
96 }
97 }
98 }
99 pairs.sort_unstable();
100 pairs.dedup();
101 pairs
102 }
103
104 pub fn remove(&mut self, record_id: RecordId) {
105 if let Some(keys) = self.record_keys.remove(&record_id) {
106 for key in keys {
107 if let Some(bucket) = self.buckets.get_mut(&key) {
108 bucket.retain(|&id| id != record_id);
109 }
110 }
111 }
112 }
113
114 pub fn len(&self) -> usize {
115 self.buckets.len()
116 }
117
118 pub fn is_empty(&self) -> bool {
119 self.buckets.is_empty()
120 }
121
122 pub fn record_count(&self) -> usize {
123 self.record_keys.len()
124 }
125}
126
127impl Default for InvertedIndex {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133impl BlockIndex for InvertedIndex {
134 fn insert(&mut self, record_id: RecordId, keys: Vec<String>) {
135 self.insert(record_id, keys);
136 }
137
138 fn lookup_union(&self, keys: &[String], exclude: RecordId) -> Vec<RecordId> {
139 self.lookup_union(keys, exclude)
140 }
141
142 fn remove(&mut self, record_id: RecordId) {
143 self.remove(record_id);
144 }
145
146 fn as_any(&self) -> &dyn std::any::Any {
147 self
148 }
149
150 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
151 self
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 fn make_index() -> InvertedIndex {
160 let mut idx = InvertedIndex::new();
161 idx.insert(1, vec!["key_a".into(), "key_b".into()]);
162 idx.insert(2, vec!["key_b".into(), "key_c".into()]);
163 idx.insert(3, vec!["key_c".into(), "key_d".into()]);
164 idx
165 }
166
167 #[test]
168 fn lookup_union_returns_all_matching() {
169 let idx = make_index();
170 let mut result = idx.lookup_union(&["key_b".into()], 99);
171 result.sort();
172 assert_eq!(result, vec![1, 2]);
173 }
174
175 #[test]
176 fn lookup_union_deduplicates() {
177 let mut idx = InvertedIndex::new();
178 idx.insert(1, vec!["k1".into(), "k2".into()]);
179 idx.insert(2, vec!["k1".into(), "k2".into()]);
180
181 let result = idx.lookup_union(&["k1".into(), "k2".into()], 99);
182 assert_eq!(result.len(), 2);
183 }
184
185 #[test]
186 fn no_self_candidates() {
187 let idx = make_index();
188 let result = idx.lookup_union(&["key_a".into(), "key_b".into()], 1);
189 assert!(!result.contains(&1));
190 }
191
192 #[test]
193 fn remove_cleans_up() {
194 let mut idx = make_index();
195 idx.remove(1);
196 let result = idx.lookup_union(&["key_a".into(), "key_b".into()], 99);
197 assert!(!result.contains(&1));
198 }
199
200 #[test]
201 fn block_index_trait_insert_and_lookup() {
202 let mut idx: Box<dyn BlockIndex> = Box::new(InvertedIndex::new());
203 idx.insert(10, vec!["k".into()]);
204 idx.insert(20, vec!["k".into()]);
205 let mut result = idx.lookup_union(&["k".into()], 99);
206 result.sort();
207 assert_eq!(result, vec![10, 20]);
208 }
209
210 #[test]
211 fn block_index_trait_remove() {
212 let mut idx: Box<dyn BlockIndex> = Box::new(InvertedIndex::new());
213 idx.insert(1, vec!["x".into()]);
214 idx.remove(1);
215 let result = idx.lookup_union(&["x".into()], 99);
216 assert!(result.is_empty());
217 }
218
219 #[test]
220 fn lookup_union_capped_skips_oversized_bucket() {
221 let mut idx = InvertedIndex::new();
222 for id in 1u64..=5 {
224 idx.insert(id, vec!["big_key".into()]);
225 }
226 idx.insert(10u64, vec!["small_key".into()]);
227 idx.insert(11u64, vec!["small_key".into()]);
228
229 let result = idx.lookup_union_capped(&["big_key".into(), "small_key".into()], 99, 3);
231 assert!(!result.contains(&1), "big_key bucket must be skipped");
232 assert!(result.contains(&10), "small_key bucket must be included");
233 assert!(result.contains(&11), "small_key bucket must be included");
234 }
235
236 #[test]
237 fn lookup_union_capped_zero_cap_disables_limit() {
238 let mut idx = InvertedIndex::new();
239 for id in 1u64..=10 {
240 idx.insert(id, vec!["k".into()]);
241 }
242 let result = idx.lookup_union_capped(&["k".into()], 1, 0);
244 assert_eq!(result.len(), 9);
245 }
246
247 #[test]
248 fn oversized_buckets_count_is_correct() {
249 let mut idx = InvertedIndex::new();
250 for id in 1u64..=5 {
251 idx.insert(id, vec!["big".into()]);
252 }
253 idx.insert(10u64, vec!["small".into()]);
254 assert_eq!(idx.oversized_buckets(4), 1);
255 assert_eq!(idx.oversized_buckets(5), 0);
256 assert_eq!(idx.oversized_buckets(0), 2);
257 }
258}