use std::fmt;
use serde::{Deserialize, Serialize};
#[derive(
Clone,
Debug,
Deserialize,
Eq,
Hash,
PartialEq,
rkyv::Archive,
rkyv::Deserialize,
rkyv::Serialize,
Serialize,
)]
pub enum EdgeEndpointDef {
Any,
NodeType(u32),
OneOf(Vec<u32>),
}
impl EdgeEndpointDef {
#[must_use]
pub fn one_of(indices: impl IntoIterator<Item = u32>) -> Self {
let mut buf: Vec<u32> = indices.into_iter().collect();
buf.sort_unstable();
buf.dedup();
assert!(
!buf.is_empty(),
"EdgeEndpointDef::one_of called with empty index set"
);
match buf.len() {
1 => Self::NodeType(buf[0]),
_ => Self::OneOf(buf),
}
}
#[must_use]
pub fn matches_node_type(&self, node_type: u32) -> bool {
match self {
Self::Any => true,
Self::NodeType(expected) => *expected == node_type,
Self::OneOf(indices) => indices.binary_search(&node_type).is_ok(),
}
}
#[must_use]
pub fn overlaps(&self, other: &Self) -> bool {
match (self, other) {
(Self::Any, _) | (_, Self::Any) => true,
(Self::NodeType(left), Self::NodeType(right)) => left == right,
(Self::NodeType(index), Self::OneOf(indices))
| (Self::OneOf(indices), Self::NodeType(index)) => indices.binary_search(index).is_ok(),
(Self::OneOf(left), Self::OneOf(right)) => sorted_slices_intersect(left, right),
}
}
#[must_use]
pub const fn node_type_index(&self) -> Option<u32> {
match self {
Self::Any | Self::OneOf(_) => None,
Self::NodeType(index) => Some(*index),
}
}
}
impl fmt::Display for EdgeEndpointDef {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Any => formatter.write_str("Any"),
Self::NodeType(index) => write!(formatter, "{index}"),
Self::OneOf(indices) => {
formatter.write_str("OneOf(")?;
for (position, index) in indices.iter().enumerate() {
if position > 0 {
formatter.write_str(", ")?;
}
write!(formatter, "{index}")?;
}
formatter.write_str(")")
}
}
}
}
fn sorted_slices_intersect(left: &[u32], right: &[u32]) -> bool {
let (mut i, mut j) = (0, 0);
while i < left.len() && j < right.len() {
match left[i].cmp(&right[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => return true,
}
}
false
}