radiate_gp/regression/
accuracy.rs1use super::{DataSet, Loss};
2use std::fmt::Debug;
3
4#[derive(Clone)]
5pub struct Accuracy<'a> {
6 name: String,
7 data_set: &'a DataSet,
8 loss_fn: Loss,
9}
10
11impl<'a> Accuracy<'a> {
12 pub fn new(name: impl Into<String>, data_set: &'a DataSet, loss_fn: Loss) -> Self {
13 Accuracy {
14 name: name.into(),
15 data_set,
16 loss_fn,
17 }
18 }
19
20 pub fn calc<F>(&self, mut eval: F) -> AccuracyResult
21 where
22 F: FnMut(&Vec<f32>) -> Vec<f32>,
23 {
24 let mut outputs = Vec::new();
25 let mut total_samples = 0.0;
26 let mut correct_predictions = 0.0;
27 let mut is_regression = true;
28
29 let mut mae = 0.0; let mut mse = 0.0; let mut min_output = f32::MAX;
32 let mut max_output = f32::MIN;
33 let mut ss_total = 0.0; let mut ss_residual = 0.0; let mut y_mean = 0.0;
36
37 let mut tp = 0.0; let mut fp = 0.0; let mut fn_ = 0.0; let loss = self.loss_fn.calculate(&self.data_set, &mut eval);
42
43 let total_values: usize = self.data_set.len();
45 if total_values > 0 {
46 y_mean =
47 self.data_set.iter().map(|row| row.output()[0]).sum::<f32>() / total_values as f32;
48 }
49
50 for row in self.data_set.iter() {
51 let output = eval(row.input());
52 outputs.push(output.clone());
53
54 if output.len() == 1 {
55 is_regression = true;
57 let y_true = row.output()[0];
58 let y_pred = output[0];
59
60 mae += (y_true - y_pred).abs();
61 mse += (y_true - y_pred).powi(2);
62 ss_residual += (y_true - y_pred).powi(2);
63 ss_total += (y_true - y_mean).powi(2);
64
65 min_output = min_output.min(y_true);
66 max_output = max_output.max(y_true);
67 total_samples += 1.0;
68 } else {
69 is_regression = false;
71 if let Some((max_idx, _)) = output
72 .iter()
73 .enumerate()
74 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
75 {
76 if let Some(target) = row.output().iter().position(|&x| x == 1.0) {
77 total_samples += 1.0;
78 if max_idx == target {
79 correct_predictions += 1.0;
80 tp += 1.0;
81 } else {
82 fp += 1.0;
83 }
84 } else {
85 fn_ += 1.0;
86 }
87 }
88 }
89 }
90
91 let accuracy = if is_regression {
93 if total_samples > 0.0 && (max_output - min_output) > 0.0 {
94 1.0 - (mae / total_samples) / (max_output - min_output) } else {
96 0.0
97 }
98 } else {
99 if total_samples > 0.0 {
100 correct_predictions / total_samples
101 } else {
102 0.0
103 }
104 };
105
106 let (precision, recall, f1_score) = if is_regression {
108 (0.0, 0.0, 0.0) } else {
110 let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
111 let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
112 let f1_score = if precision + recall > 0.0 {
113 2.0 * (precision * recall) / (precision + recall)
114 } else {
115 0.0
116 };
117 (precision, recall, f1_score)
118 };
119
120 let rmse = if total_samples > 0.0 {
121 (mse / total_samples).sqrt()
122 } else {
123 0.0
124 };
125
126 let r_squared = if ss_total > 0.0 {
128 1.0 - (ss_residual / ss_total)
129 } else {
130 0.0 };
132
133 AccuracyResult {
134 name: self.name.clone(),
135 accuracy,
136 precision,
137 recall,
138 f1_score,
139 rmse,
140 r_squared,
141 loss,
142 loss_fn: self.loss_fn.clone(),
143 sample_count: self.data_set.len(),
144 is_regression,
145 }
146 }
147}
148
149pub struct AccuracyResult {
150 name: String,
151 accuracy: f32,
152 precision: f32, recall: f32, f1_score: f32, rmse: f32, r_squared: f32, sample_count: usize,
158 loss: f32,
159 loss_fn: Loss,
160 is_regression: bool,
161}
162
163impl Debug for AccuracyResult {
164 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
165 if self.is_regression {
166 write!(
167 f,
168 "Regression Accuracy - {:?} {{\n\tN: {:?} \n\tAccuracy: {:.2}%\n\tR² Score: {:.5}\n\tRMSE: {:.5}\n\tLoss ({:?}): {:.5}\n}}",
169 self.name,
170 self.sample_count,
171 self.accuracy * 100.0,
172 self.r_squared,
173 self.rmse,
174 self.loss_fn,
175 self.loss
176 )
177 } else {
178 write!(
179 f,
180 "Classification Accuracy - {:?} {{\n\tN: {:?} \n\tAccuracy: {:.2}%\n\tPrecision: {:.2}%\n\tRecall: {:.2}%\n\tF1 Score: {:.2}%\n\tLoss ({:?}): {:.5}\n}}",
181 self.name,
182 self.sample_count,
183 self.accuracy * 100.0,
184 self.precision * 100.0,
185 self.recall * 100.0,
186 self.f1_score * 100.0,
187 self.loss_fn,
188 self.loss
189 )
190 }
191 }
192}