1#![allow(clippy::cast_possible_truncation)]
4#![allow(clippy::cast_sign_loss)]
5#![allow(clippy::cast_precision_loss)]
6#![allow(clippy::cast_possible_wrap)]
7
8use std::{collections::HashMap, sync::Arc};
28
29use arrow::array::{Array, RecordBatch};
30
31use crate::{
32 error::{Error, Result},
33 transform::{Skip, Take, Transform},
34 ArrowDataset, Dataset,
35};
36
37#[derive(Debug, Clone)]
39pub struct DatasetSplit {
40 pub train: ArrowDataset,
42 pub test: ArrowDataset,
44 pub validation: Option<ArrowDataset>,
46}
47
48impl DatasetSplit {
49 pub fn new(train: ArrowDataset, test: ArrowDataset) -> Self {
51 Self {
52 train,
53 test,
54 validation: None,
55 }
56 }
57
58 pub fn with_validation(
60 train: ArrowDataset,
61 test: ArrowDataset,
62 validation: ArrowDataset,
63 ) -> Self {
64 Self {
65 train,
66 test,
67 validation: Some(validation),
68 }
69 }
70
71 pub fn train(&self) -> &ArrowDataset {
73 &self.train
74 }
75
76 pub fn test(&self) -> &ArrowDataset {
78 &self.test
79 }
80
81 pub fn validation(&self) -> Option<&ArrowDataset> {
83 self.validation.as_ref()
84 }
85
86 pub fn from_ratios(
98 dataset: &ArrowDataset,
99 train_ratio: f64,
100 test_ratio: f64,
101 val_ratio: Option<f64>,
102 seed: Option<u64>,
103 ) -> Result<Self> {
104 let total = train_ratio + test_ratio + val_ratio.unwrap_or(0.0);
106 if (total - 1.0).abs() > 1e-9 {
107 return Err(Error::invalid_config(format!(
108 "Split ratios must sum to 1.0, got {total}"
109 )));
110 }
111
112 if train_ratio <= 0.0 || test_ratio <= 0.0 {
113 return Err(Error::invalid_config(
114 "Train and test ratios must be positive",
115 ));
116 }
117
118 if let Some(v) = val_ratio {
119 if v <= 0.0 {
120 return Err(Error::invalid_config(
121 "Validation ratio must be positive if specified",
122 ));
123 }
124 }
125
126 let len = dataset.len();
127 if len == 0 {
128 return Err(Error::empty_dataset("Cannot split empty dataset"));
129 }
130
131 let batch = concatenate_batches(dataset)?;
133
134 let batch = if let Some(s) = seed {
135 shuffle_batch(&batch, s)?
136 } else {
137 batch
138 };
139
140 let train_size = ((len as f64) * train_ratio).round() as usize;
142 let test_size = ((len as f64) * test_ratio).round() as usize;
143 let val_size = val_ratio.map(|v| ((len as f64) * v).round() as usize);
144
145 let train_size = train_size.max(1);
147 let test_size = test_size.max(1);
148
149 let train_batch = Take::new(train_size).apply(batch.clone())?;
151 let remaining = Skip::new(train_size).apply(batch)?;
152
153 let (test_batch, validation) = if val_size.is_some() {
154 let test_batch = Take::new(test_size).apply(remaining.clone())?;
155 let val_batch = Skip::new(test_size).apply(remaining)?;
156 (test_batch, Some(ArrowDataset::from_batch(val_batch)?))
157 } else {
158 (remaining, None)
159 };
160
161 Ok(Self {
162 train: ArrowDataset::from_batch(train_batch)?,
163 test: ArrowDataset::from_batch(test_batch)?,
164 validation,
165 })
166 }
167
168 pub fn stratified(
181 dataset: &ArrowDataset,
182 label_column: &str,
183 train_ratio: f64,
184 test_ratio: f64,
185 val_ratio: Option<f64>,
186 seed: Option<u64>,
187 ) -> Result<Self> {
188 let total = train_ratio + test_ratio + val_ratio.unwrap_or(0.0);
190 if (total - 1.0).abs() > 1e-9 {
191 return Err(Error::invalid_config(format!(
192 "Split ratios must sum to 1.0, got {total}"
193 )));
194 }
195
196 let len = dataset.len();
197 if len == 0 {
198 return Err(Error::empty_dataset("Cannot split empty dataset"));
199 }
200
201 let batch = concatenate_batches(dataset)?;
203
204 let schema = batch.schema();
206 let label_idx = schema.index_of(label_column).map_err(|_| {
207 Error::invalid_config(format!("Label column '{label_column}' not found"))
208 })?;
209
210 let label_array = batch.column(label_idx);
211
212 let groups = group_by_label(label_array)?;
214
215 let mut train_indices = Vec::new();
217 let mut test_indices = Vec::new();
218 let mut val_indices = Vec::new();
219
220 let base_seed = seed.unwrap_or(0);
221
222 for (label_value, mut indices) in groups {
223 if seed.is_some() {
225 let group_seed = base_seed.wrapping_add(label_value as u64);
227 shuffle_indices(&mut indices, group_seed);
228 }
229
230 let group_len = indices.len();
231 let group_train = ((group_len as f64) * train_ratio).round() as usize;
232 let group_test = ((group_len as f64) * test_ratio).round() as usize;
233
234 let group_train = group_train.max(1).min(group_len);
235
236 train_indices.extend_from_slice(&indices[..group_train]);
237
238 if val_ratio.is_some() {
239 let remaining = group_len.saturating_sub(group_train);
240 let group_test = group_test.min(remaining);
241 test_indices.extend_from_slice(&indices[group_train..group_train + group_test]);
242 val_indices.extend_from_slice(&indices[group_train + group_test..]);
243 } else {
244 test_indices.extend_from_slice(&indices[group_train..]);
245 }
246 }
247
248 let train_batch = take_indices(&batch, &train_indices)?;
250 let test_batch = take_indices(&batch, &test_indices)?;
251
252 let validation = if val_ratio.is_some() && !val_indices.is_empty() {
253 Some(ArrowDataset::from_batch(take_indices(
254 &batch,
255 &val_indices,
256 )?)?)
257 } else {
258 None
259 };
260
261 Ok(Self {
262 train: ArrowDataset::from_batch(train_batch)?,
263 test: ArrowDataset::from_batch(test_batch)?,
264 validation,
265 })
266 }
267}
268
269fn concatenate_batches(dataset: &ArrowDataset) -> Result<RecordBatch> {
271 use arrow::compute::concat_batches;
272
273 let schema = dataset.schema();
274 let batches: Vec<RecordBatch> = dataset.iter().collect();
275
276 if batches.is_empty() {
277 return Err(Error::empty_dataset("Dataset has no batches"));
278 }
279
280 if batches.len() == 1 {
281 return batches
282 .into_iter()
283 .next()
284 .ok_or_else(|| Error::empty_dataset("Dataset has no batches"));
285 }
286
287 concat_batches(&schema, &batches).map_err(Error::Arrow)
288}
289
290fn shuffle_batch(batch: &RecordBatch, seed: u64) -> Result<RecordBatch> {
292 let len = batch.num_rows();
293 let mut indices: Vec<usize> = (0..len).collect();
294 shuffle_indices(&mut indices, seed);
295 take_indices(batch, &indices)
296}
297
298fn shuffle_indices(indices: &mut [usize], seed: u64) {
300 let mut rng = seed;
302 for i in (1..indices.len()).rev() {
303 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
305 let j = (rng as usize) % (i + 1);
306 indices.swap(i, j);
307 }
308}
309
310fn group_by_label(label_array: &Arc<dyn Array>) -> Result<HashMap<i64, Vec<usize>>> {
312 use arrow::{
313 array::{Int32Array, Int64Array, StringArray, UInt32Array, UInt64Array},
314 datatypes::DataType,
315 };
316
317 let mut groups: HashMap<i64, Vec<usize>> = HashMap::new();
318
319 match label_array.data_type() {
320 DataType::Int32 => {
321 let arr = downcast_label::<Int32Array>(label_array, "Int32Array")?;
322 collect_groups(arr.iter(), &mut groups, i64::from);
323 }
324 DataType::Int64 => {
325 let arr = downcast_label::<Int64Array>(label_array, "Int64Array")?;
326 collect_groups(arr.iter(), &mut groups, |v| v);
327 }
328 DataType::UInt32 => {
329 let arr = downcast_label::<UInt32Array>(label_array, "UInt32Array")?;
330 collect_groups(arr.iter(), &mut groups, i64::from);
331 }
332 DataType::UInt64 => {
333 let arr = downcast_label::<UInt64Array>(label_array, "UInt64Array")?;
334 collect_groups(arr.iter(), &mut groups, |v| v as i64);
336 }
337 DataType::Utf8 | DataType::LargeUtf8 => {
338 let arr = downcast_label::<StringArray>(label_array, "StringArray")?;
340 collect_groups(arr.iter(), &mut groups, |s: &str| {
341 let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
343 for byte in s.as_bytes() {
344 hash ^= u64::from(*byte);
345 hash = hash.wrapping_mul(0x0100_0000_01b3);
346 }
347 hash as i64
348 });
349 }
350 dt => {
351 return Err(Error::invalid_config(format!(
352 "Unsupported label type for stratification: {dt:?}"
353 )))
354 }
355 }
356
357 Ok(groups)
358}
359
360fn downcast_label<'a, T: 'static>(array: &'a Arc<dyn Array>, type_name: &str) -> Result<&'a T> {
362 array
363 .as_any()
364 .downcast_ref::<T>()
365 .ok_or_else(|| Error::invalid_config(format!("Failed to downcast {type_name}")))
366}
367
368fn collect_groups<V, F>(
370 iter: impl Iterator<Item = Option<V>>,
371 groups: &mut HashMap<i64, Vec<usize>>,
372 to_i64: F,
373) where
374 F: Fn(V) -> i64,
375{
376 for (i, val) in iter.enumerate() {
377 if let Some(v) = val {
378 groups.entry(to_i64(v)).or_default().push(i);
379 }
380 }
381}
382
383fn take_indices(batch: &RecordBatch, indices: &[usize]) -> Result<RecordBatch> {
385 use arrow::{array::UInt32Array, compute::take};
386
387 let indices_array = UInt32Array::from(indices.iter().map(|&i| i as u32).collect::<Vec<_>>());
388
389 let columns: Vec<Arc<dyn Array>> = batch
390 .columns()
391 .iter()
392 .map(|col| take(col.as_ref(), &indices_array, None).map_err(Error::Arrow))
393 .collect::<Result<Vec<_>>>()?;
394
395 RecordBatch::try_new(batch.schema(), columns).map_err(Error::Arrow)
396}
397
398#[cfg(test)]
399mod tests {
400 use arrow::{
401 array::{Float64Array, Int32Array},
402 datatypes::{DataType, Field, Schema},
403 };
404
405 use super::*;
406
407 fn make_test_dataset(n: usize) -> ArrowDataset {
409 let schema = Arc::new(Schema::new(vec![
410 Field::new("feature", DataType::Float64, false),
411 Field::new("label", DataType::Int32, false),
412 ]));
413
414 let features: Vec<f64> = (0..n).map(|i| i as f64).collect();
415 let labels: Vec<i32> = (0..n).map(|i| (i % 3) as i32).collect(); let batch = RecordBatch::try_new(
418 schema,
419 vec![
420 Arc::new(Float64Array::from(features)),
421 Arc::new(Int32Array::from(labels)),
422 ],
423 )
424 .expect("batch creation failed");
425
426 ArrowDataset::from_batch(batch).expect("dataset creation failed")
427 }
428
429 #[test]
432 fn test_new_creates_split_without_validation() {
433 let train = make_test_dataset(80);
434 let test = make_test_dataset(20);
435
436 let split = DatasetSplit::new(train, test);
437
438 assert_eq!(split.train().len(), 80);
439 assert_eq!(split.test().len(), 20);
440 assert!(split.validation().is_none());
441 }
442
443 #[test]
446 fn test_with_validation_creates_three_way_split() {
447 let train = make_test_dataset(70);
448 let test = make_test_dataset(15);
449 let val = make_test_dataset(15);
450
451 let split = DatasetSplit::with_validation(train, test, val);
452
453 assert_eq!(split.train().len(), 70);
454 assert_eq!(split.test().len(), 15);
455 assert!(split.validation().is_some());
456 assert_eq!(split.validation().expect("val").len(), 15);
457 }
458
459 #[test]
462 fn test_from_ratios_80_20_split() {
463 let dataset = make_test_dataset(100);
464
465 let split =
466 DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, None).expect("split failed");
467
468 assert_eq!(split.train().len(), 80);
469 assert_eq!(split.test().len(), 20);
470 assert!(split.validation().is_none());
471 }
472
473 #[test]
474 fn test_from_ratios_70_15_15_split() {
475 let dataset = make_test_dataset(100);
476
477 let split =
478 DatasetSplit::from_ratios(&dataset, 0.7, 0.15, Some(0.15), None).expect("split failed");
479
480 assert_eq!(split.train().len(), 70);
481 assert_eq!(split.test().len(), 15);
482 assert!(split.validation().is_some());
483 assert_eq!(split.validation().expect("val").len(), 15);
484 }
485
486 #[test]
487 fn test_from_ratios_with_seed_is_deterministic() {
488 let dataset = make_test_dataset(100);
489
490 let split1 =
491 DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, Some(42)).expect("split failed");
492 let split2 =
493 DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, Some(42)).expect("split failed");
494
495 let train1 = split1.train().get(0).expect("batch");
497 let train2 = split2.train().get(0).expect("batch");
498
499 assert_eq!(train1.num_rows(), train2.num_rows());
500 let col1 = train1
502 .column(0)
503 .as_any()
504 .downcast_ref::<Float64Array>()
505 .expect("downcast");
506 let col2 = train2
507 .column(0)
508 .as_any()
509 .downcast_ref::<Float64Array>()
510 .expect("downcast");
511
512 for i in 0..col1.len() {
513 assert!(
514 (col1.value(i) - col2.value(i)).abs() < 1e-9,
515 "Mismatch at index {i}"
516 );
517 }
518 }
519
520 #[test]
521 fn test_from_ratios_different_seeds_produce_different_splits() {
522 let dataset = make_test_dataset(100);
523
524 let split1 =
525 DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, Some(42)).expect("split failed");
526 let split2 =
527 DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, Some(123)).expect("split failed");
528
529 let train1 = split1.train().get(0).expect("batch");
530 let train2 = split2.train().get(0).expect("batch");
531
532 let col1 = train1
533 .column(0)
534 .as_any()
535 .downcast_ref::<Float64Array>()
536 .expect("downcast");
537 let col2 = train2
538 .column(0)
539 .as_any()
540 .downcast_ref::<Float64Array>()
541 .expect("downcast");
542
543 let mut differs = false;
545 for i in 0..col1.len().min(col2.len()) {
546 if (col1.value(i) - col2.value(i)).abs() > 1e-9 {
547 differs = true;
548 break;
549 }
550 }
551 assert!(differs, "Different seeds should produce different shuffles");
552 }
553
554 #[test]
555 fn test_from_ratios_rejects_invalid_ratios() {
556 let dataset = make_test_dataset(100);
557
558 let result = DatasetSplit::from_ratios(&dataset, 0.5, 0.3, None, None);
560 assert!(result.is_err());
561
562 let result = DatasetSplit::from_ratios(&dataset, 0.0, 1.0, None, None);
564 assert!(result.is_err());
565
566 let result = DatasetSplit::from_ratios(&dataset, 1.0, 0.0, None, None);
568 assert!(result.is_err());
569
570 let result = DatasetSplit::from_ratios(&dataset, 0.8, 0.19, Some(0.0), None);
572 assert!(result.is_err());
573 }
574
575 #[test]
576 fn test_from_ratios_rejects_empty_dataset() {
577 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float64, false)]));
578 let batch = RecordBatch::try_new(
579 schema,
580 vec![Arc::new(Float64Array::from(Vec::<f64>::new()))],
581 )
582 .expect("batch");
583 let dataset = ArrowDataset::from_batch(batch).expect("dataset");
584
585 let result = DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, None);
586 assert!(result.is_err());
587 }
588
589 #[test]
590 fn test_from_ratios_handles_small_dataset() {
591 let dataset = make_test_dataset(3);
592
593 let split =
594 DatasetSplit::from_ratios(&dataset, 0.7, 0.3, None, None).expect("split failed");
595
596 assert!(split.train().len() >= 1);
598 assert!(split.test().len() >= 1);
599 assert_eq!(split.train().len() + split.test().len(), 3);
600 }
601
602 #[test]
605 fn test_stratified_preserves_class_distribution() {
606 let schema = Arc::new(Schema::new(vec![
609 Field::new("feature", DataType::Float64, false),
610 Field::new("label", DataType::Int32, false),
611 ]));
612
613 let n = 100;
614 let features: Vec<f64> = (0..n).map(|i| i as f64).collect();
615 let labels: Vec<i32> = (0..n)
616 .map(|i| {
617 if i < 60 {
618 0
619 } else if i < 90 {
620 1
621 } else {
622 2
623 }
624 })
625 .collect();
626
627 let batch = RecordBatch::try_new(
628 schema,
629 vec![
630 Arc::new(Float64Array::from(features)),
631 Arc::new(Int32Array::from(labels)),
632 ],
633 )
634 .expect("batch");
635 let dataset = ArrowDataset::from_batch(batch).expect("dataset");
636
637 let split =
638 DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42)).expect("split");
639
640 let mut train_counts = [0usize; 3];
642 for batch in split.train().iter() {
643 let labels = batch
644 .column(1)
645 .as_any()
646 .downcast_ref::<Int32Array>()
647 .expect("downcast");
648 for val in labels.iter().flatten() {
649 train_counts[val as usize] += 1;
650 }
651 }
652
653 let mut test_counts = [0usize; 3];
655 for batch in split.test().iter() {
656 let labels = batch
657 .column(1)
658 .as_any()
659 .downcast_ref::<Int32Array>()
660 .expect("downcast");
661 for val in labels.iter().flatten() {
662 test_counts[val as usize] += 1;
663 }
664 }
665
666 let train_total = train_counts.iter().sum::<usize>() as f64;
669 let test_total = test_counts.iter().sum::<usize>() as f64;
670
671 let train_ratio_0 = train_counts[0] as f64 / train_total;
672 let test_ratio_0 = test_counts[0] as f64 / test_total;
673
674 assert!(
676 (train_ratio_0 - 0.6).abs() < 0.15,
677 "Train class 0 ratio {train_ratio_0} too far from 0.6"
678 );
679 assert!(
680 (test_ratio_0 - 0.6).abs() < 0.15,
681 "Test class 0 ratio {test_ratio_0} too far from 0.6"
682 );
683 }
684
685 #[test]
686 fn test_stratified_with_validation() {
687 let dataset = make_test_dataset(90); let split = DatasetSplit::stratified(&dataset, "label", 0.7, 0.15, Some(0.15), Some(42))
690 .expect("split");
691
692 assert!(split.validation().is_some());
693 let total = split.train().len() + split.test().len() + split.validation().expect("v").len();
694 assert_eq!(total, 90);
695 }
696
697 #[test]
698 fn test_stratified_rejects_missing_column() {
699 let dataset = make_test_dataset(100);
700
701 let result = DatasetSplit::stratified(&dataset, "nonexistent", 0.8, 0.2, None, None);
702 assert!(result.is_err());
703 }
704
705 #[test]
706 fn test_stratified_rejects_invalid_ratios() {
707 let dataset = make_test_dataset(100);
708
709 let result = DatasetSplit::stratified(&dataset, "label", 0.5, 0.3, None, None);
710 assert!(result.is_err());
711 }
712
713 #[test]
714 fn test_stratified_is_deterministic_with_seed() {
715 let dataset = make_test_dataset(100);
716
717 let split1 =
718 DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42)).expect("split");
719 let split2 =
720 DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42)).expect("split");
721
722 assert_eq!(split1.train().len(), split2.train().len());
723 assert_eq!(split1.test().len(), split2.test().len());
724 }
725
726 #[test]
729 fn test_split_preserves_schema() {
730 let dataset = make_test_dataset(100);
731 let original_schema = dataset.schema();
732
733 let split =
734 DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, None).expect("split failed");
735
736 assert_eq!(split.train().schema(), original_schema);
737 assert_eq!(split.test().schema(), original_schema);
738 }
739
740 #[test]
741 fn test_split_no_data_overlap() {
742 let dataset = make_test_dataset(100);
743
744 let split =
745 DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, Some(42)).expect("split failed");
746
747 let mut train_set: std::collections::HashSet<u64> = std::collections::HashSet::new();
749 for batch in split.train().iter() {
750 let features = batch
751 .column(0)
752 .as_any()
753 .downcast_ref::<Float64Array>()
754 .expect("downcast");
755 for val in features.iter().flatten() {
756 train_set.insert(val.to_bits());
757 }
758 }
759
760 for batch in split.test().iter() {
762 let features = batch
763 .column(0)
764 .as_any()
765 .downcast_ref::<Float64Array>()
766 .expect("downcast");
767 for val in features.iter().flatten() {
768 assert!(
769 !train_set.contains(&val.to_bits()),
770 "Found overlapping value {val} in train and test"
771 );
772 }
773 }
774 }
775
776 #[test]
777 fn test_stratified_with_int64_labels() {
778 use arrow::array::Int64Array;
779
780 let schema = Arc::new(Schema::new(vec![
781 Field::new("feature", DataType::Float64, false),
782 Field::new("label", DataType::Int64, false),
783 ]));
784
785 let features: Vec<f64> = (0..100).map(|i| i as f64).collect();
786 let labels: Vec<i64> = (0..100).map(|i| (i % 3) as i64).collect();
787
788 let batch = RecordBatch::try_new(
789 schema,
790 vec![
791 Arc::new(Float64Array::from(features)),
792 Arc::new(Int64Array::from(labels)),
793 ],
794 )
795 .expect("batch creation failed");
796
797 let dataset = ArrowDataset::from_batch(batch).expect("dataset creation failed");
798
799 let split = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42))
800 .expect("split failed");
801
802 assert!(split.train().len() > 0);
803 assert!(split.test().len() > 0);
804 }
805
806 #[test]
807 fn test_stratified_with_uint32_labels() {
808 use arrow::array::UInt32Array;
809
810 let schema = Arc::new(Schema::new(vec![
811 Field::new("feature", DataType::Float64, false),
812 Field::new("label", DataType::UInt32, false),
813 ]));
814
815 let features: Vec<f64> = (0..100).map(|i| i as f64).collect();
816 let labels: Vec<u32> = (0..100).map(|i| (i % 3) as u32).collect();
817
818 let batch = RecordBatch::try_new(
819 schema,
820 vec![
821 Arc::new(Float64Array::from(features)),
822 Arc::new(UInt32Array::from(labels)),
823 ],
824 )
825 .expect("batch creation failed");
826
827 let dataset = ArrowDataset::from_batch(batch).expect("dataset creation failed");
828
829 let split = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42))
830 .expect("split failed");
831
832 assert!(split.train().len() > 0);
833 assert!(split.test().len() > 0);
834 }
835
836 #[test]
837 fn test_stratified_with_uint64_labels() {
838 use arrow::array::UInt64Array;
839
840 let schema = Arc::new(Schema::new(vec![
841 Field::new("feature", DataType::Float64, false),
842 Field::new("label", DataType::UInt64, false),
843 ]));
844
845 let features: Vec<f64> = (0..100).map(|i| i as f64).collect();
846 let labels: Vec<u64> = (0..100).map(|i| (i % 3) as u64).collect();
847
848 let batch = RecordBatch::try_new(
849 schema,
850 vec![
851 Arc::new(Float64Array::from(features)),
852 Arc::new(UInt64Array::from(labels)),
853 ],
854 )
855 .expect("batch creation failed");
856
857 let dataset = ArrowDataset::from_batch(batch).expect("dataset creation failed");
858
859 let split = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42))
860 .expect("split failed");
861
862 assert!(split.train().len() > 0);
863 assert!(split.test().len() > 0);
864 }
865
866 #[test]
867 fn test_stratified_with_string_labels() {
868 use arrow::array::StringArray;
869
870 let schema = Arc::new(Schema::new(vec![
871 Field::new("feature", DataType::Float64, false),
872 Field::new("label", DataType::Utf8, false),
873 ]));
874
875 let features: Vec<f64> = (0..100).map(|i| i as f64).collect();
876 let labels: Vec<&str> = (0..100)
877 .map(|i| if i % 2 == 0 { "a" } else { "b" })
878 .collect();
879
880 let batch = RecordBatch::try_new(
881 schema,
882 vec![
883 Arc::new(Float64Array::from(features)),
884 Arc::new(StringArray::from(labels)),
885 ],
886 )
887 .expect("batch creation failed");
888
889 let dataset = ArrowDataset::from_batch(batch).expect("dataset creation failed");
890
891 let split = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, Some(42))
892 .expect("stratified split with string labels should succeed");
893 assert_eq!(split.train().len() + split.test().len(), 100);
894 assert!(split.train().len() > 0);
895 assert!(split.test().len() > 0);
896 }
897
898 #[test]
899 fn test_stratified_without_seed() {
900 let dataset = make_test_dataset(100);
901
902 let split = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, None)
904 .expect("split failed");
905
906 assert!(split.train().len() > 0);
907 assert!(split.test().len() > 0);
908 }
909
910 #[test]
911 fn test_split_debug() {
912 let dataset = make_test_dataset(100);
913 let split =
914 DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, None).expect("split failed");
915
916 let debug = format!("{:?}", split);
917 assert!(debug.contains("DatasetSplit"));
918 }
919
920 #[test]
921 fn test_split_clone() {
922 let dataset = make_test_dataset(100);
923 let split =
924 DatasetSplit::from_ratios(&dataset, 0.8, 0.2, None, None).expect("split failed");
925
926 let cloned = split.clone();
927 assert_eq!(cloned.train().len(), split.train().len());
928 assert_eq!(cloned.test().len(), split.test().len());
929 }
930
931 #[test]
932 fn test_extreme_ratio_99_1() {
933 let dataset = make_test_dataset(100);
934 let split =
935 DatasetSplit::from_ratios(&dataset, 0.99, 0.01, None, None).expect("split failed");
936
937 assert_eq!(split.train().len(), 99);
938 assert_eq!(split.test().len(), 1);
939 }
940
941 #[test]
942 fn test_extreme_ratio_50_50() {
943 let dataset = make_test_dataset(100);
944 let split =
945 DatasetSplit::from_ratios(&dataset, 0.5, 0.5, None, None).expect("split failed");
946
947 assert_eq!(split.train().len(), 50);
948 assert_eq!(split.test().len(), 50);
949 }
950
951 #[test]
952 fn test_negative_train_ratio_rejected() {
953 let dataset = make_test_dataset(100);
954 let result = DatasetSplit::from_ratios(&dataset, -0.5, 0.5, None, None);
955 assert!(result.is_err());
956 }
957
958 #[test]
959 fn test_zero_test_ratio_rejected() {
960 let dataset = make_test_dataset(100);
961 let result = DatasetSplit::from_ratios(&dataset, 1.0, 0.0, None, None);
962 assert!(result.is_err());
963 }
964
965 #[test]
966 fn test_negative_val_ratio_rejected() {
967 let dataset = make_test_dataset(100);
968 let result = DatasetSplit::from_ratios(&dataset, 0.6, 0.5, Some(-0.1), None);
969 assert!(result.is_err());
970 }
971
972 #[test]
973 fn test_single_row_minimum_sizes() {
974 let dataset = make_test_dataset(2);
975 let split =
976 DatasetSplit::from_ratios(&dataset, 0.5, 0.5, None, None).expect("split failed");
977
978 assert!(split.train().len() >= 1);
980 assert!(split.test().len() >= 1);
981 }
982
983 #[test]
984 fn test_ratios_slightly_over_one() {
985 let dataset = make_test_dataset(100);
986 let result = DatasetSplit::from_ratios(&dataset, 0.81, 0.2, None, None);
988 assert!(result.is_err());
989 }
990
991 #[test]
992 fn test_ratios_slightly_under_one() {
993 let dataset = make_test_dataset(100);
994 let result = DatasetSplit::from_ratios(&dataset, 0.79, 0.2, None, None);
996 assert!(result.is_err());
997 }
998
999 #[test]
1000 fn test_getters_return_correct_data() {
1001 let train = make_test_dataset(80);
1002 let test = make_test_dataset(20);
1003 let val = make_test_dataset(10);
1004
1005 let split = DatasetSplit::with_validation(train.clone(), test.clone(), val.clone());
1006
1007 assert_eq!(split.train().len(), 80);
1008 assert_eq!(split.test().len(), 20);
1009 assert_eq!(split.validation().map(|v| v.len()), Some(10));
1010 }
1011
1012 #[test]
1013 fn test_validation_none_for_two_way_split() {
1014 let train = make_test_dataset(80);
1015 let test = make_test_dataset(20);
1016
1017 let split = DatasetSplit::new(train, test);
1018
1019 assert!(split.validation().is_none());
1020 }
1021
1022 #[test]
1023 fn test_stratified_empty_dataset() {
1024 let schema = Arc::new(Schema::new(vec![
1025 Field::new("x", DataType::Float64, false),
1026 Field::new("label", DataType::Int32, false),
1027 ]));
1028 let x_array = arrow::array::Float64Array::from(Vec::<f64>::new());
1029 let label_array = Int32Array::from(Vec::<i32>::new());
1030 let batch = RecordBatch::try_new(schema, vec![Arc::new(x_array), Arc::new(label_array)])
1031 .expect("batch");
1032 let dataset = ArrowDataset::from_batch(batch).expect("dataset");
1033
1034 let result = DatasetSplit::stratified(&dataset, "label", 0.8, 0.2, None, None);
1035 assert!(result.is_err());
1036 }
1037
1038 #[test]
1039 fn test_stratified_zero_test_ratio_rejected() {
1040 let dataset = make_test_dataset(100);
1041 let result = DatasetSplit::stratified(&dataset, "y", 1.0, 0.0, None, None);
1042 assert!(result.is_err());
1043 }
1044
1045 #[test]
1046 fn test_split_preserves_all_rows() {
1047 let dataset = make_test_dataset(100);
1048 let split =
1049 DatasetSplit::from_ratios(&dataset, 0.6, 0.2, Some(0.2), None).expect("split failed");
1050
1051 let total = split.train().len()
1052 + split.test().len()
1053 + split.validation().map(|v| v.len()).unwrap_or(0);
1054 assert_eq!(total, 100);
1055 }
1056}