Skip to main content

datasynth_eval/ml/
splits.rs

1//! Train/test split validation.
2//!
3//! Validates split ratios, data leakage, and distribution preservation.
4
5use crate::error::EvalResult;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Results of split analysis.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct SplitAnalysis {
12    /// Train set metrics.
13    pub train_metrics: SplitMetrics,
14    /// Validation set metrics (if present).
15    pub validation_metrics: Option<SplitMetrics>,
16    /// Test set metrics.
17    pub test_metrics: SplitMetrics,
18    /// Split ratio validation.
19    pub ratio_valid: bool,
20    /// Actual split ratios.
21    pub actual_ratios: SplitRatios,
22    /// Expected split ratios.
23    pub expected_ratios: SplitRatios,
24    /// Data leakage detected.
25    pub leakage_detected: bool,
26    /// Leakage details (if detected).
27    pub leakage_details: Vec<String>,
28    /// Class distribution preserved.
29    pub distribution_preserved: bool,
30    /// Distribution shift score (KL divergence).
31    pub distribution_shift: f64,
32    /// Overall validity.
33    pub is_valid: bool,
34    /// Issues found.
35    pub issues: Vec<String>,
36}
37
38/// Metrics for a single split.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct SplitMetrics {
41    /// Number of samples.
42    pub sample_count: usize,
43    /// Class distribution.
44    pub class_distribution: HashMap<String, f64>,
45    /// Unique entity IDs (for leakage detection).
46    pub unique_entities: usize,
47    /// Date range (min, max).
48    pub date_range: Option<(String, String)>,
49}
50
51/// Split ratios.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SplitRatios {
54    /// Train ratio.
55    pub train: f64,
56    /// Validation ratio.
57    pub validation: f64,
58    /// Test ratio.
59    pub test: f64,
60}
61
62impl Default for SplitRatios {
63    fn default() -> Self {
64        Self {
65            train: 0.7,
66            validation: 0.15,
67            test: 0.15,
68        }
69    }
70}
71
72/// Input for split analysis.
73#[derive(Debug, Clone)]
74pub struct SplitData {
75    /// Train set data.
76    pub train: SplitSetData,
77    /// Validation set data (optional).
78    pub validation: Option<SplitSetData>,
79    /// Test set data.
80    pub test: SplitSetData,
81    /// Expected split ratios.
82    pub expected_ratios: SplitRatios,
83}
84
85/// Data for a single split set.
86#[derive(Debug, Clone, Default)]
87pub struct SplitSetData {
88    /// Number of samples.
89    pub sample_count: usize,
90    /// Class labels.
91    pub labels: Vec<String>,
92    /// Entity IDs (for leakage detection).
93    pub entity_ids: Vec<String>,
94    /// Dates (for temporal leakage detection).
95    pub dates: Vec<String>,
96}
97
98/// Analyzer for train/test splits.
99pub struct SplitAnalyzer {
100    /// Tolerance for ratio validation.
101    ratio_tolerance: f64,
102    /// Maximum KL divergence for distribution preservation.
103    max_kl_divergence: f64,
104}
105
106impl SplitAnalyzer {
107    /// Create a new analyzer.
108    pub fn new() -> Self {
109        Self {
110            ratio_tolerance: 0.05,
111            max_kl_divergence: 0.1,
112        }
113    }
114
115    /// Analyze split quality.
116    pub fn analyze(&self, data: &SplitData) -> EvalResult<SplitAnalysis> {
117        let mut issues = Vec::new();
118
119        // Calculate actual ratios
120        let total = data.train.sample_count
121            + data
122                .validation
123                .as_ref()
124                .map(|v| v.sample_count)
125                .unwrap_or(0)
126            + data.test.sample_count;
127
128        let actual_ratios = if total > 0 {
129            SplitRatios {
130                train: data.train.sample_count as f64 / total as f64,
131                validation: data
132                    .validation
133                    .as_ref()
134                    .map(|v| v.sample_count as f64 / total as f64)
135                    .unwrap_or(0.0),
136                test: data.test.sample_count as f64 / total as f64,
137            }
138        } else {
139            SplitRatios::default()
140        };
141
142        // Validate ratios
143        let ratio_valid = self.validate_ratios(&actual_ratios, &data.expected_ratios);
144        if !ratio_valid {
145            issues.push(format!(
146                "Split ratios deviate from expected: actual {:.2}/{:.2}/{:.2}, expected {:.2}/{:.2}/{:.2}",
147                actual_ratios.train,
148                actual_ratios.validation,
149                actual_ratios.test,
150                data.expected_ratios.train,
151                data.expected_ratios.validation,
152                data.expected_ratios.test
153            ));
154        }
155
156        // Check for data leakage (entity overlap)
157        let (leakage_detected, leakage_details) = self.check_leakage(data);
158        if leakage_detected {
159            issues.extend(leakage_details.clone());
160        }
161
162        // Calculate distribution metrics
163        let train_metrics = self.calculate_metrics(&data.train);
164        let validation_metrics = data.validation.as_ref().map(|v| self.calculate_metrics(v));
165        let test_metrics = self.calculate_metrics(&data.test);
166
167        // Check distribution preservation
168        let (distribution_preserved, distribution_shift) =
169            self.check_distribution(&train_metrics, &test_metrics);
170        if !distribution_preserved {
171            issues.push(format!(
172                "Class distribution shift detected: KL divergence = {distribution_shift:.4}"
173            ));
174        }
175
176        let is_valid = ratio_valid && !leakage_detected && distribution_preserved;
177
178        Ok(SplitAnalysis {
179            train_metrics,
180            validation_metrics,
181            test_metrics,
182            ratio_valid,
183            actual_ratios,
184            expected_ratios: data.expected_ratios.clone(),
185            leakage_detected,
186            leakage_details,
187            distribution_preserved,
188            distribution_shift,
189            is_valid,
190            issues,
191        })
192    }
193
194    /// Validate split ratios against expected.
195    fn validate_ratios(&self, actual: &SplitRatios, expected: &SplitRatios) -> bool {
196        (actual.train - expected.train).abs() <= self.ratio_tolerance
197            && (actual.validation - expected.validation).abs() <= self.ratio_tolerance
198            && (actual.test - expected.test).abs() <= self.ratio_tolerance
199    }
200
201    /// Check for data leakage between splits.
202    fn check_leakage(&self, data: &SplitData) -> (bool, Vec<String>) {
203        let mut leakage = false;
204        let mut details = Vec::new();
205
206        let train_entities: std::collections::HashSet<_> = data.train.entity_ids.iter().collect();
207        let test_entities: std::collections::HashSet<_> = data.test.entity_ids.iter().collect();
208
209        let overlap: Vec<_> = train_entities.intersection(&test_entities).collect();
210        if !overlap.is_empty() {
211            leakage = true;
212            details.push(format!(
213                "Entity leakage: {} entities appear in both train and test",
214                overlap.len()
215            ));
216        }
217
218        // Check temporal leakage (test dates before train dates)
219        if !data.train.dates.is_empty() && !data.test.dates.is_empty() {
220            let train_max = data.train.dates.iter().max();
221            let test_min = data.test.dates.iter().min();
222
223            if let (Some(train_max), Some(test_min)) = (train_max, test_min) {
224                if test_min < train_max {
225                    leakage = true;
226                    details.push(format!(
227                        "Temporal leakage: test min date {test_min} < train max date {train_max}"
228                    ));
229                }
230            }
231        }
232
233        if let Some(ref val) = data.validation {
234            let val_entities: std::collections::HashSet<_> = val.entity_ids.iter().collect();
235
236            let train_val_overlap: Vec<_> = train_entities.intersection(&val_entities).collect();
237            if !train_val_overlap.is_empty() {
238                leakage = true;
239                details.push(format!(
240                    "Entity leakage: {} entities appear in both train and validation",
241                    train_val_overlap.len()
242                ));
243            }
244
245            let val_test_overlap: Vec<_> = val_entities.intersection(&test_entities).collect();
246            if !val_test_overlap.is_empty() {
247                leakage = true;
248                details.push(format!(
249                    "Entity leakage: {} entities appear in both validation and test",
250                    val_test_overlap.len()
251                ));
252            }
253        }
254
255        (leakage, details)
256    }
257
258    /// Calculate metrics for a split set.
259    fn calculate_metrics(&self, data: &SplitSetData) -> SplitMetrics {
260        let mut class_counts: HashMap<String, usize> = HashMap::new();
261        for label in &data.labels {
262            *class_counts.entry(label.clone()).or_insert(0) += 1;
263        }
264
265        let total = data.labels.len();
266        let class_distribution: HashMap<String, f64> = class_counts
267            .iter()
268            .map(|(k, v)| {
269                (
270                    k.clone(),
271                    if total > 0 {
272                        *v as f64 / total as f64
273                    } else {
274                        0.0
275                    },
276                )
277            })
278            .collect();
279
280        let unique_entities = data
281            .entity_ids
282            .iter()
283            .collect::<std::collections::HashSet<_>>()
284            .len();
285
286        let date_range = if !data.dates.is_empty() {
287            let min = data.dates.iter().min().cloned();
288            let max = data.dates.iter().max().cloned();
289            match (min, max) {
290                (Some(min), Some(max)) => Some((min, max)),
291                _ => None,
292            }
293        } else {
294            None
295        };
296
297        SplitMetrics {
298            sample_count: data.sample_count,
299            class_distribution,
300            unique_entities,
301            date_range,
302        }
303    }
304
305    /// Check distribution preservation between train and test.
306    fn check_distribution(&self, train: &SplitMetrics, test: &SplitMetrics) -> (bool, f64) {
307        if train.class_distribution.is_empty() || test.class_distribution.is_empty() {
308            return (true, 0.0);
309        }
310
311        // Calculate KL divergence: KL(P||Q) = sum(P(x) * log(P(x)/Q(x)))
312        let mut kl_divergence = 0.0;
313        let epsilon = 1e-10;
314
315        for (class, train_prob) in &train.class_distribution {
316            let test_prob = test.class_distribution.get(class).unwrap_or(&epsilon);
317            let p = *train_prob + epsilon;
318            let q = *test_prob + epsilon;
319            kl_divergence += p * (p / q).ln();
320        }
321
322        // Also account for classes in test not in train
323        for (class, test_prob) in &test.class_distribution {
324            if !train.class_distribution.contains_key(class) {
325                let p = epsilon;
326                let q = *test_prob + epsilon;
327                kl_divergence += p * (p / q).ln();
328            }
329        }
330
331        let preserved = kl_divergence <= self.max_kl_divergence;
332        (preserved, kl_divergence)
333    }
334}
335
336impl Default for SplitAnalyzer {
337    fn default() -> Self {
338        Self::new()
339    }
340}
341
342#[cfg(test)]
343#[allow(clippy::unwrap_used)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_valid_split() {
349        let data = SplitData {
350            train: SplitSetData {
351                sample_count: 70,
352                labels: vec!["A".to_string(); 50]
353                    .into_iter()
354                    .chain(vec!["B".to_string(); 20])
355                    .collect(),
356                entity_ids: (0..70).map(|i| format!("E{}", i)).collect(),
357                dates: vec![],
358            },
359            validation: Some(SplitSetData {
360                sample_count: 15,
361                labels: vec!["A".to_string(); 11]
362                    .into_iter()
363                    .chain(vec!["B".to_string(); 4])
364                    .collect(),
365                entity_ids: (70..85).map(|i| format!("E{}", i)).collect(),
366                dates: vec![],
367            }),
368            test: SplitSetData {
369                sample_count: 15,
370                labels: vec!["A".to_string(); 11]
371                    .into_iter()
372                    .chain(vec!["B".to_string(); 4])
373                    .collect(),
374                entity_ids: (85..100).map(|i| format!("E{}", i)).collect(),
375                dates: vec![],
376            },
377            expected_ratios: SplitRatios::default(),
378        };
379
380        let analyzer = SplitAnalyzer::new();
381        let result = analyzer.analyze(&data).unwrap();
382
383        assert!(result.ratio_valid);
384        assert!(!result.leakage_detected);
385        assert!(result.is_valid);
386    }
387
388    #[test]
389    fn test_entity_leakage() {
390        let data = SplitData {
391            train: SplitSetData {
392                sample_count: 70,
393                labels: vec![],
394                entity_ids: vec!["E1".to_string(), "E2".to_string(), "E3".to_string()],
395                dates: vec![],
396            },
397            validation: None,
398            test: SplitSetData {
399                sample_count: 30,
400                labels: vec![],
401                entity_ids: vec!["E1".to_string(), "E4".to_string()], // E1 is in both
402                dates: vec![],
403            },
404            expected_ratios: SplitRatios {
405                train: 0.7,
406                validation: 0.0,
407                test: 0.3,
408            },
409        };
410
411        let analyzer = SplitAnalyzer::new();
412        let result = analyzer.analyze(&data).unwrap();
413
414        assert!(result.leakage_detected);
415        assert!(!result.is_valid);
416    }
417}