1pub use itertools::Either;
6
7use derive_more::From;
8use itertools::Either::{Left, Right};
9
10use crate::hugr::HugrError;
11
12#[derive(
14    Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, From, serde::Serialize, serde::Deserialize,
15)]
16#[serde(transparent)]
17pub struct Node {
18    index: portgraph::NodeIndex,
19}
20
21#[derive(
23    Clone,
24    Copy,
25    PartialEq,
26    PartialOrd,
27    Eq,
28    Ord,
29    Hash,
30    Default,
31    From,
32    serde::Serialize,
33    serde::Deserialize,
34)]
35#[serde(transparent)]
36pub struct Port {
37    offset: portgraph::PortOffset,
38}
39
40pub trait PortIndex {
42    fn index(self) -> usize;
44}
45
46pub trait NodeIndex {
48    fn index(self) -> usize;
50}
51
52pub trait HugrNode: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash {}
54
55impl<T: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash> HugrNode for T {}
56
57#[derive(
59    Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, serde::Serialize, serde::Deserialize,
60)]
61pub struct IncomingPort {
62    index: u16,
63}
64
65#[derive(
67    Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, serde::Serialize, serde::Deserialize,
68)]
69pub struct OutgoingPort {
70    index: u16,
71}
72
73pub type Direction = portgraph::Direction;
75
76#[derive(
77    Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
78)]
79pub struct Wire<N = Node>(N, OutgoingPort);
82
83impl Node {
84    #[inline]
86    pub(crate) fn pg_index(self) -> portgraph::NodeIndex {
87        self.index
88    }
89}
90
91impl Port {
92    #[inline]
94    pub fn new(direction: Direction, port: usize) -> Self {
95        Self {
96            offset: portgraph::PortOffset::new(direction, port),
97        }
98    }
99
100    #[inline]
103    pub fn as_incoming(&self) -> Result<IncomingPort, HugrError> {
104        self.as_directed()
105            .left()
106            .ok_or(HugrError::InvalidPortDirection(self.direction()))
107    }
108
109    #[inline]
112    pub fn as_outgoing(&self) -> Result<OutgoingPort, HugrError> {
113        self.as_directed()
114            .right()
115            .ok_or(HugrError::InvalidPortDirection(self.direction()))
116    }
117
118    #[inline]
120    pub fn as_directed(&self) -> Either<IncomingPort, OutgoingPort> {
121        match self.direction() {
122            Direction::Incoming => Left(IncomingPort {
123                index: self.index() as u16,
124            }),
125            Direction::Outgoing => Right(OutgoingPort {
126                index: self.index() as u16,
127            }),
128        }
129    }
130
131    #[inline]
133    pub fn direction(self) -> Direction {
134        self.offset.direction()
135    }
136
137    #[inline]
139    pub(crate) fn pg_offset(self) -> portgraph::PortOffset {
140        self.offset
141    }
142}
143
144impl PortIndex for Port {
145    #[inline(always)]
146    fn index(self) -> usize {
147        self.offset.index()
148    }
149}
150
151impl PortIndex for usize {
152    #[inline(always)]
153    fn index(self) -> usize {
154        self
155    }
156}
157
158impl PortIndex for IncomingPort {
159    #[inline(always)]
160    fn index(self) -> usize {
161        self.index as usize
162    }
163}
164
165impl PortIndex for OutgoingPort {
166    #[inline(always)]
167    fn index(self) -> usize {
168        self.index as usize
169    }
170}
171
172impl From<usize> for IncomingPort {
173    #[inline(always)]
174    fn from(index: usize) -> Self {
175        Self {
176            index: index as u16,
177        }
178    }
179}
180
181impl From<usize> for OutgoingPort {
182    #[inline(always)]
183    fn from(index: usize) -> Self {
184        Self {
185            index: index as u16,
186        }
187    }
188}
189
190impl From<IncomingPort> for Port {
191    fn from(value: IncomingPort) -> Self {
192        Self {
193            offset: portgraph::PortOffset::new_incoming(value.index()),
194        }
195    }
196}
197
198impl From<OutgoingPort> for Port {
199    fn from(value: OutgoingPort) -> Self {
200        Self {
201            offset: portgraph::PortOffset::new_outgoing(value.index()),
202        }
203    }
204}
205
206impl NodeIndex for Node {
207    fn index(self) -> usize {
208        self.index.into()
209    }
210}
211
212impl<N: HugrNode> Wire<N> {
213    #[inline]
215    pub fn new(node: N, port: impl Into<OutgoingPort>) -> Self {
216        Self(node, port.into())
217    }
218
219    #[inline]
221    pub fn node(&self) -> N {
222        self.0
223    }
224
225    #[inline]
227    pub fn source(&self) -> OutgoingPort {
228        self.1
229    }
230}
231
232impl<N: HugrNode> std::fmt::Display for Wire<N> {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        write!(f, "Wire({}, {})", self.0, self.1.index)
235    }
236}
237
238#[derive(
244    Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
245)]
246pub enum CircuitUnit<N = Node> {
247    Wire(Wire<N>),
249    Linear(usize),
251}
252
253impl CircuitUnit {
254    pub fn is_wire(&self) -> bool {
256        matches!(self, CircuitUnit::Wire(_))
257    }
258
259    pub fn is_linear(&self) -> bool {
261        matches!(self, CircuitUnit::Linear(_))
262    }
263}
264
265impl From<usize> for CircuitUnit {
266    fn from(value: usize) -> Self {
267        CircuitUnit::Linear(value)
268    }
269}
270
271impl From<Wire> for CircuitUnit {
272    fn from(value: Wire) -> Self {
273        CircuitUnit::Wire(value)
274    }
275}
276
277impl std::fmt::Debug for Node {
278    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279        f.debug_tuple("Node").field(&self.index()).finish()
280    }
281}
282
283impl std::fmt::Debug for Port {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        f.debug_tuple("Port")
286            .field(&self.offset.direction())
287            .field(&self.index())
288            .finish()
289    }
290}
291
292impl std::fmt::Debug for IncomingPort {
293    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294        f.debug_tuple("IncomingPort").field(&self.index).finish()
295    }
296}
297
298impl std::fmt::Debug for OutgoingPort {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        f.debug_tuple("OutgoingPort").field(&self.index).finish()
301    }
302}
303
304impl std::fmt::Debug for Wire {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        f.debug_struct("Wire")
307            .field("node", &self.0.index())
308            .field("port", &self.1)
309            .finish()
310    }
311}
312
313impl std::fmt::Debug for CircuitUnit {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        match self {
316            Self::Wire(w) => f
317                .debug_struct("WireUnit")
318                .field("node", &w.0.index())
319                .field("port", &w.1)
320                .finish(),
321            Self::Linear(id) => f.debug_tuple("LinearUnit").field(id).finish(),
322        }
323    }
324}
325
326macro_rules! impl_display_from_debug {
327    ($($t:ty),*) => {
328        $(
329            impl std::fmt::Display for $t {
330                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331                    <Self as std::fmt::Debug>::fmt(self, f)
332                }
333            }
334        )*
335    };
336}
337impl_display_from_debug!(Node, Port, IncomingPort, OutgoingPort, CircuitUnit);