1use std::{
4 fmt::{
5 Display,
6 Result,
7 Formatter,
8 },
9 collections::HashMap,
10};
11
12use crate::{
13 standard::get_std_function,
14 Matrix,
15};
16
17use crate::error::*;
18
19#[derive(Clone, Debug)]
21pub enum Expression {
22 Assignment {
23 identifier: String,
24 value: Box<Expression>,
25 },
26 Identifier (String),
27 Int (i64),
28 Float (f64),
29 Matrix {
30 rows: usize,
31 cols: usize,
32 values: Vec<Expression>,
33 },
34 BinOp {
35 left: Box<Expression>,
36 op: String,
37 right: Box<Expression>,
38 },
39 Call {
40 name: String,
41 args: Vec<Expression>,
42 },
43 Nil,
44}
45
46impl Display for Expression {
49 fn fmt(&self, f: &mut Formatter<'_>) -> Result {
51 match self {
52 Expression::Assignment {
53 identifier: _,
54 value: v,
55 } => {
56 write!(f, "{}", v)
57 },
58 Expression::Identifier (s) => {
59 write!(f, "{}", s)
60 },
61 Expression::Int (i) => {
62 write!(f, "{}", i)
63 },
64 Expression::Float (float) => {
65 write!(f, "{:.8}", float)
66 },
67 Expression::Matrix {
68 rows: r,
69 cols: c,
70 values: v,
71 } => {
72 let mut result = String::new();
73 for i in 0..*r {
74 result.push('[');
75 for j in 0..*c {
76 let index = i*c + j;
77 result.push_str(
78 &format!(
79 "{:^10}",
80 format!("{}", v[index as usize])
81 )
82 );
83
84 if j != c - 1 {
86 result.push(' ');
87 }
88 }
89 result.push(']');
90 result.push('\n');
91 }
92 write!(f, "{}", result)
93 }
94 Expression::BinOp {
95 left: l,
96 op: o,
97 right: r,
98 } => {
99 write!(f, "{} {} {}", l, o, r)
100 },
101 Expression::Call {
102 name: _,
103 args: _,
104 } => {
105 unreachable!()
106 }
107 Expression::Nil => {
108 write!(f, "")
109 },
110 }
111 }
112}
113
114impl Expression {
115 pub fn simplify(&self, variables: &mut HashMap<String, Expression>) -> Self {
117 match self {
118 Expression::Identifier (s) => {
120 let expr = match variables.get(s) {
121 Some(e) => (*e).to_owned(),
122 None => {
123 throw(UndeclaredVariable (s.to_string()));
124 return Expression::Nil;
125 },
126 };
127 expr.simplify(variables)
129 },
130
131 Expression::Assignment {
133 identifier: ref i,
134 value: ref v,
135 } => {
136 let simplified = (**v).simplify(variables);
138
139 variables.insert(i.to_owned(), simplified.to_owned());
141
142 simplified.to_owned()
144 }
145
146 Expression::BinOp {
148 left: l,
149 op: o,
150 right: r,
151 } => {
152 let left = l.simplify(variables);
154 let right = r.simplify(variables);
155
156 if let Expression::Int (l) = left {
157 if let Expression::Int (r) = right {
158 let f = binop(l as f64, r as f64, &o);
160 if f.fract() == 0.0 {
161 Expression::Int (f as i64)
162 } else {
163 Expression::Float (f)
164 }
165 } else if let Expression::Float (r) = right {
166 let left_float = l as f64;
167 Expression::Float (binop(left_float, r, &o))
168 } else if let Expression::Matrix {
169 rows: r,
170 cols: c,
171 values: v,
172 } = right {
173 let mut values = Vec::new();
174
175 for val in v {
176 values.push(Expression::BinOp {
177 left: Box::new(left.to_owned()),
178 op: "*".to_string(),
179 right: Box::new(val),
180 }.simplify(variables));
181 }
182
183 Expression::Matrix {
184 rows: r,
185 cols: c,
186 values,
187 }
188 } else {
189 throw(InvalidOperands);
190 return Expression::Nil;
191 }
192 } else if let Expression::Float (l) = left {
193 if let Expression::Int (r) = right {
194 let right_float = r as f64;
195 Expression::Float (binop(l, right_float, &o))
196 } else if let Expression::Float (r) = right {
197 Expression::Float (binop(l, r, &o))
198 } else {
199 throw(InvalidOperands);
200 return Expression::Nil;
201 }
202 } else if let Expression::Matrix {
203 rows: r,
204 cols: k1,
205 values: vl,
206 } = left {
207 if let Expression::Matrix {
208 rows: k2,
209 cols: c,
210 values: vr,
211 } = right {
212 if k1 != k2 {
213 throw(ImproperDimensions);
214 return Expression::Nil;
215 }
216
217 matrix_dot(vl, vr, r, c, k1)
218 } else {
219 throw(InvalidOperands);
220 return Expression::Nil;
221 }
222 } else {
223 throw(InvalidOperands);
224 return Expression::Nil;
225 }
226 },
227
228 Expression::Int (_) => self.to_owned(),
230
231 Expression::Float (f) => {
233 if f.fract() == 0.0 {
234 Expression::Int (*f as i64)
235 } else {
236 Expression::Float (*f)
237 }
238 },
239
240 Expression::Matrix {
242 rows: r,
243 cols: c,
244 values: v,
245 } => {
246 let mut new = Vec::new();
247
248 for val in v {
249 new.push(val.simplify(variables));
250 }
251
252 if *r == 1 && *c == 1 {
253 v[0].simplify(variables).to_owned()
254 } else {
255 Expression::Matrix {
256 rows: *r,
257 cols: *c,
258 values: new,
259 }
260 }
261 },
262
263 Expression::Call {
268 name: n,
269 args: a,
270 } => {
271 let mut args = Vec::<Matrix>::new();
273 for arg in a {
274 let simplified = arg.simplify(variables);
275 if let Expression::Matrix {
276 rows: r,
277 cols: c,
278 values: v,
279 } = simplified {
280 let mut values: Vec<f64> = Vec::new();
282 for value in v {
283 if let Self::Int (i) = value {
284 values.push(i as f64);
285 } else if let Self::Float (f) = value {
286 values.push(f);
287 } else {
288 throw(InvalidValue);
290 return Expression::Nil;
291 }
292 }
293 args.push(Matrix::new(r, c, values));
294 } else if let Expression::Int (i) = simplified {
295 args.push(Matrix::new(1, 1, vec![i as f64]));
297 } else if let Expression::Float (f) = simplified {
298 args.push(Matrix::new(1, 1, vec![f]));
300 } else {
301 throw(InvalidOperands);
303 }
304 }
305
306 let stdfn = get_std_function(n.to_owned());
307 let output_matrix = stdfn.eval(args);
308
309 let values = output_matrix.copy_vals().iter().map(|x| Self::Float (*x)).collect::<Vec<Self>>();
310
311 Self::Matrix {
312 rows: output_matrix.rows(),
313 cols: output_matrix.cols(),
314 values,
315 }.simplify(variables)
316 },
317
318 Expression::Nil => self.to_owned(),
320 }
321 }
322}
323
324
325pub fn binop(x: f64, y: f64, binop: &str) -> f64 {
327 match binop {
328 "+" => x + y,
329 "-" => x - y,
330 "*" => x * y,
331 "/" => {
332 if y == 0.0 {
333 throw(DividedByZero);
334 0.0
335 } else {
336 x / y
337 }
338 },
339 _ => {
340 throw(InvalidOperator);
341 0.0
342 },
343 }
344}
345
346
347pub fn matrix_dot(left: Vec<Expression>, right: Vec<Expression>, rows: usize, cols: usize, count: usize) -> Expression {
349 let mut values = Vec::new();
350 for i in 0..rows {
351 for j in 0..cols {
352 let mut cell = Expression::Int (0);
353 for k in 0..count {
354 let addend = Expression::BinOp {
356 left: Box::new(left[i*count + k].to_owned()),
357 right: Box::new(right[k*cols + j].to_owned()),
358 op: "*".to_string(),
359 };
360
361 cell = Expression::BinOp {
362 left: Box::new(cell),
363 right: Box::new(addend),
364 op: "+".to_string(),
365 };
366 }
367 values.push(cell.simplify(&mut HashMap::new()));
369 }
370 }
371
372 Expression::Matrix {
373 rows,
374 cols,
375 values,
376 }
377}