1use std::sync::Arc;
4
5use arrow::array::{Array, RecordBatch};
6#[cfg(feature = "shuffle")]
7use rand::{seq::SliceRandom, SeedableRng};
8
9use super::Transform;
10use crate::error::{Error, Result};
11
12#[cfg(feature = "shuffle")]
28#[derive(Debug, Clone)]
29pub struct Shuffle {
30 seed: Option<u64>,
31}
32
33#[cfg(feature = "shuffle")]
34impl Shuffle {
35 pub fn new() -> Self {
37 Self { seed: None }
38 }
39
40 pub fn with_seed(seed: u64) -> Self {
42 Self { seed: Some(seed) }
43 }
44}
45
46#[cfg(feature = "shuffle")]
47impl Default for Shuffle {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53#[cfg(feature = "shuffle")]
54impl Transform for Shuffle {
55 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
56 let num_rows = batch.num_rows();
57 if num_rows <= 1 {
58 return Ok(batch);
59 }
60
61 let mut indices: Vec<usize> = (0..num_rows).collect();
63 let mut rng = match self.seed {
64 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
65 None => rand::rngs::StdRng::from_entropy(),
66 };
67 indices.shuffle(&mut rng);
68
69 let schema = batch.schema();
71 let new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
72 .map(|col_idx| {
73 let col = batch.column(col_idx);
74 let indices_array =
75 arrow::array::UInt64Array::from_iter_values(indices.iter().map(|&i| i as u64));
76 arrow::compute::take(col.as_ref(), &indices_array, None)
77 .map_err(Error::Arrow)
78 .map(Arc::from)
79 })
80 .collect::<Result<Vec<_>>>()?;
81
82 RecordBatch::try_new(schema, new_columns).map_err(Error::Arrow)
83 }
84}
85
86#[cfg(feature = "shuffle")]
103#[derive(Debug, Clone)]
104pub struct Sample {
105 count: Option<usize>,
106 fraction: Option<f64>,
107 seed: Option<u64>,
108}
109
110#[cfg(feature = "shuffle")]
111impl Sample {
112 pub fn new(count: usize) -> Self {
116 Self {
117 count: Some(count),
118 fraction: None,
119 seed: None,
120 }
121 }
122
123 pub fn fraction(frac: f64) -> Self {
127 Self {
128 count: None,
129 fraction: Some(frac.clamp(0.0, 1.0)),
130 seed: None,
131 }
132 }
133
134 #[must_use]
136 pub fn with_seed(mut self, seed: u64) -> Self {
137 self.seed = Some(seed);
138 self
139 }
140
141 pub fn count(&self) -> Option<usize> {
143 self.count
144 }
145
146 pub fn sample_fraction(&self) -> Option<f64> {
148 self.fraction
149 }
150}
151
152#[cfg(feature = "shuffle")]
153impl Transform for Sample {
154 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
155 let num_rows = batch.num_rows();
156 if num_rows == 0 {
157 return Ok(batch);
158 }
159
160 #[allow(
161 clippy::cast_possible_truncation,
162 clippy::cast_sign_loss,
163 clippy::cast_precision_loss
164 )]
165 let sample_size = match (self.count, self.fraction) {
166 (Some(c), _) => c.min(num_rows),
167 (None, Some(f)) => ((num_rows as f64) * f).round() as usize,
168 (None, None) => return Ok(batch),
169 };
170
171 if sample_size >= num_rows {
172 return Ok(batch);
173 }
174
175 let mut indices: Vec<usize> = (0..num_rows).collect();
177 let mut rng = match self.seed {
178 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
179 None => rand::rngs::StdRng::from_entropy(),
180 };
181 indices.shuffle(&mut rng);
182 indices.truncate(sample_size);
183 indices.sort_unstable(); let schema = batch.schema();
187 let new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
188 .map(|col_idx| {
189 let col = batch.column(col_idx);
190 let indices_array =
191 arrow::array::UInt64Array::from_iter_values(indices.iter().map(|&i| i as u64));
192 arrow::compute::take(col.as_ref(), &indices_array, None)
193 .map_err(Error::Arrow)
194 .map(Arc::from)
195 })
196 .collect::<Result<Vec<_>>>()?;
197
198 RecordBatch::try_new(schema, new_columns).map_err(Error::Arrow)
199 }
200}
201
202#[derive(Debug, Clone, Copy)]
212pub struct Take {
213 count: usize,
214}
215
216impl Take {
217 pub fn new(count: usize) -> Self {
219 Self { count }
220 }
221
222 pub fn count(&self) -> usize {
224 self.count
225 }
226}
227
228impl Transform for Take {
229 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
230 let num_rows = batch.num_rows();
231 if self.count >= num_rows {
232 return Ok(batch);
233 }
234
235 Ok(batch.slice(0, self.count))
236 }
237}
238
239#[derive(Debug, Clone, Copy)]
249pub struct Skip {
250 count: usize,
251}
252
253impl Skip {
254 pub fn new(count: usize) -> Self {
256 Self { count }
257 }
258
259 pub fn count(&self) -> usize {
261 self.count
262 }
263}
264
265impl Transform for Skip {
266 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
267 let num_rows = batch.num_rows();
268 if self.count >= num_rows {
269 return Ok(batch.slice(0, 0));
271 }
272
273 let remaining = num_rows - self.count;
274 Ok(batch.slice(self.count, remaining))
275 }
276}
277
278#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
280pub enum SortOrder {
281 #[default]
283 Ascending,
284 Descending,
286}
287
288#[derive(Debug, Clone)]
305pub struct Sort {
306 columns: Vec<(String, SortOrder)>,
307 nulls_first: bool,
308}
309
310impl Sort {
311 pub fn by<S: Into<String>>(column: S) -> Self {
313 Self {
314 columns: vec![(column.into(), SortOrder::Ascending)],
315 nulls_first: false,
316 }
317 }
318
319 pub fn by_columns<S: Into<String>>(columns: impl IntoIterator<Item = (S, SortOrder)>) -> Self {
321 Self {
322 columns: columns
323 .into_iter()
324 .map(|(name, order)| (name.into(), order))
325 .collect(),
326 nulls_first: false,
327 }
328 }
329
330 #[must_use]
332 pub fn order(mut self, order: SortOrder) -> Self {
333 if let Some((_, o)) = self.columns.first_mut() {
334 *o = order;
335 }
336 self
337 }
338
339 #[must_use]
341 pub fn nulls_first(mut self, nulls_first: bool) -> Self {
342 self.nulls_first = nulls_first;
343 self
344 }
345
346 pub fn columns(&self) -> &[(String, SortOrder)] {
348 &self.columns
349 }
350}
351
352impl Transform for Sort {
353 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
354 use arrow::compute::{lexsort_to_indices, take, SortColumn, SortOptions};
355
356 if batch.num_rows() <= 1 || self.columns.is_empty() {
357 return Ok(batch);
358 }
359
360 let schema = batch.schema();
361
362 let sort_columns: Vec<SortColumn> = self
364 .columns
365 .iter()
366 .map(|(col_name, order)| {
367 let (idx, _) = schema
368 .column_with_name(col_name)
369 .ok_or_else(|| Error::column_not_found(col_name))?;
370
371 Ok(SortColumn {
372 values: Arc::clone(batch.column(idx)),
373 options: Some(SortOptions {
374 descending: *order == SortOrder::Descending,
375 nulls_first: self.nulls_first,
376 }),
377 })
378 })
379 .collect::<Result<Vec<_>>>()?;
380
381 let indices = lexsort_to_indices(&sort_columns, None).map_err(Error::Arrow)?;
383
384 let new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
386 .map(|col_idx| {
387 let col = batch.column(col_idx);
388 take(col.as_ref(), &indices, None)
389 .map_err(Error::Arrow)
390 .map(Arc::from)
391 })
392 .collect::<Result<Vec<_>>>()?;
393
394 RecordBatch::try_new(schema, new_columns).map_err(Error::Arrow)
395 }
396}
397
398#[derive(Debug, Clone)]
416pub struct Unique {
417 columns: Option<Vec<String>>,
418 keep_last: bool,
419}
420
421impl Unique {
422 pub fn all() -> Self {
424 Self {
425 columns: None,
426 keep_last: false,
427 }
428 }
429
430 pub fn by<S: Into<String>>(columns: impl IntoIterator<Item = S>) -> Self {
432 Self {
433 columns: Some(columns.into_iter().map(Into::into).collect()),
434 keep_last: false,
435 }
436 }
437
438 #[must_use]
440 pub fn keep_first(mut self) -> Self {
441 self.keep_last = false;
442 self
443 }
444
445 #[must_use]
447 pub fn keep_last(mut self) -> Self {
448 self.keep_last = true;
449 self
450 }
451
452 pub fn columns(&self) -> Option<&[String]> {
454 self.columns.as_deref()
455 }
456
457 fn row_key(batch: &RecordBatch, row_idx: usize, key_indices: &[usize]) -> u64 {
458 use std::hash::{DefaultHasher, Hash, Hasher};
459
460 use arrow::array::{
461 BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray,
462 };
463
464 let mut hasher = DefaultHasher::new();
465
466 for &col_idx in key_indices {
467 let col = batch.column(col_idx);
468 if col.is_null(row_idx) {
469 0u8.hash(&mut hasher);
471 } else if let Some(arr) = col.as_any().downcast_ref::<Int32Array>() {
472 1u8.hash(&mut hasher);
473 arr.value(row_idx).hash(&mut hasher);
474 } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
475 2u8.hash(&mut hasher);
476 arr.value(row_idx).hash(&mut hasher);
477 } else if let Some(arr) = col.as_any().downcast_ref::<Float32Array>() {
478 3u8.hash(&mut hasher);
480 arr.value(row_idx).to_bits().hash(&mut hasher);
481 } else if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
482 4u8.hash(&mut hasher);
483 arr.value(row_idx).to_bits().hash(&mut hasher);
484 } else if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
485 5u8.hash(&mut hasher);
486 arr.value(row_idx).hash(&mut hasher);
487 } else if let Some(arr) = col.as_any().downcast_ref::<BooleanArray>() {
488 6u8.hash(&mut hasher);
489 arr.value(row_idx).hash(&mut hasher);
490 } else {
491 7u8.hash(&mut hasher);
493 format!("{:?}", col.data_type()).hash(&mut hasher);
494 }
495 }
496
497 hasher.finish()
498 }
499}
500
501impl Transform for Unique {
502 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
503 use std::collections::HashMap;
504
505 let num_rows = batch.num_rows();
506 if num_rows <= 1 {
507 return Ok(batch);
508 }
509
510 let schema = batch.schema();
511
512 let key_indices: Vec<usize> = match &self.columns {
514 Some(cols) => cols
515 .iter()
516 .map(|name| {
517 schema
518 .column_with_name(name)
519 .map(|(idx, _)| idx)
520 .ok_or_else(|| Error::column_not_found(name))
521 })
522 .collect::<Result<Vec<_>>>()?,
523 None => (0..schema.fields().len()).collect(),
524 };
525
526 let mut seen: HashMap<u64, usize> = HashMap::new();
529 let mut keep_indices: Vec<usize> = Vec::new();
530
531 let row_iter: Box<dyn Iterator<Item = usize>> = if self.keep_last {
532 Box::new((0..num_rows).rev())
533 } else {
534 Box::new(0..num_rows)
535 };
536
537 for row_idx in row_iter {
538 let key = Self::row_key(&batch, row_idx, &key_indices);
539
540 if let std::collections::hash_map::Entry::Vacant(e) = seen.entry(key) {
541 e.insert(row_idx);
542 keep_indices.push(row_idx);
543 }
544 }
545
546 if self.keep_last {
547 keep_indices.reverse();
548 }
549
550 if keep_indices.len() == num_rows {
551 return Ok(batch);
552 }
553
554 let indices_array =
556 arrow::array::UInt64Array::from_iter_values(keep_indices.iter().map(|&i| i as u64));
557
558 let new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
559 .map(|col_idx| {
560 let col = batch.column(col_idx);
561 arrow::compute::take(col.as_ref(), &indices_array, None)
562 .map_err(Error::Arrow)
563 .map(Arc::from)
564 })
565 .collect::<Result<Vec<_>>>()?;
566
567 RecordBatch::try_new(schema, new_columns).map_err(Error::Arrow)
568 }
569}
570
571#[cfg(test)]
572#[allow(
573 clippy::float_cmp,
574 clippy::cast_precision_loss,
575 clippy::redundant_closure
576)]
577mod tests {
578 use arrow::{
579 array::{Int32Array, StringArray},
580 datatypes::{DataType, Field, Schema},
581 };
582
583 use super::*;
584
585 fn create_test_batch() -> RecordBatch {
586 let schema = Arc::new(Schema::new(vec![
587 Field::new("id", DataType::Int32, false),
588 Field::new("name", DataType::Utf8, false),
589 Field::new("value", DataType::Int32, false),
590 ]));
591
592 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
593 let name_array = StringArray::from(vec!["a", "b", "c", "d", "e"]);
594 let value_array = Int32Array::from(vec![10, 20, 30, 40, 50]);
595
596 RecordBatch::try_new(
597 schema,
598 vec![
599 Arc::new(id_array),
600 Arc::new(name_array),
601 Arc::new(value_array),
602 ],
603 )
604 .ok()
605 .unwrap_or_else(|| panic!("Should create batch"))
606 }
607
608 #[cfg(feature = "shuffle")]
609 #[test]
610 fn test_shuffle_transform_deterministic() {
611 let batch = create_test_batch();
612 let transform = Shuffle::with_seed(42);
613
614 let result1 = transform.apply(batch.clone());
615 let result2 = transform.apply(batch);
616
617 assert!(result1.is_ok());
618 assert!(result2.is_ok());
619
620 let result1 = result1.ok().unwrap_or_else(|| panic!("Should succeed"));
621 let result2 = result2.ok().unwrap_or_else(|| panic!("Should succeed"));
622
623 let col1 = result1
625 .column(0)
626 .as_any()
627 .downcast_ref::<Int32Array>()
628 .unwrap_or_else(|| panic!("Should be Int32Array"));
629 let col2 = result2
630 .column(0)
631 .as_any()
632 .downcast_ref::<Int32Array>()
633 .unwrap_or_else(|| panic!("Should be Int32Array"));
634
635 for i in 0..col1.len() {
636 assert_eq!(col1.value(i), col2.value(i));
637 }
638 }
639
640 #[cfg(feature = "shuffle")]
641 #[test]
642 fn test_shuffle_preserves_row_integrity() {
643 let batch = create_test_batch();
644 let transform = Shuffle::with_seed(42);
645
646 let result = transform.apply(batch);
647 assert!(result.is_ok());
648 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
649
650 let ids = result
652 .column(0)
653 .as_any()
654 .downcast_ref::<Int32Array>()
655 .unwrap_or_else(|| panic!("Should be Int32Array"));
656 let values = result
657 .column(2)
658 .as_any()
659 .downcast_ref::<Int32Array>()
660 .unwrap_or_else(|| panic!("Should be Int32Array"));
661
662 for i in 0..ids.len() {
664 let id = ids.value(i);
665 let value = values.value(i);
666 assert_eq!(value, id * 10);
667 }
668 }
669
670 #[cfg(feature = "shuffle")]
671 #[test]
672 fn test_shuffle_single_row() {
673 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
674 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1]))])
675 .ok()
676 .unwrap_or_else(|| panic!("Should create batch"));
677
678 let transform = Shuffle::new();
679 let result = transform.apply(batch);
680 assert!(result.is_ok());
681 }
682
683 #[cfg(feature = "shuffle")]
684 #[test]
685 fn test_shuffle_default() {
686 let shuffle = Shuffle::default();
687 let batch = create_test_batch();
688 let result = shuffle.apply(batch);
689 assert!(result.is_ok());
690 }
691
692 #[cfg(feature = "shuffle")]
693 #[test]
694 fn test_shuffle_debug() {
695 let shuffle = Shuffle::new();
696 let debug_str = format!("{:?}", shuffle);
697 assert!(debug_str.contains("Shuffle"));
698 }
699
700 #[cfg(feature = "shuffle")]
701 #[test]
702 fn test_shuffle_with_seed() {
703 let batch = create_test_batch();
704 let shuffle = Shuffle::with_seed(12345);
705
706 let result1 = shuffle.apply(batch.clone());
707 let result2 = shuffle.apply(batch);
708
709 assert!(result1.is_ok());
710 assert!(result2.is_ok());
711
712 let result1 = result1.ok().unwrap_or_else(|| panic!("Should succeed"));
713 let result2 = result2.ok().unwrap_or_else(|| panic!("Should succeed"));
714
715 let ids1 = result1
716 .column(0)
717 .as_any()
718 .downcast_ref::<Int32Array>()
719 .unwrap_or_else(|| panic!("Should be Int32Array"));
720 let ids2 = result2
721 .column(0)
722 .as_any()
723 .downcast_ref::<Int32Array>()
724 .unwrap_or_else(|| panic!("Should be Int32Array"));
725
726 for i in 0..ids1.len() {
727 assert_eq!(ids1.value(i), ids2.value(i));
728 }
729 }
730
731 #[cfg(feature = "shuffle")]
734 #[test]
735 fn test_sample_by_count() {
736 let batch = create_test_batch();
737 let transform = Sample::new(3).with_seed(42);
738
739 let result = transform.apply(batch);
740 assert!(result.is_ok());
741 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
742 assert_eq!(result.num_rows(), 3);
743 }
744
745 #[cfg(feature = "shuffle")]
746 #[test]
747 fn test_sample_by_fraction() {
748 let batch = create_test_batch();
749 let transform = Sample::fraction(0.4).with_seed(42);
750
751 let result = transform.apply(batch);
752 assert!(result.is_ok());
753 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
754 assert_eq!(result.num_rows(), 2); }
756
757 #[cfg(feature = "shuffle")]
758 #[test]
759 fn test_sample_deterministic() {
760 let batch = create_test_batch();
761 let transform = Sample::new(3).with_seed(42);
762
763 let result1 = transform.apply(batch.clone());
764 let result2 = transform.apply(batch);
765
766 assert!(result1.is_ok());
767 assert!(result2.is_ok());
768
769 let result1 = result1.ok().unwrap_or_else(|| panic!("Should succeed"));
770 let result2 = result2.ok().unwrap_or_else(|| panic!("Should succeed"));
771
772 let col1 = result1
773 .column(0)
774 .as_any()
775 .downcast_ref::<Int32Array>()
776 .unwrap_or_else(|| panic!("Should be Int32Array"));
777 let col2 = result2
778 .column(0)
779 .as_any()
780 .downcast_ref::<Int32Array>()
781 .unwrap_or_else(|| panic!("Should be Int32Array"));
782
783 for i in 0..col1.len() {
784 assert_eq!(col1.value(i), col2.value(i));
785 }
786 }
787
788 #[cfg(feature = "shuffle")]
789 #[test]
790 fn test_sample_preserves_row_integrity() {
791 let batch = create_test_batch();
792 let transform = Sample::new(3).with_seed(42);
793
794 let result = transform.apply(batch);
795 assert!(result.is_ok());
796 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
797
798 let ids = result
799 .column(0)
800 .as_any()
801 .downcast_ref::<Int32Array>()
802 .unwrap_or_else(|| panic!("Should be Int32Array"));
803 let values = result
804 .column(2)
805 .as_any()
806 .downcast_ref::<Int32Array>()
807 .unwrap_or_else(|| panic!("Should be Int32Array"));
808
809 for i in 0..ids.len() {
811 let id = ids.value(i);
812 let value = values.value(i);
813 assert_eq!(value, id * 10);
814 }
815 }
816
817 #[cfg(feature = "shuffle")]
818 #[test]
819 fn test_sample_count_larger_than_batch() {
820 let batch = create_test_batch();
821 let transform = Sample::new(100);
822
823 let result = transform.apply(batch.clone());
824 assert!(result.is_ok());
825 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
826 assert_eq!(result.num_rows(), batch.num_rows());
827 }
828
829 #[cfg(feature = "shuffle")]
830 #[test]
831 fn test_sample_getters() {
832 let sample = Sample::new(10).with_seed(42);
833 assert_eq!(sample.count(), Some(10));
834 assert!(sample.sample_fraction().is_none());
835
836 let sample2 = Sample::fraction(0.5);
837 assert!(sample2.count().is_none());
838 assert_eq!(sample2.sample_fraction(), Some(0.5));
839 }
840
841 #[cfg(feature = "shuffle")]
842 #[test]
843 fn test_sample_debug() {
844 let sample = Sample::new(10);
845 let debug_str = format!("{:?}", sample);
846 assert!(debug_str.contains("Sample"));
847 }
848
849 #[cfg(feature = "shuffle")]
850 #[test]
851 fn test_sample_with_seed() {
852 let batch = create_test_batch();
853 let sample = Sample::new(3).with_seed(42);
854
855 let result1 = sample.apply(batch.clone());
856 let result2 = sample.apply(batch);
857
858 assert!(result1.is_ok());
859 assert!(result2.is_ok());
860
861 let result1 = result1.ok().unwrap_or_else(|| panic!("Should succeed"));
862 let result2 = result2.ok().unwrap_or_else(|| panic!("Should succeed"));
863
864 assert_eq!(result1.num_rows(), result2.num_rows());
865 }
866
867 #[cfg(feature = "shuffle")]
868 #[test]
869 fn test_sample_zero_count() {
870 let batch = create_test_batch();
871 let sample = Sample::new(0);
872 let result = sample.apply(batch);
873 assert!(result.is_ok());
874 let result = result.ok().unwrap();
875 assert_eq!(result.num_rows(), 0);
876 }
877
878 #[cfg(feature = "shuffle")]
879 #[test]
880 fn test_sample_fraction_zero() {
881 let batch = create_test_batch();
882 let sample = Sample::fraction(0.0);
883 let result = sample.apply(batch);
884 assert!(result.is_ok());
885 let result = result.ok().unwrap();
886 assert_eq!(result.num_rows(), 0);
887 }
888
889 #[cfg(feature = "shuffle")]
890 #[test]
891 fn test_sample_fraction_full() {
892 let batch = create_test_batch();
893 let sample = Sample::fraction(1.0);
894 let result = sample.apply(batch.clone());
895 assert!(result.is_ok());
896 let result = result.ok().unwrap();
897 assert_eq!(result.num_rows(), batch.num_rows());
898 }
899
900 #[cfg(feature = "shuffle")]
901 #[test]
902 fn test_sample_fraction_negative_clamped() {
903 let sample = Sample::fraction(-0.5);
904 assert_eq!(sample.sample_fraction(), Some(0.0));
906
907 let batch = create_test_batch();
908 let result = sample.apply(batch);
909 assert!(result.is_ok());
910 let result = result.ok().unwrap();
911 assert_eq!(result.num_rows(), 0);
912 }
913
914 #[cfg(feature = "shuffle")]
915 #[test]
916 fn test_sample_fraction_over_one_clamped() {
917 let sample = Sample::fraction(1.5);
918 assert_eq!(sample.sample_fraction(), Some(1.0));
920
921 let batch = create_test_batch();
922 let result = sample.apply(batch.clone());
923 assert!(result.is_ok());
924 let result = result.ok().unwrap();
925 assert_eq!(result.num_rows(), batch.num_rows());
926 }
927
928 #[test]
931 fn test_take_transform() {
932 let batch = create_test_batch();
933 let transform = Take::new(3);
934
935 let result = transform.apply(batch);
936 assert!(result.is_ok());
937 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
938 assert_eq!(result.num_rows(), 3);
939
940 let ids = result
941 .column(0)
942 .as_any()
943 .downcast_ref::<Int32Array>()
944 .unwrap_or_else(|| panic!("Should be Int32Array"));
945 assert_eq!(ids.value(0), 1);
946 assert_eq!(ids.value(1), 2);
947 assert_eq!(ids.value(2), 3);
948 }
949
950 #[test]
951 fn test_take_more_than_available() {
952 let batch = create_test_batch();
953 let transform = Take::new(100);
954
955 let result = transform.apply(batch.clone());
956 assert!(result.is_ok());
957 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
958 assert_eq!(result.num_rows(), batch.num_rows());
959 }
960
961 #[test]
962 fn test_take_count_getter() {
963 let take = Take::new(42);
964 assert_eq!(take.count(), 42);
965 }
966
967 #[test]
968 fn test_take_debug() {
969 let take = Take::new(10);
970 let debug_str = format!("{:?}", take);
971 assert!(debug_str.contains("Take"));
972 }
973
974 #[test]
975 fn test_take_zero_rows() {
976 let batch = create_test_batch();
977 let take = Take::new(0);
978 let result = take.apply(batch);
979 assert!(result.is_ok());
980 let result = result.ok().unwrap();
981 assert_eq!(result.num_rows(), 0);
982 }
983
984 #[test]
985 fn test_take_beyond_bounds() {
986 let batch = create_test_batch(); let take = Take::new(100); let result = take.apply(batch);
989 assert!(result.is_ok());
990 let result = result.ok().unwrap();
991 assert_eq!(result.num_rows(), 5); }
993
994 #[test]
997 fn test_skip_transform() {
998 let batch = create_test_batch();
999 let transform = Skip::new(2);
1000
1001 let result = transform.apply(batch);
1002 assert!(result.is_ok());
1003 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1004 assert_eq!(result.num_rows(), 3);
1005
1006 let ids = result
1007 .column(0)
1008 .as_any()
1009 .downcast_ref::<Int32Array>()
1010 .unwrap_or_else(|| panic!("Should be Int32Array"));
1011 assert_eq!(ids.value(0), 3);
1012 assert_eq!(ids.value(1), 4);
1013 assert_eq!(ids.value(2), 5);
1014 }
1015
1016 #[test]
1017 fn test_skip_all_rows() {
1018 let batch = create_test_batch();
1019 let transform = Skip::new(10);
1020
1021 let result = transform.apply(batch);
1022 assert!(result.is_ok());
1023 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1024 assert_eq!(result.num_rows(), 0);
1025 }
1026
1027 #[test]
1028 fn test_skip_count_getter() {
1029 let skip = Skip::new(5);
1030 assert_eq!(skip.count(), 5);
1031 }
1032
1033 #[test]
1034 fn test_skip_debug() {
1035 let skip = Skip::new(5);
1036 let debug_str = format!("{:?}", skip);
1037 assert!(debug_str.contains("Skip"));
1038 }
1039
1040 #[test]
1041 fn test_skip_more_than_batch_size() {
1042 let batch = create_test_batch();
1043 let skip = Skip::new(100);
1044 let result = skip.apply(batch);
1045 assert!(result.is_ok());
1046 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1047 assert_eq!(result.num_rows(), 0);
1048 }
1049
1050 #[test]
1051 fn test_skip_beyond_bounds() {
1052 let batch = create_test_batch(); let skip = Skip::new(100); let result = skip.apply(batch);
1055 assert!(result.is_ok());
1056 let result = result.ok().unwrap();
1057 assert_eq!(result.num_rows(), 0); }
1059
1060 #[test]
1061 fn test_skip_zero_rows() {
1062 let batch = create_test_batch();
1063 let original_rows = batch.num_rows();
1064 let skip = Skip::new(0);
1065 let result = skip.apply(batch);
1066 assert!(result.is_ok());
1067 let result = result.ok().unwrap();
1068 assert_eq!(result.num_rows(), original_rows);
1070 }
1071
1072 #[test]
1075 fn test_sort_ascending() {
1076 let schema = Arc::new(Schema::new(vec![Field::new(
1077 "value",
1078 DataType::Int32,
1079 false,
1080 )]));
1081 let values = Int32Array::from(vec![3, 1, 4, 1, 5]);
1082 let batch = RecordBatch::try_new(schema, vec![Arc::new(values)])
1083 .ok()
1084 .unwrap_or_else(|| panic!("Should create batch"));
1085
1086 let transform = Sort::by("value");
1087 let result = transform.apply(batch);
1088 assert!(result.is_ok());
1089 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1090
1091 let col = result
1092 .column(0)
1093 .as_any()
1094 .downcast_ref::<Int32Array>()
1095 .unwrap_or_else(|| panic!("Should be Int32Array"));
1096 assert_eq!(col.value(0), 1);
1097 assert_eq!(col.value(1), 1);
1098 assert_eq!(col.value(2), 3);
1099 assert_eq!(col.value(3), 4);
1100 assert_eq!(col.value(4), 5);
1101 }
1102
1103 #[test]
1104 fn test_sort_descending() {
1105 let schema = Arc::new(Schema::new(vec![Field::new(
1106 "value",
1107 DataType::Int32,
1108 false,
1109 )]));
1110 let values = Int32Array::from(vec![3, 1, 4, 1, 5]);
1111 let batch = RecordBatch::try_new(schema, vec![Arc::new(values)])
1112 .ok()
1113 .unwrap_or_else(|| panic!("Should create batch"));
1114
1115 let transform = Sort::by("value").order(SortOrder::Descending);
1116 let result = transform.apply(batch);
1117 assert!(result.is_ok());
1118 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1119
1120 let col = result
1121 .column(0)
1122 .as_any()
1123 .downcast_ref::<Int32Array>()
1124 .unwrap_or_else(|| panic!("Should be Int32Array"));
1125 assert_eq!(col.value(0), 5);
1126 assert_eq!(col.value(1), 4);
1127 assert_eq!(col.value(2), 3);
1128 assert_eq!(col.value(3), 1);
1129 assert_eq!(col.value(4), 1);
1130 }
1131
1132 #[test]
1133 fn test_sort_preserves_row_integrity() {
1134 let batch = create_test_batch();
1135 let transform = Sort::by("id").order(SortOrder::Descending);
1136
1137 let result = transform.apply(batch);
1138 assert!(result.is_ok());
1139 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1140
1141 let ids = result
1142 .column(0)
1143 .as_any()
1144 .downcast_ref::<Int32Array>()
1145 .unwrap_or_else(|| panic!("Should be Int32Array"));
1146 let values = result
1147 .column(2)
1148 .as_any()
1149 .downcast_ref::<Int32Array>()
1150 .unwrap_or_else(|| panic!("Should be Int32Array"));
1151
1152 for i in 0..ids.len() {
1154 let id = ids.value(i);
1155 let value = values.value(i);
1156 assert_eq!(value, id * 10);
1157 }
1158
1159 assert_eq!(ids.value(0), 5);
1161 assert_eq!(ids.value(4), 1);
1162 }
1163
1164 #[test]
1165 fn test_sort_multiple_columns() {
1166 let schema = Arc::new(Schema::new(vec![
1167 Field::new("group", DataType::Int32, false),
1168 Field::new("value", DataType::Int32, false),
1169 ]));
1170 let groups = Int32Array::from(vec![1, 2, 1, 2, 1]);
1171 let values = Int32Array::from(vec![30, 10, 10, 20, 20]);
1172 let batch = RecordBatch::try_new(schema, vec![Arc::new(groups), Arc::new(values)])
1173 .ok()
1174 .unwrap_or_else(|| panic!("Should create batch"));
1175
1176 let transform = Sort::by_columns(vec![
1177 ("group", SortOrder::Ascending),
1178 ("value", SortOrder::Ascending),
1179 ]);
1180 let result = transform.apply(batch);
1181 assert!(result.is_ok());
1182 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1183
1184 let groups = result
1185 .column(0)
1186 .as_any()
1187 .downcast_ref::<Int32Array>()
1188 .unwrap_or_else(|| panic!("Should be Int32Array"));
1189 let values = result
1190 .column(1)
1191 .as_any()
1192 .downcast_ref::<Int32Array>()
1193 .unwrap_or_else(|| panic!("Should be Int32Array"));
1194
1195 assert_eq!(groups.value(0), 1);
1197 assert_eq!(values.value(0), 10);
1198 assert_eq!(groups.value(1), 1);
1199 assert_eq!(values.value(1), 20);
1200 assert_eq!(groups.value(2), 1);
1201 assert_eq!(values.value(2), 30);
1202 assert_eq!(groups.value(3), 2);
1204 assert_eq!(values.value(3), 10);
1205 assert_eq!(groups.value(4), 2);
1206 assert_eq!(values.value(4), 20);
1207 }
1208
1209 #[test]
1210 fn test_sort_column_not_found() {
1211 let batch = create_test_batch();
1212 let transform = Sort::by("nonexistent");
1213
1214 let result = transform.apply(batch);
1215 assert!(result.is_err());
1216 }
1217
1218 #[test]
1219 fn test_sort_columns_getter() {
1220 let sort = Sort::by("value").order(SortOrder::Descending);
1221 let cols = sort.columns();
1222 assert_eq!(cols.len(), 1);
1223 assert_eq!(cols[0].0, "value");
1224 assert_eq!(cols[0].1, SortOrder::Descending);
1225 }
1226
1227 #[test]
1228 fn test_sort_order_default() {
1229 let order = SortOrder::default();
1230 assert_eq!(order, SortOrder::Ascending);
1231 }
1232
1233 #[test]
1234 fn test_sort_single_row() {
1235 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1236 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1]))])
1237 .ok()
1238 .unwrap_or_else(|| panic!("Should create batch"));
1239
1240 let transform = Sort::by("id");
1241 let result = transform.apply(batch);
1242 assert!(result.is_ok());
1243 }
1244
1245 #[test]
1246 fn test_sort_debug() {
1247 let sort = Sort::by("col");
1248 let debug_str = format!("{:?}", sort);
1249 assert!(debug_str.contains("Sort"));
1250 }
1251
1252 #[test]
1253 fn test_sort_empty_batch() {
1254 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1255 let empty_batch = RecordBatch::new_empty(schema);
1256
1257 let sort = Sort::by("id");
1258 let result = sort.apply(empty_batch);
1259 assert!(result.is_ok());
1260 }
1261
1262 #[test]
1263 fn test_sort_empty_columns_vector() {
1264 let batch = create_test_batch();
1265 let sort = Sort::by_columns::<String>(vec![]);
1266 let result = sort.apply(batch.clone());
1267 assert!(result.is_ok());
1268 let result = result.ok().unwrap();
1269 assert_eq!(result.num_rows(), batch.num_rows());
1271 }
1272
1273 #[test]
1274 fn test_sort_multi_column_one_missing() {
1275 let batch = create_test_batch();
1276 let sort = Sort::by_columns(vec![
1277 ("value".to_string(), SortOrder::Ascending),
1278 ("nonexistent".to_string(), SortOrder::Ascending),
1279 ]);
1280 let result = sort.apply(batch);
1281 assert!(result.is_err());
1283 }
1284
1285 #[test]
1286 fn test_sort_single_row_unchanged() {
1287 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1289 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![42]))])
1290 .ok()
1291 .unwrap();
1292
1293 let sort = Sort::by_columns(vec![("id".to_string(), SortOrder::Ascending)]);
1294 let result = sort.apply(batch.clone());
1295 assert!(result.is_ok());
1296 let result = result.ok().unwrap();
1297 assert_eq!(result.num_rows(), 1);
1299 }
1300
1301 #[test]
1304 fn test_unique_all_columns() {
1305 let schema = Arc::new(Schema::new(vec![
1306 Field::new("id", DataType::Int32, false),
1307 Field::new("value", DataType::Int32, false),
1308 ]));
1309 let ids = Int32Array::from(vec![1, 2, 1, 2, 1]); let values = Int32Array::from(vec![10, 20, 10, 20, 30]); let batch = RecordBatch::try_new(schema, vec![Arc::new(ids), Arc::new(values)])
1312 .ok()
1313 .unwrap_or_else(|| panic!("Should create batch"));
1314
1315 let transform = Unique::all();
1316 let result = transform.apply(batch);
1317 assert!(result.is_ok());
1318 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1319
1320 assert_eq!(result.num_rows(), 3); }
1322
1323 #[test]
1324 fn test_unique_by_column() {
1325 let schema = Arc::new(Schema::new(vec![
1326 Field::new("id", DataType::Int32, false),
1327 Field::new("value", DataType::Int32, false),
1328 ]));
1329 let ids = Int32Array::from(vec![1, 2, 1, 2, 3]);
1330 let values = Int32Array::from(vec![10, 20, 30, 40, 50]);
1331 let batch = RecordBatch::try_new(schema, vec![Arc::new(ids), Arc::new(values)])
1332 .ok()
1333 .unwrap_or_else(|| panic!("Should create batch"));
1334
1335 let transform = Unique::by(vec!["id"]);
1336 let result = transform.apply(batch);
1337 assert!(result.is_ok());
1338 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1339
1340 assert_eq!(result.num_rows(), 3); let ids = result
1343 .column(0)
1344 .as_any()
1345 .downcast_ref::<Int32Array>()
1346 .unwrap_or_else(|| panic!("Should be Int32Array"));
1347 let values = result
1348 .column(1)
1349 .as_any()
1350 .downcast_ref::<Int32Array>()
1351 .unwrap_or_else(|| panic!("Should be Int32Array"));
1352
1353 assert_eq!(ids.value(0), 1);
1355 assert_eq!(values.value(0), 10); assert_eq!(ids.value(1), 2);
1357 assert_eq!(values.value(1), 20); }
1359
1360 #[test]
1361 fn test_unique_keep_last() {
1362 let schema = Arc::new(Schema::new(vec![
1363 Field::new("id", DataType::Int32, false),
1364 Field::new("value", DataType::Int32, false),
1365 ]));
1366 let ids = Int32Array::from(vec![1, 2, 1, 2, 3]);
1367 let values = Int32Array::from(vec![10, 20, 30, 40, 50]);
1368 let batch = RecordBatch::try_new(schema, vec![Arc::new(ids), Arc::new(values)])
1369 .ok()
1370 .unwrap_or_else(|| panic!("Should create batch"));
1371
1372 let transform = Unique::by(vec!["id"]).keep_last();
1373 let result = transform.apply(batch);
1374 assert!(result.is_ok());
1375 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1376
1377 assert_eq!(result.num_rows(), 3);
1378
1379 let ids = result
1380 .column(0)
1381 .as_any()
1382 .downcast_ref::<Int32Array>()
1383 .unwrap_or_else(|| panic!("Should be Int32Array"));
1384 let values = result
1385 .column(1)
1386 .as_any()
1387 .downcast_ref::<Int32Array>()
1388 .unwrap_or_else(|| panic!("Should be Int32Array"));
1389
1390 assert_eq!(ids.value(0), 1);
1392 assert_eq!(values.value(0), 30); assert_eq!(ids.value(1), 2);
1394 assert_eq!(values.value(1), 40); }
1396
1397 #[test]
1398 fn test_unique_no_duplicates() {
1399 let batch = create_test_batch();
1400 let transform = Unique::all();
1401
1402 let result = transform.apply(batch.clone());
1403 assert!(result.is_ok());
1404 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1405
1406 assert_eq!(result.num_rows(), batch.num_rows()); }
1408
1409 #[test]
1410 fn test_unique_column_not_found() {
1411 let batch = create_test_batch();
1412 let transform = Unique::by(vec!["nonexistent"]);
1413
1414 let result = transform.apply(batch);
1415 assert!(result.is_err());
1416 }
1417
1418 #[test]
1419 fn test_unique_columns_getter() {
1420 let unique = Unique::by(vec!["a", "b"]);
1421 assert!(unique.columns().is_some());
1422 assert_eq!(
1423 unique
1424 .columns()
1425 .unwrap_or_else(|| panic!("Should have columns")),
1426 &["a", "b"]
1427 );
1428
1429 let unique2 = Unique::all();
1430 assert!(unique2.columns().is_none());
1431 }
1432
1433 #[test]
1434 fn test_unique_single_row() {
1435 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1436 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1]))])
1437 .ok()
1438 .unwrap_or_else(|| panic!("Should create batch"));
1439
1440 let transform = Unique::all();
1441 let result = transform.apply(batch);
1442 assert!(result.is_ok());
1443 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1444 assert_eq!(result.num_rows(), 1);
1445 }
1446
1447 #[test]
1448 fn test_unique_debug() {
1449 let unique = Unique::all();
1450 let debug_str = format!("{:?}", unique);
1451 assert!(debug_str.contains("Unique"));
1452 }
1453
1454 #[test]
1455 fn test_unique_with_int64_column() {
1456 use arrow::array::Int64Array;
1457
1458 let schema = Arc::new(Schema::new(vec![
1459 Field::new("id", DataType::Int64, false),
1460 Field::new("name", DataType::Utf8, false),
1461 ]));
1462
1463 let id_arr = Int64Array::from(vec![1i64, 1i64, 2i64, 2i64, 3i64]);
1464 let name_arr = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1465
1466 let batch = RecordBatch::try_new(schema, vec![Arc::new(id_arr), Arc::new(name_arr)])
1467 .ok()
1468 .unwrap_or_else(|| panic!("batch"));
1469
1470 let unique = Unique::by(vec!["id"]);
1471 let result = unique.apply(batch);
1472 assert!(result.is_ok());
1473 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1474 assert_eq!(result.num_rows(), 3);
1475 }
1476
1477 #[test]
1478 fn test_unique_with_float64_column() {
1479 use arrow::array::Float64Array;
1480
1481 let schema = Arc::new(Schema::new(vec![
1482 Field::new("val", DataType::Float64, false),
1483 Field::new("name", DataType::Utf8, false),
1484 ]));
1485
1486 let val_arr = Float64Array::from(vec![1.0f64, 1.0f64, 2.0f64, 2.0f64, 3.0f64]);
1487 let name_arr = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1488
1489 let batch = RecordBatch::try_new(schema, vec![Arc::new(val_arr), Arc::new(name_arr)])
1490 .ok()
1491 .unwrap_or_else(|| panic!("batch"));
1492
1493 let unique = Unique::by(vec!["val"]);
1494 let result = unique.apply(batch);
1495 assert!(result.is_ok());
1496 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1497 assert_eq!(result.num_rows(), 3);
1498 }
1499
1500 #[test]
1501 fn test_unique_with_float32_column() {
1502 use arrow::array::Float32Array;
1503
1504 let schema = Arc::new(Schema::new(vec![
1505 Field::new("val", DataType::Float32, false),
1506 Field::new("name", DataType::Utf8, false),
1507 ]));
1508
1509 let val_arr = Float32Array::from(vec![1.0f32, 1.0f32, 2.0f32, 2.0f32, 3.0f32]);
1510 let name_arr = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1511
1512 let batch = RecordBatch::try_new(schema, vec![Arc::new(val_arr), Arc::new(name_arr)])
1513 .ok()
1514 .unwrap_or_else(|| panic!("batch"));
1515
1516 let unique = Unique::by(vec!["val"]);
1517 let result = unique.apply(batch);
1518 assert!(result.is_ok());
1519 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1520 assert_eq!(result.num_rows(), 3);
1521 }
1522
1523 #[test]
1524 fn test_unique_with_bool_column() {
1525 use arrow::array::BooleanArray;
1526
1527 let schema = Arc::new(Schema::new(vec![
1528 Field::new("flag", DataType::Boolean, false),
1529 Field::new("name", DataType::Utf8, false),
1530 ]));
1531
1532 let flag_arr = BooleanArray::from(vec![true, true, false, false, true]);
1533 let name_arr = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1534
1535 let batch = RecordBatch::try_new(schema, vec![Arc::new(flag_arr), Arc::new(name_arr)])
1536 .ok()
1537 .unwrap_or_else(|| panic!("batch"));
1538
1539 let unique = Unique::by(vec!["flag"]);
1540 let result = unique.apply(batch);
1541 assert!(result.is_ok());
1542 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1543 assert_eq!(result.num_rows(), 2);
1544 }
1545
1546 #[test]
1547 fn test_unique_with_null_values() {
1548 let schema = Arc::new(Schema::new(vec![
1549 Field::new("id", DataType::Int32, true),
1550 Field::new("name", DataType::Utf8, false),
1551 ]));
1552
1553 let id_arr = Int32Array::from(vec![Some(1), None, Some(1), None, Some(2)]);
1554 let name_arr = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1555
1556 let batch = RecordBatch::try_new(schema, vec![Arc::new(id_arr), Arc::new(name_arr)])
1557 .ok()
1558 .unwrap_or_else(|| panic!("batch"));
1559
1560 let unique = Unique::by(vec!["id"]);
1561 let result = unique.apply(batch);
1562 assert!(result.is_ok());
1563 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1564 assert_eq!(result.num_rows(), 3); }
1566
1567 #[test]
1568 fn test_unique_empty_batch() {
1569 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1570 let arr = Int32Array::from(Vec::<i32>::new());
1571 let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)])
1572 .ok()
1573 .unwrap_or_else(|| panic!("batch"));
1574
1575 let unique = Unique::by(["id"]);
1576 let result = unique.apply(batch);
1577 assert!(result.is_ok());
1578 let result = result.ok().unwrap();
1579 assert_eq!(result.num_rows(), 0);
1580 }
1581}