1mod iter;
7
8use sealed::Array;
9pub(crate) use sealed::BitValueImpl;
10
11use std::{ffi::c_ulong, fmt, slice};
12
13mod sealed {
14 use super::Word;
15
16 pub trait BitValueImpl {
17 #[doc(hidden)]
18 type __PrivateArray: AsRef<[Word]>
19 + AsMut<[Word]>
20 + Copy
21 + IntoIterator<Item = Word, IntoIter: Clone>;
22 #[doc(hidden)]
23 const __PRIVATE_ZERO: Self::__PrivateArray;
24 fn from_index(index: usize) -> Self;
28 fn into_index(self) -> usize;
29 }
30
31 pub(crate) type Array<V> = <V as BitValueImpl>::__PrivateArray;
32}
33
34pub type Word = c_ulong;
38
39pub trait BitValue: Copy + sealed::BitValueImpl {
44 const MAX: Self;
53}
54
55pub struct BitSet<V: BitValue> {
57 pub(crate) words: Array<V>,
58}
59
60impl<V: BitValue> Copy for BitSet<V> {}
61impl<V: BitValue> Clone for BitSet<V> {
62 fn clone(&self) -> Self {
63 *self
64 }
65}
66impl<V: BitValue> Default for BitSet<V> {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl<V: BitValue> BitSet<V> {
73 pub const fn new() -> Self {
75 Self {
76 words: V::__PRIVATE_ZERO,
77 }
78 }
79
80 pub fn words(&self) -> &[Word] {
92 self.words.as_ref()
93 }
94
95 pub fn words_mut(&mut self) -> &mut [Word] {
103 self.words.as_mut()
104 }
105
106 pub fn len(&self) -> usize {
108 self.words
109 .as_ref()
110 .iter()
111 .map(|w| w.count_ones() as usize)
112 .sum::<usize>()
113 }
114
115 pub fn is_empty(&self) -> bool {
117 self.words.as_ref().iter().all(|&w| w == 0)
118 }
119
120 pub fn contains(&self, value: V) -> bool {
122 if value.into_index() > V::MAX.into_index() {
123 return false;
124 }
125 let index = value.into_index();
126 let wordpos = index / Word::BITS as usize;
127 let bitpos = index % Word::BITS as usize;
128
129 let word = self.words.as_ref()[wordpos];
130 let bit = word & (1 << bitpos) != 0;
131 bit
132 }
133
134 pub fn insert(&mut self, value: V) -> bool {
142 assert!(
143 value.into_index() <= V::MAX.into_index(),
144 "value out of range for `BitSet` storage (value's index is {}, max is {})",
145 value.into_index(),
146 V::MAX.into_index(),
147 );
148
149 let present = self.contains(value);
150
151 let index = value.into_index();
152 let wordpos = index / Word::BITS as usize;
153 let bitpos = index % Word::BITS as usize;
154 self.words.as_mut()[wordpos] |= 1 << bitpos;
155 present
156 }
157
158 pub fn remove(&mut self, value: V) -> bool {
162 if value.into_index() > V::MAX.into_index() {
163 return false;
164 }
165 let present = self.contains(value);
166
167 let index = value.into_index();
168 let wordpos = index / Word::BITS as usize;
169 let bitpos = index % Word::BITS as usize;
170 self.words.as_mut()[wordpos] &= !(1 << bitpos);
171 present
172 }
173
174 pub fn iter(&self) -> Iter<'_, V> {
176 Iter {
177 imp: iter::IterImpl::new(self.words.as_ref().iter().copied()),
178 }
179 }
180
181 pub(crate) fn symmetric_difference<'a>(
184 &'a self,
185 other: &'a BitSet<V>,
186 ) -> SymmetricDifference<'a, V> {
187 SymmetricDifference {
188 imp: iter::IterImpl::new(SymmDiffWords {
189 a: self.words.as_ref().iter(),
190 b: other.words.as_ref().iter(),
191 }),
192 }
193 }
194}
195
196impl<V: BitValue + fmt::Debug> fmt::Debug for BitSet<V> {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 f.debug_set().entries(self.iter()).finish()
199 }
200}
201
202impl<V: BitValue> PartialEq for BitSet<V> {
203 fn eq(&self, other: &Self) -> bool {
204 self.words.as_ref() == other.words.as_ref()
205 }
206}
207impl<V: BitValue> Eq for BitSet<V> {}
208
209impl<V: BitValue> FromIterator<V> for BitSet<V> {
210 fn from_iter<T: IntoIterator<Item = V>>(iter: T) -> Self {
211 let mut this = Self::new();
212 this.extend(iter);
213 this
214 }
215}
216impl<V: BitValue> Extend<V> for BitSet<V> {
217 fn extend<T: IntoIterator<Item = V>>(&mut self, iter: T) {
218 for item in iter {
219 self.insert(item);
220 }
221 }
222}
223
224impl<'a, V: BitValue> IntoIterator for &'a BitSet<V> {
225 type Item = V;
226 type IntoIter = Iter<'a, V>;
227
228 fn into_iter(self) -> Self::IntoIter {
229 self.iter()
230 }
231}
232impl<V: BitValue> IntoIterator for BitSet<V> {
233 type Item = V;
234 type IntoIter = IntoIter<V>;
235
236 fn into_iter(self) -> Self::IntoIter {
237 IntoIter {
238 imp: iter::IterImpl::new(self.words.into_iter()),
239 }
240 }
241}
242
243pub struct IntoIter<V: BitValue> {
245 imp: iter::IterImpl<V, <Array<V> as IntoIterator>::IntoIter>,
246}
247impl<V: BitValue> Iterator for IntoIter<V> {
248 type Item = V;
249 fn next(&mut self) -> Option<Self::Item> {
250 self.imp.next()
251 }
252}
253impl<V: BitValue + fmt::Debug> fmt::Debug for IntoIter<V> {
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 f.debug_tuple("IntoIter")
256 .field(&DebugAsSet(self.imp.clone()))
257 .finish()
258 }
259}
260
261pub struct Iter<'a, V: BitValue> {
263 imp: iter::IterImpl<V, std::iter::Copied<slice::Iter<'a, Word>>>,
264}
265
266impl<V: BitValue> Iterator for Iter<'_, V> {
267 type Item = V;
268
269 fn next(&mut self) -> Option<Self::Item> {
270 self.imp.next()
271 }
272}
273
274impl<V: BitValue + fmt::Debug> fmt::Debug for Iter<'_, V> {
275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276 f.debug_tuple("Iter")
277 .field(&DebugAsSet(self.imp.clone()))
278 .finish()
279 }
280}
281
282struct DebugAsSet<I>(I);
283impl<I: Clone + Iterator> fmt::Debug for DebugAsSet<I>
284where
285 I::Item: fmt::Debug,
286{
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 f.debug_set().entries(self.0.clone()).finish()
289 }
290}
291
292pub(crate) struct SymmetricDifference<'a, V: BitValue> {
296 imp: iter::IterImpl<V, SymmDiffWords<'a>>,
297}
298
299struct SymmDiffWords<'a> {
300 a: slice::Iter<'a, Word>,
301 b: slice::Iter<'a, Word>,
302}
303
304impl Iterator for SymmDiffWords<'_> {
305 type Item = Word;
306 fn next(&mut self) -> Option<Word> {
307 Some(self.a.next().copied()? ^ self.b.next().copied()?)
308 }
309}
310
311impl<V: BitValue> Iterator for SymmetricDifference<'_, V> {
312 type Item = V;
313
314 fn next(&mut self) -> Option<Self::Item> {
315 self.imp.next()
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use crate::{
322 InputProp,
323 event::{Abs, EventType, Key, Led, Misc, Rel},
324 };
325
326 use super::*;
327
328 #[test]
329 fn sizes() {
330 assert_eq!(size_of::<BitSet<EventType>>(), size_of::<Word>());
332 assert_eq!(size_of::<BitSet<InputProp>>(), size_of::<Word>());
333 assert_eq!(size_of::<BitSet<Rel>>(), size_of::<Word>());
334 assert_eq!(size_of::<BitSet<Misc>>(), size_of::<Word>());
335 assert_eq!(size_of::<BitSet<Led>>(), size_of::<Word>());
336 }
337
338 #[test]
339 fn bit0() {
340 let mut set = BitSet::new();
341 set.insert(InputProp(0));
342
343 assert!(set.contains(InputProp::POINTER));
344 assert!(!set.contains(InputProp::DIRECT));
345 assert!(!set.contains(InputProp::MAX));
346 assert!(!set.contains(InputProp(InputProp::MAX.0 + 1)));
347 assert!(!set.contains(InputProp(u8::MAX)));
348
349 assert_eq!(set.iter().collect::<Vec<_>>(), &[InputProp::POINTER]);
350 }
351
352 #[test]
353 fn max() {
354 let mut set = BitSet::new();
355 set.insert(InputProp::MAX);
356
357 assert!(!set.contains(InputProp::POINTER));
358 assert!(!set.contains(InputProp::DIRECT));
359 assert!(set.contains(InputProp::MAX));
360 assert!(!set.contains(InputProp(InputProp::MAX.0 + 1)));
361 assert!(!set.remove(InputProp(InputProp::MAX.0 + 1)));
362 }
363
364 #[test]
365 #[should_panic = "value out of range for `BitSet`"]
366 fn above_max() {
367 let mut set = BitSet::new();
368 set.insert(Abs::from_raw(Abs::MAX.raw() + 1));
369 }
370
371 #[test]
372 fn debug() {
373 let set = BitSet::from_iter([Abs::X, Abs::Y, Abs::BRAKE]);
374 assert_eq!(format!("{set:?}"), "{ABS_X, ABS_Y, ABS_BRAKE}");
375
376 let mut iter = set.iter();
377 assert_eq!(iter.next(), Some(Abs::X));
378 assert_eq!(format!("{iter:?}"), "Iter({ABS_Y, ABS_BRAKE})");
379
380 let mut iter = set.into_iter();
381 assert_eq!(iter.next(), Some(Abs::X));
382 assert_eq!(format!("{iter:?}"), "IntoIter({ABS_Y, ABS_BRAKE})");
383 }
384
385 #[test]
386 fn multiple() {
387 let mut set = BitSet::new();
388 set.insert(Key::KEY_RESERVED);
389 set.insert(Key::KEY_Q);
390 set.insert(Key::MAX);
391 set.insert(Key::KEY_MACRO1);
392
393 assert_eq!(
394 set.iter().collect::<Vec<_>>(),
395 &[Key::KEY_RESERVED, Key::KEY_Q, Key::KEY_MACRO1, Key::MAX]
396 );
397 }
398
399 #[test]
400 fn symmdiff() {
401 let mut a = BitSet::new();
402 a.insert(Key::KEY_B);
403
404 assert_eq!(
405 a.symmetric_difference(&BitSet::new()).collect::<Vec<_>>(),
406 &[Key::KEY_B]
407 );
408 assert_eq!(
409 BitSet::new().symmetric_difference(&a).collect::<Vec<_>>(),
410 &[Key::KEY_B]
411 );
412
413 let mut b = BitSet::new();
414 b.insert(Key::KEY_A);
415
416 assert_eq!(
417 a.symmetric_difference(&b).collect::<Vec<_>>(),
418 &[Key::KEY_A, Key::KEY_B]
419 );
420 assert_eq!(
421 b.symmetric_difference(&a).collect::<Vec<_>>(),
422 &[Key::KEY_A, Key::KEY_B]
423 );
424
425 assert_eq!(a.symmetric_difference(&a).collect::<Vec<_>>(), &[]);
426 assert_eq!(
427 BitSet::<Key>::new()
428 .symmetric_difference(&BitSet::new())
429 .collect::<Vec<_>>(),
430 &[]
431 );
432 }
433}