1use crate::error::EvalResult;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct SplitAnalysis {
12 pub train_metrics: SplitMetrics,
14 pub validation_metrics: Option<SplitMetrics>,
16 pub test_metrics: SplitMetrics,
18 pub ratio_valid: bool,
20 pub actual_ratios: SplitRatios,
22 pub expected_ratios: SplitRatios,
24 pub leakage_detected: bool,
26 pub leakage_details: Vec<String>,
28 pub distribution_preserved: bool,
30 pub distribution_shift: f64,
32 pub is_valid: bool,
34 pub issues: Vec<String>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct SplitMetrics {
41 pub sample_count: usize,
43 pub class_distribution: HashMap<String, f64>,
45 pub unique_entities: usize,
47 pub date_range: Option<(String, String)>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SplitRatios {
54 pub train: f64,
56 pub validation: f64,
58 pub test: f64,
60}
61
62impl Default for SplitRatios {
63 fn default() -> Self {
64 Self {
65 train: 0.7,
66 validation: 0.15,
67 test: 0.15,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct SplitData {
75 pub train: SplitSetData,
77 pub validation: Option<SplitSetData>,
79 pub test: SplitSetData,
81 pub expected_ratios: SplitRatios,
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct SplitSetData {
88 pub sample_count: usize,
90 pub labels: Vec<String>,
92 pub entity_ids: Vec<String>,
94 pub dates: Vec<String>,
96}
97
98pub struct SplitAnalyzer {
100 ratio_tolerance: f64,
102 max_kl_divergence: f64,
104}
105
106impl SplitAnalyzer {
107 pub fn new() -> Self {
109 Self {
110 ratio_tolerance: 0.05,
111 max_kl_divergence: 0.1,
112 }
113 }
114
115 pub fn analyze(&self, data: &SplitData) -> EvalResult<SplitAnalysis> {
117 let mut issues = Vec::new();
118
119 let total = data.train.sample_count
121 + data
122 .validation
123 .as_ref()
124 .map(|v| v.sample_count)
125 .unwrap_or(0)
126 + data.test.sample_count;
127
128 let actual_ratios = if total > 0 {
129 SplitRatios {
130 train: data.train.sample_count as f64 / total as f64,
131 validation: data
132 .validation
133 .as_ref()
134 .map(|v| v.sample_count as f64 / total as f64)
135 .unwrap_or(0.0),
136 test: data.test.sample_count as f64 / total as f64,
137 }
138 } else {
139 SplitRatios::default()
140 };
141
142 let ratio_valid = self.validate_ratios(&actual_ratios, &data.expected_ratios);
144 if !ratio_valid {
145 issues.push(format!(
146 "Split ratios deviate from expected: actual {:.2}/{:.2}/{:.2}, expected {:.2}/{:.2}/{:.2}",
147 actual_ratios.train,
148 actual_ratios.validation,
149 actual_ratios.test,
150 data.expected_ratios.train,
151 data.expected_ratios.validation,
152 data.expected_ratios.test
153 ));
154 }
155
156 let (leakage_detected, leakage_details) = self.check_leakage(data);
158 if leakage_detected {
159 issues.extend(leakage_details.clone());
160 }
161
162 let train_metrics = self.calculate_metrics(&data.train);
164 let validation_metrics = data.validation.as_ref().map(|v| self.calculate_metrics(v));
165 let test_metrics = self.calculate_metrics(&data.test);
166
167 let (distribution_preserved, distribution_shift) =
169 self.check_distribution(&train_metrics, &test_metrics);
170 if !distribution_preserved {
171 issues.push(format!(
172 "Class distribution shift detected: KL divergence = {distribution_shift:.4}"
173 ));
174 }
175
176 let is_valid = ratio_valid && !leakage_detected && distribution_preserved;
177
178 Ok(SplitAnalysis {
179 train_metrics,
180 validation_metrics,
181 test_metrics,
182 ratio_valid,
183 actual_ratios,
184 expected_ratios: data.expected_ratios.clone(),
185 leakage_detected,
186 leakage_details,
187 distribution_preserved,
188 distribution_shift,
189 is_valid,
190 issues,
191 })
192 }
193
194 fn validate_ratios(&self, actual: &SplitRatios, expected: &SplitRatios) -> bool {
196 (actual.train - expected.train).abs() <= self.ratio_tolerance
197 && (actual.validation - expected.validation).abs() <= self.ratio_tolerance
198 && (actual.test - expected.test).abs() <= self.ratio_tolerance
199 }
200
201 fn check_leakage(&self, data: &SplitData) -> (bool, Vec<String>) {
203 let mut leakage = false;
204 let mut details = Vec::new();
205
206 let train_entities: std::collections::HashSet<_> = data.train.entity_ids.iter().collect();
207 let test_entities: std::collections::HashSet<_> = data.test.entity_ids.iter().collect();
208
209 let overlap: Vec<_> = train_entities.intersection(&test_entities).collect();
210 if !overlap.is_empty() {
211 leakage = true;
212 details.push(format!(
213 "Entity leakage: {} entities appear in both train and test",
214 overlap.len()
215 ));
216 }
217
218 if !data.train.dates.is_empty() && !data.test.dates.is_empty() {
220 let train_max = data.train.dates.iter().max();
221 let test_min = data.test.dates.iter().min();
222
223 if let (Some(train_max), Some(test_min)) = (train_max, test_min) {
224 if test_min < train_max {
225 leakage = true;
226 details.push(format!(
227 "Temporal leakage: test min date {test_min} < train max date {train_max}"
228 ));
229 }
230 }
231 }
232
233 if let Some(ref val) = data.validation {
234 let val_entities: std::collections::HashSet<_> = val.entity_ids.iter().collect();
235
236 let train_val_overlap: Vec<_> = train_entities.intersection(&val_entities).collect();
237 if !train_val_overlap.is_empty() {
238 leakage = true;
239 details.push(format!(
240 "Entity leakage: {} entities appear in both train and validation",
241 train_val_overlap.len()
242 ));
243 }
244
245 let val_test_overlap: Vec<_> = val_entities.intersection(&test_entities).collect();
246 if !val_test_overlap.is_empty() {
247 leakage = true;
248 details.push(format!(
249 "Entity leakage: {} entities appear in both validation and test",
250 val_test_overlap.len()
251 ));
252 }
253 }
254
255 (leakage, details)
256 }
257
258 fn calculate_metrics(&self, data: &SplitSetData) -> SplitMetrics {
260 let mut class_counts: HashMap<String, usize> = HashMap::new();
261 for label in &data.labels {
262 *class_counts.entry(label.clone()).or_insert(0) += 1;
263 }
264
265 let total = data.labels.len();
266 let class_distribution: HashMap<String, f64> = class_counts
267 .iter()
268 .map(|(k, v)| {
269 (
270 k.clone(),
271 if total > 0 {
272 *v as f64 / total as f64
273 } else {
274 0.0
275 },
276 )
277 })
278 .collect();
279
280 let unique_entities = data
281 .entity_ids
282 .iter()
283 .collect::<std::collections::HashSet<_>>()
284 .len();
285
286 let date_range = if !data.dates.is_empty() {
287 let min = data.dates.iter().min().cloned();
288 let max = data.dates.iter().max().cloned();
289 match (min, max) {
290 (Some(min), Some(max)) => Some((min, max)),
291 _ => None,
292 }
293 } else {
294 None
295 };
296
297 SplitMetrics {
298 sample_count: data.sample_count,
299 class_distribution,
300 unique_entities,
301 date_range,
302 }
303 }
304
305 fn check_distribution(&self, train: &SplitMetrics, test: &SplitMetrics) -> (bool, f64) {
307 if train.class_distribution.is_empty() || test.class_distribution.is_empty() {
308 return (true, 0.0);
309 }
310
311 let mut kl_divergence = 0.0;
313 let epsilon = 1e-10;
314
315 for (class, train_prob) in &train.class_distribution {
316 let test_prob = test.class_distribution.get(class).unwrap_or(&epsilon);
317 let p = *train_prob + epsilon;
318 let q = *test_prob + epsilon;
319 kl_divergence += p * (p / q).ln();
320 }
321
322 for (class, test_prob) in &test.class_distribution {
324 if !train.class_distribution.contains_key(class) {
325 let p = epsilon;
326 let q = *test_prob + epsilon;
327 kl_divergence += p * (p / q).ln();
328 }
329 }
330
331 let preserved = kl_divergence <= self.max_kl_divergence;
332 (preserved, kl_divergence)
333 }
334}
335
336impl Default for SplitAnalyzer {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342#[cfg(test)]
343#[allow(clippy::unwrap_used)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_valid_split() {
349 let data = SplitData {
350 train: SplitSetData {
351 sample_count: 70,
352 labels: vec!["A".to_string(); 50]
353 .into_iter()
354 .chain(vec!["B".to_string(); 20])
355 .collect(),
356 entity_ids: (0..70).map(|i| format!("E{}", i)).collect(),
357 dates: vec![],
358 },
359 validation: Some(SplitSetData {
360 sample_count: 15,
361 labels: vec!["A".to_string(); 11]
362 .into_iter()
363 .chain(vec!["B".to_string(); 4])
364 .collect(),
365 entity_ids: (70..85).map(|i| format!("E{}", i)).collect(),
366 dates: vec![],
367 }),
368 test: SplitSetData {
369 sample_count: 15,
370 labels: vec!["A".to_string(); 11]
371 .into_iter()
372 .chain(vec!["B".to_string(); 4])
373 .collect(),
374 entity_ids: (85..100).map(|i| format!("E{}", i)).collect(),
375 dates: vec![],
376 },
377 expected_ratios: SplitRatios::default(),
378 };
379
380 let analyzer = SplitAnalyzer::new();
381 let result = analyzer.analyze(&data).unwrap();
382
383 assert!(result.ratio_valid);
384 assert!(!result.leakage_detected);
385 assert!(result.is_valid);
386 }
387
388 #[test]
389 fn test_entity_leakage() {
390 let data = SplitData {
391 train: SplitSetData {
392 sample_count: 70,
393 labels: vec![],
394 entity_ids: vec!["E1".to_string(), "E2".to_string(), "E3".to_string()],
395 dates: vec![],
396 },
397 validation: None,
398 test: SplitSetData {
399 sample_count: 30,
400 labels: vec![],
401 entity_ids: vec!["E1".to_string(), "E4".to_string()], dates: vec![],
403 },
404 expected_ratios: SplitRatios {
405 train: 0.7,
406 validation: 0.0,
407 test: 0.3,
408 },
409 };
410
411 let analyzer = SplitAnalyzer::new();
412 let result = analyzer.analyze(&data).unwrap();
413
414 assert!(result.leakage_detected);
415 assert!(!result.is_valid);
416 }
417}