1#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct BitVec2048 {
9 words: [u64; 32],
10}
11
12impl Default for BitVec2048 {
13 fn default() -> Self {
14 Self::new()
15 }
16}
17
18impl BitVec2048 {
19 pub fn new() -> Self {
21 Self { words: [0u64; 32] }
22 }
23
24 pub fn set(&mut self, bit: usize) {
29 assert!(
30 bit < 2048,
31 "bit index {bit} out of range for BitVec2048 (max 2047)"
32 );
33 self.words[bit / 64] |= 1u64 << (bit % 64);
34 }
35
36 pub fn get(&self, bit: usize) -> bool {
41 assert!(
42 bit < 2048,
43 "bit index {bit} out of range for BitVec2048 (max 2047)"
44 );
45 (self.words[bit / 64] >> (bit % 64)) & 1 == 1
46 }
47
48 pub fn popcount(&self) -> u32 {
50 self.words.iter().map(|w| w.count_ones()).sum()
51 }
52
53 pub fn and(&self, other: &Self) -> Self {
55 let mut result = Self::new();
56 for (out, (a, b)) in result
57 .words
58 .iter_mut()
59 .zip(self.words.iter().zip(other.words.iter()))
60 {
61 *out = a & b;
62 }
63 result
64 }
65
66 pub fn or(&self, other: &Self) -> Self {
68 let mut result = Self::new();
69 for (out, (a, b)) in result
70 .words
71 .iter_mut()
72 .zip(self.words.iter().zip(other.words.iter()))
73 {
74 *out = a | b;
75 }
76 result
77 }
78
79 pub fn intersection_popcount(&self, other: &Self) -> u32 {
85 self.words.iter().zip(other.words.iter())
86 .map(|(a, b)| (a & b).count_ones())
87 .sum()
88 }
89
90 #[inline]
100 pub fn tanimoto_with_counts(&self, other: &Self, self_popcount: u32, other_popcount: u32) -> f32 {
101 let inter = self.intersection_popcount(other) as f32;
102 let union = self_popcount as f32 + other_popcount as f32 - inter;
103 if union == 0.0 { 1.0 } else { inter / union }
104 }
105
106 pub fn tanimoto(&self, other: &Self) -> f64 {
110 let intersection = self.and(other).popcount() as f64;
111 let a = self.popcount() as f64;
112 let b = other.popcount() as f64;
113 let union = a + b - intersection;
114 if union == 0.0 {
115 1.0
116 } else {
117 intersection / union
118 }
119 }
120
121 pub fn dice(&self, other: &Self) -> f64 {
125 let intersection = self.and(other).popcount() as f64;
126 let a = self.popcount() as f64;
127 let b = other.popcount() as f64;
128 let denom = a + b;
129 if denom == 0.0 {
130 1.0
131 } else {
132 2.0 * intersection / denom
133 }
134 }
135
136 pub fn fold(&self, bits: usize) -> Self {
146 assert!(
147 matches!(bits, 256 | 512 | 1024),
148 "fold target must be 1024, 512, or 256; got {bits}"
149 );
150
151 let mut current = self.clone();
152 let mut current_bits = 2048usize;
153
154 while current_bits > bits {
155 let half_words = current_bits / 64 / 2; let mut folded = Self::new();
157 for i in 0..half_words {
158 folded.words[i] = current.words[i] ^ current.words[i + half_words];
159 }
160 current = folded;
161 current_bits /= 2;
162 }
163
164 current
165 }
166
167 pub fn to_bitvecn(&self) -> BitVecN {
169 BitVecN {
170 words: self.words.to_vec(),
171 bits: 2048,
172 }
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Eq)]
181pub struct BitVecN {
182 words: Vec<u64>,
183 bits: usize,
184}
185
186impl BitVecN {
187 pub fn new(bits: usize) -> Self {
192 assert!(bits > 0, "BitVecN must have at least 1 bit");
193 let num_words = bits.div_ceil(64);
194 Self {
195 words: vec![0u64; num_words],
196 bits,
197 }
198 }
199
200 pub fn bit_width(&self) -> usize {
202 self.bits
203 }
204
205 pub fn set(&mut self, bit: usize) {
210 assert!(
211 bit < self.bits,
212 "bit index {bit} out of range for BitVecN (max {})",
213 self.bits - 1
214 );
215 self.words[bit / 64] |= 1u64 << (bit % 64);
216 }
217
218 pub fn get(&self, bit: usize) -> bool {
223 assert!(
224 bit < self.bits,
225 "bit index {bit} out of range for BitVecN (max {})",
226 self.bits - 1
227 );
228 (self.words[bit / 64] >> (bit % 64)) & 1 == 1
229 }
230
231 pub fn popcount(&self) -> u32 {
233 self.words.iter().map(|w| w.count_ones()).sum()
234 }
235
236 pub fn tanimoto(&self, other: &Self) -> f64 {
243 assert_eq!(
244 self.bits, other.bits,
245 "BitVecN tanimoto requires same bit width"
246 );
247 let mut intersection = 0u32;
248 for (a, b) in self.words.iter().zip(other.words.iter()) {
249 intersection += (a & b).count_ones();
250 }
251 let intersection = intersection as f64;
252 let a = self.popcount() as f64;
253 let b = other.popcount() as f64;
254 let union = a + b - intersection;
255 if union == 0.0 {
256 1.0
257 } else {
258 intersection / union
259 }
260 }
261
262 pub fn from_bitvec2048(bv: &BitVec2048) -> Self {
264 BitVecN {
265 words: bv.words.to_vec(),
266 bits: 2048,
267 }
268 }
269
270 pub fn to_bitvec2048(&self) -> Option<BitVec2048> {
274 if self.bits != 2048 {
275 return None;
276 }
277 let mut arr = [0u64; 32];
278 for (i, &w) in self.words.iter().enumerate() {
279 arr[i] = w;
280 }
281 Some(BitVec2048 { words: arr })
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
292 fn new_bitvec_is_all_zero() {
293 let bv = BitVec2048::new();
294 assert_eq!(bv.popcount(), 0, "new bitvec must have popcount 0");
295 }
296
297 #[test]
298 fn set_and_get_boundary_bits() {
299 let mut bv = BitVec2048::new();
300 bv.set(0);
301 bv.set(2047);
302 assert_eq!(bv.popcount(), 2);
303 assert!(bv.get(0));
304 assert!(!bv.get(1));
305 assert!(bv.get(2047));
306 }
307
308 #[test]
309 fn tanimoto_identical_nonzero() {
310 let mut bv = BitVec2048::new();
311 bv.set(42);
312 bv.set(100);
313 assert_eq!(bv.tanimoto(&bv.clone()), 1.0);
314 }
315
316 #[test]
317 fn tanimoto_disjoint_is_zero() {
318 let mut a = BitVec2048::new();
319 a.set(10);
320 let mut b = BitVec2048::new();
321 b.set(20);
322 assert_eq!(a.tanimoto(&b), 0.0);
323 }
324
325 #[test]
326 fn dice_identical_nonzero() {
327 let mut bv = BitVec2048::new();
328 bv.set(7);
329 assert_eq!(bv.dice(&bv.clone()), 1.0);
330 }
331
332 #[test]
333 fn fold_1024_popcount_leq_original() {
334 let mut bv = BitVec2048::new();
336 for i in (0..2048).step_by(3) {
337 bv.set(i);
338 }
339 let folded = bv.fold(1024);
340 assert!(
341 folded.popcount() <= bv.popcount(),
342 "folded popcount ({}) must be <= original ({})",
343 folded.popcount(),
344 bv.popcount()
345 );
346 }
347
348 #[test]
349 fn and_or_basic_correctness() {
350 let mut a = BitVec2048::new();
351 a.set(5);
352 a.set(10);
353
354 let mut b = BitVec2048::new();
355 b.set(10);
356 b.set(15);
357
358 let and = a.and(&b);
359 assert!(and.get(10), "AND: shared bit 10 should be set");
360 assert!(!and.get(5), "AND: bit 5 only in A should be clear");
361 assert!(!and.get(15), "AND: bit 15 only in B should be clear");
362
363 let or = a.or(&b);
364 assert!(or.get(5), "OR: bit 5 should be set");
365 assert!(or.get(10), "OR: bit 10 should be set");
366 assert!(or.get(15), "OR: bit 15 should be set");
367 assert!(!or.get(0), "OR: bit 0 should be clear");
368 }
369
370 #[test]
373 fn bitvecn_new_creates_zero_vector() {
374 let bv = BitVecN::new(512);
375 assert_eq!(bv.bit_width(), 512);
376 assert_eq!(bv.popcount(), 0);
377 }
378
379 #[test]
380 fn bitvecn_set_get_basic() {
381 let mut bv = BitVecN::new(1024);
382 bv.set(0);
383 bv.set(512);
384 bv.set(1023);
385 assert!(bv.get(0));
386 assert!(bv.get(512));
387 assert!(bv.get(1023));
388 assert!(!bv.get(1));
389 assert_eq!(bv.popcount(), 3);
390 }
391
392 #[test]
393 fn bitvecn_tanimoto_identical() {
394 let mut bv = BitVecN::new(256);
395 bv.set(10);
396 bv.set(50);
397 assert_eq!(bv.tanimoto(&bv.clone()), 1.0);
398 }
399
400 #[test]
401 fn bitvecn_tanimoto_disjoint() {
402 let mut a = BitVecN::new(256);
403 a.set(5);
404 let mut b = BitVecN::new(256);
405 b.set(10);
406 assert_eq!(a.tanimoto(&b), 0.0);
407 }
408
409 #[test]
410 fn bitvecn_conversion_to_from_2048() {
411 let mut bv2048 = BitVec2048::new();
412 bv2048.set(42);
413 bv2048.set(100);
414
415 let bvn = BitVecN::from_bitvec2048(&bv2048);
416 assert_eq!(bvn.bit_width(), 2048);
417 assert!(bvn.get(42));
418 assert!(bvn.get(100));
419 assert_eq!(bvn.popcount(), 2);
420
421 let bv2048_back = bvn.to_bitvec2048();
422 assert_eq!(bv2048_back, Some(bv2048));
423 }
424
425 #[test]
426 fn bitvecn_arbitrary_sizes() {
427 for size in [128, 256, 512, 1024, 2048, 4096] {
428 let mut bv = BitVecN::new(size);
429 bv.set(0);
430 bv.set(size - 1);
431 assert_eq!(bv.popcount(), 2);
432 assert_eq!(bv.bit_width(), size);
433 }
434 }
435}