1use crate::capnp::hugr_v0_capnp as hugr_capnp;
2use crate::v0::table;
3use crate::{CURRENT_VERSION, v0 as model};
4use bumpalo::Bump;
5use bumpalo::collections::Vec as BumpVec;
6use std::io::BufRead;
7
8#[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)]
10#[non_exhaustive]
11#[display("Error reading a HUGR model payload.")]
12pub enum ReadError {
13 #[from(forward)]
14 DecodingError(capnp::Error),
16
17 #[display("Can not read file with version {actual} (tooling version {current}).")]
19 VersionError {
20 current: semver::Version,
22 actual: semver::Version,
24 },
25}
26
27type ReadResult<T> = Result<T, ReadError>;
28
29pub fn read_from_slice<'a>(slice: &[u8], bump: &'a Bump) -> ReadResult<table::Package<'a>> {
31 read_from_reader(slice, bump)
32}
33
34pub fn read_from_reader(reader: impl BufRead, bump: &Bump) -> ReadResult<table::Package<'_>> {
36 let reader =
37 capnp::serialize_packed::read_message(reader, capnp::message::ReaderOptions::new())?;
38 let root = reader.get_root::<hugr_capnp::package::Reader>()?;
39 read_package(bump, root)
40}
41
42macro_rules! read_list {
44 ($bump:expr, $reader:expr, $read:expr) => {{
45 let mut __list_reader = $reader;
46 let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump);
47 for __item_reader in __list_reader.iter() {
48 __list.push($read($bump, __item_reader)?);
49 }
50 __list.into_bump_slice()
51 }};
52}
53
54macro_rules! read_scalar_list {
56 ($bump:expr, $reader:expr, $get:ident, $wrap:path) => {{
57 let mut __list_reader = $reader.$get()?;
58 let mut __list = BumpVec::with_capacity_in(__list_reader.len() as _, $bump);
59 for __item in __list_reader.iter() {
60 __list.push($wrap(__item));
61 }
62 __list.into_bump_slice()
63 }};
64}
65
66fn read_package<'a>(
67 bump: &'a Bump,
68 reader: hugr_capnp::package::Reader,
69) -> ReadResult<table::Package<'a>> {
70 let version = read_version(reader.get_version()?)?;
71
72 if version.major != CURRENT_VERSION.major || version.minor > CURRENT_VERSION.minor {
73 return Err(ReadError::VersionError {
74 current: CURRENT_VERSION.clone(),
75 actual: version,
76 });
77 }
78
79 let modules = reader
80 .get_modules()?
81 .iter()
82 .map(|m| read_module(bump, m))
83 .collect::<ReadResult<_>>()?;
84
85 Ok(table::Package { modules })
86}
87
88fn read_version(reader: hugr_capnp::version::Reader) -> ReadResult<semver::Version> {
89 let major = reader.get_major();
90 let minor = reader.get_minor();
91 Ok(semver::Version::new(major as u64, minor as u64, 0))
92}
93
94fn read_module<'a>(
95 bump: &'a Bump,
96 reader: hugr_capnp::module::Reader,
97) -> ReadResult<table::Module<'a>> {
98 let root = table::RegionId(reader.get_root());
99
100 let nodes = reader
101 .get_nodes()?
102 .iter()
103 .map(|r| read_node(bump, r))
104 .collect::<ReadResult<_>>()?;
105
106 let regions = reader
107 .get_regions()?
108 .iter()
109 .map(|r| read_region(bump, r))
110 .collect::<ReadResult<_>>()?;
111
112 let terms = reader
113 .get_terms()?
114 .iter()
115 .map(|r| read_term(bump, r))
116 .collect::<ReadResult<_>>()?;
117
118 Ok(table::Module {
119 root,
120 nodes,
121 regions,
122 terms,
123 })
124}
125
126fn read_node<'a>(bump: &'a Bump, reader: hugr_capnp::node::Reader) -> ReadResult<table::Node<'a>> {
127 let operation = read_operation(bump, reader.get_operation()?)?;
128 let inputs = read_scalar_list!(bump, reader, get_inputs, table::LinkIndex);
129 let outputs = read_scalar_list!(bump, reader, get_outputs, table::LinkIndex);
130 let regions = read_scalar_list!(bump, reader, get_regions, table::RegionId);
131 let meta = read_scalar_list!(bump, reader, get_meta, table::TermId);
132 let signature = reader.get_signature().checked_sub(1).map(table::TermId);
133
134 Ok(table::Node {
135 operation,
136 inputs,
137 outputs,
138 regions,
139 meta,
140 signature,
141 })
142}
143
144fn read_operation<'a>(
145 bump: &'a Bump,
146 reader: hugr_capnp::operation::Reader,
147) -> ReadResult<table::Operation<'a>> {
148 use hugr_capnp::operation::Which;
149 Ok(match reader.which()? {
150 Which::Invalid(()) => table::Operation::Invalid,
151 Which::Dfg(()) => table::Operation::Dfg,
152 Which::Cfg(()) => table::Operation::Cfg,
153 Which::Block(()) => table::Operation::Block,
154 Which::FuncDefn(reader) => table::Operation::DefineFunc(read_symbol(bump, reader?, None)?),
155 Which::FuncDecl(reader) => table::Operation::DeclareFunc(read_symbol(bump, reader?, None)?),
156 Which::AliasDefn(reader) => {
157 let symbol = reader.get_symbol()?;
158 let value = table::TermId(reader.get_value());
159 table::Operation::DefineAlias(read_symbol(bump, symbol, Some(&[]))?, value)
160 }
161 Which::AliasDecl(reader) => {
162 table::Operation::DeclareAlias(read_symbol(bump, reader?, Some(&[]))?)
163 }
164 Which::ConstructorDecl(reader) => {
165 table::Operation::DeclareConstructor(read_symbol(bump, reader?, None)?)
166 }
167 Which::OperationDecl(reader) => {
168 table::Operation::DeclareOperation(read_symbol(bump, reader?, None)?)
169 }
170 Which::Custom(operation) => table::Operation::Custom(table::TermId(operation)),
171 Which::TailLoop(()) => table::Operation::TailLoop,
172 Which::Conditional(()) => table::Operation::Conditional,
173 Which::Import(name) => table::Operation::Import {
174 name: bump.alloc_str(name?.to_str()?),
175 },
176 })
177}
178
179fn read_region<'a>(
180 bump: &'a Bump,
181 reader: hugr_capnp::region::Reader,
182) -> ReadResult<table::Region<'a>> {
183 let kind = match reader.get_kind()? {
184 hugr_capnp::RegionKind::DataFlow => model::RegionKind::DataFlow,
185 hugr_capnp::RegionKind::ControlFlow => model::RegionKind::ControlFlow,
186 hugr_capnp::RegionKind::Module => model::RegionKind::Module,
187 };
188
189 let sources = read_scalar_list!(bump, reader, get_sources, table::LinkIndex);
190 let targets = read_scalar_list!(bump, reader, get_targets, table::LinkIndex);
191 let children = read_scalar_list!(bump, reader, get_children, table::NodeId);
192 let meta = read_scalar_list!(bump, reader, get_meta, table::TermId);
193 let signature = reader.get_signature().checked_sub(1).map(table::TermId);
194
195 let scope = if reader.has_scope() {
196 Some(read_region_scope(reader.get_scope()?)?)
197 } else {
198 None
199 };
200
201 Ok(table::Region {
202 kind,
203 sources,
204 targets,
205 children,
206 meta,
207 signature,
208 scope,
209 })
210}
211
212fn read_region_scope(reader: hugr_capnp::region_scope::Reader) -> ReadResult<table::RegionScope> {
213 let links = reader.get_links();
214 let ports = reader.get_ports();
215 Ok(table::RegionScope { links, ports })
216}
217
218impl From<hugr_capnp::Visibility> for Option<model::Visibility> {
219 fn from(value: hugr_capnp::Visibility) -> Self {
220 match value {
221 hugr_capnp::Visibility::Unspecified => None,
222 hugr_capnp::Visibility::Private => Some(model::Visibility::Private),
223 hugr_capnp::Visibility::Public => Some(model::Visibility::Public),
224 }
225 }
226}
227
228fn read_symbol<'a>(
230 bump: &'a Bump,
231 reader: hugr_capnp::symbol::Reader,
232 constraints: Option<&'a [table::TermId]>,
233) -> ReadResult<&'a mut table::Symbol<'a>> {
234 let name = bump.alloc_str(reader.get_name()?.to_str()?);
235 let visibility = reader.get_visibility()?.into();
236 let visibility = bump.alloc(visibility);
237 let params = read_list!(bump, reader.get_params()?, read_param);
238 let constraints = match constraints {
239 Some(cs) => cs,
240 None => read_scalar_list!(bump, reader, get_constraints, table::TermId),
241 };
242 let signature = table::TermId(reader.get_signature());
243 Ok(bump.alloc(table::Symbol {
244 visibility,
245 name,
246 params,
247 constraints,
248 signature,
249 }))
250}
251
252fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult<table::Term<'a>> {
253 use hugr_capnp::term::Which;
254 Ok(match reader.which()? {
255 Which::Wildcard(()) => table::Term::Wildcard,
256 Which::String(value) => table::Term::Literal(model::Literal::Str(value?.to_str()?.into())),
257 Which::Nat(value) => table::Term::Literal(model::Literal::Nat(value)),
258
259 Which::Variable(reader) => {
260 let node = table::NodeId(reader.get_node());
261 let index = reader.get_index();
262 table::Term::Var(table::VarId(node, index))
263 }
264
265 Which::Apply(reader) => {
266 let symbol = table::NodeId(reader.get_symbol());
267 let args = read_scalar_list!(bump, reader, get_args, table::TermId);
268 table::Term::Apply(symbol, args)
269 }
270
271 Which::List(reader) => {
272 let parts = read_list!(bump, reader?, read_seq_part);
273 table::Term::List(parts)
274 }
275
276 Which::Tuple(reader) => {
277 let parts = read_list!(bump, reader?, read_seq_part);
278 table::Term::Tuple(parts)
279 }
280
281 Which::Func(region) => table::Term::Func(table::RegionId(region)),
282
283 Which::Bytes(bytes) => table::Term::Literal(model::Literal::Bytes(bytes?.into())),
284 Which::Float(value) => table::Term::Literal(model::Literal::Float(value.into())),
285 })
286}
287
288fn read_seq_part(
289 _: &Bump,
290 reader: hugr_capnp::term::seq_part::Reader,
291) -> ReadResult<table::SeqPart> {
292 use hugr_capnp::term::seq_part::Which;
293 Ok(match reader.which()? {
294 Which::Item(term) => table::SeqPart::Item(table::TermId(term)),
295 Which::Splice(list) => table::SeqPart::Splice(table::TermId(list)),
296 })
297}
298
299fn read_param<'a>(
300 bump: &'a Bump,
301 reader: hugr_capnp::param::Reader,
302) -> ReadResult<table::Param<'a>> {
303 let name = bump.alloc_str(reader.get_name()?.to_str()?);
304 let r#type = table::TermId(reader.get_type());
305 Ok(table::Param { name, r#type })
306}