hugr_model/v0/binary/
read.rs

1use crate::capnp::hugr_v0_capnp as hugr_capnp;
2use crate::v0 as model;
3use crate::v0::table;
4use bumpalo::collections::Vec as BumpVec;
5use bumpalo::Bump;
6use std::io::BufRead;
7
8/// An error encounted while deserialising a model.
9#[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)]
10#[non_exhaustive]
11pub enum ReadError {
12    #[from(forward)]
13    /// An error encounted while decoding a model from a `capnproto` buffer.
14    DecodingError(capnp::Error),
15}
16
17type ReadResult<T> = Result<T, ReadError>;
18
19/// Read a hugr module from a byte slice.
20pub fn read_from_slice<'a>(slice: &[u8], bump: &'a Bump) -> ReadResult<table::Module<'a>> {
21    read_from_reader(slice, bump)
22}
23
24/// Read a hugr module from an impl of [BufRead].
25pub fn read_from_reader(reader: impl BufRead, bump: &Bump) -> ReadResult<table::Module<'_>> {
26    let reader =
27        capnp::serialize_packed::read_message(reader, capnp::message::ReaderOptions::new())?;
28    let root = reader.get_root::<hugr_capnp::module::Reader>()?;
29    read_module(bump, root)
30}
31
32/// Read a list of structs from a reader into a slice allocated through the bump allocator.
33macro_rules! read_list {
34    ($bump:expr, $reader:expr, $read:expr) => {{
35        let mut __list_reader = $reader;
36        let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump);
37        for __item_reader in __list_reader.iter() {
38            __list.push($read($bump, __item_reader)?);
39        }
40        __list.into_bump_slice()
41    }};
42}
43
44/// Read a list of scalars from a reader into a slice allocated through the bump allocator.
45macro_rules! read_scalar_list {
46    ($bump:expr, $reader:expr, $get:ident, $wrap:path) => {{
47        let mut __list_reader = $reader.$get()?;
48        let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump);
49        for __item in __list_reader.iter() {
50            __list.push($wrap(__item));
51        }
52        __list.into_bump_slice()
53    }};
54}
55
56fn read_module<'a>(
57    bump: &'a Bump,
58    reader: hugr_capnp::module::Reader,
59) -> ReadResult<table::Module<'a>> {
60    let root = table::RegionId(reader.get_root());
61
62    let nodes = reader
63        .get_nodes()?
64        .iter()
65        .map(|r| read_node(bump, r))
66        .collect::<ReadResult<_>>()?;
67
68    let regions = reader
69        .get_regions()?
70        .iter()
71        .map(|r| read_region(bump, r))
72        .collect::<ReadResult<_>>()?;
73
74    let terms = reader
75        .get_terms()?
76        .iter()
77        .map(|r| read_term(bump, r))
78        .collect::<ReadResult<_>>()?;
79
80    Ok(table::Module {
81        root,
82        nodes,
83        regions,
84        terms,
85    })
86}
87
88fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult<table::Node<'a>> {
89    let operation = read_operation(bump, reader.get_operation()?)?;
90    let inputs = read_scalar_list!(bump, reader, get_inputs, table::LinkIndex);
91    let outputs = read_scalar_list!(bump, reader, get_outputs, table::LinkIndex);
92    let regions = read_scalar_list!(bump, reader, get_regions, table::RegionId);
93    let meta = read_scalar_list!(bump, reader, get_meta, table::TermId);
94    let signature = reader.get_signature().checked_sub(1).map(table::TermId);
95
96    Ok(table::Node {
97        operation,
98        inputs,
99        outputs,
100        regions,
101        meta,
102        signature,
103    })
104}
105
106fn read_operation<'a>(
107    bump: &'a Bump,
108    reader: hugr_capnp::operation::Reader,
109) -> ReadResult<table::Operation<'a>> {
110    use hugr_capnp::operation::Which;
111    Ok(match reader.which()? {
112        Which::Invalid(()) => table::Operation::Invalid,
113        Which::Dfg(()) => table::Operation::Dfg,
114        Which::Cfg(()) => table::Operation::Cfg,
115        Which::Block(()) => table::Operation::Block,
116        Which::FuncDefn(reader) => {
117            let reader = reader?;
118            let name = bump.alloc_str(reader.get_name()?.to_str()?);
119            let params = read_list!(bump, reader.get_params()?, read_param);
120            let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId);
121            let signature = table::TermId(reader.get_signature());
122            let symbol = bump.alloc(table::Symbol {
123                name,
124                params,
125                constraints,
126                signature,
127            });
128            table::Operation::DefineFunc(symbol)
129        }
130        Which::FuncDecl(reader) => {
131            let reader = reader?;
132            let name = bump.alloc_str(reader.get_name()?.to_str()?);
133            let params = read_list!(bump, reader.get_params()?, read_param);
134            let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId);
135            let signature = table::TermId(reader.get_signature());
136            let symbol = bump.alloc(table::Symbol {
137                name,
138                params,
139                constraints,
140                signature,
141            });
142            table::Operation::DeclareFunc(symbol)
143        }
144        Which::AliasDefn(reader) => {
145            let symbol = reader.get_symbol()?;
146            let value = table::TermId(reader.get_value());
147            let name = bump.alloc_str(symbol.get_name()?.to_str()?);
148            let params = read_list!(bump, symbol.get_params()?, read_param);
149            let signature = table::TermId(symbol.get_signature());
150            let symbol = bump.alloc(table::Symbol {
151                name,
152                params,
153                constraints: &[],
154                signature,
155            });
156            table::Operation::DefineAlias(symbol, value)
157        }
158        Which::AliasDecl(reader) => {
159            let reader = reader?;
160            let name = bump.alloc_str(reader.get_name()?.to_str()?);
161            let params = read_list!(bump, reader.get_params()?, read_param);
162            let signature = table::TermId(reader.get_signature());
163            let symbol = bump.alloc(table::Symbol {
164                name,
165                params,
166                constraints: &[],
167                signature,
168            });
169            table::Operation::DeclareAlias(symbol)
170        }
171        Which::ConstructorDecl(reader) => {
172            let reader = reader?;
173            let name = bump.alloc_str(reader.get_name()?.to_str()?);
174            let params = read_list!(bump, reader.get_params()?, read_param);
175            let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId);
176            let signature = table::TermId(reader.get_signature());
177            let symbol = bump.alloc(table::Symbol {
178                name,
179                params,
180                constraints,
181                signature,
182            });
183            table::Operation::DeclareConstructor(symbol)
184        }
185        Which::OperationDecl(reader) => {
186            let reader = reader?;
187            let name = bump.alloc_str(reader.get_name()?.to_str()?);
188            let params = read_list!(bump, reader.get_params()?, read_param);
189            let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId);
190            let signature = table::TermId(reader.get_signature());
191            let symbol = bump.alloc(table::Symbol {
192                name,
193                params,
194                constraints,
195                signature,
196            });
197            table::Operation::DeclareOperation(symbol)
198        }
199        Which::Custom(operation) => table::Operation::Custom(table::TermId(operation)),
200        Which::TailLoop(()) => table::Operation::TailLoop,
201        Which::Conditional(()) => table::Operation::Conditional,
202        Which::Import(name) => table::Operation::Import {
203            name: bump.alloc_str(name?.to_str()?),
204        },
205    })
206}
207
208fn read_region<'a>(
209    bump: &'a Bump,
210    reader: hugr_capnp::region::Reader,
211) -> ReadResult<table::Region<'a>> {
212    let kind = match reader.get_kind()? {
213        hugr_capnp::RegionKind::DataFlow => model::RegionKind::DataFlow,
214        hugr_capnp::RegionKind::ControlFlow => model::RegionKind::ControlFlow,
215        hugr_capnp::RegionKind::Module => model::RegionKind::Module,
216    };
217
218    let sources = read_scalar_list!(bump, reader, get_sources, table::LinkIndex);
219    let targets = read_scalar_list!(bump, reader, get_targets, table::LinkIndex);
220    let children = read_scalar_list!(bump, reader, get_children, table::NodeId);
221    let meta = read_scalar_list!(bump, reader, get_meta, table::TermId);
222    let signature = reader.get_signature().checked_sub(1).map(table::TermId);
223
224    let scope = if reader.has_scope() {
225        Some(read_region_scope(reader.get_scope()?)?)
226    } else {
227        None
228    };
229
230    Ok(table::Region {
231        kind,
232        sources,
233        targets,
234        children,
235        meta,
236        signature,
237        scope,
238    })
239}
240
241fn read_region_scope(reader: hugr_capnp::region_scope::Reader) -> ReadResult<table::RegionScope> {
242    let links = reader.get_links();
243    let ports = reader.get_ports();
244    Ok(table::RegionScope { links, ports })
245}
246
247fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult<table::Term<'a>> {
248    use hugr_capnp::term::Which;
249    Ok(match reader.which()? {
250        Which::Wildcard(()) => table::Term::Wildcard,
251        Which::String(value) => table::Term::Literal(model::Literal::Str(value?.to_str()?.into())),
252        Which::Nat(value) => table::Term::Literal(model::Literal::Nat(value)),
253
254        Which::Variable(reader) => {
255            let node = table::NodeId(reader.get_node());
256            let index = reader.get_index();
257            table::Term::Var(table::VarId(node, index))
258        }
259
260        Which::Apply(reader) => {
261            let symbol = table::NodeId(reader.get_symbol());
262            let args = read_scalar_list!(bump, reader, get_args, table::TermId);
263            table::Term::Apply(symbol, args)
264        }
265
266        Which::List(reader) => {
267            let parts = read_list!(bump, reader?, read_seq_part);
268            table::Term::List(parts)
269        }
270
271        Which::ExtSet(reader) => {
272            let parts = read_list!(bump, reader?, read_ext_set_part);
273            table::Term::ExtSet(parts)
274        }
275
276        Which::Tuple(reader) => {
277            let parts = read_list!(bump, reader?, read_seq_part);
278            table::Term::Tuple(parts)
279        }
280
281        Which::ConstFunc(region) => table::Term::ConstFunc(table::RegionId(region)),
282
283        Which::Bytes(bytes) => table::Term::Literal(model::Literal::Bytes(bytes?.into())),
284        Which::Float(value) => table::Term::Literal(model::Literal::Float(value.into())),
285    })
286}
287
288fn read_seq_part(
289    _: &Bump,
290    reader: hugr_capnp::term::seq_part::Reader,
291) -> ReadResult<table::SeqPart> {
292    use hugr_capnp::term::seq_part::Which;
293    Ok(match reader.which()? {
294        Which::Item(term) => table::SeqPart::Item(table::TermId(term)),
295        Which::Splice(list) => table::SeqPart::Splice(table::TermId(list)),
296    })
297}
298
299fn read_ext_set_part<'a>(
300    bump: &'a Bump,
301    reader: hugr_capnp::term::ext_set_part::Reader,
302) -> ReadResult<table::ExtSetPart<'a>> {
303    use hugr_capnp::term::ext_set_part::Which;
304    Ok(match reader.which()? {
305        Which::Extension(ext) => table::ExtSetPart::Extension(bump.alloc_str(ext?.to_str()?)),
306        Which::Splice(list) => table::ExtSetPart::Splice(table::TermId(list)),
307    })
308}
309
310fn read_param<'a>(
311    bump: &'a Bump,
312    reader: hugr_capnp::param::Reader,
313) -> ReadResult<table::Param<'a>> {
314    let name = bump.alloc_str(reader.get_name()?.to_str()?);
315    let r#type = table::TermId(reader.get_type());
316    Ok(table::Param { name, r#type })
317}