aftermath/
expr.rs

1use crate::token_stream::{self, Token};
2
3use bumpalo::Bump;
4
5/// A token Tree representing a whole Expression
6/// It lives inside an [Arena](bumpalo::Bump)
7#[derive(Debug, PartialEq)]
8pub enum Expr<'arena> {
9    /// A real number
10    RealNumber {
11        /// The value of the real number
12        val: f64,
13    },
14    /// An imaginary number
15    ImaginaryNumber {
16        /// The value of the imaginary number, without the `i` unit
17        val: f64,
18    },
19    /// Complex number
20    ComplexNumber {
21        /// The value of the complex number's node
22        val: num_complex::Complex64,
23    },
24    /// A variable
25    Binding {
26        /// The name of the variable
27        name: &'arena mut str,
28    },
29    /// A function call, with an variable amount of arguments
30    FunctionCall {
31        /// Name of the function
32        ident: &'arena mut str,
33        /// List of argument in order they appeard
34        args: bumpalo::collections::Vec<'arena, &'arena mut Expr<'arena>>,
35    },
36    /// An operation
37    Operator {
38        /// The operator
39        op: Operator,
40        /// Left side of the operation
41        rhs: &'arena mut Expr<'arena>,
42        /// Right side of the operation
43        lhs: &'arena mut Expr<'arena>,
44    },
45}
46
47impl<'arena> Expr<'arena> {
48    /// Clone an AST with another backing [arena](bumpalo::Bump)
49    #[allow(clippy::mut_from_ref)]
50    pub fn clone_in<'new_arena>(
51        &self,
52        arena: &'new_arena Bump,
53    ) -> &'new_arena mut Expr<'new_arena> {
54        use Expr::{Binding, ComplexNumber, FunctionCall, ImaginaryNumber, Operator, RealNumber};
55        arena.alloc(match self {
56            RealNumber { val } => RealNumber { val: *val },
57            ImaginaryNumber { val } => ImaginaryNumber { val: *val },
58            ComplexNumber { val } => ComplexNumber { val: *val },
59            Binding { name } => Binding {
60                name: arena.alloc_str(name),
61            },
62            FunctionCall { ident, args } => FunctionCall {
63                ident: arena.alloc_str(ident),
64                args: bumpalo::collections::FromIteratorIn::from_iter_in(
65                    args.iter().map(|c| c.clone_in(arena)),
66                    arena,
67                ),
68            },
69            Operator { op, rhs, lhs } => Operator {
70                op: *op,
71                rhs: rhs.clone_in(arena),
72                lhs: lhs.clone_in(arena),
73            },
74        })
75    }
76}
77
78#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
79#[repr(u16)]
80#[allow(missing_docs)]
81pub enum Operator {
82    Plus = 1,
83    Minus = 2,
84
85    Multiply = 11,
86    Divide = 12,
87    Modulo = 13,
88
89    Pow = 21,
90
91    UnaryMinus = 31,
92    UnaryPlus = 32,
93}
94
95#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
96enum Associativity {
97    Right,
98    Left,
99}
100
101impl Operator {
102    #[must_use]
103    /// Get a static str representation of the
104    pub fn as_str(self) -> &'static str {
105        match self {
106            Self::Pow => "^",
107            Self::Plus | Self::UnaryPlus => "+",
108            Self::Minus | Self::UnaryMinus => "-",
109            Self::Divide => "/",
110            Self::Multiply => "*",
111            Self::Modulo => "%",
112        }
113    }
114
115    pub(crate) fn from_str(input: &str) -> Option<Self> {
116        match input {
117            "^" => Some(Self::Pow),
118            "+" => Some(Self::Plus),
119            "-" => Some(Self::Minus),
120            "/" => Some(Self::Divide),
121            "*" => Some(Self::Multiply),
122            "%" => Some(Self::Modulo),
123            _ => None,
124        }
125    }
126
127    fn associativity(self) -> Associativity {
128        match self {
129            Self::Pow => Associativity::Left,
130            _ => Associativity::Right,
131        }
132    }
133
134    fn class(self) -> u8 {
135        self as u8 / 10
136    }
137}
138
139fn function_pass<'input>(
140    mut iter: std::iter::Peekable<
141        impl Iterator<Item = Result<Token<'input>, InvalidToken<'input>>> + 'input,
142    >,
143) -> impl Iterator<Item = Result<Token<'input>, InvalidToken<'input>>> + 'input {
144    let mut need_sep = None;
145    std::iter::from_fn(move || {
146        if let Some(n) = need_sep.as_mut() {
147            *n -= 1;
148            if *n == 0u8 {
149                need_sep = None;
150                Some(Ok(token_stream::Token::Whitespace))
151            } else {
152                iter.next()
153            }
154        } else {
155            let next = iter.next();
156            match &next {
157                Some(Ok(token_stream::Token::ReservedWord(_))) => {
158                    if let Some(Ok(token_stream::Token::LeftParenthesis)) = iter.peek() {
159                        need_sep = Some(2);
160                    }
161                }
162                Some(Ok(token_stream::Token::Comma)) => {
163                    need_sep = Some(1);
164                }
165                _ => {}
166            };
167            next
168        }
169    })
170}
171
172fn implicit_multiple_pass<'input>(
173    mut iter: std::iter::Peekable<
174        impl Iterator<Item = Result<Token<'input>, InvalidToken<'input>>> + 'input,
175    >,
176) -> impl Iterator<Item = Result<Token<'input>, InvalidToken<'input>>> + 'input {
177    let mut need_sep = None;
178    std::iter::from_fn(move || {
179        if let Some(n) = need_sep.as_mut() {
180            *n -= 1;
181            if *n == 0u8 {
182                need_sep = None;
183                Some(Ok(token_stream::Token::Operator(Operator::Multiply)))
184            } else {
185                iter.next()
186            }
187        } else {
188            let next = iter.next();
189            if matches!(
190                &next,
191                Some(Ok(token_stream::Token::Ident(_)
192                    | token_stream::Token::Literal(_)
193                    | token_stream::Token::RightParenthesis))
194            ) {
195                if let Some(Ok(
196                    token_stream::Token::LeftParenthesis
197                    | token_stream::Token::Ident(_)
198                    | token_stream::Token::ReservedWord(_)
199                    | token_stream::Token::Literal(_),
200                )) = iter.peek()
201                {
202                    need_sep = Some(1);
203                }
204            }
205            next
206        }
207    })
208}
209
210fn unary_pass<'input>(
211    mut iter: std::iter::Peekable<
212        impl Iterator<Item = Result<Token<'input>, InvalidToken<'input>>> + 'input,
213    >,
214) -> impl Iterator<Item = Result<Token<'input>, InvalidToken<'input>>> + 'input {
215    let _next = iter.peek_mut().map(|next| match next {
216        Ok(token_stream::Token::Operator(op @ Operator::Minus)) => {
217            *op = Operator::UnaryMinus;
218        }
219        Ok(token_stream::Token::Operator(op @ Operator::Plus)) => {
220            *op = Operator::UnaryPlus;
221        }
222        _ => (),
223    });
224    std::iter::from_fn(move || {
225        let next = iter.next();
226        if let Some(Ok(
227            token_stream::Token::Operator(_)
228            | token_stream::Token::Comma
229            | token_stream::Token::Whitespace
230            | token_stream::Token::LeftParenthesis,
231        )) = next
232        {
233            match iter.peek_mut() {
234                Some(Ok(token_stream::Token::Operator(op @ Operator::Minus))) => {
235                    *op = Operator::UnaryMinus;
236                }
237                Some(Ok(token_stream::Token::Operator(op @ Operator::Plus))) => {
238                    *op = Operator::UnaryPlus;
239                }
240                _ => (),
241            }
242        }
243        next
244    })
245}
246
247pub use token_stream::InvalidToken;
248
249/// Error returned by [Expr::parse](Expr::parse)
250#[derive(Debug, Clone, PartialEq, Eq)]
251pub enum BuildError<'input> {
252    /// An underlying token in the input is wrong
253    InvalidToken(InvalidToken<'input>),
254    /// At least one parenthesis is missing
255    MissingParenthesis,
256    /// At least one operator is missing
257    MissingOperator,
258    /// At least one operand is missing
259    MissingOperand,
260    /// An unknown error occured
261    UnkownError,
262}
263
264impl<'input> From<InvalidToken<'input>> for BuildError<'input> {
265    fn from(value: InvalidToken<'input>) -> Self {
266        Self::InvalidToken(value)
267    }
268}
269
270impl<'arena> std::fmt::Display for Expr<'arena> {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        if f.alternate() {
273            self.to_string_inner(f)
274        } else {
275            self.to_string_inner_min_parens(f, None)
276        }
277    }
278}
279
280// fn print<T: std::fmt::Debug + ?Sized>(t: &T, level: u16) {
281//     println!("{:width$}[{level}]{t:?}", "", width = (level * 4) as usize);
282// }
283
284// thread_local! {static CURRENT_LEVEL: std::cell::Cell<u16> = 0.into();}
285
286impl<'arena> Expr<'arena> {
287    /// Create an AST from an input str
288    ///
289    /// # Errors
290    /// This will error on any wrong input
291    pub fn parse<'input, 'words: 'input + 'word, 'word: 'input>(
292        arena: &'arena Bump,
293        input: &'input str,
294        reserved_words: &'words [&'word str],
295    ) -> Result<&'arena mut Expr<'arena>, BuildError<'input>> {
296        let iter = token_stream::parse_tokens(input, reserved_words);
297        let iter = function_pass(iter.peekable());
298        let iter = implicit_multiple_pass(iter.peekable());
299        let iter = unary_pass(iter.peekable());
300        let iter = iter.fuse();
301        // let iter = iter.inspect(|t| print(&t, CURRENT_LEVEL.with(std::cell::Cell::get)));
302
303        Self::parse_iter(arena, iter, &(true.into()), 0)
304    }
305
306    fn parse_iter<'input, 'words: 'input + 'word, 'word: 'input>(
307        arena: &'arena Bump,
308        mut iter: impl Iterator<Item = Result<Token<'input>, InvalidToken<'input>>>,
309        check_func_sep: &std::cell::Cell<bool>,
310        level: u16,
311    ) -> Result<&'arena mut Expr<'arena>, BuildError<'input>> {
312        let mut output = Vec::<&mut Self>::new();
313        let mut operator = Vec::<Token<'input>>::new();
314        let mut was_function_call = false;
315        loop {
316            if let Some(token) = iter.next() {
317                //print(&format_args!("Output Buffer: {output:?}"), level);
318                match token? {
319                    Token::Whitespace => {
320                        Self::handle_whitespace(
321                            arena,
322                            &mut iter,
323                            check_func_sep,
324                            &mut output,
325                            level,
326                        )?;
327                    }
328                    Token::Literal(v) => output.push(arena.alloc(Expr::RealNumber { val: v })),
329                    Token::Ident(name) => {
330                        output.push(arena.alloc(Expr::Binding {
331                            name: arena.alloc_str(name),
332                        }));
333                    }
334                    Token::ReservedWord(name) => {
335                        was_function_call = true;
336                        // print("FUNCTION CALL", level);
337                        output.push(arena.alloc(Expr::FunctionCall {
338                            ident: arena.alloc_str(name),
339                            args: bumpalo::collections::Vec::with_capacity_in(2, arena),
340                        }));
341                    }
342
343                    Token::Comma => {
344                        Self::handle_comma(arena, &mut operator, &mut output /* level */)?;
345                    }
346                    t @ Token::LeftParenthesis if !was_function_call => operator.push(t),
347                    Token::LeftParenthesis => was_function_call = false,
348                    Token::Operator(op) => {
349                        Self::handle_operator(arena, op, &mut operator, &mut output)?;
350                    }
351                    Token::RightParenthesis => loop {
352                        let Some(op) = operator.pop() else {
353                            // print("Missing Parenthesis Error", level);
354                            return Err(dbg!(BuildError::MissingParenthesis));
355                        };
356                        match op {
357                            Token::LeftParenthesis => break,
358                            Token::Operator(o @ (Operator::UnaryMinus | Operator::UnaryPlus)) => {
359                                let rhs = output.pop().ok_or(BuildError::MissingOperand)?;
360                                output.push(arena.alloc(Expr::Operator {
361                                    op: o,
362                                    lhs: arena.alloc(Expr::RealNumber { val: 0.0 }),
363                                    rhs,
364                                }));
365                            }
366                            Token::Operator(o) => {
367                                let rhs = output.pop().ok_or(BuildError::MissingOperand)?;
368                                let lhs = output.pop().ok_or(BuildError::MissingOperand)?;
369
370                                output.push(arena.alloc(Expr::Operator { op: o, rhs, lhs }));
371                            }
372                            _ => (),
373                        }
374                    },
375                }
376            } else {
377                for op in operator.into_iter().rev() {
378                    match op {
379                        Token::LeftParenthesis => return Err(BuildError::MissingParenthesis),
380                        Token::Operator(o @ (Operator::UnaryMinus | Operator::UnaryPlus)) => {
381                            let rhs = output.pop().ok_or(BuildError::MissingOperand)?;
382                            output.push(arena.alloc(Expr::Operator {
383                                op: o,
384                                lhs: arena.alloc(Expr::RealNumber { val: 0.0 }),
385                                rhs,
386                            }));
387                        }
388                        Token::Operator(o) => {
389                            let rhs = output.pop().ok_or(BuildError::MissingOperand)?;
390                            let lhs = output.pop().ok_or(BuildError::MissingOperand)?;
391
392                            output.push(arena.alloc(Expr::Operator { op: o, rhs, lhs }));
393                        }
394                        Token::Comma | Token::Whitespace => { /* No-op but still an operator */ }
395                        _ => (),
396                    }
397                }
398                break;
399            }
400        }
401        //print(&format_args!("End: {}", output.len()), level);
402        output.pop().ok_or(match output.len() {
403            0 => BuildError::UnkownError,
404            _ => dbg!(BuildError::MissingOperator),
405        })
406    }
407
408    fn handle_comma<'input>(
409        arena: &'arena Bump,
410        operator: &mut Vec<Token>,
411        output: &mut Vec<&'arena mut Self>,
412        // level: u16,
413    ) -> Result<&'arena mut Self, BuildError<'input>> {
414        loop {
415            let Some(op) = operator.pop() else {
416                // print("Missing Parenthesis Error", level);
417                break;
418            };
419            match op {
420                Token::LeftParenthesis => break,
421                Token::Operator(o @ (Operator::UnaryMinus | Operator::UnaryPlus)) => {
422                    let rhs = output.pop().ok_or(BuildError::MissingOperand)?;
423                    output.push(arena.alloc(Expr::Operator {
424                        op: o,
425                        lhs: arena.alloc(Expr::RealNumber { val: 0.0 }),
426                        rhs,
427                    }));
428                }
429                Token::Operator(o) => {
430                    let rhs = output.pop().ok_or(BuildError::MissingOperand)?;
431                    let lhs = output.pop().ok_or(BuildError::MissingOperand)?;
432
433                    output.push(arena.alloc(Expr::Operator { op: o, rhs, lhs }));
434                }
435                _ => (),
436            }
437        }
438        // print(&format_args!("Comma: {:?}", &output), level - 1);
439        // CURRENT_LEVEL.with(|c| c.set(level - 1));
440        output.pop().ok_or(match output.len() {
441            0 => BuildError::UnkownError,
442            _ => BuildError::MissingOperator,
443        })
444    }
445
446    fn handle_operator<'input>(
447        arena: &'arena Bump,
448        op1: Operator,
449        operator: &mut Vec<Token>,
450        output: &mut Vec<&'arena mut Self>,
451    ) -> Result<(), BuildError<'input>> {
452        loop {
453            let Some(peek) = operator.last() else {break;};
454            match peek {
455                Token::Operator(op2)
456                    if op2.class() > op1.class()
457                        || (op1.class() == op2.class()
458                            && op1.associativity() == Associativity::Left) =>
459                {
460                    let op = operator.pop().unwrap();
461                    match op {
462                        Token::Operator(o @ (Operator::UnaryMinus | Operator::UnaryPlus)) => {
463                            let rhs = output.pop().ok_or(BuildError::MissingOperand)?;
464                            output.push(arena.alloc(Expr::Operator {
465                                op: o,
466                                lhs: arena.alloc(Expr::RealNumber { val: 0.0 }),
467                                rhs,
468                            }));
469                        }
470                        Token::Operator(o) => {
471                            let rhs = output.pop().ok_or(BuildError::MissingOperand)?;
472                            let lhs = output.pop().ok_or(BuildError::MissingOperand)?;
473
474                            output.push(arena.alloc(Expr::Operator { op: o, rhs, lhs }));
475                        }
476                        _ => (),
477                    }
478                }
479                _ => break,
480            }
481        }
482        operator.push(Token::Operator(op1));
483        Ok(())
484    }
485
486    #[allow(trivial_casts)]
487    fn handle_whitespace<'input>(
488        arena: &'arena Bump,
489        iter: &mut impl Iterator<Item = Result<Token<'input>, InvalidToken<'input>>>,
490        check_func_sep: &std::cell::Cell<bool>,
491        output: &mut [&'arena mut Self],
492        level: u16,
493    ) -> Result<(), BuildError<'input>> {
494        check_func_sep.set(false);
495        let parens_count = std::cell::Cell::new(1u16);
496        let child_check_func_sep = std::cell::Cell::new(true);
497        let error = std::cell::Cell::new(false);
498        let mut sub_iter = iter
499            .by_ref()
500            .inspect(|t| {
501                // print(
502                //     &((child_whitespace.get() && !matches!(t, &Ok(Token::Comma)))
503                //         || parens_count.get() != 0),
504                // );
505                match t {
506                    Ok(Token::LeftParenthesis) => {
507                        parens_count.set(if let Some(n) = parens_count.get().checked_add(1) {
508                            n
509                        } else {
510                            // print("Error underflow", level);
511                            error.set(true);
512                            255
513                        });
514                    }
515                    Ok(Token::RightParenthesis) => {
516                        parens_count.set(if let Some(n) = parens_count.get().checked_sub(1) {
517                            n
518                        } else {
519                            // print("Error overflow", level);
520                            error.set(true);
521                            255
522                        });
523                    }
524                    _ => (),
525                }
526            })
527            .take_while(|token| {
528                !(matches!(token, Ok(Token::Comma) if child_check_func_sep.get())
529                    || matches!(token, Ok(Token::RightParenthesis) if parens_count.get() == 0,))
530            });
531        // print("LEVEL START", level);
532        // CURRENT_LEVEL.with(|c| c.set(level));
533        let ast = Self::parse_iter(
534            arena,
535            &mut sub_iter as &mut dyn Iterator<Item = Result<Token<'_>, InvalidToken<'_>>>,
536            &child_check_func_sep,
537            level + 1,
538        );
539        // CURRENT_LEVEL.with(|c| c.set(level));
540        // print(
541        //     &format_args!(
542        //         "LEVEL END: {}",
543        //         if error.get() { "Error" } else { "No Error" }
544        //     ),
545        //     level + 1,
546        // );
547        if error.get() {
548            return Err(dbg!(BuildError::UnkownError));
549        }
550        check_func_sep.set(true);
551        // print(&format_args!("output = {:?}", &output), level);
552        match output.last_mut() {
553            Some(Expr::FunctionCall { args, .. }) => {
554                args.push(ast?);
555            }
556            _ => {
557                return Err(BuildError::MissingOperator);
558            }
559        }
560        Ok(())
561    }
562}
563
564/// The real implementation of display
565impl<'arena> Expr<'arena> {
566    fn to_string_inner_min_parens(
567        &self,
568        buf: &mut impl std::fmt::Write,
569        parent_precedence: Option<u8>,
570    ) -> std::fmt::Result {
571        match self {
572            Expr::FunctionCall { ident, args } => {
573                write!(buf, "{ident}(")?;
574                for arg in args.iter().take(args.len() - 1) {
575                    arg.to_string_inner_min_parens(buf, None)?;
576                    write!(buf, ", ")?;
577                }
578                if let Some(arg) = args.last() {
579                    arg.to_string_inner_min_parens(buf, None)?;
580                }
581                write!(buf, ")")?;
582            }
583            Expr::RealNumber { val } if val.is_sign_negative() => write!(buf, "({val})")?,
584            Expr::RealNumber { val } => write!(buf, "{val}")?,
585            Expr::ImaginaryNumber { val } if val.is_sign_negative() => write!(buf, "({val}i)")?,
586            Expr::ImaginaryNumber { val } => write!(buf, "{val}i")?,
587            Expr::ComplexNumber { val }
588                if val.re.is_sign_negative() || val.im.is_sign_negative() =>
589            {
590                write!(buf, "({val})")?;
591            }
592            Expr::ComplexNumber { val } => write!(buf, "{val}")?,
593            Expr::Binding { name } => write!(buf, "{name}")?,
594            Expr::Operator {
595                op: op @ (Operator::UnaryMinus | Operator::UnaryPlus),
596                rhs,
597                ..
598            } => {
599                if parent_precedence.map_or(false, |p| op.class() < p) {
600                    write!(buf, "(")?;
601                    write!(buf, "{}", op.as_str())?;
602                    rhs.to_string_inner_min_parens(buf, Some(op.class()))?;
603                    write!(buf, ")")?;
604                } else {
605                    write!(buf, "{}", op.as_str())?;
606                    rhs.to_string_inner_min_parens(buf, Some(op.class()))?;
607                }
608            }
609            Expr::Operator { op, rhs, lhs } => {
610                if parent_precedence.map_or(false, |p| op.class() < p) {
611                    write!(buf, "(")?;
612                    lhs.to_string_inner_min_parens(buf, Some(op.class()))?;
613                    write!(buf, " {} ", op.as_str())?;
614                    rhs.to_string_inner_min_parens(buf, Some(op.class()))?;
615                    write!(buf, ")")?;
616                } else {
617                    lhs.to_string_inner_min_parens(buf, Some(op.class()))?;
618                    write!(buf, " {} ", op.as_str())?;
619                    rhs.to_string_inner_min_parens(buf, Some(op.class()))?;
620                }
621            }
622        }
623        Ok(())
624    }
625
626    fn to_string_inner(&self, buf: &mut impl std::fmt::Write) -> std::fmt::Result {
627        match self {
628            Expr::FunctionCall { ident, args } => {
629                write!(buf, "{ident}(")?;
630                for arg in args.iter().take(args.len() - 1) {
631                    arg.to_string_inner(buf)?;
632                    write!(buf, ", ")?;
633                }
634                if let Some(arg) = args.last() {
635                    arg.to_string_inner(buf)?;
636                }
637                write!(buf, ")")?;
638            }
639            Expr::RealNumber { val } => write!(buf, "({val})")?,
640            Expr::ImaginaryNumber { val } => write!(buf, "({val}i)")?,
641            Expr::ComplexNumber { val } => write!(buf, "({val})")?,
642            Expr::Binding { name } => write!(buf, "{name}")?,
643            Expr::Operator {
644                op: op @ (Operator::UnaryMinus | Operator::UnaryPlus),
645                rhs,
646                ..
647            } => {
648                write!(buf, "({}", op.as_str())?;
649                rhs.to_string_inner(buf)?;
650                write!(buf, ")")?;
651            }
652            Expr::Operator { op, rhs, lhs } => {
653                write!(buf, "(")?;
654                lhs.to_string_inner(buf)?;
655                write!(buf, " {} ", op.as_str())?;
656                rhs.to_string_inner(buf)?;
657                write!(buf, ")")?;
658            }
659        }
660        Ok(())
661    }
662}
663
664#[cfg(test)]
665mod tests {
666    use super::token_stream::{
667        stream_to_string,
668        Token::{Comma, LeftParenthesis, Literal, ReservedWord, RightParenthesis, Whitespace},
669    };
670    use super::*;
671
672    #[test]
673    fn function_sep() {
674        let input = "max(1, 5)";
675        let stream = token_stream::parse_tokens(input, token_stream::RESTRICTED_WORD);
676        let first_pass = function_pass(stream.peekable());
677
678        let res: Result<Vec<_>, _> = first_pass.collect();
679
680        assert_eq!(
681            res.unwrap(),
682            vec![
683                ReservedWord("max"),
684                LeftParenthesis,
685                Whitespace,
686                Literal(1.0),
687                Comma,
688                Whitespace,
689                Literal(5.0),
690                RightParenthesis
691            ]
692        );
693    }
694
695    #[test]
696    fn implicit_multiple() {
697        let input = "a(1) + 1(1) + 1a + aa + (1)(1)1";
698        let stream = token_stream::parse_tokens(input, token_stream::RESTRICTED_WORD);
699        let first_pass = implicit_multiple_pass(stream.peekable());
700
701        let iter = first_pass
702            .flat_map(|token| [Ok(Whitespace), token].into_iter())
703            .skip(1);
704        let res = stream_to_string(iter);
705
706        assert_eq!(
707            res.unwrap(),
708            "a * ( 1 ) + 1 * ( 1 ) + 1 * a + a * a + ( 1 ) * ( 1 ) * 1"
709        );
710    }
711
712    #[test]
713    fn unary() {
714        let input = "-(-1) + -(+a)";
715
716        let stream = token_stream::parse_tokens(input, token_stream::RESTRICTED_WORD);
717        let iter = stream
718            .flat_map(|token| [Ok(Whitespace), token].into_iter())
719            .skip(1);
720        let res = stream_to_string(iter);
721
722        assert_eq!(res.unwrap(), "- ( - 1 ) + - ( + a )");
723    }
724    #[cfg(test)]
725    mod ast {
726        use super::Expr;
727
728        macro_rules! ast_test {
729            ($name:ident: $input:literal $(=)?) => {
730                #[test]
731                fn $name() {
732                    let arena = bumpalo::Bump::with_capacity(1024);
733                    let expr = Expr::parse(&arena, $input, super::token_stream::RESTRICTED_WORD);
734
735                    let expr = expr.unwrap();
736
737                    dbg!(expr.to_string());
738                    panic!();
739                }
740            };
741
742            ($name:ident: $input:literal = $output:literal) => {
743                #[test]
744                fn $name() {
745                    println!("{}", $input);
746                    let arena = bumpalo::Bump::with_capacity(1024);
747                    let expr = Expr::parse(&arena, $input, super::token_stream::RESTRICTED_WORD);
748
749                    let expr = expr.unwrap();
750                    println!("==================================================");
751
752                    let same_expr =
753                        Expr::parse(&arena, $output, super::token_stream::RESTRICTED_WORD);
754
755                    let same_expr = same_expr.unwrap();
756
757                    assert_eq!(expr.to_string(), $output);
758
759                    assert_eq!(same_expr.to_string(), $output);
760                }
761            };
762        }
763
764        ast_test! {simple_addition: "1+1" = "1 + 1"}
765        ast_test! {simple_substraction: "1-1" = "1 - 1"}
766        ast_test! {simple_multiplication: "1*1" = "1 * 1"}
767        ast_test! {simple_division: "1/1" = "1 / 1"}
768        ast_test! {simple_modulo: "1%1" = "1 % 1"}
769        ast_test! {simple_unary_minus: "--1" = "--1"}
770        ast_test! {simple_unary_plus: "++1" = "++1"}
771
772        ast_test! {mult1: "4 + 2 * 3" = "4 + 2 * 3"}
773        ast_test! {implicit_multi1: "2a2" = "2 * a * 2"}
774
775        ast_test! {complex1: "3 + 4 * 2 / (1 - 5) ^ 2 ^ 3" = "3 + 4 * 2 / (1 - 5) ^ 2 ^ 3"}
776
777        ast_test! {function: "max(exp(7, 10), 3)" = "max(exp(7, 10), 3)"}
778        ast_test! {function2: "max(2exp(7, 10), 3)" = "max(2 * exp(7, 10), 3)"}
779        ast_test! {function3:
780        "exp(exp(exp(exp(exp(exp(1), exp(1))) + 56, 2exp(exp(exp(exp(exp(1), exp(1))), exp(exp(exp(1), exp(exp(exp(1), exp(1))))))))), exp(exp(exp(exp(exp(exp(exp(5 + 7 + 54), exp(5 + 7 + 54))), exp(5 + 7 + 54))), exp(5 + 7 + 54))))" =
781        "exp(exp(exp(exp(exp(exp(1), exp(1))) + 56, 2 * exp(exp(exp(exp(exp(1), exp(1))), exp(exp(exp(1), exp(exp(exp(1), exp(1))))))))), exp(exp(exp(exp(exp(exp(exp(5 + 7 + 54), exp(5 + 7 + 54))), exp(5 + 7 + 54))), exp(5 + 7 + 54))))"}
782        ast_test! {function4: "max(1, 2, 4, 4, 5, 7, 30)" = "max(1, 2, 4, 4, 5, 7, 30)"}
783    }
784}