hugr_model/v0/binary/
read.rs

1use crate::capnp::hugr_v0_capnp as hugr_capnp;
2use crate::v0::table;
3use crate::{CURRENT_VERSION, v0 as model};
4use bumpalo::Bump;
5use bumpalo::collections::Vec as BumpVec;
6use std::io::BufRead;
7
8/// An error encountered while deserialising a model.
9#[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)]
10#[non_exhaustive]
11#[display("Error reading a HUGR model payload.")]
12pub enum ReadError {
13    #[from(forward)]
14    /// An error encountered while decoding a model from a `capnproto` buffer.
15    DecodingError(capnp::Error),
16
17    /// The file could not be read due to a version mismatch.
18    #[display("Can not read file with version {actual} (tooling version {current}).")]
19    VersionError {
20        /// The current version of the hugr-model format.
21        current: semver::Version,
22        /// The version of the hugr-model format in the file.
23        actual: semver::Version,
24    },
25}
26
27type ReadResult<T> = Result<T, ReadError>;
28
29/// Read a hugr package from a byte slice.
30pub fn read_from_slice<'a>(slice: &[u8], bump: &'a Bump) -> ReadResult<table::Package<'a>> {
31    read_from_reader(slice, bump)
32}
33
34/// Read a hugr package from an impl of [`BufRead`].
35pub fn read_from_reader(reader: impl BufRead, bump: &Bump) -> ReadResult<table::Package<'_>> {
36    let reader =
37        capnp::serialize_packed::read_message(reader, capnp::message::ReaderOptions::new())?;
38    let root = reader.get_root::<hugr_capnp::package::Reader>()?;
39    read_package(bump, root)
40}
41
42/// Read a list of structs from a reader into a slice allocated through the bump allocator.
43macro_rules! read_list {
44    ($bump:expr, $reader:expr, $read:expr) => {{
45        let mut __list_reader = $reader;
46        let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump);
47        for __item_reader in __list_reader.iter() {
48            __list.push($read($bump, __item_reader)?);
49        }
50        __list.into_bump_slice()
51    }};
52}
53
54/// Read a list of scalars from a reader into a slice allocated through the bump allocator.
55macro_rules! read_scalar_list {
56    ($bump:expr, $reader:expr, $get:ident, $wrap:path) => {{
57        let mut __list_reader = $reader.$get()?;
58        let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump);
59        for __item in __list_reader.iter() {
60            __list.push($wrap(__item));
61        }
62        __list.into_bump_slice()
63    }};
64}
65
66fn read_package<'a>(
67    bump: &'a Bump,
68    reader: hugr_capnp::package::Reader,
69) -> ReadResult<table::Package<'a>> {
70    let version = read_version(reader.get_version()?)?;
71
72    if version.major != CURRENT_VERSION.major || version.minor > CURRENT_VERSION.minor {
73        return Err(ReadError::VersionError {
74            current: CURRENT_VERSION.clone(),
75            actual: version,
76        });
77    }
78
79    let modules = reader
80        .get_modules()?
81        .iter()
82        .map(|m| read_module(bump, m))
83        .collect::<ReadResult<_>>()?;
84
85    Ok(table::Package { modules })
86}
87
88fn read_version(reader: hugr_capnp::version::Reader) -> ReadResult<semver::Version> {
89    let major = reader.get_major();
90    let minor = reader.get_minor();
91    Ok(semver::Version::new(major as u64, minor as u64, 0))
92}
93
94fn read_module<'a>(
95    bump: &'a Bump,
96    reader: hugr_capnp::module::Reader,
97) -> ReadResult<table::Module<'a>> {
98    let root = table::RegionId(reader.get_root());
99
100    let nodes = reader
101        .get_nodes()?
102        .iter()
103        .map(|r| read_node(bump, r))
104        .collect::<ReadResult<_>>()?;
105
106    let regions = reader
107        .get_regions()?
108        .iter()
109        .map(|r| read_region(bump, r))
110        .collect::<ReadResult<_>>()?;
111
112    let terms = reader
113        .get_terms()?
114        .iter()
115        .map(|r| read_term(bump, r))
116        .collect::<ReadResult<_>>()?;
117
118    Ok(table::Module {
119        root,
120        nodes,
121        regions,
122        terms,
123    })
124}
125
126fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult<table::Node<'a>> {
127    let operation = read_operation(bump, reader.get_operation()?)?;
128    let inputs = read_scalar_list!(bump, reader, get_inputs, table::LinkIndex);
129    let outputs = read_scalar_list!(bump, reader, get_outputs, table::LinkIndex);
130    let regions = read_scalar_list!(bump, reader, get_regions, table::RegionId);
131    let meta = read_scalar_list!(bump, reader, get_meta, table::TermId);
132    let signature = reader.get_signature().checked_sub(1).map(table::TermId);
133
134    Ok(table::Node {
135        operation,
136        inputs,
137        outputs,
138        regions,
139        meta,
140        signature,
141    })
142}
143
144fn read_operation<'a>(
145    bump: &'a Bump,
146    reader: hugr_capnp::operation::Reader,
147) -> ReadResult<table::Operation<'a>> {
148    use hugr_capnp::operation::Which;
149    Ok(match reader.which()? {
150        Which::Invalid(()) => table::Operation::Invalid,
151        Which::Dfg(()) => table::Operation::Dfg,
152        Which::Cfg(()) => table::Operation::Cfg,
153        Which::Block(()) => table::Operation::Block,
154        Which::FuncDefn(reader) => table::Operation::DefineFunc(read_symbol(bump, reader?, None)?),
155        Which::FuncDecl(reader) => table::Operation::DeclareFunc(read_symbol(bump, reader?, None)?),
156        Which::AliasDefn(reader) => {
157            let symbol = reader.get_symbol()?;
158            let value = table::TermId(reader.get_value());
159            table::Operation::DefineAlias(read_symbol(bump, symbol, Some(&[]))?, value)
160        }
161        Which::AliasDecl(reader) => {
162            table::Operation::DeclareAlias(read_symbol(bump, reader?, Some(&[]))?)
163        }
164        Which::ConstructorDecl(reader) => {
165            table::Operation::DeclareConstructor(read_symbol(bump, reader?, None)?)
166        }
167        Which::OperationDecl(reader) => {
168            table::Operation::DeclareOperation(read_symbol(bump, reader?, None)?)
169        }
170        Which::Custom(operation) => table::Operation::Custom(table::TermId(operation)),
171        Which::TailLoop(()) => table::Operation::TailLoop,
172        Which::Conditional(()) => table::Operation::Conditional,
173        Which::Import(name) => table::Operation::Import {
174            name: bump.alloc_str(name?.to_str()?),
175        },
176    })
177}
178
179fn read_region<'a>(
180    bump: &'a Bump,
181    reader: hugr_capnp::region::Reader,
182) -> ReadResult<table::Region<'a>> {
183    let kind = match reader.get_kind()? {
184        hugr_capnp::RegionKind::DataFlow => model::RegionKind::DataFlow,
185        hugr_capnp::RegionKind::ControlFlow => model::RegionKind::ControlFlow,
186        hugr_capnp::RegionKind::Module => model::RegionKind::Module,
187    };
188
189    let sources = read_scalar_list!(bump, reader, get_sources, table::LinkIndex);
190    let targets = read_scalar_list!(bump, reader, get_targets, table::LinkIndex);
191    let children = read_scalar_list!(bump, reader, get_children, table::NodeId);
192    let meta = read_scalar_list!(bump, reader, get_meta, table::TermId);
193    let signature = reader.get_signature().checked_sub(1).map(table::TermId);
194
195    let scope = if reader.has_scope() {
196        Some(read_region_scope(reader.get_scope()?)?)
197    } else {
198        None
199    };
200
201    Ok(table::Region {
202        kind,
203        sources,
204        targets,
205        children,
206        meta,
207        signature,
208        scope,
209    })
210}
211
212fn read_region_scope(reader: hugr_capnp::region_scope::Reader) -> ReadResult<table::RegionScope> {
213    let links = reader.get_links();
214    let ports = reader.get_ports();
215    Ok(table::RegionScope { links, ports })
216}
217
218impl From<hugr_capnp::Visibility> for Option<model::Visibility> {
219    fn from(value: hugr_capnp::Visibility) -> Self {
220        match value {
221            hugr_capnp::Visibility::Unspecified => None,
222            hugr_capnp::Visibility::Private => Some(model::Visibility::Private),
223            hugr_capnp::Visibility::Public => Some(model::Visibility::Public),
224        }
225    }
226}
227
228/// (Only) if `constraints` are None, then they are read from the `reader`
229fn read_symbol<'a>(
230    bump: &'a Bump,
231    reader: hugr_capnp::symbol::Reader,
232    constraints: Option<&'a [table::TermId]>,
233) -> ReadResult<&'a mut table::Symbol<'a>> {
234    let name = bump.alloc_str(reader.get_name()?.to_str()?);
235    let visibility = reader.get_visibility()?.into();
236    let visibility = bump.alloc(visibility);
237    let params = read_list!(bump, reader.get_params()?, read_param);
238    let constraints = match constraints {
239        Some(cs) => cs,
240        None => read_scalar_list!(bump, reader, get_constraints, table::TermId),
241    };
242    let signature = table::TermId(reader.get_signature());
243    Ok(bump.alloc(table::Symbol {
244        visibility,
245        name,
246        params,
247        constraints,
248        signature,
249    }))
250}
251
252fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult<table::Term<'a>> {
253    use hugr_capnp::term::Which;
254    Ok(match reader.which()? {
255        Which::Wildcard(()) => table::Term::Wildcard,
256        Which::String(value) => table::Term::Literal(model::Literal::Str(value?.to_str()?.into())),
257        Which::Nat(value) => table::Term::Literal(model::Literal::Nat(value)),
258
259        Which::Variable(reader) => {
260            let node = table::NodeId(reader.get_node());
261            let index = reader.get_index();
262            table::Term::Var(table::VarId(node, index))
263        }
264
265        Which::Apply(reader) => {
266            let symbol = table::NodeId(reader.get_symbol());
267            let args = read_scalar_list!(bump, reader, get_args, table::TermId);
268            table::Term::Apply(symbol, args)
269        }
270
271        Which::List(reader) => {
272            let parts = read_list!(bump, reader?, read_seq_part);
273            table::Term::List(parts)
274        }
275
276        Which::Tuple(reader) => {
277            let parts = read_list!(bump, reader?, read_seq_part);
278            table::Term::Tuple(parts)
279        }
280
281        Which::Func(region) => table::Term::Func(table::RegionId(region)),
282
283        Which::Bytes(bytes) => table::Term::Literal(model::Literal::Bytes(bytes?.into())),
284        Which::Float(value) => table::Term::Literal(model::Literal::Float(value.into())),
285    })
286}
287
288fn read_seq_part(
289    _: &Bump,
290    reader: hugr_capnp::term::seq_part::Reader,
291) -> ReadResult<table::SeqPart> {
292    use hugr_capnp::term::seq_part::Which;
293    Ok(match reader.which()? {
294        Which::Item(term) => table::SeqPart::Item(table::TermId(term)),
295        Which::Splice(list) => table::SeqPart::Splice(table::TermId(list)),
296    })
297}
298
299fn read_param<'a>(
300    bump: &'a Bump,
301    reader: hugr_capnp::param::Reader,
302) -> ReadResult<table::Param<'a>> {
303    let name = bump.alloc_str(reader.get_name()?.to_str()?);
304    let r#type = table::TermId(reader.get_type());
305    Ok(table::Param { name, r#type })
306}