Skip to main content

radiate_gp/regression/
regression.rs

1use super::{DataSet, Loss};
2use crate::{Graph, GraphChromosome, GraphEvaluator, Op, Tree, TreeChromosome, eval::EvalIntoMut};
3use radiate_core::{BatchFitnessFunction, Genotype, fitness::FitnessFunction};
4use std::cell::RefCell;
5
6thread_local! {
7    static LOSS_BUFFER: RefCell<Vec<f32>> = RefCell::new(Vec::new());
8}
9
10#[derive(Clone)]
11pub struct Regression {
12    data_set: DataSet<f32>,
13    loss: Loss,
14}
15
16impl Regression {
17    pub fn new(sample_set: impl Into<DataSet<f32>>, loss: Loss) -> Self {
18        Regression {
19            data_set: sample_set.into(),
20            loss,
21        }
22    }
23
24    #[inline]
25    fn calc_into_buff_mut<EV>(&self, eval: &mut EV) -> f32
26    where
27        EV: EvalIntoMut<[f32], [f32]>,
28    {
29        let out_len = self.data_set.shape().2;
30        LOSS_BUFFER.with(|cell| {
31            let mut buf = cell.borrow_mut();
32
33            if buf.len() < out_len {
34                buf.resize(out_len, 0.0);
35            }
36
37            self.loss
38                .calculate(&self.data_set, &mut buf[..out_len], |x, y| {
39                    eval.eval_into_mut(x, y)
40                })
41        })
42    }
43}
44
45/// --- Graphs ---
46impl<'a> FitnessFunction<&'a Genotype<GraphChromosome<Op<f32>>>, f32> for Regression {
47    #[inline]
48    fn evaluate(&self, input: &'a Genotype<GraphChromosome<Op<f32>>>) -> f32 {
49        let mut evaluator = GraphEvaluator::new(&input[0]);
50        self.calc_into_buff_mut(&mut evaluator)
51    }
52}
53
54impl FitnessFunction<Graph<Op<f32>>, f32> for Regression {
55    #[inline]
56    fn evaluate(&self, input: Graph<Op<f32>>) -> f32 {
57        let mut evaluator = GraphEvaluator::new(&input);
58        self.calc_into_buff_mut(&mut evaluator)
59    }
60}
61
62impl BatchFitnessFunction<Graph<Op<f32>>, f32> for Regression {
63    #[inline]
64    fn evaluate(&self, inputs: Vec<Graph<Op<f32>>>) -> Vec<f32> {
65        let mut results = Vec::with_capacity(inputs.len());
66        for input in inputs {
67            let mut evaluator = GraphEvaluator::new(&input);
68            results.push(self.calc_into_buff_mut(&mut evaluator));
69        }
70
71        results
72    }
73}
74
75impl<'a> BatchFitnessFunction<&'a Genotype<GraphChromosome<Op<f32>>>, f32> for Regression {
76    #[inline]
77    fn evaluate(&self, inputs: Vec<&'a Genotype<GraphChromosome<Op<f32>>>>) -> Vec<f32> {
78        let mut results = Vec::with_capacity(inputs.len());
79        for input in inputs {
80            let mut evaluator = GraphEvaluator::new(&input[0]);
81            results.push(self.calc_into_buff_mut(&mut evaluator));
82        }
83
84        results
85    }
86}
87
88/// --- Trees ---
89impl FitnessFunction<Tree<Op<f32>>, f32> for Regression {
90    #[inline]
91    fn evaluate(&self, mut input: Tree<Op<f32>>) -> f32 {
92        self.calc_into_buff_mut(&mut input)
93    }
94}
95
96impl FitnessFunction<Vec<Tree<Op<f32>>>, f32> for Regression {
97    #[inline]
98    fn evaluate(&self, mut input: Vec<Tree<Op<f32>>>) -> f32 {
99        self.calc_into_buff_mut(&mut input)
100    }
101}
102
103impl BatchFitnessFunction<Tree<Op<f32>>, f32> for Regression {
104    #[inline]
105    fn evaluate(&self, mut inputs: Vec<Tree<Op<f32>>>) -> Vec<f32> {
106        let mut results = Vec::with_capacity(inputs.len());
107        for input in inputs.iter_mut() {
108            results.push(self.calc_into_buff_mut(input));
109        }
110
111        results
112    }
113}
114
115impl BatchFitnessFunction<Vec<Tree<Op<f32>>>, f32> for Regression {
116    #[inline]
117    fn evaluate(&self, mut inputs: Vec<Vec<Tree<Op<f32>>>>) -> Vec<f32> {
118        let mut results = Vec::with_capacity(inputs.len());
119        for input in inputs.iter_mut() {
120            results.push(self.calc_into_buff_mut(input));
121        }
122
123        results
124    }
125}
126
127impl<'a> FitnessFunction<&'a Genotype<TreeChromosome<Op<f32>>>, f32> for Regression {
128    #[inline]
129    fn evaluate(&self, input: &'a Genotype<TreeChromosome<Op<f32>>>) -> f32 {
130        let roots = input.iter().map(|c| c.root()).collect::<Vec<_>>();
131        self.calc_into_buff_mut(&mut roots.as_slice())
132    }
133}
134
135impl<'a> BatchFitnessFunction<&'a Genotype<TreeChromosome<Op<f32>>>, f32> for Regression {
136    #[inline]
137    fn evaluate(&self, inputs: Vec<&'a Genotype<TreeChromosome<Op<f32>>>>) -> Vec<f32> {
138        let mut results = Vec::with_capacity(inputs.len());
139        for input in inputs {
140            let roots = input.iter().map(|c| c.root()).collect::<Vec<_>>();
141            results.push(self.calc_into_buff_mut(&mut roots.as_slice()));
142        }
143
144        results
145    }
146}