1use crate::parser::{ASTNode, ASTNodeType, ParserError, parse};
2use crate::tokenizer::Associativity;
3
4pub fn pretty_print(ast: &ASTNode) -> String {
14 pretty_print_node(ast)
15}
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18enum Side {
19 Left,
20 Right,
21}
22
23fn infix_info(op: &str) -> (u8, Associativity) {
24 match op {
25 ":" => (10, Associativity::Left),
26 " " => (9, Associativity::Left),
27 "," => (8, Associativity::Left),
28 "^" => (5, Associativity::Right),
29 "*" | "/" => (4, Associativity::Left),
30 "+" | "-" => (3, Associativity::Left),
31 "&" => (2, Associativity::Left),
32 "=" | "<" | ">" | "<=" | ">=" | "<>" => (1, Associativity::Left),
33 _ => (0, Associativity::Left),
34 }
35}
36
37fn unary_precedence(op: &str) -> u8 {
38 match op {
39 "#" => 11,
40 "%" => 7,
41 _ => 6,
42 }
43}
44
45fn node_precedence(ast: &ASTNode) -> u8 {
46 match &ast.node_type {
47 ASTNodeType::BinaryOp { op, .. } => infix_info(op).0,
48 ASTNodeType::UnaryOp { op, .. } => unary_precedence(op),
49 _ => 10,
51 }
52}
53
54fn child_needs_parens(
55 child: &ASTNode,
56 parent_op: &str,
57 parent_prec: u8,
58 parent_assoc: Associativity,
59 side: Side,
60) -> bool {
61 let child_prec = node_precedence(child);
62 if child_prec < parent_prec {
63 return true;
64 }
65 if child_prec > parent_prec {
66 return false;
67 }
68
69 match side {
71 Side::Left => {
72 if parent_assoc == Associativity::Right {
73 matches!(child.node_type, ASTNodeType::BinaryOp { .. })
75 } else {
76 false
77 }
78 }
79 Side::Right => {
80 if parent_assoc == Associativity::Left {
81 if let ASTNodeType::BinaryOp { op: child_op, .. } = &child.node_type {
82 if child_op != parent_op {
83 return true;
84 }
85
86 if parent_op == "-" || parent_op == "/" {
88 return true;
89 }
90 }
91 false
92 } else {
93 if let ASTNodeType::BinaryOp { op: child_op, .. } = &child.node_type {
95 return child_op != parent_op;
96 }
97 false
98 }
99 }
100 }
101}
102
103fn unary_operand_needs_parens(unary_op: &str, operand: &ASTNode) -> bool {
104 match unary_op {
105 "%" | "#" => matches!(operand.node_type, ASTNodeType::BinaryOp { .. }),
106 _ => {
107 let operand_prec = node_precedence(operand);
108 operand_prec < unary_precedence(unary_op)
109 && matches!(operand.node_type, ASTNodeType::BinaryOp { .. })
110 }
111 }
112}
113
114fn pretty_child(
115 child: &ASTNode,
116 parent_op: &str,
117 parent_prec: u8,
118 parent_assoc: Associativity,
119 side: Side,
120) -> String {
121 let s = pretty_print_node(child);
122 if child_needs_parens(child, parent_op, parent_prec, parent_assoc, side) {
123 format!("({s})")
124 } else {
125 s
126 }
127}
128
129fn pretty_print_node(ast: &ASTNode) -> String {
130 match &ast.node_type {
131 ASTNodeType::Literal(value) => match value {
132 crate::LiteralValue::Text(s) => {
134 let escaped = s.replace('"', "\"\"");
135 format!("\"{escaped}\"")
136 }
137 _ => format!("{value}"),
138 },
139 ASTNodeType::Reference { reference, .. } => reference.normalise(),
140 ASTNodeType::UnaryOp { op, expr } => {
141 let inner = pretty_print_node(expr);
142 let inner = if unary_operand_needs_parens(op, expr) {
143 format!("({inner})")
144 } else {
145 inner
146 };
147
148 if op == "%" || op == "#" {
149 format!("{inner}{op}")
150 } else {
151 format!("{op}{inner}")
152 }
153 }
154 ASTNodeType::BinaryOp { op, left, right } => {
155 let (prec, assoc) = infix_info(op);
156 let left_s = pretty_child(left, op, prec, assoc, Side::Left);
157 let right_s = pretty_child(right, op, prec, assoc, Side::Right);
158
159 match op.as_str() {
160 ":" => format!("{left_s}:{right_s}"),
163 " " => format!("{left_s} {right_s}"),
164 "," => format!("{left_s}, {right_s}"),
165 _ => format!("{left_s} {op} {right_s}"),
166 }
167 }
168 ASTNodeType::Function { name, args } => {
169 let args_str = args
170 .iter()
171 .map(pretty_print_node)
172 .collect::<Vec<String>>()
173 .join(", ");
174
175 format!("{}({})", name.to_uppercase(), args_str)
176 }
177 ASTNodeType::Call { callee, args } => {
178 let callee_str = pretty_print_node(callee);
179 let callee_rendered = match &callee.node_type {
183 ASTNodeType::Function { .. } | ASTNodeType::Call { .. } => callee_str,
184 _ => format!("({callee_str})"),
185 };
186 let args_str = args
187 .iter()
188 .map(pretty_print_node)
189 .collect::<Vec<String>>()
190 .join(", ");
191 format!("{callee_rendered}({args_str})")
192 }
193 ASTNodeType::Array(rows) => {
194 let rows_str = rows
195 .iter()
196 .map(|row| {
197 row.iter()
198 .map(pretty_print_node)
199 .collect::<Vec<String>>()
200 .join(", ")
201 })
202 .collect::<Vec<String>>()
203 .join("; ");
204
205 format!("{{{rows_str}}}")
206 }
207 }
208}
209
210pub fn canonical_formula(ast: &ASTNode) -> String {
215 format!("={}", pretty_print(ast))
216}
217
218pub fn pretty_parse_render(formula: &str) -> Result<String, ParserError> {
222 if formula.is_empty() {
224 return Ok(String::new());
225 }
226
227 let needs_equals = !formula.starts_with('=');
229 let formula_to_parse = if needs_equals {
230 format!("={formula}")
231 } else {
232 formula.to_string()
233 };
234
235 let ast = parse(&formula_to_parse)?;
237
238 let pretty_printed = pretty_print(&ast);
240
241 if needs_equals {
243 Ok(pretty_printed)
244 } else {
245 Ok(format!("={pretty_printed}"))
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_pretty_print_validation() {
255 let original = "= sum( a1 ,2 ) ";
256 let pretty = pretty_parse_render(original).unwrap();
257 assert_eq!(pretty, "=SUM(A1, 2)");
258
259 let round = pretty_parse_render(&pretty).unwrap();
260 assert_eq!(pretty, round); }
262
263 #[test]
264 fn test_ast_canonicalization() {
265 let formula = "=sum( a1, b2 )";
267 let pretty = pretty_parse_render(formula).unwrap();
268
269 assert_eq!(pretty, "=SUM(A1, B2)");
271
272 let repretty = pretty_parse_render(&pretty).unwrap();
274 assert_eq!(pretty, repretty);
275 }
276
277 #[test]
278 fn test_pretty_print_operators() {
279 let formula = "=a1+b2*3";
280 let pretty = pretty_parse_render(formula).unwrap();
281 assert_eq!(pretty, "=A1 + B2 * 3");
282
283 let formula = "=a1 + b2 * 3";
284 let pretty = pretty_parse_render(formula).unwrap();
285 assert_eq!(pretty, "=A1 + B2 * 3");
286 }
287
288 #[test]
289 fn test_pretty_print_inserts_parentheses_when_needed() {
290 let formula = "=(a1+b2)*c3";
291 let pretty = pretty_parse_render(formula).unwrap();
292 assert_eq!(pretty, "=(A1 + B2) * C3");
293 }
294
295 #[test]
296 fn test_pretty_print_function_nesting() {
297 let formula = "=if(a1>0, sum(b1:b10), average(c1:c10))";
298 let pretty = pretty_parse_render(formula).unwrap();
299 assert_eq!(pretty, "=IF(A1 > 0, SUM(B1:B10), AVERAGE(C1:C10))");
300 }
301
302 #[test]
303 fn test_pretty_print_arrays() {
304 let formula = "={1,2;3,4}";
305 let pretty = pretty_parse_render(formula).unwrap();
306 assert_eq!(pretty, "={1, 2; 3, 4}");
307
308 let formula = "={1, 2; 3, 4}";
309 let pretty = pretty_parse_render(formula).unwrap();
310 assert_eq!(pretty, "={1, 2; 3, 4}");
311 }
312
313 #[test]
314 fn test_pretty_print_references() {
315 let formula = "=Sheet1!$a$1:$b$2";
316 let pretty = pretty_parse_render(formula).unwrap();
317 assert_eq!(pretty, "=Sheet1!$A$1:$B$2");
318
319 let formula = "='My Sheet'!a1";
320 let pretty = pretty_parse_render(formula).unwrap();
321 assert_eq!(pretty, "='My Sheet'!A1");
322 }
323
324 #[test]
325 fn test_pretty_print_text_literals_in_functions() {
326 let formula = "=SUMIFS(A:A, B:B, \"*Parking*\")";
328 let pretty = pretty_parse_render(formula).unwrap();
329 assert_eq!(pretty, "=SUMIFS(A:A, B:B, \"*Parking*\")");
330 }
331
332 #[test]
333 fn test_pretty_print_text_concatenation_and_escaping() {
334 let formula = "=\">=\"&DATE(2024,1,1)";
336 let pretty = pretty_parse_render(formula).unwrap();
337 assert_eq!(pretty, "=\">=\" & DATE(2024, 1, 1)");
338
339 let formula = "=\"He said \"\"Hi\"\"\"";
341 let pretty = pretty_parse_render(formula).unwrap();
342 assert_eq!(pretty, "=\"He said \"\"Hi\"\"\"");
343 }
344
345 #[test]
346 fn test_pretty_print_text_in_arrays() {
347 let formula = "={\"A\", \"B\"; \"C\", \"D\"}";
348 let pretty = pretty_parse_render(formula).unwrap();
349 assert_eq!(pretty, "={\"A\", \"B\"; \"C\", \"D\"}");
350 }
351}