1use core::hash::{Hash, Hasher};
2use core::iter::FusedIterator;
3use core::marker::PhantomData;
4use core::ops::BitAndAssign;
5use num_traits::{AsPrimitive, PrimInt};
6
7use super::bitset::PrimBitSetIter;
8
9pub struct BitSliceIter<'a, T: PrimInt, V> {
11 words: &'a [T],
12 word_idx: usize,
13 current: PrimBitSetIter<T, usize>,
14 _marker: PhantomData<V>,
15}
16
17impl<T: PrimInt + BitAndAssign, V> BitSliceIter<'_, T, V> {
18 #[inline]
19 fn remaining_len(&self) -> usize {
20 self.current.len()
21 + self.words[self.word_idx..]
22 .iter()
23 .map(|w| w.count_ones() as usize)
24 .sum::<usize>()
25 }
26}
27
28impl<'a, T: PrimInt + BitAndAssign, V: TryFrom<usize>> Iterator for BitSliceIter<'a, T, V> {
29 type Item = V;
30
31 fn next(&mut self) -> Option<V> {
32 let bits_per = core::mem::size_of::<T>() * 8;
33 loop {
34 if let Some(pos) = self.current.next() {
35 let idx = (self.word_idx - 1) * bits_per + pos;
36 let converted = V::try_from(idx);
37 debug_assert!(converted.is_ok());
38 match converted {
39 Ok(value) => return Some(value),
40 Err(_) => unsafe { core::hint::unreachable_unchecked() },
41 }
42 }
43 if self.word_idx >= self.words.len() {
44 return None;
45 }
46 self.current = PrimBitSetIter::from_raw(self.words[self.word_idx]);
47 self.word_idx += 1;
48 }
49 }
50
51 #[inline]
52 fn size_hint(&self) -> (usize, Option<usize>) {
53 let len = self.remaining_len();
54 (len, Some(len))
55 }
56
57 #[inline]
58 fn count(self) -> usize
59 where
60 Self: Sized,
61 {
62 self.remaining_len()
63 }
64}
65
66impl<T: PrimInt + BitAndAssign, V: TryFrom<usize>> ExactSizeIterator for BitSliceIter<'_, T, V> {
67 #[inline]
68 fn len(&self) -> usize {
69 self.remaining_len()
70 }
71}
72
73impl<T: PrimInt + BitAndAssign, V: TryFrom<usize>> FusedIterator for BitSliceIter<'_, T, V> {}
74
75#[repr(transparent)]
83pub struct BitSlice<T, V>(PhantomData<V>, [T]);
84
85impl<T, V> BitSlice<T, V> {
86 pub(crate) fn from_slice_ref(s: &[T]) -> &Self {
87 unsafe { &*(s as *const [T] as *const Self) }
90 }
91
92 pub(crate) fn from_slice_mut(s: &mut [T]) -> &mut Self {
93 unsafe { &mut *(s as *mut [T] as *mut Self) }
95 }
96}
97
98impl<T: PrimInt, V> BitSlice<T, V> {
99 const BITS_PER: usize = core::mem::size_of::<T>() * 8;
100
101 #[inline]
102 fn index_of(idx: usize) -> (usize, T) {
103 (
104 idx / Self::BITS_PER,
105 T::one().unsigned_shl((idx % Self::BITS_PER) as u32),
106 )
107 }
108
109 #[inline]
110 pub fn capacity(&self) -> usize {
111 self.1.len() * Self::BITS_PER
112 }
113
114 #[inline]
115 pub fn len(&self) -> usize {
116 self.1.iter().map(|w| w.count_ones() as usize).sum()
117 }
118
119 #[inline]
120 pub fn is_empty(&self) -> bool {
121 self.1.iter().all(|w| w.is_zero())
122 }
123
124 #[inline]
125 pub fn first(&self) -> Option<V>
126 where
127 V: TryFrom<usize>,
128 {
129 for (i, &word) in self.1.iter().enumerate() {
130 if !word.is_zero() {
131 let bit = word.trailing_zeros() as usize;
132 let idx = i * Self::BITS_PER + bit;
133 let converted = V::try_from(idx);
134 debug_assert!(converted.is_ok());
135 return Some(match converted {
136 Ok(value) => value,
137 Err(_) => unsafe { core::hint::unreachable_unchecked() },
138 });
139 }
140 }
141 None
142 }
143
144 #[inline]
145 pub fn last(&self) -> Option<V>
146 where
147 V: TryFrom<usize>,
148 {
149 for (i, &word) in self.1.iter().enumerate().rev() {
150 if !word.is_zero() {
151 let bit = Self::BITS_PER - 1 - word.leading_zeros() as usize;
152 let idx = i * Self::BITS_PER + bit;
153 let converted = V::try_from(idx);
154 debug_assert!(converted.is_ok());
155 return Some(match converted {
156 Ok(value) => value,
157 Err(_) => unsafe { core::hint::unreachable_unchecked() },
158 });
159 }
160 }
161 None
162 }
163
164 #[inline]
165 pub fn pop_first(&mut self) -> Option<V>
166 where
167 V: TryFrom<usize>,
168 {
169 for (i, word) in self.1.iter_mut().enumerate() {
170 if !word.is_zero() {
171 let bit = word.trailing_zeros() as usize;
172 let mask = T::one().unsigned_shl(bit as u32);
173 *word = *word & !mask;
174 let idx = i * Self::BITS_PER + bit;
175 let converted = V::try_from(idx);
176 debug_assert!(converted.is_ok());
177 return Some(match converted {
178 Ok(value) => value,
179 Err(_) => unsafe { core::hint::unreachable_unchecked() },
180 });
181 }
182 }
183 None
184 }
185
186 #[inline]
187 pub fn pop_last(&mut self) -> Option<V>
188 where
189 V: TryFrom<usize>,
190 {
191 for (i, word) in self.1.iter_mut().enumerate().rev() {
192 if !word.is_zero() {
193 let bit = Self::BITS_PER - 1 - word.leading_zeros() as usize;
194 let mask = T::one().unsigned_shl(bit as u32);
195 *word = *word & !mask;
196 let idx = i * Self::BITS_PER + bit;
197 let converted = V::try_from(idx);
198 debug_assert!(converted.is_ok());
199 return Some(match converted {
200 Ok(value) => value,
201 Err(_) => unsafe { core::hint::unreachable_unchecked() },
202 });
203 }
204 }
205 None
206 }
207
208 #[inline]
209 pub fn contains(&self, id: &V) -> bool
210 where
211 V: Copy + AsPrimitive<usize>,
212 {
213 let idx = (*id).as_();
214 debug_assert!(
215 idx < self.capacity(),
216 "index {idx} out of range for capacity {}",
217 self.capacity()
218 );
219 let (seg, mask) = Self::index_of(idx);
220 self.1.get(seg).is_some_and(|w| *w & mask != T::zero())
221 }
222
223 #[inline]
224 pub fn set(&mut self, id: V, value: bool)
225 where
226 V: AsPrimitive<usize>,
227 {
228 let idx = id.as_();
229 debug_assert!(
230 idx < self.capacity(),
231 "index {idx} out of range for capacity {}",
232 self.capacity()
233 );
234 let (seg, mask) = Self::index_of(idx);
235 if let Some(word) = self.1.get_mut(seg) {
236 if value {
237 *word = *word | mask;
238 } else {
239 *word = *word & !mask;
240 }
241 }
242 }
243
244 #[inline]
245 pub fn insert(&mut self, id: V) -> bool
246 where
247 V: AsPrimitive<usize>,
248 {
249 let idx = id.as_();
250 debug_assert!(
251 idx < self.capacity(),
252 "index {idx} out of range for capacity {}",
253 self.capacity()
254 );
255 let (seg, mask) = Self::index_of(idx);
256 let Some(word) = self.1.get_mut(seg) else {
257 return false;
258 };
259 let was_absent = *word & mask == T::zero();
260 *word = *word | mask;
261 was_absent
262 }
263
264 #[inline]
265 pub fn remove(&mut self, id: V) -> bool
266 where
267 V: AsPrimitive<usize>,
268 {
269 let idx = id.as_();
270 debug_assert!(
271 idx < self.capacity(),
272 "index {idx} out of range for capacity {}",
273 self.capacity()
274 );
275 let (seg, mask) = Self::index_of(idx);
276 let Some(word) = self.1.get_mut(seg) else {
277 return false;
278 };
279 let was_present = *word & mask != T::zero();
280 *word = *word & !mask;
281 was_present
282 }
283
284 #[inline]
285 pub fn toggle(&mut self, id: V)
286 where
287 V: AsPrimitive<usize>,
288 {
289 let idx = id.as_();
290 debug_assert!(
291 idx < self.capacity(),
292 "index {idx} out of range for capacity {}",
293 self.capacity()
294 );
295 let (seg, mask) = Self::index_of(idx);
296 if let Some(word) = self.1.get_mut(seg) {
297 *word = *word ^ mask;
298 }
299 }
300
301 #[inline]
302 pub fn clear(&mut self) {
303 self.1.fill(T::zero());
304 }
305
306 pub fn retain(&mut self, mut f: impl FnMut(V) -> bool)
307 where
308 V: TryFrom<usize>,
309 {
310 for (i, word) in self.1.iter_mut().enumerate() {
311 let mut w = *word;
312 while !w.is_zero() {
313 let bit = w.trailing_zeros() as usize;
314 let mask = T::one().unsigned_shl(bit as u32);
315 w = w & !mask;
316 let idx = i * Self::BITS_PER + bit;
317 let converted = V::try_from(idx);
318 debug_assert!(converted.is_ok());
319 let value = match converted {
320 Ok(v) => v,
321 Err(_) => unsafe { core::hint::unreachable_unchecked() },
322 };
323 if !f(value) {
324 *word = *word & !mask;
325 }
326 }
327 }
328 }
329
330 pub fn append(&mut self, other: &mut Self) {
331 let min = self.1.len().min(other.1.len());
332 for i in 0..min {
333 self.1[i] = self.1[i] | other.1[i];
334 other.1[i] = T::zero();
335 }
336 }
337
338 #[inline]
339 pub fn iter(&self) -> BitSliceIter<'_, T, V>
340 where
341 T: BitAndAssign,
342 V: TryFrom<usize>,
343 {
344 BitSliceIter {
345 words: &self.1,
346 word_idx: 0,
347 current: PrimBitSetIter::empty(),
348 _marker: PhantomData,
349 }
350 }
351
352 #[inline]
353 pub fn is_subset(&self, other: &Self) -> bool {
354 let min = self.1.len().min(other.1.len());
355 self.1[..min]
356 .iter()
357 .zip(other.1[..min].iter())
358 .all(|(a, b)| *a & *b == *a)
359 && self.1[min..].iter().all(|w| w.is_zero())
360 }
361
362 #[inline]
363 pub fn is_superset(&self, other: &Self) -> bool {
364 other.is_subset(self)
365 }
366
367 #[inline]
368 pub fn is_disjoint(&self, other: &Self) -> bool {
369 self.1
370 .iter()
371 .zip(other.1.iter())
372 .all(|(a, b)| (*a & *b).is_zero())
373 }
374
375 fn word_op_iter<'a>(
376 a: &'a [T],
377 b: &'a [T],
378 len: usize,
379 op: impl Fn(T, T) -> T + 'a,
380 ) -> impl Iterator<Item = V> + 'a
381 where
382 T: BitAndAssign,
383 V: TryFrom<usize>,
384 {
385 let bits_per = Self::BITS_PER;
386 (0..len).flat_map(move |i| {
387 let w_a = a.get(i).copied().unwrap_or(T::zero());
388 let w_b = b.get(i).copied().unwrap_or(T::zero());
389 let combined = op(w_a, w_b);
390 let offset = i * bits_per;
391 PrimBitSetIter::<T, usize>(combined, PhantomData).map(move |pos| {
392 let idx = offset + pos;
393 debug_assert!(V::try_from(idx).is_ok());
394 match V::try_from(idx) {
395 Ok(v) => v,
396 Err(_) => unsafe { core::hint::unreachable_unchecked() },
397 }
398 })
399 })
400 }
401
402 #[inline]
403 pub fn difference<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
404 where
405 T: BitAndAssign,
406 V: TryFrom<usize>,
407 {
408 Self::word_op_iter(&self.1, &other.1, self.1.len(), |a, b| a & !b)
409 }
410
411 #[inline]
412 pub fn intersection<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
413 where
414 T: BitAndAssign,
415 V: TryFrom<usize>,
416 {
417 Self::word_op_iter(
418 &self.1,
419 &other.1,
420 self.1.len().min(other.1.len()),
421 |a, b| a & b,
422 )
423 }
424
425 #[inline]
426 pub fn union<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
427 where
428 T: BitAndAssign,
429 V: TryFrom<usize>,
430 {
431 Self::word_op_iter(
432 &self.1,
433 &other.1,
434 self.1.len().max(other.1.len()),
435 |a, b| a | b,
436 )
437 }
438
439 #[inline]
440 pub fn symmetric_difference<'a>(&'a self, other: &'a Self) -> impl Iterator<Item = V> + 'a
441 where
442 T: BitAndAssign,
443 V: TryFrom<usize>,
444 {
445 Self::word_op_iter(
446 &self.1,
447 &other.1,
448 self.1.len().max(other.1.len()),
449 |a, b| a ^ b,
450 )
451 }
452
453 #[cfg(feature = "bitvec")]
456 #[inline]
457 pub fn as_bitvec_slice(&self) -> &bitvec::slice::BitSlice<T, bitvec::order::Lsb0>
458 where
459 T: bitvec::store::BitStore,
460 {
461 bitvec::slice::BitSlice::from_slice(&self.1)
462 }
463
464 #[cfg(feature = "bitvec")]
465 #[inline]
466 pub fn as_mut_bitvec_slice(&mut self) -> &mut bitvec::slice::BitSlice<T, bitvec::order::Lsb0>
467 where
468 T: bitvec::store::BitStore,
469 {
470 bitvec::slice::BitSlice::from_slice_mut(&mut self.1)
471 }
472
473 #[inline]
475 pub fn raw_words(&self) -> &[T] {
476 &self.1
477 }
478}
479
480impl<'a, T: PrimInt + BitAndAssign, V: TryFrom<usize>> IntoIterator for &'a BitSlice<T, V> {
481 type Item = V;
482 type IntoIter = BitSliceIter<'a, T, V>;
483
484 #[inline]
485 fn into_iter(self) -> Self::IntoIter {
486 self.iter()
487 }
488}
489
490impl<T: PrimInt, V> PartialEq for BitSlice<T, V> {
491 fn eq(&self, other: &Self) -> bool {
492 let min = self.1.len().min(other.1.len());
493 self.1[..min] == other.1[..min]
494 && self.1[min..].iter().all(|w| w.is_zero())
495 && other.1[min..].iter().all(|w| w.is_zero())
496 }
497}
498
499impl<T: PrimInt, V> Eq for BitSlice<T, V> {}
500
501impl<T: PrimInt + Hash, V> Hash for BitSlice<T, V> {
502 fn hash<H: Hasher>(&self, state: &mut H) {
503 let effective_len = self
505 .1
506 .iter()
507 .rposition(|w| !w.is_zero())
508 .map_or(0, |i| i + 1);
509 for w in &self.1[..effective_len] {
510 w.hash(state);
511 }
512 }
513}
514
515impl<T: PrimInt + BitAndAssign, V> core::fmt::Debug for BitSlice<T, V> {
516 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
517 let bits_per = core::mem::size_of::<T>() * 8;
518 f.write_str("{")?;
519 let mut first = true;
520 for (i, &word) in self.1.iter().enumerate() {
521 let offset = i * bits_per;
522 for pos in PrimBitSetIter::<T, usize>(word, PhantomData) {
523 if !first {
524 f.write_str(", ")?;
525 }
526 first = false;
527 write!(f, "{}", offset + pos)?;
528 }
529 }
530 f.write_str("}")
531 }
532}