Skip to main content

datasynth_eval/ml/
splits.rs

1//! Train/test split validation.
2//!
3//! Validates split ratios, data leakage, and distribution preservation.
4
5use crate::error::EvalResult;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Results of split analysis.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct SplitAnalysis {
12    /// Train set metrics.
13    pub train_metrics: SplitMetrics,
14    /// Validation set metrics (if present).
15    pub validation_metrics: Option<SplitMetrics>,
16    /// Test set metrics.
17    pub test_metrics: SplitMetrics,
18    /// Split ratio validation.
19    pub ratio_valid: bool,
20    /// Actual split ratios.
21    pub actual_ratios: SplitRatios,
22    /// Expected split ratios.
23    pub expected_ratios: SplitRatios,
24    /// Data leakage detected.
25    pub leakage_detected: bool,
26    /// Leakage details (if detected).
27    pub leakage_details: Vec<String>,
28    /// Class distribution preserved.
29    pub distribution_preserved: bool,
30    /// Distribution shift score (KL divergence).
31    pub distribution_shift: f64,
32    /// Overall validity.
33    pub is_valid: bool,
34    /// Issues found.
35    pub issues: Vec<String>,
36}
37
38/// Metrics for a single split.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct SplitMetrics {
41    /// Number of samples.
42    pub sample_count: usize,
43    /// Class distribution.
44    pub class_distribution: HashMap<String, f64>,
45    /// Unique entity IDs (for leakage detection).
46    pub unique_entities: usize,
47    /// Date range (min, max).
48    pub date_range: Option<(String, String)>,
49}
50
51/// Split ratios.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SplitRatios {
54    /// Train ratio.
55    pub train: f64,
56    /// Validation ratio.
57    pub validation: f64,
58    /// Test ratio.
59    pub test: f64,
60}
61
62impl Default for SplitRatios {
63    fn default() -> Self {
64        Self {
65            train: 0.7,
66            validation: 0.15,
67            test: 0.15,
68        }
69    }
70}
71
72/// Input for split analysis.
73#[derive(Debug, Clone)]
74pub struct SplitData {
75    /// Train set data.
76    pub train: SplitSetData,
77    /// Validation set data (optional).
78    pub validation: Option<SplitSetData>,
79    /// Test set data.
80    pub test: SplitSetData,
81    /// Expected split ratios.
82    pub expected_ratios: SplitRatios,
83}
84
85/// Data for a single split set.
86#[derive(Debug, Clone, Default)]
87pub struct SplitSetData {
88    /// Number of samples.
89    pub sample_count: usize,
90    /// Class labels.
91    pub labels: Vec<String>,
92    /// Entity IDs (for leakage detection).
93    pub entity_ids: Vec<String>,
94    /// Dates (for temporal leakage detection).
95    pub dates: Vec<String>,
96}
97
98/// Analyzer for train/test splits.
99pub struct SplitAnalyzer {
100    /// Tolerance for ratio validation.
101    ratio_tolerance: f64,
102    /// Maximum KL divergence for distribution preservation.
103    max_kl_divergence: f64,
104}
105
106impl SplitAnalyzer {
107    /// Create a new analyzer.
108    pub fn new() -> Self {
109        Self {
110            ratio_tolerance: 0.05,
111            max_kl_divergence: 0.1,
112        }
113    }
114
115    /// Analyze split quality.
116    pub fn analyze(&self, data: &SplitData) -> EvalResult<SplitAnalysis> {
117        let mut issues = Vec::new();
118
119        // Calculate actual ratios
120        let total = data.train.sample_count
121            + data
122                .validation
123                .as_ref()
124                .map(|v| v.sample_count)
125                .unwrap_or(0)
126            + data.test.sample_count;
127
128        let actual_ratios = if total > 0 {
129            SplitRatios {
130                train: data.train.sample_count as f64 / total as f64,
131                validation: data
132                    .validation
133                    .as_ref()
134                    .map(|v| v.sample_count as f64 / total as f64)
135                    .unwrap_or(0.0),
136                test: data.test.sample_count as f64 / total as f64,
137            }
138        } else {
139            SplitRatios::default()
140        };
141
142        // Validate ratios
143        let ratio_valid = self.validate_ratios(&actual_ratios, &data.expected_ratios);
144        if !ratio_valid {
145            issues.push(format!(
146                "Split ratios deviate from expected: actual {:.2}/{:.2}/{:.2}, expected {:.2}/{:.2}/{:.2}",
147                actual_ratios.train,
148                actual_ratios.validation,
149                actual_ratios.test,
150                data.expected_ratios.train,
151                data.expected_ratios.validation,
152                data.expected_ratios.test
153            ));
154        }
155
156        // Check for data leakage (entity overlap)
157        let (leakage_detected, leakage_details) = self.check_leakage(data);
158        if leakage_detected {
159            issues.extend(leakage_details.clone());
160        }
161
162        // Calculate distribution metrics
163        let train_metrics = self.calculate_metrics(&data.train);
164        let validation_metrics = data.validation.as_ref().map(|v| self.calculate_metrics(v));
165        let test_metrics = self.calculate_metrics(&data.test);
166
167        // Check distribution preservation
168        let (distribution_preserved, distribution_shift) =
169            self.check_distribution(&train_metrics, &test_metrics);
170        if !distribution_preserved {
171            issues.push(format!(
172                "Class distribution shift detected: KL divergence = {:.4}",
173                distribution_shift
174            ));
175        }
176
177        let is_valid = ratio_valid && !leakage_detected && distribution_preserved;
178
179        Ok(SplitAnalysis {
180            train_metrics,
181            validation_metrics,
182            test_metrics,
183            ratio_valid,
184            actual_ratios,
185            expected_ratios: data.expected_ratios.clone(),
186            leakage_detected,
187            leakage_details,
188            distribution_preserved,
189            distribution_shift,
190            is_valid,
191            issues,
192        })
193    }
194
195    /// Validate split ratios against expected.
196    fn validate_ratios(&self, actual: &SplitRatios, expected: &SplitRatios) -> bool {
197        (actual.train - expected.train).abs() <= self.ratio_tolerance
198            && (actual.validation - expected.validation).abs() <= self.ratio_tolerance
199            && (actual.test - expected.test).abs() <= self.ratio_tolerance
200    }
201
202    /// Check for data leakage between splits.
203    fn check_leakage(&self, data: &SplitData) -> (bool, Vec<String>) {
204        let mut leakage = false;
205        let mut details = Vec::new();
206
207        let train_entities: std::collections::HashSet<_> = data.train.entity_ids.iter().collect();
208        let test_entities: std::collections::HashSet<_> = data.test.entity_ids.iter().collect();
209
210        let overlap: Vec<_> = train_entities.intersection(&test_entities).collect();
211        if !overlap.is_empty() {
212            leakage = true;
213            details.push(format!(
214                "Entity leakage: {} entities appear in both train and test",
215                overlap.len()
216            ));
217        }
218
219        // Check temporal leakage (test dates before train dates)
220        if !data.train.dates.is_empty() && !data.test.dates.is_empty() {
221            let train_max = data.train.dates.iter().max();
222            let test_min = data.test.dates.iter().min();
223
224            if let (Some(train_max), Some(test_min)) = (train_max, test_min) {
225                if test_min < train_max {
226                    leakage = true;
227                    details.push(format!(
228                        "Temporal leakage: test min date {} < train max date {}",
229                        test_min, train_max
230                    ));
231                }
232            }
233        }
234
235        if let Some(ref val) = data.validation {
236            let val_entities: std::collections::HashSet<_> = val.entity_ids.iter().collect();
237
238            let train_val_overlap: Vec<_> = train_entities.intersection(&val_entities).collect();
239            if !train_val_overlap.is_empty() {
240                leakage = true;
241                details.push(format!(
242                    "Entity leakage: {} entities appear in both train and validation",
243                    train_val_overlap.len()
244                ));
245            }
246
247            let val_test_overlap: Vec<_> = val_entities.intersection(&test_entities).collect();
248            if !val_test_overlap.is_empty() {
249                leakage = true;
250                details.push(format!(
251                    "Entity leakage: {} entities appear in both validation and test",
252                    val_test_overlap.len()
253                ));
254            }
255        }
256
257        (leakage, details)
258    }
259
260    /// Calculate metrics for a split set.
261    fn calculate_metrics(&self, data: &SplitSetData) -> SplitMetrics {
262        let mut class_counts: HashMap<String, usize> = HashMap::new();
263        for label in &data.labels {
264            *class_counts.entry(label.clone()).or_insert(0) += 1;
265        }
266
267        let total = data.labels.len();
268        let class_distribution: HashMap<String, f64> = class_counts
269            .iter()
270            .map(|(k, v)| {
271                (
272                    k.clone(),
273                    if total > 0 {
274                        *v as f64 / total as f64
275                    } else {
276                        0.0
277                    },
278                )
279            })
280            .collect();
281
282        let unique_entities = data
283            .entity_ids
284            .iter()
285            .collect::<std::collections::HashSet<_>>()
286            .len();
287
288        let date_range = if !data.dates.is_empty() {
289            let min = data.dates.iter().min().cloned();
290            let max = data.dates.iter().max().cloned();
291            match (min, max) {
292                (Some(min), Some(max)) => Some((min, max)),
293                _ => None,
294            }
295        } else {
296            None
297        };
298
299        SplitMetrics {
300            sample_count: data.sample_count,
301            class_distribution,
302            unique_entities,
303            date_range,
304        }
305    }
306
307    /// Check distribution preservation between train and test.
308    fn check_distribution(&self, train: &SplitMetrics, test: &SplitMetrics) -> (bool, f64) {
309        if train.class_distribution.is_empty() || test.class_distribution.is_empty() {
310            return (true, 0.0);
311        }
312
313        // Calculate KL divergence: KL(P||Q) = sum(P(x) * log(P(x)/Q(x)))
314        let mut kl_divergence = 0.0;
315        let epsilon = 1e-10;
316
317        for (class, train_prob) in &train.class_distribution {
318            let test_prob = test.class_distribution.get(class).unwrap_or(&epsilon);
319            let p = *train_prob + epsilon;
320            let q = *test_prob + epsilon;
321            kl_divergence += p * (p / q).ln();
322        }
323
324        // Also account for classes in test not in train
325        for (class, test_prob) in &test.class_distribution {
326            if !train.class_distribution.contains_key(class) {
327                let p = epsilon;
328                let q = *test_prob + epsilon;
329                kl_divergence += p * (p / q).ln();
330            }
331        }
332
333        let preserved = kl_divergence <= self.max_kl_divergence;
334        (preserved, kl_divergence)
335    }
336}
337
338impl Default for SplitAnalyzer {
339    fn default() -> Self {
340        Self::new()
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_valid_split() {
350        let data = SplitData {
351            train: SplitSetData {
352                sample_count: 70,
353                labels: vec!["A".to_string(); 50]
354                    .into_iter()
355                    .chain(vec!["B".to_string(); 20])
356                    .collect(),
357                entity_ids: (0..70).map(|i| format!("E{}", i)).collect(),
358                dates: vec![],
359            },
360            validation: Some(SplitSetData {
361                sample_count: 15,
362                labels: vec!["A".to_string(); 11]
363                    .into_iter()
364                    .chain(vec!["B".to_string(); 4])
365                    .collect(),
366                entity_ids: (70..85).map(|i| format!("E{}", i)).collect(),
367                dates: vec![],
368            }),
369            test: SplitSetData {
370                sample_count: 15,
371                labels: vec!["A".to_string(); 11]
372                    .into_iter()
373                    .chain(vec!["B".to_string(); 4])
374                    .collect(),
375                entity_ids: (85..100).map(|i| format!("E{}", i)).collect(),
376                dates: vec![],
377            },
378            expected_ratios: SplitRatios::default(),
379        };
380
381        let analyzer = SplitAnalyzer::new();
382        let result = analyzer.analyze(&data).unwrap();
383
384        assert!(result.ratio_valid);
385        assert!(!result.leakage_detected);
386        assert!(result.is_valid);
387    }
388
389    #[test]
390    fn test_entity_leakage() {
391        let data = SplitData {
392            train: SplitSetData {
393                sample_count: 70,
394                labels: vec![],
395                entity_ids: vec!["E1".to_string(), "E2".to_string(), "E3".to_string()],
396                dates: vec![],
397            },
398            validation: None,
399            test: SplitSetData {
400                sample_count: 30,
401                labels: vec![],
402                entity_ids: vec!["E1".to_string(), "E4".to_string()], // E1 is in both
403                dates: vec![],
404            },
405            expected_ratios: SplitRatios {
406                train: 0.7,
407                validation: 0.0,
408                test: 0.3,
409            },
410        };
411
412        let analyzer = SplitAnalyzer::new();
413        let result = analyzer.analyze(&data).unwrap();
414
415        assert!(result.leakage_detected);
416        assert!(!result.is_valid);
417    }
418}