entrenar/storage/preflight/checks/
data_integrity.rs1use std::collections::HashMap;
4
5use super::{CheckResult, CheckType, PreflightCheck};
6
7impl PreflightCheck {
8 pub fn no_nan_values() -> Self {
14 Self::new(
15 "no_nan_values",
16 CheckType::DataIntegrity,
17 "Ensures no NaN values exist in the dataset",
18 |data, _ctx| {
19 let mut nan_count = 0;
20 let mut nan_locations = Vec::new();
21
22 for (row_idx, row) in data.iter().enumerate() {
23 for (col_idx, val) in row.iter().enumerate() {
24 if val.is_nan() {
25 nan_count += 1;
26 if nan_locations.len() < 5 {
27 nan_locations.push(format!("({row_idx}, {col_idx})"));
28 }
29 }
30 }
31 }
32
33 if nan_count == 0 {
34 CheckResult::passed("No NaN values found")
35 } else {
36 CheckResult::failed_with_details(
37 format!("Found {nan_count} NaN values"),
38 format!("First locations: {}", nan_locations.join(", ")),
39 )
40 }
41 },
42 )
43 }
44
45 pub fn no_inf_values() -> Self {
47 Self::new(
48 "no_inf_values",
49 CheckType::DataIntegrity,
50 "Ensures no infinite values exist in the dataset",
51 |data, _ctx| {
52 let mut inf_count = 0;
53 let mut inf_locations = Vec::new();
54
55 for (row_idx, row) in data.iter().enumerate() {
56 for (col_idx, val) in row.iter().enumerate() {
57 if val.is_infinite() {
58 inf_count += 1;
59 if inf_locations.len() < 5 {
60 inf_locations.push(format!("({row_idx}, {col_idx})"));
61 }
62 }
63 }
64 }
65
66 if inf_count == 0 {
67 CheckResult::passed("No infinite values found")
68 } else {
69 CheckResult::failed_with_details(
70 format!("Found {inf_count} infinite values"),
71 format!("First locations: {}", inf_locations.join(", ")),
72 )
73 }
74 },
75 )
76 }
77
78 pub fn min_samples(min: usize) -> Self {
80 Self::new(
81 "min_samples",
82 CheckType::DataIntegrity,
83 format!("Ensures at least {min} samples exist"),
84 move |data, ctx| {
85 let min_required = ctx.min_samples.unwrap_or(min);
86 let actual = data.len();
87
88 if actual >= min_required {
89 CheckResult::passed(format!("Found {actual} samples (minimum: {min_required})"))
90 } else {
91 CheckResult::failed(format!(
92 "Only {actual} samples found (minimum: {min_required})"
93 ))
94 }
95 },
96 )
97 }
98
99 pub fn min_features(min: usize) -> Self {
101 Self::new(
102 "min_features",
103 CheckType::DataIntegrity,
104 format!("Ensures at least {min} features exist"),
105 move |data, ctx| {
106 let min_required = ctx.min_features.unwrap_or(min);
107 let actual = data.first().map_or(0, Vec::len);
108
109 if actual >= min_required {
110 CheckResult::passed(format!(
111 "Found {actual} features (minimum: {min_required})"
112 ))
113 } else {
114 CheckResult::failed(format!(
115 "Only {actual} features found (minimum: {min_required})"
116 ))
117 }
118 },
119 )
120 }
121
122 pub fn consistent_dimensions() -> Self {
124 Self::new(
125 "consistent_dimensions",
126 CheckType::DataIntegrity,
127 "Ensures all rows have the same number of features",
128 |data, _ctx| {
129 if data.is_empty() {
130 return CheckResult::skipped("No data to check");
131 }
132
133 let expected_len = data[0].len();
134 let mut inconsistent = Vec::new();
135
136 for (idx, row) in data.iter().enumerate() {
137 if row.len() != expected_len {
138 inconsistent.push(format!("row {idx}: {} features", row.len()));
139 if inconsistent.len() >= 5 {
140 break;
141 }
142 }
143 }
144
145 if inconsistent.is_empty() {
146 CheckResult::passed(format!(
147 "All {} rows have {expected_len} features",
148 data.len()
149 ))
150 } else {
151 CheckResult::failed_with_details(
152 format!("Inconsistent dimensions (expected {expected_len} features)"),
153 inconsistent.join(", "),
154 )
155 }
156 },
157 )
158 }
159
160 pub fn no_constant_features() -> Self {
162 Self::new(
163 "no_constant_features",
164 CheckType::DataIntegrity,
165 "Ensures no features have zero variance",
166 |data, _ctx| {
167 if data.is_empty() || data[0].is_empty() {
168 return CheckResult::skipped("No data to check");
169 }
170
171 let n_features = data[0].len();
172 let mut constant_features = Vec::new();
173
174 for col in 0..n_features {
175 let values: Vec<f64> = data.iter().map(|row| row[col]).collect();
176 let first = values[0];
177
178 if values.iter().all(|v| (*v - first).abs() < f64::EPSILON) {
179 constant_features.push(col);
180 }
181 }
182
183 if constant_features.is_empty() {
184 CheckResult::passed("No constant features found")
185 } else {
186 CheckResult::warning(format!(
187 "Found {} constant feature(s): {:?}",
188 constant_features.len(),
189 constant_features
190 ))
191 }
192 },
193 )
194 .optional()
195 }
196
197 pub fn label_balance(max_imbalance_ratio: f64) -> Self {
199 Self::new(
200 "label_balance",
201 CheckType::DataIntegrity,
202 format!("Ensures class imbalance ratio <= {max_imbalance_ratio}"),
203 move |data, _ctx| {
204 if data.is_empty() || data[0].is_empty() {
205 return CheckResult::skipped("No data to check");
206 }
207
208 let labels: Vec<i64> =
210 data.iter().map(|row| *row.last().unwrap_or(&0.0) as i64).collect();
211
212 let mut counts: HashMap<i64, usize> = HashMap::new();
213 for label in &labels {
214 *counts.entry(*label).or_default() += 1;
215 }
216
217 if counts.is_empty() {
218 return CheckResult::skipped("No labels found");
219 }
220
221 let max_count = *counts.values().max().unwrap_or(&0);
222 let min_count = *counts.values().min().unwrap_or(&0);
223
224 if min_count == 0 {
225 return CheckResult::failed("One or more classes have zero samples");
226 }
227
228 let ratio = max_count as f64 / min_count as f64;
229
230 if ratio <= max_imbalance_ratio {
231 CheckResult::passed(format!(
232 "Class imbalance ratio {ratio:.2} <= {max_imbalance_ratio}"
233 ))
234 } else {
235 CheckResult::warning(format!(
236 "Class imbalance ratio {ratio:.2} > {max_imbalance_ratio}"
237 ))
238 }
239 },
240 )
241 .optional()
242 }
243}