hugr_core/
core.rs

1//! Definitions for the core types used in the Hugr.
2//!
3//! These types are re-exported in the root of the crate.
4
5pub use itertools::Either;
6
7use derive_more::From;
8use itertools::Either::{Left, Right};
9
10use crate::{HugrView, hugr::HugrError};
11
12/// A handle to a node in the HUGR.
13#[derive(
14    Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, From, serde::Serialize, serde::Deserialize,
15)]
16#[serde(transparent)]
17pub struct Node {
18    index: portgraph::NodeIndex,
19}
20
21/// A handle to a port for a node in the HUGR.
22#[derive(
23    Clone,
24    Copy,
25    PartialEq,
26    PartialOrd,
27    Eq,
28    Ord,
29    Hash,
30    Default,
31    From,
32    serde::Serialize,
33    serde::Deserialize,
34)]
35#[serde(transparent)]
36pub struct Port {
37    offset: portgraph::PortOffset<u32>,
38}
39
40/// A trait for getting the undirected index of a port.
41pub trait PortIndex {
42    /// Returns the offset of the port.
43    fn index(self) -> usize;
44}
45
46/// A trait for getting the index of a node.
47pub trait NodeIndex {
48    /// Returns the index of the node.
49    fn index(self) -> usize;
50}
51
52/// A trait for nodes in the Hugr.
53pub trait HugrNode: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash {}
54
55impl<T: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash> HugrNode for T {}
56
57/// A port in the incoming direction.
58#[derive(
59    Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, serde::Serialize, serde::Deserialize,
60)]
61pub struct IncomingPort {
62    index: u16,
63}
64
65/// A port in the outgoing direction.
66#[derive(
67    Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, serde::Serialize, serde::Deserialize,
68)]
69pub struct OutgoingPort {
70    index: u16,
71}
72
73/// The direction of a port.
74pub type Direction = portgraph::Direction;
75
76#[derive(
77    Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
78)]
79/// A `DataFlow` wire, defined by a Value-kind output port of a node
80// Stores node and offset to output port
81pub struct Wire<N = Node>(N, OutgoingPort);
82
83impl Node {
84    /// Returns the node as a portgraph `NodeIndex`.
85    #[inline]
86    pub(crate) fn into_portgraph(self) -> portgraph::NodeIndex {
87        self.index
88    }
89}
90
91impl Port {
92    /// Creates a new port.
93    #[inline]
94    #[must_use]
95    pub fn new(direction: Direction, port: usize) -> Self {
96        Self {
97            offset: portgraph::PortOffset::new(direction, port),
98        }
99    }
100
101    /// Converts to an [`IncomingPort`] if this port is one; else fails with
102    /// [`HugrError::InvalidPortDirection`]
103    #[inline]
104    pub fn as_incoming(&self) -> Result<IncomingPort, HugrError> {
105        self.as_directed()
106            .left()
107            .ok_or(HugrError::InvalidPortDirection(self.direction()))
108    }
109
110    /// Converts to an [`OutgoingPort`] if this port is one; else fails with
111    /// [`HugrError::InvalidPortDirection`]
112    #[inline]
113    pub fn as_outgoing(&self) -> Result<OutgoingPort, HugrError> {
114        self.as_directed()
115            .right()
116            .ok_or(HugrError::InvalidPortDirection(self.direction()))
117    }
118
119    /// Converts to either an [`IncomingPort`] or an [`OutgoingPort`], as appropriate.
120    #[inline]
121    #[must_use]
122    pub fn as_directed(&self) -> Either<IncomingPort, OutgoingPort> {
123        match self.direction() {
124            Direction::Incoming => Left(IncomingPort {
125                index: self.index() as u16,
126            }),
127            Direction::Outgoing => Right(OutgoingPort {
128                index: self.index() as u16,
129            }),
130        }
131    }
132
133    /// Returns the direction of the port.
134    #[inline]
135    #[must_use]
136    pub fn direction(self) -> Direction {
137        self.offset.direction()
138    }
139
140    /// Returns the port as a portgraph `PortOffset`.
141    #[inline]
142    pub(crate) fn pg_offset(self) -> portgraph::PortOffset<u32> {
143        self.offset
144    }
145}
146
147impl PortIndex for Port {
148    #[inline(always)]
149    fn index(self) -> usize {
150        self.offset.index()
151    }
152}
153
154impl PortIndex for usize {
155    #[inline(always)]
156    fn index(self) -> usize {
157        self
158    }
159}
160
161impl PortIndex for IncomingPort {
162    #[inline(always)]
163    fn index(self) -> usize {
164        self.index as usize
165    }
166}
167
168impl PortIndex for OutgoingPort {
169    #[inline(always)]
170    fn index(self) -> usize {
171        self.index as usize
172    }
173}
174
175impl From<usize> for IncomingPort {
176    #[inline(always)]
177    fn from(index: usize) -> Self {
178        Self {
179            index: index as u16,
180        }
181    }
182}
183
184impl From<usize> for OutgoingPort {
185    #[inline(always)]
186    fn from(index: usize) -> Self {
187        Self {
188            index: index as u16,
189        }
190    }
191}
192
193impl From<IncomingPort> for Port {
194    fn from(value: IncomingPort) -> Self {
195        Self {
196            offset: portgraph::PortOffset::new_incoming(value.index()),
197        }
198    }
199}
200
201impl From<OutgoingPort> for Port {
202    fn from(value: OutgoingPort) -> Self {
203        Self {
204            offset: portgraph::PortOffset::new_outgoing(value.index()),
205        }
206    }
207}
208
209impl NodeIndex for Node {
210    fn index(self) -> usize {
211        self.index.into()
212    }
213}
214
215impl<N: HugrNode> Wire<N> {
216    /// Create a new wire from a node and a port.
217    #[inline]
218    pub fn new(node: N, port: impl Into<OutgoingPort>) -> Self {
219        Self(node, port.into())
220    }
221
222    /// Create a new wire from a node and a port that is connected to the wire.
223    ///
224    /// If `port` is an incoming port, the wire is traversed to find the unique
225    /// outgoing port that is connected to the wire. Otherwise, this is
226    /// equivalent to constructing a wire using [`Wire::new`].
227    ///
228    /// ## Panics
229    ///
230    /// This will panic if the wire is not connected to a unique outgoing port.
231    #[inline]
232    pub fn from_connected_port(
233        node: N,
234        port: impl Into<Port>,
235        hugr: &impl HugrView<Node = N>,
236    ) -> Self {
237        let (node, outgoing) = match port.into().as_directed() {
238            Either::Left(incoming) => hugr
239                .single_linked_output(node, incoming)
240                .expect("invalid dfg port"),
241            Either::Right(outgoing) => (node, outgoing),
242        };
243        Self::new(node, outgoing)
244    }
245
246    /// The node of the unique outgoing port that the wire is connected to.
247    #[inline]
248    pub fn node(&self) -> N {
249        self.0
250    }
251
252    /// The unique outgoing port that the wire is connected to.
253    #[inline]
254    pub fn source(&self) -> OutgoingPort {
255        self.1
256    }
257
258    /// Get all ports connected to the wire.
259    ///
260    /// Return a chained iterator of the unique outgoing port, followed by all
261    /// incoming ports connected to the wire.
262    pub fn all_connected_ports<'h, H: HugrView<Node = N>>(
263        &self,
264        hugr: &'h H,
265    ) -> impl Iterator<Item = (N, Port)> + use<'h, N, H> {
266        let node = self.node();
267        let out_port = self.source();
268
269        std::iter::once((node, out_port.into())).chain(hugr.linked_ports(node, out_port))
270    }
271}
272
273impl<N: HugrNode> std::fmt::Display for Wire<N> {
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        write!(f, "Wire({}, {})", self.0, self.1.index)
276    }
277}
278
279/// Marks [FuncDefn](crate::ops::FuncDefn)s and [FuncDecl](crate::ops::FuncDecl)s as
280/// to whether they should be considered for linking.
281#[derive(
282    Clone,
283    Debug,
284    derive_more::Display,
285    PartialEq,
286    Eq,
287    PartialOrd,
288    Ord,
289    serde::Serialize,
290    serde::Deserialize,
291)]
292#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
293#[non_exhaustive]
294pub enum Visibility {
295    /// Function is visible or exported
296    Public,
297    /// Function is hidden, for use within the hugr only
298    Private,
299}
300
301impl From<hugr_model::v0::Visibility> for Visibility {
302    fn from(value: hugr_model::v0::Visibility) -> Self {
303        match value {
304            hugr_model::v0::Visibility::Private => Self::Private,
305            hugr_model::v0::Visibility::Public => Self::Public,
306        }
307    }
308}
309
310impl From<Visibility> for hugr_model::v0::Visibility {
311    fn from(value: Visibility) -> Self {
312        match value {
313            Visibility::Public => hugr_model::v0::Visibility::Public,
314            Visibility::Private => hugr_model::v0::Visibility::Private,
315        }
316    }
317}
318
319/// Enum for uniquely identifying the origin of linear wires in a circuit-like
320/// dataflow region.
321///
322/// Falls back to [`Wire`] if the wire is not linear or if it's not possible to
323/// track the origin.
324#[derive(
325    Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
326)]
327pub enum CircuitUnit<N = Node> {
328    /// Arbitrary input wire.
329    Wire(Wire<N>),
330    /// Index to region input.
331    Linear(usize),
332}
333
334impl CircuitUnit {
335    /// Check if this is a wire.
336    #[must_use]
337    pub fn is_wire(&self) -> bool {
338        matches!(self, CircuitUnit::Wire(_))
339    }
340
341    /// Check if this is a linear unit.
342    #[must_use]
343    pub fn is_linear(&self) -> bool {
344        matches!(self, CircuitUnit::Linear(_))
345    }
346}
347
348impl From<usize> for CircuitUnit {
349    fn from(value: usize) -> Self {
350        CircuitUnit::Linear(value)
351    }
352}
353
354impl From<Wire> for CircuitUnit {
355    fn from(value: Wire) -> Self {
356        CircuitUnit::Wire(value)
357    }
358}
359
360impl std::fmt::Debug for Node {
361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        f.debug_tuple("Node").field(&self.index()).finish()
363    }
364}
365
366impl std::fmt::Debug for Port {
367    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        f.debug_tuple("Port")
369            .field(&self.offset.direction())
370            .field(&self.index())
371            .finish()
372    }
373}
374
375impl std::fmt::Debug for IncomingPort {
376    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377        f.debug_tuple("IncomingPort").field(&self.index).finish()
378    }
379}
380
381impl std::fmt::Debug for OutgoingPort {
382    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383        f.debug_tuple("OutgoingPort").field(&self.index).finish()
384    }
385}
386
387impl<N: HugrNode> std::fmt::Debug for Wire<N> {
388    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
389        f.debug_struct("Wire")
390            .field("node", &self.0)
391            .field("port", &self.1)
392            .finish()
393    }
394}
395
396impl std::fmt::Debug for CircuitUnit {
397    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398        match self {
399            Self::Wire(w) => f
400                .debug_struct("WireUnit")
401                .field("node", &w.0.index())
402                .field("port", &w.1)
403                .finish(),
404            Self::Linear(id) => f.debug_tuple("LinearUnit").field(id).finish(),
405        }
406    }
407}
408
409macro_rules! impl_display_from_debug {
410    ($($t:ty),*) => {
411        $(
412            impl std::fmt::Display for $t {
413                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414                    <Self as std::fmt::Debug>::fmt(self, f)
415                }
416            }
417        )*
418    };
419}
420impl_display_from_debug!(Node, Port, IncomingPort, OutgoingPort, CircuitUnit);