Skip to main content

hugr_model/v0/ast/
print.rs

1use std::{borrow::Cow, fmt::Display};
2
3use base64::{Engine as _, prelude::BASE64_STANDARD};
4use pretty::{Arena, DocAllocator as _, RefDoc};
5
6use crate::v0::{Literal, RegionKind};
7
8use super::parse::is_bare_symbol_name;
9use super::{
10    LinkName, Module, Node, Operation, Package, Param, Region, SeqPart, Symbol, SymbolIdent,
11    SymbolName, Term, VarName, Visibility,
12};
13
14struct Printer<'a> {
15    /// The arena in which to allocate the pretty-printed documents.
16    arena: &'a Arena<'a>,
17    /// Parts of the document to be concatenated.
18    docs: Vec<RefDoc<'a>>,
19    /// Stack of indices into `docs` denoting the current nesting.
20    docs_stack: Vec<usize>,
21}
22
23impl<'a> Printer<'a> {
24    fn new(arena: &'a Arena<'a>) -> Self {
25        Self {
26            arena,
27            docs: Vec::new(),
28            docs_stack: Vec::new(),
29        }
30    }
31
32    fn finish(self) -> RefDoc<'a> {
33        let sep = self
34            .arena
35            .concat([self.arena.hardline(), self.arena.hardline()]);
36        self.arena.intersperse(self.docs, sep).into_doc()
37    }
38
39    fn parens_enter(&mut self) {
40        self.delim_open();
41    }
42
43    fn parens_exit(&mut self) {
44        self.delim_close("(", ")", 2);
45    }
46
47    fn brackets_enter(&mut self) {
48        self.delim_open();
49    }
50
51    fn brackets_exit(&mut self) {
52        self.delim_close("[", "]", 1);
53    }
54
55    fn group_enter(&mut self) {
56        self.delim_open();
57    }
58
59    fn group_exit(&mut self) {
60        self.delim_close("", "", 0);
61    }
62
63    fn delim_open(&mut self) {
64        self.docs_stack.push(self.docs.len());
65    }
66
67    fn delim_close(&mut self, open: &'static str, close: &'static str, nesting: isize) {
68        let docs = self.docs.drain(self.docs_stack.pop().unwrap()..);
69        let doc = self.arena.concat([
70            self.arena.text(open),
71            self.arena
72                .intersperse(docs, self.arena.line())
73                .nest(nesting)
74                .group(),
75            self.arena.text(close),
76        ]);
77        self.docs.push(doc.into_doc());
78    }
79
80    fn text(&mut self, text: impl Into<Cow<'a, str>>) {
81        self.docs.push(self.arena.text(text).into_doc());
82    }
83
84    fn int(&mut self, value: u64) {
85        self.text(format!("{value}"));
86    }
87
88    fn string(&mut self, string: &str) {
89        self.text(quote_string(string));
90    }
91
92    fn bytes(&mut self, bytes: &[u8]) {
93        // every 3 bytes are encoded into 4 characters
94        let mut output = String::with_capacity(2 + bytes.len().div_ceil(3) * 4);
95        output.push('"');
96        BASE64_STANDARD.encode_string(bytes, &mut output);
97        output.push('"');
98        self.text(output);
99    }
100}
101
102/// Return the escaped representation used for string literals.
103fn quote_string(string: &str) -> String {
104    let mut output = String::with_capacity(string.len() + 2);
105    output.push('"');
106
107    for c in string.chars() {
108        match c {
109            '\\' => output.push_str("\\\\"),
110            '"' => output.push_str("\\\""),
111            '\n' => output.push_str("\\n"),
112            '\r' => output.push_str("\\r"),
113            '\t' => output.push_str("\\t"),
114            _ => output.push(c),
115        }
116    }
117
118    output.push('"');
119    output
120}
121
122fn print_term<'a>(printer: &mut Printer<'a>, term: &'a Term) {
123    match term {
124        Term::Wildcard => printer.text("_"),
125        Term::Var(var) => print_var_name(printer, var),
126        Term::Apply(symbol, terms) => {
127            if terms.is_empty() {
128                print_symbol_ident(printer, symbol);
129            } else {
130                printer.parens_enter();
131                print_symbol_ident(printer, symbol);
132
133                for term in terms.iter() {
134                    print_term(printer, term);
135                }
136
137                printer.parens_exit();
138            }
139        }
140        Term::List(list_parts) => {
141            printer.brackets_enter();
142            print_list_parts(printer, list_parts);
143            printer.brackets_exit();
144        }
145        Term::Literal(literal) => {
146            print_literal(printer, literal);
147        }
148        Term::Tuple(tuple_parts) => {
149            printer.parens_enter();
150            printer.text("tuple");
151            print_tuple_parts(printer, tuple_parts);
152            printer.parens_exit();
153        }
154        Term::Func(region) => {
155            printer.parens_enter();
156            printer.text("fn");
157            print_region(printer, region);
158            printer.parens_exit();
159        }
160    }
161}
162
163fn print_literal<'a>(printer: &mut Printer<'a>, literal: &'a Literal) {
164    match literal {
165        Literal::Str(str) => {
166            printer.string(str);
167        }
168        Literal::Nat(nat) => {
169            printer.int(*nat);
170        }
171        Literal::Bytes(bytes) => {
172            printer.parens_enter();
173            printer.text("bytes");
174            printer.bytes(bytes);
175            printer.parens_exit();
176        }
177        Literal::Float(float) => {
178            // The debug representation of a float always includes a decimal point.
179            printer.text(format!("{:.?}", float.into_inner()));
180        }
181    }
182}
183
184fn print_seq_splice<'a>(printer: &mut Printer<'a>, term: &'a Term) {
185    printer.group_enter();
186    print_term(printer, term);
187    printer.text("...");
188    printer.group_exit();
189}
190
191/// Print a [`SeqPart`] in isolation for the [`Display`] instance.
192fn print_seq_part<'a>(printer: &mut Printer<'a>, part: &'a SeqPart) {
193    match part {
194        SeqPart::Item(term) => print_term(printer, term),
195        SeqPart::Splice(term) => print_seq_splice(printer, term),
196    }
197}
198
199/// Print the parts of a list [`Term`], merging spread lists.
200fn print_list_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
201    for part in parts {
202        match part {
203            SeqPart::Item(term) => print_term(printer, term),
204            SeqPart::Splice(Term::List(nested)) => print_list_parts(printer, nested),
205            SeqPart::Splice(term) => print_seq_splice(printer, term),
206        }
207    }
208}
209
210/// Print the parts of a tuple [`Term`], merging spread tuples.
211fn print_tuple_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
212    for part in parts {
213        match part {
214            SeqPart::Item(term) => print_term(printer, term),
215            SeqPart::Splice(Term::Tuple(nested)) => print_tuple_parts(printer, nested),
216            SeqPart::Splice(term) => print_seq_splice(printer, term),
217        }
218    }
219}
220
221fn print_symbol_name<'a>(printer: &mut Printer<'a>, name: &'a SymbolName) {
222    printer.text(format_symbol_name(name.as_ref()));
223}
224
225fn print_symbol_ident<'a>(printer: &mut Printer<'a>, symbol_ident: &'a SymbolIdent) {
226    print_versioned_symbol_name(printer, &symbol_ident.name, symbol_ident.version.as_ref());
227}
228
229fn print_versioned_symbol_name<'a>(
230    printer: &mut Printer<'a>,
231    name: &'a SymbolName,
232    version: Option<&semver::Version>,
233) {
234    match version {
235        Some(version) => printer.text(format!("{name}@{version}")),
236        None => print_symbol_name(printer, name),
237    }
238}
239
240fn print_var_name<'a>(printer: &mut Printer<'a>, name: &'a VarName) {
241    printer.text(format!("{name}"));
242}
243
244fn print_link_name<'a>(printer: &mut Printer<'a>, name: &'a LinkName) {
245    printer.text(format!("{name}"));
246}
247
248fn print_port_lists<'a>(
249    printer: &mut Printer<'a>,
250    inputs: &'a [LinkName],
251    outputs: &'a [LinkName],
252) {
253    // If the node/region has no ports, we avoid printing the port lists.
254    // This is especially important for the syntax of nodes that introduce symbols
255    // since these nodes never have any input or output ports.
256    if inputs.is_empty() && outputs.is_empty() {
257        return;
258    }
259
260    // The group encodes the preference that the port lists occur on the same
261    // line whenever possible.
262    printer.group_enter();
263    printer.brackets_enter();
264    for input in inputs {
265        print_link_name(printer, input);
266    }
267    printer.brackets_exit();
268    printer.brackets_enter();
269    for output in outputs {
270        print_link_name(printer, output);
271    }
272    printer.brackets_exit();
273    printer.group_exit();
274}
275
276fn print_package<'a>(printer: &mut Printer<'a>, package: &'a Package) {
277    printer.parens_enter();
278    printer.text("hugr");
279    printer.text("0");
280    printer.parens_exit();
281
282    for module in &package.modules {
283        printer.parens_enter();
284        printer.text("mod");
285        printer.parens_exit();
286
287        print_module(printer, module);
288    }
289}
290
291fn print_module<'a>(printer: &mut Printer<'a>, module: &'a Module) {
292    for meta in &module.root.meta {
293        print_meta_item(printer, meta);
294    }
295
296    for child in &module.root.children {
297        print_node(printer, child);
298    }
299}
300
301fn print_node<'a>(printer: &mut Printer<'a>, node: &'a Node) {
302    printer.parens_enter();
303
304    printer.group_enter();
305    match &node.operation {
306        Operation::Invalid => printer.text("invalid"),
307        Operation::Dfg => printer.text("dfg"),
308        Operation::Cfg => printer.text("cfg"),
309        Operation::Block => printer.text("block"),
310        Operation::DefineFunc(symbol_signature) => {
311            printer.text("define-func");
312            print_symbol(printer, symbol_signature);
313        }
314        Operation::DeclareFunc(symbol_signature) => {
315            printer.text("declare-func");
316            print_symbol(printer, symbol_signature);
317        }
318        Operation::Custom(term) => {
319            print_term(printer, term);
320        }
321        Operation::DefineAlias(symbol_signature, value) => {
322            printer.text("define-alias");
323            print_symbol(printer, symbol_signature);
324            print_term(printer, value);
325        }
326        Operation::DeclareAlias(symbol_signature) => {
327            printer.text("declare-alias");
328            print_symbol(printer, symbol_signature);
329        }
330        Operation::TailLoop => printer.text("tail-loop"),
331        Operation::Conditional => printer.text("cond"),
332        Operation::DeclareConstructor(symbol_signature) => {
333            printer.text("declare-ctr");
334            print_symbol(printer, symbol_signature);
335        }
336        Operation::DeclareOperation(symbol_signature) => {
337            printer.text("declare-operation");
338            print_symbol(printer, symbol_signature);
339        }
340        Operation::Import(symbol) => {
341            printer.text("import");
342            print_symbol_ident(printer, symbol);
343        }
344    }
345
346    print_port_lists(printer, &node.inputs, &node.outputs);
347    printer.group_exit();
348
349    if let Some(signature) = &node.signature {
350        print_signature(printer, signature);
351    }
352
353    for meta in &node.meta {
354        print_meta_item(printer, meta);
355    }
356
357    for region in &node.regions {
358        print_region(printer, region);
359    }
360
361    printer.parens_exit();
362}
363
364fn print_region<'a>(printer: &mut Printer<'a>, region: &'a Region) {
365    printer.parens_enter();
366    printer.group_enter();
367
368    printer.text(match region.kind {
369        RegionKind::DataFlow => "dfg",
370        RegionKind::ControlFlow => "cfg",
371        RegionKind::Module => "mod",
372    });
373
374    print_port_lists(printer, &region.sources, &region.targets);
375    printer.group_exit();
376
377    if let Some(signature) = &region.signature {
378        print_signature(printer, signature);
379    }
380
381    for meta in &region.meta {
382        print_meta_item(printer, meta);
383    }
384
385    for child in &region.children {
386        print_node(printer, child);
387    }
388
389    printer.parens_exit();
390}
391
392fn print_symbol<'a>(printer: &mut Printer<'a>, symbol: &'a Symbol) {
393    printer.group_enter();
394
395    match symbol.visibility {
396        None => (),
397        Some(Visibility::Private) => printer.text("private"),
398        Some(Visibility::Public) => printer.text("public"),
399    }
400
401    print_versioned_symbol_name(printer, &symbol.name, symbol.version.as_ref());
402
403    for param in &symbol.params {
404        print_param(printer, param);
405    }
406
407    for constraint in &symbol.constraints {
408        print_constraint(printer, constraint);
409    }
410
411    print_term(printer, &symbol.signature);
412    printer.group_exit();
413}
414
415fn print_param<'a>(printer: &mut Printer<'a>, param: &'a Param) {
416    printer.parens_enter();
417    printer.text("param");
418    print_var_name(printer, &param.name);
419    print_term(printer, &param.r#type);
420    printer.parens_exit();
421}
422
423fn print_constraint<'a>(printer: &mut Printer<'a>, constraint: &'a Term) {
424    printer.parens_enter();
425    printer.text("where");
426    print_term(printer, constraint);
427    printer.parens_exit();
428}
429
430fn print_meta_item<'a>(printer: &mut Printer<'a>, meta: &'a Term) {
431    printer.parens_enter();
432    printer.text("meta");
433    print_term(printer, meta);
434    printer.parens_exit();
435}
436
437fn print_signature<'a>(printer: &mut Printer<'a>, signature: &'a Term) {
438    printer.parens_enter();
439    printer.text("signature");
440    print_term(printer, signature);
441    printer.parens_exit();
442}
443
444macro_rules! impl_display {
445    ($t:ident, $print:expr) => {
446        impl Display for $t {
447            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448                let arena = Arena::new();
449                let mut printer = Printer::new(&arena);
450                $print(&mut printer, self);
451                let doc = printer.finish();
452                doc.render_fmt(80, f)
453            }
454        }
455    };
456}
457
458impl_display!(Package, print_package);
459impl_display!(Module, print_module);
460impl_display!(Node, print_node);
461impl_display!(Region, print_region);
462impl_display!(Param, print_param);
463impl_display!(Term, print_term);
464impl_display!(SeqPart, print_seq_part);
465impl_display!(Literal, print_literal);
466impl_display!(Symbol, print_symbol);
467impl_display!(SymbolIdent, print_symbol_ident);
468
469impl Display for VarName {
470    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471        write!(f, "?{}", self.0)
472    }
473}
474
475impl Display for SymbolName {
476    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477        write!(f, "{}", format_symbol_name(self.as_ref()))
478    }
479}
480
481impl Display for LinkName {
482    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483        write!(f, "%{}", self.0)
484    }
485}
486
487/// Format a symbol name as a parseable token.
488///
489/// The text syntax keeps common dotted identifiers bare for readability. Names
490/// produced by frontends may contain punctuation or whitespace, so those names
491/// use a Rust-style raw string form to preserve roundtripping.
492fn format_symbol_name(name: &str) -> Cow<'_, str> {
493    if is_bare_symbol_name(name) {
494        Cow::Borrowed(name)
495    } else {
496        Cow::Owned(quote_raw_symbol_name(name))
497    }
498}
499
500/// Return the raw string representation used for escaped symbol names.
501fn quote_raw_symbol_name(name: &str) -> String {
502    let mut hashes = "#".to_string();
503    while name.contains(&format!("\"{hashes}")) {
504        hashes.push('#');
505    }
506
507    format!("r{hashes}\"{name}\"{hashes}")
508}