radiate_gp/regression/
accuracy.rs1use super::{DataSet, Loss};
2use crate::{Eval, EvalMut, Graph, GraphEvaluator, Op, Tree};
3use std::fmt::Debug;
4
5#[derive(Clone, Default)]
6pub struct Accuracy<'a> {
7 name: Option<String>,
8 data_set: Option<&'a DataSet<f32>>,
9 loss_fn: Option<Loss>,
10}
11
12impl<'a> Accuracy<'a> {
13 pub fn named(mut self, name: impl Into<String>) -> Self {
14 self.name = Some(name.into());
15 self
16 }
17
18 pub fn on(mut self, data_set: &'a DataSet<f32>) -> Self {
19 self.data_set = Some(data_set);
20 self
21 }
22
23 pub fn loss(mut self, loss_fn: Loss) -> Self {
24 self.loss_fn = Some(loss_fn);
25 self
26 }
27
28 pub fn calc(&self, eval: &mut impl EvalMut<[f32], Vec<f32>>) -> AccuracyResult {
29 let data_set = self
30 .data_set
31 .expect("DataSet reference must be provided for accuracy calculation");
32 let loss_fn = self
33 .loss_fn
34 .expect("Loss function must be provided for accuracy calculation");
35
36 self.calc_internal(eval, data_set, loss_fn)
37 }
38
39 pub fn calc_internal(
40 &self,
41 eval: &mut impl EvalMut<[f32], Vec<f32>>,
42 data_set: &DataSet<f32>,
43 loss_fn: Loss,
44 ) -> AccuracyResult {
45 let mut outputs = Vec::new();
46 let mut total_samples = 0.0;
47 let mut correct_predictions = 0.0;
48 let mut is_regression = true;
49
50 let mut mae = 0.0;
51 let mut mse = 0.0;
52 let mut min_output = f32::MAX;
53 let mut max_output = f32::MIN;
54 let mut ss_total = 0.0;
55 let mut ss_residual = 0.0;
56 let mut y_mean = 0.0;
57
58 let mut tp = 0.0;
59 let mut fp = 0.0;
60 let mut fn_ = 0.0;
61
62 let loss = loss_fn.calc(data_set, eval);
63
64 let total_values = data_set.len();
65 if total_values > 0 {
66 y_mean = data_set.iter().map(|row| row.output()[0]).sum::<f32>() / total_values as f32;
67 }
68
69 for row in data_set.iter() {
70 let output = eval.eval_mut(row.input());
71 outputs.push(output.clone());
72
73 if output.len() == 1 {
74 is_regression = true;
75 let y_true = row.output()[0];
76 let y_pred = output[0];
77
78 mae += (y_true - y_pred).abs();
79 mse += (y_true - y_pred).powi(2);
80 ss_residual += (y_true - y_pred).powi(2);
81 ss_total += (y_true - y_mean).powi(2);
82
83 min_output = min_output.min(y_true);
84 max_output = max_output.max(y_true);
85 total_samples += 1.0;
86 } else {
87 is_regression = false;
88 if let Some((max_idx, _)) = output
89 .iter()
90 .enumerate()
91 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
92 {
93 if let Some(target) = row.output().iter().position(|&x| x == 1.0) {
94 total_samples += 1.0;
95 if max_idx == target {
96 correct_predictions += 1.0;
97 tp += 1.0;
98 } else {
99 fp += 1.0;
100 }
101 } else {
102 fn_ += 1.0;
103 }
104 }
105 }
106 }
107
108 let accuracy = if is_regression {
110 if total_samples > 0.0 && (max_output - min_output) > 0.0 {
111 1.0 - (mae / total_samples) / (max_output - min_output)
112 } else {
113 0.0
114 }
115 } else if total_samples > 0.0 {
116 correct_predictions / total_samples
117 } else {
118 0.0
119 };
120
121 let (precision, recall, f1_score) = if is_regression {
123 (0.0, 0.0, 0.0) } else {
125 let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
126 let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
127 let f1_score = if precision + recall > 0.0 {
128 2.0 * (precision * recall) / (precision + recall)
129 } else {
130 0.0
131 };
132 (precision, recall, f1_score)
133 };
134
135 let rmse = if total_samples > 0.0 {
136 (mse / total_samples).sqrt()
137 } else {
138 0.0
139 };
140
141 let r_squared = if ss_total > 0.0 {
143 1.0 - (ss_residual / ss_total)
144 } else {
145 0.0
146 };
147
148 AccuracyResult {
149 name: match &self.name {
150 Some(name) => name.clone(),
151 None => {
152 if is_regression {
153 "Regression Accuracy".to_string()
154 } else {
155 "Classification Accuracy".to_string()
156 }
157 }
158 },
159 accuracy,
160 precision,
161 recall,
162 f1_score,
163 rmse,
164 r_squared,
165 loss,
166 loss_fn,
167 sample_count: data_set.len(),
168 is_regression,
169 }
170 }
171}
172
173pub struct AccuracyResult {
174 name: String,
175 accuracy: f32,
176 precision: f32, recall: f32, f1_score: f32, rmse: f32, r_squared: f32, sample_count: usize,
182 loss: f32,
183 loss_fn: Loss,
184 is_regression: bool,
185}
186
187impl AccuracyResult {
188 pub fn name(&self) -> &str {
189 &self.name
190 }
191
192 pub fn accuracy(&self) -> f32 {
193 self.accuracy
194 }
195
196 pub fn precision(&self) -> f32 {
197 self.precision
198 }
199
200 pub fn recall(&self) -> f32 {
201 self.recall
202 }
203
204 pub fn f1_score(&self) -> f32 {
205 self.f1_score
206 }
207
208 pub fn rmse(&self) -> f32 {
209 self.rmse
210 }
211
212 pub fn r_squared(&self) -> f32 {
213 self.r_squared
214 }
215
216 pub fn sample_count(&self) -> usize {
217 self.sample_count
218 }
219
220 pub fn loss(&self) -> f32 {
221 self.loss
222 }
223
224 pub fn loss_fn(&self) -> Loss {
225 self.loss_fn
226 }
227}
228
229impl Debug for AccuracyResult {
230 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
231 if self.is_regression {
232 write!(
233 f,
234 "{:?} {{\n\tN: {:?} \n\tAccuracy: {:.2}%\n\tR² Score: {:.5}\n\tRMSE: {:.5}\n\tLoss ({:?}): {:.5}\n}}",
235 self.name,
236 self.sample_count,
237 self.accuracy * 100.0,
238 self.r_squared,
239 self.rmse,
240 self.loss_fn,
241 self.loss
242 )
243 } else {
244 write!(
245 f,
246 "{:?} {{\n\tN: {:?} \n\tAccuracy: {:.2}%\n\tPrecision: {:.2}%\n\tRecall: {:.2}%\n\tF1 Score: {:.2}%\n\tLoss ({:?}): {:.5}\n}}",
247 self.name,
248 self.sample_count,
249 self.accuracy * 100.0,
250 self.precision * 100.0,
251 self.recall * 100.0,
252 self.f1_score * 100.0,
253 self.loss_fn,
254 self.loss
255 )
256 }
257 }
258}
259
260impl Eval<Graph<Op<f32>>, Option<AccuracyResult>> for Accuracy<'_> {
261 fn eval(&self, graph: &Graph<Op<f32>>) -> Option<AccuracyResult> {
262 let mut evaluator = GraphEvaluator::new(graph);
263 Some(self.calc(&mut evaluator))
264 }
265}
266
267impl Eval<Tree<Op<f32>>, Option<AccuracyResult>> for Accuracy<'_> {
268 fn eval(&self, tree: &Tree<Op<f32>>) -> Option<AccuracyResult> {
269 Some(self.calc(&mut tree.clone()))
270 }
271}
272
273impl Eval<Vec<Tree<Op<f32>>>, Option<AccuracyResult>> for Accuracy<'_> {
274 fn eval(&self, trees: &Vec<Tree<Op<f32>>>) -> Option<AccuracyResult> {
275 let mut cloned_trees = trees.clone();
276 Some(self.calc(&mut cloned_trees))
277 }
278}