polymath_rs/
transformations.rs1use std::collections::VecDeque;
2
3use crate::{
4 ast::{
5 BiExpression, Binary, Expression, Expressions, Group, Literal, Table, TableRow,
6 TriExpression, Unary, AST,
7 },
8 tokens::types::TokenType,
9};
10
11pub fn transform(ast: AST) -> AST {
12 AST {
14 expressions: transform_expressions(ast.expressions),
15 }
16}
17
18fn transform_expressions(expressions: Expressions) -> Expressions {
19 Expressions {
20 expressions: expressions
21 .expressions
22 .into_iter()
23 .map(|expr| transform_expression(expr))
24 .collect(),
25 }
26}
27
28fn transform_expression(expr: Expression) -> Expression {
29 match expr {
30 Expression::Group(group) => transform_group(Group {
31 l_brace: group.l_brace,
32 expressions: transform_expressions(group.expressions),
33 r_brace: group.r_brace,
34 }),
35 Expression::Frac(frac) => transform_frac(frac),
36 Expression::Sub(bi_expression) => Expression::Sub(transform_bi_expression(bi_expression)),
37 Expression::Pow(bi_expression) => Expression::Pow(transform_bi_expression(bi_expression)),
38 Expression::SubPow(tri_expression) => {
39 Expression::SubPow(transform_tri_expression(tri_expression))
40 }
41 Expression::Unary(unary) => Expression::Unary(transform_unary(unary)),
42 Expression::Binary(binary) => Expression::Binary(transform_binary(binary)),
43 Expression::Literal(lit) => Expression::Literal(lit),
44 Expression::Expressions(expressions) => {
45 Expression::Expressions(transform_expressions(expressions))
46 }
47 Expression::Unit => Expression::Unit,
48 }
49}
50
51fn transform_binary(binary: crate::ast::Binary) -> crate::ast::Binary {
52 Binary {
53 operator: binary.operator,
54 expression_1: Box::new(transform_expression(*binary.expression_1)),
55 expression_2: Box::new(transform_expression(*binary.expression_2)),
56 }
57}
58
59fn transform_tri_expression(
60 tri_expression: crate::ast::TriExpression,
61) -> crate::ast::TriExpression {
62 TriExpression {
63 expression_1: Box::new(transform_expression(*tri_expression.expression_1)),
64 expression_2: Box::new(transform_expression(*tri_expression.expression_2)),
65 expression_3: Box::new(transform_expression(*tri_expression.expression_3)),
66 }
67}
68
69fn transform_bi_expression(bi_expression: BiExpression) -> BiExpression {
70 BiExpression {
71 expression_1: Box::new(transform_expression(*bi_expression.expression_1)),
72 expression_2: Box::new(transform_expression(*bi_expression.expression_2)),
73 }
74}
75
76fn transform_frac(frac: crate::ast::BiExpression) -> Expression {
77 Expression::Frac(transform_bi_expression(frac))
78}
79
80fn transform_unary(unary: Unary) -> Unary {
81 Unary {
82 operator: unary.operator,
83 expression: Box::new(transform_expression(*unary.expression)),
84 }
85}
86
87fn transform_group<'a>(group: Group<'a>) -> Expression<'a> {
88 let group = Group {
89 l_brace: group.l_brace,
90 expressions: transform_expressions(group.expressions),
91 r_brace: group.r_brace,
92 };
93
94 let expressions = &group.expressions.expressions;
95
96 let group_expressions = expressions
97 .iter()
98 .enumerate()
99 .filter_map(|(index, expr)| if index % 2 == 0 { Some(expr) } else { None })
100 .collect::<Vec<&Expression<'a>>>();
101
102 let commas = expressions
103 .iter()
104 .skip(1)
105 .enumerate()
106 .filter_map(|(index, expr)| if index % 2 == 0 { Some(expr) } else { None })
107 .collect::<Vec<&Expression<'a>>>();
108
109 let all_groups = group_expressions
110 .iter()
111 .all(|expr| matches!(expr, Expression::Group(_)));
112
113 let all_commas = commas.iter().all(|comma| match comma {
114 Expression::Literal(lit) => {
115 matches!(lit, crate::ast::Literal::Literal(token) if match token.token_type {
116 TokenType::Symbol => matches!(token.span.text, ","),
117 _ => false,
118 })
119 }
120 _ => false,
121 });
122
123 if !all_groups || !all_commas {
124 return Expression::Group(group);
125 }
126
127 let groups = group_expressions
128 .iter()
129 .map(|expression| match expression {
130 Expression::Group(group) => group,
131 _ => unreachable!(),
132 })
133 .collect::<Vec<&Group<'a>>>();
134
135 let group_comma_indicies = groups
136 .iter()
137 .map(|group| {
138 group
139 .expressions
140 .expressions
141 .iter()
142 .enumerate()
143 .filter_map(|(index, expr)| match expr {
144 Expression::Literal(lit) => {
145 if let crate::ast::Literal::Literal(token) = lit {
146 if let TokenType::Symbol = token.token_type {
147 if let "," = token.span.text {
148 Some(index)
149 } else {
150 None
151 }
152 } else {
153 None
154 }
155 } else {
156 None
157 }
158 }
159 _ => None,
160 })
161 .collect::<Vec<usize>>()
162 })
163 .collect::<Vec<Vec<usize>>>();
164
165 let comma_counts_match = group_comma_indicies
166 .iter()
167 .all(|commas| commas.len() == group_comma_indicies[0].len());
168
169 if !comma_counts_match {
170 return Expression::Group(group);
171 }
172
173 let mut table_contents = groups
174 .iter()
175 .zip(&group_comma_indicies)
176 .map(|(group, group_commas)| {
177 let mut group_ranges =
178 group_commas
179 .iter()
180 .fold(Vec::<(usize, usize)>::new(), |mut vec, comma_pos| {
181 if let Some((_, last_pos)) = vec.last() {
182 vec.push((*last_pos + 1, *comma_pos));
183 } else {
184 vec.push((0, *comma_pos));
185 }
186 vec
187 });
188
189 group_ranges.push((
190 group_commas.last().map(|index| index + 1).unwrap_or(0),
191 group.expressions.expressions.len(),
192 ));
193
194 TableRow {
195 cols: group_ranges
196 .iter()
197 .map(|(start_pos, end_pos)| Expressions {
198 expressions: group
199 .expressions
200 .expressions
201 .iter()
202 .skip(*start_pos)
203 .take(end_pos - start_pos)
204 .cloned()
205 .collect::<VecDeque<Expression<'a>>>(),
206 })
207 .collect::<Vec<Expressions<'a>>>(),
208 }
209 })
210 .collect::<Vec<TableRow>>();
211
212 let seperator_pos = table_contents
213 .first()
214 .iter()
215 .flat_map(|row| {
216 row.cols
217 .iter()
218 .enumerate()
219 .filter(|(_, col)| {
220 col.expressions.iter().any(|expr| match expr {
221 Expression::Literal(Literal::Literal(token)) => {
222 matches!(token.span.text, "|")
223 }
224 _ => false,
225 })
226 })
227 .map(|(index, _)| index)
228 })
229 .collect::<Vec<usize>>();
230
231 table_contents = table_contents
233 .into_iter()
234 .map(|row| TableRow {
235 cols: row
236 .cols
237 .into_iter()
238 .filter(|col| {
239 !col.expressions.iter().any(|expr| {
240 if let Expression::Literal(Literal::Literal(token)) = expr {
241 matches!(token.span.text, "|")
242 } else {
243 false
244 }
245 })
246 })
247 .collect::<Vec<Expressions<'a>>>(),
248 })
249 .collect::<Vec<TableRow>>();
250
251 if let Some(second_row) = table_contents.get(1).map(|row| row.cols.len()) {
253 if let Some(first_row) = table_contents.get_mut(0) {
254 for i in 0..second_row {
255 if first_row.cols.get(i).is_none() {
256 first_row.cols.push(Expressions {
257 expressions: VecDeque::new(),
258 });
259 }
260 }
261 }
262 }
263
264 Expression::Literal(crate::ast::Literal::Table(Table {
266 seperators: seperator_pos,
267 l_brace: group.l_brace,
268 rows: table_contents,
269 r_brace: group.r_brace,
270 }))
271}