1use std::{borrow::Cow, fmt::Display};
2
3use base64::{prelude::BASE64_STANDARD, Engine as _};
4use pretty::{Arena, DocAllocator as _, RefDoc};
5
6use crate::v0::{Literal, RegionKind};
7
8use super::{
9 LinkName, Module, Node, Operation, Param, Region, SeqPart, Symbol, SymbolName, Term, VarName,
10};
11
12struct Printer<'a> {
13 arena: &'a Arena<'a>,
15 docs: Vec<RefDoc<'a>>,
17 docs_stack: Vec<usize>,
19}
20
21impl<'a> Printer<'a> {
22 fn new(arena: &'a Arena<'a>) -> Self {
23 Self {
24 arena,
25 docs: Vec::new(),
26 docs_stack: Vec::new(),
27 }
28 }
29
30 fn finish(self) -> RefDoc<'a> {
31 let sep = self
32 .arena
33 .concat([self.arena.hardline(), self.arena.hardline()]);
34 self.arena.intersperse(self.docs, sep).into_doc()
35 }
36
37 fn parens_enter(&mut self) {
38 self.delim_open();
39 }
40
41 fn parens_exit(&mut self) {
42 self.delim_close("(", ")", 2);
43 }
44
45 fn brackets_enter(&mut self) {
46 self.delim_open();
47 }
48
49 fn brackets_exit(&mut self) {
50 self.delim_close("[", "]", 1);
51 }
52
53 fn group_enter(&mut self) {
54 self.delim_open();
55 }
56
57 fn group_exit(&mut self) {
58 self.delim_close("", "", 0);
59 }
60
61 fn delim_open(&mut self) {
62 self.docs_stack.push(self.docs.len());
63 }
64
65 fn delim_close(&mut self, open: &'static str, close: &'static str, nesting: isize) {
66 let docs = self.docs.drain(self.docs_stack.pop().unwrap()..);
67 let doc = self.arena.concat([
68 self.arena.text(open),
69 self.arena
70 .intersperse(docs, self.arena.line())
71 .nest(nesting)
72 .group(),
73 self.arena.text(close),
74 ]);
75 self.docs.push(doc.into_doc());
76 }
77
78 fn text(&mut self, text: impl Into<Cow<'a, str>>) {
79 self.docs.push(self.arena.text(text).into_doc());
80 }
81
82 fn int(&mut self, value: u64) {
83 self.text(format!("{}", value));
84 }
85
86 fn string(&mut self, string: &str) {
87 let mut output = String::with_capacity(string.len() + 2);
88 output.push('"');
89
90 for c in string.chars() {
91 match c {
92 '\\' => output.push_str("\\\\"),
93 '"' => output.push_str("\\\""),
94 '\n' => output.push_str("\\n"),
95 '\r' => output.push_str("\\r"),
96 '\t' => output.push_str("\\t"),
97 _ => output.push(c),
98 }
99 }
100
101 output.push('"');
102 self.text(output);
103 }
104
105 fn bytes(&mut self, bytes: &[u8]) {
106 let mut output = String::with_capacity(2 + bytes.len().div_ceil(3) * 4);
108 output.push('"');
109 BASE64_STANDARD.encode_string(bytes, &mut output);
110 output.push('"');
111 self.text(output);
112 }
113}
114
115fn print_term<'a>(printer: &mut Printer<'a>, term: &'a Term) {
116 match term {
117 Term::Wildcard => printer.text("_"),
118 Term::Var(var) => print_var_name(printer, var),
119 Term::Apply(symbol, terms) => {
120 if terms.is_empty() {
121 print_symbol_name(printer, symbol);
122 } else {
123 printer.parens_enter();
124 print_symbol_name(printer, symbol);
125
126 for term in terms.iter() {
127 print_term(printer, term);
128 }
129
130 printer.parens_exit();
131 }
132 }
133 Term::List(list_parts) => {
134 printer.brackets_enter();
135 print_list_parts(printer, list_parts);
136 printer.brackets_exit();
137 }
138 Term::Literal(literal) => {
139 print_literal(printer, literal);
140 }
141 Term::Tuple(tuple_parts) => {
142 printer.parens_enter();
143 printer.text("tuple");
144 print_tuple_parts(printer, tuple_parts);
145 printer.parens_exit();
146 }
147 Term::ExtSet => {
148 printer.parens_enter();
149 printer.text("ext");
150 printer.parens_exit();
151 }
152 Term::Func(region) => {
153 printer.parens_enter();
154 printer.text("fn");
155 print_region(printer, region);
156 printer.parens_exit();
157 }
158 }
159}
160
161fn print_literal<'a>(printer: &mut Printer<'a>, literal: &'a Literal) {
162 match literal {
163 Literal::Str(str) => {
164 printer.string(str);
165 }
166 Literal::Nat(nat) => {
167 printer.int(*nat);
168 }
169 Literal::Bytes(bytes) => {
170 printer.parens_enter();
171 printer.text("bytes");
172 printer.bytes(bytes);
173 printer.parens_exit();
174 }
175 Literal::Float(float) => {
176 printer.text(format!("{:.?}", float.into_inner()));
178 }
179 }
180}
181
182fn print_seq_splice<'a>(printer: &mut Printer<'a>, term: &'a Term) {
183 printer.group_enter();
184 print_term(printer, term);
185 printer.text("...");
186 printer.group_exit();
187}
188
189fn print_seq_part<'a>(printer: &mut Printer<'a>, part: &'a SeqPart) {
191 match part {
192 SeqPart::Item(term) => print_term(printer, term),
193 SeqPart::Splice(term) => print_seq_splice(printer, term),
194 }
195}
196
197fn print_list_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
199 for part in parts {
200 match part {
201 SeqPart::Item(term) => print_term(printer, term),
202 SeqPart::Splice(Term::List(nested)) => print_list_parts(printer, nested),
203 SeqPart::Splice(term) => print_seq_splice(printer, term),
204 }
205 }
206}
207
208fn print_tuple_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
210 for part in parts {
211 match part {
212 SeqPart::Item(term) => print_term(printer, term),
213 SeqPart::Splice(Term::Tuple(nested)) => print_tuple_parts(printer, nested),
214 SeqPart::Splice(term) => print_seq_splice(printer, term),
215 }
216 }
217}
218
219fn print_symbol_name<'a>(printer: &mut Printer<'a>, name: &'a SymbolName) {
220 printer.text(name.0.as_str())
221}
222
223fn print_var_name<'a>(printer: &mut Printer<'a>, name: &'a VarName) {
224 printer.text(format!("{}", name))
225}
226
227fn print_link_name<'a>(printer: &mut Printer<'a>, name: &'a LinkName) {
228 printer.text(format!("{}", name))
229}
230
231fn print_port_lists<'a>(
232 printer: &mut Printer<'a>,
233 inputs: &'a [LinkName],
234 outputs: &'a [LinkName],
235) {
236 if inputs.is_empty() && outputs.is_empty() {
240 return;
241 }
242
243 printer.group_enter();
246 printer.brackets_enter();
247 for input in inputs {
248 print_link_name(printer, input);
249 }
250 printer.brackets_exit();
251 printer.brackets_enter();
252 for output in outputs {
253 print_link_name(printer, output);
254 }
255 printer.brackets_exit();
256 printer.group_exit();
257}
258
259fn print_module<'a>(printer: &mut Printer<'a>, module: &'a Module) {
260 printer.parens_enter();
261 printer.text("hugr");
262 printer.text("0");
263 printer.parens_exit();
264
265 for meta in module.root.meta.iter() {
266 print_meta_item(printer, meta);
267 }
268
269 for child in module.root.children.iter() {
270 print_node(printer, child);
271 }
272}
273
274fn print_node<'a>(printer: &mut Printer<'a>, node: &'a Node) {
275 printer.parens_enter();
276
277 printer.group_enter();
278 match &node.operation {
279 Operation::Invalid => printer.text("invalid"),
280 Operation::Dfg => printer.text("dfg"),
281 Operation::Cfg => printer.text("cfg"),
282 Operation::Block => printer.text("block"),
283 Operation::DefineFunc(symbol_signature) => {
284 printer.text("define-func");
285 print_symbol(printer, symbol_signature);
286 }
287 Operation::DeclareFunc(symbol_signature) => {
288 printer.text("declare-func");
289 print_symbol(printer, symbol_signature);
290 }
291 Operation::Custom(term) => {
292 print_term(printer, term);
293 }
294 Operation::DefineAlias(symbol_signature, value) => {
295 printer.text("define-alias");
296 print_symbol(printer, symbol_signature);
297 print_term(printer, value);
298 }
299 Operation::DeclareAlias(symbol_signature) => {
300 printer.text("declare-alias");
301 print_symbol(printer, symbol_signature);
302 }
303 Operation::TailLoop => printer.text("tail-loop"),
304 Operation::Conditional => printer.text("cond"),
305 Operation::DeclareConstructor(symbol_signature) => {
306 printer.text("declare-ctr");
307 print_symbol(printer, symbol_signature);
308 }
309 Operation::DeclareOperation(symbol_signature) => {
310 printer.text("declare-operation");
311 print_symbol(printer, symbol_signature);
312 }
313 Operation::Import(symbol) => {
314 printer.text("import");
315 print_symbol_name(printer, symbol);
316 }
317 }
318
319 print_port_lists(printer, &node.inputs, &node.outputs);
320 printer.group_exit();
321
322 if let Some(signature) = &node.signature {
323 print_signature(printer, signature);
324 }
325
326 for meta in node.meta.iter() {
327 print_meta_item(printer, meta);
328 }
329
330 for region in node.regions.iter() {
331 print_region(printer, region);
332 }
333
334 printer.parens_exit();
335}
336
337fn print_region<'a>(printer: &mut Printer<'a>, region: &'a Region) {
338 printer.parens_enter();
339 printer.group_enter();
340
341 printer.text(match region.kind {
342 RegionKind::DataFlow => "dfg",
343 RegionKind::ControlFlow => "cfg",
344 RegionKind::Module => "mod",
345 });
346
347 print_port_lists(printer, ®ion.sources, ®ion.targets);
348 printer.group_exit();
349
350 if let Some(signature) = ®ion.signature {
351 print_signature(printer, signature);
352 }
353
354 for meta in region.meta.iter() {
355 print_meta_item(printer, meta);
356 }
357
358 for child in region.children.iter() {
359 print_node(printer, child);
360 }
361
362 printer.parens_exit();
363}
364
365fn print_symbol<'a>(printer: &mut Printer<'a>, symbol: &'a Symbol) {
366 print_symbol_name(printer, &symbol.name);
367
368 for param in symbol.params.iter() {
369 print_param(printer, param);
370 }
371
372 for constraint in symbol.constraints.iter() {
373 print_constraint(printer, constraint);
374 }
375
376 print_term(printer, &symbol.signature);
377}
378
379fn print_param<'a>(printer: &mut Printer<'a>, param: &'a Param) {
380 printer.parens_enter();
381 printer.text("param");
382 print_var_name(printer, ¶m.name);
383 print_term(printer, ¶m.r#type);
384 printer.parens_exit();
385}
386
387fn print_constraint<'a>(printer: &mut Printer<'a>, constraint: &'a Term) {
388 printer.parens_enter();
389 printer.text("where");
390 print_term(printer, constraint);
391 printer.parens_exit();
392}
393
394fn print_meta_item<'a>(printer: &mut Printer<'a>, meta: &'a Term) {
395 printer.parens_enter();
396 printer.text("meta");
397 print_term(printer, meta);
398 printer.parens_exit();
399}
400
401fn print_signature<'a>(printer: &mut Printer<'a>, signature: &'a Term) {
402 printer.parens_enter();
403 printer.text("signature");
404 print_term(printer, signature);
405 printer.parens_exit();
406}
407
408macro_rules! impl_display {
409 ($t:ident, $print:expr) => {
410 impl Display for $t {
411 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412 let arena = Arena::new();
413 let mut printer = Printer::new(&arena);
414 $print(&mut printer, self);
415 let doc = printer.finish();
416 doc.render_fmt(80, f)
417 }
418 }
419 };
420}
421
422impl_display!(Module, print_module);
423impl_display!(Node, print_node);
424impl_display!(Region, print_region);
425impl_display!(Param, print_param);
426impl_display!(Term, print_term);
427impl_display!(SeqPart, print_seq_part);
428impl_display!(Literal, print_literal);
429
430impl Display for VarName {
431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432 write!(f, "?{}", self.0)
433 }
434}
435
436impl Display for SymbolName {
437 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
438 write!(f, "{}", self.0)
439 }
440}
441
442impl Display for LinkName {
443 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444 write!(f, "%{}", self.0)
445 }
446}