hugr_model/v0/binary/
write.rs

1use std::io::Write;
2
3use crate::capnp::hugr_v0_capnp as hugr_capnp;
4use crate::v0 as model;
5use crate::v0::table;
6
7/// An error encounter while serializing a model.
8#[derive(Debug, derive_more::From, derive_more::Display, derive_more::Error)]
9#[non_exhaustive]
10pub enum WriteError {
11    /// An error encountered while encoding a `capnproto` buffer.
12    EncodingError(capnp::Error),
13}
14
15/// Write a list of items into a list builder.
16macro_rules! write_list {
17    ($builder:expr, $init:ident, $write:expr, $list:expr) => {
18        let mut __list_builder = $builder.reborrow().$init($list.len() as _);
19        for (index, item) in $list.iter().enumerate() {
20            $write(__list_builder.reborrow().get(index as _), item);
21        }
22    };
23}
24
25/// Writes a module to an impl of [Write].
26pub fn write_to_writer(module: &table::Module, writer: impl Write) -> Result<(), WriteError> {
27    let mut message = capnp::message::Builder::new_default();
28    let builder = message.init_root();
29    write_module(builder, module);
30
31    Ok(capnp::serialize_packed::write_message(writer, &message)?)
32}
33
34/// Writes a module to a byte vector.
35pub fn write_to_vec(module: &table::Module) -> Vec<u8> {
36    let mut message = capnp::message::Builder::new_default();
37    let builder = message.init_root();
38    write_module(builder, module);
39
40    let mut output = Vec::new();
41    let _ = capnp::serialize_packed::write_message(&mut output, &message);
42    output
43}
44
45fn write_module(mut builder: hugr_capnp::module::Builder, module: &table::Module) {
46    builder.set_root(module.root.0);
47    write_list!(builder, init_nodes, write_node, module.nodes);
48    write_list!(builder, init_regions, write_region, module.regions);
49    write_list!(builder, init_terms, write_term, module.terms);
50}
51
52fn write_node(mut builder: hugr_capnp::node::Builder, node: &table::Node) {
53    write_operation(builder.reborrow().init_operation(), &node.operation);
54    let _ = builder.set_inputs(table::LinkIndex::unwrap_slice(node.inputs));
55    let _ = builder.set_outputs(table::LinkIndex::unwrap_slice(node.outputs));
56    let _ = builder.set_meta(table::TermId::unwrap_slice(node.meta));
57    let _ = builder.set_regions(table::RegionId::unwrap_slice(node.regions));
58    builder.set_signature(node.signature.map_or(0, |t| t.0 + 1));
59}
60
61fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &table::Operation) {
62    match operation {
63        table::Operation::Dfg => builder.set_dfg(()),
64        table::Operation::Cfg => builder.set_cfg(()),
65        table::Operation::Block => builder.set_block(()),
66        table::Operation::TailLoop => builder.set_tail_loop(()),
67        table::Operation::Conditional => builder.set_conditional(()),
68        table::Operation::Custom(operation) => builder.set_custom(operation.0),
69
70        table::Operation::DefineFunc(symbol) => {
71            let builder = builder.init_func_defn();
72            write_symbol(builder, symbol);
73        }
74        table::Operation::DeclareFunc(symbol) => {
75            let builder = builder.init_func_decl();
76            write_symbol(builder, symbol);
77        }
78
79        table::Operation::DefineAlias(symbol, value) => {
80            let mut builder = builder.init_alias_defn();
81            write_symbol(builder.reborrow().init_symbol(), symbol);
82            builder.set_value(value.0);
83        }
84        table::Operation::DeclareAlias(symbol) => {
85            let builder = builder.init_alias_decl();
86            write_symbol(builder, symbol);
87        }
88
89        table::Operation::DeclareConstructor(symbol) => {
90            let builder = builder.init_constructor_decl();
91            write_symbol(builder, symbol);
92        }
93        table::Operation::DeclareOperation(symbol) => {
94            let builder = builder.init_operation_decl();
95            write_symbol(builder, symbol);
96        }
97
98        table::Operation::Import { name } => {
99            builder.set_import(*name);
100        }
101
102        table::Operation::Invalid => builder.set_invalid(()),
103    }
104}
105
106fn write_symbol(mut builder: hugr_capnp::symbol::Builder, symbol: &table::Symbol) {
107    builder.set_name(symbol.name);
108    write_list!(builder, init_params, write_param, symbol.params);
109    let _ = builder.set_constraints(table::TermId::unwrap_slice(symbol.constraints));
110    builder.set_signature(symbol.signature.0);
111}
112
113fn write_param(mut builder: hugr_capnp::param::Builder, param: &table::Param) {
114    builder.set_name(param.name);
115    builder.set_type(param.r#type.0);
116}
117
118fn write_region(mut builder: hugr_capnp::region::Builder, region: &table::Region) {
119    builder.set_kind(match region.kind {
120        model::RegionKind::DataFlow => hugr_capnp::RegionKind::DataFlow,
121        model::RegionKind::ControlFlow => hugr_capnp::RegionKind::ControlFlow,
122        model::RegionKind::Module => hugr_capnp::RegionKind::Module,
123    });
124
125    let _ = builder.set_sources(table::LinkIndex::unwrap_slice(region.sources));
126    let _ = builder.set_targets(table::LinkIndex::unwrap_slice(region.targets));
127    let _ = builder.set_children(table::NodeId::unwrap_slice(region.children));
128    let _ = builder.set_meta(table::TermId::unwrap_slice(region.meta));
129    builder.set_signature(region.signature.map_or(0, |t| t.0 + 1));
130
131    if let Some(scope) = &region.scope {
132        write_region_scope(builder.init_scope(), scope);
133    }
134}
135
136fn write_region_scope(mut builder: hugr_capnp::region_scope::Builder, scope: &table::RegionScope) {
137    builder.set_links(scope.links);
138    builder.set_ports(scope.ports);
139}
140
141fn write_term(mut builder: hugr_capnp::term::Builder, term: &table::Term) {
142    match term {
143        table::Term::Wildcard => builder.set_wildcard(()),
144        table::Term::Var(table::VarId(node, index)) => {
145            let mut builder = builder.init_variable();
146            builder.set_node(node.0);
147            builder.set_index(*index);
148        }
149
150        table::Term::Literal(value) => match value {
151            model::Literal::Str(value) => builder.set_string(value),
152            model::Literal::Nat(value) => builder.set_nat(*value),
153            model::Literal::Bytes(value) => builder.set_bytes(value),
154            model::Literal::Float(value) => builder.set_float(value.into_inner()),
155        },
156
157        table::Term::ConstFunc(region) => builder.set_const_func(region.0),
158        table::Term::Apply(symbol, args) => {
159            let mut builder = builder.init_apply();
160            builder.set_symbol(symbol.0);
161            let _ = builder.set_args(table::TermId::unwrap_slice(args));
162        }
163
164        table::Term::List(parts) => {
165            write_list!(builder, init_list, write_seq_part, parts);
166        }
167
168        table::Term::ExtSet(parts) => {
169            write_list!(builder, init_ext_set, write_ext_set_part, parts);
170        }
171
172        table::Term::Tuple(parts) => {
173            write_list!(builder, init_tuple, write_seq_part, parts);
174        }
175    }
176}
177
178fn write_seq_part(mut builder: hugr_capnp::term::seq_part::Builder, part: &table::SeqPart) {
179    match part {
180        table::SeqPart::Item(term_id) => builder.set_item(term_id.0),
181        table::SeqPart::Splice(term_id) => builder.set_splice(term_id.0),
182    }
183}
184
185fn write_ext_set_part(
186    mut builder: hugr_capnp::term::ext_set_part::Builder,
187    part: &table::ExtSetPart,
188) {
189    match part {
190        table::ExtSetPart::Extension(ext) => builder.set_extension(ext),
191        table::ExtSetPart::Splice(term_id) => builder.set_splice(term_id.0),
192    }
193}