Skip to main content

alimentar/
split.rs

1// Allow casts for size calculations - these are intentional and safe for
2// dataset sizes
3#![allow(clippy::cast_possible_truncation)]
4#![allow(clippy::cast_sign_loss)]
5#![allow(clippy::cast_precision_loss)]
6#![allow(clippy::cast_possible_wrap)]
7
8//! Dataset splitting utilities
9//!
10//! Provides train/test/validation splitting with stratification support.
11//!
12//! # Example
13//!
14//! ```ignore
15//! use alimentar::split::DatasetSplit;
16//!
17//! // Simple ratio split
18//! let split = DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, None)?;
19//!
20//! // With validation
21//! let split = DatasetSplit::from_ratios(&dataset, 0.7, 0.15, Some(0.15), Some(42))?;
22//!
23//! // Stratified by label column
24//! let split = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, None)?;
25//! ```
26
27use std::{collections::HashMap, sync::Arc};
28
29use arrow::array::{Array, RecordBatch};
30
31use crate::{
32    error::{Error, Result},
33    transform::{Skip, Take, Transform},
34    ArrowDataset, Dataset,
35};
36
37/// Dataset split with optional validation set
38#[derive(Debug, Clone)]
39pub struct DatasetSplit {
40    /// Training dataset (required)
41    pub train: ArrowDataset,
42    /// Test/holdout dataset (required)
43    pub test: ArrowDataset,
44    /// Validation dataset (optional)
45    pub validation: Option<ArrowDataset>,
46}
47
48impl DatasetSplit {
49    /// Create train/test split (no validation)
50    pub fn new(train: ArrowDataset, test: ArrowDataset) -> Self {
51        Self {
52            train,
53            test,
54            validation: None,
55        }
56    }
57
58    /// Create train/test/validation split
59    pub fn with_validation(
60        train: ArrowDataset,
61        test: ArrowDataset,
62        validation: ArrowDataset,
63    ) -> Self {
64        Self {
65            train,
66            test,
67            validation: Some(validation),
68        }
69    }
70
71    /// Get training data
72    pub fn train(&self) -> &ArrowDataset {
73        &self.train
74    }
75
76    /// Get test data
77    pub fn test(&self) -> &ArrowDataset {
78        &self.test
79    }
80
81    /// Get validation data (if present)
82    pub fn validation(&self) -> Option<&ArrowDataset> {
83        self.validation.as_ref()
84    }
85
86    /// Split dataset by ratios
87    ///
88    /// # Arguments
89    /// * `dataset` - Source dataset to split
90    /// * `train_ratio` - Fraction for training (0.0 to 1.0)
91    /// * `test_ratio` - Fraction for testing (0.0 to 1.0)
92    /// * `val_ratio` - Optional fraction for validation
93    /// * `seed` - Optional random seed for shuffling
94    ///
95    /// # Errors
96    /// Returns error if ratios don't sum to 1.0 or dataset is empty
97    pub fn from_ratios(
98        dataset: &ArrowDataset,
99        train_ratio: f64,
100        test_ratio: f64,
101        val_ratio: Option<f64>,
102        seed: Option<u64>,
103    ) -> Result<Self> {
104        // Validate ratios
105        let total = train_ratio + test_ratio + val_ratio.unwrap_or(0.0);
106        if (total - 1.0).abs() > 1e-9 {
107            return Err(Error::invalid_config(format!(
108                "Split ratios must sum to 1.0, got {total}"
109            )));
110        }
111
112        if train_ratio <= 0.0 || test_ratio <= 0.0 {
113            return Err(Error::invalid_config(
114                "Train and test ratios must be positive",
115            ));
116        }
117
118        if let Some(v) = val_ratio {
119            if v <= 0.0 {
120                return Err(Error::invalid_config(
121                    "Validation ratio must be positive if specified",
122                ));
123            }
124        }
125
126        let len = dataset.len();
127        if len == 0 {
128            return Err(Error::empty_dataset("Cannot split empty dataset"));
129        }
130
131        // Get all data as a single batch (concatenate if multiple batches)
132        let batch = concatenate_batches(dataset)?;
133
134        let batch = if let Some(s) = seed {
135            shuffle_batch(&batch, s)?
136        } else {
137            batch
138        };
139
140        // Calculate sizes
141        let train_size = ((len as f64) * train_ratio).round() as usize;
142        let test_size = ((len as f64) * test_ratio).round() as usize;
143        let val_size = val_ratio.map(|v| ((len as f64) * v).round() as usize);
144
145        // Adjust for rounding errors
146        let train_size = train_size.max(1);
147        let test_size = test_size.max(1);
148
149        // Split the batch
150        let train_batch = Take::new(train_size).apply(batch.clone())?;
151        let remaining = Skip::new(train_size).apply(batch)?;
152
153        let (test_batch, validation) = if val_size.is_some() {
154            let test_batch = Take::new(test_size).apply(remaining.clone())?;
155            let val_batch = Skip::new(test_size).apply(remaining)?;
156            (test_batch, Some(ArrowDataset::from_batch(val_batch)?))
157        } else {
158            (remaining, None)
159        };
160
161        Ok(Self {
162            train: ArrowDataset::from_batch(train_batch)?,
163            test: ArrowDataset::from_batch(test_batch)?,
164            validation,
165        })
166    }
167
168    /// Stratified split preserving label distribution
169    ///
170    /// # Arguments
171    /// * `dataset` - Source dataset to split
172    /// * `label_column` - Name of the label/target column
173    /// * `train_ratio` - Fraction for training
174    /// * `test_ratio` - Fraction for testing
175    /// * `val_ratio` - Optional fraction for validation
176    /// * `seed` - Optional random seed
177    ///
178    /// # Errors
179    /// Returns error if label column not found or ratios invalid
180    pub fn stratified(
181        dataset: &ArrowDataset,
182        label_column: &str,
183        train_ratio: f64,
184        test_ratio: f64,
185        val_ratio: Option<f64>,
186        seed: Option<u64>,
187    ) -> Result<Self> {
188        // Validate ratios
189        let total = train_ratio + test_ratio + val_ratio.unwrap_or(0.0);
190        if (total - 1.0).abs() > 1e-9 {
191            return Err(Error::invalid_config(format!(
192                "Split ratios must sum to 1.0, got {total}"
193            )));
194        }
195
196        let len = dataset.len();
197        if len == 0 {
198            return Err(Error::empty_dataset("Cannot split empty dataset"));
199        }
200
201        // Get all data as a single batch
202        let batch = concatenate_batches(dataset)?;
203
204        // Find label column
205        let schema = batch.schema();
206        let label_idx = schema.index_of(label_column).map_err(|_| {
207            Error::invalid_config(format!("Label column '{label_column}' not found"))
208        })?;
209
210        let label_array = batch.column(label_idx);
211
212        // Group indices by label value
213        let groups = group_by_label(label_array)?;
214
215        // Split each group proportionally
216        let mut train_indices = Vec::new();
217        let mut test_indices = Vec::new();
218        let mut val_indices = Vec::new();
219
220        let base_seed = seed.unwrap_or(0);
221
222        for (label_value, mut indices) in groups {
223            // Shuffle within group
224            if seed.is_some() {
225                // Simple deterministic shuffle using label as additional seed component
226                let group_seed = base_seed.wrapping_add(label_value as u64);
227                shuffle_indices(&mut indices, group_seed);
228            }
229
230            let group_len = indices.len();
231            let group_train = ((group_len as f64) * train_ratio).round() as usize;
232            let group_test = ((group_len as f64) * test_ratio).round() as usize;
233
234            let group_train = group_train.max(1).min(group_len);
235
236            train_indices.extend_from_slice(&indices[..group_train]);
237
238            if val_ratio.is_some() {
239                let remaining = group_len.saturating_sub(group_train);
240                let group_test = group_test.min(remaining);
241                test_indices.extend_from_slice(&indices[group_train..group_train + group_test]);
242                val_indices.extend_from_slice(&indices[group_train + group_test..]);
243            } else {
244                test_indices.extend_from_slice(&indices[group_train..]);
245            }
246        }
247
248        // Build split datasets from indices
249        let train_batch = take_indices(&batch, &train_indices)?;
250        let test_batch = take_indices(&batch, &test_indices)?;
251
252        let validation = if val_ratio.is_some() && !val_indices.is_empty() {
253            Some(ArrowDataset::from_batch(take_indices(
254                &batch,
255                &val_indices,
256            )?)?)
257        } else {
258            None
259        };
260
261        Ok(Self {
262            train: ArrowDataset::from_batch(train_batch)?,
263            test: ArrowDataset::from_batch(test_batch)?,
264            validation,
265        })
266    }
267}
268
269/// Concatenate all batches from a dataset into a single batch
270fn concatenate_batches(dataset: &ArrowDataset) -> Result<RecordBatch> {
271    use arrow::compute::concat_batches;
272
273    let schema = dataset.schema();
274    let batches: Vec<RecordBatch> = dataset.iter().collect();
275
276    if batches.is_empty() {
277        return Err(Error::empty_dataset("Dataset has no batches"));
278    }
279
280    if batches.len() == 1 {
281        return batches
282            .into_iter()
283            .next()
284            .ok_or_else(|| Error::empty_dataset("Dataset has no batches"));
285    }
286
287    concat_batches(&schema, &batches).map_err(Error::Arrow)
288}
289
290/// Shuffle a record batch deterministically
291fn shuffle_batch(batch: &RecordBatch, seed: u64) -> Result<RecordBatch> {
292    let len = batch.num_rows();
293    let mut indices: Vec<usize> = (0..len).collect();
294    shuffle_indices(&mut indices, seed);
295    take_indices(batch, &indices)
296}
297
298/// Shuffle indices in place using simple LCG
299fn shuffle_indices(indices: &mut [usize], seed: u64) {
300    // Simple Fisher-Yates with LCG random
301    let mut rng = seed;
302    for i in (1..indices.len()).rev() {
303        // LCG: next = (a * current + c) mod m
304        rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
305        let j = (rng as usize) % (i + 1);
306        indices.swap(i, j);
307    }
308}
309
310/// Group row indices by label value
311fn group_by_label(label_array: &Arc<dyn Array>) -> Result<HashMap<i64, Vec<usize>>> {
312    use arrow::{
313        array::{Int32Array, Int64Array, StringArray, UInt32Array, UInt64Array},
314        datatypes::DataType,
315    };
316
317    let mut groups: HashMap<i64, Vec<usize>> = HashMap::new();
318
319    match label_array.data_type() {
320        DataType::Int32 => {
321            let arr = downcast_label::<Int32Array>(label_array, "Int32Array")?;
322            collect_groups(arr.iter(), &mut groups, i64::from);
323        }
324        DataType::Int64 => {
325            let arr = downcast_label::<Int64Array>(label_array, "Int64Array")?;
326            collect_groups(arr.iter(), &mut groups, |v| v);
327        }
328        DataType::UInt32 => {
329            let arr = downcast_label::<UInt32Array>(label_array, "UInt32Array")?;
330            collect_groups(arr.iter(), &mut groups, i64::from);
331        }
332        DataType::UInt64 => {
333            let arr = downcast_label::<UInt64Array>(label_array, "UInt64Array")?;
334            // May truncate very large values
335            collect_groups(arr.iter(), &mut groups, |v| v as i64);
336        }
337        DataType::Utf8 | DataType::LargeUtf8 => {
338            // Hash string labels to i64 for grouping (alimentar#37)
339            let arr = downcast_label::<StringArray>(label_array, "StringArray")?;
340            collect_groups(arr.iter(), &mut groups, |s: &str| {
341                // FNV-1a hash — deterministic, fast, good distribution
342                let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
343                for byte in s.as_bytes() {
344                    hash ^= u64::from(*byte);
345                    hash = hash.wrapping_mul(0x0100_0000_01b3);
346                }
347                hash as i64
348            });
349        }
350        dt => {
351            return Err(Error::invalid_config(format!(
352                "Unsupported label type for stratification: {dt:?}"
353            )))
354        }
355    }
356
357    Ok(groups)
358}
359
360/// Downcast a label array to a concrete Arrow array type
361fn downcast_label<'a, T: 'static>(array: &'a Arc<dyn Array>, type_name: &str) -> Result<&'a T> {
362    array
363        .as_any()
364        .downcast_ref::<T>()
365        .ok_or_else(|| Error::invalid_config(format!("Failed to downcast {type_name}")))
366}
367
368/// Collect values from an Arrow array iterator into groups by label.
369fn collect_groups<V, F>(
370    iter: impl Iterator<Item = Option<V>>,
371    groups: &mut HashMap<i64, Vec<usize>>,
372    to_i64: F,
373) where
374    F: Fn(V) -> i64,
375{
376    for (i, val) in iter.enumerate() {
377        if let Some(v) = val {
378            groups.entry(to_i64(v)).or_default().push(i);
379        }
380    }
381}
382
383/// Take rows at given indices from a batch
384fn take_indices(batch: &RecordBatch, indices: &[usize]) -> Result<RecordBatch> {
385    use arrow::{array::UInt32Array, compute::take};
386
387    let indices_array = UInt32Array::from(indices.iter().map(|&i| i as u32).collect::<Vec<_>>());
388
389    let columns: Vec<Arc<dyn Array>> = batch
390        .columns()
391        .iter()
392        .map(|col| take(col.as_ref(), &indices_array, None).map_err(Error::Arrow))
393        .collect::<Result<Vec<_>>>()?;
394
395    RecordBatch::try_new(batch.schema(), columns).map_err(Error::Arrow)
396}
397
398#[cfg(test)]
399mod tests {
400    use arrow::{
401        array::{Float64Array, Int32Array},
402        datatypes::{DataType, Field, Schema},
403    };
404
405    use super::*;
406
407    /// Helper to create a test dataset with n samples
408    fn make_test_dataset(n: usize) -> ArrowDataset {
409        let schema = Arc::new(Schema::new(vec![
410            Field::new("feature", DataType::Float64, false),
411            Field::new("label", DataType::Int32, false),
412        ]));
413
414        let features: Vec<f64> = (0..n).map(|i| i as f64).collect();
415        let labels: Vec<i32> = (0..n).map(|i| (i % 3) as i32).collect(); // 3 classes
416
417        let batch = RecordBatch::try_new(
418            schema,
419            vec![
420                Arc::new(Float64Array::from(features)),
421                Arc::new(Int32Array::from(labels)),
422            ],
423        )
424        .expect("batch creation failed");
425
426        ArrowDataset::from_batch(batch).expect("dataset creation failed")
427    }
428
429    // ========== DatasetSplit::new tests ==========
430
431    #[test]
432    fn test_new_creates_split_without_validation() {
433        let train = make_test_dataset(80);
434        let test = make_test_dataset(20);
435
436        let split = DatasetSplit::new(train, test);
437
438        assert_eq!(split.train().len(), 80);
439        assert_eq!(split.test().len(), 20);
440        assert!(split.validation().is_none());
441    }
442
443    // ========== DatasetSplit::with_validation tests ==========
444
445    #[test]
446    fn test_with_validation_creates_three_way_split() {
447        let train = make_test_dataset(70);
448        let test = make_test_dataset(15);
449        let val = make_test_dataset(15);
450
451        let split = DatasetSplit::with_validation(train, test, val);
452
453        assert_eq!(split.train().len(), 70);
454        assert_eq!(split.test().len(), 15);
455        assert!(split.validation().is_some());
456        assert_eq!(split.validation().expect("val").len(), 15);
457    }
458
459    // ========== DatasetSplit::from_ratios tests ==========
460
461    #[test]
462    fn test_from_ratios_80_20_split() {
463        let dataset = make_test_dataset(100);
464
465        let split =
466            DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, None).expect("split failed");
467
468        assert_eq!(split.train().len(), 80);
469        assert_eq!(split.test().len(), 20);
470        assert!(split.validation().is_none());
471    }
472
473    #[test]
474    fn test_from_ratios_70_15_15_split() {
475        let dataset = make_test_dataset(100);
476
477        let split =
478            DatasetSplit::from_ratios(&dataset, 0.7, 0.15, Some(0.15), None).expect("split failed");
479
480        assert_eq!(split.train().len(), 70);
481        assert_eq!(split.test().len(), 15);
482        assert!(split.validation().is_some());
483        assert_eq!(split.validation().expect("val").len(), 15);
484    }
485
486    #[test]
487    fn test_from_ratios_with_seed_is_deterministic() {
488        let dataset = make_test_dataset(100);
489
490        let split1 =
491            DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, Some(42)).expect("split failed");
492        let split2 =
493            DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, Some(42)).expect("split failed");
494
495        // Same seed should produce same split
496        let train1 = split1.train().get(0).expect("batch");
497        let train2 = split2.train().get(0).expect("batch");
498
499        assert_eq!(train1.num_rows(), train2.num_rows());
500        // Check first column values match
501        let col1 = train1
502            .column(0)
503            .as_any()
504            .downcast_ref::<Float64Array>()
505            .expect("downcast");
506        let col2 = train2
507            .column(0)
508            .as_any()
509            .downcast_ref::<Float64Array>()
510            .expect("downcast");
511
512        for i in 0..col1.len() {
513            assert!(
514                (col1.value(i) - col2.value(i)).abs() < 1e-9,
515                "Mismatch at index {i}"
516            );
517        }
518    }
519
520    #[test]
521    fn test_from_ratios_different_seeds_produce_different_splits() {
522        let dataset = make_test_dataset(100);
523
524        let split1 =
525            DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, Some(42)).expect("split failed");
526        let split2 =
527            DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, Some(123)).expect("split failed");
528
529        let train1 = split1.train().get(0).expect("batch");
530        let train2 = split2.train().get(0).expect("batch");
531
532        let col1 = train1
533            .column(0)
534            .as_any()
535            .downcast_ref::<Float64Array>()
536            .expect("downcast");
537        let col2 = train2
538            .column(0)
539            .as_any()
540            .downcast_ref::<Float64Array>()
541            .expect("downcast");
542
543        // At least some values should differ
544        let mut differs = false;
545        for i in 0..col1.len().min(col2.len()) {
546            if (col1.value(i) - col2.value(i)).abs() > 1e-9 {
547                differs = true;
548                break;
549            }
550        }
551        assert!(differs, "Different seeds should produce different shuffles");
552    }
553
554    #[test]
555    fn test_from_ratios_rejects_invalid_ratios() {
556        let dataset = make_test_dataset(100);
557
558        // Ratios don't sum to 1.0
559        let result = DatasetSplit::from_ratios(&dataset, 0.5, 0.3, None, None);
560        assert!(result.is_err());
561
562        // Zero train ratio
563        let result = DatasetSplit::from_ratios(&dataset, 0.0, 1.0, None, None);
564        assert!(result.is_err());
565
566        // Zero test ratio
567        let result = DatasetSplit::from_ratios(&dataset, 1.0, 0.0, None, None);
568        assert!(result.is_err());
569
570        // Zero validation ratio
571        let result = DatasetSplit::from_ratios(&dataset, 0.8, 0.19, Some(0.0), None);
572        assert!(result.is_err());
573    }
574
575    #[test]
576    fn test_from_ratios_rejects_empty_dataset() {
577        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float64, false)]));
578        let batch = RecordBatch::try_new(
579            schema,
580            vec![Arc::new(Float64Array::from(Vec::<f64>::new()))],
581        )
582        .expect("batch");
583        let dataset = ArrowDataset::from_batch(batch).expect("dataset");
584
585        let result = DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, None);
586        assert!(result.is_err());
587    }
588
589    #[test]
590    fn test_from_ratios_handles_small_dataset() {
591        let dataset = make_test_dataset(3);
592
593        let split =
594            DatasetSplit::from_ratios(&dataset, 0.7, 0.3, None, None).expect("split failed");
595
596        // Should have at least 1 in each
597        assert!(split.train().len() >= 1);
598        assert!(split.test().len() >= 1);
599        assert_eq!(split.train().len() + split.test().len(), 3);
600    }
601
602    // ========== DatasetSplit::stratified tests ==========
603
604    #[test]
605    fn test_stratified_preserves_class_distribution() {
606        // Create dataset with known class distribution: 60% class 0, 30% class 1, 10%
607        // class 2
608        let schema = Arc::new(Schema::new(vec![
609            Field::new("feature", DataType::Float64, false),
610            Field::new("label", DataType::Int32, false),
611        ]));
612
613        let n = 100;
614        let features: Vec<f64> = (0..n).map(|i| i as f64).collect();
615        let labels: Vec<i32> = (0..n)
616            .map(|i| {
617                if i < 60 {
618                    0
619                } else if i < 90 {
620                    1
621                } else {
622                    2
623                }
624            })
625            .collect();
626
627        let batch = RecordBatch::try_new(
628            schema,
629            vec![
630                Arc::new(Float64Array::from(features)),
631                Arc::new(Int32Array::from(labels)),
632            ],
633        )
634        .expect("batch");
635        let dataset = ArrowDataset::from_batch(batch).expect("dataset");
636
637        let split =
638            DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42)).expect("split");
639
640        // Count classes in train (iterate all batches)
641        let mut train_counts = [0usize; 3];
642        for batch in split.train().iter() {
643            let labels = batch
644                .column(1)
645                .as_any()
646                .downcast_ref::<Int32Array>()
647                .expect("downcast");
648            for val in labels.iter().flatten() {
649                train_counts[val as usize] += 1;
650            }
651        }
652
653        // Count classes in test
654        let mut test_counts = [0usize; 3];
655        for batch in split.test().iter() {
656            let labels = batch
657                .column(1)
658                .as_any()
659                .downcast_ref::<Int32Array>()
660                .expect("downcast");
661            for val in labels.iter().flatten() {
662                test_counts[val as usize] += 1;
663            }
664        }
665
666        // Verify proportions are approximately preserved (within tolerance)
667        // Original: 60/30/10, expect ~same ratio in both splits
668        let train_total = train_counts.iter().sum::<usize>() as f64;
669        let test_total = test_counts.iter().sum::<usize>() as f64;
670
671        let train_ratio_0 = train_counts[0] as f64 / train_total;
672        let test_ratio_0 = test_counts[0] as f64 / test_total;
673
674        // Should be close to 60% in both
675        assert!(
676            (train_ratio_0 - 0.6).abs() < 0.15,
677            "Train class 0 ratio {train_ratio_0} too far from 0.6"
678        );
679        assert!(
680            (test_ratio_0 - 0.6).abs() < 0.15,
681            "Test class 0 ratio {test_ratio_0} too far from 0.6"
682        );
683    }
684
685    #[test]
686    fn test_stratified_with_validation() {
687        let dataset = make_test_dataset(90); // Divisible by 3 classes
688
689        let split = DatasetSplit::stratified(&dataset, "label", 0.7, 0.15, Some(0.15), Some(42))
690            .expect("split");
691
692        assert!(split.validation().is_some());
693        let total = split.train().len() + split.test().len() + split.validation().expect("v").len();
694        assert_eq!(total, 90);
695    }
696
697    #[test]
698    fn test_stratified_rejects_missing_column() {
699        let dataset = make_test_dataset(100);
700
701        let result = DatasetSplit::stratified(&dataset, "nonexistent", 0.8, 0.2, None, None);
702        assert!(result.is_err());
703    }
704
705    #[test]
706    fn test_stratified_rejects_invalid_ratios() {
707        let dataset = make_test_dataset(100);
708
709        let result = DatasetSplit::stratified(&dataset, "label", 0.5, 0.3, None, None);
710        assert!(result.is_err());
711    }
712
713    #[test]
714    fn test_stratified_is_deterministic_with_seed() {
715        let dataset = make_test_dataset(100);
716
717        let split1 =
718            DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42)).expect("split");
719        let split2 =
720            DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42)).expect("split");
721
722        assert_eq!(split1.train().len(), split2.train().len());
723        assert_eq!(split1.test().len(), split2.test().len());
724    }
725
726    // ========== Edge case tests ==========
727
728    #[test]
729    fn test_split_preserves_schema() {
730        let dataset = make_test_dataset(100);
731        let original_schema = dataset.schema();
732
733        let split =
734            DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, None).expect("split failed");
735
736        assert_eq!(split.train().schema(), original_schema);
737        assert_eq!(split.test().schema(), original_schema);
738    }
739
740    #[test]
741    fn test_split_no_data_overlap() {
742        let dataset = make_test_dataset(100);
743
744        let split =
745            DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, Some(42)).expect("split failed");
746
747        // Collect all train values
748        let mut train_set: std::collections::HashSet<u64> = std::collections::HashSet::new();
749        for batch in split.train().iter() {
750            let features = batch
751                .column(0)
752                .as_any()
753                .downcast_ref::<Float64Array>()
754                .expect("downcast");
755            for val in features.iter().flatten() {
756                train_set.insert(val.to_bits());
757            }
758        }
759
760        // Check no test values are in train
761        for batch in split.test().iter() {
762            let features = batch
763                .column(0)
764                .as_any()
765                .downcast_ref::<Float64Array>()
766                .expect("downcast");
767            for val in features.iter().flatten() {
768                assert!(
769                    !train_set.contains(&val.to_bits()),
770                    "Found overlapping value {val} in train and test"
771                );
772            }
773        }
774    }
775
776    #[test]
777    fn test_stratified_with_int64_labels() {
778        use arrow::array::Int64Array;
779
780        let schema = Arc::new(Schema::new(vec![
781            Field::new("feature", DataType::Float64, false),
782            Field::new("label", DataType::Int64, false),
783        ]));
784
785        let features: Vec<f64> = (0..100).map(|i| i as f64).collect();
786        let labels: Vec<i64> = (0..100).map(|i| (i % 3) as i64).collect();
787
788        let batch = RecordBatch::try_new(
789            schema,
790            vec![
791                Arc::new(Float64Array::from(features)),
792                Arc::new(Int64Array::from(labels)),
793            ],
794        )
795        .expect("batch creation failed");
796
797        let dataset = ArrowDataset::from_batch(batch).expect("dataset creation failed");
798
799        let split = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42))
800            .expect("split failed");
801
802        assert!(split.train().len() > 0);
803        assert!(split.test().len() > 0);
804    }
805
806    #[test]
807    fn test_stratified_with_uint32_labels() {
808        use arrow::array::UInt32Array;
809
810        let schema = Arc::new(Schema::new(vec![
811            Field::new("feature", DataType::Float64, false),
812            Field::new("label", DataType::UInt32, false),
813        ]));
814
815        let features: Vec<f64> = (0..100).map(|i| i as f64).collect();
816        let labels: Vec<u32> = (0..100).map(|i| (i % 3) as u32).collect();
817
818        let batch = RecordBatch::try_new(
819            schema,
820            vec![
821                Arc::new(Float64Array::from(features)),
822                Arc::new(UInt32Array::from(labels)),
823            ],
824        )
825        .expect("batch creation failed");
826
827        let dataset = ArrowDataset::from_batch(batch).expect("dataset creation failed");
828
829        let split = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42))
830            .expect("split failed");
831
832        assert!(split.train().len() > 0);
833        assert!(split.test().len() > 0);
834    }
835
836    #[test]
837    fn test_stratified_with_uint64_labels() {
838        use arrow::array::UInt64Array;
839
840        let schema = Arc::new(Schema::new(vec![
841            Field::new("feature", DataType::Float64, false),
842            Field::new("label", DataType::UInt64, false),
843        ]));
844
845        let features: Vec<f64> = (0..100).map(|i| i as f64).collect();
846        let labels: Vec<u64> = (0..100).map(|i| (i % 3) as u64).collect();
847
848        let batch = RecordBatch::try_new(
849            schema,
850            vec![
851                Arc::new(Float64Array::from(features)),
852                Arc::new(UInt64Array::from(labels)),
853            ],
854        )
855        .expect("batch creation failed");
856
857        let dataset = ArrowDataset::from_batch(batch).expect("dataset creation failed");
858
859        let split = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42))
860            .expect("split failed");
861
862        assert!(split.train().len() > 0);
863        assert!(split.test().len() > 0);
864    }
865
866    #[test]
867    fn test_stratified_with_string_labels() {
868        use arrow::array::StringArray;
869
870        let schema = Arc::new(Schema::new(vec![
871            Field::new("feature", DataType::Float64, false),
872            Field::new("label", DataType::Utf8, false),
873        ]));
874
875        let features: Vec<f64> = (0..100).map(|i| i as f64).collect();
876        let labels: Vec<&str> = (0..100)
877            .map(|i| if i % 2 == 0 { "a" } else { "b" })
878            .collect();
879
880        let batch = RecordBatch::try_new(
881            schema,
882            vec![
883                Arc::new(Float64Array::from(features)),
884                Arc::new(StringArray::from(labels)),
885            ],
886        )
887        .expect("batch creation failed");
888
889        let dataset = ArrowDataset::from_batch(batch).expect("dataset creation failed");
890
891        let split = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42))
892            .expect("stratified split with string labels should succeed");
893        assert_eq!(split.train().len() + split.test().len(), 100);
894        assert!(split.train().len() > 0);
895        assert!(split.test().len() > 0);
896    }
897
898    #[test]
899    fn test_stratified_without_seed() {
900        let dataset = make_test_dataset(100);
901
902        // Without seed, should still work
903        let split = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, None)
904            .expect("split failed");
905
906        assert!(split.train().len() > 0);
907        assert!(split.test().len() > 0);
908    }
909
910    #[test]
911    fn test_split_debug() {
912        let dataset = make_test_dataset(100);
913        let split =
914            DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, None).expect("split failed");
915
916        let debug = format!("{:?}", split);
917        assert!(debug.contains("DatasetSplit"));
918    }
919
920    #[test]
921    fn test_split_clone() {
922        let dataset = make_test_dataset(100);
923        let split =
924            DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, None).expect("split failed");
925
926        let cloned = split.clone();
927        assert_eq!(cloned.train().len(), split.train().len());
928        assert_eq!(cloned.test().len(), split.test().len());
929    }
930
931    #[test]
932    fn test_extreme_ratio_99_1() {
933        let dataset = make_test_dataset(100);
934        let split =
935            DatasetSplit::from_ratios(&dataset, 0.99, 0.01, None, None).expect("split failed");
936
937        assert_eq!(split.train().len(), 99);
938        assert_eq!(split.test().len(), 1);
939    }
940
941    #[test]
942    fn test_extreme_ratio_50_50() {
943        let dataset = make_test_dataset(100);
944        let split =
945            DatasetSplit::from_ratios(&dataset, 0.5, 0.5, None, None).expect("split failed");
946
947        assert_eq!(split.train().len(), 50);
948        assert_eq!(split.test().len(), 50);
949    }
950
951    #[test]
952    fn test_negative_train_ratio_rejected() {
953        let dataset = make_test_dataset(100);
954        let result = DatasetSplit::from_ratios(&dataset, -0.5, 0.5, None, None);
955        assert!(result.is_err());
956    }
957
958    #[test]
959    fn test_zero_test_ratio_rejected() {
960        let dataset = make_test_dataset(100);
961        let result = DatasetSplit::from_ratios(&dataset, 1.0, 0.0, None, None);
962        assert!(result.is_err());
963    }
964
965    #[test]
966    fn test_negative_val_ratio_rejected() {
967        let dataset = make_test_dataset(100);
968        let result = DatasetSplit::from_ratios(&dataset, 0.6, 0.5, Some(-0.1), None);
969        assert!(result.is_err());
970    }
971
972    #[test]
973    fn test_single_row_minimum_sizes() {
974        let dataset = make_test_dataset(2);
975        let split =
976            DatasetSplit::from_ratios(&dataset, 0.5, 0.5, None, None).expect("split failed");
977
978        // Each should get at least 1 row
979        assert!(split.train().len() >= 1);
980        assert!(split.test().len() >= 1);
981    }
982
983    #[test]
984    fn test_ratios_slightly_over_one() {
985        let dataset = make_test_dataset(100);
986        // Sum is 1.01, should be rejected
987        let result = DatasetSplit::from_ratios(&dataset, 0.81, 0.2, None, None);
988        assert!(result.is_err());
989    }
990
991    #[test]
992    fn test_ratios_slightly_under_one() {
993        let dataset = make_test_dataset(100);
994        // Sum is 0.99, should be rejected
995        let result = DatasetSplit::from_ratios(&dataset, 0.79, 0.2, None, None);
996        assert!(result.is_err());
997    }
998
999    #[test]
1000    fn test_getters_return_correct_data() {
1001        let train = make_test_dataset(80);
1002        let test = make_test_dataset(20);
1003        let val = make_test_dataset(10);
1004
1005        let split = DatasetSplit::with_validation(train.clone(), test.clone(), val.clone());
1006
1007        assert_eq!(split.train().len(), 80);
1008        assert_eq!(split.test().len(), 20);
1009        assert_eq!(split.validation().map(|v| v.len()), Some(10));
1010    }
1011
1012    #[test]
1013    fn test_validation_none_for_two_way_split() {
1014        let train = make_test_dataset(80);
1015        let test = make_test_dataset(20);
1016
1017        let split = DatasetSplit::new(train, test);
1018
1019        assert!(split.validation().is_none());
1020    }
1021
1022    #[test]
1023    fn test_stratified_empty_dataset() {
1024        let schema = Arc::new(Schema::new(vec![
1025            Field::new("x", DataType::Float64, false),
1026            Field::new("label", DataType::Int32, false),
1027        ]));
1028        let x_array = arrow::array::Float64Array::from(Vec::<f64>::new());
1029        let label_array = Int32Array::from(Vec::<i32>::new());
1030        let batch = RecordBatch::try_new(schema, vec![Arc::new(x_array), Arc::new(label_array)])
1031            .expect("batch");
1032        let dataset = ArrowDataset::from_batch(batch).expect("dataset");
1033
1034        let result = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, None);
1035        assert!(result.is_err());
1036    }
1037
1038    #[test]
1039    fn test_stratified_zero_test_ratio_rejected() {
1040        let dataset = make_test_dataset(100);
1041        let result = DatasetSplit::stratified(&dataset, "y", 1.0, 0.0, None, None);
1042        assert!(result.is_err());
1043    }
1044
1045    #[test]
1046    fn test_split_preserves_all_rows() {
1047        let dataset = make_test_dataset(100);
1048        let split =
1049            DatasetSplit::from_ratios(&dataset, 0.6, 0.2, Some(0.2), None).expect("split failed");
1050
1051        let total = split.train().len()
1052            + split.test().len()
1053            + split.validation().map(|v| v.len()).unwrap_or(0);
1054        assert_eq!(total, 100);
1055    }
1056}