number_diff/simplify/
operations.rs

1use std::sync::Arc;
2
3use crate::{
4    Elementary::{self, *},
5    Error,
6};
7
8use super::classification::Category;
9use super::polynomial;
10
11impl Elementary {
12    pub fn simplify(&self) -> Result<Self, Error> {
13        let new_function: Self = match self.classify()? {
14            Category::Constant => self.simplify_constant()?,
15            Category::Polynomial => polynomial::simplify_polynomial(self.clone())?,
16            Category::ClusterFuck => self.simplify_operations()?,
17            _ => self.clone(),
18        };
19
20        self.check_simplification(&new_function)
21    }
22
23    // makes sure that the simplified funciton is correct, that is, it will yield the same result
24    // upon calling for all numbers within its definition set.
25    fn check_simplification(&self, new_function: &Self) -> Result<Self, Error> {
26        for i in -1000..1000 {
27            if !self.is_within_margin(new_function, i as f64) {
28                return Err(Error::InternalError(String::from(
29                    format!("while attempting to simplify {self:?}, the simplification method yielded inconsistent results. Found that self({i}) != new_function({i})." ))));
30            }
31        }
32        Ok(new_function.to_owned())
33    }
34
35    fn is_within_margin(&self, other: &Self, point: f64) -> bool {
36        let callable_self = self.clone().call();
37        let callable_new = other.clone().call();
38
39        // if the original function returns Nan, then the simplification may have gotten rid of the
40        // issue
41        if callable_self(point).is_nan() {
42            return true;
43        } else if callable_self(point).is_infinite() && callable_new(point).is_infinite() {
44            return true;
45        } else {
46            // if not, the value must be within 1% of the original value + 1e-5
47            let margin = (callable_self(point) * 0.01).abs() + 1e-5;
48
49            callable_new(point) > (callable_self(point) - margin)
50                && callable_new(point) < (callable_self(point) + margin)
51        }
52    }
53
54    // used for functions of category ClusterFuck in order to break down and simplify each
55    // individual funciton individually.
56    pub fn simplify_operations(&self) -> Result<Self, Error> {
57        match self {
58            Mul(func1, func2) => Ok(func1.simplify()? * func2.simplify()?),
59            Div(func1, func2) => Ok((func1.simplify()? / func2.simplify()?).divide()?),
60            Add(func1, func2) => Ok(func1.simplify()? + func2.simplify()?),
61            Sub(func1, func2) => Ok(func1.simplify()? - func2.simplify()?),
62            Pow(func1, func2) => Self::simplify_power(func1, func2),
63            Log(func1, func2) => Ok(Log(
64                Arc::new(func1.simplify()?),
65                Arc::new(func2.simplify()?),
66            )),
67            _ => Ok(self.to_owned()),
68        }
69    }
70
71    pub fn simplify_power(base: &Self, exp: &Self) -> Result<Self, Error> {
72        match exp.clone() {
73            Con(numb) => {
74                if numb == 0. {
75                    Ok(Con(1.))
76                } else if numb == 1. {
77                    Ok(base.clone())
78                } else {
79                    match base {
80                        X => Ok(Pow(base.clone().into(), exp.clone().into())),
81                        Pow(inner_base, inner_exp) => Ok(Pow(
82                            inner_base.clone(),
83                            (exp.clone() * (**inner_exp).clone()).simplify()?.into(),
84                        )),
85                        _ => Ok(Pow(base.simplify()?.into(), exp.simplify()?.into())),
86                    }
87                }
88            }
89            _ => match base {
90                X => Ok(Pow(base.clone().into(), exp.clone().into())),
91                Pow(inner_base, inner_exp) => Ok(Pow(
92                    inner_base.clone(),
93                    (exp.clone() * (**inner_exp).clone()).simplify()?.into(),
94                )),
95                _ => Ok(Pow(base.simplify()?.into(), exp.simplify()?.into())),
96            },
97        }
98    }
99
100    pub fn divide(&self) -> Result<Self, Error> {
101        if let Div(numerator, denomenator) = self {
102            let numerator = numerator.factor()?;
103            let denomenator = denomenator.factor()?;
104
105            let mut removed_numerator: Vec<usize> = Vec::new();
106            let mut removed_denomenator: Vec<usize> = Vec::new();
107
108            let mut constant_factor = 1.;
109
110            for i in 0..numerator.len() {
111                for j in 0..denomenator.len() {
112                    if numerator[i] == denomenator[j] {
113                        removed_numerator.push(i);
114                        removed_denomenator.push(j);
115                    } else if let (Con(numb1), Con(numb2)) =
116                        (numerator[i].clone(), denomenator[j].clone())
117                    {
118                        constant_factor *= numb1 / numb2;
119                        removed_numerator.push(i);
120                        removed_denomenator.push(j);
121                    }
122                }
123            }
124
125            let mut new_numerator = Con(constant_factor);
126            for (i, term) in numerator.iter().enumerate() {
127                if !removed_numerator.contains(&i) {
128                    new_numerator *= term.clone();
129                }
130            }
131            let mut new_denomenator = Con(1.);
132            for (i, term) in denomenator.iter().enumerate() {
133                if !removed_denomenator.contains(&i) {
134                    new_denomenator *= term.clone();
135                }
136            }
137
138            if new_denomenator == Con(1.) {
139                Ok(new_numerator.simplify()?)
140            } else {
141                Ok(new_numerator / new_denomenator)
142            }
143        } else {
144            Err(Error::SimplifyError(
145                self.to_owned(),
146                String::from("Attempted to divide a non-divisible expression while simplifying"),
147            ))
148        }
149    }
150
151    pub fn factor(&self) -> Result<Vec<Self>, Error> {
152        let mut factors: Vec<Self> = Vec::new();
153        if let Mul(func1, func2) = self {
154            for factor in func1.factor()? {
155                factors.push(factor);
156            }
157            for factor in func2.factor()? {
158                factors.push(factor);
159            }
160        } else if let Add(func1, func2) = self {
161            for f1 in func1.factor()? {
162                for f2 in func2.factor()? {
163                    if f1.clone() == f2.clone() {
164                        factors.push(f1.clone());
165                        factors.push(
166                            (Div(func1.to_owned(), Arc::new(f1.clone())).divide()?
167                                + Div(func2.to_owned(), Arc::new(f2.clone())).divide()?)
168                            .simplify()?,
169                        );
170                    } else if let (Con(numb1), Con(numb2)) = (f1.clone(), f2.clone()) {
171                        let gcd = Con(gcd(numb1, numb2));
172                        factors.push(gcd.clone());
173                        factors.push(
174                            ((func1.clone() / gcd.clone()).divide()?)
175                                + (func2.clone() / gcd).divide()?,
176                        );
177                    }
178                }
179            }
180        } else {
181            factors.push(self.to_owned());
182        }
183
184        let res: Vec<Self> = factors
185            .iter()
186            .filter(|e| **e != Con(1.))
187            .map(|e| e.to_owned())
188            .collect();
189
190        Ok(res)
191    }
192
193    pub fn simplify_constant(&self) -> Result<Self, Error> {
194        if self.classify()? == Category::Constant {
195            let value = self.clone().call()(0.);
196            Ok(Con(value))
197        } else {
198            Err(Error::SimplifyError(
199                self.to_owned(),
200                String::from("Attempted to constant-simplify a non-constant expression"),
201            ))
202        }
203    }
204}
205
206fn gcd(numb1: f64, numb2: f64) -> f64 {
207    // create bindings
208    let mut numb1 = numb1;
209    let mut numb2 = numb2;
210
211    while numb2 != 0. {
212        let temp = numb1;
213        numb1 = numb2;
214        numb2 = temp % numb2;
215    }
216    numb1
217}