use super::*;
use static_assertions::const_assert_eq;
pub struct EnumSetSameTypeHack<'a, T: BigEnumSetType + 'static> {
pub unified: &'a [T],
pub set: BigEnumSet<T>,
}
#[cfg(feature = "serde")]
pub use serde2 as serde;
pub unsafe trait BigEnumSetTypePrivate {
type Repr: AsMut<[usize]> + AsRef<[usize]> + Copy + Clone + Hash + PartialEq + Eq + PartialOrd + Ord;
const REPR_LEN: usize;
const REPR_NONE: Self::Repr;
const REPR_ALL: Self::Repr;
fn enum_into_u16(self) -> u16;
unsafe fn enum_from_u16(val: u16) -> Self;
#[cfg(feature = "serde")]
fn serialize<S: serde::Serializer>(set: &BigEnumSet<Self>, ser: S) -> Result<S::Ok, S::Error>
where Self: BigEnumSetType;
#[cfg(feature = "serde")]
fn deserialize<'de, D: serde::Deserializer<'de>>(de: D) -> Result<BigEnumSet<Self>, D::Error>
where Self: BigEnumSetType;
}
pub const WORD_BITS: u16 = mem::size_of::<usize>() as u16 * 8;
pub const WORD_MASK: u16 = WORD_BITS - 1;
pub const WORD_SHIFT: u8 = WORD_BITS.trailing_zeros() as u8; const_assert_eq!(WORD_BITS.count_ones(), 1);
pub(crate) fn union<T: BigEnumSetType>(this: &mut BigEnumSet<T>, other: &BigEnumSet<T>) {
this.apply_op(other, |w1, w2| w1 | w2);
}
pub(crate) fn intersection<T: BigEnumSetType>(this: &mut BigEnumSet<T>, other: &BigEnumSet<T>) {
this.apply_op(other, |w1, w2| w1 & w2);
}
pub(crate) fn difference<T: BigEnumSetType>(this: &mut BigEnumSet<T>, other: &BigEnumSet<T>) {
this.apply_op(other, |w1, w2| w1 & !w2);
}
pub(crate) fn symmetrical_difference<T: BigEnumSetType>(this: &mut BigEnumSet<T>, other: &BigEnumSet<T>) {
this.apply_op(other, |w1, w2| w1 ^ w2);
}
pub(crate) fn complement<T: BigEnumSetType>(this: &mut BigEnumSet<T>) {
this.apply_op(&BigEnumSet::all(), |w1, w2| !w1 & w2);
}
pub fn union_enum<T: BigEnumSetType>(this: &mut BigEnumSet<T>, value: T) {
this.insert(value);
}
pub fn intersection_enum<T: BigEnumSetType>(this: &mut BigEnumSet<T>, value: T) {
let present = this.contains(value);
this.clear();
if present {
this.insert(value);
}
}
pub(crate) fn difference_enum<T: BigEnumSetType>(this: &mut BigEnumSet<T>, value: T) {
this.remove(value);
}
pub fn difference_enum_reverse<T: BigEnumSetType>(this: &mut BigEnumSet<T>, value: T) {
let present = this.contains(value);
this.clear();
if !present {
this.insert(value);
}
}
pub fn symmetrical_difference_enum<T: BigEnumSetType>(this: &mut BigEnumSet<T>, value: T) {
if this.contains(value) {
this.remove(value);
} else {
this.insert(value);
}
}
#[cfg(feature = "serde")]
pub mod serde_impl {
use super::serde::{Deserialize, Deserializer};
use super::serde::{Serialize, Serializer};
use crate::{BigEnumSet, BigEnumSetType};
use core::fmt;
use core::mem;
pub fn serialize_as_list<S, T>(set: &BigEnumSet<T>, ser: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: Serialize + BigEnumSetType,
{
use super::serde::ser::SerializeSeq;
let mut seq = ser.serialize_seq(Some(set.len()))?;
for v in set.iter() {
seq.serialize_element(&v)?;
}
seq.end()
}
pub fn deserialize_from_list<'de, D, T>(deser: D) -> Result<BigEnumSet<T>, D::Error>
where
D: Deserializer<'de>,
T: BigEnumSetType + Deserialize<'de>,
{
use super::serde::de::{SeqAccess, Visitor};
use core::any::type_name;
use core::marker::PhantomData;
struct SetVisitor<T>(PhantomData<T>);
impl<'de, T> Visitor<'de> for SetVisitor<T>
where T: BigEnumSetType + Deserialize<'de> {
type Value = BigEnumSet<T>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "a list of {}", type_name::<T>())
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error> where A: SeqAccess<'de> {
let mut accum = BigEnumSet::<T>::new();
while let Some(val) = seq.next_element::<T>()? {
accum |= val;
}
Ok(accum)
}
}
deser.deserialize_seq(SetVisitor(PhantomData))
}
const WORD_SIZE: usize = mem::size_of::<usize>();
pub fn serialize_as_bytes<S, T>(set: &BigEnumSet<T>, ser: S, n_bytes: usize) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: BigEnumSetType,
{
use super::serde::ser::SerializeTuple;
let words = set.__repr.as_ref();
debug_assert!(!words.is_empty());
debug_assert!((words.len() - 1) * WORD_SIZE < n_bytes);
let mut bytes_left = n_bytes;
let mut seq = ser.serialize_tuple(n_bytes)?;
let (last, rest) = words.split_last().unwrap();
for word in rest.iter() {
for b in word.to_le_bytes().iter() {
seq.serialize_element(b)?;
}
}
bytes_left -= rest.len() * WORD_SIZE;
let last_bytes = bytes_left.min(WORD_SIZE);
for b in last.to_le_bytes()[0 .. last_bytes].iter() {
seq.serialize_element(b)?;
}
bytes_left -= last_bytes;
for _i in 0 .. bytes_left {
seq.serialize_element(&0u8)?;
}
seq.end()
}
pub fn deserialize_from_bytes<'de, D, T>(deser: D, n_bytes: usize, check_unknown: bool) -> Result<BigEnumSet<T>, D::Error>
where
D: Deserializer<'de>,
T: BigEnumSetType,
{
use super::serde::de::{Error, SeqAccess, Visitor};
use core::marker::PhantomData;
struct SetVisitor<T> {
n_bytes: usize,
check_unknown: bool,
pd: PhantomData<T>,
}
impl<T> SetVisitor<T>
where T: BigEnumSetType {
fn validate<E: Error>(&self, bytes_read: usize, mut set: BigEnumSet<T>) -> Result<BigEnumSet<T>, E> {
if bytes_read != self.n_bytes {
return Err(Error::invalid_length(bytes_read, self));
}
if self.check_unknown {
let has_invalid = set.__repr.as_ref().iter()
.zip(T::REPR_ALL.as_ref().iter())
.any(|(&w1, &w2)| w1 & !w2 != 0);
if has_invalid {
return Err(Error::custom("BigEnumSet contains unknown bits"));
}
} else {
set.__repr.as_mut().iter_mut()
.zip(T::REPR_ALL.as_ref().iter())
.for_each(|(w1, w2)| *w1 &= *w2);
}
Ok(set)
}
}
impl<'de, T> Visitor<'de> for SetVisitor<T>
where T: BigEnumSetType {
type Value = BigEnumSet<T>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "a byte array of length {}", self.n_bytes)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where A: SeqAccess<'de> {
let mut set = BigEnumSet::<T>::default();
let words = set.__repr.as_mut();
let mut bytes_read = 0;
for word in words.iter_mut() {
let mut word_arr = [0u8; WORD_SIZE];
for b in word_arr.iter_mut() {
*b = match seq.next_element::<u8>()? {
Some(b) => b,
None => {
*word = usize::from_le_bytes(word_arr);
return self.validate(bytes_read, set);
}
};
bytes_read += 1;
}
*word = usize::from_le_bytes(word_arr);
}
while let Some(b) = seq.next_element::<u8>()? {
if self.check_unknown && b != 0 {
return Err(Error::custom("BigEnumSet contains unknown bits"));
}
bytes_read += 1;
}
self.validate(bytes_read, set)
}
}
deser.deserialize_tuple(n_bytes, SetVisitor { n_bytes, check_unknown, pd: PhantomData })
}
}