1use std::{borrow::Cow, fmt::Display};
2
3use base64::{Engine as _, prelude::BASE64_STANDARD};
4use pretty::{Arena, DocAllocator as _, RefDoc};
5
6use crate::v0::{Literal, RegionKind};
7
8use super::parse::is_bare_symbol_name;
9use super::{
10 LinkName, Module, Node, Operation, Package, Param, Region, SeqPart, Symbol, SymbolIdent,
11 SymbolName, Term, VarName, Visibility,
12};
13
14struct Printer<'a> {
15 arena: &'a Arena<'a>,
17 docs: Vec<RefDoc<'a>>,
19 docs_stack: Vec<usize>,
21}
22
23impl<'a> Printer<'a> {
24 fn new(arena: &'a Arena<'a>) -> Self {
25 Self {
26 arena,
27 docs: Vec::new(),
28 docs_stack: Vec::new(),
29 }
30 }
31
32 fn finish(self) -> RefDoc<'a> {
33 let sep = self
34 .arena
35 .concat([self.arena.hardline(), self.arena.hardline()]);
36 self.arena.intersperse(self.docs, sep).into_doc()
37 }
38
39 fn parens_enter(&mut self) {
40 self.delim_open();
41 }
42
43 fn parens_exit(&mut self) {
44 self.delim_close("(", ")", 2);
45 }
46
47 fn brackets_enter(&mut self) {
48 self.delim_open();
49 }
50
51 fn brackets_exit(&mut self) {
52 self.delim_close("[", "]", 1);
53 }
54
55 fn group_enter(&mut self) {
56 self.delim_open();
57 }
58
59 fn group_exit(&mut self) {
60 self.delim_close("", "", 0);
61 }
62
63 fn delim_open(&mut self) {
64 self.docs_stack.push(self.docs.len());
65 }
66
67 fn delim_close(&mut self, open: &'static str, close: &'static str, nesting: isize) {
68 let docs = self.docs.drain(self.docs_stack.pop().unwrap()..);
69 let doc = self.arena.concat([
70 self.arena.text(open),
71 self.arena
72 .intersperse(docs, self.arena.line())
73 .nest(nesting)
74 .group(),
75 self.arena.text(close),
76 ]);
77 self.docs.push(doc.into_doc());
78 }
79
80 fn text(&mut self, text: impl Into<Cow<'a, str>>) {
81 self.docs.push(self.arena.text(text).into_doc());
82 }
83
84 fn int(&mut self, value: u64) {
85 self.text(format!("{value}"));
86 }
87
88 fn string(&mut self, string: &str) {
89 self.text(quote_string(string));
90 }
91
92 fn bytes(&mut self, bytes: &[u8]) {
93 let mut output = String::with_capacity(2 + bytes.len().div_ceil(3) * 4);
95 output.push('"');
96 BASE64_STANDARD.encode_string(bytes, &mut output);
97 output.push('"');
98 self.text(output);
99 }
100}
101
102fn quote_string(string: &str) -> String {
104 let mut output = String::with_capacity(string.len() + 2);
105 output.push('"');
106
107 for c in string.chars() {
108 match c {
109 '\\' => output.push_str("\\\\"),
110 '"' => output.push_str("\\\""),
111 '\n' => output.push_str("\\n"),
112 '\r' => output.push_str("\\r"),
113 '\t' => output.push_str("\\t"),
114 _ => output.push(c),
115 }
116 }
117
118 output.push('"');
119 output
120}
121
122fn print_term<'a>(printer: &mut Printer<'a>, term: &'a Term) {
123 match term {
124 Term::Wildcard => printer.text("_"),
125 Term::Var(var) => print_var_name(printer, var),
126 Term::Apply(symbol, terms) => {
127 if terms.is_empty() {
128 print_symbol_ident(printer, symbol);
129 } else {
130 printer.parens_enter();
131 print_symbol_ident(printer, symbol);
132
133 for term in terms.iter() {
134 print_term(printer, term);
135 }
136
137 printer.parens_exit();
138 }
139 }
140 Term::List(list_parts) => {
141 printer.brackets_enter();
142 print_list_parts(printer, list_parts);
143 printer.brackets_exit();
144 }
145 Term::Literal(literal) => {
146 print_literal(printer, literal);
147 }
148 Term::Tuple(tuple_parts) => {
149 printer.parens_enter();
150 printer.text("tuple");
151 print_tuple_parts(printer, tuple_parts);
152 printer.parens_exit();
153 }
154 Term::Func(region) => {
155 printer.parens_enter();
156 printer.text("fn");
157 print_region(printer, region);
158 printer.parens_exit();
159 }
160 }
161}
162
163fn print_literal<'a>(printer: &mut Printer<'a>, literal: &'a Literal) {
164 match literal {
165 Literal::Str(str) => {
166 printer.string(str);
167 }
168 Literal::Nat(nat) => {
169 printer.int(*nat);
170 }
171 Literal::Bytes(bytes) => {
172 printer.parens_enter();
173 printer.text("bytes");
174 printer.bytes(bytes);
175 printer.parens_exit();
176 }
177 Literal::Float(float) => {
178 printer.text(format!("{:.?}", float.into_inner()));
180 }
181 }
182}
183
184fn print_seq_splice<'a>(printer: &mut Printer<'a>, term: &'a Term) {
185 printer.group_enter();
186 print_term(printer, term);
187 printer.text("...");
188 printer.group_exit();
189}
190
191fn print_seq_part<'a>(printer: &mut Printer<'a>, part: &'a SeqPart) {
193 match part {
194 SeqPart::Item(term) => print_term(printer, term),
195 SeqPart::Splice(term) => print_seq_splice(printer, term),
196 }
197}
198
199fn print_list_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
201 for part in parts {
202 match part {
203 SeqPart::Item(term) => print_term(printer, term),
204 SeqPart::Splice(Term::List(nested)) => print_list_parts(printer, nested),
205 SeqPart::Splice(term) => print_seq_splice(printer, term),
206 }
207 }
208}
209
210fn print_tuple_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
212 for part in parts {
213 match part {
214 SeqPart::Item(term) => print_term(printer, term),
215 SeqPart::Splice(Term::Tuple(nested)) => print_tuple_parts(printer, nested),
216 SeqPart::Splice(term) => print_seq_splice(printer, term),
217 }
218 }
219}
220
221fn print_symbol_name<'a>(printer: &mut Printer<'a>, name: &'a SymbolName) {
222 printer.text(format_symbol_name(name.as_ref()));
223}
224
225fn print_symbol_ident<'a>(printer: &mut Printer<'a>, symbol_ident: &'a SymbolIdent) {
226 print_versioned_symbol_name(printer, &symbol_ident.name, symbol_ident.version.as_ref());
227}
228
229fn print_versioned_symbol_name<'a>(
230 printer: &mut Printer<'a>,
231 name: &'a SymbolName,
232 version: Option<&semver::Version>,
233) {
234 match version {
235 Some(version) => printer.text(format!("{name}@{version}")),
236 None => print_symbol_name(printer, name),
237 }
238}
239
240fn print_var_name<'a>(printer: &mut Printer<'a>, name: &'a VarName) {
241 printer.text(format!("{name}"));
242}
243
244fn print_link_name<'a>(printer: &mut Printer<'a>, name: &'a LinkName) {
245 printer.text(format!("{name}"));
246}
247
248fn print_port_lists<'a>(
249 printer: &mut Printer<'a>,
250 inputs: &'a [LinkName],
251 outputs: &'a [LinkName],
252) {
253 if inputs.is_empty() && outputs.is_empty() {
257 return;
258 }
259
260 printer.group_enter();
263 printer.brackets_enter();
264 for input in inputs {
265 print_link_name(printer, input);
266 }
267 printer.brackets_exit();
268 printer.brackets_enter();
269 for output in outputs {
270 print_link_name(printer, output);
271 }
272 printer.brackets_exit();
273 printer.group_exit();
274}
275
276fn print_package<'a>(printer: &mut Printer<'a>, package: &'a Package) {
277 printer.parens_enter();
278 printer.text("hugr");
279 printer.text("0");
280 printer.parens_exit();
281
282 for module in &package.modules {
283 printer.parens_enter();
284 printer.text("mod");
285 printer.parens_exit();
286
287 print_module(printer, module);
288 }
289}
290
291fn print_module<'a>(printer: &mut Printer<'a>, module: &'a Module) {
292 for meta in &module.root.meta {
293 print_meta_item(printer, meta);
294 }
295
296 for child in &module.root.children {
297 print_node(printer, child);
298 }
299}
300
301fn print_node<'a>(printer: &mut Printer<'a>, node: &'a Node) {
302 printer.parens_enter();
303
304 printer.group_enter();
305 match &node.operation {
306 Operation::Invalid => printer.text("invalid"),
307 Operation::Dfg => printer.text("dfg"),
308 Operation::Cfg => printer.text("cfg"),
309 Operation::Block => printer.text("block"),
310 Operation::DefineFunc(symbol_signature) => {
311 printer.text("define-func");
312 print_symbol(printer, symbol_signature);
313 }
314 Operation::DeclareFunc(symbol_signature) => {
315 printer.text("declare-func");
316 print_symbol(printer, symbol_signature);
317 }
318 Operation::Custom(term) => {
319 print_term(printer, term);
320 }
321 Operation::DefineAlias(symbol_signature, value) => {
322 printer.text("define-alias");
323 print_symbol(printer, symbol_signature);
324 print_term(printer, value);
325 }
326 Operation::DeclareAlias(symbol_signature) => {
327 printer.text("declare-alias");
328 print_symbol(printer, symbol_signature);
329 }
330 Operation::TailLoop => printer.text("tail-loop"),
331 Operation::Conditional => printer.text("cond"),
332 Operation::DeclareConstructor(symbol_signature) => {
333 printer.text("declare-ctr");
334 print_symbol(printer, symbol_signature);
335 }
336 Operation::DeclareOperation(symbol_signature) => {
337 printer.text("declare-operation");
338 print_symbol(printer, symbol_signature);
339 }
340 Operation::Import(symbol) => {
341 printer.text("import");
342 print_symbol_ident(printer, symbol);
343 }
344 }
345
346 print_port_lists(printer, &node.inputs, &node.outputs);
347 printer.group_exit();
348
349 if let Some(signature) = &node.signature {
350 print_signature(printer, signature);
351 }
352
353 for meta in &node.meta {
354 print_meta_item(printer, meta);
355 }
356
357 for region in &node.regions {
358 print_region(printer, region);
359 }
360
361 printer.parens_exit();
362}
363
364fn print_region<'a>(printer: &mut Printer<'a>, region: &'a Region) {
365 printer.parens_enter();
366 printer.group_enter();
367
368 printer.text(match region.kind {
369 RegionKind::DataFlow => "dfg",
370 RegionKind::ControlFlow => "cfg",
371 RegionKind::Module => "mod",
372 });
373
374 print_port_lists(printer, ®ion.sources, ®ion.targets);
375 printer.group_exit();
376
377 if let Some(signature) = ®ion.signature {
378 print_signature(printer, signature);
379 }
380
381 for meta in ®ion.meta {
382 print_meta_item(printer, meta);
383 }
384
385 for child in ®ion.children {
386 print_node(printer, child);
387 }
388
389 printer.parens_exit();
390}
391
392fn print_symbol<'a>(printer: &mut Printer<'a>, symbol: &'a Symbol) {
393 printer.group_enter();
394
395 match symbol.visibility {
396 None => (),
397 Some(Visibility::Private) => printer.text("private"),
398 Some(Visibility::Public) => printer.text("public"),
399 }
400
401 print_versioned_symbol_name(printer, &symbol.name, symbol.version.as_ref());
402
403 for param in &symbol.params {
404 print_param(printer, param);
405 }
406
407 for constraint in &symbol.constraints {
408 print_constraint(printer, constraint);
409 }
410
411 print_term(printer, &symbol.signature);
412 printer.group_exit();
413}
414
415fn print_param<'a>(printer: &mut Printer<'a>, param: &'a Param) {
416 printer.parens_enter();
417 printer.text("param");
418 print_var_name(printer, ¶m.name);
419 print_term(printer, ¶m.r#type);
420 printer.parens_exit();
421}
422
423fn print_constraint<'a>(printer: &mut Printer<'a>, constraint: &'a Term) {
424 printer.parens_enter();
425 printer.text("where");
426 print_term(printer, constraint);
427 printer.parens_exit();
428}
429
430fn print_meta_item<'a>(printer: &mut Printer<'a>, meta: &'a Term) {
431 printer.parens_enter();
432 printer.text("meta");
433 print_term(printer, meta);
434 printer.parens_exit();
435}
436
437fn print_signature<'a>(printer: &mut Printer<'a>, signature: &'a Term) {
438 printer.parens_enter();
439 printer.text("signature");
440 print_term(printer, signature);
441 printer.parens_exit();
442}
443
444macro_rules! impl_display {
445 ($t:ident, $print:expr) => {
446 impl Display for $t {
447 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448 let arena = Arena::new();
449 let mut printer = Printer::new(&arena);
450 $print(&mut printer, self);
451 let doc = printer.finish();
452 doc.render_fmt(80, f)
453 }
454 }
455 };
456}
457
458impl_display!(Package, print_package);
459impl_display!(Module, print_module);
460impl_display!(Node, print_node);
461impl_display!(Region, print_region);
462impl_display!(Param, print_param);
463impl_display!(Term, print_term);
464impl_display!(SeqPart, print_seq_part);
465impl_display!(Literal, print_literal);
466impl_display!(Symbol, print_symbol);
467impl_display!(SymbolIdent, print_symbol_ident);
468
469impl Display for VarName {
470 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471 write!(f, "?{}", self.0)
472 }
473}
474
475impl Display for SymbolName {
476 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477 write!(f, "{}", format_symbol_name(self.as_ref()))
478 }
479}
480
481impl Display for LinkName {
482 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483 write!(f, "%{}", self.0)
484 }
485}
486
487fn format_symbol_name(name: &str) -> Cow<'_, str> {
493 if is_bare_symbol_name(name) {
494 Cow::Borrowed(name)
495 } else {
496 Cow::Owned(quote_raw_symbol_name(name))
497 }
498}
499
500fn quote_raw_symbol_name(name: &str) -> String {
502 let mut hashes = "#".to_string();
503 while name.contains(&format!("\"{hashes}")) {
504 hashes.push('#');
505 }
506
507 format!("r{hashes}\"{name}\"{hashes}")
508}