hugr_model/v0/ast/
print.rs

1use std::{borrow::Cow, fmt::Display};
2
3use base64::{prelude::BASE64_STANDARD, Engine as _};
4use pretty::{Arena, DocAllocator as _, RefDoc};
5
6use crate::v0::{Literal, RegionKind};
7
8use super::{
9    LinkName, Module, Node, Operation, Param, Region, SeqPart, Symbol, SymbolName, Term, VarName,
10};
11
12struct Printer<'a> {
13    /// The arena in which to allocate the pretty-printed documents.
14    arena: &'a Arena<'a>,
15    /// Parts of the document to be concatenated.
16    docs: Vec<RefDoc<'a>>,
17    /// Stack of indices into `docs` denoting the current nesting.
18    docs_stack: Vec<usize>,
19}
20
21impl<'a> Printer<'a> {
22    fn new(arena: &'a Arena<'a>) -> Self {
23        Self {
24            arena,
25            docs: Vec::new(),
26            docs_stack: Vec::new(),
27        }
28    }
29
30    fn finish(self) -> RefDoc<'a> {
31        let sep = self
32            .arena
33            .concat([self.arena.hardline(), self.arena.hardline()]);
34        self.arena.intersperse(self.docs, sep).into_doc()
35    }
36
37    fn parens_enter(&mut self) {
38        self.delim_open();
39    }
40
41    fn parens_exit(&mut self) {
42        self.delim_close("(", ")", 2);
43    }
44
45    fn brackets_enter(&mut self) {
46        self.delim_open();
47    }
48
49    fn brackets_exit(&mut self) {
50        self.delim_close("[", "]", 1);
51    }
52
53    fn group_enter(&mut self) {
54        self.delim_open();
55    }
56
57    fn group_exit(&mut self) {
58        self.delim_close("", "", 0);
59    }
60
61    fn delim_open(&mut self) {
62        self.docs_stack.push(self.docs.len());
63    }
64
65    fn delim_close(&mut self, open: &'static str, close: &'static str, nesting: isize) {
66        let docs = self.docs.drain(self.docs_stack.pop().unwrap()..);
67        let doc = self.arena.concat([
68            self.arena.text(open),
69            self.arena
70                .intersperse(docs, self.arena.line())
71                .nest(nesting)
72                .group(),
73            self.arena.text(close),
74        ]);
75        self.docs.push(doc.into_doc());
76    }
77
78    fn text(&mut self, text: impl Into<Cow<'a, str>>) {
79        self.docs.push(self.arena.text(text).into_doc());
80    }
81
82    fn int(&mut self, value: u64) {
83        self.text(format!("{}", value));
84    }
85
86    fn string(&mut self, string: &str) {
87        let mut output = String::with_capacity(string.len() + 2);
88        output.push('"');
89
90        for c in string.chars() {
91            match c {
92                '\\' => output.push_str("\\\\"),
93                '"' => output.push_str("\\\""),
94                '\n' => output.push_str("\\n"),
95                '\r' => output.push_str("\\r"),
96                '\t' => output.push_str("\\t"),
97                _ => output.push(c),
98            }
99        }
100
101        output.push('"');
102        self.text(output);
103    }
104
105    fn bytes(&mut self, bytes: &[u8]) {
106        // every 3 bytes are encoded into 4 characters
107        let mut output = String::with_capacity(2 + bytes.len().div_ceil(3) * 4);
108        output.push('"');
109        BASE64_STANDARD.encode_string(bytes, &mut output);
110        output.push('"');
111        self.text(output);
112    }
113}
114
115fn print_term<'a>(printer: &mut Printer<'a>, term: &'a Term) {
116    match term {
117        Term::Wildcard => printer.text("_"),
118        Term::Var(var) => print_var_name(printer, var),
119        Term::Apply(symbol, terms) => {
120            if terms.is_empty() {
121                print_symbol_name(printer, symbol);
122            } else {
123                printer.parens_enter();
124                print_symbol_name(printer, symbol);
125
126                for term in terms.iter() {
127                    print_term(printer, term);
128                }
129
130                printer.parens_exit();
131            }
132        }
133        Term::List(list_parts) => {
134            printer.brackets_enter();
135            print_list_parts(printer, list_parts);
136            printer.brackets_exit();
137        }
138        Term::Literal(literal) => {
139            print_literal(printer, literal);
140        }
141        Term::Tuple(tuple_parts) => {
142            printer.parens_enter();
143            printer.text("tuple");
144            print_tuple_parts(printer, tuple_parts);
145            printer.parens_exit();
146        }
147        Term::ExtSet => {
148            printer.parens_enter();
149            printer.text("ext");
150            printer.parens_exit();
151        }
152        Term::Func(region) => {
153            printer.parens_enter();
154            printer.text("fn");
155            print_region(printer, region);
156            printer.parens_exit();
157        }
158    }
159}
160
161fn print_literal<'a>(printer: &mut Printer<'a>, literal: &'a Literal) {
162    match literal {
163        Literal::Str(str) => {
164            printer.string(str);
165        }
166        Literal::Nat(nat) => {
167            printer.int(*nat);
168        }
169        Literal::Bytes(bytes) => {
170            printer.parens_enter();
171            printer.text("bytes");
172            printer.bytes(bytes);
173            printer.parens_exit();
174        }
175        Literal::Float(float) => {
176            // The debug representation of a float always includes a decimal point.
177            printer.text(format!("{:.?}", float.into_inner()));
178        }
179    }
180}
181
182fn print_seq_splice<'a>(printer: &mut Printer<'a>, term: &'a Term) {
183    printer.group_enter();
184    print_term(printer, term);
185    printer.text("...");
186    printer.group_exit();
187}
188
189/// Print a [`SeqPart`] in isolation for the [`Display`] instance.
190fn print_seq_part<'a>(printer: &mut Printer<'a>, part: &'a SeqPart) {
191    match part {
192        SeqPart::Item(term) => print_term(printer, term),
193        SeqPart::Splice(term) => print_seq_splice(printer, term),
194    }
195}
196
197/// Print the parts of a list [`Term`], merging spreaded lists.
198fn print_list_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
199    for part in parts {
200        match part {
201            SeqPart::Item(term) => print_term(printer, term),
202            SeqPart::Splice(Term::List(nested)) => print_list_parts(printer, nested),
203            SeqPart::Splice(term) => print_seq_splice(printer, term),
204        }
205    }
206}
207
208/// Print the parts of a tuple [`Term`], merging spreaded tuples.
209fn print_tuple_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
210    for part in parts {
211        match part {
212            SeqPart::Item(term) => print_term(printer, term),
213            SeqPart::Splice(Term::Tuple(nested)) => print_tuple_parts(printer, nested),
214            SeqPart::Splice(term) => print_seq_splice(printer, term),
215        }
216    }
217}
218
219fn print_symbol_name<'a>(printer: &mut Printer<'a>, name: &'a SymbolName) {
220    printer.text(name.0.as_str())
221}
222
223fn print_var_name<'a>(printer: &mut Printer<'a>, name: &'a VarName) {
224    printer.text(format!("{}", name))
225}
226
227fn print_link_name<'a>(printer: &mut Printer<'a>, name: &'a LinkName) {
228    printer.text(format!("{}", name))
229}
230
231fn print_port_lists<'a>(
232    printer: &mut Printer<'a>,
233    inputs: &'a [LinkName],
234    outputs: &'a [LinkName],
235) {
236    // If the node/region has no ports, we avoid printing the port lists.
237    // This is especially important for the syntax of nodes that introduce symbols
238    // since these nodes never have any input or output ports.
239    if inputs.is_empty() && outputs.is_empty() {
240        return;
241    }
242
243    // The group encodes the preference that the port lists occur on the same
244    // line whenever possible.
245    printer.group_enter();
246    printer.brackets_enter();
247    for input in inputs {
248        print_link_name(printer, input);
249    }
250    printer.brackets_exit();
251    printer.brackets_enter();
252    for output in outputs {
253        print_link_name(printer, output);
254    }
255    printer.brackets_exit();
256    printer.group_exit();
257}
258
259fn print_module<'a>(printer: &mut Printer<'a>, module: &'a Module) {
260    printer.parens_enter();
261    printer.text("hugr");
262    printer.text("0");
263    printer.parens_exit();
264
265    for meta in module.root.meta.iter() {
266        print_meta_item(printer, meta);
267    }
268
269    for child in module.root.children.iter() {
270        print_node(printer, child);
271    }
272}
273
274fn print_node<'a>(printer: &mut Printer<'a>, node: &'a Node) {
275    printer.parens_enter();
276
277    printer.group_enter();
278    match &node.operation {
279        Operation::Invalid => printer.text("invalid"),
280        Operation::Dfg => printer.text("dfg"),
281        Operation::Cfg => printer.text("cfg"),
282        Operation::Block => printer.text("block"),
283        Operation::DefineFunc(symbol_signature) => {
284            printer.text("define-func");
285            print_symbol(printer, symbol_signature);
286        }
287        Operation::DeclareFunc(symbol_signature) => {
288            printer.text("declare-func");
289            print_symbol(printer, symbol_signature);
290        }
291        Operation::Custom(term) => {
292            print_term(printer, term);
293        }
294        Operation::DefineAlias(symbol_signature, value) => {
295            printer.text("define-alias");
296            print_symbol(printer, symbol_signature);
297            print_term(printer, value);
298        }
299        Operation::DeclareAlias(symbol_signature) => {
300            printer.text("declare-alias");
301            print_symbol(printer, symbol_signature);
302        }
303        Operation::TailLoop => printer.text("tail-loop"),
304        Operation::Conditional => printer.text("cond"),
305        Operation::DeclareConstructor(symbol_signature) => {
306            printer.text("declare-ctr");
307            print_symbol(printer, symbol_signature);
308        }
309        Operation::DeclareOperation(symbol_signature) => {
310            printer.text("declare-operation");
311            print_symbol(printer, symbol_signature);
312        }
313        Operation::Import(symbol) => {
314            printer.text("import");
315            print_symbol_name(printer, symbol);
316        }
317    }
318
319    print_port_lists(printer, &node.inputs, &node.outputs);
320    printer.group_exit();
321
322    if let Some(signature) = &node.signature {
323        print_signature(printer, signature);
324    }
325
326    for meta in node.meta.iter() {
327        print_meta_item(printer, meta);
328    }
329
330    for region in node.regions.iter() {
331        print_region(printer, region);
332    }
333
334    printer.parens_exit();
335}
336
337fn print_region<'a>(printer: &mut Printer<'a>, region: &'a Region) {
338    printer.parens_enter();
339    printer.group_enter();
340
341    printer.text(match region.kind {
342        RegionKind::DataFlow => "dfg",
343        RegionKind::ControlFlow => "cfg",
344        RegionKind::Module => "mod",
345    });
346
347    print_port_lists(printer, &region.sources, &region.targets);
348    printer.group_exit();
349
350    if let Some(signature) = &region.signature {
351        print_signature(printer, signature);
352    }
353
354    for meta in region.meta.iter() {
355        print_meta_item(printer, meta);
356    }
357
358    for child in region.children.iter() {
359        print_node(printer, child);
360    }
361
362    printer.parens_exit();
363}
364
365fn print_symbol<'a>(printer: &mut Printer<'a>, symbol: &'a Symbol) {
366    print_symbol_name(printer, &symbol.name);
367
368    for param in symbol.params.iter() {
369        print_param(printer, param);
370    }
371
372    for constraint in symbol.constraints.iter() {
373        print_constraint(printer, constraint);
374    }
375
376    print_term(printer, &symbol.signature);
377}
378
379fn print_param<'a>(printer: &mut Printer<'a>, param: &'a Param) {
380    printer.parens_enter();
381    printer.text("param");
382    print_var_name(printer, &param.name);
383    print_term(printer, &param.r#type);
384    printer.parens_exit();
385}
386
387fn print_constraint<'a>(printer: &mut Printer<'a>, constraint: &'a Term) {
388    printer.parens_enter();
389    printer.text("where");
390    print_term(printer, constraint);
391    printer.parens_exit();
392}
393
394fn print_meta_item<'a>(printer: &mut Printer<'a>, meta: &'a Term) {
395    printer.parens_enter();
396    printer.text("meta");
397    print_term(printer, meta);
398    printer.parens_exit();
399}
400
401fn print_signature<'a>(printer: &mut Printer<'a>, signature: &'a Term) {
402    printer.parens_enter();
403    printer.text("signature");
404    print_term(printer, signature);
405    printer.parens_exit();
406}
407
408macro_rules! impl_display {
409    ($t:ident, $print:expr) => {
410        impl Display for $t {
411            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412                let arena = Arena::new();
413                let mut printer = Printer::new(&arena);
414                $print(&mut printer, self);
415                let doc = printer.finish();
416                doc.render_fmt(80, f)
417            }
418        }
419    };
420}
421
422impl_display!(Module, print_module);
423impl_display!(Node, print_node);
424impl_display!(Region, print_region);
425impl_display!(Param, print_param);
426impl_display!(Term, print_term);
427impl_display!(SeqPart, print_seq_part);
428impl_display!(Literal, print_literal);
429
430impl Display for VarName {
431    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432        write!(f, "?{}", self.0)
433    }
434}
435
436impl Display for SymbolName {
437    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
438        write!(f, "{}", self.0)
439    }
440}
441
442impl Display for LinkName {
443    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444        write!(f, "%{}", self.0)
445    }
446}