radiate_gp/regression/
regression.rs1use super::{DataSet, Loss};
2use crate::{Graph, GraphChromosome, GraphEvaluator, Op, Tree, TreeChromosome, eval::EvalIntoMut};
3use radiate_core::{BatchFitnessFunction, Genotype, fitness::FitnessFunction};
4use std::cell::RefCell;
5
6thread_local! {
7 static LOSS_BUFFER: RefCell<Vec<f32>> = RefCell::new(Vec::new());
8}
9
10#[derive(Clone)]
11pub struct Regression {
12 data_set: DataSet<f32>,
13 loss: Loss,
14}
15
16impl Regression {
17 pub fn new(sample_set: impl Into<DataSet<f32>>, loss: Loss) -> Self {
18 Regression {
19 data_set: sample_set.into(),
20 loss,
21 }
22 }
23
24 #[inline]
25 fn calc_into_buff_mut<EV>(&self, eval: &mut EV) -> f32
26 where
27 EV: EvalIntoMut<[f32], [f32]>,
28 {
29 let out_len = self.data_set.shape().2;
30 LOSS_BUFFER.with(|cell| {
31 let mut buf = cell.borrow_mut();
32
33 if buf.len() < out_len {
34 buf.resize(out_len, 0.0);
35 }
36
37 self.loss
38 .calculate(&self.data_set, &mut buf[..out_len], |x, y| {
39 eval.eval_into_mut(x, y)
40 })
41 })
42 }
43}
44
45impl<'a> FitnessFunction<&'a Genotype<GraphChromosome<Op<f32>>>, f32> for Regression {
47 #[inline]
48 fn evaluate(&self, input: &'a Genotype<GraphChromosome<Op<f32>>>) -> f32 {
49 let mut evaluator = GraphEvaluator::new(&input[0]);
50 self.calc_into_buff_mut(&mut evaluator)
51 }
52}
53
54impl FitnessFunction<Graph<Op<f32>>, f32> for Regression {
55 #[inline]
56 fn evaluate(&self, input: Graph<Op<f32>>) -> f32 {
57 let mut evaluator = GraphEvaluator::new(&input);
58 self.calc_into_buff_mut(&mut evaluator)
59 }
60}
61
62impl BatchFitnessFunction<Graph<Op<f32>>, f32> for Regression {
63 #[inline]
64 fn evaluate(&self, inputs: Vec<Graph<Op<f32>>>) -> Vec<f32> {
65 let mut results = Vec::with_capacity(inputs.len());
66 for input in inputs {
67 let mut evaluator = GraphEvaluator::new(&input);
68 results.push(self.calc_into_buff_mut(&mut evaluator));
69 }
70
71 results
72 }
73}
74
75impl<'a> BatchFitnessFunction<&'a Genotype<GraphChromosome<Op<f32>>>, f32> for Regression {
76 #[inline]
77 fn evaluate(&self, inputs: Vec<&'a Genotype<GraphChromosome<Op<f32>>>>) -> Vec<f32> {
78 let mut results = Vec::with_capacity(inputs.len());
79 for input in inputs {
80 let mut evaluator = GraphEvaluator::new(&input[0]);
81 results.push(self.calc_into_buff_mut(&mut evaluator));
82 }
83
84 results
85 }
86}
87
88impl FitnessFunction<Tree<Op<f32>>, f32> for Regression {
90 #[inline]
91 fn evaluate(&self, mut input: Tree<Op<f32>>) -> f32 {
92 self.calc_into_buff_mut(&mut input)
93 }
94}
95
96impl FitnessFunction<Vec<Tree<Op<f32>>>, f32> for Regression {
97 #[inline]
98 fn evaluate(&self, mut input: Vec<Tree<Op<f32>>>) -> f32 {
99 self.calc_into_buff_mut(&mut input)
100 }
101}
102
103impl BatchFitnessFunction<Tree<Op<f32>>, f32> for Regression {
104 #[inline]
105 fn evaluate(&self, mut inputs: Vec<Tree<Op<f32>>>) -> Vec<f32> {
106 let mut results = Vec::with_capacity(inputs.len());
107 for input in inputs.iter_mut() {
108 results.push(self.calc_into_buff_mut(input));
109 }
110
111 results
112 }
113}
114
115impl BatchFitnessFunction<Vec<Tree<Op<f32>>>, f32> for Regression {
116 #[inline]
117 fn evaluate(&self, mut inputs: Vec<Vec<Tree<Op<f32>>>>) -> Vec<f32> {
118 let mut results = Vec::with_capacity(inputs.len());
119 for input in inputs.iter_mut() {
120 results.push(self.calc_into_buff_mut(input));
121 }
122
123 results
124 }
125}
126
127impl<'a> FitnessFunction<&'a Genotype<TreeChromosome<Op<f32>>>, f32> for Regression {
128 #[inline]
129 fn evaluate(&self, input: &'a Genotype<TreeChromosome<Op<f32>>>) -> f32 {
130 let roots = input.iter().map(|c| c.root()).collect::<Vec<_>>();
131 self.calc_into_buff_mut(&mut roots.as_slice())
132 }
133}
134
135impl<'a> BatchFitnessFunction<&'a Genotype<TreeChromosome<Op<f32>>>, f32> for Regression {
136 #[inline]
137 fn evaluate(&self, inputs: Vec<&'a Genotype<TreeChromosome<Op<f32>>>>) -> Vec<f32> {
138 let mut results = Vec::with_capacity(inputs.len());
139 for input in inputs {
140 let roots = input.iter().map(|c| c.root()).collect::<Vec<_>>();
141 results.push(self.calc_into_buff_mut(&mut roots.as_slice()));
142 }
143
144 results
145 }
146}