use std::{fmt, mem, ops};
use thiserror::Error;
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::Direction;
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "pyo3", derive(IntoPyObject))]
pub struct NodeIndex<U = u32>(BitField<U>);
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "pyo3", derive(IntoPyObject))]
pub struct PortIndex<U = u32>(BitField<U>);
macro_rules! index_impls {
($name:ident) => {
impl<U: IndexBase> $name<U> {
pub fn max_index() -> usize {
BitField::<U>::max_index()
}
#[inline(always)]
pub fn new(index: usize) -> Self {
Self(BitField::new(index, false))
}
#[inline(always)]
pub fn index(&self) -> usize {
self.0.index_unchecked()
}
}
impl<U: IndexBase> TryFrom<usize> for $name<U> {
type Error = IndexError;
#[inline(always)]
fn try_from(index: usize) -> Result<Self, Self::Error> {
BitField::try_from(index).map(Self)
}
}
impl<U: IndexBase> From<$name<U>> for usize {
#[inline(always)]
fn from(index: $name<U>) -> Self {
Self::from(index.0)
}
}
impl<U: IndexBase> std::fmt::Debug for $name<U> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, concat!(stringify!($name), "({})"), self.index())
}
}
impl<U: IndexBase> AsMut<$name<U>> for BitField<U> {
fn as_mut(&mut self) -> &mut $name<U> {
unsafe { &mut *(self as *mut BitField<U> as *mut $name<U>) }
}
}
#[cfg(feature = "serde")]
impl<U: serde::Serialize + IndexBase> serde::Serialize for $name<U> {
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.index().serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'de, U: IndexBase + serde::Deserialize<'de>> serde::Deserialize<'de> for $name<U> {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok($name::new(serde::Deserialize::deserialize(deserializer)?))
}
}
};
}
index_impls!(NodeIndex);
index_impls!(PortIndex);
impl<U: IndexBase> Default for PortIndex<U> {
fn default() -> Self {
PortIndex::new(0)
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
#[repr(transparent)]
pub struct MaybeNodeIndex<U = u32>(BitField<U>);
#[derive(Clone, Copy, PartialEq, Eq)]
#[repr(transparent)]
pub struct MaybePortIndex<U = u32>(BitField<U>);
macro_rules! maybe_index_impls {
($maybe_index:ident, $index:ident) => {
impl<U: IndexBase> $maybe_index<U> {
pub fn new(value: Option<$index<U>>) -> Self {
match value {
Some(value) => {
debug_assert!(!value.0.is_none(), "invalid index");
Self(value.0)
}
None => Self(BitField::new_none()),
}
}
#[inline(always)]
pub fn to_option(self) -> Option<$index<U>> {
if self.0.is_none() {
None
} else {
Some($index(self.0))
}
}
#[inline(always)]
pub fn as_mut(&mut self) -> Option<&mut $index<U>> {
if self.is_none() {
None
} else {
Some(self.0.as_mut())
}
}
#[inline(always)]
pub fn is_some(self) -> bool {
!self.is_none()
}
#[inline(always)]
pub fn is_none(self) -> bool {
self.0.is_none()
}
#[inline(always)]
pub fn expect(self, msg: &str) -> $index<U> {
if self.is_none() {
panic!("{msg}");
}
$index(self.0)
}
#[inline(always)]
pub fn unwrap(self) -> $index<U> {
self.expect("unwrap called on None")
}
#[inline(always)]
pub fn take(&mut self) -> Self {
mem::replace(self, None.into())
}
#[inline(always)]
pub fn replace(&mut self, value: $index<U>) -> Self {
mem::replace(self, value.into())
}
}
impl<U: IndexBase> fmt::Debug for $maybe_index<U> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.to_option())
}
}
impl<U: IndexBase> From<$index<U>> for $maybe_index<U> {
#[inline(always)]
fn from(index: $index<U>) -> Self {
Self(index.0)
}
}
impl<U: IndexBase> From<$maybe_index<U>> for Option<$index<U>> {
#[inline(always)]
fn from(maybe_index: $maybe_index<U>) -> Self {
maybe_index.to_option()
}
}
impl<U: IndexBase> From<Option<$index<U>>> for $maybe_index<U> {
#[inline(always)]
fn from(index: Option<$index<U>>) -> Self {
Self::new(index)
}
}
impl<U: IndexBase> Default for $maybe_index<U> {
#[inline(always)]
fn default() -> Self {
Self(BitField::new_none())
}
}
impl<U: IndexBase> IntoIterator for $maybe_index<U> {
type Item = $index<U>;
type IntoIter = std::option::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.to_option().into_iter()
}
}
#[cfg(feature = "serde")]
impl<U: serde::Serialize + IndexBase> serde::Serialize for $maybe_index<U> {
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.to_option().serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'de, U: IndexBase + serde::Deserialize<'de>> serde::Deserialize<'de>
for $maybe_index<U>
{
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok($maybe_index::new(serde::Deserialize::deserialize(
deserializer,
)?))
}
}
};
}
maybe_index_impls!(MaybeNodeIndex, NodeIndex);
maybe_index_impls!(MaybePortIndex, PortIndex);
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)]
pub struct PortOffset<U = u16>(BitField<U>);
impl<U: IndexBase> PortOffset<U> {
pub fn max_offset() -> usize {
BitField::<U>::max_index()
}
#[inline(always)]
pub fn new(direction: Direction, offset: usize) -> Self {
Self(BitField::new(offset, direction == Direction::Outgoing))
}
#[inline(always)]
pub fn new_incoming(offset: usize) -> Self {
Self::new(Direction::Incoming, offset)
}
#[inline(always)]
pub fn new_outgoing(offset: usize) -> Self {
Self::new(Direction::Outgoing, offset)
}
#[inline(always)]
pub fn direction(self) -> Direction {
if self.0.bit_flag().expect("invalid port offset") {
Direction::Outgoing
} else {
Direction::Incoming
}
}
#[inline(always)]
pub fn index(self) -> usize {
self.0.index().expect("invalid port offset")
}
#[inline(always)]
pub fn opposite(&self) -> Self {
Self(self.0.flip_bit_flag())
}
}
impl<U: IndexBase> Default for PortOffset<U> {
fn default() -> Self {
PortOffset::new_outgoing(0)
}
}
impl<U: IndexBase> std::fmt::Debug for PortOffset<U> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.direction() {
Direction::Incoming => write!(f, "Incoming({})", self.index()),
Direction::Outgoing => write!(f, "Outgoing({})", self.index()),
}
}
}
#[cfg(feature = "serde")]
mod serde_port_offset_impl {
use super::*;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Serialize, Deserialize)]
struct PortOffsetSer {
index: usize,
direction: Direction,
}
impl<U: IndexBase> From<PortOffset<U>> for PortOffsetSer {
fn from(port_offset: PortOffset<U>) -> Self {
Self {
index: port_offset.index(),
direction: port_offset.direction(),
}
}
}
impl<U: IndexBase> From<PortOffsetSer> for PortOffset<U> {
fn from(port_offset: PortOffsetSer) -> Self {
Self::new(port_offset.direction, port_offset.index)
}
}
impl<U: Serialize + IndexBase> Serialize for PortOffset<U> {
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let port_offset: PortOffsetSer = (*self).into();
port_offset.serialize(serializer)
}
}
impl<'de, U: IndexBase + Deserialize<'de>> Deserialize<'de> for PortOffset<U> {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let port_offset: PortOffsetSer = Deserialize::deserialize(deserializer)?;
Ok(port_offset.into())
}
}
}
#[repr(transparent)]
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
#[cfg_attr(feature = "pyo3", derive(IntoPyObject))]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
pub struct BitField<U>(U);
mod sealed {
pub trait Sealed {}
impl Sealed for u8 {}
impl Sealed for u16 {}
impl Sealed for u32 {}
impl Sealed for u64 {}
impl Sealed for usize {}
}
pub trait IndexBase:
sealed::Sealed
+ Copy
+ std::fmt::Debug
+ std::hash::Hash
+ Ord
+ num_traits::Unsigned
+ num_traits::Bounded
+ num_traits::NumAssign
+ num_traits::ToPrimitive
+ num_traits::FromPrimitive
+ ops::Shr<u8, Output = Self>
+ ops::Not<Output = Self>
+ ops::BitAnd<Output = Self>
+ ops::BitOr<Output = Self>
+ ops::BitXor<Output = Self>
{
type NonZero: Copy + std::fmt::Debug + Eq;
#[inline(always)]
fn to_usize(self) -> usize {
num_traits::ToPrimitive::to_usize(&self).expect("value too large for usize")
}
fn to_nonzero(self) -> Option<Self::NonZero>;
unsafe fn to_nonzero_unchecked(self) -> Self::NonZero;
fn from_nonzero(nonzero: Self::NonZero) -> Self;
}
impl IndexBase for u8 {
type NonZero = std::num::NonZeroU8;
#[inline(always)]
fn to_nonzero(self) -> Option<Self::NonZero> {
std::num::NonZeroU8::new(self)
}
#[inline(always)]
unsafe fn to_nonzero_unchecked(self) -> Self::NonZero {
Self::NonZero::new_unchecked(self)
}
#[inline(always)]
fn from_nonzero(nonzero: Self::NonZero) -> Self {
nonzero.get()
}
}
impl IndexBase for u16 {
type NonZero = std::num::NonZeroU16;
#[inline(always)]
fn to_nonzero(self) -> Option<Self::NonZero> {
std::num::NonZeroU16::new(self)
}
#[inline(always)]
unsafe fn to_nonzero_unchecked(self) -> Self::NonZero {
Self::NonZero::new_unchecked(self)
}
#[inline(always)]
fn from_nonzero(nonzero: Self::NonZero) -> Self {
nonzero.get()
}
}
impl IndexBase for u32 {
type NonZero = std::num::NonZeroU32;
#[inline(always)]
fn to_nonzero(self) -> Option<Self::NonZero> {
std::num::NonZeroU32::new(self)
}
#[inline(always)]
unsafe fn to_nonzero_unchecked(self) -> Self::NonZero {
Self::NonZero::new_unchecked(self)
}
#[inline(always)]
fn from_nonzero(nonzero: Self::NonZero) -> Self {
nonzero.get()
}
}
impl IndexBase for u64 {
type NonZero = std::num::NonZeroU64;
#[inline(always)]
fn to_nonzero(self) -> Option<Self::NonZero> {
std::num::NonZeroU64::new(self)
}
#[inline(always)]
unsafe fn to_nonzero_unchecked(self) -> Self::NonZero {
Self::NonZero::new_unchecked(self)
}
#[inline(always)]
fn from_nonzero(nonzero: Self::NonZero) -> Self {
nonzero.get()
}
}
impl IndexBase for usize {
type NonZero = std::num::NonZeroUsize;
#[inline(always)]
fn to_nonzero(self) -> Option<Self::NonZero> {
std::num::NonZeroUsize::new(self)
}
#[inline(always)]
unsafe fn to_nonzero_unchecked(self) -> Self::NonZero {
Self::NonZero::new_unchecked(self)
}
#[inline(always)]
fn from_nonzero(nonzero: Self::NonZero) -> Self {
nonzero.get()
}
}
impl<U: IndexBase> BitField<U> {
#[inline(always)]
pub(crate) fn new(index: usize, bit_flag: bool) -> Self {
if index > Self::max_index() {
panic!("index too large");
}
let u = U::from_usize(index).unwrap();
let ret = Self(u);
if bit_flag {
ret.set_bit_flag()
} else {
ret
}
}
#[inline(always)]
pub(crate) fn new_none() -> Self {
Self(U::max_value())
}
#[inline(always)]
fn max_index() -> usize {
(U::max_value().to_usize() >> 1) - 1
}
#[inline(always)]
pub(crate) fn index(self) -> Option<usize> {
if self.is_none() {
return None;
}
Some(self.index_unchecked())
}
#[inline(always)]
pub(crate) fn index_unchecked(self) -> usize {
(self.0 & !Self::msb_mask()).to_usize()
}
#[allow(unused)]
#[inline(always)]
fn set_index(self, index: usize) -> Self {
let u = U::from_usize(index).expect("index too large");
let msb = if self.is_none() {
U::zero()
} else {
self.0 & Self::msb_mask()
};
Self(u | msb)
}
#[inline(always)]
pub(crate) fn bit_flag(self) -> Option<bool> {
if self.is_none() {
return None;
}
Some(self.0 & Self::msb_mask() != U::zero())
}
#[inline(always)]
fn set_bit_flag(self) -> Self {
assert!(!self.is_none(), "bit field is unset");
let msb = self.0 | Self::msb_mask();
Self(msb)
}
#[allow(unused)]
#[inline(always)]
fn unset_bit_flag(self) -> Self {
assert!(!self.is_none(), "bit field is unset");
let msb = self.0 & !Self::msb_mask();
Self(msb)
}
#[inline(always)]
fn flip_bit_flag(self) -> Self {
Self(self.0 ^ Self::msb_mask())
}
#[inline(always)]
fn msb_mask() -> U {
U::max_value() - (U::max_value() >> 1)
}
#[inline(always)]
pub(crate) fn is_none(self) -> bool {
self == Self::new_none()
}
}
impl<U: IndexBase> TryFrom<usize> for BitField<U> {
type Error = IndexError;
#[inline]
fn try_from(index: usize) -> Result<Self, Self::Error> {
if index > Self::max_index() {
Err(IndexError { index })
} else {
Ok(Self::new(index, false))
}
}
}
impl<U: IndexBase> From<BitField<U>> for usize {
#[inline]
fn from(bit_field: BitField<U>) -> Self {
bit_field.index().expect("invalid index")
}
}
#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[error("the index {index} is too large.")]
pub struct IndexError {
pub(crate) index: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_opposite() {
let incoming = PortOffset::<u16>::new_incoming(5);
let outgoing = PortOffset::<u16>::new_outgoing(5);
assert_eq!(incoming.opposite(), outgoing);
assert_eq!(outgoing.opposite(), incoming);
}
use rstest::rstest;
#[rstest]
#[case(0u16, false)]
#[case(1u16, false)]
#[case(16u16, true)]
fn test_create_bitfield(#[case] value: u16, #[case] flag: bool) {
let idx = BitField::<u16>::new(value as usize, flag);
assert_eq!(idx.index().unwrap(), value as usize);
assert_eq!(idx.bit_flag().unwrap(), flag);
let idx = idx.flip_bit_flag();
assert_eq!(idx.bit_flag().unwrap(), !flag);
let idx = idx.set_bit_flag();
assert!(idx.bit_flag().unwrap());
let idx = idx.unset_bit_flag();
assert!(!idx.bit_flag().unwrap());
let idx = idx.set_index(2 * value as usize);
assert_eq!(idx.index().unwrap(), 2 * value as usize);
}
#[test]
fn test_from_usize_overflow_error() {
let value = u16::MAX >> 1;
assert!(BitField::<u16>::try_from(value as usize).is_err());
}
#[test]
fn test_bitfield_none() {
let idx = BitField::<u16>::new_none();
assert_eq!(idx.index(), None);
assert_eq!(idx.bit_flag(), None);
let idx = idx.set_index(1);
assert_eq!(idx.index().unwrap(), 1);
assert!(!idx.bit_flag().unwrap());
}
#[test]
fn test_port_offset_direction_and_index() {
let idx = 42;
let incoming = PortOffset::<u16>::new_incoming(idx);
assert_eq!(incoming.direction(), Direction::Incoming);
assert_eq!(incoming.index(), { idx });
assert_eq!(format!("{incoming:?}"), "Incoming(42)");
let idx2 = 99;
let outgoing = PortOffset::<u16>::new_outgoing(idx2);
assert_eq!(outgoing.direction(), Direction::Outgoing);
assert_eq!(outgoing.index(), { idx2 });
assert_eq!(format!("{outgoing:?}"), "Outgoing(99)");
}
}
#[cfg(test)]
#[cfg(feature = "serde")]
mod test_serde {
use crate::boundary::test::line_graph;
use crate::MultiPortGraph;
use crate::NodeIndex;
use rstest::rstest;
#[rstest]
fn test_serde_node_index(line_graph: (MultiPortGraph, [NodeIndex; 5])) {
let (graph, _) = line_graph;
insta::assert_snapshot!(serde_json::to_string_pretty(&graph).unwrap(),)
}
}