1use core::hash::{Hash, Hasher};
2use core::iter::FusedIterator;
3use core::marker::PhantomData;
4use core::ops::BitAndAssign;
5use num_traits::{AsPrimitive, PrimInt};
6
7use super::bitset::PrimBitSetIter;
8
9pub struct WordSetIter<S, T: PrimInt, V> {
13 store: S,
14 word_idx: usize,
15 current: PrimBitSetIter<T, usize>,
16 _marker: PhantomData<V>,
17}
18
19impl<S: AsRef<[T]>, T: PrimInt + BitAndAssign, V> WordSetIter<S, T, V> {
20 #[inline]
21 pub(crate) fn new(store: S) -> Self {
22 Self {
23 store,
24 word_idx: 0,
25 current: PrimBitSetIter::empty(),
26 _marker: PhantomData,
27 }
28 }
29
30 #[inline]
31 fn remaining_len(&self) -> usize {
32 self.current.len()
33 + self.store.as_ref()[self.word_idx..]
34 .iter()
35 .map(|w| w.count_ones() as usize)
36 .sum::<usize>()
37 }
38}
39
40impl<S: AsRef<[T]>, T: PrimInt + BitAndAssign, V: TryFrom<usize>> Iterator
41 for WordSetIter<S, T, V>
42{
43 type Item = V;
44
45 fn next(&mut self) -> Option<V> {
46 let words = self.store.as_ref();
47 let bits_per = core::mem::size_of::<T>() * 8;
48 loop {
49 if let Some(pos) = self.current.next() {
50 let idx = (self.word_idx - 1) * bits_per + pos;
51 let converted = V::try_from(idx);
52 debug_assert!(converted.is_ok());
53 match converted {
54 Ok(value) => return Some(value),
55 Err(_) => unsafe { core::hint::unreachable_unchecked() },
56 }
57 }
58 if self.word_idx >= words.len() {
59 return None;
60 }
61 self.current = PrimBitSetIter::from_raw(words[self.word_idx]);
62 self.word_idx += 1;
63 }
64 }
65
66 #[inline]
67 fn size_hint(&self) -> (usize, Option<usize>) {
68 let len = self.remaining_len();
69 (len, Some(len))
70 }
71
72 #[inline]
73 fn count(self) -> usize
74 where
75 Self: Sized,
76 {
77 self.remaining_len()
78 }
79}
80
81impl<S: AsRef<[T]>, T: PrimInt + BitAndAssign, V: TryFrom<usize>> ExactSizeIterator
82 for WordSetIter<S, T, V>
83{
84 #[inline]
85 fn len(&self) -> usize {
86 self.remaining_len()
87 }
88}
89
90impl<S: AsRef<[T]>, T: PrimInt + BitAndAssign, V: TryFrom<usize>> FusedIterator
91 for WordSetIter<S, T, V>
92{
93}
94
95pub type BitSliceIter<'a, T, V> = WordSetIter<&'a [T], T, V>;
97
98pub struct Drain<'a, T: PrimInt, V> {
103 words: &'a mut [T],
104 word_idx: usize,
105 current: PrimBitSetIter<T, usize>,
106 _marker: PhantomData<V>,
107}
108
109impl<T: PrimInt + BitAndAssign, V> Drain<'_, T, V> {
110 #[inline]
111 fn remaining_len(&self) -> usize {
112 self.current.len()
113 + self.words[self.word_idx..]
114 .iter()
115 .map(|w| w.count_ones() as usize)
116 .sum::<usize>()
117 }
118}
119
120impl<T: PrimInt + BitAndAssign, V: TryFrom<usize>> Iterator for Drain<'_, T, V> {
121 type Item = V;
122
123 fn next(&mut self) -> Option<V> {
124 let bits_per = core::mem::size_of::<T>() * 8;
125 loop {
126 if let Some(pos) = self.current.next() {
127 let idx = (self.word_idx - 1) * bits_per + pos;
128 let converted = V::try_from(idx);
129 debug_assert!(converted.is_ok());
130 match converted {
131 Ok(value) => return Some(value),
132 Err(_) => unsafe { core::hint::unreachable_unchecked() },
133 }
134 }
135 if self.word_idx >= self.words.len() {
136 return None;
137 }
138 self.current = PrimBitSetIter::from_raw(self.words[self.word_idx]);
139 self.words[self.word_idx] = T::zero();
140 self.word_idx += 1;
141 }
142 }
143
144 #[inline]
145 fn size_hint(&self) -> (usize, Option<usize>) {
146 let len = self.remaining_len();
147 (len, Some(len))
148 }
149
150 #[inline]
151 fn count(self) -> usize
152 where
153 Self: Sized,
154 {
155 self.remaining_len()
156 }
157}
158
159impl<T: PrimInt + BitAndAssign, V: TryFrom<usize>> ExactSizeIterator for Drain<'_, T, V> {
160 #[inline]
161 fn len(&self) -> usize {
162 self.remaining_len()
163 }
164}
165
166impl<T: PrimInt + BitAndAssign, V: TryFrom<usize>> FusedIterator for Drain<'_, T, V> {}
167
168impl<T: PrimInt, V> Drop for Drain<'_, T, V> {
169 fn drop(&mut self) {
170 for w in &mut self.words[self.word_idx..] {
172 *w = T::zero();
173 }
174 }
175}
176
177#[repr(transparent)]
185pub struct BitSlice<T, V>(PhantomData<V>, [T]);
186
187impl<T, V> BitSlice<T, V> {
188 pub(crate) fn from_slice_ref(s: &[T]) -> &Self {
189 unsafe { &*(s as *const [T] as *const Self) }
192 }
193
194 pub(crate) fn from_slice_mut(s: &mut [T]) -> &mut Self {
195 unsafe { &mut *(s as *mut [T] as *mut Self) }
197 }
198}
199
200impl<T: PrimInt, V> BitSlice<T, V> {
201 const BITS_PER: usize = core::mem::size_of::<T>() * 8;
202
203 #[inline]
204 fn index_of(idx: usize) -> (usize, T) {
205 (
206 idx / Self::BITS_PER,
207 T::one().unsigned_shl((idx % Self::BITS_PER) as u32),
208 )
209 }
210
211 #[inline]
212 pub fn capacity(&self) -> usize {
213 self.1.len() * Self::BITS_PER
214 }
215
216 #[inline]
217 pub fn len(&self) -> usize {
218 self.1.iter().map(|w| w.count_ones() as usize).sum()
219 }
220
221 #[inline]
222 pub fn is_empty(&self) -> bool {
223 self.1.iter().all(|w| w.is_zero())
224 }
225
226 #[inline]
227 pub fn first(&self) -> Option<V>
228 where
229 V: TryFrom<usize>,
230 {
231 for (i, &word) in self.1.iter().enumerate() {
232 if !word.is_zero() {
233 let bit = word.trailing_zeros() as usize;
234 let idx = i * Self::BITS_PER + bit;
235 let converted = V::try_from(idx);
236 debug_assert!(converted.is_ok());
237 return Some(match converted {
238 Ok(value) => value,
239 Err(_) => unsafe { core::hint::unreachable_unchecked() },
240 });
241 }
242 }
243 None
244 }
245
246 #[inline]
247 pub fn last(&self) -> Option<V>
248 where
249 V: TryFrom<usize>,
250 {
251 for (i, &word) in self.1.iter().enumerate().rev() {
252 if !word.is_zero() {
253 let bit = Self::BITS_PER - 1 - word.leading_zeros() as usize;
254 let idx = i * Self::BITS_PER + bit;
255 let converted = V::try_from(idx);
256 debug_assert!(converted.is_ok());
257 return Some(match converted {
258 Ok(value) => value,
259 Err(_) => unsafe { core::hint::unreachable_unchecked() },
260 });
261 }
262 }
263 None
264 }
265
266 #[inline]
267 pub fn pop_first(&mut self) -> Option<V>
268 where
269 V: TryFrom<usize>,
270 {
271 for (i, word) in self.1.iter_mut().enumerate() {
272 if !word.is_zero() {
273 let bit = word.trailing_zeros() as usize;
274 let mask = T::one().unsigned_shl(bit as u32);
275 *word = *word & !mask;
276 let idx = i * Self::BITS_PER + bit;
277 let converted = V::try_from(idx);
278 debug_assert!(converted.is_ok());
279 return Some(match converted {
280 Ok(value) => value,
281 Err(_) => unsafe { core::hint::unreachable_unchecked() },
282 });
283 }
284 }
285 None
286 }
287
288 #[inline]
289 pub fn pop_last(&mut self) -> Option<V>
290 where
291 V: TryFrom<usize>,
292 {
293 for (i, word) in self.1.iter_mut().enumerate().rev() {
294 if !word.is_zero() {
295 let bit = Self::BITS_PER - 1 - word.leading_zeros() as usize;
296 let mask = T::one().unsigned_shl(bit as u32);
297 *word = *word & !mask;
298 let idx = i * Self::BITS_PER + bit;
299 let converted = V::try_from(idx);
300 debug_assert!(converted.is_ok());
301 return Some(match converted {
302 Ok(value) => value,
303 Err(_) => unsafe { core::hint::unreachable_unchecked() },
304 });
305 }
306 }
307 None
308 }
309
310 #[inline]
311 pub fn contains(&self, id: &V) -> bool
312 where
313 V: Copy + AsPrimitive<usize>,
314 {
315 let idx = (*id).as_();
316 debug_assert!(
317 idx < self.capacity(),
318 "index {idx} out of range for capacity {}",
319 self.capacity()
320 );
321 let (seg, mask) = Self::index_of(idx);
322 self.1.get(seg).is_some_and(|w| *w & mask != T::zero())
323 }
324
325 #[inline]
326 pub fn set(&mut self, id: V, value: bool)
327 where
328 V: AsPrimitive<usize>,
329 {
330 let idx = id.as_();
331 debug_assert!(
332 idx < self.capacity(),
333 "index {idx} out of range for capacity {}",
334 self.capacity()
335 );
336 let (seg, mask) = Self::index_of(idx);
337 if let Some(word) = self.1.get_mut(seg) {
338 if value {
339 *word = *word | mask;
340 } else {
341 *word = *word & !mask;
342 }
343 }
344 }
345
346 #[inline]
347 pub fn insert(&mut self, id: V) -> bool
348 where
349 V: AsPrimitive<usize>,
350 {
351 let idx = id.as_();
352 debug_assert!(
353 idx < self.capacity(),
354 "index {idx} out of range for capacity {}",
355 self.capacity()
356 );
357 let (seg, mask) = Self::index_of(idx);
358 let Some(word) = self.1.get_mut(seg) else {
359 return false;
360 };
361 let was_absent = *word & mask == T::zero();
362 *word = *word | mask;
363 was_absent
364 }
365
366 #[inline]
367 pub fn remove(&mut self, id: V) -> bool
368 where
369 V: AsPrimitive<usize>,
370 {
371 let idx = id.as_();
372 debug_assert!(
373 idx < self.capacity(),
374 "index {idx} out of range for capacity {}",
375 self.capacity()
376 );
377 let (seg, mask) = Self::index_of(idx);
378 let Some(word) = self.1.get_mut(seg) else {
379 return false;
380 };
381 let was_present = *word & mask != T::zero();
382 *word = *word & !mask;
383 was_present
384 }
385
386 #[inline]
387 pub fn toggle(&mut self, id: V)
388 where
389 V: AsPrimitive<usize>,
390 {
391 let idx = id.as_();
392 debug_assert!(
393 idx < self.capacity(),
394 "index {idx} out of range for capacity {}",
395 self.capacity()
396 );
397 let (seg, mask) = Self::index_of(idx);
398 if let Some(word) = self.1.get_mut(seg) {
399 *word = *word ^ mask;
400 }
401 }
402
403 #[inline]
404 pub fn clear(&mut self) {
405 self.1.fill(T::zero());
406 }
407
408 #[inline]
409 pub fn drain(&mut self) -> Drain<'_, T, V>
410 where
411 T: BitAndAssign,
412 V: TryFrom<usize>,
413 {
414 Drain {
415 words: &mut self.1,
416 word_idx: 0,
417 current: PrimBitSetIter::empty(),
418 _marker: PhantomData,
419 }
420 }
421
422 pub fn retain(&mut self, mut f: impl FnMut(V) -> bool)
423 where
424 V: TryFrom<usize>,
425 {
426 for (i, word) in self.1.iter_mut().enumerate() {
427 let mut w = *word;
428 while !w.is_zero() {
429 let bit = w.trailing_zeros() as usize;
430 let mask = T::one().unsigned_shl(bit as u32);
431 w = w & !mask;
432 let idx = i * Self::BITS_PER + bit;
433 let converted = V::try_from(idx);
434 debug_assert!(converted.is_ok());
435 let value = match converted {
436 Ok(v) => v,
437 Err(_) => unsafe { core::hint::unreachable_unchecked() },
438 };
439 if !f(value) {
440 *word = *word & !mask;
441 }
442 }
443 }
444 }
445
446 pub fn union_from(&mut self, other: &Self) {
447 let min = self.1.len().min(other.1.len());
448 for i in 0..min {
449 self.1[i] = self.1[i] | other.1[i];
450 }
451 }
452
453 pub fn append(&mut self, other: &mut Self) {
454 let min = self.1.len().min(other.1.len());
455 for i in 0..min {
456 self.1[i] = self.1[i] | other.1[i];
457 other.1[i] = T::zero();
458 }
459 }
460
461 #[inline]
462 pub fn iter(&self) -> BitSliceIter<'_, T, V>
463 where
464 T: BitAndAssign,
465 V: TryFrom<usize>,
466 {
467 WordSetIter::new(&self.1)
468 }
469
470 #[inline]
471 pub fn is_subset(&self, other: &Self) -> bool {
472 let min = self.1.len().min(other.1.len());
473 self.1[..min]
474 .iter()
475 .zip(other.1[..min].iter())
476 .all(|(a, b)| *a & *b == *a)
477 && self.1[min..].iter().all(|w| w.is_zero())
478 }
479
480 #[inline]
481 pub fn is_superset(&self, other: &Self) -> bool {
482 other.is_subset(self)
483 }
484
485 #[inline]
486 pub fn is_disjoint(&self, other: &Self) -> bool {
487 self.1
488 .iter()
489 .zip(other.1.iter())
490 .all(|(a, b)| (*a & *b).is_zero())
491 }
492
493 fn word_op_iter<'a>(
494 a: &'a [T],
495 b: &'a [T],
496 len: usize,
497 op: impl Fn(T, T) -> T + 'a,
498 ) -> impl Iterator<Item = V> + 'a
499 where
500 T: BitAndAssign,
501 V: TryFrom<usize>,
502 {
503 let bits_per = Self::BITS_PER;
504 (0..len).flat_map(move |i| {
505 let w_a = a.get(i).copied().unwrap_or(T::zero());
506 let w_b = b.get(i).copied().unwrap_or(T::zero());
507 let combined = op(w_a, w_b);
508 let offset = i * bits_per;
509 PrimBitSetIter::<T, usize>(combined, PhantomData).map(move |pos| {
510 let idx = offset + pos;
511 debug_assert!(V::try_from(idx).is_ok());
512 match V::try_from(idx) {
513 Ok(v) => v,
514 Err(_) => unsafe { core::hint::unreachable_unchecked() },
515 }
516 })
517 })
518 }
519
520 #[inline]
521 pub fn difference<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
522 where
523 T: BitAndAssign,
524 V: TryFrom<usize>,
525 {
526 Self::word_op_iter(&self.1, &other.1, self.1.len(), |a, b| a & !b)
527 }
528
529 #[inline]
530 pub fn intersection<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
531 where
532 T: BitAndAssign,
533 V: TryFrom<usize>,
534 {
535 Self::word_op_iter(
536 &self.1,
537 &other.1,
538 self.1.len().min(other.1.len()),
539 |a, b| a & b,
540 )
541 }
542
543 #[inline]
544 pub fn union<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
545 where
546 T: BitAndAssign,
547 V: TryFrom<usize>,
548 {
549 Self::word_op_iter(
550 &self.1,
551 &other.1,
552 self.1.len().max(other.1.len()),
553 |a, b| a | b,
554 )
555 }
556
557 #[inline]
558 pub fn symmetric_difference<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
559 where
560 T: BitAndAssign,
561 V: TryFrom<usize>,
562 {
563 Self::word_op_iter(
564 &self.1,
565 &other.1,
566 self.1.len().max(other.1.len()),
567 |a, b| a ^ b,
568 )
569 }
570
571 #[cfg(feature = "bitvec")]
574 #[inline]
575 pub fn as_bitvec_slice(&self) -> &bitvec::slice::BitSlice<T, bitvec::order::Lsb0>
576 where
577 T: bitvec::store::BitStore,
578 {
579 bitvec::slice::BitSlice::from_slice(&self.1)
580 }
581
582 #[cfg(feature = "bitvec")]
583 #[inline]
584 pub fn as_mut_bitvec_slice(&mut self) -> &mut bitvec::slice::BitSlice<T, bitvec::order::Lsb0>
585 where
586 T: bitvec::store::BitStore,
587 {
588 bitvec::slice::BitSlice::from_slice_mut(&mut self.1)
589 }
590
591 #[inline]
593 pub fn raw_words(&self) -> &[T] {
594 &self.1
595 }
596}
597
598impl<'a, T: PrimInt + BitAndAssign, V: TryFrom<usize>> IntoIterator for &'a BitSlice<T, V> {
599 type Item = V;
600 type IntoIter = BitSliceIter<'a, T, V>;
601
602 #[inline]
603 fn into_iter(self) -> Self::IntoIter {
604 self.iter()
605 }
606}
607
608impl<T: PrimInt, V> PartialEq for BitSlice<T, V> {
609 fn eq(&self, other: &Self) -> bool {
610 let min = self.1.len().min(other.1.len());
611 self.1[..min] == other.1[..min]
612 && self.1[min..].iter().all(|w| w.is_zero())
613 && other.1[min..].iter().all(|w| w.is_zero())
614 }
615}
616
617impl<T: PrimInt, V> Eq for BitSlice<T, V> {}
618
619impl<T: PrimInt + Hash, V> Hash for BitSlice<T, V> {
620 fn hash<H: Hasher>(&self, state: &mut H) {
621 let effective_len = self
623 .1
624 .iter()
625 .rposition(|w| !w.is_zero())
626 .map_or(0, |i| i + 1);
627 for w in &self.1[..effective_len] {
628 w.hash(state);
629 }
630 }
631}
632
633impl<T: PrimInt + BitAndAssign, V> core::fmt::Debug for BitSlice<T, V> {
634 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
635 let bits_per = core::mem::size_of::<T>() * 8;
636 f.write_str("{")?;
637 let mut first = true;
638 for (i, &word) in self.1.iter().enumerate() {
639 let offset = i * bits_per;
640 for pos in PrimBitSetIter::<T, usize>(word, PhantomData) {
641 if !first {
642 f.write_str(", ")?;
643 }
644 first = false;
645 write!(f, "{}", offset + pos)?;
646 }
647 }
648 f.write_str("}")
649 }
650}
651
652impl<T: PrimInt + BitAndAssign, V> core::fmt::Display for BitSlice<T, V> {
653 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
654 core::fmt::Debug::fmt(self, f)
655 }
656}