gep_toolkit/operations/
stack_op.rs

1use std::fmt;
2use std::rc::Rc;
3
4use super::expressions::Expression;
5use super::primitives::{Argument, Constant, Modifier, Operator, PrimitiveOperation};
6
7#[derive(Debug, Clone, PartialEq)]
8pub enum StackOperation {
9    Primitive(PrimitiveOperation),
10    Expression(Rc<Expression>, usize),
11}
12impl fmt::Display for StackOperation {
13    fn fmt(&self, f: &mut std::fmt::Formatter) -> fmt::Result {
14        match self {
15            StackOperation::Primitive(pr) => write!(f, "{}", pr),
16            StackOperation::Expression(_expr, index) => write!(f, "EXP[{}]", *index),
17        }
18    }
19}
20
21pub struct Stack<'a> {
22    value: Vec<f64>,
23    args: &'a Vec<f64>,
24}
25
26impl<'a> Stack<'a> {
27    pub fn new(args: &'a Vec<f64>) -> Stack {
28        Stack {
29            value: vec![],
30            args,
31        }
32    }
33
34    fn push(&mut self, value: f64) {
35        self.value.push(value);
36    }
37
38    fn get_arg(&self, index: usize) -> f64 {
39        match self.args.get(index) {
40            None => panic!("Can't find arg with index {} on Stack with {} args", index, self.args.len()),
41            Some(value) => *value,
42        }
43    }
44
45    pub fn pop(&mut self) -> f64 {
46        self.value.pop().unwrap()
47    }
48
49    pub fn len(&self) -> usize { self.value.len() }
50
51    pub fn result(&self) -> f64 {
52        match self.value.last() {
53            None => 0f64,
54            Some(result) => *result,
55        }
56    }
57}
58
59// impl Clone for StackOperation {
60//     fn clone(&self) -> Self {
61//         match self {
62//             StackOperation::Primitive(p) => StackOperation::Primitive(p.clone()),
63//             StackOperation::Expression(expr) => StackOperation::Expression(Rc::clone(expr))
64//         }
65//     }
66// }
67
68impl StackOperation {
69    pub fn update_stack(&self, stack: &mut Stack) {
70        match self {
71            StackOperation::Expression(expression, _) => {
72                if stack.len() >= 2 {
73                    let arg2 = stack.pop();
74                    let arg1 = stack.pop();
75                    let result = expression.compute_result(&vec![arg1, arg2]);
76
77                    let result = if result.is_nan() {
78                        0.0
79                    } else if result.is_infinite() {
80                        f64::MAX
81                    } else {
82                        result
83                    };
84                    stack.push(result);
85                }
86            }
87
88            StackOperation::Primitive(prim_op) => {
89                match prim_op {
90                    PrimitiveOperation::Constant(cons) => {
91                        stack.push(cons.value());
92                    }
93
94                    PrimitiveOperation::Argument(arg) => {
95                        let Argument::Arg(index) = arg;
96                        let arg_value = stack.get_arg(*index as usize);
97                        stack.push(arg_value);
98                    }
99
100                    PrimitiveOperation::Modifier(modifier) => {
101                        if stack.len() >= 1 {
102                            let arg = stack.pop();
103                            let result = modifier.compute(arg);
104
105                            let result = if result.is_nan() {
106                                0.0
107                            } else if result.is_infinite() {
108                                f64::MAX
109                            } else {
110                                result
111                            };
112                            stack.push(result);
113                        }
114                    }
115
116                    PrimitiveOperation::Operator(op) => {
117                        if stack.len() >= 2 {
118                            let arg2 = stack.pop();
119                            let arg1 = stack.pop();
120                            let result = op.compute(arg1, arg2);
121
122                            let result = if result.is_nan() {
123                                0.0
124                            } else if result.is_infinite() {
125                                f64::MAX
126                            } else {
127                                result
128                            };
129                            stack.push(result);
130                        }
131                    }
132                }
133            }
134        }
135    }
136
137    // pub fn is_pure_operation(&self) -> bool {
138    //     match self {
139    //         StackOperation::Constant(_) => true,
140    //         StackOperation::Modifier(_) => true,
141    //         StackOperation::Operator(_) => true,
142    //         _ => false
143    //     }
144    // }
145
146    pub fn construct(operation: impl StackOperationConstructor) -> StackOperation {
147        operation.stack_operation()
148    }
149}
150
151
152pub trait StackOperationConstructor {
153    fn stack_operation(self) -> StackOperation;
154}
155
156impl StackOperationConstructor for Expression {
157    fn stack_operation(self) -> StackOperation {
158        StackOperation::Expression(Rc::new(self), usize::MAX)
159    }
160}
161impl StackOperationConstructor for PrimitiveOperation {
162    fn stack_operation(self) -> StackOperation {
163        StackOperation::Primitive(self)
164    }
165}
166impl StackOperationConstructor for Argument {
167    fn stack_operation(self) -> StackOperation {
168        PrimitiveOperation::Argument(self).stack_operation()
169    }
170}
171impl StackOperationConstructor for Constant {
172    fn stack_operation(self) -> StackOperation {
173        PrimitiveOperation::Constant(self).stack_operation()
174    }
175}
176impl StackOperationConstructor for Modifier {
177    fn stack_operation(self) -> StackOperation {
178        PrimitiveOperation::Modifier(self).stack_operation()
179    }
180}
181impl StackOperationConstructor for Operator {
182    fn stack_operation(self) -> StackOperation {
183        PrimitiveOperation::Operator(self).stack_operation()
184    }
185}
186
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn test_constant_changes_stack() {
194        let args = vec![];
195        let mut stack = Stack::new(&args);
196
197        Constant::C1
198            .stack_operation()
199            .update_stack(&mut stack);
200
201        assert_eq!(stack.result(), 1_f64);
202    }
203
204    #[test]
205    fn test_argument_changes_stack() {
206        let args = vec![100_f64];
207        let mut stack = Stack::new(&args);
208        assert_eq!(stack.result(), 0_f64);
209
210        Argument::Arg(0)
211            .stack_operation()
212            .update_stack(&mut stack);
213
214        assert_eq!(stack.result(), 100_f64);
215    }
216
217    #[test]
218    #[should_panic]
219    fn test_stack_panics_given_invalid_argument_index() {
220        let args = vec![100_f64];
221        let mut stack = Stack::new(&args);
222
223        Argument::Arg(1)
224            .stack_operation()
225            .update_stack(&mut stack);
226    }
227
228    #[test]
229    fn test_modifier_changes_stack() {
230        let args = vec![];
231        let mut stack = Stack::new(&args);
232        stack.push(6_f64);
233
234        PrimitiveOperation::Modifier(Modifier::Sqr)
235            .stack_operation()
236            .update_stack(&mut stack);
237
238        assert_eq!(stack.result(), 36_f64);
239    }
240
241    // #[test]
242    // fn test_custom_modifier_changes_stack() {
243    //     let args = vec![];
244    //     let mut stack = Stack::new(&args);
245    //     stack.push(2_f64);
246    //
247    //     PrimitiveOperation::Modifier(Modifier { func: |x| x * x * 3_f64 }, "ABC")
248    //         .stack_operation()
249    //         .update_stack(&mut stack);
250    //
251    //     assert_eq!(stack.result(), 12_f64);
252    // }
253
254    #[test]
255    fn test_operator_changes_stack() {
256        let args = vec![];
257        let mut stack = Stack::new(&args);
258        stack.push(2_f64);
259        stack.push(3_f64);
260
261        PrimitiveOperation::Operator(Operator::Multiply)
262            .stack_operation()
263            .update_stack(&mut stack);
264
265        assert_eq!(stack.result(), 6_f64);
266    }
267
268    // TODO:
269    // #[test]
270    // fn test_custom_operator_changes_stack() {
271    //     let args = vec![];
272    //     let mut stack = Stack::new(&args);
273    //     stack.push(2_f64);
274    //     stack.push(3_f64);
275    //
276    //     PrimitiveOperation::Operator(Operator { func: |x, y| (x - y) * (x + y) }, "ABC")
277    //         .stack_operation()
278    //         .update_stack(&mut stack);
279    //
280    //     assert_eq!(stack.pop(), -5_f64);
281    //     assert_eq!(stack.result(), 0_f64);
282    // }
283
284    #[test]
285    fn test_expression_changes_stack() {
286        let args = vec![];
287        let mut stack = Stack::new(&args);
288        stack.push(2_f64);
289        stack.push(3_f64);
290
291
292        // (x + 1)^2 - y
293        // x = 2, y = 6
294        Expression::new(vec![
295            Argument::Arg(0).stack_operation(),
296            Constant::C1.stack_operation(),
297            Operator::Plus.stack_operation(),
298            Modifier::Sqr.stack_operation(),
299            Argument::Arg(1).stack_operation(),
300            Operator::Minus.stack_operation(),
301        ])
302            .stack_operation()
303            .update_stack(&mut stack);
304
305        assert_eq!(stack.pop(), 6_f64);
306        assert_eq!(stack.result(), 0_f64);
307    }
308}