expression_num/
lib.rs

1use asexp::Sexp;
2use expression::{Expression, ExpressionError};
3use num_traits::{One, Zero};
4use std::fmt::Debug;
5use std::ops::{Add, Div, Mul, Sub};
6
7pub trait NumType:
8    Debug
9    + Copy
10    + Clone
11    + PartialEq
12    + PartialOrd
13    + Default
14    + Zero
15    + One
16    + Add<Output = Self>
17    + Sub<Output = Self>
18    + Mul<Output = Self>
19    + Div<Output = Self>
20{
21}
22
23impl NumType for f32 {}
24impl NumType for f64 {}
25impl NumType for u32 {}
26impl NumType for u64 {}
27
28/// An expression evaluates to a numeric value of type `NumType`.
29#[derive(Clone, PartialEq, Eq, Debug)]
30pub enum NumExpr<T: NumType> {
31    /// A constant value.
32    Const(T),
33
34    /// References a variable by position
35    Var(usize),
36
37    Add(Box<NumExpr<T>>, Box<NumExpr<T>>),
38    Sub(Box<NumExpr<T>>, Box<NumExpr<T>>),
39    Mul(Box<NumExpr<T>>, Box<NumExpr<T>>),
40    Div(Box<NumExpr<T>>, Box<NumExpr<T>>),
41
42    /// Safe division with x/0 = 0.0
43    Divz(Box<NumExpr<T>>, Box<NumExpr<T>>),
44
45    /// Reciprocal (1 / x).
46    Recip(Box<NumExpr<T>>),
47
48    /// Reciprocal using safe division
49    Recipz(Box<NumExpr<T>>),
50}
51
52impl<T: NumType> NumExpr<T> {
53    pub fn op_add(self, other: NumExpr<T>) -> NumExpr<T> {
54        match (self, other) {
55            (NumExpr::Const(a), NumExpr::Const(b)) => NumExpr::Const(a + b),
56            (a, b) => NumExpr::Add(Box::new(a), Box::new(b)),
57        }
58    }
59
60    pub fn op_sub(self, other: NumExpr<T>) -> NumExpr<T> {
61        match (self, other) {
62            (NumExpr::Const(a), NumExpr::Const(b)) => NumExpr::Const(a - b),
63            (a, b) => NumExpr::Sub(Box::new(a), Box::new(b)),
64        }
65    }
66
67    pub fn op_mul(self, other: NumExpr<T>) -> NumExpr<T> {
68        match (self, other) {
69            (NumExpr::Const(a), NumExpr::Const(b)) => NumExpr::Const(a * b),
70            (a, b) => NumExpr::Mul(Box::new(a), Box::new(b)),
71        }
72    }
73
74    pub fn op_div(self, other: NumExpr<T>) -> NumExpr<T> {
75        match (self, other) {
76            (NumExpr::Const(a), NumExpr::Const(b)) if b != T::zero() => NumExpr::Const(a * b),
77            (a, b) => NumExpr::Div(Box::new(a), Box::new(b)),
78        }
79    }
80
81    pub fn op_divz(self, other: NumExpr<T>) -> NumExpr<T> {
82        match (self, other) {
83            (NumExpr::Const(a), NumExpr::Const(b)) => {
84                if b == T::zero() {
85                    NumExpr::Const(T::zero())
86                } else {
87                    NumExpr::Const(a * b)
88                }
89            }
90            (a, b) => NumExpr::Divz(Box::new(a), Box::new(b)),
91        }
92    }
93
94    pub fn op_recip(self) -> NumExpr<T> {
95        match self {
96            NumExpr::Const(a) if a != T::zero() => NumExpr::Const(T::one() / a),
97            a => NumExpr::Recip(Box::new(a)),
98        }
99    }
100
101    pub fn op_recipz(self) -> NumExpr<T> {
102        match self {
103            NumExpr::Const(a) => {
104                if a == T::zero() {
105                    NumExpr::Const(T::zero())
106                } else {
107                    NumExpr::Const(T::one() / a)
108                }
109            }
110            a => NumExpr::Recipz(Box::new(a)),
111        }
112    }
113}
114
115impl<T: NumType> Expression for NumExpr<T> {
116    type Element = T;
117
118    fn evaluate(&self, variables: &[Self::Element]) -> Result<Self::Element, ExpressionError> {
119        Ok(match self {
120            &NumExpr::Var(n) => variables
121                .get(n)
122                .ok_or(ExpressionError::InvalidVariable)?
123                .clone(),
124            &NumExpr::Const(val) => val,
125            &NumExpr::Add(ref e1, ref e2) => e1.evaluate(variables)? + e2.evaluate(variables)?,
126            &NumExpr::Sub(ref e1, ref e2) => e1.evaluate(variables)? - e2.evaluate(variables)?,
127            &NumExpr::Mul(ref e1, ref e2) => e1.evaluate(variables)? * e2.evaluate(variables)?,
128            &NumExpr::Div(ref e1, ref e2) => {
129                let a = e1.evaluate(variables)?;
130                let div = e2.evaluate(variables)?;
131                if div == T::zero() {
132                    return Err(ExpressionError::DivByZero);
133                }
134                a / div
135            }
136            &NumExpr::Divz(ref e1, ref e2) => {
137                let a = e1.evaluate(variables)?;
138                let div = e2.evaluate(variables)?;
139                if div == T::zero() {
140                    div
141                } else {
142                    a / div
143                }
144            }
145            &NumExpr::Recip(ref e1) => {
146                let div = e1.evaluate(variables)?;
147                if div == T::zero() {
148                    return Err(ExpressionError::DivByZero);
149                } else {
150                    T::one() / div
151                }
152            }
153            &NumExpr::Recipz(ref e1) => {
154                let div = e1.evaluate(variables)?;
155                if div == T::zero() {
156                    div
157                } else {
158                    T::one() / div
159                }
160            }
161        })
162    }
163}
164
165impl<'a, T: NumType + Into<Sexp>> Into<Sexp> for &'a NumExpr<T> {
166    fn into(self) -> Sexp {
167        match self {
168            &NumExpr::Const(n) => n.into(),
169            &NumExpr::Var(n) => Sexp::from(format!("${}", n)),
170            &NumExpr::Add(ref a, ref b) => Sexp::from((
171                "+",
172                Into::<Sexp>::into(a.as_ref()),
173                Into::<Sexp>::into(b.as_ref()),
174            )),
175            &NumExpr::Sub(ref a, ref b) => Sexp::from((
176                "-",
177                Into::<Sexp>::into(a.as_ref()),
178                Into::<Sexp>::into(b.as_ref()),
179            )),
180            &NumExpr::Mul(ref a, ref b) => Sexp::from((
181                "*",
182                Into::<Sexp>::into(a.as_ref()),
183                Into::<Sexp>::into(b.as_ref()),
184            )),
185            &NumExpr::Div(ref a, ref b) => Sexp::from((
186                "/",
187                Into::<Sexp>::into(a.as_ref()),
188                Into::<Sexp>::into(b.as_ref()),
189            )),
190            &NumExpr::Divz(ref a, ref b) => Sexp::from((
191                "divz",
192                Into::<Sexp>::into(a.as_ref()),
193                Into::<Sexp>::into(b.as_ref()),
194            )),
195            &NumExpr::Recip(ref a) => Sexp::from(("recip", Into::<Sexp>::into(a.as_ref()))),
196            &NumExpr::Recipz(ref a) => Sexp::from(("recipz", Into::<Sexp>::into(a.as_ref()))),
197        }
198    }
199}
200
201#[cfg(test)]
202const NO_VARS: [f32; 0] = [];
203
204#[test]
205fn test_expr_divz() {
206    let expr = NumExpr::Divz(Box::new(NumExpr::Const(1.0)), Box::new(NumExpr::Const(0.0)));
207    assert_eq!(Ok(0.0), expr.evaluate(&NO_VARS));
208}
209
210#[test]
211fn test_expr_recipz() {
212    let expr = NumExpr::Recipz(Box::new(NumExpr::Const(0.0)));
213    assert_eq!(Ok(0.0), expr.evaluate(&NO_VARS));
214
215    let expr = NumExpr::Recipz(Box::new(NumExpr::Const(1.0)));
216    assert_eq!(Ok(1.0), expr.evaluate(&NO_VARS));
217
218    let expr = NumExpr::Recipz(Box::new(NumExpr::Const(0.5)));
219    assert_eq!(Ok(2.0), expr.evaluate(&NO_VARS));
220}
221
222#[test]
223fn test_expr() {
224    let expr = NumExpr::Sub(
225        Box::new(NumExpr::Const(0.0)),
226        Box::new(NumExpr::Div(
227            Box::new(NumExpr::Mul(
228                Box::new(NumExpr::Add(
229                    Box::new(NumExpr::Const(1.0)),
230                    Box::new(NumExpr::Var(0)),
231                )),
232                Box::new(NumExpr::Var(1)),
233            )),
234            Box::new(NumExpr::Const(2.0)),
235        )),
236    );
237
238    fn fun(a: f32, b: f32) -> f32 {
239        0.0 - ((1.0 + a) * b) / 2.0
240    }
241
242    fn check(expr: &NumExpr<f32>, a: f32, b: f32) {
243        assert_eq!(Ok(fun(a, b)), expr.evaluate(&[a, b]))
244    }
245
246    check(&expr, 123.0, 4444.0);
247    check(&expr, 0.0, -12.0);
248}
249
250#[test]
251fn test_constant_folding() {
252    let expr = NumExpr::Const(1.0);
253    let expr2 = expr.op_add(NumExpr::Const(2.0));
254    assert_eq!(NumExpr::Const(1.0 + 2.0), expr2);
255
256    let expr = NumExpr::Var(1);
257    let expr2 = expr.op_add(NumExpr::Const(2.0));
258    assert_eq!(
259        NumExpr::Add(Box::new(NumExpr::Var(1)), Box::new(NumExpr::Const(2.0))),
260        expr2
261    );
262}