hugr_model/v0/binary/
read.rs

1use crate::capnp::hugr_v0_capnp as hugr_capnp;
2use crate::v0 as model;
3use crate::v0::table;
4use bumpalo::Bump;
5use bumpalo::collections::Vec as BumpVec;
6use std::io::BufRead;
7
8/// An error encountered while deserialising a model.
9#[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)]
10#[non_exhaustive]
11pub enum ReadError {
12    #[from(forward)]
13    /// An error encountered while decoding a model from a `capnproto` buffer.
14    DecodingError(capnp::Error),
15}
16
17type ReadResult<T> = Result<T, ReadError>;
18
19/// Read a hugr package from a byte slice.
20pub fn read_from_slice<'a>(slice: &[u8], bump: &'a Bump) -> ReadResult<table::Package<'a>> {
21    read_from_reader(slice, bump)
22}
23
24/// Read a hugr package from an impl of [`BufRead`].
25pub fn read_from_reader(reader: impl BufRead, bump: &Bump) -> ReadResult<table::Package<'_>> {
26    let reader =
27        capnp::serialize_packed::read_message(reader, capnp::message::ReaderOptions::new())?;
28    let root = reader.get_root::<hugr_capnp::package::Reader>()?;
29    read_package(bump, root)
30}
31
32/// Read a list of structs from a reader into a slice allocated through the bump allocator.
33macro_rules! read_list {
34    ($bump:expr, $reader:expr, $read:expr) => {{
35        let mut __list_reader = $reader;
36        let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump);
37        for __item_reader in __list_reader.iter() {
38            __list.push($read($bump, __item_reader)?);
39        }
40        __list.into_bump_slice()
41    }};
42}
43
44/// Read a list of scalars from a reader into a slice allocated through the bump allocator.
45macro_rules! read_scalar_list {
46    ($bump:expr, $reader:expr, $get:ident, $wrap:path) => {{
47        let mut __list_reader = $reader.$get()?;
48        let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump);
49        for __item in __list_reader.iter() {
50            __list.push($wrap(__item));
51        }
52        __list.into_bump_slice()
53    }};
54}
55
56fn read_package<'a>(
57    bump: &'a Bump,
58    reader: hugr_capnp::package::Reader,
59) -> ReadResult<table::Package<'a>> {
60    let modules = reader
61        .get_modules()?
62        .iter()
63        .map(|m| read_module(bump, m))
64        .collect::<ReadResult<_>>()?;
65
66    Ok(table::Package { modules })
67}
68
69fn read_module<'a>(
70    bump: &'a Bump,
71    reader: hugr_capnp::module::Reader,
72) -> ReadResult<table::Module<'a>> {
73    let root = table::RegionId(reader.get_root());
74
75    let nodes = reader
76        .get_nodes()?
77        .iter()
78        .map(|r| read_node(bump, r))
79        .collect::<ReadResult<_>>()?;
80
81    let regions = reader
82        .get_regions()?
83        .iter()
84        .map(|r| read_region(bump, r))
85        .collect::<ReadResult<_>>()?;
86
87    let terms = reader
88        .get_terms()?
89        .iter()
90        .map(|r| read_term(bump, r))
91        .collect::<ReadResult<_>>()?;
92
93    Ok(table::Module {
94        root,
95        nodes,
96        regions,
97        terms,
98    })
99}
100
101fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult<table::Node<'a>> {
102    let operation = read_operation(bump, reader.get_operation()?)?;
103    let inputs = read_scalar_list!(bump, reader, get_inputs, table::LinkIndex);
104    let outputs = read_scalar_list!(bump, reader, get_outputs, table::LinkIndex);
105    let regions = read_scalar_list!(bump, reader, get_regions, table::RegionId);
106    let meta = read_scalar_list!(bump, reader, get_meta, table::TermId);
107    let signature = reader.get_signature().checked_sub(1).map(table::TermId);
108
109    Ok(table::Node {
110        operation,
111        inputs,
112        outputs,
113        regions,
114        meta,
115        signature,
116    })
117}
118
119fn read_operation<'a>(
120    bump: &'a Bump,
121    reader: hugr_capnp::operation::Reader,
122) -> ReadResult<table::Operation<'a>> {
123    use hugr_capnp::operation::Which;
124    Ok(match reader.which()? {
125        Which::Invalid(()) => table::Operation::Invalid,
126        Which::Dfg(()) => table::Operation::Dfg,
127        Which::Cfg(()) => table::Operation::Cfg,
128        Which::Block(()) => table::Operation::Block,
129        Which::FuncDefn(reader) => table::Operation::DefineFunc(read_symbol(bump, reader?, None)?),
130        Which::FuncDecl(reader) => table::Operation::DeclareFunc(read_symbol(bump, reader?, None)?),
131        Which::AliasDefn(reader) => {
132            let symbol = reader.get_symbol()?;
133            let value = table::TermId(reader.get_value());
134            table::Operation::DefineAlias(read_symbol(bump, symbol, Some(&[]))?, value)
135        }
136        Which::AliasDecl(reader) => {
137            table::Operation::DeclareAlias(read_symbol(bump, reader?, Some(&[]))?)
138        }
139        Which::ConstructorDecl(reader) => {
140            table::Operation::DeclareConstructor(read_symbol(bump, reader?, None)?)
141        }
142        Which::OperationDecl(reader) => {
143            table::Operation::DeclareOperation(read_symbol(bump, reader?, None)?)
144        }
145        Which::Custom(operation) => table::Operation::Custom(table::TermId(operation)),
146        Which::TailLoop(()) => table::Operation::TailLoop,
147        Which::Conditional(()) => table::Operation::Conditional,
148        Which::Import(name) => table::Operation::Import {
149            name: bump.alloc_str(name?.to_str()?),
150        },
151    })
152}
153
154fn read_region<'a>(
155    bump: &'a Bump,
156    reader: hugr_capnp::region::Reader,
157) -> ReadResult<table::Region<'a>> {
158    let kind = match reader.get_kind()? {
159        hugr_capnp::RegionKind::DataFlow => model::RegionKind::DataFlow,
160        hugr_capnp::RegionKind::ControlFlow => model::RegionKind::ControlFlow,
161        hugr_capnp::RegionKind::Module => model::RegionKind::Module,
162    };
163
164    let sources = read_scalar_list!(bump, reader, get_sources, table::LinkIndex);
165    let targets = read_scalar_list!(bump, reader, get_targets, table::LinkIndex);
166    let children = read_scalar_list!(bump, reader, get_children, table::NodeId);
167    let meta = read_scalar_list!(bump, reader, get_meta, table::TermId);
168    let signature = reader.get_signature().checked_sub(1).map(table::TermId);
169
170    let scope = if reader.has_scope() {
171        Some(read_region_scope(reader.get_scope()?)?)
172    } else {
173        None
174    };
175
176    Ok(table::Region {
177        kind,
178        sources,
179        targets,
180        children,
181        meta,
182        signature,
183        scope,
184    })
185}
186
187fn read_region_scope(reader: hugr_capnp::region_scope::Reader) -> ReadResult<table::RegionScope> {
188    let links = reader.get_links();
189    let ports = reader.get_ports();
190    Ok(table::RegionScope { links, ports })
191}
192
193impl From<hugr_capnp::Visibility> for Option<model::Visibility> {
194    fn from(value: hugr_capnp::Visibility) -> Self {
195        match value {
196            hugr_capnp::Visibility::Unspecified => None,
197            hugr_capnp::Visibility::Private => Some(model::Visibility::Private),
198            hugr_capnp::Visibility::Public => Some(model::Visibility::Public),
199        }
200    }
201}
202
203/// (Only) if `constraints` are None, then they are read from the `reader`
204fn read_symbol<'a>(
205    bump: &'a Bump,
206    reader: hugr_capnp::symbol::Reader,
207    constraints: Option<&'a [table::TermId]>,
208) -> ReadResult<&'a mut table::Symbol<'a>> {
209    let name = bump.alloc_str(reader.get_name()?.to_str()?);
210    let visibility = reader.get_visibility()?.into();
211    let visibility = bump.alloc(visibility);
212    let params = read_list!(bump, reader.get_params()?, read_param);
213    let constraints = match constraints {
214        Some(cs) => cs,
215        None => read_scalar_list!(bump, reader, get_constraints, table::TermId),
216    };
217    let signature = table::TermId(reader.get_signature());
218    Ok(bump.alloc(table::Symbol {
219        visibility,
220        name,
221        params,
222        constraints,
223        signature,
224    }))
225}
226
227fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult<table::Term<'a>> {
228    use hugr_capnp::term::Which;
229    Ok(match reader.which()? {
230        Which::Wildcard(()) => table::Term::Wildcard,
231        Which::String(value) => table::Term::Literal(model::Literal::Str(value?.to_str()?.into())),
232        Which::Nat(value) => table::Term::Literal(model::Literal::Nat(value)),
233
234        Which::Variable(reader) => {
235            let node = table::NodeId(reader.get_node());
236            let index = reader.get_index();
237            table::Term::Var(table::VarId(node, index))
238        }
239
240        Which::Apply(reader) => {
241            let symbol = table::NodeId(reader.get_symbol());
242            let args = read_scalar_list!(bump, reader, get_args, table::TermId);
243            table::Term::Apply(symbol, args)
244        }
245
246        Which::List(reader) => {
247            let parts = read_list!(bump, reader?, read_seq_part);
248            table::Term::List(parts)
249        }
250
251        Which::Tuple(reader) => {
252            let parts = read_list!(bump, reader?, read_seq_part);
253            table::Term::Tuple(parts)
254        }
255
256        Which::Func(region) => table::Term::Func(table::RegionId(region)),
257
258        Which::Bytes(bytes) => table::Term::Literal(model::Literal::Bytes(bytes?.into())),
259        Which::Float(value) => table::Term::Literal(model::Literal::Float(value.into())),
260    })
261}
262
263fn read_seq_part(
264    _: &Bump,
265    reader: hugr_capnp::term::seq_part::Reader,
266) -> ReadResult<table::SeqPart> {
267    use hugr_capnp::term::seq_part::Which;
268    Ok(match reader.which()? {
269        Which::Item(term) => table::SeqPart::Item(table::TermId(term)),
270        Which::Splice(list) => table::SeqPart::Splice(table::TermId(list)),
271    })
272}
273
274fn read_param<'a>(
275    bump: &'a Bump,
276    reader: hugr_capnp::param::Reader,
277) -> ReadResult<table::Param<'a>> {
278    let name = bump.alloc_str(reader.get_name()?.to_str()?);
279    let r#type = table::TermId(reader.get_type());
280    Ok(table::Param { name, r#type })
281}