hugr_model/v0/binary/
read.rs

1use crate::capnp::hugr_v0_capnp as hugr_capnp;
2use crate::v0 as model;
3use crate::v0::table;
4use bumpalo::Bump;
5use bumpalo::collections::Vec as BumpVec;
6use std::io::BufRead;
7
8/// An error encountered while deserialising a model.
9#[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)]
10#[non_exhaustive]
11pub enum ReadError {
12    #[from(forward)]
13    /// An error encountered while decoding a model from a `capnproto` buffer.
14    DecodingError(capnp::Error),
15}
16
17type ReadResult<T> = Result<T, ReadError>;
18
19/// Read a hugr package from a byte slice.
20pub fn read_from_slice<'a>(slice: &[u8], bump: &'a Bump) -> ReadResult<table::Package<'a>> {
21    read_from_reader(slice, bump)
22}
23
24/// Read a hugr package from an impl of [`BufRead`].
25pub fn read_from_reader(reader: impl BufRead, bump: &Bump) -> ReadResult<table::Package<'_>> {
26    let reader =
27        capnp::serialize_packed::read_message(reader, capnp::message::ReaderOptions::new())?;
28    let root = reader.get_root::<hugr_capnp::package::Reader>()?;
29    read_package(bump, root)
30}
31
32/// Read a list of structs from a reader into a slice allocated through the bump allocator.
33macro_rules! read_list {
34    ($bump:expr, $reader:expr, $read:expr) => {{
35        let mut __list_reader = $reader;
36        let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump);
37        for __item_reader in __list_reader.iter() {
38            __list.push($read($bump, __item_reader)?);
39        }
40        __list.into_bump_slice()
41    }};
42}
43
44/// Read a list of scalars from a reader into a slice allocated through the bump allocator.
45macro_rules! read_scalar_list {
46    ($bump:expr, $reader:expr, $get:ident, $wrap:path) => {{
47        let mut __list_reader = $reader.$get()?;
48        let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump);
49        for __item in __list_reader.iter() {
50            __list.push($wrap(__item));
51        }
52        __list.into_bump_slice()
53    }};
54}
55
56fn read_package<'a>(
57    bump: &'a Bump,
58    reader: hugr_capnp::package::Reader,
59) -> ReadResult<table::Package<'a>> {
60    let modules = reader
61        .get_modules()?
62        .iter()
63        .map(|m| read_module(bump, m))
64        .collect::<ReadResult<_>>()?;
65
66    Ok(table::Package { modules })
67}
68
69fn read_module<'a>(
70    bump: &'a Bump,
71    reader: hugr_capnp::module::Reader,
72) -> ReadResult<table::Module<'a>> {
73    let root = table::RegionId(reader.get_root());
74
75    let nodes = reader
76        .get_nodes()?
77        .iter()
78        .map(|r| read_node(bump, r))
79        .collect::<ReadResult<_>>()?;
80
81    let regions = reader
82        .get_regions()?
83        .iter()
84        .map(|r| read_region(bump, r))
85        .collect::<ReadResult<_>>()?;
86
87    let terms = reader
88        .get_terms()?
89        .iter()
90        .map(|r| read_term(bump, r))
91        .collect::<ReadResult<_>>()?;
92
93    Ok(table::Module {
94        root,
95        nodes,
96        regions,
97        terms,
98    })
99}
100
101fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult<table::Node<'a>> {
102    let operation = read_operation(bump, reader.get_operation()?)?;
103    let inputs = read_scalar_list!(bump, reader, get_inputs, table::LinkIndex);
104    let outputs = read_scalar_list!(bump, reader, get_outputs, table::LinkIndex);
105    let regions = read_scalar_list!(bump, reader, get_regions, table::RegionId);
106    let meta = read_scalar_list!(bump, reader, get_meta, table::TermId);
107    let signature = reader.get_signature().checked_sub(1).map(table::TermId);
108
109    Ok(table::Node {
110        operation,
111        inputs,
112        outputs,
113        regions,
114        meta,
115        signature,
116    })
117}
118
119fn read_operation<'a>(
120    bump: &'a Bump,
121    reader: hugr_capnp::operation::Reader,
122) -> ReadResult<table::Operation<'a>> {
123    use hugr_capnp::operation::Which;
124    Ok(match reader.which()? {
125        Which::Invalid(()) => table::Operation::Invalid,
126        Which::Dfg(()) => table::Operation::Dfg,
127        Which::Cfg(()) => table::Operation::Cfg,
128        Which::Block(()) => table::Operation::Block,
129        Which::FuncDefn(reader) => {
130            let reader = reader?;
131            let name = bump.alloc_str(reader.get_name()?.to_str()?);
132            let params = read_list!(bump, reader.get_params()?, read_param);
133            let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId);
134            let signature = table::TermId(reader.get_signature());
135            let symbol = bump.alloc(table::Symbol {
136                name,
137                params,
138                constraints,
139                signature,
140            });
141            table::Operation::DefineFunc(symbol)
142        }
143        Which::FuncDecl(reader) => {
144            let reader = reader?;
145            let name = bump.alloc_str(reader.get_name()?.to_str()?);
146            let params = read_list!(bump, reader.get_params()?, read_param);
147            let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId);
148            let signature = table::TermId(reader.get_signature());
149            let symbol = bump.alloc(table::Symbol {
150                name,
151                params,
152                constraints,
153                signature,
154            });
155            table::Operation::DeclareFunc(symbol)
156        }
157        Which::AliasDefn(reader) => {
158            let symbol = reader.get_symbol()?;
159            let value = table::TermId(reader.get_value());
160            let name = bump.alloc_str(symbol.get_name()?.to_str()?);
161            let params = read_list!(bump, symbol.get_params()?, read_param);
162            let signature = table::TermId(symbol.get_signature());
163            let symbol = bump.alloc(table::Symbol {
164                name,
165                params,
166                constraints: &[],
167                signature,
168            });
169            table::Operation::DefineAlias(symbol, value)
170        }
171        Which::AliasDecl(reader) => {
172            let reader = reader?;
173            let name = bump.alloc_str(reader.get_name()?.to_str()?);
174            let params = read_list!(bump, reader.get_params()?, read_param);
175            let signature = table::TermId(reader.get_signature());
176            let symbol = bump.alloc(table::Symbol {
177                name,
178                params,
179                constraints: &[],
180                signature,
181            });
182            table::Operation::DeclareAlias(symbol)
183        }
184        Which::ConstructorDecl(reader) => {
185            let reader = reader?;
186            let name = bump.alloc_str(reader.get_name()?.to_str()?);
187            let params = read_list!(bump, reader.get_params()?, read_param);
188            let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId);
189            let signature = table::TermId(reader.get_signature());
190            let symbol = bump.alloc(table::Symbol {
191                name,
192                params,
193                constraints,
194                signature,
195            });
196            table::Operation::DeclareConstructor(symbol)
197        }
198        Which::OperationDecl(reader) => {
199            let reader = reader?;
200            let name = bump.alloc_str(reader.get_name()?.to_str()?);
201            let params = read_list!(bump, reader.get_params()?, read_param);
202            let constraints = read_scalar_list!(bump, reader, get_constraints, table::TermId);
203            let signature = table::TermId(reader.get_signature());
204            let symbol = bump.alloc(table::Symbol {
205                name,
206                params,
207                constraints,
208                signature,
209            });
210            table::Operation::DeclareOperation(symbol)
211        }
212        Which::Custom(operation) => table::Operation::Custom(table::TermId(operation)),
213        Which::TailLoop(()) => table::Operation::TailLoop,
214        Which::Conditional(()) => table::Operation::Conditional,
215        Which::Import(name) => table::Operation::Import {
216            name: bump.alloc_str(name?.to_str()?),
217        },
218    })
219}
220
221fn read_region<'a>(
222    bump: &'a Bump,
223    reader: hugr_capnp::region::Reader,
224) -> ReadResult<table::Region<'a>> {
225    let kind = match reader.get_kind()? {
226        hugr_capnp::RegionKind::DataFlow => model::RegionKind::DataFlow,
227        hugr_capnp::RegionKind::ControlFlow => model::RegionKind::ControlFlow,
228        hugr_capnp::RegionKind::Module => model::RegionKind::Module,
229    };
230
231    let sources = read_scalar_list!(bump, reader, get_sources, table::LinkIndex);
232    let targets = read_scalar_list!(bump, reader, get_targets, table::LinkIndex);
233    let children = read_scalar_list!(bump, reader, get_children, table::NodeId);
234    let meta = read_scalar_list!(bump, reader, get_meta, table::TermId);
235    let signature = reader.get_signature().checked_sub(1).map(table::TermId);
236
237    let scope = if reader.has_scope() {
238        Some(read_region_scope(reader.get_scope()?)?)
239    } else {
240        None
241    };
242
243    Ok(table::Region {
244        kind,
245        sources,
246        targets,
247        children,
248        meta,
249        signature,
250        scope,
251    })
252}
253
254fn read_region_scope(reader: hugr_capnp::region_scope::Reader) -> ReadResult<table::RegionScope> {
255    let links = reader.get_links();
256    let ports = reader.get_ports();
257    Ok(table::RegionScope { links, ports })
258}
259
260fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult<table::Term<'a>> {
261    use hugr_capnp::term::Which;
262    Ok(match reader.which()? {
263        Which::Wildcard(()) => table::Term::Wildcard,
264        Which::String(value) => table::Term::Literal(model::Literal::Str(value?.to_str()?.into())),
265        Which::Nat(value) => table::Term::Literal(model::Literal::Nat(value)),
266
267        Which::Variable(reader) => {
268            let node = table::NodeId(reader.get_node());
269            let index = reader.get_index();
270            table::Term::Var(table::VarId(node, index))
271        }
272
273        Which::Apply(reader) => {
274            let symbol = table::NodeId(reader.get_symbol());
275            let args = read_scalar_list!(bump, reader, get_args, table::TermId);
276            table::Term::Apply(symbol, args)
277        }
278
279        Which::List(reader) => {
280            let parts = read_list!(bump, reader?, read_seq_part);
281            table::Term::List(parts)
282        }
283
284        Which::Tuple(reader) => {
285            let parts = read_list!(bump, reader?, read_seq_part);
286            table::Term::Tuple(parts)
287        }
288
289        Which::Func(region) => table::Term::Func(table::RegionId(region)),
290
291        Which::Bytes(bytes) => table::Term::Literal(model::Literal::Bytes(bytes?.into())),
292        Which::Float(value) => table::Term::Literal(model::Literal::Float(value.into())),
293    })
294}
295
296fn read_seq_part(
297    _: &Bump,
298    reader: hugr_capnp::term::seq_part::Reader,
299) -> ReadResult<table::SeqPart> {
300    use hugr_capnp::term::seq_part::Which;
301    Ok(match reader.which()? {
302        Which::Item(term) => table::SeqPart::Item(table::TermId(term)),
303        Which::Splice(list) => table::SeqPart::Splice(table::TermId(list)),
304    })
305}
306
307fn read_param<'a>(
308    bump: &'a Bump,
309    reader: hugr_capnp::param::Reader,
310) -> ReadResult<table::Param<'a>> {
311    let name = bump.alloc_str(reader.get_name()?.to_str()?);
312    let r#type = table::TermId(reader.get_type());
313    Ok(table::Param { name, r#type })
314}