elemental/
expression.rs

1//! Abstracts over Elemental expressions.
2
3use std::{
4    fmt::{
5        Display,
6        Result,
7        Formatter,
8    },
9    collections::HashMap,
10};
11
12use crate::{
13    standard::get_std_function,
14    Matrix,
15};
16
17use crate::error::*;
18
19/// Defines the expression types that are available in Elemental.
20#[derive(Clone, Debug)]
21pub enum Expression {
22    Assignment {
23        identifier: String,
24        value: Box<Expression>,
25    },
26    Identifier (String),
27    Int (i64),
28    Float (f64),
29    Matrix {
30        rows: usize,
31        cols: usize,
32        values: Vec<Expression>,
33    },
34    BinOp {
35        left: Box<Expression>,
36        op: String,
37        right: Box<Expression>,
38    },
39    Call {
40        name: String,
41        args: Vec<Expression>,
42    },
43    Nil,
44}
45
46/// Implementing `std::fmt::Display` allows us to print expressions
47/// using the default formatter.
48impl Display for Expression {
49    /// Display each `Expression`.
50    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
51        match self {
52            Expression::Assignment {
53                identifier: _,
54                value: v,
55            } => {
56                write!(f, "{}", v)
57            },
58            Expression::Identifier (s) => {
59                write!(f, "{}", s)
60            },
61            Expression::Int (i) => {
62                write!(f, "{}", i)
63            },
64            Expression::Float (float) => {
65                write!(f, "{:.8}", float)
66            },
67            Expression::Matrix {
68                rows: r,
69                cols: c,
70                values: v,
71            } => {
72                let mut result = String::new();
73                for i in 0..*r {
74                    result.push('[');
75                    for j in 0..*c {
76                        let index = i*c + j;
77                        result.push_str(
78                            &format!(
79                                "{:^10}",
80                                format!("{}", v[index as usize])
81                            )
82                        );
83
84                        // Write a tab if we're not at the end yet
85                        if j != c - 1 {
86                            result.push(' ');
87                        }
88                    }
89                    result.push(']');
90                    result.push('\n');
91                }
92                write!(f, "{}", result)
93            }
94            Expression::BinOp {
95                left: l,
96                op: o,
97                right: r,
98            } => {
99                write!(f, "{} {} {}", l, o, r)
100            },
101            Expression::Call {
102                name: _,
103                args: _,
104            } => {
105                unreachable!()
106            }
107            Expression::Nil => {
108                write!(f, "")
109            },
110        }
111    }
112}
113
114impl Expression {
115    /// Simplify this expression, given a reference to a list of variables.
116    pub fn simplify(&self, variables: &mut HashMap<String, Expression>) -> Self {
117        match self {
118            // Look up the variable and plug in
119            Expression::Identifier (s) => {
120                let expr = match variables.get(s) {
121                    Some(e) => (*e).to_owned(),
122                    None => {
123                        throw(UndeclaredVariable (s.to_string()));
124                        return Expression::Nil;
125                    },
126                };
127                // Simplify
128                expr.simplify(variables)
129            },
130
131            // Insert the assigned variable into the list of variables
132            Expression::Assignment {
133                identifier: ref i,
134                value: ref v,
135            } => {
136                // Simplify the value of assignment
137                let simplified = (**v).simplify(variables);
138
139                // Register the variable
140                variables.insert(i.to_owned(), simplified.to_owned());
141
142                // Return the simplified value
143                simplified.to_owned()
144            }
145
146            // Simplify the left and right and return
147            Expression::BinOp {
148                left: l,
149                op: o,
150                right: r,
151            } => {
152                // Simplify the left-hand and right-hand sides
153                let left = l.simplify(variables);
154                let right = r.simplify(variables);
155
156                if let Expression::Int (l) = left {
157                    if let Expression::Int (r) = right {
158                        // Evaluate this as a float, then try to cast it to an `Int`
159                        let f = binop(l as f64, r as f64, &o);
160                        if f.fract() == 0.0 {
161                            Expression::Int (f as i64)
162                        } else {
163                            Expression::Float (f)
164                        }
165                    } else if let Expression::Float (r) = right {
166                        let left_float = l as f64;
167                        Expression::Float (binop(left_float, r, &o))
168                    } else if let Expression::Matrix {
169                        rows: r,
170                        cols: c,
171                        values: v,
172                    } = right {
173                        let mut values = Vec::new();
174
175                        for val in v {
176                            values.push(Expression::BinOp {
177                                left: Box::new(left.to_owned()),
178                                op: "*".to_string(),
179                                right: Box::new(val),
180                            }.simplify(variables)); 
181                        }
182
183                        Expression::Matrix {
184                            rows: r,
185                            cols: c,
186                            values,
187                        }
188                    } else {
189                        throw(InvalidOperands);
190                        return Expression::Nil;
191                    }
192                } else if let Expression::Float (l) = left {
193                    if let Expression::Int (r) = right {
194                        let right_float = r as f64;
195                        Expression::Float (binop(l, right_float, &o))
196                    } else if let Expression::Float (r) = right {
197                        Expression::Float (binop(l, r, &o))
198                    } else {
199                        throw(InvalidOperands);
200                        return Expression::Nil;
201                    }
202                } else if let Expression::Matrix {
203                    rows: r,
204                    cols: k1,
205                    values: vl,
206                } = left {
207                    if let Expression::Matrix {
208                        rows: k2,
209                        cols: c,
210                        values: vr,
211                    } = right {
212                        if k1 != k2 {
213                            throw(ImproperDimensions);
214                            return Expression::Nil;
215                        }
216
217                        matrix_dot(vl, vr, r, c, k1)
218                    } else {
219                        throw(InvalidOperands);
220                        return Expression::Nil;
221                    }
222                } else {
223                    throw(InvalidOperands);
224                    return Expression::Nil;
225                }
226            },
227            
228            // `Int is already in simplest form
229            Expression::Int (_) => self.to_owned(),
230
231            // `Float` can be reduced to `Int` if it has no fractional part
232            Expression::Float (f) => {
233                if f.fract() == 0.0 {
234                    Expression::Int (*f as i64)
235                } else {
236                    Expression::Float (*f)
237                }
238            },
239            
240            // To simplify a `Matrix`, simplify each value
241            Expression::Matrix {
242                rows: r,
243                cols: c,
244                values: v,
245            } => {
246                let mut new = Vec::new();
247
248                for val in v {
249                    new.push(val.simplify(variables));
250                }
251
252                if *r == 1 && *c == 1 {
253                    v[0].simplify(variables).to_owned()
254                } else {
255                    Expression::Matrix {
256                        rows: *r,
257                        cols: *c,
258                        values: new,
259                    }
260                }
261            },
262
263            // To simplify a call, look up the function in the standard library
264            // and pass the arguments necessary
265            // 
266            // In Elemental, all functions act on matrices.  
267            Expression::Call {
268                name: n,
269                args: a,
270            } => {
271                // Simplify each argument and convert them to "native" matrices.
272                let mut args = Vec::<Matrix>::new();
273                for arg in a {
274                    let simplified = arg.simplify(variables);
275                    if let Expression::Matrix {
276                        rows: r,
277                        cols: c,
278                        values: v,
279                    } = simplified {
280                        // Convert each value in the matrix from `Expression` to `f64`.
281                        let mut values: Vec<f64> = Vec::new();
282                        for value in v {
283                            if let Self::Int (i) = value {
284                                values.push(i as f64);
285                            } else if let Self::Float (f) = value {
286                                values.push(f);
287                            } else {
288                                // A value in one of the matrices is not a numeric literal
289                                throw(InvalidValue);
290                                return Expression::Nil;
291                            }
292                        }
293                        args.push(Matrix::new(r, c, values));
294                    } else if let Expression::Int (i) = simplified {
295                        // If one value is a number, convert it into a 1x1 matrix
296                        args.push(Matrix::new(1, 1, vec![i as f64]));
297                    } else if let Expression::Float (f) = simplified {
298                        // If one value is a number, convert it into a 1x1 matrix
299                        args.push(Matrix::new(1, 1, vec![f]));
300                    } else {
301                        // One of the arguments is not a matrix or a number
302                        throw(InvalidOperands);
303                    }
304                }
305
306                let stdfn = get_std_function(n.to_owned());
307                let output_matrix = stdfn.eval(args);
308
309                let values = output_matrix.copy_vals().iter().map(|x| Self::Float (*x)).collect::<Vec<Self>>();
310
311                Self::Matrix {
312                    rows: output_matrix.rows(),
313                    cols: output_matrix.cols(),
314                    values,
315                }.simplify(variables)
316            },
317            
318            // `Nil` is already in simplest form
319            Expression::Nil => self.to_owned(),
320        }
321    }
322}
323
324
325/// Executes the given binary operation on two floats.
326pub fn binop(x: f64, y: f64, binop: &str) -> f64 {
327    match binop {
328        "+" => x + y,
329        "-" => x - y,
330        "*" => x * y,
331        "/" => {
332            if y == 0.0 {
333                throw(DividedByZero);
334                0.0
335            } else {
336                x / y
337            }
338        },
339        _ => {
340            throw(InvalidOperator);
341            0.0
342        },
343    }
344}
345
346
347/// Computes the dot product of two matrices.
348pub fn matrix_dot(left: Vec<Expression>, right: Vec<Expression>, rows: usize, cols: usize, count: usize) -> Expression {
349    let mut values = Vec::new();
350    for i in 0..rows {
351        for j in 0..cols {
352            let mut cell = Expression::Int (0);
353            for k in 0..count {
354                // Add the addend to the cell
355                let addend = Expression::BinOp {
356                    left: Box::new(left[i*count + k].to_owned()),
357                    right: Box::new(right[k*cols + j].to_owned()),
358                    op: "*".to_string(),
359                };
360
361                cell = Expression::BinOp {
362                    left: Box::new(cell),
363                    right: Box::new(addend),
364                    op: "+".to_string(),
365                };
366            }
367            // Push the cell to the list of values
368            values.push(cell.simplify(&mut HashMap::new()));
369        }
370    }
371
372    Expression::Matrix {
373        rows,
374        cols,
375        values,
376    }
377}