polymath_rs/
transformations.rs

1use std::collections::VecDeque;
2
3use crate::{
4    ast::{
5        BiExpression, Binary, Expression, Expressions, Group, Literal, Table, TableRow,
6        TriExpression, Unary, AST,
7    },
8    tokens::types::TokenType,
9};
10
11pub fn transform(ast: AST) -> AST {
12    // TODO apply matrix / vector transformations here
13    AST {
14        expressions: transform_expressions(ast.expressions),
15    }
16}
17
18fn transform_expressions(expressions: Expressions) -> Expressions {
19    Expressions {
20        expressions: expressions
21            .expressions
22            .into_iter()
23            .map(|expr| transform_expression(expr))
24            .collect(),
25    }
26}
27
28fn transform_expression(expr: Expression) -> Expression {
29    match expr {
30        Expression::Group(group) => transform_group(Group {
31            l_brace: group.l_brace,
32            expressions: transform_expressions(group.expressions),
33            r_brace: group.r_brace,
34        }),
35        Expression::Frac(frac) => transform_frac(frac),
36        Expression::Sub(bi_expression) => Expression::Sub(transform_bi_expression(bi_expression)),
37        Expression::Pow(bi_expression) => Expression::Pow(transform_bi_expression(bi_expression)),
38        Expression::SubPow(tri_expression) => {
39            Expression::SubPow(transform_tri_expression(tri_expression))
40        }
41        Expression::Unary(unary) => Expression::Unary(transform_unary(unary)),
42        Expression::Binary(binary) => Expression::Binary(transform_binary(binary)),
43        Expression::Literal(lit) => Expression::Literal(lit),
44        Expression::Expressions(expressions) => {
45            Expression::Expressions(transform_expressions(expressions))
46        }
47        Expression::Unit => Expression::Unit,
48    }
49}
50
51fn transform_binary(binary: crate::ast::Binary) -> crate::ast::Binary {
52    Binary {
53        operator: binary.operator,
54        expression_1: Box::new(transform_expression(*binary.expression_1)),
55        expression_2: Box::new(transform_expression(*binary.expression_2)),
56    }
57}
58
59fn transform_tri_expression(
60    tri_expression: crate::ast::TriExpression,
61) -> crate::ast::TriExpression {
62    TriExpression {
63        expression_1: Box::new(transform_expression(*tri_expression.expression_1)),
64        expression_2: Box::new(transform_expression(*tri_expression.expression_2)),
65        expression_3: Box::new(transform_expression(*tri_expression.expression_3)),
66    }
67}
68
69fn transform_bi_expression(bi_expression: BiExpression) -> BiExpression {
70    BiExpression {
71        expression_1: Box::new(transform_expression(*bi_expression.expression_1)),
72        expression_2: Box::new(transform_expression(*bi_expression.expression_2)),
73    }
74}
75
76fn transform_frac(frac: crate::ast::BiExpression) -> Expression {
77    Expression::Frac(transform_bi_expression(frac))
78}
79
80fn transform_unary(unary: Unary) -> Unary {
81    Unary {
82        operator: unary.operator,
83        expression: Box::new(transform_expression(*unary.expression)),
84    }
85}
86
87fn transform_group<'a>(group: Group<'a>) -> Expression<'a> {
88    let group = Group {
89        l_brace: group.l_brace,
90        expressions: transform_expressions(group.expressions),
91        r_brace: group.r_brace,
92    };
93
94    let expressions = &group.expressions.expressions;
95
96    let group_expressions = expressions
97        .iter()
98        .enumerate()
99        .filter_map(|(index, expr)| if index % 2 == 0 { Some(expr) } else { None })
100        .collect::<Vec<&Expression<'a>>>();
101
102    let commas = expressions
103        .iter()
104        .skip(1)
105        .enumerate()
106        .filter_map(|(index, expr)| if index % 2 == 0 { Some(expr) } else { None })
107        .collect::<Vec<&Expression<'a>>>();
108
109    let all_groups = group_expressions
110        .iter()
111        .all(|expr| matches!(expr, Expression::Group(_)));
112
113    let all_commas = commas.iter().all(|comma| match comma {
114        Expression::Literal(lit) => {
115            matches!(lit, crate::ast::Literal::Literal(token) if match token.token_type {
116                TokenType::Symbol => matches!(token.span.text, ","),
117                _ => false,
118            })
119        }
120        _ => false,
121    });
122
123    if !all_groups || !all_commas {
124        return Expression::Group(group);
125    }
126
127    let groups = group_expressions
128        .iter()
129        .map(|expression| match expression {
130            Expression::Group(group) => group,
131            _ => unreachable!(),
132        })
133        .collect::<Vec<&Group<'a>>>();
134
135    let group_comma_indicies = groups
136        .iter()
137        .map(|group| {
138            group
139                .expressions
140                .expressions
141                .iter()
142                .enumerate()
143                .filter_map(|(index, expr)| match expr {
144                    Expression::Literal(lit) => {
145                        if let crate::ast::Literal::Literal(token) = lit {
146                            if let TokenType::Symbol = token.token_type {
147                                if let "," = token.span.text {
148                                    Some(index)
149                                } else {
150                                    None
151                                }
152                            } else {
153                                None
154                            }
155                        } else {
156                            None
157                        }
158                    }
159                    _ => None,
160                })
161                .collect::<Vec<usize>>()
162        })
163        .collect::<Vec<Vec<usize>>>();
164
165    let comma_counts_match = group_comma_indicies
166        .iter()
167        .all(|commas| commas.len() == group_comma_indicies[0].len());
168
169    if !comma_counts_match {
170        return Expression::Group(group);
171    }
172
173    let mut table_contents = groups
174        .iter()
175        .zip(&group_comma_indicies)
176        .map(|(group, group_commas)| {
177            let mut group_ranges =
178                group_commas
179                    .iter()
180                    .fold(Vec::<(usize, usize)>::new(), |mut vec, comma_pos| {
181                        if let Some((_, last_pos)) = vec.last() {
182                            vec.push((*last_pos + 1, *comma_pos));
183                        } else {
184                            vec.push((0, *comma_pos));
185                        }
186                        vec
187                    });
188
189            group_ranges.push((
190                group_commas.last().map(|index| index + 1).unwrap_or(0),
191                group.expressions.expressions.len(),
192            ));
193
194            TableRow {
195                cols: group_ranges
196                    .iter()
197                    .map(|(start_pos, end_pos)| Expressions {
198                        expressions: group
199                            .expressions
200                            .expressions
201                            .iter()
202                            .skip(*start_pos)
203                            .take(end_pos - start_pos)
204                            .cloned()
205                            .collect::<VecDeque<Expression<'a>>>(),
206                    })
207                    .collect::<Vec<Expressions<'a>>>(),
208            }
209        })
210        .collect::<Vec<TableRow>>();
211
212    let seperator_pos = table_contents
213        .first()
214        .iter()
215        .flat_map(|row| {
216            row.cols
217                .iter()
218                .enumerate()
219                .filter(|(_, col)| {
220                    col.expressions.iter().any(|expr| match expr {
221                        Expression::Literal(Literal::Literal(token)) => {
222                            matches!(token.span.text, "|")
223                        }
224                        _ => false,
225                    })
226                })
227                .map(|(index, _)| index)
228        })
229        .collect::<Vec<usize>>();
230
231    // remove all vertical bars
232    table_contents = table_contents
233        .into_iter()
234        .map(|row| TableRow {
235            cols: row
236                .cols
237                .into_iter()
238                .filter(|col| {
239                    !col.expressions.iter().any(|expr| {
240                        if let Expression::Literal(Literal::Literal(token)) = expr {
241                            matches!(token.span.text, "|")
242                        } else {
243                            false
244                        }
245                    })
246                })
247                .collect::<Vec<Expressions<'a>>>(),
248        })
249        .collect::<Vec<TableRow>>();
250
251    // fill first row back up again
252    if let Some(second_row) = table_contents.get(1).map(|row| row.cols.len()) {
253        if let Some(first_row) = table_contents.get_mut(0) {
254            for i in 0..second_row {
255                if first_row.cols.get(i).is_none() {
256                    first_row.cols.push(Expressions {
257                        expressions: VecDeque::new(),
258                    });
259                }
260            }
261        }
262    }
263
264    // TODO some combinations of parenthesis are actually not accepted for a table layout in the original asciimath
265    Expression::Literal(crate::ast::Literal::Table(Table {
266        seperators: seperator_pos,
267        l_brace: group.l_brace,
268        rows: table_contents,
269        r_brace: group.r_brace,
270    }))
271}