rvf_runtime/
membership.rs1use rvf_types::membership::{FilterMode, MembershipHeader, MEMBERSHIP_MAGIC};
14use rvf_types::{ErrorCode, RvfError};
15
16pub struct MembershipFilter {
18 mode: FilterMode,
20 bitmap: Vec<u64>,
22 vector_count: u64,
24 member_count: u64,
26 generation_id: u32,
28}
29
30impl MembershipFilter {
31 pub fn new_include(vector_count: u64) -> Self {
33 let words = vector_count.div_ceil(64) as usize;
34 Self {
35 mode: FilterMode::Include,
36 bitmap: vec![0u64; words],
37 vector_count,
38 member_count: 0,
39 generation_id: 0,
40 }
41 }
42
43 pub fn new_exclude(vector_count: u64) -> Self {
45 let words = vector_count.div_ceil(64) as usize;
46 Self {
47 mode: FilterMode::Exclude,
48 bitmap: vec![0u64; words],
49 vector_count,
50 member_count: 0,
51 generation_id: 0,
52 }
53 }
54
55 pub fn add(&mut self, vector_id: u64) {
57 if vector_id >= self.vector_count {
58 return;
59 }
60 let word = (vector_id / 64) as usize;
61 let bit = vector_id % 64;
62 if word < self.bitmap.len() {
63 let mask = 1u64 << bit;
64 if self.bitmap[word] & mask == 0 {
65 self.bitmap[word] |= mask;
66 self.member_count += 1;
67 }
68 }
69 }
70
71 pub fn remove(&mut self, vector_id: u64) {
73 if vector_id >= self.vector_count {
74 return;
75 }
76 let word = (vector_id / 64) as usize;
77 let bit = vector_id % 64;
78 if word < self.bitmap.len() {
79 let mask = 1u64 << bit;
80 if self.bitmap[word] & mask != 0 {
81 self.bitmap[word] &= !mask;
82 self.member_count -= 1;
83 }
84 }
85 }
86
87 fn bitmap_contains(&self, vector_id: u64) -> bool {
89 if vector_id >= self.vector_count {
90 return false;
91 }
92 let word = (vector_id / 64) as usize;
93 let bit = vector_id % 64;
94 if word < self.bitmap.len() {
95 self.bitmap[word] & (1u64 << bit) != 0
96 } else {
97 false
98 }
99 }
100
101 pub fn contains(&self, vector_id: u64) -> bool {
106 match self.mode {
107 FilterMode::Include => self.bitmap_contains(vector_id),
108 FilterMode::Exclude => !self.bitmap_contains(vector_id),
109 }
110 }
111
112 pub fn member_count(&self) -> u64 {
114 self.member_count
115 }
116
117 pub fn vector_count(&self) -> u64 {
119 self.vector_count
120 }
121
122 pub fn mode(&self) -> FilterMode {
124 self.mode
125 }
126
127 pub fn generation_id(&self) -> u32 {
129 self.generation_id
130 }
131
132 pub fn bump_generation(&mut self) {
134 self.generation_id += 1;
135 }
136
137 pub fn serialize(&self) -> Vec<u8> {
139 let mut buf = Vec::with_capacity(self.bitmap.len() * 8);
140 for &word in &self.bitmap {
141 buf.extend_from_slice(&word.to_le_bytes());
142 }
143 buf
144 }
145
146 pub fn deserialize(data: &[u8], header: &MembershipHeader) -> Result<Self, RvfError> {
148 let mode = FilterMode::try_from(header.filter_mode)
149 .map_err(|_| RvfError::Code(ErrorCode::MembershipInvalid))?;
150
151 let word_count = header.vector_count.div_ceil(64) as usize;
152 let expected_bytes = word_count * 8;
153 if data.len() < expected_bytes {
154 return Err(RvfError::Code(ErrorCode::MembershipInvalid));
155 }
156
157 let mut bitmap = Vec::with_capacity(word_count);
158 for i in 0..word_count {
159 let offset = i * 8;
160 let word = u64::from_le_bytes([
161 data[offset],
162 data[offset + 1],
163 data[offset + 2],
164 data[offset + 3],
165 data[offset + 4],
166 data[offset + 5],
167 data[offset + 6],
168 data[offset + 7],
169 ]);
170 bitmap.push(word);
171 }
172
173 let member_count: u64 = bitmap.iter().map(|w| w.count_ones() as u64).sum();
175
176 Ok(Self {
177 mode,
178 bitmap,
179 vector_count: header.vector_count,
180 member_count,
181 generation_id: header.generation_id,
182 })
183 }
184
185 pub fn to_header(&self) -> MembershipHeader {
187 let bitmap_bytes = self.serialize();
188 let filter_hash = crate::store::simple_shake256_256(&bitmap_bytes);
189
190 MembershipHeader {
191 magic: MEMBERSHIP_MAGIC,
192 version: 1,
193 filter_type: rvf_types::membership::FilterType::Bitmap as u8,
194 filter_mode: self.mode as u8,
195 vector_count: self.vector_count,
196 member_count: self.member_count,
197 filter_offset: 96, filter_size: bitmap_bytes.len() as u32,
199 generation_id: self.generation_id,
200 filter_hash,
201 bloom_offset: 0,
202 bloom_size: 0,
203 _reserved: 0,
204 _reserved2: [0u8; 8],
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn include_mode_empty_is_empty_view() {
215 let filter = MembershipFilter::new_include(100);
216 for i in 0..100 {
217 assert!(!filter.contains(i));
218 }
219 }
220
221 #[test]
222 fn include_mode_add_and_check() {
223 let mut filter = MembershipFilter::new_include(100);
224 filter.add(10);
225 filter.add(50);
226 filter.add(99);
227
228 assert!(filter.contains(10));
229 assert!(filter.contains(50));
230 assert!(filter.contains(99));
231 assert!(!filter.contains(0));
232 assert!(!filter.contains(11));
233 assert_eq!(filter.member_count(), 3);
234 }
235
236 #[test]
237 fn exclude_mode() {
238 let mut filter = MembershipFilter::new_exclude(100);
239 assert!(filter.contains(0));
241 assert!(filter.contains(50));
242
243 filter.add(50);
245 assert!(!filter.contains(50));
246 assert!(filter.contains(0));
247 assert!(filter.contains(99));
248 }
249
250 #[test]
251 fn add_remove() {
252 let mut filter = MembershipFilter::new_include(64);
253 filter.add(10);
254 assert_eq!(filter.member_count(), 1);
255 assert!(filter.contains(10));
256
257 filter.remove(10);
258 assert_eq!(filter.member_count(), 0);
259 assert!(!filter.contains(10));
260 }
261
262 #[test]
263 fn add_out_of_bounds_ignored() {
264 let mut filter = MembershipFilter::new_include(10);
265 filter.add(100); assert_eq!(filter.member_count(), 0);
267 }
268
269 #[test]
270 fn double_add_no_double_count() {
271 let mut filter = MembershipFilter::new_include(64);
272 filter.add(5);
273 filter.add(5);
274 assert_eq!(filter.member_count(), 1);
275 }
276
277 #[test]
278 fn serialize_deserialize_round_trip() {
279 let mut filter = MembershipFilter::new_include(200);
280 filter.add(0);
281 filter.add(63);
282 filter.add(64);
283 filter.add(127);
284 filter.add(199);
285
286 let header = filter.to_header();
287 let bitmap_data = filter.serialize();
288
289 let filter2 = MembershipFilter::deserialize(&bitmap_data, &header).unwrap();
290 assert_eq!(filter2.vector_count(), 200);
291 assert_eq!(filter2.member_count(), 5);
292 assert!(filter2.contains(0));
293 assert!(filter2.contains(63));
294 assert!(filter2.contains(64));
295 assert!(filter2.contains(127));
296 assert!(filter2.contains(199));
297 assert!(!filter2.contains(1));
298 assert!(!filter2.contains(100));
299 }
300
301 #[test]
302 fn generation_bump() {
303 let mut filter = MembershipFilter::new_include(10);
304 assert_eq!(filter.generation_id(), 0);
305 filter.bump_generation();
306 assert_eq!(filter.generation_id(), 1);
307 }
308
309 #[test]
310 fn bitmap_word_boundary() {
311 let mut filter = MembershipFilter::new_include(130);
313 filter.add(63);
314 filter.add(64);
315 filter.add(128);
316
317 assert!(filter.contains(63));
318 assert!(filter.contains(64));
319 assert!(filter.contains(128));
320 assert!(!filter.contains(62));
321 assert!(!filter.contains(65));
322 assert!(!filter.contains(129));
323 }
324}