use super::Sysno;
use core::fmt;
use core::num::NonZeroUsize;
const fn bits_per<T>() -> usize {
core::mem::size_of::<T>().saturating_mul(8)
}
const fn words<T>(bits: usize) -> usize {
let width = bits_per::<T>();
if width == 0 {
return 0;
}
bits / width + ((bits % width != 0) as usize)
}
#[derive(Clone, Eq, PartialEq)]
pub struct SysnoSet {
pub(crate) data: [usize; words::<usize>(Sysno::table_size())],
}
impl Default for SysnoSet {
fn default() -> Self {
Self::empty()
}
}
impl SysnoSet {
const ALL: &'static Self = &Self::new(Sysno::ALL);
const WORD_WIDTH: usize = usize::BITS as usize;
#[inline]
pub(crate) const fn get_idx_mask(sysno: Sysno) -> (usize, usize) {
let bit = (sysno.id() as usize) - (Sysno::first().id() as usize);
(bit / Self::WORD_WIDTH, 1 << (bit % Self::WORD_WIDTH))
}
pub const fn new(syscalls: &[Sysno]) -> Self {
let mut set = Self::empty();
let mut i = 0;
while i < syscalls.len() {
let (idx, mask) = Self::get_idx_mask(syscalls[i]);
set.data[idx] |= mask;
i += 1;
}
set
}
pub const fn empty() -> Self {
Self {
data: [0; words::<usize>(Sysno::table_size())],
}
}
pub const fn all() -> Self {
Self {
data: Self::ALL.data,
}
}
pub const fn contains(&self, sysno: Sysno) -> bool {
let (idx, mask) = Self::get_idx_mask(sysno);
self.data[idx] & mask != 0
}
pub fn is_empty(&self) -> bool {
self.data.iter().all(|&x| x == 0)
}
pub fn clear(&mut self) {
for word in &mut self.data {
*word = 0;
}
}
pub fn count(&self) -> usize {
self.data
.iter()
.fold(0, |acc, x| acc + x.count_ones() as usize)
}
pub fn insert(&mut self, sysno: Sysno) -> bool {
let (idx, mask) = Self::get_idx_mask(sysno);
let old_value = self.data[idx] & mask;
self.data[idx] |= mask;
old_value == 0
}
pub fn remove(&mut self, sysno: Sysno) -> bool {
let (idx, mask) = Self::get_idx_mask(sysno);
let old_value = self.data[idx] & mask;
self.data[idx] &= !mask;
old_value != 0
}
#[must_use]
pub const fn union(mut self, other: &Self) -> Self {
let mut i = 0;
let n = self.data.len();
while i < n {
self.data[i] |= other.data[i];
i += 1;
}
self
}
#[must_use]
pub const fn intersection(mut self, other: &Self) -> Self {
let mut i = 0;
let n = self.data.len();
while i < n {
self.data[i] &= other.data[i];
i += 1;
}
self
}
#[must_use]
pub const fn difference(mut self, other: &Self) -> Self {
let mut i = 0;
let n = self.data.len();
while i < n {
self.data[i] &= !other.data[i];
i += 1;
}
self
}
#[must_use]
pub const fn symmetric_difference(mut self, other: &Self) -> Self {
let mut i = 0;
let n = self.data.len();
while i < n {
self.data[i] ^= other.data[i];
i += 1;
}
self
}
pub fn iter(&self) -> SysnoSetIter {
SysnoSetIter::new(self.data.iter())
}
}
impl fmt::Debug for SysnoSet {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_set().entries(self.iter()).finish()
}
}
impl core::ops::BitOr for SysnoSet {
type Output = Self;
fn bitor(mut self, rhs: Self) -> Self::Output {
self |= rhs;
self
}
}
impl core::ops::BitOrAssign<&Self> for SysnoSet {
fn bitor_assign(&mut self, rhs: &Self) {
for (left, right) in self.data.iter_mut().zip(rhs.data.iter()) {
*left |= right;
}
}
}
impl core::ops::BitOrAssign for SysnoSet {
fn bitor_assign(&mut self, rhs: Self) {
*self |= &rhs;
}
}
impl core::ops::BitOrAssign<Sysno> for SysnoSet {
fn bitor_assign(&mut self, sysno: Sysno) {
self.insert(sysno);
}
}
impl FromIterator<Sysno> for SysnoSet {
fn from_iter<I: IntoIterator<Item = Sysno>>(iter: I) -> Self {
let mut set = SysnoSet::empty();
set.extend(iter);
set
}
}
impl Extend<Sysno> for SysnoSet {
fn extend<T: IntoIterator<Item = Sysno>>(&mut self, iter: T) {
for sysno in iter {
self.insert(sysno);
}
}
}
impl<'a> IntoIterator for &'a SysnoSet {
type Item = Sysno;
type IntoIter = SysnoSetIter<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
struct NonZeroUsizeIter<'a> {
iter: core::slice::Iter<'a, usize>,
count: usize,
}
impl<'a> NonZeroUsizeIter<'a> {
pub fn new(iter: core::slice::Iter<'a, usize>) -> Self {
Self { iter, count: 0 }
}
}
impl<'a> Iterator for NonZeroUsizeIter<'a> {
type Item = NonZeroUsize;
fn next(&mut self) -> Option<Self::Item> {
for item in &mut self.iter {
self.count += 1;
if let Some(item) = NonZeroUsize::new(*item) {
return Some(item);
}
}
None
}
}
pub struct SysnoSetIter<'a> {
iter: NonZeroUsizeIter<'a>,
current: Option<NonZeroUsize>,
}
impl<'a> SysnoSetIter<'a> {
fn new(iter: core::slice::Iter<'a, usize>) -> Self {
let mut iter = NonZeroUsizeIter::new(iter);
let current = iter.next();
Self { iter, current }
}
}
impl<'a> Iterator for SysnoSetIter<'a> {
type Item = Sysno;
fn next(&mut self) -> Option<Self::Item> {
const MASK: usize = !1usize;
if let Some(word) = self.current.take() {
let index = self.iter.count.wrapping_sub(1);
let bit = word.trailing_zeros();
let next_word =
NonZeroUsize::new(word.get() & MASK.rotate_left(bit));
self.current = next_word.or_else(|| self.iter.next());
let offset = Sysno::first().id() as u32;
let sysno = index as u32 * usize::BITS + bit + offset;
return Some(Sysno::from(sysno));
}
None
}
}
#[cfg(feature = "serde")]
use serde::{
de::{Deserialize, Deserializer, SeqAccess, Visitor},
ser::{Serialize, SerializeSeq, Serializer},
};
#[cfg(feature = "serde")]
impl Serialize for SysnoSet {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.count()))?;
for sysno in self {
seq.serialize_element(&sysno)?;
}
seq.end()
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for SysnoSet {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct SeqVisitor;
impl<'de> Visitor<'de> for SeqVisitor {
type Value = SysnoSet;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("a sequence")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut values = SysnoSet::empty();
while let Some(value) = seq.next_element()? {
values.insert(value);
}
Ok(values)
}
}
deserializer.deserialize_seq(SeqVisitor)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_words() {
assert_eq!(words::<u64>(42), 1);
assert_eq!(words::<u64>(0), 0);
assert_eq!(words::<u32>(42), 2);
assert_eq!(words::<()>(42), 0);
}
#[test]
fn test_bits_per() {
assert_eq!(bits_per::<()>(), 0);
assert_eq!(bits_per::<u8>(), 8);
assert_eq!(bits_per::<u32>(), 32);
assert_eq!(bits_per::<u64>(), 64);
}
#[test]
fn test_default() {
assert_eq!(SysnoSet::default(), SysnoSet::empty());
}
#[test]
fn test_const_new() {
static SYSCALLS: SysnoSet =
SysnoSet::new(&[Sysno::openat, Sysno::read, Sysno::close]);
assert!(SYSCALLS.contains(Sysno::openat));
assert!(SYSCALLS.contains(Sysno::read));
assert!(SYSCALLS.contains(Sysno::close));
assert!(!SYSCALLS.contains(Sysno::write));
}
#[test]
fn test_contains() {
let set = SysnoSet::empty();
assert!(!set.contains(Sysno::openat));
assert!(!set.contains(Sysno::first()));
assert!(!set.contains(Sysno::last()));
let set = SysnoSet::all();
assert!(set.contains(Sysno::openat));
assert!(set.contains(Sysno::first()));
assert!(set.contains(Sysno::last()));
}
#[test]
fn test_is_empty() {
let mut set = SysnoSet::empty();
assert!(set.is_empty());
assert!(set.insert(Sysno::openat));
assert!(!set.is_empty());
assert!(set.remove(Sysno::openat));
assert!(set.is_empty());
assert!(set.insert(Sysno::last()));
assert!(!set.is_empty());
}
#[test]
fn test_count() {
let mut set = SysnoSet::empty();
assert_eq!(set.count(), 0);
assert!(set.insert(Sysno::openat));
assert!(set.insert(Sysno::last()));
assert_eq!(set.count(), 2);
}
#[test]
fn test_insert() {
let mut set = SysnoSet::empty();
assert!(set.insert(Sysno::openat));
assert!(set.insert(Sysno::read));
assert!(set.insert(Sysno::close));
assert!(set.contains(Sysno::openat));
assert!(set.contains(Sysno::read));
assert!(set.contains(Sysno::close));
assert_eq!(set.count(), 3);
}
#[test]
fn test_remove() {
let mut set = SysnoSet::all();
assert!(set.remove(Sysno::openat));
assert!(!set.contains(Sysno::openat));
assert!(set.contains(Sysno::close));
}
#[cfg(feature = "std")]
#[test]
fn test_from_iter() {
let set =
SysnoSet::from_iter(vec![Sysno::openat, Sysno::read, Sysno::close]);
assert!(set.contains(Sysno::openat));
assert!(set.contains(Sysno::read));
assert!(set.contains(Sysno::close));
assert_eq!(set.count(), 3);
}
#[test]
fn test_all() {
let mut all = SysnoSet::all();
assert_eq!(all.count(), Sysno::count());
all.contains(Sysno::openat);
all.contains(Sysno::first());
all.contains(Sysno::last());
all.clear();
assert_eq!(all.count(), 0);
}
#[test]
fn test_union() {
let a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
assert_eq!(
a.union(&b),
SysnoSet::new(&[
Sysno::read,
Sysno::write,
Sysno::openat,
Sysno::close
])
);
}
#[test]
fn test_bitorassign() {
let mut a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
a |= &b;
a |= b;
a |= Sysno::openat;
assert_eq!(
a,
SysnoSet::new(&[
Sysno::read,
Sysno::write,
Sysno::close,
Sysno::openat,
])
);
}
#[test]
fn test_bitor() {
let a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
assert_eq!(
a | b,
SysnoSet::new(&[
Sysno::read,
Sysno::write,
Sysno::openat,
Sysno::close,
])
);
}
#[test]
fn test_intersection() {
let a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
assert_eq!(
a.intersection(&b),
SysnoSet::new(&[Sysno::openat, Sysno::close])
);
}
#[test]
fn test_difference() {
let a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
assert_eq!(a.difference(&b), SysnoSet::new(&[Sysno::read]));
}
#[test]
fn test_symmetric_difference() {
let a = SysnoSet::new(&[Sysno::read, Sysno::openat, Sysno::close]);
let b = SysnoSet::new(&[Sysno::write, Sysno::openat, Sysno::close]);
assert_eq!(
a.symmetric_difference(&b),
SysnoSet::new(&[Sysno::read, Sysno::write])
);
}
#[cfg(feature = "std")]
#[test]
fn test_iter() {
let syscalls = &[Sysno::read, Sysno::openat, Sysno::close];
let set = SysnoSet::new(syscalls);
assert_eq!(set.iter().collect::<Vec<_>>().len(), 3);
}
#[test]
fn test_iter_full() {
assert_eq!(SysnoSet::all().iter().count(), Sysno::count());
}
#[test]
fn test_into_iter() {
let syscalls = &[Sysno::read, Sysno::openat, Sysno::close];
let set = SysnoSet::new(syscalls);
assert_eq!(set.into_iter().count(), 3);
}
#[cfg(feature = "std")]
#[test]
fn test_debug() {
let syscalls = &[Sysno::openat, Sysno::read];
let set = SysnoSet::new(syscalls);
let result = format!("{:?}", set);
assert_eq!(result.len(), "{read, openat}".len());
assert!(result.starts_with('{'));
assert!(result.ends_with('}'));
assert!(result.contains(", "));
assert!(result.contains("read"));
assert!(result.contains("openat"));
}
#[cfg(feature = "std")]
#[test]
fn test_iter_empty() {
assert_eq!(SysnoSet::empty().iter().collect::<Vec<_>>(), &[]);
}
#[cfg(feature = "serde")]
#[test]
fn test_serde_roundtrip() {
let syscalls = SysnoSet::new(&[
Sysno::read,
Sysno::write,
Sysno::close,
Sysno::openat,
]);
let s = serde_json::to_string_pretty(&syscalls).unwrap();
assert_eq!(serde_json::from_str::<SysnoSet>(&s).unwrap(), syscalls);
}
}