fixit/
algorithm.rs

1use super::tokens::{BinaryOperator, InfixToken, PostfixToken, StackToken};
2use std::{error::Error, fmt};
3
4/// An error during infix to postfix conversion.
5#[derive(Debug, PartialEq, Eq)]
6pub enum ConvertError {
7    /// The number of [`GroupStart`](InfixToken::GroupStart) and [`GroupEnd`](InfixToken::GroupEnd) tokens did not match.
8    UnbalancedGroups(i32),
9}
10
11impl fmt::Display for ConvertError {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        match self {
14            ConvertError::UnbalancedGroups(_) => write!(f, "Unbalanced groups"),
15        }
16    }
17}
18
19impl Error for ConvertError {}
20
21/// Converts an iterator of [`InfixToken`] to a [`Vec`] of [`PostfixToken`].
22///
23/// Note that the only error checking this function performs is to make sure
24/// that the number of group starts and finishes is equal. This is checked to avoid
25/// executing `unreachable` code.
26///
27/// Critically, this function will not check that the `tokens` argument is a valid
28/// infix token stream, particularly that each [`BinaryOp`](InfixToken::BinaryOp)
29/// has two adjacent [`Operand`](InfixToken::Operand)s.
30///
31/// Returns `Ok(...)` on success, otherwise returns a [`ConvertError`].
32///
33/// # Errors
34///
35/// See [`ConvertError`].
36pub fn convert<Operand, BinaryOp, I>(
37    tokens: I,
38) -> Result<Vec<PostfixToken<Operand, BinaryOp>>, ConvertError>
39where
40    I: IntoIterator<Item = InfixToken<Operand, BinaryOp>>,
41    BinaryOp: BinaryOperator,
42{
43    let mut result = vec![];
44    let mut stack: Vec<StackToken<BinaryOp>> = vec![];
45    let mut group_depth = 0;
46
47    tokens.into_iter().for_each(|token| match token {
48        InfixToken::Operand(name) => result.push(PostfixToken::Operand(name)),
49        InfixToken::BinaryOp(op) => {
50            while stack
51                .last()
52                .map_or(false, |last| stack_to_result(last, &op))
53            {
54                // Safe to `unwrap` because `stack.last()` returned `Some`
55                result.push(stack.pop().unwrap().into());
56            }
57
58            stack.push(StackToken::BinaryOp(op));
59        }
60        InfixToken::GroupStart => {
61            stack.push(StackToken::GroupStart);
62            group_depth += 1;
63        }
64        InfixToken::GroupEnd => {
65            while let Some(last) = stack.pop() {
66                match last {
67                    StackToken::BinaryOp(op) => result.push(PostfixToken::BinaryOp(op)),
68                    StackToken::GroupStart => break,
69                }
70            }
71            group_depth -= 1;
72        }
73    });
74
75    match group_depth {
76        0 => {
77            result.extend(stack.into_iter().rev().map(Into::into));
78            Ok(result)
79        }
80        group_depth => Err(ConvertError::UnbalancedGroups(group_depth)),
81    }
82}
83
84fn stack_to_result<BinaryOp>(last: &StackToken<BinaryOp>, op: &BinaryOp) -> bool
85where
86    BinaryOp: BinaryOperator,
87{
88    match last {
89        StackToken::BinaryOp(last_op) => last_op.precedence() >= op.precedence(),
90        StackToken::GroupStart => false,
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::{convert, BinaryOperator, ConvertError, InfixToken, PostfixToken};
97
98    #[derive(Debug, PartialEq, Eq)]
99    enum TestBinaryOp {
100        Add,
101        Sub,
102        Mul,
103        Div,
104    }
105
106    impl BinaryOperator for TestBinaryOp {
107        fn precedence(&self) -> u8 {
108            match self {
109                TestBinaryOp::Add => 1,
110                TestBinaryOp::Sub => 1,
111                TestBinaryOp::Mul => 2,
112                TestBinaryOp::Div => 2,
113            }
114        }
115    }
116
117    #[test]
118    fn test_ok_1() {
119        let infix_tokens = vec![
120            InfixToken::Operand("m"),
121            InfixToken::BinaryOp(TestBinaryOp::Mul),
122            InfixToken::Operand("n"),
123            InfixToken::BinaryOp(TestBinaryOp::Add),
124            InfixToken::GroupStart,
125            InfixToken::Operand("p"),
126            InfixToken::BinaryOp(TestBinaryOp::Sub),
127            InfixToken::Operand("q"),
128            InfixToken::GroupEnd,
129            InfixToken::BinaryOp(TestBinaryOp::Add),
130            InfixToken::Operand("r"),
131        ];
132
133        assert_eq!(
134            convert(infix_tokens).unwrap(),
135            vec![
136                PostfixToken::Operand("m"),
137                PostfixToken::Operand("n"),
138                PostfixToken::BinaryOp(TestBinaryOp::Mul),
139                PostfixToken::Operand("p"),
140                PostfixToken::Operand("q"),
141                PostfixToken::BinaryOp(TestBinaryOp::Sub),
142                PostfixToken::BinaryOp(TestBinaryOp::Add),
143                PostfixToken::Operand("r"),
144                PostfixToken::BinaryOp(TestBinaryOp::Add),
145            ]
146        );
147    }
148
149    #[test]
150    fn test_ok_2() {
151        let infix_tokens = vec![
152            InfixToken::Operand("a"),
153            InfixToken::BinaryOp(TestBinaryOp::Add),
154            InfixToken::Operand("b"),
155            InfixToken::BinaryOp(TestBinaryOp::Mul),
156            InfixToken::Operand("c"),
157            InfixToken::BinaryOp(TestBinaryOp::Add),
158            InfixToken::Operand("d"),
159        ];
160
161        assert_eq!(
162            convert(infix_tokens).unwrap(),
163            vec![
164                PostfixToken::Operand("a"),
165                PostfixToken::Operand("b"),
166                PostfixToken::Operand("c"),
167                PostfixToken::BinaryOp(TestBinaryOp::Mul),
168                PostfixToken::BinaryOp(TestBinaryOp::Add),
169                PostfixToken::Operand("d"),
170                PostfixToken::BinaryOp(TestBinaryOp::Add)
171            ]
172        );
173    }
174
175    #[test]
176    fn test_ok_3() {
177        let infix_tokens = vec![
178            InfixToken::GroupStart,
179            InfixToken::GroupStart,
180            InfixToken::Operand("a"),
181            InfixToken::BinaryOp(TestBinaryOp::Add),
182            InfixToken::Operand("b"),
183            InfixToken::GroupEnd,
184            InfixToken::BinaryOp(TestBinaryOp::Sub),
185            InfixToken::Operand("c"),
186            InfixToken::BinaryOp(TestBinaryOp::Mul),
187            InfixToken::GroupStart,
188            InfixToken::Operand("d"),
189            InfixToken::BinaryOp(TestBinaryOp::Div),
190            InfixToken::Operand("e"),
191            InfixToken::GroupEnd,
192            InfixToken::GroupEnd,
193            InfixToken::BinaryOp(TestBinaryOp::Add),
194            InfixToken::Operand("f"),
195        ];
196
197        assert_eq!(
198            convert(infix_tokens).unwrap(),
199            vec![
200                PostfixToken::Operand("a"),
201                PostfixToken::Operand("b"),
202                PostfixToken::BinaryOp(TestBinaryOp::Add),
203                PostfixToken::Operand("c"),
204                PostfixToken::Operand("d"),
205                PostfixToken::Operand("e"),
206                PostfixToken::BinaryOp(TestBinaryOp::Div),
207                PostfixToken::BinaryOp(TestBinaryOp::Mul),
208                PostfixToken::BinaryOp(TestBinaryOp::Sub),
209                PostfixToken::Operand("f"),
210                PostfixToken::BinaryOp(TestBinaryOp::Add),
211            ]
212        );
213    }
214
215    #[test]
216    fn test_bad_1() {
217        let infix_tokens = vec![
218            InfixToken::GroupStart,
219            InfixToken::GroupStart,
220            InfixToken::Operand("a"),
221            InfixToken::BinaryOp(TestBinaryOp::Add),
222            InfixToken::Operand("b"),
223            InfixToken::GroupEnd,
224            InfixToken::BinaryOp(TestBinaryOp::Sub),
225            InfixToken::Operand("c"),
226            InfixToken::BinaryOp(TestBinaryOp::Mul),
227            InfixToken::GroupStart,
228            InfixToken::Operand("d"),
229            InfixToken::BinaryOp(TestBinaryOp::Div),
230            InfixToken::Operand("e"),
231            InfixToken::GroupEnd,
232            InfixToken::GroupEnd,
233            InfixToken::BinaryOp(TestBinaryOp::Add),
234            InfixToken::Operand("f"),
235            InfixToken::GroupStart, // Extra
236        ];
237
238        let result = convert(infix_tokens).unwrap_err();
239
240        assert_eq!(result, ConvertError::UnbalancedGroups(1));
241        assert_eq!(result.to_string(), "Unbalanced groups");
242    }
243
244    #[test]
245    fn test_bad_2() {
246        let infix_tokens = vec![
247            InfixToken::GroupStart,
248            InfixToken::GroupStart,
249            InfixToken::Operand("a"),
250            InfixToken::BinaryOp(TestBinaryOp::Add),
251            InfixToken::Operand("b"),
252            InfixToken::GroupEnd,
253            InfixToken::BinaryOp(TestBinaryOp::Sub),
254            InfixToken::Operand("c"),
255            InfixToken::BinaryOp(TestBinaryOp::Mul),
256            InfixToken::GroupStart,
257            InfixToken::Operand("d"),
258            InfixToken::BinaryOp(TestBinaryOp::Div),
259            InfixToken::Operand("e"),
260            InfixToken::GroupEnd,
261            InfixToken::BinaryOp(TestBinaryOp::Add),
262            InfixToken::Operand("f"),
263            // Missing GroupEnd
264        ];
265
266        let result = convert(infix_tokens).unwrap_err();
267
268        assert_eq!(result, ConvertError::UnbalancedGroups(1));
269        assert_eq!(result.to_string(), "Unbalanced groups");
270    }
271}