rawsys_linux/
set.rs

1//! Enables the creation of a syscall bitset.
2
3use super::Sysno;
4
5use core::fmt;
6use core::num::NonZeroUsize;
7
8const fn bits_per<T>() -> usize {
9    core::mem::size_of::<T>().saturating_mul(8)
10}
11
12/// Returns the number of words of type `T` required to hold the specified
13/// number of `bits`.
14const fn words<T>(bits: usize) -> usize {
15    let width = bits_per::<T>();
16    if width == 0 {
17        return 0;
18    }
19
20    bits / width + (!bits.is_multiple_of(width) as usize)
21}
22
23#[allow(clippy::doc_markdown)]
24/// A set of syscalls.
25///
26/// Backed by a compact bitset, this provides constant-time membership checks
27/// and set algebra (union, intersection, difference) with predictable cost.
28/// The bit layout matches the `Sysno` table so that conversions are trivial.
29///
30/// Complexity
31/// - `contains`/`insert`/`remove`: O(1)
32/// - `count`/`is_empty`: O(n_words)
33///
34/// # Examples
35///
36/// ```
37/// # use rawsys_linux::{Sysno, SysnoSet};
38/// let syscalls = SysnoSet::new(&[Sysno::read, Sysno::write, Sysno::openat, Sysno::close]);
39/// assert!(syscalls.contains(Sysno::read));
40/// assert!(syscalls.contains(Sysno::close));
41/// ```
42/// Most operations can be done at compile-time as well.
43/// ```
44/// # use rawsys_linux::{Sysno, SysnoSet};
45/// const SYSCALLS: SysnoSet =
46///     SysnoSet::new(&[Sysno::read, Sysno::write, Sysno::close])
47///         .union(&SysnoSet::new(&[Sysno::openat]));
48/// const _: () = assert!(SYSCALLS.contains(Sysno::read));
49/// const _: () = assert!(SYSCALLS.contains(Sysno::openat));
50/// ```
51#[derive(Clone, Eq, PartialEq)]
52pub struct SysnoSet {
53    pub(crate) data: [usize; words::<usize>(Sysno::table_size())],
54}
55
56impl Default for SysnoSet {
57    fn default() -> Self {
58        Self::empty()
59    }
60}
61
62impl SysnoSet {
63    /// The set of all valid syscalls.
64    pub(crate) const ALL: &'static Self = &Self::new(Sysno::ALL);
65
66    const WORD_WIDTH: usize = usize::BITS as usize;
67
68    /// Compute the index and mask for the given syscall as stored in the set data.
69    #[inline]
70    pub(crate) const fn get_idx_mask(sysno: Sysno) -> (usize, usize) {
71        let bit = (sysno.id() as usize) - (Sysno::first().id() as usize);
72        (bit / Self::WORD_WIDTH, 1 << (bit % Self::WORD_WIDTH))
73    }
74
75    /// Initialize the syscall set with the given slice of syscalls.
76    ///
77    /// Since this is a `const fn`, this can be used at compile-time.
78    pub const fn new(syscalls: &[Sysno]) -> Self {
79        let mut set = Self::empty();
80
81        // Use while-loop because for-loops are not yet allowed in const-fns.
82        // https://github.com/rust-lang/rust/issues/87575
83        let mut i = 0;
84        while i < syscalls.len() {
85            let (idx, mask) = Self::get_idx_mask(syscalls[i]);
86            set.data[idx] |= mask;
87            i += 1;
88        }
89
90        set
91    }
92
93    /// Creates an empty set of syscalls.
94    pub const fn empty() -> Self {
95        Self {
96            data: [0; words::<usize>(Sysno::table_size())],
97        }
98    }
99
100    /// Creates a set containing all valid syscalls.
101    ///
102    /// Note: This returns a by-value copy of the bitset. Prefer borrowing
103    /// `SysnoSet::ALL` directly when you only need membership checks to avoid
104    /// copying the entire array.
105    pub const fn all() -> Self {
106        Self {
107            data: Self::ALL.data,
108        }
109    }
110
111    /// Returns true if the set contains the given syscall.
112    pub const fn contains(&self, sysno: Sysno) -> bool {
113        let (idx, mask) = Self::get_idx_mask(sysno);
114        self.data[idx] & mask != 0
115    }
116
117    /// Returns true if the set is empty. Although this is an O(1) operation
118    /// (because the total number of possible syscalls is always constant), it
119    /// must go through the whole bit set to count the number of bits. Thus,
120    /// this may have a large, constant overhead.
121    pub fn is_empty(&self) -> bool {
122        self.data.iter().all(|&x| x == 0)
123    }
124
125    /// Clears the set, removing all syscalls.
126    pub fn clear(&mut self) {
127        for word in &mut self.data {
128            *word = 0;
129        }
130    }
131
132    /// Returns the number of syscalls in the set. Although this is an O(1)
133    /// operation (because the total number of syscalls is always constant), it
134    /// must go through the whole bit set to count the number of bits. Thus,
135    /// this may have a large, constant overhead.
136    pub fn count(&self) -> usize {
137        self.data
138            .iter()
139            .fold(0, |acc, x| acc + x.count_ones() as usize)
140    }
141
142    /// Inserts the given syscall into the set. Returns true if the syscall was
143    /// not already in the set.
144    pub fn insert(&mut self, sysno: Sysno) -> bool {
145        // The returned value computation will be optimized away by the compiler
146        // if not needed.
147        let (idx, mask) = Self::get_idx_mask(sysno);
148        let old_value = self.data[idx] & mask;
149        self.data[idx] |= mask;
150        old_value == 0
151    }
152
153    /// Removes the given syscall from the set. Returns true if the syscall was
154    /// in the set.
155    pub fn remove(&mut self, sysno: Sysno) -> bool {
156        // The returned value computation will be optimized away by the compiler
157        // if not needed.
158        let (idx, mask) = Self::get_idx_mask(sysno);
159        let old_value = self.data[idx] & mask;
160        self.data[idx] &= !mask;
161        old_value != 0
162    }
163
164    /// Does a set union with this set and another.
165    #[must_use]
166    pub const fn union(mut self, other: &Self) -> Self {
167        let mut i = 0;
168        let n = self.data.len();
169        while i < n {
170            self.data[i] |= other.data[i];
171            i += 1;
172        }
173
174        self
175    }
176
177    /// Does a set intersection with this set and another.
178    #[must_use]
179    pub const fn intersection(mut self, other: &Self) -> Self {
180        let mut i = 0;
181        let n = self.data.len();
182        while i < n {
183            self.data[i] &= other.data[i];
184            i += 1;
185        }
186
187        self
188    }
189
190    /// Calculates the difference with this set and another. That is, the
191    /// resulting set only includes the syscalls that are in `self` but not in
192    /// `other`.
193    #[must_use]
194    pub const fn difference(mut self, other: &Self) -> Self {
195        let mut i = 0;
196        let n = self.data.len();
197        while i < n {
198            self.data[i] &= !other.data[i];
199            i += 1;
200        }
201
202        self
203    }
204
205    /// Calculates the symmetric difference with this set and another. That is,
206    /// the resulting set only includes the syscalls that are in `self` or in
207    /// `other`, but not in both.
208    #[must_use]
209    pub const fn symmetric_difference(mut self, other: &Self) -> Self {
210        let mut i = 0;
211        let n = self.data.len();
212        while i < n {
213            self.data[i] ^= other.data[i];
214            i += 1;
215        }
216
217        self
218    }
219
220    /// Returns an iterator that iterates over the syscalls contained in the set.
221    pub fn iter(&self) -> SysnoSetIter<'_> {
222        SysnoSetIter::new(self.data.iter())
223    }
224}
225
226impl fmt::Debug for SysnoSet {
227    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
228        f.debug_set().entries(self.iter()).finish()
229    }
230}
231
232impl core::ops::BitOr for SysnoSet {
233    type Output = Self;
234
235    fn bitor(mut self, rhs: Self) -> Self::Output {
236        self |= rhs;
237        self
238    }
239}
240
241impl core::ops::BitOrAssign<&Self> for SysnoSet {
242    fn bitor_assign(&mut self, rhs: &Self) {
243        for (left, right) in self.data.iter_mut().zip(rhs.data.iter()) {
244            *left |= right;
245        }
246    }
247}
248
249impl core::ops::BitOrAssign for SysnoSet {
250    fn bitor_assign(&mut self, rhs: Self) {
251        *self |= &rhs;
252    }
253}
254
255impl core::ops::BitOrAssign<Sysno> for SysnoSet {
256    fn bitor_assign(&mut self, sysno: Sysno) {
257        self.insert(sysno);
258    }
259}
260
261impl FromIterator<Sysno> for SysnoSet {
262    fn from_iter<I: IntoIterator<Item = Sysno>>(iter: I) -> Self {
263        let mut set = SysnoSet::empty();
264        set.extend(iter);
265        set
266    }
267}
268
269impl Extend<Sysno> for SysnoSet {
270    fn extend<T: IntoIterator<Item = Sysno>>(&mut self, iter: T) {
271        for sysno in iter {
272            self.insert(sysno);
273        }
274    }
275}
276
277impl<'a> IntoIterator for &'a SysnoSet {
278    type Item = Sysno;
279    type IntoIter = SysnoSetIter<'a>;
280
281    fn into_iter(self) -> Self::IntoIter {
282        self.iter()
283    }
284}
285
286/// Helper for iterating over the non-zero values of the words in the bitset.
287struct NonZeroUsizeIter<'a> {
288    iter: core::slice::Iter<'a, usize>,
289    count: usize,
290}
291
292impl<'a> NonZeroUsizeIter<'a> {
293    pub fn new(iter: core::slice::Iter<'a, usize>) -> Self {
294        Self { iter, count: 0 }
295    }
296}
297
298impl Iterator for NonZeroUsizeIter<'_> {
299    type Item = NonZeroUsize;
300
301    fn next(&mut self) -> Option<Self::Item> {
302        for item in &mut self.iter {
303            self.count += 1;
304
305            if let Some(item) = NonZeroUsize::new(*item) {
306                return Some(item);
307            }
308        }
309
310        None
311    }
312}
313
314/// An iterator over the syscalls contained in a [`SysnoSet`].
315pub struct SysnoSetIter<'a> {
316    // Our iterator over nonzero words in the bitset.
317    iter: NonZeroUsizeIter<'a>,
318
319    // The current word in the set we're operating on. This is only None if the
320    // iterator has been exhausted. The next bit that is set is found by
321    // counting the number of leading zeros. When found, we just mask it off.
322    current: Option<NonZeroUsize>,
323}
324
325impl<'a> SysnoSetIter<'a> {
326    fn new(iter: core::slice::Iter<'a, usize>) -> Self {
327        let mut iter = NonZeroUsizeIter::new(iter);
328        let current = iter.next();
329        Self { iter, current }
330    }
331}
332
333impl Iterator for SysnoSetIter<'_> {
334    type Item = Sysno;
335
336    fn next(&mut self) -> Option<Self::Item> {
337        // Construct a mask where all but the last bit is set. This is then
338        // shifted to remove the first bit we find.
339        const MASK: usize = !1usize;
340
341        if let Some(word) = self.current.take() {
342            let index = self.iter.count.wrapping_sub(1);
343
344            // Get the index of the next bit. For example:
345            //      0b0000000010000
346            //                ^
347            // Here, there are 4 trailing zeros, so 4 is the next set bit. Since
348            // we're only iterating over non-zero words, we are guaranteed to
349            // get a valid index.
350            let bit = word.trailing_zeros();
351
352            // Mask off that bit and store the resulting word for next time.
353            let next_word =
354                NonZeroUsize::new(word.get() & MASK.rotate_left(bit));
355
356            self.current = next_word.or_else(|| self.iter.next());
357
358            let offset = Sysno::first().id() as u32;
359            let sysno = index as u32 * usize::BITS + bit + offset;
360
361            // SAFETY: `index` is the position of a non-zero word from the
362            // bitset that represents this SysnoSet, and `bit` is the position
363            // of a set bit within that word. SysnoSet only ever sets bits for
364            // valid syscalls (constructed from `Sysno` values and closed under
365            // set ops), so the computed `sysno` always corresponds to a valid
366            // `Sysno` discriminant. The enum is `#[repr(i32)]`, so transmuting
367            // the integer value is sound.
368            #[cfg(debug_assertions)]
369            debug_assert!(Sysno::new(sysno as usize).is_some());
370
371            let s = unsafe { core::mem::transmute::<i32, Sysno>(sysno as i32) };
372            return Some(s);
373        }
374
375        None
376    }
377}
378
379#[cfg(feature = "serde")]
380use serde::{
381    de::{Deserialize, Deserializer, SeqAccess, Visitor},
382    ser::{Serialize, SerializeSeq, Serializer},
383};
384
385#[cfg(feature = "serde")]
386impl Serialize for SysnoSet {
387    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
388    where
389        S: Serializer,
390    {
391        let mut seq = serializer.serialize_seq(Some(self.count()))?;
392        for sysno in self {
393            seq.serialize_element(&sysno)?;
394        }
395        seq.end()
396    }
397}
398
399#[cfg(feature = "serde")]
400impl<'de> Deserialize<'de> for SysnoSet {
401    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
402    where
403        D: Deserializer<'de>,
404    {
405        struct SeqVisitor;
406
407        impl<'de> Visitor<'de> for SeqVisitor {
408            type Value = SysnoSet;
409
410            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
411                f.write_str("a sequence")
412            }
413
414            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
415            where
416                A: SeqAccess<'de>,
417            {
418                let mut values = SysnoSet::empty();
419
420                while let Some(value) = seq.next_element()? {
421                    values.insert(value);
422                }
423
424                Ok(values)
425            }
426        }
427
428        deserializer.deserialize_seq(SeqVisitor)
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn test_words() {
438        assert_eq!(words::<u64>(42), 1);
439        assert_eq!(words::<u64>(0), 0);
440        assert_eq!(words::<u32>(42), 2);
441        assert_eq!(words::<()>(42), 0);
442    }
443
444    #[test]
445    fn test_bits_per() {
446        assert_eq!(bits_per::<()>(), 0);
447        assert_eq!(bits_per::<u8>(), 8);
448        assert_eq!(bits_per::<u32>(), 32);
449        assert_eq!(bits_per::<u64>(), 64);
450    }
451
452    #[test]
453    fn test_default() {
454        assert_eq!(SysnoSet::default(), SysnoSet::empty());
455    }
456
457    #[test]
458    fn test_const_new() {
459        static SYSCALLS: SysnoSet =
460            SysnoSet::new(&[Sysno::openat, Sysno::read, Sysno::close]);
461
462        assert!(SYSCALLS.contains(Sysno::openat));
463        assert!(SYSCALLS.contains(Sysno::read));
464        assert!(SYSCALLS.contains(Sysno::close));
465        assert!(!SYSCALLS.contains(Sysno::write));
466    }
467
468    #[test]
469    fn test_contains() {
470        let set = SysnoSet::empty();
471        assert!(!set.contains(Sysno::openat));
472        assert!(!set.contains(Sysno::first()));
473        assert!(!set.contains(Sysno::last()));
474
475        let set = SysnoSet::all();
476        assert!(set.contains(Sysno::openat));
477        assert!(set.contains(Sysno::first()));
478        assert!(set.contains(Sysno::last()));
479    }
480
481    #[test]
482    fn test_is_empty() {
483        let mut set = SysnoSet::empty();
484        assert!(set.is_empty());
485        assert!(set.insert(Sysno::openat));
486        assert!(!set.is_empty());
487        assert!(set.remove(Sysno::openat));
488        assert!(set.is_empty());
489        assert!(set.insert(Sysno::last()));
490        assert!(!set.is_empty());
491    }
492
493    #[test]
494    fn test_count() {
495        let mut set = SysnoSet::empty();
496        assert_eq!(set.count(), 0);
497        assert!(set.insert(Sysno::openat));
498        assert!(set.insert(Sysno::last()));
499        assert_eq!(set.count(), 2);
500    }
501
502    #[test]
503    fn test_insert() {
504        let mut set = SysnoSet::empty();
505        assert!(set.insert(Sysno::openat));
506        assert!(set.insert(Sysno::read));
507        assert!(set.insert(Sysno::close));
508        assert!(set.contains(Sysno::openat));
509        assert!(set.contains(Sysno::read));
510        assert!(set.contains(Sysno::close));
511        assert_eq!(set.count(), 3);
512    }
513
514    #[test]
515    fn test_remove() {
516        let mut set = SysnoSet::all();
517        assert!(set.remove(Sysno::openat));
518        assert!(!set.contains(Sysno::openat));
519        assert!(set.contains(Sysno::close));
520    }
521
522    #[cfg(feature = "std")]
523    #[test]
524    fn test_from_iter() {
525        let set =
526            SysnoSet::from_iter(vec![Sysno::openat, Sysno::read, Sysno::close]);
527        assert!(set.contains(Sysno::openat));
528        assert!(set.contains(Sysno::read));
529        assert!(set.contains(Sysno::close));
530        assert_eq!(set.count(), 3);
531    }
532
533    #[test]
534    fn test_all() {
535        let mut all = SysnoSet::all();
536        assert_eq!(all.count(), Sysno::count());
537
538        all.contains(Sysno::openat);
539        all.contains(Sysno::first());
540        all.contains(Sysno::last());
541
542        all.clear();
543
544        assert_eq!(all.count(), 0);
545    }
546
547    #[test]
548    fn test_union() {
549        let a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
550        let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
551        assert_eq!(
552            a.union(&b),
553            SysnoSet::new(&[
554                Sysno::read,
555                Sysno::write,
556                Sysno::openat,
557                Sysno::close
558            ])
559        );
560    }
561
562    #[test]
563    fn test_bitorassign() {
564        let mut a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
565        let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
566        a |= &b;
567        a |= b;
568        a |= Sysno::openat;
569
570        assert_eq!(
571            a,
572            SysnoSet::new(&[
573                Sysno::read,
574                Sysno::write,
575                Sysno::close,
576                Sysno::openat,
577            ])
578        );
579    }
580
581    #[test]
582    fn test_bitor() {
583        let a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
584        let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
585        assert_eq!(
586            a | b,
587            SysnoSet::new(&[
588                Sysno::read,
589                Sysno::write,
590                Sysno::openat,
591                Sysno::close,
592            ])
593        );
594    }
595
596    #[test]
597    fn test_intersection() {
598        let a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
599        let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
600        assert_eq!(
601            a.intersection(&b),
602            SysnoSet::new(&[Sysno::openat, Sysno::close])
603        );
604    }
605
606    #[test]
607    fn test_difference() {
608        let a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
609        let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
610        assert_eq!(a.difference(&b), SysnoSet::new(&[Sysno::read]));
611    }
612
613    #[test]
614    fn test_symmetric_difference() {
615        let a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
616        let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
617        assert_eq!(
618            a.symmetric_difference(&b),
619            SysnoSet::new(&[Sysno::read, Sysno::write])
620        );
621    }
622
623    #[cfg(feature = "std")]
624    #[test]
625    fn test_iter() {
626        let syscalls = &[Sysno::read, Sysno::openat, Sysno::close];
627        let set = SysnoSet::new(syscalls);
628
629        assert_eq!(set.iter().collect::<Vec<_>>().len(), 3);
630    }
631
632    #[test]
633    fn test_iter_full() {
634        assert_eq!(SysnoSet::all().iter().count(), Sysno::count());
635    }
636
637    #[test]
638    fn test_into_iter() {
639        let syscalls = &[Sysno::read, Sysno::openat, Sysno::close];
640        let set = SysnoSet::new(syscalls);
641
642        assert_eq!(set.into_iter().count(), 3);
643    }
644
645    #[cfg(feature = "std")]
646    #[test]
647    fn test_debug() {
648        let syscalls = &[Sysno::openat, Sysno::read];
649        let set = SysnoSet::new(syscalls);
650        // The order of the debug output is not guaranteed, so we can't do an exact match
651        let result = format!("{:?}", set);
652        assert_eq!(result.len(), "{read, openat}".len());
653        assert!(result.starts_with('{'));
654        assert!(result.ends_with('}'));
655        assert!(result.contains(", "));
656        assert!(result.contains("read"));
657        assert!(result.contains("openat"));
658    }
659
660    #[cfg(feature = "std")]
661    #[test]
662    fn test_iter_empty() {
663        assert_eq!(SysnoSet::empty().iter().collect::<Vec<_>>(), &[]);
664    }
665
666    #[cfg(feature = "serde")]
667    #[test]
668    fn test_serde_roundtrip() {
669        let syscalls = SysnoSet::new(&[
670            Sysno::read,
671            Sysno::write,
672            Sysno::close,
673            Sysno::openat,
674        ]);
675
676        let s = serde_json::to_string_pretty(&syscalls).unwrap();
677
678        assert_eq!(serde_json::from_str::<SysnoSet>(&s).unwrap(), syscalls);
679    }
680}