hugr_model/v0/ast/
print.rs

1use std::{borrow::Cow, fmt::Display};
2
3use base64::{Engine as _, prelude::BASE64_STANDARD};
4use pretty::{Arena, DocAllocator as _, RefDoc};
5
6use crate::v0::{Literal, RegionKind};
7
8use super::{
9    LinkName, Module, Node, Operation, Package, Param, Region, SeqPart, Symbol, SymbolName, Term,
10    VarName, Visibility,
11};
12
13struct Printer<'a> {
14    /// The arena in which to allocate the pretty-printed documents.
15    arena: &'a Arena<'a>,
16    /// Parts of the document to be concatenated.
17    docs: Vec<RefDoc<'a>>,
18    /// Stack of indices into `docs` denoting the current nesting.
19    docs_stack: Vec<usize>,
20}
21
22impl<'a> Printer<'a> {
23    fn new(arena: &'a Arena<'a>) -> Self {
24        Self {
25            arena,
26            docs: Vec::new(),
27            docs_stack: Vec::new(),
28        }
29    }
30
31    fn finish(self) -> RefDoc<'a> {
32        let sep = self
33            .arena
34            .concat([self.arena.hardline(), self.arena.hardline()]);
35        self.arena.intersperse(self.docs, sep).into_doc()
36    }
37
38    fn parens_enter(&mut self) {
39        self.delim_open();
40    }
41
42    fn parens_exit(&mut self) {
43        self.delim_close("(", ")", 2);
44    }
45
46    fn brackets_enter(&mut self) {
47        self.delim_open();
48    }
49
50    fn brackets_exit(&mut self) {
51        self.delim_close("[", "]", 1);
52    }
53
54    fn group_enter(&mut self) {
55        self.delim_open();
56    }
57
58    fn group_exit(&mut self) {
59        self.delim_close("", "", 0);
60    }
61
62    fn delim_open(&mut self) {
63        self.docs_stack.push(self.docs.len());
64    }
65
66    fn delim_close(&mut self, open: &'static str, close: &'static str, nesting: isize) {
67        let docs = self.docs.drain(self.docs_stack.pop().unwrap()..);
68        let doc = self.arena.concat([
69            self.arena.text(open),
70            self.arena
71                .intersperse(docs, self.arena.line())
72                .nest(nesting)
73                .group(),
74            self.arena.text(close),
75        ]);
76        self.docs.push(doc.into_doc());
77    }
78
79    fn text(&mut self, text: impl Into<Cow<'a, str>>) {
80        self.docs.push(self.arena.text(text).into_doc());
81    }
82
83    fn int(&mut self, value: u64) {
84        self.text(format!("{value}"));
85    }
86
87    fn string(&mut self, string: &str) {
88        let mut output = String::with_capacity(string.len() + 2);
89        output.push('"');
90
91        for c in string.chars() {
92            match c {
93                '\\' => output.push_str("\\\\"),
94                '"' => output.push_str("\\\""),
95                '\n' => output.push_str("\\n"),
96                '\r' => output.push_str("\\r"),
97                '\t' => output.push_str("\\t"),
98                _ => output.push(c),
99            }
100        }
101
102        output.push('"');
103        self.text(output);
104    }
105
106    fn bytes(&mut self, bytes: &[u8]) {
107        // every 3 bytes are encoded into 4 characters
108        let mut output = String::with_capacity(2 + bytes.len().div_ceil(3) * 4);
109        output.push('"');
110        BASE64_STANDARD.encode_string(bytes, &mut output);
111        output.push('"');
112        self.text(output);
113    }
114}
115
116fn print_term<'a>(printer: &mut Printer<'a>, term: &'a Term) {
117    match term {
118        Term::Wildcard => printer.text("_"),
119        Term::Var(var) => print_var_name(printer, var),
120        Term::Apply(symbol, terms) => {
121            if terms.is_empty() {
122                print_symbol_name(printer, symbol);
123            } else {
124                printer.parens_enter();
125                print_symbol_name(printer, symbol);
126
127                for term in terms.iter() {
128                    print_term(printer, term);
129                }
130
131                printer.parens_exit();
132            }
133        }
134        Term::List(list_parts) => {
135            printer.brackets_enter();
136            print_list_parts(printer, list_parts);
137            printer.brackets_exit();
138        }
139        Term::Literal(literal) => {
140            print_literal(printer, literal);
141        }
142        Term::Tuple(tuple_parts) => {
143            printer.parens_enter();
144            printer.text("tuple");
145            print_tuple_parts(printer, tuple_parts);
146            printer.parens_exit();
147        }
148        Term::Func(region) => {
149            printer.parens_enter();
150            printer.text("fn");
151            print_region(printer, region);
152            printer.parens_exit();
153        }
154    }
155}
156
157fn print_literal<'a>(printer: &mut Printer<'a>, literal: &'a Literal) {
158    match literal {
159        Literal::Str(str) => {
160            printer.string(str);
161        }
162        Literal::Nat(nat) => {
163            printer.int(*nat);
164        }
165        Literal::Bytes(bytes) => {
166            printer.parens_enter();
167            printer.text("bytes");
168            printer.bytes(bytes);
169            printer.parens_exit();
170        }
171        Literal::Float(float) => {
172            // The debug representation of a float always includes a decimal point.
173            printer.text(format!("{:.?}", float.into_inner()));
174        }
175    }
176}
177
178fn print_seq_splice<'a>(printer: &mut Printer<'a>, term: &'a Term) {
179    printer.group_enter();
180    print_term(printer, term);
181    printer.text("...");
182    printer.group_exit();
183}
184
185/// Print a [`SeqPart`] in isolation for the [`Display`] instance.
186fn print_seq_part<'a>(printer: &mut Printer<'a>, part: &'a SeqPart) {
187    match part {
188        SeqPart::Item(term) => print_term(printer, term),
189        SeqPart::Splice(term) => print_seq_splice(printer, term),
190    }
191}
192
193/// Print the parts of a list [`Term`], merging spreaded lists.
194fn print_list_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
195    for part in parts {
196        match part {
197            SeqPart::Item(term) => print_term(printer, term),
198            SeqPart::Splice(Term::List(nested)) => print_list_parts(printer, nested),
199            SeqPart::Splice(term) => print_seq_splice(printer, term),
200        }
201    }
202}
203
204/// Print the parts of a tuple [`Term`], merging spreaded tuples.
205fn print_tuple_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
206    for part in parts {
207        match part {
208            SeqPart::Item(term) => print_term(printer, term),
209            SeqPart::Splice(Term::Tuple(nested)) => print_tuple_parts(printer, nested),
210            SeqPart::Splice(term) => print_seq_splice(printer, term),
211        }
212    }
213}
214
215fn print_symbol_name<'a>(printer: &mut Printer<'a>, name: &'a SymbolName) {
216    printer.text(name.0.as_str());
217}
218
219fn print_var_name<'a>(printer: &mut Printer<'a>, name: &'a VarName) {
220    printer.text(format!("{name}"));
221}
222
223fn print_link_name<'a>(printer: &mut Printer<'a>, name: &'a LinkName) {
224    printer.text(format!("{name}"));
225}
226
227fn print_port_lists<'a>(
228    printer: &mut Printer<'a>,
229    inputs: &'a [LinkName],
230    outputs: &'a [LinkName],
231) {
232    // If the node/region has no ports, we avoid printing the port lists.
233    // This is especially important for the syntax of nodes that introduce symbols
234    // since these nodes never have any input or output ports.
235    if inputs.is_empty() && outputs.is_empty() {
236        return;
237    }
238
239    // The group encodes the preference that the port lists occur on the same
240    // line whenever possible.
241    printer.group_enter();
242    printer.brackets_enter();
243    for input in inputs {
244        print_link_name(printer, input);
245    }
246    printer.brackets_exit();
247    printer.brackets_enter();
248    for output in outputs {
249        print_link_name(printer, output);
250    }
251    printer.brackets_exit();
252    printer.group_exit();
253}
254
255fn print_package<'a>(printer: &mut Printer<'a>, package: &'a Package) {
256    printer.parens_enter();
257    printer.text("hugr");
258    printer.text("0");
259    printer.parens_exit();
260
261    for module in &package.modules {
262        printer.parens_enter();
263        printer.text("mod");
264        printer.parens_exit();
265
266        print_module(printer, module);
267    }
268}
269
270fn print_module<'a>(printer: &mut Printer<'a>, module: &'a Module) {
271    for meta in &module.root.meta {
272        print_meta_item(printer, meta);
273    }
274
275    for child in &module.root.children {
276        print_node(printer, child);
277    }
278}
279
280fn print_node<'a>(printer: &mut Printer<'a>, node: &'a Node) {
281    printer.parens_enter();
282
283    printer.group_enter();
284    match &node.operation {
285        Operation::Invalid => printer.text("invalid"),
286        Operation::Dfg => printer.text("dfg"),
287        Operation::Cfg => printer.text("cfg"),
288        Operation::Block => printer.text("block"),
289        Operation::DefineFunc(symbol_signature) => {
290            printer.text("define-func");
291            print_symbol(printer, symbol_signature);
292        }
293        Operation::DeclareFunc(symbol_signature) => {
294            printer.text("declare-func");
295            print_symbol(printer, symbol_signature);
296        }
297        Operation::Custom(term) => {
298            print_term(printer, term);
299        }
300        Operation::DefineAlias(symbol_signature, value) => {
301            printer.text("define-alias");
302            print_symbol(printer, symbol_signature);
303            print_term(printer, value);
304        }
305        Operation::DeclareAlias(symbol_signature) => {
306            printer.text("declare-alias");
307            print_symbol(printer, symbol_signature);
308        }
309        Operation::TailLoop => printer.text("tail-loop"),
310        Operation::Conditional => printer.text("cond"),
311        Operation::DeclareConstructor(symbol_signature) => {
312            printer.text("declare-ctr");
313            print_symbol(printer, symbol_signature);
314        }
315        Operation::DeclareOperation(symbol_signature) => {
316            printer.text("declare-operation");
317            print_symbol(printer, symbol_signature);
318        }
319        Operation::Import(symbol) => {
320            printer.text("import");
321            print_symbol_name(printer, symbol);
322        }
323    }
324
325    print_port_lists(printer, &node.inputs, &node.outputs);
326    printer.group_exit();
327
328    if let Some(signature) = &node.signature {
329        print_signature(printer, signature);
330    }
331
332    for meta in &node.meta {
333        print_meta_item(printer, meta);
334    }
335
336    for region in &node.regions {
337        print_region(printer, region);
338    }
339
340    printer.parens_exit();
341}
342
343fn print_region<'a>(printer: &mut Printer<'a>, region: &'a Region) {
344    printer.parens_enter();
345    printer.group_enter();
346
347    printer.text(match region.kind {
348        RegionKind::DataFlow => "dfg",
349        RegionKind::ControlFlow => "cfg",
350        RegionKind::Module => "mod",
351    });
352
353    print_port_lists(printer, &region.sources, &region.targets);
354    printer.group_exit();
355
356    if let Some(signature) = &region.signature {
357        print_signature(printer, signature);
358    }
359
360    for meta in &region.meta {
361        print_meta_item(printer, meta);
362    }
363
364    for child in &region.children {
365        print_node(printer, child);
366    }
367
368    printer.parens_exit();
369}
370
371fn print_symbol<'a>(printer: &mut Printer<'a>, symbol: &'a Symbol) {
372    match symbol.visibility {
373        None => (),
374        Some(Visibility::Private) => printer.text("private"),
375        Some(Visibility::Public) => printer.text("public"),
376    }
377
378    print_symbol_name(printer, &symbol.name);
379
380    for param in &symbol.params {
381        print_param(printer, param);
382    }
383
384    for constraint in &symbol.constraints {
385        print_constraint(printer, constraint);
386    }
387
388    print_term(printer, &symbol.signature);
389}
390
391fn print_param<'a>(printer: &mut Printer<'a>, param: &'a Param) {
392    printer.parens_enter();
393    printer.text("param");
394    print_var_name(printer, &param.name);
395    print_term(printer, &param.r#type);
396    printer.parens_exit();
397}
398
399fn print_constraint<'a>(printer: &mut Printer<'a>, constraint: &'a Term) {
400    printer.parens_enter();
401    printer.text("where");
402    print_term(printer, constraint);
403    printer.parens_exit();
404}
405
406fn print_meta_item<'a>(printer: &mut Printer<'a>, meta: &'a Term) {
407    printer.parens_enter();
408    printer.text("meta");
409    print_term(printer, meta);
410    printer.parens_exit();
411}
412
413fn print_signature<'a>(printer: &mut Printer<'a>, signature: &'a Term) {
414    printer.parens_enter();
415    printer.text("signature");
416    print_term(printer, signature);
417    printer.parens_exit();
418}
419
420macro_rules! impl_display {
421    ($t:ident, $print:expr) => {
422        impl Display for $t {
423            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424                let arena = Arena::new();
425                let mut printer = Printer::new(&arena);
426                $print(&mut printer, self);
427                let doc = printer.finish();
428                doc.render_fmt(80, f)
429            }
430        }
431    };
432}
433
434impl_display!(Package, print_package);
435impl_display!(Module, print_module);
436impl_display!(Node, print_node);
437impl_display!(Region, print_region);
438impl_display!(Param, print_param);
439impl_display!(Term, print_term);
440impl_display!(SeqPart, print_seq_part);
441impl_display!(Literal, print_literal);
442impl_display!(Symbol, print_symbol);
443
444impl Display for VarName {
445    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
446        write!(f, "?{}", self.0)
447    }
448}
449
450impl Display for SymbolName {
451    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452        write!(f, "{}", self.0)
453    }
454}
455
456impl Display for LinkName {
457    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458        write!(f, "%{}", self.0)
459    }
460}