pub use itertools::Either;
use derive_more::From;
use itertools::Either::{Left, Right};
use crate::{HugrView, hugr::HugrError};
#[derive(
Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, From, serde::Serialize, serde::Deserialize,
)]
#[serde(transparent)]
pub struct Node {
index: portgraph::NodeIndex,
}
#[derive(
Clone,
Copy,
PartialEq,
PartialOrd,
Eq,
Ord,
Hash,
Default,
From,
serde::Serialize,
serde::Deserialize,
)]
#[serde(transparent)]
pub struct Port {
offset: portgraph::PortOffset<u32>,
}
pub trait PortIndex {
fn index(self) -> usize;
}
pub trait NodeIndex {
fn index(self) -> usize;
}
pub trait HugrNode: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash {}
impl<T: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash> HugrNode for T {}
#[derive(
Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, serde::Serialize, serde::Deserialize,
)]
pub struct IncomingPort {
index: u16,
}
#[derive(
Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, serde::Serialize, serde::Deserialize,
)]
pub struct OutgoingPort {
index: u16,
}
pub type Direction = portgraph::Direction;
#[derive(
Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
)]
pub struct Wire<N = Node>(N, OutgoingPort);
impl Node {
#[inline]
pub(crate) fn into_portgraph(self) -> portgraph::NodeIndex {
self.index
}
}
impl Port {
#[inline]
#[must_use]
pub fn new(direction: Direction, port: usize) -> Self {
Self {
offset: portgraph::PortOffset::new(direction, port),
}
}
#[inline]
pub fn as_incoming(&self) -> Result<IncomingPort, HugrError> {
self.as_directed()
.left()
.ok_or(HugrError::InvalidPortDirection(self.direction()))
}
#[inline]
pub fn as_outgoing(&self) -> Result<OutgoingPort, HugrError> {
self.as_directed()
.right()
.ok_or(HugrError::InvalidPortDirection(self.direction()))
}
#[inline]
#[must_use]
pub fn as_directed(&self) -> Either<IncomingPort, OutgoingPort> {
match self.direction() {
Direction::Incoming => Left(IncomingPort {
index: self.index() as u16,
}),
Direction::Outgoing => Right(OutgoingPort {
index: self.index() as u16,
}),
}
}
#[inline]
#[must_use]
pub fn direction(self) -> Direction {
self.offset.direction()
}
#[inline]
pub(crate) fn pg_offset(self) -> portgraph::PortOffset<u32> {
self.offset
}
}
impl PortIndex for Port {
#[inline(always)]
fn index(self) -> usize {
self.offset.index()
}
}
impl PortIndex for usize {
#[inline(always)]
fn index(self) -> usize {
self
}
}
impl PortIndex for IncomingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}
impl PortIndex for OutgoingPort {
#[inline(always)]
fn index(self) -> usize {
self.index as usize
}
}
impl From<usize> for IncomingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}
impl From<usize> for OutgoingPort {
#[inline(always)]
fn from(index: usize) -> Self {
Self {
index: index as u16,
}
}
}
impl From<IncomingPort> for Port {
fn from(value: IncomingPort) -> Self {
Self {
offset: portgraph::PortOffset::new_incoming(value.index()),
}
}
}
impl From<OutgoingPort> for Port {
fn from(value: OutgoingPort) -> Self {
Self {
offset: portgraph::PortOffset::new_outgoing(value.index()),
}
}
}
impl NodeIndex for Node {
fn index(self) -> usize {
self.index.into()
}
}
impl<N: HugrNode> Wire<N> {
#[inline]
pub fn new(node: N, port: impl Into<OutgoingPort>) -> Self {
Self(node, port.into())
}
#[inline]
pub fn from_connected_port(
node: N,
port: impl Into<Port>,
hugr: &impl HugrView<Node = N>,
) -> Self {
let (node, outgoing) = match port.into().as_directed() {
Either::Left(incoming) => hugr
.single_linked_output(node, incoming)
.expect("invalid dfg port"),
Either::Right(outgoing) => (node, outgoing),
};
Self::new(node, outgoing)
}
#[inline]
pub fn node(&self) -> N {
self.0
}
#[inline]
pub fn source(&self) -> OutgoingPort {
self.1
}
pub fn all_connected_ports<'h, H: HugrView<Node = N>>(
&self,
hugr: &'h H,
) -> impl Iterator<Item = (N, Port)> + use<'h, N, H> {
let node = self.node();
let out_port = self.source();
std::iter::once((node, out_port.into())).chain(hugr.linked_ports(node, out_port))
}
}
impl<N: HugrNode> std::fmt::Display for Wire<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Wire({}, {})", self.0, self.1.index)
}
}
#[derive(
Clone,
Debug,
derive_more::Display,
PartialEq,
Eq,
PartialOrd,
Ord,
serde::Serialize,
serde::Deserialize,
)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
#[non_exhaustive]
pub enum Visibility {
Public,
Private,
}
impl From<hugr_model::v0::Visibility> for Visibility {
fn from(value: hugr_model::v0::Visibility) -> Self {
match value {
hugr_model::v0::Visibility::Private => Self::Private,
hugr_model::v0::Visibility::Public => Self::Public,
}
}
}
impl From<Visibility> for hugr_model::v0::Visibility {
fn from(value: Visibility) -> Self {
match value {
Visibility::Public => hugr_model::v0::Visibility::Public,
Visibility::Private => hugr_model::v0::Visibility::Private,
}
}
}
#[derive(
Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
)]
pub enum CircuitUnit<N = Node> {
Wire(Wire<N>),
Linear(usize),
}
impl CircuitUnit {
#[must_use]
pub fn is_wire(&self) -> bool {
matches!(self, CircuitUnit::Wire(_))
}
#[must_use]
pub fn is_linear(&self) -> bool {
matches!(self, CircuitUnit::Linear(_))
}
}
impl From<usize> for CircuitUnit {
fn from(value: usize) -> Self {
CircuitUnit::Linear(value)
}
}
impl From<Wire> for CircuitUnit {
fn from(value: Wire) -> Self {
CircuitUnit::Wire(value)
}
}
impl std::fmt::Debug for Node {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Node").field(&self.index()).finish()
}
}
impl std::fmt::Debug for Port {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Port")
.field(&self.offset.direction())
.field(&self.index())
.finish()
}
}
impl std::fmt::Debug for IncomingPort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("IncomingPort").field(&self.index).finish()
}
}
impl std::fmt::Debug for OutgoingPort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("OutgoingPort").field(&self.index).finish()
}
}
impl<N: HugrNode> std::fmt::Debug for Wire<N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Wire")
.field("node", &self.0)
.field("port", &self.1)
.finish()
}
}
impl std::fmt::Debug for CircuitUnit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Wire(w) => f
.debug_struct("WireUnit")
.field("node", &w.0.index())
.field("port", &w.1)
.finish(),
Self::Linear(id) => f.debug_tuple("LinearUnit").field(id).finish(),
}
}
}
macro_rules! impl_display_from_debug {
($($t:ty),*) => {
$(
impl std::fmt::Display for $t {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
<Self as std::fmt::Debug>::fmt(self, f)
}
}
)*
};
}
impl_display_from_debug!(Node, Port, IncomingPort, OutgoingPort, CircuitUnit);