hugr_model/v0/table/
mod.rs

1//! Table representation of hugr modules.
2//!
3//! Instead of directly nesting data structures, we store them in tables and
4//! refer to them by their id in the table. Variables, symbols and links are
5//! fully resolved: uses refer to the id of the declaration. This allows the
6//! table representation to be read from the [binary format] and imported into
7//! the core data structures without having to perform potentially costly name
8//! resolutions.
9//!
10//! The tabling is also used for deduplication of terms. In practice, many terms
11//! will share the same subterms, and we can save memory and validation time by
12//! storing them only once. However we allow non-deduplicated terms for cases in
13//! which terms carry additional identity over just their structure. For
14//! instance, structurally identical terms could originate from different
15//! locations in a text file and therefore should be treated differently when
16//! locating type errors.
17//!
18//! This format is intended to be used as an intermediary data structure to
19//! convert between different representations (such as the [binary format], the
20//! [text format] or internal compiler data structures). To make this efficient,
21//! we use arena allocation via the [`bumpalo`] crate to efficiently construct and
22//! tear down this representation. The data structures in this module therefore carry
23//! a lifetime parameter that indicates the lifetime of the arena.
24//!
25//! [binary format]: crate::v0::binary
26//! [text format]: crate::v0::ast
27
28use smol_str::SmolStr;
29use thiserror::Error;
30
31mod view;
32use super::{ast, Literal, RegionKind};
33pub use view::View;
34
35/// A module consisting of a hugr graph together with terms.
36///
37/// See [`ast::Module`] for the AST representation.
38///
39/// [`ast::Module`]: crate::v0::ast::Module
40#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
41pub struct Module<'a> {
42    /// The id of the root region.
43    pub root: RegionId,
44    /// Table of [`Node`]s.
45    pub nodes: Vec<Node<'a>>,
46    /// Table of [`Region`]s.
47    pub regions: Vec<Region<'a>>,
48    /// Table of [`Term`]s.
49    pub terms: Vec<Term<'a>>,
50}
51
52impl<'a> Module<'a> {
53    /// Return the node data for a given node id.
54    #[inline]
55    pub fn get_node(&self, node_id: NodeId) -> Option<&Node<'a>> {
56        self.nodes.get(node_id.index())
57    }
58
59    /// Return a mutable reference to the node data for a given node id.
60    #[inline]
61    pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut Node<'a>> {
62        self.nodes.get_mut(node_id.index())
63    }
64
65    /// Insert a new node into the module and return its id.
66    pub fn insert_node(&mut self, node: Node<'a>) -> NodeId {
67        let id = NodeId::new(self.nodes.len());
68        self.nodes.push(node);
69        id
70    }
71
72    /// Return the term data for a given term id.
73    #[inline]
74    pub fn get_term(&self, term_id: TermId) -> Option<&Term<'a>> {
75        self.terms.get(term_id.index())
76    }
77
78    /// Return a mutable reference to the term data for a given term id.
79    #[inline]
80    pub fn get_term_mut(&mut self, term_id: TermId) -> Option<&mut Term<'a>> {
81        self.terms.get_mut(term_id.index())
82    }
83
84    /// Insert a new term into the module and return its id.
85    pub fn insert_term(&mut self, term: Term<'a>) -> TermId {
86        let id = TermId::new(self.terms.len());
87        self.terms.push(term);
88        id
89    }
90
91    /// Return the region data for a given region id.
92    #[inline]
93    pub fn get_region(&self, region_id: RegionId) -> Option<&Region<'a>> {
94        self.regions.get(region_id.index())
95    }
96
97    /// Return a mutable reference to the region data for a given region id.
98    #[inline]
99    pub fn get_region_mut(&mut self, region_id: RegionId) -> Option<&mut Region<'a>> {
100        self.regions.get_mut(region_id.index())
101    }
102
103    /// Insert a new region into the module and return its id.
104    pub fn insert_region(&mut self, region: Region<'a>) -> RegionId {
105        let id = RegionId::new(self.regions.len());
106        self.regions.push(region);
107        id
108    }
109
110    /// Attempt to view a part of this module via a [`View`] instance.
111    pub fn view<S, V: View<'a, S>>(&'a self, src: S) -> Option<V> {
112        V::view(self, src)
113    }
114
115    /// Convert the module to the [ast] representation.
116    ///
117    /// [ast]: crate::v0::ast
118    pub fn as_ast(&self) -> Option<ast::Module> {
119        let root = self.view(self.root)?;
120        Some(ast::Module { root })
121    }
122}
123
124/// Nodes in the hugr graph.
125///
126/// See [`ast::Node`] for the AST representation.
127///
128/// [`ast::Node`]: crate::v0::ast::Node
129#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
130pub struct Node<'a> {
131    /// The operation that the node performs.
132    pub operation: Operation<'a>,
133    /// The input ports of the node.
134    pub inputs: &'a [LinkIndex],
135    /// The output ports of the node.
136    pub outputs: &'a [LinkIndex],
137    /// The regions of the node.
138    pub regions: &'a [RegionId],
139    /// The meta information attached to the node.
140    pub meta: &'a [TermId],
141    /// The signature of the node.
142    ///
143    /// Can be `None` to indicate that the node's signature should be inferred,
144    /// or for nodes with operations that do not have a signature.
145    pub signature: Option<TermId>,
146}
147
148/// Operations that nodes can perform.
149///
150/// See [`ast::Operation`] for the AST representation.
151///
152/// [`ast::Operation`]: crate::v0::ast::Operation
153#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
154pub enum Operation<'a> {
155    /// Invalid operation to be used as a placeholder.
156    /// This is useful for modules that have non-contiguous node ids, or modules
157    /// that have not yet been fully constructed.
158    #[default]
159    Invalid,
160    /// Data flow graphs.
161    Dfg,
162    /// Control flow graphs.
163    Cfg,
164    /// Basic blocks in a control flow graph.
165    Block,
166    /// Function definitions.
167    DefineFunc(&'a Symbol<'a>),
168    /// Function declarations.
169    DeclareFunc(&'a Symbol<'a>),
170    /// Custom operation.
171    Custom(TermId),
172    /// Alias definitions.
173    DefineAlias(&'a Symbol<'a>, TermId),
174    /// Alias declarations.
175    DeclareAlias(&'a Symbol<'a>),
176    /// Tail controlled loop.
177    /// Nodes with this operation contain a dataflow graph that is executed in a loop.
178    /// The loop body is executed at least once, producing a result that indicates whether
179    /// to continue the loop or return the result.
180    ///
181    /// # Port Types
182    ///
183    /// - **Inputs**: `inputs` + `rest`
184    /// - **Outputs**: `outputs` + `rest`
185    /// - **Sources**: `inputs` + `rest`
186    /// - **Targets**: `(adt [inputs outputs])` + `rest`
187    TailLoop,
188
189    /// Conditional operation.
190    ///
191    /// # Port types
192    ///
193    /// - **Inputs**: `[(adt inputs)]` + `context`
194    /// - **Outputs**: `outputs`
195    Conditional,
196
197    /// Declaration for a term constructor.
198    ///
199    /// Nodes with this operation must be within a module region.
200    DeclareConstructor(&'a Symbol<'a>),
201
202    /// Declaration for a operation.
203    ///
204    /// Nodes with this operation must be within a module region.
205    DeclareOperation(&'a Symbol<'a>),
206
207    /// Import a symbol.
208    Import {
209        /// The name of the symbol to be imported.
210        name: &'a str,
211    },
212}
213
214impl<'a> Operation<'a> {
215    /// Returns the symbol introduced by the operation, if any.
216    pub fn symbol(&self) -> Option<&'a str> {
217        match self {
218            Operation::DefineFunc(symbol) => Some(symbol.name),
219            Operation::DeclareFunc(symbol) => Some(symbol.name),
220            Operation::DefineAlias(symbol, _) => Some(symbol.name),
221            Operation::DeclareAlias(symbol) => Some(symbol.name),
222            Operation::DeclareConstructor(symbol) => Some(symbol.name),
223            Operation::DeclareOperation(symbol) => Some(symbol.name),
224            Operation::Import { name } => Some(name),
225            _ => None,
226        }
227    }
228}
229
230/// A region in the hugr.
231///
232/// See [`ast::Region`] for the AST representation.
233///
234/// [`ast::Region`]: crate::v0::ast::Region
235#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
236pub struct Region<'a> {
237    /// The kind of the region. See [`RegionKind`] for details.
238    pub kind: RegionKind,
239    /// The source ports of the region.
240    pub sources: &'a [LinkIndex],
241    /// The target ports of the region.
242    pub targets: &'a [LinkIndex],
243    /// The nodes in the region. The order of the nodes is not significant.
244    pub children: &'a [NodeId],
245    /// The metadata attached to the region.
246    pub meta: &'a [TermId],
247    /// The signature of the region.
248    pub signature: Option<TermId>,
249    /// Information about the scope defined by this region, if the region is closed.
250    pub scope: Option<RegionScope>,
251}
252
253/// Information about the scope defined by a closed region.
254#[derive(Debug, Clone, PartialEq, Eq, Hash)]
255pub struct RegionScope {
256    /// The number of links in the scope.
257    pub links: u32,
258    /// The number of ports in the scope.
259    pub ports: u32,
260}
261
262/// A symbol.
263///
264/// See [`ast::Symbol`] for the AST representation.
265///
266/// [`ast::Symbol`]: crate::v0::ast::Symbol
267#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
268pub struct Symbol<'a> {
269    /// The name of the symbol.
270    pub name: &'a str,
271    /// The static parameters.
272    pub params: &'a [Param<'a>],
273    /// The constraints on the static parameters.
274    pub constraints: &'a [TermId],
275    /// The signature of the symbol.
276    pub signature: TermId,
277}
278
279/// An index of a variable within a node's parameter list.
280pub type VarIndex = u16;
281
282/// A term in the compile time meta language.
283///
284/// See [`ast::Term`] for the AST representation.
285///
286/// [`ast::Term`]: crate::v0::ast::Term
287#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
288pub enum Term<'a> {
289    /// Standin for any term.
290    #[default]
291    Wildcard,
292
293    /// A local variable.
294    Var(VarId),
295
296    /// Apply a symbol to a sequence of arguments.
297    ///
298    /// The symbol is defined by a node in the same graph. The type of this term
299    /// is derived from instantiating the symbol's parameters in the symbol's
300    /// signature.
301    Apply(NodeId, &'a [TermId]),
302
303    /// List of static data.
304    ///
305    /// Lists can include individual items or other lists to be spliced in.
306    ///
307    /// **Type:** `(core.list ?t)`
308    List(&'a [SeqPart]),
309
310    /// A static literal value.
311    Literal(Literal),
312
313    /// Extension set.
314    ///
315    /// **Type:** `core.ext_set`
316    ExtSet(&'a [ExtSetPart<'a>]),
317
318    /// A constant anonymous function.
319    ///
320    /// **Type:** `(core.const (core.fn ?ins ?outs ?ext) (ext))`
321    ConstFunc(RegionId),
322
323    /// Tuple of static data.
324    ///
325    /// Tuples can include individual items or other tuples to be spliced in.
326    ///
327    /// **Type:** `(core.tuple ?types)`
328    Tuple(&'a [SeqPart]),
329}
330
331impl From<Literal> for Term<'_> {
332    fn from(value: Literal) -> Self {
333        Self::Literal(value)
334    }
335}
336
337/// A part of a list/tuple term.
338///
339/// See [`ast::SeqPart`] for the AST representation.
340///
341/// [`ast::SeqPart`]: crate::v0::ast::SeqPart
342#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
343pub enum SeqPart {
344    /// A single item.
345    Item(TermId),
346    /// A list to be spliced into the parent list/tuple.
347    Splice(TermId),
348}
349
350/// A part of an extension set term.
351#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
352pub enum ExtSetPart<'a> {
353    /// An extension.
354    Extension(&'a str),
355    /// An extension set to be spliced into the parent extension set.
356    Splice(TermId),
357}
358
359/// A parameter to a function or alias.
360///
361/// Parameter names must be unique within a parameter list.
362///
363/// See [`ast::Param`] for the AST representation.
364///
365/// [`ast::Param`]: crate::v0::ast::Param
366#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
367pub struct Param<'a> {
368    /// The name of the parameter.
369    pub name: &'a str,
370    /// The type of the parameter.
371    pub r#type: TermId,
372}
373
374macro_rules! define_index {
375    ($(#[$meta:meta])* $vis:vis struct $name:ident(pub u32);) => {
376        #[repr(transparent)]
377        $(#[$meta])*
378        $vis struct $name(pub u32);
379
380        impl $name {
381            /// Create a new index.
382            ///
383            /// # Panics
384            ///
385            /// Panics if the index is 2^32 or larger.
386            pub fn new(index: usize) -> Self {
387                assert!(index < u32::MAX as usize, "index out of bounds");
388                Self(index as u32)
389            }
390
391            /// Returns the index as a `usize` to conveniently use it as a slice index.
392            #[inline]
393            pub fn index(self) -> usize {
394                self.0 as usize
395            }
396
397            /// Convert a slice of this index type into a slice of `u32`s.
398            pub fn unwrap_slice(slice: &[Self]) -> &[u32] {
399                // SAFETY: This type is just a newtype around `u32`.
400                unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u32, slice.len()) }
401            }
402
403            /// Convert a slice of `u32`s into a slice of this index type.
404            pub fn wrap_slice(slice: &[u32]) -> &[Self] {
405                // SAFETY: This type is just a newtype around `u32`.
406                unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
407            }
408        }
409    };
410}
411
412define_index! {
413    /// Id of a node in a hugr graph.
414    #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
415    pub struct NodeId(pub u32);
416}
417
418define_index! {
419    /// Index of a link in a hugr graph.
420    #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
421    pub struct LinkIndex(pub u32);
422}
423
424define_index! {
425    /// Id of a region in a hugr graph.
426    #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
427    pub struct RegionId(pub u32);
428}
429
430define_index! {
431    /// Id of a term in a hugr graph.
432    #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
433    pub struct TermId(pub u32);
434}
435
436/// The id of a link consisting of its region and the link index.
437#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
438#[display("{_0}#{_1}")]
439pub struct LinkId(pub RegionId, pub LinkIndex);
440
441/// The id of a variable consisting of its node and the variable index.
442#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
443#[display("{_0}#{_1}")]
444pub struct VarId(pub NodeId, pub VarIndex);
445
446/// Errors that can occur when traversing and interpreting the model.
447#[derive(Debug, Clone, Error)]
448pub enum ModelError {
449    /// There is a reference to a node that does not exist.
450    #[error("node not found: {0}")]
451    NodeNotFound(NodeId),
452    /// There is a reference to a term that does not exist.
453    #[error("term not found: {0}")]
454    TermNotFound(TermId),
455    /// There is a reference to a region that does not exist.
456    #[error("region not found: {0}")]
457    RegionNotFound(RegionId),
458    /// Invalid variable reference.
459    #[error("variable {0} invalid")]
460    InvalidVar(VarId),
461    /// Invalid symbol reference.
462    #[error("symbol reference {0} invalid")]
463    InvalidSymbol(NodeId),
464    /// The model contains an operation in a place where it is not allowed.
465    #[error("unexpected operation on node: {0}")]
466    UnexpectedOperation(NodeId),
467    /// There is a term that is not well-typed.
468    #[error("type error in term: {0}")]
469    TypeError(TermId),
470    /// There is a node whose regions are not well-formed according to the node's operation.
471    #[error("node has invalid regions: {0}")]
472    InvalidRegions(NodeId),
473    /// There is a name that is not well-formed.
474    #[error("malformed name: {0}")]
475    MalformedName(SmolStr),
476    /// There is a condition node that lacks a case for a tag or
477    /// defines two cases for the same tag.
478    #[error("condition node is malformed: {0}")]
479    MalformedCondition(NodeId),
480    /// There is a node that is not well-formed or has the invalid operation.
481    #[error("invalid operation on node: {0}")]
482    InvalidOperation(NodeId),
483}