1use asexp::Sexp;
2use expression::{Expression, ExpressionError};
3use num_traits::{One, Zero};
4use std::fmt::Debug;
5use std::ops::{Add, Div, Mul, Sub};
6
7pub trait NumType:
8 Debug
9 + Copy
10 + Clone
11 + PartialEq
12 + PartialOrd
13 + Default
14 + Zero
15 + One
16 + Add<Output = Self>
17 + Sub<Output = Self>
18 + Mul<Output = Self>
19 + Div<Output = Self>
20{
21}
22
23impl NumType for f32 {}
24impl NumType for f64 {}
25impl NumType for u32 {}
26impl NumType for u64 {}
27
28#[derive(Clone, PartialEq, Eq, Debug)]
30pub enum NumExpr<T: NumType> {
31 Const(T),
33
34 Var(usize),
36
37 Add(Box<NumExpr<T>>, Box<NumExpr<T>>),
38 Sub(Box<NumExpr<T>>, Box<NumExpr<T>>),
39 Mul(Box<NumExpr<T>>, Box<NumExpr<T>>),
40 Div(Box<NumExpr<T>>, Box<NumExpr<T>>),
41
42 Divz(Box<NumExpr<T>>, Box<NumExpr<T>>),
44
45 Recip(Box<NumExpr<T>>),
47
48 Recipz(Box<NumExpr<T>>),
50}
51
52impl<T: NumType> NumExpr<T> {
53 pub fn op_add(self, other: NumExpr<T>) -> NumExpr<T> {
54 match (self, other) {
55 (NumExpr::Const(a), NumExpr::Const(b)) => NumExpr::Const(a + b),
56 (a, b) => NumExpr::Add(Box::new(a), Box::new(b)),
57 }
58 }
59
60 pub fn op_sub(self, other: NumExpr<T>) -> NumExpr<T> {
61 match (self, other) {
62 (NumExpr::Const(a), NumExpr::Const(b)) => NumExpr::Const(a - b),
63 (a, b) => NumExpr::Sub(Box::new(a), Box::new(b)),
64 }
65 }
66
67 pub fn op_mul(self, other: NumExpr<T>) -> NumExpr<T> {
68 match (self, other) {
69 (NumExpr::Const(a), NumExpr::Const(b)) => NumExpr::Const(a * b),
70 (a, b) => NumExpr::Mul(Box::new(a), Box::new(b)),
71 }
72 }
73
74 pub fn op_div(self, other: NumExpr<T>) -> NumExpr<T> {
75 match (self, other) {
76 (NumExpr::Const(a), NumExpr::Const(b)) if b != T::zero() => NumExpr::Const(a * b),
77 (a, b) => NumExpr::Div(Box::new(a), Box::new(b)),
78 }
79 }
80
81 pub fn op_divz(self, other: NumExpr<T>) -> NumExpr<T> {
82 match (self, other) {
83 (NumExpr::Const(a), NumExpr::Const(b)) => {
84 if b == T::zero() {
85 NumExpr::Const(T::zero())
86 } else {
87 NumExpr::Const(a * b)
88 }
89 }
90 (a, b) => NumExpr::Divz(Box::new(a), Box::new(b)),
91 }
92 }
93
94 pub fn op_recip(self) -> NumExpr<T> {
95 match self {
96 NumExpr::Const(a) if a != T::zero() => NumExpr::Const(T::one() / a),
97 a => NumExpr::Recip(Box::new(a)),
98 }
99 }
100
101 pub fn op_recipz(self) -> NumExpr<T> {
102 match self {
103 NumExpr::Const(a) => {
104 if a == T::zero() {
105 NumExpr::Const(T::zero())
106 } else {
107 NumExpr::Const(T::one() / a)
108 }
109 }
110 a => NumExpr::Recipz(Box::new(a)),
111 }
112 }
113}
114
115impl<T: NumType> Expression for NumExpr<T> {
116 type Element = T;
117
118 fn evaluate(&self, variables: &[Self::Element]) -> Result<Self::Element, ExpressionError> {
119 Ok(match self {
120 &NumExpr::Var(n) => variables
121 .get(n)
122 .ok_or(ExpressionError::InvalidVariable)?
123 .clone(),
124 &NumExpr::Const(val) => val,
125 &NumExpr::Add(ref e1, ref e2) => e1.evaluate(variables)? + e2.evaluate(variables)?,
126 &NumExpr::Sub(ref e1, ref e2) => e1.evaluate(variables)? - e2.evaluate(variables)?,
127 &NumExpr::Mul(ref e1, ref e2) => e1.evaluate(variables)? * e2.evaluate(variables)?,
128 &NumExpr::Div(ref e1, ref e2) => {
129 let a = e1.evaluate(variables)?;
130 let div = e2.evaluate(variables)?;
131 if div == T::zero() {
132 return Err(ExpressionError::DivByZero);
133 }
134 a / div
135 }
136 &NumExpr::Divz(ref e1, ref e2) => {
137 let a = e1.evaluate(variables)?;
138 let div = e2.evaluate(variables)?;
139 if div == T::zero() {
140 div
141 } else {
142 a / div
143 }
144 }
145 &NumExpr::Recip(ref e1) => {
146 let div = e1.evaluate(variables)?;
147 if div == T::zero() {
148 return Err(ExpressionError::DivByZero);
149 } else {
150 T::one() / div
151 }
152 }
153 &NumExpr::Recipz(ref e1) => {
154 let div = e1.evaluate(variables)?;
155 if div == T::zero() {
156 div
157 } else {
158 T::one() / div
159 }
160 }
161 })
162 }
163}
164
165impl<'a, T: NumType + Into<Sexp>> Into<Sexp> for &'a NumExpr<T> {
166 fn into(self) -> Sexp {
167 match self {
168 &NumExpr::Const(n) => n.into(),
169 &NumExpr::Var(n) => Sexp::from(format!("${}", n)),
170 &NumExpr::Add(ref a, ref b) => Sexp::from((
171 "+",
172 Into::<Sexp>::into(a.as_ref()),
173 Into::<Sexp>::into(b.as_ref()),
174 )),
175 &NumExpr::Sub(ref a, ref b) => Sexp::from((
176 "-",
177 Into::<Sexp>::into(a.as_ref()),
178 Into::<Sexp>::into(b.as_ref()),
179 )),
180 &NumExpr::Mul(ref a, ref b) => Sexp::from((
181 "*",
182 Into::<Sexp>::into(a.as_ref()),
183 Into::<Sexp>::into(b.as_ref()),
184 )),
185 &NumExpr::Div(ref a, ref b) => Sexp::from((
186 "/",
187 Into::<Sexp>::into(a.as_ref()),
188 Into::<Sexp>::into(b.as_ref()),
189 )),
190 &NumExpr::Divz(ref a, ref b) => Sexp::from((
191 "divz",
192 Into::<Sexp>::into(a.as_ref()),
193 Into::<Sexp>::into(b.as_ref()),
194 )),
195 &NumExpr::Recip(ref a) => Sexp::from(("recip", Into::<Sexp>::into(a.as_ref()))),
196 &NumExpr::Recipz(ref a) => Sexp::from(("recipz", Into::<Sexp>::into(a.as_ref()))),
197 }
198 }
199}
200
201#[cfg(test)]
202const NO_VARS: [f32; 0] = [];
203
204#[test]
205fn test_expr_divz() {
206 let expr = NumExpr::Divz(Box::new(NumExpr::Const(1.0)), Box::new(NumExpr::Const(0.0)));
207 assert_eq!(Ok(0.0), expr.evaluate(&NO_VARS));
208}
209
210#[test]
211fn test_expr_recipz() {
212 let expr = NumExpr::Recipz(Box::new(NumExpr::Const(0.0)));
213 assert_eq!(Ok(0.0), expr.evaluate(&NO_VARS));
214
215 let expr = NumExpr::Recipz(Box::new(NumExpr::Const(1.0)));
216 assert_eq!(Ok(1.0), expr.evaluate(&NO_VARS));
217
218 let expr = NumExpr::Recipz(Box::new(NumExpr::Const(0.5)));
219 assert_eq!(Ok(2.0), expr.evaluate(&NO_VARS));
220}
221
222#[test]
223fn test_expr() {
224 let expr = NumExpr::Sub(
225 Box::new(NumExpr::Const(0.0)),
226 Box::new(NumExpr::Div(
227 Box::new(NumExpr::Mul(
228 Box::new(NumExpr::Add(
229 Box::new(NumExpr::Const(1.0)),
230 Box::new(NumExpr::Var(0)),
231 )),
232 Box::new(NumExpr::Var(1)),
233 )),
234 Box::new(NumExpr::Const(2.0)),
235 )),
236 );
237
238 fn fun(a: f32, b: f32) -> f32 {
239 0.0 - ((1.0 + a) * b) / 2.0
240 }
241
242 fn check(expr: &NumExpr<f32>, a: f32, b: f32) {
243 assert_eq!(Ok(fun(a, b)), expr.evaluate(&[a, b]))
244 }
245
246 check(&expr, 123.0, 4444.0);
247 check(&expr, 0.0, -12.0);
248}
249
250#[test]
251fn test_constant_folding() {
252 let expr = NumExpr::Const(1.0);
253 let expr2 = expr.op_add(NumExpr::Const(2.0));
254 assert_eq!(NumExpr::Const(1.0 + 2.0), expr2);
255
256 let expr = NumExpr::Var(1);
257 let expr2 = expr.op_add(NumExpr::Const(2.0));
258 assert_eq!(
259 NumExpr::Add(Box::new(NumExpr::Var(1)), Box::new(NumExpr::Const(2.0))),
260 expr2
261 );
262}