pub use itertools::Either;
use derive_more::From;
use itertools::Either::{Left, Right};
use crate::hugr::HugrError;
#[derive(
Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, From, serde::Serialize, serde::Deserialize,
)]
#[serde(transparent)]
pub struct Node {
index: portgraph::NodeIndex,
}
#[derive(
Clone,
Copy,
PartialEq,
PartialOrd,
Eq,
Ord,
Hash,
Default,
From,
serde::Serialize,
serde::Deserialize,
)]
#[serde(transparent)]
pub struct Port {
offset: portgraph::PortOffset,
}
pub trait PortIndex {
fn index(self) -> usize;
}
pub trait NodeIndex {
fn index(self) -> usize;
}
#[derive(
Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, serde::Serialize, serde::Deserialize,
)]
pub struct IncomingPort {
index: u16,
}
#[derive(
Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, serde::Serialize, serde::Deserialize,
)]
pub struct OutgoingPort {
index: u16,
}
pub type Direction = portgraph::Direction;
#[derive(
Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
)]
pub struct Wire(Node, OutgoingPort);
impl Node {
#[inline]
pub(crate) fn pg_index(self) -> portgraph::NodeIndex {
self.index
}
}
impl Port {
#[inline]
pub fn new(direction: Direction, port: usize) -> Self {
Self {
offset: portgraph::PortOffset::new(direction, port),
}
}
#[inline]
pub fn as_incoming(&self) -> Result<IncomingPort, HugrError> {
self.as_directed()
.left()
.ok_or(HugrError::InvalidPortDirection(self.direction()))
}
#[inline]
pub fn as_outgoing(&self) -> Result<OutgoingPort, HugrError> {
self.as_directed()
.right()
.ok_or(HugrError::InvalidPortDirection(self.direction()))
}
#[inline]
pub fn as_directed(&self) -> Either<IncomingPort, OutgoingPort> {
match self.direction() {
Direction::Incoming => Left(IncomingPort {
index: self.index() as u16,
}),
Direction::Outgoing => Right(OutgoingPort {
index: self.index() as u16,
}),
}
}
#[inline]
pub fn direction(self) -> Direction {
self.offset.direction()
}
#[inline]
pub(crate) fn pg_offset(self) -> portgraph::PortOffset {
self.offset
}
}
impl PortIndex for Port {
#[inline(always)]
fn index(self) -> usize {
self.offset.index()
}
}
impl PortIndex for usize {
#[inline(always)]
fn index(self) -> usize {
self
}
}
impl PortIndex for IncomingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}
impl PortIndex for OutgoingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}
impl From<usize> for IncomingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}
impl From<usize> for OutgoingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}
impl From<IncomingPort> for Port {
fn from(value: IncomingPort) -> Self {
Self {
offset: portgraph::PortOffset::new_incoming(value.index()),
}
}
}
impl From<OutgoingPort> for Port {
fn from(value: OutgoingPort) -> Self {
Self {
offset: portgraph::PortOffset::new_outgoing(value.index()),
}
}
}
impl NodeIndex for Node {
fn index(self) -> usize {
self.index.into()
}
}
impl Wire {
#[inline]
pub fn new(node: Node, port: impl Into<OutgoingPort>) -> Self {
Self(node, port.into())
}
#[inline]
pub fn node(&self) -> Node {
self.0
}
#[inline]
pub fn source(&self) -> OutgoingPort {
self.1
}
}
impl std::fmt::Display for Wire {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Wire({}, {})", self.0.index(), self.1.index)
}
}
#[derive(
Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
)]
pub enum CircuitUnit {
Wire(Wire),
Linear(usize),
}
impl CircuitUnit {
pub fn is_wire(&self) -> bool {
matches!(self, CircuitUnit::Wire(_))
}
pub fn is_linear(&self) -> bool {
matches!(self, CircuitUnit::Linear(_))
}
}
impl From<usize> for CircuitUnit {
fn from(value: usize) -> Self {
CircuitUnit::Linear(value)
}
}
impl From<Wire> for CircuitUnit {
fn from(value: Wire) -> Self {
CircuitUnit::Wire(value)
}
}
impl std::fmt::Debug for Node {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Node").field(&self.index()).finish()
}
}
impl std::fmt::Debug for Port {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Port")
.field(&self.offset.direction())
.field(&self.index())
.finish()
}
}
impl std::fmt::Debug for IncomingPort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("IncomingPort").field(&self.index).finish()
}
}
impl std::fmt::Debug for OutgoingPort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("OutgoingPort").field(&self.index).finish()
}
}
impl std::fmt::Debug for Wire {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Wire")
.field("node", &self.0.index())
.field("port", &self.1)
.finish()
}
}
impl std::fmt::Debug for CircuitUnit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Wire(w) => f
.debug_struct("WireUnit")
.field("node", &w.0.index())
.field("port", &w.1)
.finish(),
Self::Linear(id) => f.debug_tuple("LinearUnit").field(id).finish(),
}
}
}
macro_rules! impl_display_from_debug {
($($t:ty),*) => {
$(
impl std::fmt::Display for $t {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
<Self as std::fmt::Debug>::fmt(self, f)
}
}
)*
};
}
impl_display_from_debug!(Node, Port, IncomingPort, OutgoingPort, CircuitUnit);