hugr_core/
core.rs

1//! Definitions for the core types used in the Hugr.
2//!
3//! These types are re-exported in the root of the crate.
4
5pub use itertools::Either;
6
7use derive_more::From;
8use itertools::Either::{Left, Right};
9
10use crate::hugr::HugrError;
11
12/// A handle to a node in the HUGR.
13#[derive(
14    Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, From, serde::Serialize, serde::Deserialize,
15)]
16#[serde(transparent)]
17pub struct Node {
18    index: portgraph::NodeIndex,
19}
20
21/// A handle to a port for a node in the HUGR.
22#[derive(
23    Clone,
24    Copy,
25    PartialEq,
26    PartialOrd,
27    Eq,
28    Ord,
29    Hash,
30    Default,
31    From,
32    serde::Serialize,
33    serde::Deserialize,
34)]
35#[serde(transparent)]
36pub struct Port {
37    offset: portgraph::PortOffset,
38}
39
40/// A trait for getting the undirected index of a port.
41pub trait PortIndex {
42    /// Returns the offset of the port.
43    fn index(self) -> usize;
44}
45
46/// A trait for getting the index of a node.
47pub trait NodeIndex {
48    /// Returns the index of the node.
49    fn index(self) -> usize;
50}
51
52/// A port in the incoming direction.
53#[derive(
54    Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, serde::Serialize, serde::Deserialize,
55)]
56pub struct IncomingPort {
57    index: u16,
58}
59
60/// A port in the outgoing direction.
61#[derive(
62    Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, serde::Serialize, serde::Deserialize,
63)]
64pub struct OutgoingPort {
65    index: u16,
66}
67
68/// The direction of a port.
69pub type Direction = portgraph::Direction;
70
71#[derive(
72    Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
73)]
74/// A DataFlow wire, defined by a Value-kind output port of a node
75// Stores node and offset to output port
76pub struct Wire(Node, OutgoingPort);
77
78impl Node {
79    /// Returns the node as a portgraph `NodeIndex`.
80    #[inline]
81    pub(crate) fn pg_index(self) -> portgraph::NodeIndex {
82        self.index
83    }
84}
85
86impl Port {
87    /// Creates a new port.
88    #[inline]
89    pub fn new(direction: Direction, port: usize) -> Self {
90        Self {
91            offset: portgraph::PortOffset::new(direction, port),
92        }
93    }
94
95    /// Converts to an [IncomingPort] if this port is one; else fails with
96    /// [HugrError::InvalidPortDirection]
97    #[inline]
98    pub fn as_incoming(&self) -> Result<IncomingPort, HugrError> {
99        self.as_directed()
100            .left()
101            .ok_or(HugrError::InvalidPortDirection(self.direction()))
102    }
103
104    /// Converts to an [OutgoingPort] if this port is one; else fails with
105    /// [HugrError::InvalidPortDirection]
106    #[inline]
107    pub fn as_outgoing(&self) -> Result<OutgoingPort, HugrError> {
108        self.as_directed()
109            .right()
110            .ok_or(HugrError::InvalidPortDirection(self.direction()))
111    }
112
113    /// Converts to either an [IncomingPort] or an [OutgoingPort], as appropriate.
114    #[inline]
115    pub fn as_directed(&self) -> Either<IncomingPort, OutgoingPort> {
116        match self.direction() {
117            Direction::Incoming => Left(IncomingPort {
118                index: self.index() as u16,
119            }),
120            Direction::Outgoing => Right(OutgoingPort {
121                index: self.index() as u16,
122            }),
123        }
124    }
125
126    /// Returns the direction of the port.
127    #[inline]
128    pub fn direction(self) -> Direction {
129        self.offset.direction()
130    }
131
132    /// Returns the port as a portgraph `PortOffset`.
133    #[inline]
134    pub(crate) fn pg_offset(self) -> portgraph::PortOffset {
135        self.offset
136    }
137}
138
139impl PortIndex for Port {
140    #[inline(always)]
141    fn index(self) -> usize {
142        self.offset.index()
143    }
144}
145
146impl PortIndex for usize {
147    #[inline(always)]
148    fn index(self) -> usize {
149        self
150    }
151}
152
153impl PortIndex for IncomingPort {
154    #[inline(always)]
155    fn index(self) -> usize {
156        self.index as usize
157    }
158}
159
160impl PortIndex for OutgoingPort {
161    #[inline(always)]
162    fn index(self) -> usize {
163        self.index as usize
164    }
165}
166
167impl From<usize> for IncomingPort {
168    #[inline(always)]
169    fn from(index: usize) -> Self {
170        Self {
171            index: index as u16,
172        }
173    }
174}
175
176impl From<usize> for OutgoingPort {
177    #[inline(always)]
178    fn from(index: usize) -> Self {
179        Self {
180            index: index as u16,
181        }
182    }
183}
184
185impl From<IncomingPort> for Port {
186    fn from(value: IncomingPort) -> Self {
187        Self {
188            offset: portgraph::PortOffset::new_incoming(value.index()),
189        }
190    }
191}
192
193impl From<OutgoingPort> for Port {
194    fn from(value: OutgoingPort) -> Self {
195        Self {
196            offset: portgraph::PortOffset::new_outgoing(value.index()),
197        }
198    }
199}
200
201impl NodeIndex for Node {
202    fn index(self) -> usize {
203        self.index.into()
204    }
205}
206
207impl Wire {
208    /// Create a new wire from a node and a port.
209    #[inline]
210    pub fn new(node: Node, port: impl Into<OutgoingPort>) -> Self {
211        Self(node, port.into())
212    }
213
214    /// The node that this wire is connected to.
215    #[inline]
216    pub fn node(&self) -> Node {
217        self.0
218    }
219
220    /// The output port that this wire is connected to.
221    #[inline]
222    pub fn source(&self) -> OutgoingPort {
223        self.1
224    }
225}
226
227impl std::fmt::Display for Wire {
228    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        write!(f, "Wire({}, {})", self.0.index(), self.1.index)
230    }
231}
232
233/// Enum for uniquely identifying the origin of linear wires in a circuit-like
234/// dataflow region.
235///
236/// Falls back to [`Wire`] if the wire is not linear or if it's not possible to
237/// track the origin.
238#[derive(
239    Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
240)]
241pub enum CircuitUnit {
242    /// Arbitrary input wire.
243    Wire(Wire),
244    /// Index to region input.
245    Linear(usize),
246}
247
248impl CircuitUnit {
249    /// Check if this is a wire.
250    pub fn is_wire(&self) -> bool {
251        matches!(self, CircuitUnit::Wire(_))
252    }
253
254    /// Check if this is a linear unit.
255    pub fn is_linear(&self) -> bool {
256        matches!(self, CircuitUnit::Linear(_))
257    }
258}
259
260impl From<usize> for CircuitUnit {
261    fn from(value: usize) -> Self {
262        CircuitUnit::Linear(value)
263    }
264}
265
266impl From<Wire> for CircuitUnit {
267    fn from(value: Wire) -> Self {
268        CircuitUnit::Wire(value)
269    }
270}
271
272impl std::fmt::Debug for Node {
273    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274        f.debug_tuple("Node").field(&self.index()).finish()
275    }
276}
277
278impl std::fmt::Debug for Port {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        f.debug_tuple("Port")
281            .field(&self.offset.direction())
282            .field(&self.index())
283            .finish()
284    }
285}
286
287impl std::fmt::Debug for IncomingPort {
288    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        f.debug_tuple("IncomingPort").field(&self.index).finish()
290    }
291}
292
293impl std::fmt::Debug for OutgoingPort {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        f.debug_tuple("OutgoingPort").field(&self.index).finish()
296    }
297}
298
299impl std::fmt::Debug for Wire {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        f.debug_struct("Wire")
302            .field("node", &self.0.index())
303            .field("port", &self.1)
304            .finish()
305    }
306}
307
308impl std::fmt::Debug for CircuitUnit {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        match self {
311            Self::Wire(w) => f
312                .debug_struct("WireUnit")
313                .field("node", &w.0.index())
314                .field("port", &w.1)
315                .finish(),
316            Self::Linear(id) => f.debug_tuple("LinearUnit").field(id).finish(),
317        }
318    }
319}
320
321macro_rules! impl_display_from_debug {
322    ($($t:ty),*) => {
323        $(
324            impl std::fmt::Display for $t {
325                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326                    <Self as std::fmt::Debug>::fmt(self, f)
327                }
328            }
329        )*
330    };
331}
332impl_display_from_debug!(Node, Port, IncomingPort, OutgoingPort, CircuitUnit);