Skip to main content

radiate_gp/regression/
accuracy.rs

1use super::{DataSet, Loss};
2use crate::{Eval, EvalMut, Graph, GraphEvaluator, Op, Tree};
3use std::fmt::Debug;
4
5#[derive(Clone, Default)]
6pub struct Accuracy<'a> {
7    name: Option<String>,
8    data_set: Option<&'a DataSet<f32>>,
9    loss_fn: Option<Loss>,
10}
11
12impl<'a> Accuracy<'a> {
13    pub fn named(mut self, name: impl Into<String>) -> Self {
14        self.name = Some(name.into());
15        self
16    }
17
18    pub fn on(mut self, data_set: &'a DataSet<f32>) -> Self {
19        self.data_set = Some(data_set);
20        self
21    }
22
23    pub fn loss(mut self, loss_fn: Loss) -> Self {
24        self.loss_fn = Some(loss_fn);
25        self
26    }
27
28    pub fn calc(&self, eval: &mut impl EvalMut<[f32], Vec<f32>>) -> AccuracyResult {
29        let data_set = self
30            .data_set
31            .expect("DataSet reference must be provided for accuracy calculation");
32        let loss_fn = self
33            .loss_fn
34            .expect("Loss function must be provided for accuracy calculation");
35
36        self.calc_internal(eval, data_set, loss_fn)
37    }
38
39    pub fn calc_internal(
40        &self,
41        eval: &mut impl EvalMut<[f32], Vec<f32>>,
42        data_set: &DataSet<f32>,
43        loss_fn: Loss,
44    ) -> AccuracyResult {
45        let mut outputs = Vec::new();
46        let mut total_samples = 0.0;
47        let mut correct_predictions = 0.0;
48        let mut is_regression = true;
49
50        let mut mae = 0.0;
51        let mut mse = 0.0;
52        let mut min_output = f32::MAX;
53        let mut max_output = f32::MIN;
54        let mut ss_total = 0.0;
55        let mut ss_residual = 0.0;
56        let mut y_mean = 0.0;
57
58        let mut tp = 0.0;
59        let mut fp = 0.0;
60        let mut fn_ = 0.0;
61
62        let loss = loss_fn.calc(data_set, eval);
63
64        let total_values = data_set.len();
65        if total_values > 0 {
66            y_mean = data_set.iter().map(|row| row.output()[0]).sum::<f32>() / total_values as f32;
67        }
68
69        for row in data_set.iter() {
70            let output = eval.eval_mut(row.input());
71            outputs.push(output.clone());
72
73            if output.len() == 1 {
74                is_regression = true;
75                let y_true = row.output()[0];
76                let y_pred = output[0];
77
78                mae += (y_true - y_pred).abs();
79                mse += (y_true - y_pred).powi(2);
80                ss_residual += (y_true - y_pred).powi(2);
81                ss_total += (y_true - y_mean).powi(2);
82
83                min_output = min_output.min(y_true);
84                max_output = max_output.max(y_true);
85                total_samples += 1.0;
86            } else {
87                is_regression = false;
88                if let Some((max_idx, _)) = output
89                    .iter()
90                    .enumerate()
91                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
92                {
93                    if let Some(target) = row.output().iter().position(|&x| x == 1.0) {
94                        total_samples += 1.0;
95                        if max_idx == target {
96                            correct_predictions += 1.0;
97                            tp += 1.0;
98                        } else {
99                            fp += 1.0;
100                        }
101                    } else {
102                        fn_ += 1.0;
103                    }
104                }
105            }
106        }
107
108        // Compute final accuracy
109        let accuracy = if is_regression {
110            if total_samples > 0.0 && (max_output - min_output) > 0.0 {
111                1.0 - (mae / total_samples) / (max_output - min_output)
112            } else {
113                0.0
114            }
115        } else if total_samples > 0.0 {
116            correct_predictions / total_samples
117        } else {
118            0.0
119        };
120
121        // Compute classification metrics only if it's a classification task
122        let (precision, recall, f1_score) = if is_regression {
123            (0.0, 0.0, 0.0) // Not applicable for regression
124        } else {
125            let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
126            let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
127            let f1_score = if precision + recall > 0.0 {
128                2.0 * (precision * recall) / (precision + recall)
129            } else {
130                0.0
131            };
132            (precision, recall, f1_score)
133        };
134
135        let rmse = if total_samples > 0.0 {
136            (mse / total_samples).sqrt()
137        } else {
138            0.0
139        };
140
141        // Compute R² score
142        let r_squared = if ss_total > 0.0 {
143            1.0 - (ss_residual / ss_total)
144        } else {
145            0.0
146        };
147
148        AccuracyResult {
149            name: match &self.name {
150                Some(name) => name.clone(),
151                None => {
152                    if is_regression {
153                        "Regression Accuracy".to_string()
154                    } else {
155                        "Classification Accuracy".to_string()
156                    }
157                }
158            },
159            accuracy,
160            precision,
161            recall,
162            f1_score,
163            rmse,
164            r_squared,
165            loss,
166            loss_fn,
167            sample_count: data_set.len(),
168            is_regression,
169        }
170    }
171}
172
173pub struct AccuracyResult {
174    name: String,
175    accuracy: f32,
176    precision: f32, // Only for classification
177    recall: f32,    // Only for classification
178    f1_score: f32,  // Only for classification
179    rmse: f32,      // Only for regression
180    r_squared: f32, // Only for regression
181    sample_count: usize,
182    loss: f32,
183    loss_fn: Loss,
184    is_regression: bool,
185}
186
187impl AccuracyResult {
188    pub fn name(&self) -> &str {
189        &self.name
190    }
191
192    pub fn accuracy(&self) -> f32 {
193        self.accuracy
194    }
195
196    pub fn precision(&self) -> f32 {
197        self.precision
198    }
199
200    pub fn recall(&self) -> f32 {
201        self.recall
202    }
203
204    pub fn f1_score(&self) -> f32 {
205        self.f1_score
206    }
207
208    pub fn rmse(&self) -> f32 {
209        self.rmse
210    }
211
212    pub fn r_squared(&self) -> f32 {
213        self.r_squared
214    }
215
216    pub fn sample_count(&self) -> usize {
217        self.sample_count
218    }
219
220    pub fn loss(&self) -> f32 {
221        self.loss
222    }
223
224    pub fn loss_fn(&self) -> Loss {
225        self.loss_fn
226    }
227}
228
229impl Debug for AccuracyResult {
230    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
231        if self.is_regression {
232            write!(
233                f,
234                "{:?} {{\n\tN: {:?} \n\tAccuracy: {:.2}%\n\tR² Score: {:.5}\n\tRMSE: {:.5}\n\tLoss ({:?}): {:.5}\n}}",
235                self.name,
236                self.sample_count,
237                self.accuracy * 100.0,
238                self.r_squared,
239                self.rmse,
240                self.loss_fn,
241                self.loss
242            )
243        } else {
244            write!(
245                f,
246                "{:?} {{\n\tN: {:?} \n\tAccuracy: {:.2}%\n\tPrecision: {:.2}%\n\tRecall: {:.2}%\n\tF1 Score: {:.2}%\n\tLoss ({:?}): {:.5}\n}}",
247                self.name,
248                self.sample_count,
249                self.accuracy * 100.0,
250                self.precision * 100.0,
251                self.recall * 100.0,
252                self.f1_score * 100.0,
253                self.loss_fn,
254                self.loss
255            )
256        }
257    }
258}
259
260impl Eval<Graph<Op<f32>>, Option<AccuracyResult>> for Accuracy<'_> {
261    fn eval(&self, graph: &Graph<Op<f32>>) -> Option<AccuracyResult> {
262        let mut evaluator = GraphEvaluator::new(graph);
263        Some(self.calc(&mut evaluator))
264    }
265}
266
267impl Eval<Tree<Op<f32>>, Option<AccuracyResult>> for Accuracy<'_> {
268    fn eval(&self, tree: &Tree<Op<f32>>) -> Option<AccuracyResult> {
269        Some(self.calc(&mut tree.clone()))
270    }
271}
272
273impl Eval<Vec<Tree<Op<f32>>>, Option<AccuracyResult>> for Accuracy<'_> {
274    fn eval(&self, trees: &Vec<Tree<Op<f32>>>) -> Option<AccuracyResult> {
275        let mut cloned_trees = trees.clone();
276        Some(self.calc(&mut cloned_trees))
277    }
278}