symbolic_regression/
symbolic_regression.rs

1extern crate rand;
2extern crate evco;
3
4use std::fmt;
5use rand::{OsRng, Rng};
6use std::cmp::Ordering;
7use std::collections::BinaryHeap;
8
9use evco::gp::*;
10use evco::gp::tree::*;
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13enum Equation {
14    Add(BoxTree<Equation>, BoxTree<Equation>),
15    Sub(BoxTree<Equation>, BoxTree<Equation>),
16    Mul(BoxTree<Equation>, BoxTree<Equation>),
17    Div(BoxTree<Equation>, BoxTree<Equation>),
18    Neg(BoxTree<Equation>),
19    Sin(BoxTree<Equation>),
20    Cos(BoxTree<Equation>),
21    Int(i64),
22    Input,
23}
24
25use Equation::*;
26
27impl Tree for Equation {
28    type Environment = f64;
29    type Action = f64;
30
31    fn branch<R: Rng>(tg: &mut TreeGen<R>, current_depth: usize) -> BoxTree<Self> {
32        let left = Self::child(tg, current_depth + 1);
33        let right = Self::child(tg, current_depth + 1);
34        match tg.gen_range(0, 7) {
35                0 => Add(left, right),
36                1 => Sub(left, right),
37                2 => Mul(left, right),
38                3 => Div(left, right),
39                4 => Neg(left),
40                5 => Sin(left),
41                6 => Cos(left),
42                _ => unreachable!(),
43            }
44            .into()
45    }
46
47    fn leaf<R: Rng>(tg: &mut TreeGen<R>, _: usize) -> BoxTree<Self> {
48        match tg.gen_range(0, 2) {
49                0 => Int(tg.gen_range(-1, 2)),
50                1 => Input,
51                _ => unreachable!(),
52            }
53            .into()
54    }
55
56    fn count_children(&mut self) -> usize {
57        match *self {
58            Int(_) => 0,
59            _ => 2,
60        }
61    }
62
63    fn children(&self) -> Vec<&BoxTree<Self>> {
64        match *self {
65            Add(ref left, ref right) |
66            Sub(ref left, ref right) |
67            Mul(ref left, ref right) |
68            Div(ref left, ref right) => vec![left, right],
69            Neg(ref left) | Sin(ref left) | Cos(ref left) => vec![left],
70            Int(_) | Input => vec![],
71        }
72    }
73
74    fn children_mut(&mut self) -> Vec<&mut BoxTree<Self>> {
75        match *self {
76            Add(ref mut left, ref mut right) |
77            Sub(ref mut left, ref mut right) |
78            Mul(ref mut left, ref mut right) |
79            Div(ref mut left, ref mut right) => vec![left, right],
80            Neg(ref mut left) |
81            Sin(ref mut left) |
82            Cos(ref mut left) => vec![left],
83            Int(_) | Input => vec![],
84        }
85    }
86
87    fn evaluate(&self, env: &Self::Environment) -> Self::Action {
88        match *self {
89            Add(ref left, ref right) => left.evaluate(env) + right.evaluate(env),
90            Sub(ref left, ref right) => left.evaluate(env) - right.evaluate(env),
91            Mul(ref left, ref right) => left.evaluate(env) * right.evaluate(env),
92            Div(ref left, ref right) => protected_div(left.evaluate(env), right.evaluate(env)),
93            Neg(ref left) => -left.evaluate(env),
94            Sin(ref left) => left.evaluate(env).sin(),
95            Cos(ref left) => left.evaluate(env).cos(),
96            Int(i) => i as f64,
97            Input => *env,
98        }
99    }
100}
101
102impl fmt::Display for Equation {
103    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
104        match *self {
105            Add(ref left, ref right) => write!(f, "{} + {}", left, right),
106            Sub(ref left, ref right) => write!(f, "{} - {}", left, right),
107            Mul(ref left, ref right) => write!(f, "({}) * ({})", left, right),
108            Div(ref left, ref right) => write!(f, "({}) / ({})", left, right),
109            Neg(ref left) => write!(f, "-{}", left),
110            Sin(ref left) => write!(f, "sin({})", left),
111            Cos(ref left) => write!(f, "cos({})", left),
112            Int(i) => write!(f, "({})", i),
113            Input => write!(f, "x"),
114        }
115    }
116}
117// Add(Int(1), Div(Input, Input))
118// 1 + x/x
119// (1) + (x) / (x)xx(1)(x) / (x)xx
120// 1 + x /
121
122fn protected_div(numerator: f64, denominator: f64) -> f64 {
123    let div = numerator / denominator;
124    if div.is_finite() { div } else { 1.0 }
125}
126
127#[derive(Debug, Clone)]
128struct RankedIndividual(f64, Individual<Equation>);
129
130impl Ord for RankedIndividual {
131    fn cmp(&self, other: &RankedIndividual) -> Ordering {
132        self.0.partial_cmp(&other.0).unwrap_or(Ordering::Greater)
133    }
134}
135
136impl PartialOrd for RankedIndividual {
137    fn partial_cmp(&self, other: &RankedIndividual) -> Option<std::cmp::Ordering> {
138        Some(self.cmp(other))
139    }
140}
141
142impl PartialEq for RankedIndividual {
143    fn eq(&self, other: &RankedIndividual) -> bool {
144        self.0 == other.0
145    }
146}
147
148impl Eq for RankedIndividual {}
149
150fn main() {
151    let mut rng = OsRng::new().unwrap();
152    let mut tree_gen = TreeGen::full(&mut rng, 1, 4);
153
154    let mut rng = OsRng::new().unwrap();
155    let crossover = Crossover::one_point();
156
157    let mut mutate_rng = OsRng::new().unwrap();
158    let mut mut_tree_gen = TreeGen::full(&mut mutate_rng, 1, 2);
159    let mutation = Mutation::uniform();
160
161    let inputs: Vec<f64> = (-10..11).map(|i| (i as f64) / 10.0).collect();
162    let expecteds: Vec<f64> = inputs.iter()
163        .cloned()
164        .map(|i| i.powi(4) + i.powi(3) + i.powi(2) + i)
165        .collect();
166
167    let mut population: Vec<Individual<Equation>> =
168        (0..200).map(|_| Individual::new(&mut tree_gen)).collect();
169    for round in 0..40 {
170        let mut ranking = BinaryHeap::new();
171        for individual in population.drain(..) {
172            let mut sum_of_squared_errors = 0.0;
173            for i in 0..inputs.len() {
174                let input = inputs[i];
175                let expected = expecteds[i];
176                let output = individual.tree.evaluate(&input);
177                let squared_error = (output - expected).powi(2);
178                sum_of_squared_errors += squared_error;
179            }
180            if !sum_of_squared_errors.is_finite() {
181                sum_of_squared_errors = 100000000000.0;
182            }
183            ranking.push(RankedIndividual(sum_of_squared_errors, individual));
184        }
185
186        let ranking = ranking.into_sorted_vec();
187        //println!("{:?}", ranking);
188
189        println!("=== ROUND {} ===", round);
190        for i in 0..3 {
191            println!("Rank {:?}\n  Range = [-1.0, 1.0]    Step = +0.1\n  Comparing to x^4 + x^3 \
192                      + x^2 + x\n  Sum of squared error = {}\n  Equation = {}",
193                     i,
194                     ranking[i].0,
195                     ranking[i].1);
196        }
197
198        for i in 0..100 {
199            let RankedIndividual(_, mut indv1) = ranking[i].clone();
200            let RankedIndividual(_, mut indv2) = ranking[i + 1].clone();
201
202            population.push(indv1.clone());
203            population.push(indv2.clone());
204
205            crossover.mate(&mut indv1, &mut indv2, &mut rng);
206
207            if rng.gen() {
208                mutation.mutate(&mut indv1, &mut mut_tree_gen);
209            }
210            if rng.gen() {
211                mutation.mutate(&mut indv2, &mut mut_tree_gen);
212            }
213
214            population.push(indv1);
215            population.push(indv2);
216        }
217
218        println!();
219    }
220}