1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
use std::sync::Arc;

use crate::{
    Elementary::{self, *},
    Error,
};

use super::classification::Category;
use super::polynomial;

impl Elementary {
    pub fn simplify(&self) -> Result<Self, Error> {
        let new_function: Self = match self.classify()? {
            Category::Constant => self.simplify_constant()?,
            Category::Polynomial => polynomial::simplify_polynomial(self.clone())?,
            Category::ClusterFuck => self.simplify_operations()?,
            _ => self.clone(),
        };

        self.check_simplification(&new_function)
    }

    // makes sure that the simplified funciton is correct, that is, it will yield the same result
    // upon calling for all numbers within its definition set.
    fn check_simplification(&self, new_function: &Self) -> Result<Self, Error> {
        for i in -1000..1000 {
            if !self.is_within_margin(new_function, i as f64) {
                return Err(Error::InternalError(String::from(
                    format!("while attempting to simplify {self:?}, the simplification method yielded inconsistent results. Found that self({i}) != new_function({i})." ))));
            }
        }
        Ok(new_function.to_owned())
    }

    fn is_within_margin(&self, other: &Self, point: f64) -> bool {
        let callable_self = self.clone().call();
        let callable_new = other.clone().call();

        // if the original function returns Nan, then the simplification may have gotten rid of the
        // issue
        if callable_self(point).is_nan() {
            return true;
        } else if callable_self(point).is_infinite() && callable_new(point).is_infinite() {
            return true;
        } else {
            // if not, the value must be within 1% of the original value + 1e-5
            let margin = (callable_self(point) * 0.01).abs() + 1e-5;

            callable_new(point) > (callable_self(point) - margin)
                && callable_new(point) < (callable_self(point) + margin)
        }
    }

    // used for functions of category ClusterFuck in order to break down and simplify each
    // individual funciton individually.
    pub fn simplify_operations(&self) -> Result<Self, Error> {
        match self {
            Mul(func1, func2) => Ok(func1.simplify()? * func2.simplify()?),
            Div(func1, func2) => Ok((func1.simplify()? / func2.simplify()?).divide()?),
            Add(func1, func2) => Ok(func1.simplify()? + func2.simplify()?),
            Sub(func1, func2) => Ok(func1.simplify()? - func2.simplify()?),
            Pow(func1, func2) => Self::simplify_power(func1, func2),
            Log(func1, func2) => Ok(Log(
                Arc::new(func1.simplify()?),
                Arc::new(func2.simplify()?),
            )),
            _ => Ok(self.to_owned()),
        }
    }

    pub fn simplify_power(base: &Self, exp: &Self) -> Result<Self, Error> {
        match exp.clone() {
            Con(numb) => {
                if numb == 0. {
                    Ok(Con(1.))
                } else if numb == 1. {
                    Ok(base.clone())
                } else {
                    match base {
                        X => Ok(Pow(base.clone().into(), exp.clone().into())),
                        Pow(inner_base, inner_exp) => Ok(Pow(
                            inner_base.clone(),
                            (exp.clone() * (**inner_exp).clone()).simplify()?.into(),
                        )),
                        _ => Ok(Pow(base.simplify()?.into(), exp.simplify()?.into())),
                    }
                }
            }
            _ => match base {
                X => Ok(Pow(base.clone().into(), exp.clone().into())),
                Pow(inner_base, inner_exp) => Ok(Pow(
                    inner_base.clone(),
                    (exp.clone() * (**inner_exp).clone()).simplify()?.into(),
                )),
                _ => Ok(Pow(base.simplify()?.into(), exp.simplify()?.into())),
            },
        }
    }

    pub fn divide(&self) -> Result<Self, Error> {
        if let Div(numerator, denomenator) = self {
            let numerator = numerator.factor()?;
            let denomenator = denomenator.factor()?;

            let mut removed_numerator: Vec<usize> = Vec::new();
            let mut removed_denomenator: Vec<usize> = Vec::new();

            let mut constant_factor = 1.;

            for i in 0..numerator.len() {
                for j in 0..denomenator.len() {
                    if numerator[i] == denomenator[j] {
                        removed_numerator.push(i);
                        removed_denomenator.push(j);
                    } else if let (Con(numb1), Con(numb2)) =
                        (numerator[i].clone(), denomenator[j].clone())
                    {
                        constant_factor *= numb1 / numb2;
                        removed_numerator.push(i);
                        removed_denomenator.push(j);
                    }
                }
            }

            let mut new_numerator = Con(constant_factor);
            for (i, term) in numerator.iter().enumerate() {
                if !removed_numerator.contains(&i) {
                    new_numerator *= term.clone();
                }
            }
            let mut new_denomenator = Con(1.);
            for (i, term) in denomenator.iter().enumerate() {
                if !removed_denomenator.contains(&i) {
                    new_denomenator *= term.clone();
                }
            }

            if new_denomenator == Con(1.) {
                Ok(new_numerator.simplify()?)
            } else {
                Ok(new_numerator / new_denomenator)
            }
        } else {
            Err(Error::SimplifyError(
                self.to_owned(),
                String::from("Attempted to divide a non-divisible expression while simplifying"),
            ))
        }
    }

    pub fn factor(&self) -> Result<Vec<Self>, Error> {
        let mut factors: Vec<Self> = Vec::new();
        if let Mul(func1, func2) = self {
            for factor in func1.factor()? {
                factors.push(factor);
            }
            for factor in func2.factor()? {
                factors.push(factor);
            }
        } else if let Add(func1, func2) = self {
            for f1 in func1.factor()? {
                for f2 in func2.factor()? {
                    if f1.clone() == f2.clone() {
                        factors.push(f1.clone());
                        factors.push(
                            (Div(func1.to_owned(), Arc::new(f1.clone())).divide()?
                                + Div(func2.to_owned(), Arc::new(f2.clone())).divide()?)
                            .simplify()?,
                        );
                    } else if let (Con(numb1), Con(numb2)) = (f1.clone(), f2.clone()) {
                        let gcd = Con(gcd(numb1, numb2));
                        factors.push(gcd.clone());
                        factors.push(
                            ((func1.clone() / gcd.clone()).divide()?)
                                + (func2.clone() / gcd).divide()?,
                        );
                    }
                }
            }
        } else {
            factors.push(self.to_owned());
        }

        let res: Vec<Self> = factors
            .iter()
            .filter(|e| **e != Con(1.))
            .map(|e| e.to_owned())
            .collect();

        Ok(res)
    }

    pub fn simplify_constant(&self) -> Result<Self, Error> {
        if self.classify()? == Category::Constant {
            let value = self.clone().call()(0.);
            Ok(Con(value))
        } else {
            Err(Error::SimplifyError(
                self.to_owned(),
                String::from("Attempted to constant-simplify a non-constant expression"),
            ))
        }
    }
}

fn gcd(numb1: f64, numb2: f64) -> f64 {
    // create bindings
    let mut numb1 = numb1;
    let mut numb2 = numb2;

    while numb2 != 0. {
        let temp = numb1;
        numb1 = numb2;
        numb2 = temp % numb2;
    }
    numb1
}