mathml_rs/methods/
evaluate.rs

1use super::super::structs::constants::Constant;
2use super::super::structs::math_node::{MathNode, NodeIndex};
3use super::super::structs::numbers::{NumType, Number};
4use super::super::structs::op::Op;
5use math::round;
6use mathru::statistics::combins::factorial;
7use std::collections::HashMap;
8
9pub fn evaluate_node(
10    nodes: &[MathNode],
11    head_idx: NodeIndex,
12    values: &HashMap<String, f64>,
13    functions: &HashMap<String, Vec<MathNode>>,
14) -> Result<f64, String> {
15    let head = nodes[head_idx].clone();
16    //dbg!(values);
17    match head {
18        MathNode::Root(root) => {
19            if root.children.len() != 1 {
20                return Err("Root with multiple/zero children!".to_string());
21            }
22            evaluate_node(nodes, root.children[0], values, functions)
23        }
24        MathNode::Apply(apply) => {
25            let op_result = apply.get_op(nodes);
26            // If this is a regular mathematical operator, go ahead
27            if let Ok(op) = op_result {
28                match op {
29                    Op::Times => {
30                        let mut result = 1.0;
31                        for operand_idx in apply.operands {
32                            result *= evaluate_node(nodes, operand_idx, values, functions)?;
33                        }
34                        Ok(result)
35                    }
36                    Op::Plus => {
37                        let mut result = 0.0;
38                        for operand_idx in apply.operands {
39                            result += evaluate_node(nodes, operand_idx, values, functions)?;
40                        }
41                        Ok(result)
42                    }
43                    Op::Minus => {
44                        let a = evaluate_node(nodes, apply.operands[0], values, functions)?;
45                        match apply.operands.len() {
46                            1 => Ok(-a),
47                            2 => {
48                                let b = evaluate_node(nodes, apply.operands[1], values, functions)?;
49                                Ok(a - b)
50                            }
51                            _ => Err("Too many operands!".to_string()),
52                        }
53                    }
54                    Op::Divide => {
55                        if apply.operands.len() != 2 {
56                            return Err("Invalid number of operands.".to_string());
57                        }
58                        let a = evaluate_node(nodes, apply.operands[0], values, functions)?;
59                        let b = evaluate_node(nodes, apply.operands[1], values, functions)?;
60                        Ok(a / b)
61                    }
62                    Op::Power => {
63                        if apply.operands.len() != 2 {
64                            return Err("Invalid number of operands.".to_string());
65                        }
66                        let a = evaluate_node(nodes, apply.operands[0], values, functions)?;
67                        let b = evaluate_node(nodes, apply.operands[1], values, functions)?;
68                        Ok(a.powf(b))
69                    }
70                    Op::Ceiling => {
71                        if apply.operands.len() != 1 {
72                            return Err("Invalid number of operands.".to_string());
73                        }
74                        let a = evaluate_node(nodes, apply.operands[0], values, functions)?;
75                        Ok(round::ceil(a, 0))
76                    }
77                    Op::Floor => {
78                        if apply.operands.len() != 1 {
79                            return Err("Invalid number of operands.".to_string());
80                        }
81                        let a = evaluate_node(nodes, apply.operands[0], values, functions)?;
82                        Ok(round::floor(a, 0))
83                    }
84                    Op::Factorial => {
85                        if apply.operands.len() != 1 {
86                            return Err("Invalid number of operands.".to_string());
87                        }
88                        let a = evaluate_node(nodes, apply.operands[0], values, functions)?;
89                        Ok(factorial(a as u32) as f64)
90                    }
91                    _ => Err("Evaluation not supported for operator.".to_string()),
92                }
93            } else {
94                // Evaluate as a lambda function
95                let mut res = None;
96                if let MathNode::Ci(ci) = &nodes[apply.operator.unwrap()] {
97                    let lambda_name = ci.name.as_ref().unwrap();
98                    if let Some(lambda) = functions.get(lambda_name) {
99                        let mut argument_values = Vec::new();
100                        for operand in apply.operands {
101                            argument_values.push(evaluate_node(nodes, operand, values, functions)?);
102                        }
103                        res = Some(evaluate_lambda(
104                            lambda,
105                            0,
106                            &argument_values,
107                            values,
108                            functions,
109                        )?);
110                        //dbg!(lambda_name, res);
111                    }
112                }
113                if let Some(value) = res {
114                    Ok(value)
115                } else {
116                    Err("Invalid operator".to_string())
117                }
118            }
119        }
120        MathNode::Cn(cn) => match &cn.r#type {
121            Some(NumType::Integer) => {
122                if let Some(Number::Integer(i)) = cn.value {
123                    let result = i.into();
124                    //println!("Returning {} from cn", result);
125                    Ok(result)
126                } else {
127                    Err("Wrong type".to_string())
128                }
129            }
130            Some(NumType::Real) | None => {
131                if let Some(Number::Real(r)) = cn.value {
132                    Ok(r)
133                } else {
134                    Err("Wrong type".to_string())
135                }
136            }
137            Some(NumType::Rational) => {
138                if let Some(Number::Rational(x, y)) = cn.value {
139                    Ok((x as f64) / (y as f64))
140                } else {
141                    Err("Wrong type".to_string())
142                }
143            }
144            Some(NumType::ENotation) => {
145                if let Some(Number::ENotation(x, y)) = cn.value {
146                    Ok(x * 10.0_f64.powf(y as f64))
147                } else {
148                    Err("Wrong type".to_string())
149                }
150            }
151            _ => Err("Invalid Cn type".to_string()),
152        },
153        MathNode::Ci(ci) => {
154            let name = ci.name.expect("Ci element with no content!");
155            if values.contains_key(&name) {
156                let result = *values.get(&name).unwrap();
157                //println!("Returning {} from ci", result);
158                Ok(result)
159            } else {
160                let error = format!("No value found for Ci {}", name);
161                Err(error)
162            }
163        }
164        MathNode::Piecewise(..) => Ok(evaluate_piecewise(nodes, head_idx, values, functions)?),
165        _ => {
166            let error = format!("Couldn't evaluate operator {}", head);
167            Err(error)
168        }
169    }
170}
171
172pub fn evaluate_lambda(
173    nodes: &[MathNode],
174    head_idx: NodeIndex,
175    argument_values: &[f64],
176    values: &HashMap<String, f64>,
177    functions: &HashMap<String, Vec<MathNode>>,
178) -> Result<f64, String> {
179    let head = nodes[head_idx].clone();
180    match head {
181        MathNode::Root(root) => {
182            if root.children.len() != 1 {
183                return Err("Root with multiple/zero children!".to_string());
184            }
185            evaluate_lambda(nodes, root.children[0], argument_values, values, functions)
186        }
187        MathNode::Lambda(lambda) => {
188            let mut argument_names = Vec::new();
189            for binding in lambda.bindings {
190                if let MathNode::BVar(bvar) = nodes[binding].clone() {
191                    for child in bvar.children {
192                        if let MathNode::Ci(ci) = nodes[child].clone() {
193                            argument_names.push(ci.name.unwrap());
194                        }
195                    }
196                }
197            }
198
199            if argument_values.len() != argument_names.len() {
200                Err("Argument names and values mismatch".to_string())
201            } else {
202                let mut assignments: HashMap<String, f64> = HashMap::new();
203                for i in 0..argument_values.len() {
204                    assignments.insert(argument_names[i].clone(), argument_values[i]);
205                }
206                let res = evaluate_node(nodes, lambda.expr.unwrap(), &assignments, functions)?;
207                Ok(res)
208            }
209        }
210        _ => evaluate_node(nodes, head_idx, values, functions),
211    }
212}
213
214pub fn evaluate_piecewise(
215    nodes: &[MathNode],
216    head_idx: NodeIndex,
217    values: &HashMap<String, f64>,
218    functions: &HashMap<String, Vec<MathNode>>,
219) -> Result<f64, String> {
220    let head = nodes[head_idx].clone();
221    match head {
222        MathNode::Piecewise(piecewise) => {
223            let pieces_idx = piecewise.pieces;
224            let otherwise_idx = piecewise.otherwise;
225            let mut result = None;
226            for piece_idx in pieces_idx {
227                let (condition, value) = evaluate_piece(nodes, piece_idx, values, functions)?;
228                if condition {
229                    if let Some(..) = value {
230                        result = value;
231                        break;
232                    }
233                }
234            }
235            if let Some(value) = result {
236                Ok(value)
237            } else if let Some(otherwise_idx_value) = otherwise_idx {
238                Ok(evaluate_piecewise(
239                    nodes,
240                    otherwise_idx_value,
241                    values,
242                    functions,
243                )?)
244            } else {
245                Err("All pieces evaluated to false and no otherwise branch found.".to_string())
246            }
247        }
248        MathNode::Otherwise(otherwise) => {
249            let expr_idx = otherwise.expr.expect("Otherwise branch is empty!");
250            Ok(evaluate_node(nodes, expr_idx, values, functions)?)
251        }
252        _ => {
253            //dbg!(head);
254            Err("haha couldn't parse".to_string())
255        }
256    }
257}
258
259pub fn evaluate_piece(
260    nodes: &[MathNode],
261    head_idx: NodeIndex,
262    values: &HashMap<String, f64>,
263    functions: &HashMap<String, Vec<MathNode>>,
264) -> Result<(bool, Option<f64>), String> {
265    let head = nodes[head_idx].clone();
266    match head {
267        MathNode::Piece(piece) => {
268            let expr_idx = piece.expr.expect("Piece has no expression!");
269            let condition_idx = piece.condition.expect("Piece condition is empty!");
270            let condition_result = evaluate_condition(nodes, condition_idx, values, functions)?;
271            if condition_result {
272                let expr_result = evaluate_node(nodes, expr_idx, values, functions)?;
273                Ok((true, Some(expr_result)))
274            } else {
275                Ok((false, None))
276            }
277        }
278        _ => {
279            //dbg!(head);
280            Err("haha couldn't parse".to_string())
281        }
282    }
283}
284
285pub fn evaluate_condition(
286    nodes: &[MathNode],
287    head_idx: NodeIndex,
288    values: &HashMap<String, f64>,
289    functions: &HashMap<String, Vec<MathNode>>,
290) -> Result<bool, String> {
291    let head = nodes[head_idx].clone();
292    match head {
293        MathNode::Constant(constantnode) => {
294            if let Some(constant) = constantnode.constant {
295                match constant {
296                    Constant::False => Ok(false),
297                    Constant::True => Ok(true),
298                    _ => Err("haha".to_string()),
299                }
300            } else {
301                Err("hh".to_string())
302            }
303        }
304        MathNode::Apply(apply) => {
305            let op_result = apply.get_op(nodes);
306            let mut result = None;
307            // If this is a regular mathematical operator, go ahead
308            if let Ok(op) = op_result {
309                let mut operand_results = Vec::<f64>::new();
310                let mut child_condition_results = Vec::<bool>::new();
311                match op {
312                    Op::Eq | Op::Neq | Op::Geq | Op::Leq | Op::Gt | Op::Lt => {
313                        if apply.operands.len() != 2 {
314                            return Err("Invalid number of operands.".to_string());
315                        }
316                        for operand_location in apply.operands {
317                            operand_results.push(evaluate_node(
318                                nodes,
319                                operand_location,
320                                values,
321                                functions,
322                            )?);
323                        }
324                    }
325                    Op::And | Op::Or | Op::Xor => {
326                        for operand_location in apply.operands {
327                            child_condition_results.push(evaluate_condition(
328                                nodes,
329                                operand_location,
330                                values,
331                                functions,
332                            )?);
333                        }
334                    }
335                    _ => {}
336                }
337
338                let condition_count = child_condition_results.len();
339                let true_count = child_condition_results.iter().filter(|x| **x).count();
340                match op {
341                    Op::Eq => {
342                        result =
343                            Some((operand_results[0] - operand_results[1]).abs() <= f64::EPSILON)
344                    }
345                    Op::Neq => {
346                        result =
347                            Some((operand_results[0] - operand_results[1]).abs() > f64::EPSILON)
348                    }
349                    Op::Gt => result = Some(operand_results[0] > operand_results[1]),
350                    Op::Lt => result = Some(operand_results[0] < operand_results[1]),
351                    Op::Geq => result = Some(operand_results[0] >= operand_results[1]),
352                    Op::Leq => result = Some(operand_results[0] <= operand_results[1]),
353                    Op::And => result = Some(condition_count == true_count),
354                    Op::Or => result = Some(true_count > 0),
355                    Op::Xor => {
356                        result = Some(true_count % 2 == 1);
357                        //dbg!(child_condition_results);
358                        //if result.unwrap() {
359                        //dbg!(result.unwrap());
360                        //}
361                    }
362                    _ => {}
363                }
364            }
365            if let Some(value) = result {
366                Ok(value)
367            } else {
368                Err("Condition not supported".to_string())
369            }
370        }
371        _ => {
372            let error = format!("Couldn't evaluate operator {}", head);
373            Err(error)
374        }
375    }
376}