hugr_model/v0/binary/
write.rs

1use std::io::Write;
2
3use crate::CURRENT_VERSION;
4use crate::capnp::hugr_v0_capnp as hugr_capnp;
5use crate::v0::{self as model, table};
6
7/// An error encounter while serializing a model.
8#[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)]
9#[non_exhaustive]
10#[display("Error encoding a package in HUGR model format.")]
11pub enum WriteError {
12    #[from(forward)]
13    /// An error encountered while encoding a `capnproto` buffer.
14    EncodingError(capnp::Error),
15}
16
17/// Write a list of items into a list builder.
18macro_rules! write_list {
19    ($builder:expr, $init:ident, $write:expr, $list:expr) => {
20        let mut __list_builder = $builder.reborrow().$init($list.len() as _);
21        for (index, item) in $list.iter().enumerate() {
22            $write(__list_builder.reborrow().get(index as _), item);
23        }
24    };
25}
26
27/// Writes a package to an impl of [Write].
28pub fn write_to_writer(package: &table::Package, writer: impl Write) -> Result<(), WriteError> {
29    let mut message = capnp::message::Builder::new_default();
30    let builder = message.init_root();
31    write_package(builder, package);
32
33    Ok(capnp::serialize_packed::write_message(writer, &message)?)
34}
35
36/// Writes a package to a byte vector.
37#[must_use]
38pub fn write_to_vec(package: &table::Package) -> Vec<u8> {
39    let mut message = capnp::message::Builder::new_default();
40    let builder = message.init_root();
41    write_package(builder, package);
42
43    let mut output = Vec::new();
44    let _ = capnp::serialize_packed::write_message(&mut output, &message);
45    output
46}
47
48fn write_package(mut builder: hugr_capnp::package::Builder, package: &table::Package) {
49    write_list!(builder, init_modules, write_module, package.modules);
50    write_version(builder.init_version(), &CURRENT_VERSION);
51}
52
53fn write_version(mut builder: hugr_capnp::version::Builder, version: &semver::Version) {
54    builder.set_major(version.major as u32);
55    builder.set_minor(version.minor as u32);
56}
57
58fn write_module(mut builder: hugr_capnp::module::Builder, module: &table::Module) {
59    builder.set_root(module.root.0);
60    write_list!(builder, init_nodes, write_node, module.nodes);
61    write_list!(builder, init_regions, write_region, module.regions);
62    write_list!(builder, init_terms, write_term, module.terms);
63}
64
65fn write_node(mut builder: hugr_capnp::node::Builder, node: &table::Node) {
66    write_operation(builder.reborrow().init_operation(), &node.operation);
67    let _ = builder.set_inputs(table::LinkIndex::unwrap_slice(node.inputs));
68    let _ = builder.set_outputs(table::LinkIndex::unwrap_slice(node.outputs));
69    let _ = builder.set_meta(table::TermId::unwrap_slice(node.meta));
70    let _ = builder.set_regions(table::RegionId::unwrap_slice(node.regions));
71    builder.set_signature(node.signature.map_or(0, |t| t.0 + 1));
72}
73
74fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &table::Operation) {
75    match operation {
76        table::Operation::Dfg => builder.set_dfg(()),
77        table::Operation::Cfg => builder.set_cfg(()),
78        table::Operation::Block => builder.set_block(()),
79        table::Operation::TailLoop => builder.set_tail_loop(()),
80        table::Operation::Conditional => builder.set_conditional(()),
81        table::Operation::Custom(operation) => builder.set_custom(operation.0),
82
83        table::Operation::DefineFunc(symbol) => {
84            let builder = builder.init_func_defn();
85            write_symbol(builder, symbol);
86        }
87        table::Operation::DeclareFunc(symbol) => {
88            let builder = builder.init_func_decl();
89            write_symbol(builder, symbol);
90        }
91
92        table::Operation::DefineAlias(symbol, value) => {
93            let mut builder = builder.init_alias_defn();
94            write_symbol(builder.reborrow().init_symbol(), symbol);
95            builder.set_value(value.0);
96        }
97        table::Operation::DeclareAlias(symbol) => {
98            let builder = builder.init_alias_decl();
99            write_symbol(builder, symbol);
100        }
101
102        table::Operation::DeclareConstructor(symbol) => {
103            let builder = builder.init_constructor_decl();
104            write_symbol(builder, symbol);
105        }
106        table::Operation::DeclareOperation(symbol) => {
107            let builder = builder.init_operation_decl();
108            write_symbol(builder, symbol);
109        }
110
111        table::Operation::Import { name } => {
112            builder.set_import(*name);
113        }
114
115        table::Operation::Invalid => builder.set_invalid(()),
116    }
117}
118
119fn write_symbol(mut builder: hugr_capnp::symbol::Builder, symbol: &table::Symbol) {
120    builder.set_name(symbol.name);
121    if let Some(vis) = symbol.visibility {
122        builder.set_visibility(match vis {
123            model::Visibility::Private => hugr_capnp::Visibility::Private,
124            model::Visibility::Public => hugr_capnp::Visibility::Public,
125        })
126    } // else, None -> use capnp default == Unspecified
127    write_list!(builder, init_params, write_param, symbol.params);
128    let _ = builder.set_constraints(table::TermId::unwrap_slice(symbol.constraints));
129    builder.set_signature(symbol.signature.0);
130}
131
132fn write_param(mut builder: hugr_capnp::param::Builder, param: &table::Param) {
133    builder.set_name(param.name);
134    builder.set_type(param.r#type.0);
135}
136
137fn write_region(mut builder: hugr_capnp::region::Builder, region: &table::Region) {
138    builder.set_kind(match region.kind {
139        model::RegionKind::DataFlow => hugr_capnp::RegionKind::DataFlow,
140        model::RegionKind::ControlFlow => hugr_capnp::RegionKind::ControlFlow,
141        model::RegionKind::Module => hugr_capnp::RegionKind::Module,
142    });
143
144    let _ = builder.set_sources(table::LinkIndex::unwrap_slice(region.sources));
145    let _ = builder.set_targets(table::LinkIndex::unwrap_slice(region.targets));
146    let _ = builder.set_children(table::NodeId::unwrap_slice(region.children));
147    let _ = builder.set_meta(table::TermId::unwrap_slice(region.meta));
148    builder.set_signature(region.signature.map_or(0, |t| t.0 + 1));
149
150    if let Some(scope) = &region.scope {
151        write_region_scope(builder.init_scope(), scope);
152    }
153}
154
155fn write_region_scope(mut builder: hugr_capnp::region_scope::Builder, scope: &table::RegionScope) {
156    builder.set_links(scope.links);
157    builder.set_ports(scope.ports);
158}
159
160fn write_term(mut builder: hugr_capnp::term::Builder, term: &table::Term) {
161    match term {
162        table::Term::Wildcard => builder.set_wildcard(()),
163        table::Term::Var(table::VarId(node, index)) => {
164            let mut builder = builder.init_variable();
165            builder.set_node(node.0);
166            builder.set_index(*index);
167        }
168
169        table::Term::Literal(value) => match value {
170            model::Literal::Str(value) => builder.set_string(value),
171            model::Literal::Nat(value) => builder.set_nat(*value),
172            model::Literal::Bytes(value) => builder.set_bytes(value),
173            model::Literal::Float(value) => builder.set_float(value.into_inner()),
174        },
175
176        table::Term::Func(region) => builder.set_func(region.0),
177        table::Term::Apply(symbol, args) => {
178            let mut builder = builder.init_apply();
179            builder.set_symbol(symbol.0);
180            let _ = builder.set_args(table::TermId::unwrap_slice(args));
181        }
182
183        table::Term::List(parts) => {
184            write_list!(builder, init_list, write_seq_part, parts);
185        }
186
187        table::Term::Tuple(parts) => {
188            write_list!(builder, init_tuple, write_seq_part, parts);
189        }
190    }
191}
192
193fn write_seq_part(mut builder: hugr_capnp::term::seq_part::Builder, part: &table::SeqPart) {
194    match part {
195        table::SeqPart::Item(term_id) => builder.set_item(term_id.0),
196        table::SeqPart::Splice(term_id) => builder.set_splice(term_id.0),
197    }
198}