Skip to main content

entrenar/storage/preflight/checks/
data_integrity.rs

1//! Data integrity preflight checks.
2
3use std::collections::HashMap;
4
5use super::{CheckResult, CheckType, PreflightCheck};
6
7impl PreflightCheck {
8    // =========================================================================
9    // Built-in Data Integrity Checks
10    // =========================================================================
11
12    /// Check for NaN values in data
13    pub fn no_nan_values() -> Self {
14        Self::new(
15            "no_nan_values",
16            CheckType::DataIntegrity,
17            "Ensures no NaN values exist in the dataset",
18            |data, _ctx| {
19                let mut nan_count = 0;
20                let mut nan_locations = Vec::new();
21
22                for (row_idx, row) in data.iter().enumerate() {
23                    for (col_idx, val) in row.iter().enumerate() {
24                        if val.is_nan() {
25                            nan_count += 1;
26                            if nan_locations.len() < 5 {
27                                nan_locations.push(format!("({row_idx}, {col_idx})"));
28                            }
29                        }
30                    }
31                }
32
33                if nan_count == 0 {
34                    CheckResult::passed("No NaN values found")
35                } else {
36                    CheckResult::failed_with_details(
37                        format!("Found {nan_count} NaN values"),
38                        format!("First locations: {}", nan_locations.join(", ")),
39                    )
40                }
41            },
42        )
43    }
44
45    /// Check for infinite values in data
46    pub fn no_inf_values() -> Self {
47        Self::new(
48            "no_inf_values",
49            CheckType::DataIntegrity,
50            "Ensures no infinite values exist in the dataset",
51            |data, _ctx| {
52                let mut inf_count = 0;
53                let mut inf_locations = Vec::new();
54
55                for (row_idx, row) in data.iter().enumerate() {
56                    for (col_idx, val) in row.iter().enumerate() {
57                        if val.is_infinite() {
58                            inf_count += 1;
59                            if inf_locations.len() < 5 {
60                                inf_locations.push(format!("({row_idx}, {col_idx})"));
61                            }
62                        }
63                    }
64                }
65
66                if inf_count == 0 {
67                    CheckResult::passed("No infinite values found")
68                } else {
69                    CheckResult::failed_with_details(
70                        format!("Found {inf_count} infinite values"),
71                        format!("First locations: {}", inf_locations.join(", ")),
72                    )
73                }
74            },
75        )
76    }
77
78    /// Check minimum number of samples
79    pub fn min_samples(min: usize) -> Self {
80        Self::new(
81            "min_samples",
82            CheckType::DataIntegrity,
83            format!("Ensures at least {min} samples exist"),
84            move |data, ctx| {
85                let min_required = ctx.min_samples.unwrap_or(min);
86                let actual = data.len();
87
88                if actual >= min_required {
89                    CheckResult::passed(format!("Found {actual} samples (minimum: {min_required})"))
90                } else {
91                    CheckResult::failed(format!(
92                        "Only {actual} samples found (minimum: {min_required})"
93                    ))
94                }
95            },
96        )
97    }
98
99    /// Check minimum number of features
100    pub fn min_features(min: usize) -> Self {
101        Self::new(
102            "min_features",
103            CheckType::DataIntegrity,
104            format!("Ensures at least {min} features exist"),
105            move |data, ctx| {
106                let min_required = ctx.min_features.unwrap_or(min);
107                let actual = data.first().map_or(0, Vec::len);
108
109                if actual >= min_required {
110                    CheckResult::passed(format!(
111                        "Found {actual} features (minimum: {min_required})"
112                    ))
113                } else {
114                    CheckResult::failed(format!(
115                        "Only {actual} features found (minimum: {min_required})"
116                    ))
117                }
118            },
119        )
120    }
121
122    /// Check for consistent row lengths
123    pub fn consistent_dimensions() -> Self {
124        Self::new(
125            "consistent_dimensions",
126            CheckType::DataIntegrity,
127            "Ensures all rows have the same number of features",
128            |data, _ctx| {
129                if data.is_empty() {
130                    return CheckResult::skipped("No data to check");
131                }
132
133                let expected_len = data[0].len();
134                let mut inconsistent = Vec::new();
135
136                for (idx, row) in data.iter().enumerate() {
137                    if row.len() != expected_len {
138                        inconsistent.push(format!("row {idx}: {} features", row.len()));
139                        if inconsistent.len() >= 5 {
140                            break;
141                        }
142                    }
143                }
144
145                if inconsistent.is_empty() {
146                    CheckResult::passed(format!(
147                        "All {} rows have {expected_len} features",
148                        data.len()
149                    ))
150                } else {
151                    CheckResult::failed_with_details(
152                        format!("Inconsistent dimensions (expected {expected_len} features)"),
153                        inconsistent.join(", "),
154                    )
155                }
156            },
157        )
158    }
159
160    /// Check feature variance (detect constant features)
161    pub fn no_constant_features() -> Self {
162        Self::new(
163            "no_constant_features",
164            CheckType::DataIntegrity,
165            "Ensures no features have zero variance",
166            |data, _ctx| {
167                if data.is_empty() || data[0].is_empty() {
168                    return CheckResult::skipped("No data to check");
169                }
170
171                let n_features = data[0].len();
172                let mut constant_features = Vec::new();
173
174                for col in 0..n_features {
175                    let values: Vec<f64> = data.iter().map(|row| row[col]).collect();
176                    let first = values[0];
177
178                    if values.iter().all(|v| (*v - first).abs() < f64::EPSILON) {
179                        constant_features.push(col);
180                    }
181                }
182
183                if constant_features.is_empty() {
184                    CheckResult::passed("No constant features found")
185                } else {
186                    CheckResult::warning(format!(
187                        "Found {} constant feature(s): {:?}",
188                        constant_features.len(),
189                        constant_features
190                    ))
191                }
192            },
193        )
194        .optional()
195    }
196
197    /// Check for label balance (classification)
198    pub fn label_balance(max_imbalance_ratio: f64) -> Self {
199        Self::new(
200            "label_balance",
201            CheckType::DataIntegrity,
202            format!("Ensures class imbalance ratio <= {max_imbalance_ratio}"),
203            move |data, _ctx| {
204                if data.is_empty() || data[0].is_empty() {
205                    return CheckResult::skipped("No data to check");
206                }
207
208                // Assume last column is label
209                let labels: Vec<i64> =
210                    data.iter().map(|row| *row.last().unwrap_or(&0.0) as i64).collect();
211
212                let mut counts: HashMap<i64, usize> = HashMap::new();
213                for label in &labels {
214                    *counts.entry(*label).or_default() += 1;
215                }
216
217                if counts.is_empty() {
218                    return CheckResult::skipped("No labels found");
219                }
220
221                let max_count = *counts.values().max().unwrap_or(&0);
222                let min_count = *counts.values().min().unwrap_or(&0);
223
224                if min_count == 0 {
225                    return CheckResult::failed("One or more classes have zero samples");
226                }
227
228                let ratio = max_count as f64 / min_count as f64;
229
230                if ratio <= max_imbalance_ratio {
231                    CheckResult::passed(format!(
232                        "Class imbalance ratio {ratio:.2} <= {max_imbalance_ratio}"
233                    ))
234                } else {
235                    CheckResult::warning(format!(
236                        "Class imbalance ratio {ratio:.2} > {max_imbalance_ratio}"
237                    ))
238                }
239            },
240        )
241        .optional()
242    }
243}