1use std::{borrow::Cow, fmt::Display};
2
3use base64::{Engine as _, prelude::BASE64_STANDARD};
4use pretty::{Arena, DocAllocator as _, RefDoc};
5
6use crate::v0::{Literal, RegionKind};
7
8use super::{
9 LinkName, Module, Node, Operation, Package, Param, Region, SeqPart, Symbol, SymbolName, Term,
10 VarName, Visibility,
11};
12
13struct Printer<'a> {
14 arena: &'a Arena<'a>,
16 docs: Vec<RefDoc<'a>>,
18 docs_stack: Vec<usize>,
20}
21
22impl<'a> Printer<'a> {
23 fn new(arena: &'a Arena<'a>) -> Self {
24 Self {
25 arena,
26 docs: Vec::new(),
27 docs_stack: Vec::new(),
28 }
29 }
30
31 fn finish(self) -> RefDoc<'a> {
32 let sep = self
33 .arena
34 .concat([self.arena.hardline(), self.arena.hardline()]);
35 self.arena.intersperse(self.docs, sep).into_doc()
36 }
37
38 fn parens_enter(&mut self) {
39 self.delim_open();
40 }
41
42 fn parens_exit(&mut self) {
43 self.delim_close("(", ")", 2);
44 }
45
46 fn brackets_enter(&mut self) {
47 self.delim_open();
48 }
49
50 fn brackets_exit(&mut self) {
51 self.delim_close("[", "]", 1);
52 }
53
54 fn group_enter(&mut self) {
55 self.delim_open();
56 }
57
58 fn group_exit(&mut self) {
59 self.delim_close("", "", 0);
60 }
61
62 fn delim_open(&mut self) {
63 self.docs_stack.push(self.docs.len());
64 }
65
66 fn delim_close(&mut self, open: &'static str, close: &'static str, nesting: isize) {
67 let docs = self.docs.drain(self.docs_stack.pop().unwrap()..);
68 let doc = self.arena.concat([
69 self.arena.text(open),
70 self.arena
71 .intersperse(docs, self.arena.line())
72 .nest(nesting)
73 .group(),
74 self.arena.text(close),
75 ]);
76 self.docs.push(doc.into_doc());
77 }
78
79 fn text(&mut self, text: impl Into<Cow<'a, str>>) {
80 self.docs.push(self.arena.text(text).into_doc());
81 }
82
83 fn int(&mut self, value: u64) {
84 self.text(format!("{value}"));
85 }
86
87 fn string(&mut self, string: &str) {
88 let mut output = String::with_capacity(string.len() + 2);
89 output.push('"');
90
91 for c in string.chars() {
92 match c {
93 '\\' => output.push_str("\\\\"),
94 '"' => output.push_str("\\\""),
95 '\n' => output.push_str("\\n"),
96 '\r' => output.push_str("\\r"),
97 '\t' => output.push_str("\\t"),
98 _ => output.push(c),
99 }
100 }
101
102 output.push('"');
103 self.text(output);
104 }
105
106 fn bytes(&mut self, bytes: &[u8]) {
107 let mut output = String::with_capacity(2 + bytes.len().div_ceil(3) * 4);
109 output.push('"');
110 BASE64_STANDARD.encode_string(bytes, &mut output);
111 output.push('"');
112 self.text(output);
113 }
114}
115
116fn print_term<'a>(printer: &mut Printer<'a>, term: &'a Term) {
117 match term {
118 Term::Wildcard => printer.text("_"),
119 Term::Var(var) => print_var_name(printer, var),
120 Term::Apply(symbol, terms) => {
121 if terms.is_empty() {
122 print_symbol_name(printer, symbol);
123 } else {
124 printer.parens_enter();
125 print_symbol_name(printer, symbol);
126
127 for term in terms.iter() {
128 print_term(printer, term);
129 }
130
131 printer.parens_exit();
132 }
133 }
134 Term::List(list_parts) => {
135 printer.brackets_enter();
136 print_list_parts(printer, list_parts);
137 printer.brackets_exit();
138 }
139 Term::Literal(literal) => {
140 print_literal(printer, literal);
141 }
142 Term::Tuple(tuple_parts) => {
143 printer.parens_enter();
144 printer.text("tuple");
145 print_tuple_parts(printer, tuple_parts);
146 printer.parens_exit();
147 }
148 Term::Func(region) => {
149 printer.parens_enter();
150 printer.text("fn");
151 print_region(printer, region);
152 printer.parens_exit();
153 }
154 }
155}
156
157fn print_literal<'a>(printer: &mut Printer<'a>, literal: &'a Literal) {
158 match literal {
159 Literal::Str(str) => {
160 printer.string(str);
161 }
162 Literal::Nat(nat) => {
163 printer.int(*nat);
164 }
165 Literal::Bytes(bytes) => {
166 printer.parens_enter();
167 printer.text("bytes");
168 printer.bytes(bytes);
169 printer.parens_exit();
170 }
171 Literal::Float(float) => {
172 printer.text(format!("{:.?}", float.into_inner()));
174 }
175 }
176}
177
178fn print_seq_splice<'a>(printer: &mut Printer<'a>, term: &'a Term) {
179 printer.group_enter();
180 print_term(printer, term);
181 printer.text("...");
182 printer.group_exit();
183}
184
185fn print_seq_part<'a>(printer: &mut Printer<'a>, part: &'a SeqPart) {
187 match part {
188 SeqPart::Item(term) => print_term(printer, term),
189 SeqPart::Splice(term) => print_seq_splice(printer, term),
190 }
191}
192
193fn print_list_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
195 for part in parts {
196 match part {
197 SeqPart::Item(term) => print_term(printer, term),
198 SeqPart::Splice(Term::List(nested)) => print_list_parts(printer, nested),
199 SeqPart::Splice(term) => print_seq_splice(printer, term),
200 }
201 }
202}
203
204fn print_tuple_parts<'a>(printer: &mut Printer<'a>, parts: &'a [SeqPart]) {
206 for part in parts {
207 match part {
208 SeqPart::Item(term) => print_term(printer, term),
209 SeqPart::Splice(Term::Tuple(nested)) => print_tuple_parts(printer, nested),
210 SeqPart::Splice(term) => print_seq_splice(printer, term),
211 }
212 }
213}
214
215fn print_symbol_name<'a>(printer: &mut Printer<'a>, name: &'a SymbolName) {
216 printer.text(name.0.as_str());
217}
218
219fn print_var_name<'a>(printer: &mut Printer<'a>, name: &'a VarName) {
220 printer.text(format!("{name}"));
221}
222
223fn print_link_name<'a>(printer: &mut Printer<'a>, name: &'a LinkName) {
224 printer.text(format!("{name}"));
225}
226
227fn print_port_lists<'a>(
228 printer: &mut Printer<'a>,
229 inputs: &'a [LinkName],
230 outputs: &'a [LinkName],
231) {
232 if inputs.is_empty() && outputs.is_empty() {
236 return;
237 }
238
239 printer.group_enter();
242 printer.brackets_enter();
243 for input in inputs {
244 print_link_name(printer, input);
245 }
246 printer.brackets_exit();
247 printer.brackets_enter();
248 for output in outputs {
249 print_link_name(printer, output);
250 }
251 printer.brackets_exit();
252 printer.group_exit();
253}
254
255fn print_package<'a>(printer: &mut Printer<'a>, package: &'a Package) {
256 printer.parens_enter();
257 printer.text("hugr");
258 printer.text("0");
259 printer.parens_exit();
260
261 for module in &package.modules {
262 printer.parens_enter();
263 printer.text("mod");
264 printer.parens_exit();
265
266 print_module(printer, module);
267 }
268}
269
270fn print_module<'a>(printer: &mut Printer<'a>, module: &'a Module) {
271 for meta in &module.root.meta {
272 print_meta_item(printer, meta);
273 }
274
275 for child in &module.root.children {
276 print_node(printer, child);
277 }
278}
279
280fn print_node<'a>(printer: &mut Printer<'a>, node: &'a Node) {
281 printer.parens_enter();
282
283 printer.group_enter();
284 match &node.operation {
285 Operation::Invalid => printer.text("invalid"),
286 Operation::Dfg => printer.text("dfg"),
287 Operation::Cfg => printer.text("cfg"),
288 Operation::Block => printer.text("block"),
289 Operation::DefineFunc(symbol_signature) => {
290 printer.text("define-func");
291 print_symbol(printer, symbol_signature);
292 }
293 Operation::DeclareFunc(symbol_signature) => {
294 printer.text("declare-func");
295 print_symbol(printer, symbol_signature);
296 }
297 Operation::Custom(term) => {
298 print_term(printer, term);
299 }
300 Operation::DefineAlias(symbol_signature, value) => {
301 printer.text("define-alias");
302 print_symbol(printer, symbol_signature);
303 print_term(printer, value);
304 }
305 Operation::DeclareAlias(symbol_signature) => {
306 printer.text("declare-alias");
307 print_symbol(printer, symbol_signature);
308 }
309 Operation::TailLoop => printer.text("tail-loop"),
310 Operation::Conditional => printer.text("cond"),
311 Operation::DeclareConstructor(symbol_signature) => {
312 printer.text("declare-ctr");
313 print_symbol(printer, symbol_signature);
314 }
315 Operation::DeclareOperation(symbol_signature) => {
316 printer.text("declare-operation");
317 print_symbol(printer, symbol_signature);
318 }
319 Operation::Import(symbol) => {
320 printer.text("import");
321 print_symbol_name(printer, symbol);
322 }
323 }
324
325 print_port_lists(printer, &node.inputs, &node.outputs);
326 printer.group_exit();
327
328 if let Some(signature) = &node.signature {
329 print_signature(printer, signature);
330 }
331
332 for meta in &node.meta {
333 print_meta_item(printer, meta);
334 }
335
336 for region in &node.regions {
337 print_region(printer, region);
338 }
339
340 printer.parens_exit();
341}
342
343fn print_region<'a>(printer: &mut Printer<'a>, region: &'a Region) {
344 printer.parens_enter();
345 printer.group_enter();
346
347 printer.text(match region.kind {
348 RegionKind::DataFlow => "dfg",
349 RegionKind::ControlFlow => "cfg",
350 RegionKind::Module => "mod",
351 });
352
353 print_port_lists(printer, ®ion.sources, ®ion.targets);
354 printer.group_exit();
355
356 if let Some(signature) = ®ion.signature {
357 print_signature(printer, signature);
358 }
359
360 for meta in ®ion.meta {
361 print_meta_item(printer, meta);
362 }
363
364 for child in ®ion.children {
365 print_node(printer, child);
366 }
367
368 printer.parens_exit();
369}
370
371fn print_symbol<'a>(printer: &mut Printer<'a>, symbol: &'a Symbol) {
372 match symbol.visibility {
373 None => (),
374 Some(Visibility::Private) => printer.text("private"),
375 Some(Visibility::Public) => printer.text("public"),
376 }
377
378 print_symbol_name(printer, &symbol.name);
379
380 for param in &symbol.params {
381 print_param(printer, param);
382 }
383
384 for constraint in &symbol.constraints {
385 print_constraint(printer, constraint);
386 }
387
388 print_term(printer, &symbol.signature);
389}
390
391fn print_param<'a>(printer: &mut Printer<'a>, param: &'a Param) {
392 printer.parens_enter();
393 printer.text("param");
394 print_var_name(printer, ¶m.name);
395 print_term(printer, ¶m.r#type);
396 printer.parens_exit();
397}
398
399fn print_constraint<'a>(printer: &mut Printer<'a>, constraint: &'a Term) {
400 printer.parens_enter();
401 printer.text("where");
402 print_term(printer, constraint);
403 printer.parens_exit();
404}
405
406fn print_meta_item<'a>(printer: &mut Printer<'a>, meta: &'a Term) {
407 printer.parens_enter();
408 printer.text("meta");
409 print_term(printer, meta);
410 printer.parens_exit();
411}
412
413fn print_signature<'a>(printer: &mut Printer<'a>, signature: &'a Term) {
414 printer.parens_enter();
415 printer.text("signature");
416 print_term(printer, signature);
417 printer.parens_exit();
418}
419
420macro_rules! impl_display {
421 ($t:ident, $print:expr) => {
422 impl Display for $t {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 let arena = Arena::new();
425 let mut printer = Printer::new(&arena);
426 $print(&mut printer, self);
427 let doc = printer.finish();
428 doc.render_fmt(80, f)
429 }
430 }
431 };
432}
433
434impl_display!(Package, print_package);
435impl_display!(Module, print_module);
436impl_display!(Node, print_node);
437impl_display!(Region, print_region);
438impl_display!(Param, print_param);
439impl_display!(Term, print_term);
440impl_display!(SeqPart, print_seq_part);
441impl_display!(Literal, print_literal);
442impl_display!(Symbol, print_symbol);
443
444impl Display for VarName {
445 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
446 write!(f, "?{}", self.0)
447 }
448}
449
450impl Display for SymbolName {
451 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
452 write!(f, "{}", self.0)
453 }
454}
455
456impl Display for LinkName {
457 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458 write!(f, "%{}", self.0)
459 }
460}