radiate_gp/regression/
accuracy.rs

1use super::{DataSet, Loss};
2use std::fmt::Debug;
3
4#[derive(Clone)]
5pub struct Accuracy<'a> {
6    name: String,
7    data_set: &'a DataSet,
8    loss_fn: Loss,
9}
10
11impl<'a> Accuracy<'a> {
12    pub fn new(name: impl Into<String>, data_set: &'a DataSet, loss_fn: Loss) -> Self {
13        Accuracy {
14            name: name.into(),
15            data_set,
16            loss_fn,
17        }
18    }
19
20    pub fn calc<F>(&self, mut eval: F) -> AccuracyResult
21    where
22        F: FnMut(&Vec<f32>) -> Vec<f32>,
23    {
24        let mut outputs = Vec::new();
25        let mut total_samples = 0.0;
26        let mut correct_predictions = 0.0;
27        let mut is_regression = true;
28
29        let mut mae = 0.0; // Mean Absolute Error
30        let mut mse = 0.0; // Mean Squared Error
31        let mut min_output = f32::MAX;
32        let mut max_output = f32::MIN;
33        let mut ss_total = 0.0; // Sum of squares total for R²
34        let mut ss_residual = 0.0; // Sum of squares residual for R²
35        let mut y_mean = 0.0;
36
37        let mut tp = 0.0; // True Positives (for classification)
38        let mut fp = 0.0; // False Positives
39        let mut fn_ = 0.0; // False Negatives
40
41        let loss = self.loss_fn.calculate(&self.data_set, &mut eval);
42
43        // Compute the mean of actual values for R² calculation
44        let total_values: usize = self.data_set.len();
45        if total_values > 0 {
46            y_mean =
47                self.data_set.iter().map(|row| row.output()[0]).sum::<f32>() / total_values as f32;
48        }
49
50        for row in self.data_set.iter() {
51            let output = eval(row.input());
52            outputs.push(output.clone());
53
54            if output.len() == 1 {
55                // Regression case
56                is_regression = true;
57                let y_true = row.output()[0];
58                let y_pred = output[0];
59
60                mae += (y_true - y_pred).abs();
61                mse += (y_true - y_pred).powi(2);
62                ss_residual += (y_true - y_pred).powi(2);
63                ss_total += (y_true - y_mean).powi(2);
64
65                min_output = min_output.min(y_true);
66                max_output = max_output.max(y_true);
67                total_samples += 1.0;
68            } else {
69                // Classification case
70                is_regression = false;
71                if let Some((max_idx, _)) = output
72                    .iter()
73                    .enumerate()
74                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
75                {
76                    if let Some(target) = row.output().iter().position(|&x| x == 1.0) {
77                        total_samples += 1.0;
78                        if max_idx == target {
79                            correct_predictions += 1.0;
80                            tp += 1.0;
81                        } else {
82                            fp += 1.0;
83                        }
84                    } else {
85                        fn_ += 1.0;
86                    }
87                }
88            }
89        }
90
91        // Compute final accuracy
92        let accuracy = if is_regression {
93            if total_samples > 0.0 && (max_output - min_output) > 0.0 {
94                1.0 - (mae / total_samples) / (max_output - min_output) // Scaled MAE-based accuracy
95            } else {
96                0.0
97            }
98        } else {
99            if total_samples > 0.0 {
100                correct_predictions / total_samples
101            } else {
102                0.0
103            }
104        };
105
106        // Compute classification metrics only if it's a classification task
107        let (precision, recall, f1_score) = if is_regression {
108            (0.0, 0.0, 0.0) // Not applicable for regression
109        } else {
110            let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
111            let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
112            let f1_score = if precision + recall > 0.0 {
113                2.0 * (precision * recall) / (precision + recall)
114            } else {
115                0.0
116            };
117            (precision, recall, f1_score)
118        };
119
120        let rmse = if total_samples > 0.0 {
121            (mse / total_samples).sqrt()
122        } else {
123            0.0
124        };
125
126        // Compute R² score
127        let r_squared = if ss_total > 0.0 {
128            1.0 - (ss_residual / ss_total)
129        } else {
130            0.0 // If ss_total is 0, all y_true are the same, meaning the model is perfect (or there's no variance)
131        };
132
133        AccuracyResult {
134            name: self.name.clone(),
135            accuracy,
136            precision,
137            recall,
138            f1_score,
139            rmse,
140            r_squared,
141            loss,
142            loss_fn: self.loss_fn.clone(),
143            sample_count: self.data_set.len(),
144            is_regression,
145        }
146    }
147}
148
149pub struct AccuracyResult {
150    name: String,
151    accuracy: f32,
152    precision: f32, // Only for classification
153    recall: f32,    // Only for classification
154    f1_score: f32,  // Only for classification
155    rmse: f32,      // Only for regression
156    r_squared: f32, // Only for regression
157    sample_count: usize,
158    loss: f32,
159    loss_fn: Loss,
160    is_regression: bool,
161}
162
163impl Debug for AccuracyResult {
164    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
165        if self.is_regression {
166            write!(
167                f,
168                "Regression Accuracy - {:?} {{\n\tN: {:?} \n\tAccuracy: {:.2}%\n\tR² Score: {:.5}\n\tRMSE: {:.5}\n\tLoss ({:?}): {:.5}\n}}",
169                self.name,
170                self.sample_count,
171                self.accuracy * 100.0,
172                self.r_squared,
173                self.rmse,
174                self.loss_fn,
175                self.loss
176            )
177        } else {
178            write!(
179                f,
180                "Classification Accuracy - {:?} {{\n\tN: {:?} \n\tAccuracy: {:.2}%\n\tPrecision: {:.2}%\n\tRecall: {:.2}%\n\tF1 Score: {:.2}%\n\tLoss ({:?}): {:.5}\n}}",
181                self.name,
182                self.sample_count,
183                self.accuracy * 100.0,
184                self.precision * 100.0,
185                self.recall * 100.0,
186                self.f1_score * 100.0,
187                self.loss_fn,
188                self.loss
189            )
190        }
191    }
192}