1mod iter;
7
8use sealed::Array;
9pub(crate) use sealed::BitValueImpl;
10
11use std::{ffi::c_ulong, fmt, slice};
12
13mod sealed {
14 use super::Word;
15
16 pub trait BitValueImpl {
17 #[doc(hidden)]
18 type __PrivateArray: AsRef<[Word]>
19 + AsMut<[Word]>
20 + Copy
21 + IntoIterator<Item = Word, IntoIter: Clone>;
22 #[doc(hidden)]
23 const __PRIVATE_ZERO: Self::__PrivateArray;
24 fn from_index(index: usize) -> Self;
28 fn into_index(self) -> usize;
29 }
30
31 pub(crate) type Array<V> = <V as BitValueImpl>::__PrivateArray;
32}
33
34pub type Word = c_ulong;
38
39pub trait BitValue: Copy + sealed::BitValueImpl {
44 const MAX: Self;
48}
49
50pub struct BitSet<V: BitValue> {
52 pub(crate) words: Array<V>,
53}
54
55impl<V: BitValue> Copy for BitSet<V> {}
56impl<V: BitValue> Clone for BitSet<V> {
57 fn clone(&self) -> Self {
58 *self
59 }
60}
61impl<V: BitValue> Default for BitSet<V> {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl<V: BitValue> BitSet<V> {
68 pub const fn new() -> Self {
70 Self {
71 words: V::__PRIVATE_ZERO,
72 }
73 }
74
75 pub fn words(&self) -> &[Word] {
85 self.words.as_ref()
86 }
87
88 pub fn words_mut(&mut self) -> &mut [Word] {
96 self.words.as_mut()
97 }
98
99 pub fn len(&self) -> usize {
101 self.words
102 .as_ref()
103 .iter()
104 .map(|w| w.count_ones() as usize)
105 .sum::<usize>()
106 }
107
108 pub fn is_empty(&self) -> bool {
110 self.words.as_ref().iter().all(|&w| w == 0)
111 }
112
113 pub fn contains(&self, value: V) -> bool {
115 if value.into_index() > V::MAX.into_index() {
116 return false;
117 }
118 let index = value.into_index();
119 let wordpos = index / Word::BITS as usize;
120 let bitpos = index % Word::BITS as usize;
121
122 let word = self.words.as_ref()[wordpos];
123 let bit = word & (1 << bitpos) != 0;
124 bit
125 }
126
127 pub fn insert(&mut self, value: V) -> bool {
135 assert!(
136 value.into_index() <= V::MAX.into_index(),
137 "value out of range for `BitSet` storage (value's index is {}, max is {})",
138 value.into_index(),
139 V::MAX.into_index(),
140 );
141
142 let present = self.contains(value);
143
144 let index = value.into_index();
145 let wordpos = index / Word::BITS as usize;
146 let bitpos = index % Word::BITS as usize;
147 self.words.as_mut()[wordpos] |= 1 << bitpos;
148 present
149 }
150
151 pub fn remove(&mut self, value: V) -> bool {
155 if value.into_index() > V::MAX.into_index() {
156 return false;
157 }
158 let present = self.contains(value);
159
160 let index = value.into_index();
161 let wordpos = index / Word::BITS as usize;
162 let bitpos = index % Word::BITS as usize;
163 self.words.as_mut()[wordpos] &= !(1 << bitpos);
164 present
165 }
166
167 pub fn iter(&self) -> Iter<'_, V> {
169 Iter {
170 imp: iter::IterImpl::new(self.words.as_ref().iter().copied()),
171 }
172 }
173
174 pub(crate) fn symmetric_difference<'a>(
177 &'a self,
178 other: &'a BitSet<V>,
179 ) -> SymmetricDifference<'a, V> {
180 SymmetricDifference {
181 imp: iter::IterImpl::new(SymmDiffWords {
182 a: self.words.as_ref().iter(),
183 b: other.words.as_ref().iter(),
184 }),
185 }
186 }
187}
188
189impl<V: BitValue + fmt::Debug> fmt::Debug for BitSet<V> {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 f.debug_set().entries(self.iter()).finish()
192 }
193}
194
195impl<V: BitValue> PartialEq for BitSet<V> {
196 fn eq(&self, other: &Self) -> bool {
197 self.words.as_ref() == other.words.as_ref()
198 }
199}
200impl<V: BitValue> Eq for BitSet<V> {}
201
202impl<V: BitValue> FromIterator<V> for BitSet<V> {
203 fn from_iter<T: IntoIterator<Item = V>>(iter: T) -> Self {
204 let mut this = Self::new();
205 this.extend(iter);
206 this
207 }
208}
209impl<V: BitValue> Extend<V> for BitSet<V> {
210 fn extend<T: IntoIterator<Item = V>>(&mut self, iter: T) {
211 for item in iter {
212 self.insert(item);
213 }
214 }
215}
216
217impl<'a, V: BitValue> IntoIterator for &'a BitSet<V> {
218 type Item = V;
219 type IntoIter = Iter<'a, V>;
220
221 fn into_iter(self) -> Self::IntoIter {
222 self.iter()
223 }
224}
225impl<V: BitValue> IntoIterator for BitSet<V> {
226 type Item = V;
227 type IntoIter = IntoIter<V>;
228
229 fn into_iter(self) -> Self::IntoIter {
230 IntoIter {
231 imp: iter::IterImpl::new(self.words.into_iter()),
232 }
233 }
234}
235
236pub struct IntoIter<V: BitValue> {
238 imp: iter::IterImpl<V, <Array<V> as IntoIterator>::IntoIter>,
239}
240impl<V: BitValue> Iterator for IntoIter<V> {
241 type Item = V;
242 fn next(&mut self) -> Option<Self::Item> {
243 self.imp.next()
244 }
245}
246impl<V: BitValue + fmt::Debug> fmt::Debug for IntoIter<V> {
247 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248 f.debug_tuple("IntoIter")
249 .field(&DebugAsSet(self.imp.clone()))
250 .finish()
251 }
252}
253
254pub struct Iter<'a, V: BitValue> {
256 imp: iter::IterImpl<V, std::iter::Copied<slice::Iter<'a, Word>>>,
257}
258
259impl<V: BitValue> Iterator for Iter<'_, V> {
260 type Item = V;
261
262 fn next(&mut self) -> Option<Self::Item> {
263 self.imp.next()
264 }
265}
266
267impl<V: BitValue + fmt::Debug> fmt::Debug for Iter<'_, V> {
268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269 f.debug_tuple("Iter")
270 .field(&DebugAsSet(self.imp.clone()))
271 .finish()
272 }
273}
274
275struct DebugAsSet<I>(I);
276impl<I: Clone + Iterator> fmt::Debug for DebugAsSet<I>
277where
278 I::Item: fmt::Debug,
279{
280 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281 f.debug_set().entries(self.0.clone()).finish()
282 }
283}
284
285pub(crate) struct SymmetricDifference<'a, V: BitValue> {
289 imp: iter::IterImpl<V, SymmDiffWords<'a>>,
290}
291
292struct SymmDiffWords<'a> {
293 a: slice::Iter<'a, Word>,
294 b: slice::Iter<'a, Word>,
295}
296
297impl Iterator for SymmDiffWords<'_> {
298 type Item = Word;
299 fn next(&mut self) -> Option<Word> {
300 Some(self.a.next().copied()? ^ self.b.next().copied()?)
301 }
302}
303
304impl<V: BitValue> Iterator for SymmetricDifference<'_, V> {
305 type Item = V;
306
307 fn next(&mut self) -> Option<Self::Item> {
308 self.imp.next()
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use std::mem;
315
316 use crate::{
317 InputProp,
318 event::{Abs, EventType, Key, Led, Misc, Rel},
319 };
320
321 use super::*;
322
323 #[test]
324 fn sizes() {
325 assert_eq!(mem::size_of::<BitSet<EventType>>(), mem::size_of::<Word>());
327 assert_eq!(mem::size_of::<BitSet<InputProp>>(), mem::size_of::<Word>());
328 assert_eq!(mem::size_of::<BitSet<Rel>>(), mem::size_of::<Word>());
329 assert_eq!(mem::size_of::<BitSet<Misc>>(), mem::size_of::<Word>());
330 assert_eq!(mem::size_of::<BitSet<Led>>(), mem::size_of::<Word>());
331 }
332
333 #[test]
334 fn bit0() {
335 let mut set = BitSet::new();
336 set.insert(InputProp(0));
337
338 assert!(set.contains(InputProp::POINTER));
339 assert!(!set.contains(InputProp::DIRECT));
340 assert!(!set.contains(InputProp::MAX));
341 assert!(!set.contains(InputProp::CNT));
342 assert!(!set.contains(InputProp(u8::MAX)));
343
344 assert_eq!(set.iter().collect::<Vec<_>>(), &[InputProp::POINTER]);
345 }
346
347 #[test]
348 fn max() {
349 let mut set = BitSet::new();
350 set.insert(InputProp::MAX);
351
352 assert!(!set.contains(InputProp::POINTER));
353 assert!(!set.contains(InputProp::DIRECT));
354 assert!(set.contains(InputProp::MAX));
355 assert!(!set.contains(InputProp::CNT));
356 assert!(!set.remove(InputProp::CNT));
357 }
358
359 #[test]
360 #[should_panic = "value out of range for `BitSet`"]
361 fn above_max() {
362 let mut set = BitSet::new();
363 set.insert(Abs::from_raw(Abs::MAX.raw() + 1));
364 }
365
366 #[test]
367 fn debug() {
368 let set = BitSet::from_iter([Abs::X, Abs::Y, Abs::BRAKE]);
369 assert_eq!(format!("{set:?}"), "{ABS_X, ABS_Y, ABS_BRAKE}");
370
371 let mut iter = set.iter();
372 assert_eq!(iter.next(), Some(Abs::X));
373 assert_eq!(format!("{iter:?}"), "Iter({ABS_Y, ABS_BRAKE})");
374
375 let mut iter = set.into_iter();
376 assert_eq!(iter.next(), Some(Abs::X));
377 assert_eq!(format!("{iter:?}"), "IntoIter({ABS_Y, ABS_BRAKE})");
378 }
379
380 #[test]
381 fn multiple() {
382 let mut set = BitSet::new();
383 set.insert(Key::KEY_RESERVED);
384 set.insert(Key::KEY_Q);
385 set.insert(Key::KEY_MAX);
386 set.insert(Key::KEY_MACRO1);
387
388 assert_eq!(
389 set.iter().collect::<Vec<_>>(),
390 &[Key::KEY_RESERVED, Key::KEY_Q, Key::KEY_MACRO1, Key::KEY_MAX]
391 );
392 }
393
394 #[test]
395 fn symmdiff() {
396 let mut a = BitSet::new();
397 a.insert(Key::KEY_B);
398
399 assert_eq!(
400 a.symmetric_difference(&BitSet::new()).collect::<Vec<_>>(),
401 &[Key::KEY_B]
402 );
403 assert_eq!(
404 BitSet::new().symmetric_difference(&a).collect::<Vec<_>>(),
405 &[Key::KEY_B]
406 );
407
408 let mut b = BitSet::new();
409 b.insert(Key::KEY_A);
410
411 assert_eq!(
412 a.symmetric_difference(&b).collect::<Vec<_>>(),
413 &[Key::KEY_A, Key::KEY_B]
414 );
415 assert_eq!(
416 b.symmetric_difference(&a).collect::<Vec<_>>(),
417 &[Key::KEY_A, Key::KEY_B]
418 );
419
420 assert_eq!(a.symmetric_difference(&a).collect::<Vec<_>>(), &[]);
421 assert_eq!(
422 BitSet::<Key>::new()
423 .symmetric_difference(&BitSet::new())
424 .collect::<Vec<_>>(),
425 &[]
426 );
427 }
428}