1use crate::error::EvalResult;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct SplitAnalysis {
12 pub train_metrics: SplitMetrics,
14 pub validation_metrics: Option<SplitMetrics>,
16 pub test_metrics: SplitMetrics,
18 pub ratio_valid: bool,
20 pub actual_ratios: SplitRatios,
22 pub expected_ratios: SplitRatios,
24 pub leakage_detected: bool,
26 pub leakage_details: Vec<String>,
28 pub distribution_preserved: bool,
30 pub distribution_shift: f64,
32 pub is_valid: bool,
34 pub issues: Vec<String>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct SplitMetrics {
41 pub sample_count: usize,
43 pub class_distribution: HashMap<String, f64>,
45 pub unique_entities: usize,
47 pub date_range: Option<(String, String)>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SplitRatios {
54 pub train: f64,
56 pub validation: f64,
58 pub test: f64,
60}
61
62impl Default for SplitRatios {
63 fn default() -> Self {
64 Self {
65 train: 0.7,
66 validation: 0.15,
67 test: 0.15,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct SplitData {
75 pub train: SplitSetData,
77 pub validation: Option<SplitSetData>,
79 pub test: SplitSetData,
81 pub expected_ratios: SplitRatios,
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct SplitSetData {
88 pub sample_count: usize,
90 pub labels: Vec<String>,
92 pub entity_ids: Vec<String>,
94 pub dates: Vec<String>,
96}
97
98pub struct SplitAnalyzer {
100 ratio_tolerance: f64,
102 max_kl_divergence: f64,
104}
105
106impl SplitAnalyzer {
107 pub fn new() -> Self {
109 Self {
110 ratio_tolerance: 0.05,
111 max_kl_divergence: 0.1,
112 }
113 }
114
115 pub fn analyze(&self, data: &SplitData) -> EvalResult<SplitAnalysis> {
117 let mut issues = Vec::new();
118
119 let total = data.train.sample_count
121 + data
122 .validation
123 .as_ref()
124 .map(|v| v.sample_count)
125 .unwrap_or(0)
126 + data.test.sample_count;
127
128 let actual_ratios = if total > 0 {
129 SplitRatios {
130 train: data.train.sample_count as f64 / total as f64,
131 validation: data
132 .validation
133 .as_ref()
134 .map(|v| v.sample_count as f64 / total as f64)
135 .unwrap_or(0.0),
136 test: data.test.sample_count as f64 / total as f64,
137 }
138 } else {
139 SplitRatios::default()
140 };
141
142 let ratio_valid = self.validate_ratios(&actual_ratios, &data.expected_ratios);
144 if !ratio_valid {
145 issues.push(format!(
146 "Split ratios deviate from expected: actual {:.2}/{:.2}/{:.2}, expected {:.2}/{:.2}/{:.2}",
147 actual_ratios.train,
148 actual_ratios.validation,
149 actual_ratios.test,
150 data.expected_ratios.train,
151 data.expected_ratios.validation,
152 data.expected_ratios.test
153 ));
154 }
155
156 let (leakage_detected, leakage_details) = self.check_leakage(data);
158 if leakage_detected {
159 issues.extend(leakage_details.clone());
160 }
161
162 let train_metrics = self.calculate_metrics(&data.train);
164 let validation_metrics = data.validation.as_ref().map(|v| self.calculate_metrics(v));
165 let test_metrics = self.calculate_metrics(&data.test);
166
167 let (distribution_preserved, distribution_shift) =
169 self.check_distribution(&train_metrics, &test_metrics);
170 if !distribution_preserved {
171 issues.push(format!(
172 "Class distribution shift detected: KL divergence = {:.4}",
173 distribution_shift
174 ));
175 }
176
177 let is_valid = ratio_valid && !leakage_detected && distribution_preserved;
178
179 Ok(SplitAnalysis {
180 train_metrics,
181 validation_metrics,
182 test_metrics,
183 ratio_valid,
184 actual_ratios,
185 expected_ratios: data.expected_ratios.clone(),
186 leakage_detected,
187 leakage_details,
188 distribution_preserved,
189 distribution_shift,
190 is_valid,
191 issues,
192 })
193 }
194
195 fn validate_ratios(&self, actual: &SplitRatios, expected: &SplitRatios) -> bool {
197 (actual.train - expected.train).abs() <= self.ratio_tolerance
198 && (actual.validation - expected.validation).abs() <= self.ratio_tolerance
199 && (actual.test - expected.test).abs() <= self.ratio_tolerance
200 }
201
202 fn check_leakage(&self, data: &SplitData) -> (bool, Vec<String>) {
204 let mut leakage = false;
205 let mut details = Vec::new();
206
207 let train_entities: std::collections::HashSet<_> = data.train.entity_ids.iter().collect();
208 let test_entities: std::collections::HashSet<_> = data.test.entity_ids.iter().collect();
209
210 let overlap: Vec<_> = train_entities.intersection(&test_entities).collect();
211 if !overlap.is_empty() {
212 leakage = true;
213 details.push(format!(
214 "Entity leakage: {} entities appear in both train and test",
215 overlap.len()
216 ));
217 }
218
219 if !data.train.dates.is_empty() && !data.test.dates.is_empty() {
221 let train_max = data.train.dates.iter().max();
222 let test_min = data.test.dates.iter().min();
223
224 if let (Some(train_max), Some(test_min)) = (train_max, test_min) {
225 if test_min < train_max {
226 leakage = true;
227 details.push(format!(
228 "Temporal leakage: test min date {} < train max date {}",
229 test_min, train_max
230 ));
231 }
232 }
233 }
234
235 if let Some(ref val) = data.validation {
236 let val_entities: std::collections::HashSet<_> = val.entity_ids.iter().collect();
237
238 let train_val_overlap: Vec<_> = train_entities.intersection(&val_entities).collect();
239 if !train_val_overlap.is_empty() {
240 leakage = true;
241 details.push(format!(
242 "Entity leakage: {} entities appear in both train and validation",
243 train_val_overlap.len()
244 ));
245 }
246
247 let val_test_overlap: Vec<_> = val_entities.intersection(&test_entities).collect();
248 if !val_test_overlap.is_empty() {
249 leakage = true;
250 details.push(format!(
251 "Entity leakage: {} entities appear in both validation and test",
252 val_test_overlap.len()
253 ));
254 }
255 }
256
257 (leakage, details)
258 }
259
260 fn calculate_metrics(&self, data: &SplitSetData) -> SplitMetrics {
262 let mut class_counts: HashMap<String, usize> = HashMap::new();
263 for label in &data.labels {
264 *class_counts.entry(label.clone()).or_insert(0) += 1;
265 }
266
267 let total = data.labels.len();
268 let class_distribution: HashMap<String, f64> = class_counts
269 .iter()
270 .map(|(k, v)| {
271 (
272 k.clone(),
273 if total > 0 {
274 *v as f64 / total as f64
275 } else {
276 0.0
277 },
278 )
279 })
280 .collect();
281
282 let unique_entities = data
283 .entity_ids
284 .iter()
285 .collect::<std::collections::HashSet<_>>()
286 .len();
287
288 let date_range = if !data.dates.is_empty() {
289 let min = data.dates.iter().min().cloned();
290 let max = data.dates.iter().max().cloned();
291 match (min, max) {
292 (Some(min), Some(max)) => Some((min, max)),
293 _ => None,
294 }
295 } else {
296 None
297 };
298
299 SplitMetrics {
300 sample_count: data.sample_count,
301 class_distribution,
302 unique_entities,
303 date_range,
304 }
305 }
306
307 fn check_distribution(&self, train: &SplitMetrics, test: &SplitMetrics) -> (bool, f64) {
309 if train.class_distribution.is_empty() || test.class_distribution.is_empty() {
310 return (true, 0.0);
311 }
312
313 let mut kl_divergence = 0.0;
315 let epsilon = 1e-10;
316
317 for (class, train_prob) in &train.class_distribution {
318 let test_prob = test.class_distribution.get(class).unwrap_or(&epsilon);
319 let p = *train_prob + epsilon;
320 let q = *test_prob + epsilon;
321 kl_divergence += p * (p / q).ln();
322 }
323
324 for (class, test_prob) in &test.class_distribution {
326 if !train.class_distribution.contains_key(class) {
327 let p = epsilon;
328 let q = *test_prob + epsilon;
329 kl_divergence += p * (p / q).ln();
330 }
331 }
332
333 let preserved = kl_divergence <= self.max_kl_divergence;
334 (preserved, kl_divergence)
335 }
336}
337
338impl Default for SplitAnalyzer {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_valid_split() {
350 let data = SplitData {
351 train: SplitSetData {
352 sample_count: 70,
353 labels: vec!["A".to_string(); 50]
354 .into_iter()
355 .chain(vec!["B".to_string(); 20])
356 .collect(),
357 entity_ids: (0..70).map(|i| format!("E{}", i)).collect(),
358 dates: vec![],
359 },
360 validation: Some(SplitSetData {
361 sample_count: 15,
362 labels: vec!["A".to_string(); 11]
363 .into_iter()
364 .chain(vec!["B".to_string(); 4])
365 .collect(),
366 entity_ids: (70..85).map(|i| format!("E{}", i)).collect(),
367 dates: vec![],
368 }),
369 test: SplitSetData {
370 sample_count: 15,
371 labels: vec!["A".to_string(); 11]
372 .into_iter()
373 .chain(vec!["B".to_string(); 4])
374 .collect(),
375 entity_ids: (85..100).map(|i| format!("E{}", i)).collect(),
376 dates: vec![],
377 },
378 expected_ratios: SplitRatios::default(),
379 };
380
381 let analyzer = SplitAnalyzer::new();
382 let result = analyzer.analyze(&data).unwrap();
383
384 assert!(result.ratio_valid);
385 assert!(!result.leakage_detected);
386 assert!(result.is_valid);
387 }
388
389 #[test]
390 fn test_entity_leakage() {
391 let data = SplitData {
392 train: SplitSetData {
393 sample_count: 70,
394 labels: vec![],
395 entity_ids: vec!["E1".to_string(), "E2".to_string(), "E3".to_string()],
396 dates: vec![],
397 },
398 validation: None,
399 test: SplitSetData {
400 sample_count: 30,
401 labels: vec![],
402 entity_ids: vec!["E1".to_string(), "E4".to_string()], dates: vec![],
404 },
405 expected_ratios: SplitRatios {
406 train: 0.7,
407 validation: 0.0,
408 test: 0.3,
409 },
410 };
411
412 let analyzer = SplitAnalyzer::new();
413 let result = analyzer.analyze(&data).unwrap();
414
415 assert!(result.leakage_detected);
416 assert!(!result.is_valid);
417 }
418}