Skip to main content

alimentar/transform/
row_ops.rs

1//! Row-level operations: sorting, sampling, deduplication, and slicing.
2
3use std::sync::Arc;
4
5use arrow::array::{Array, RecordBatch};
6#[cfg(feature = "shuffle")]
7use rand::{seq::SliceRandom, SeedableRng};
8
9use super::Transform;
10use crate::error::{Error, Result};
11
12/// A transform that shuffles rows in a RecordBatch.
13///
14/// Requires the `shuffle` feature.
15///
16/// # Example
17///
18/// ```ignore
19/// use alimentar::Shuffle;
20///
21/// // Random shuffle
22/// let shuffle = Shuffle::new();
23///
24/// // Deterministic shuffle with seed
25/// let shuffle = Shuffle::with_seed(42);
26/// ```
27#[cfg(feature = "shuffle")]
28#[derive(Debug, Clone)]
29pub struct Shuffle {
30    seed: Option<u64>,
31}
32
33#[cfg(feature = "shuffle")]
34impl Shuffle {
35    /// Creates a new Shuffle transform with random ordering.
36    pub fn new() -> Self {
37        Self { seed: None }
38    }
39
40    /// Creates a new Shuffle transform with a fixed seed for reproducibility.
41    pub fn with_seed(seed: u64) -> Self {
42        Self { seed: Some(seed) }
43    }
44}
45
46#[cfg(feature = "shuffle")]
47impl Default for Shuffle {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53#[cfg(feature = "shuffle")]
54impl Transform for Shuffle {
55    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
56        let num_rows = batch.num_rows();
57        if num_rows <= 1 {
58            return Ok(batch);
59        }
60
61        // Create shuffled indices
62        let mut indices: Vec<usize> = (0..num_rows).collect();
63        let mut rng = match self.seed {
64            Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
65            None => rand::rngs::StdRng::from_entropy(),
66        };
67        indices.shuffle(&mut rng);
68
69        // Reorder each column according to shuffled indices
70        let schema = batch.schema();
71        let new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
72            .map(|col_idx| {
73                let col = batch.column(col_idx);
74                let indices_array =
75                    arrow::array::UInt64Array::from_iter_values(indices.iter().map(|&i| i as u64));
76                arrow::compute::take(col.as_ref(), &indices_array, None)
77                    .map_err(Error::Arrow)
78                    .map(Arc::from)
79            })
80            .collect::<Result<Vec<_>>>()?;
81
82        RecordBatch::try_new(schema, new_columns).map_err(Error::Arrow)
83    }
84}
85
86/// A transform that randomly samples rows from a RecordBatch.
87///
88/// Useful for creating train/test splits or reducing dataset size.
89/// Requires the `shuffle` feature.
90///
91/// # Example
92///
93/// ```ignore
94/// use alimentar::Sample;
95///
96/// // Sample 100 rows with a fixed seed
97/// let sample = Sample::new(100).with_seed(42);
98///
99/// // Sample 10% of rows
100/// let sample = Sample::fraction(0.1);
101/// ```
102#[cfg(feature = "shuffle")]
103#[derive(Debug, Clone)]
104pub struct Sample {
105    count: Option<usize>,
106    fraction: Option<f64>,
107    seed: Option<u64>,
108}
109
110#[cfg(feature = "shuffle")]
111impl Sample {
112    /// Creates a Sample transform that selects exactly `count` rows.
113    ///
114    /// If the batch has fewer rows than `count`, all rows are returned.
115    pub fn new(count: usize) -> Self {
116        Self {
117            count: Some(count),
118            fraction: None,
119            seed: None,
120        }
121    }
122
123    /// Creates a Sample transform that selects a fraction of rows.
124    ///
125    /// The fraction should be between 0.0 and 1.0.
126    pub fn fraction(frac: f64) -> Self {
127        Self {
128            count: None,
129            fraction: Some(frac.clamp(0.0, 1.0)),
130            seed: None,
131        }
132    }
133
134    /// Sets a seed for reproducible sampling.
135    #[must_use]
136    pub fn with_seed(mut self, seed: u64) -> Self {
137        self.seed = Some(seed);
138        self
139    }
140
141    /// Returns the sample count if set.
142    pub fn count(&self) -> Option<usize> {
143        self.count
144    }
145
146    /// Returns the sample fraction if set.
147    pub fn sample_fraction(&self) -> Option<f64> {
148        self.fraction
149    }
150}
151
152#[cfg(feature = "shuffle")]
153impl Transform for Sample {
154    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
155        let num_rows = batch.num_rows();
156        if num_rows == 0 {
157            return Ok(batch);
158        }
159
160        #[allow(
161            clippy::cast_possible_truncation,
162            clippy::cast_sign_loss,
163            clippy::cast_precision_loss
164        )]
165        let sample_size = match (self.count, self.fraction) {
166            (Some(c), _) => c.min(num_rows),
167            (None, Some(f)) => ((num_rows as f64) * f).round() as usize,
168            (None, None) => return Ok(batch),
169        };
170
171        if sample_size >= num_rows {
172            return Ok(batch);
173        }
174
175        // Create shuffled indices and take first sample_size
176        let mut indices: Vec<usize> = (0..num_rows).collect();
177        let mut rng = match self.seed {
178            Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
179            None => rand::rngs::StdRng::from_entropy(),
180        };
181        indices.shuffle(&mut rng);
182        indices.truncate(sample_size);
183        indices.sort_unstable(); // Keep original order
184
185        // Reorder each column according to sampled indices
186        let schema = batch.schema();
187        let new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
188            .map(|col_idx| {
189                let col = batch.column(col_idx);
190                let indices_array =
191                    arrow::array::UInt64Array::from_iter_values(indices.iter().map(|&i| i as u64));
192                arrow::compute::take(col.as_ref(), &indices_array, None)
193                    .map_err(Error::Arrow)
194                    .map(Arc::from)
195            })
196            .collect::<Result<Vec<_>>>()?;
197
198        RecordBatch::try_new(schema, new_columns).map_err(Error::Arrow)
199    }
200}
201
202/// A transform that takes the first N rows from a RecordBatch.
203///
204/// # Example
205///
206/// ```ignore
207/// use alimentar::Take;
208///
209/// let take = Take::new(100); // Take first 100 rows
210/// ```
211#[derive(Debug, Clone, Copy)]
212pub struct Take {
213    count: usize,
214}
215
216impl Take {
217    /// Creates a Take transform that keeps the first `count` rows.
218    pub fn new(count: usize) -> Self {
219        Self { count }
220    }
221
222    /// Returns the number of rows to take.
223    pub fn count(&self) -> usize {
224        self.count
225    }
226}
227
228impl Transform for Take {
229    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
230        let num_rows = batch.num_rows();
231        if self.count >= num_rows {
232            return Ok(batch);
233        }
234
235        Ok(batch.slice(0, self.count))
236    }
237}
238
239/// A transform that skips the first N rows from a RecordBatch.
240///
241/// # Example
242///
243/// ```ignore
244/// use alimentar::Skip;
245///
246/// let skip = Skip::new(10); // Skip first 10 rows
247/// ```
248#[derive(Debug, Clone, Copy)]
249pub struct Skip {
250    count: usize,
251}
252
253impl Skip {
254    /// Creates a Skip transform that skips the first `count` rows.
255    pub fn new(count: usize) -> Self {
256        Self { count }
257    }
258
259    /// Returns the number of rows to skip.
260    pub fn count(&self) -> usize {
261        self.count
262    }
263}
264
265impl Transform for Skip {
266    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
267        let num_rows = batch.num_rows();
268        if self.count >= num_rows {
269            // Skip all rows - return empty batch with same schema
270            return Ok(batch.slice(0, 0));
271        }
272
273        let remaining = num_rows - self.count;
274        Ok(batch.slice(self.count, remaining))
275    }
276}
277
278/// Sort order for the Sort transform.
279#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
280pub enum SortOrder {
281    /// Ascending order (smallest to largest)
282    #[default]
283    Ascending,
284    /// Descending order (largest to smallest)
285    Descending,
286}
287
288/// A transform that sorts rows by one or more columns.
289///
290/// # Example
291///
292/// ```ignore
293/// use alimentar::{Sort, SortOrder};
294///
295/// // Sort by single column ascending
296/// let sort = Sort::by("age");
297///
298/// // Sort by column descending
299/// let sort = Sort::by("score").order(SortOrder::Descending);
300///
301/// // Sort by multiple columns
302/// let sort = Sort::by_columns(vec![("name", SortOrder::Ascending), ("age", SortOrder::Descending)]);
303/// ```
304#[derive(Debug, Clone)]
305pub struct Sort {
306    columns: Vec<(String, SortOrder)>,
307    nulls_first: bool,
308}
309
310impl Sort {
311    /// Creates a Sort transform for a single column (ascending by default).
312    pub fn by<S: Into<String>>(column: S) -> Self {
313        Self {
314            columns: vec![(column.into(), SortOrder::Ascending)],
315            nulls_first: false,
316        }
317    }
318
319    /// Creates a Sort transform for multiple columns with specified orders.
320    pub fn by_columns<S: Into<String>>(columns: impl IntoIterator<Item = (S, SortOrder)>) -> Self {
321        Self {
322            columns: columns
323                .into_iter()
324                .map(|(name, order)| (name.into(), order))
325                .collect(),
326            nulls_first: false,
327        }
328    }
329
330    /// Sets the sort order for a single-column sort.
331    #[must_use]
332    pub fn order(mut self, order: SortOrder) -> Self {
333        if let Some((_, o)) = self.columns.first_mut() {
334            *o = order;
335        }
336        self
337    }
338
339    /// Sets whether nulls should appear first (default: false, nulls last).
340    #[must_use]
341    pub fn nulls_first(mut self, nulls_first: bool) -> Self {
342        self.nulls_first = nulls_first;
343        self
344    }
345
346    /// Returns the columns and their sort orders.
347    pub fn columns(&self) -> &[(String, SortOrder)] {
348        &self.columns
349    }
350}
351
352impl Transform for Sort {
353    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
354        use arrow::compute::{lexsort_to_indices, take, SortColumn, SortOptions};
355
356        if batch.num_rows() <= 1 || self.columns.is_empty() {
357            return Ok(batch);
358        }
359
360        let schema = batch.schema();
361
362        // Build sort columns
363        let sort_columns: Vec<SortColumn> = self
364            .columns
365            .iter()
366            .map(|(col_name, order)| {
367                let (idx, _) = schema
368                    .column_with_name(col_name)
369                    .ok_or_else(|| Error::column_not_found(col_name))?;
370
371                Ok(SortColumn {
372                    values: Arc::clone(batch.column(idx)),
373                    options: Some(SortOptions {
374                        descending: *order == SortOrder::Descending,
375                        nulls_first: self.nulls_first,
376                    }),
377                })
378            })
379            .collect::<Result<Vec<_>>>()?;
380
381        // Get sorted indices
382        let indices = lexsort_to_indices(&sort_columns, None).map_err(Error::Arrow)?;
383
384        // Reorder all columns
385        let new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
386            .map(|col_idx| {
387                let col = batch.column(col_idx);
388                take(col.as_ref(), &indices, None)
389                    .map_err(Error::Arrow)
390                    .map(Arc::from)
391            })
392            .collect::<Result<Vec<_>>>()?;
393
394        RecordBatch::try_new(schema, new_columns).map_err(Error::Arrow)
395    }
396}
397
398/// A transform that removes duplicate rows based on specified columns.
399///
400/// # Example
401///
402/// ```ignore
403/// use alimentar::Unique;
404///
405/// // Keep only unique rows based on all columns
406/// let unique = Unique::all();
407///
408/// // Keep only unique rows based on specific columns
409/// let unique = Unique::by(vec!["user_id", "date"]);
410///
411/// // Keep first occurrence (default) or last
412/// let unique = Unique::by(vec!["id"]).keep_first();
413/// let unique = Unique::by(vec!["id"]).keep_last();
414/// ```
415#[derive(Debug, Clone)]
416pub struct Unique {
417    columns: Option<Vec<String>>,
418    keep_last: bool,
419}
420
421impl Unique {
422    /// Creates a Unique transform that considers all columns.
423    pub fn all() -> Self {
424        Self {
425            columns: None,
426            keep_last: false,
427        }
428    }
429
430    /// Creates a Unique transform that considers specific columns.
431    pub fn by<S: Into<String>>(columns: impl IntoIterator<Item = S>) -> Self {
432        Self {
433            columns: Some(columns.into_iter().map(Into::into).collect()),
434            keep_last: false,
435        }
436    }
437
438    /// Keep the first occurrence of duplicates (default).
439    #[must_use]
440    pub fn keep_first(mut self) -> Self {
441        self.keep_last = false;
442        self
443    }
444
445    /// Keep the last occurrence of duplicates.
446    #[must_use]
447    pub fn keep_last(mut self) -> Self {
448        self.keep_last = true;
449        self
450    }
451
452    /// Returns the columns used for uniqueness check.
453    pub fn columns(&self) -> Option<&[String]> {
454        self.columns.as_deref()
455    }
456
457    fn row_key(batch: &RecordBatch, row_idx: usize, key_indices: &[usize]) -> u64 {
458        use std::hash::{DefaultHasher, Hash, Hasher};
459
460        use arrow::array::{
461            BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray,
462        };
463
464        let mut hasher = DefaultHasher::new();
465
466        for &col_idx in key_indices {
467            let col = batch.column(col_idx);
468            if col.is_null(row_idx) {
469                // Sentinel: hash a tag byte that cannot collide with real values
470                0u8.hash(&mut hasher);
471            } else if let Some(arr) = col.as_any().downcast_ref::<Int32Array>() {
472                1u8.hash(&mut hasher);
473                arr.value(row_idx).hash(&mut hasher);
474            } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
475                2u8.hash(&mut hasher);
476                arr.value(row_idx).hash(&mut hasher);
477            } else if let Some(arr) = col.as_any().downcast_ref::<Float32Array>() {
478                // Use bits for exact comparison (same semantics as before)
479                3u8.hash(&mut hasher);
480                arr.value(row_idx).to_bits().hash(&mut hasher);
481            } else if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
482                4u8.hash(&mut hasher);
483                arr.value(row_idx).to_bits().hash(&mut hasher);
484            } else if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
485                5u8.hash(&mut hasher);
486                arr.value(row_idx).hash(&mut hasher);
487            } else if let Some(arr) = col.as_any().downcast_ref::<BooleanArray>() {
488                6u8.hash(&mut hasher);
489                arr.value(row_idx).hash(&mut hasher);
490            } else {
491                // Fallback: hash the data type debug repr
492                7u8.hash(&mut hasher);
493                format!("{:?}", col.data_type()).hash(&mut hasher);
494            }
495        }
496
497        hasher.finish()
498    }
499}
500
501impl Transform for Unique {
502    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
503        use std::collections::HashMap;
504
505        let num_rows = batch.num_rows();
506        if num_rows <= 1 {
507            return Ok(batch);
508        }
509
510        let schema = batch.schema();
511
512        // Determine which columns to use for uniqueness
513        let key_indices: Vec<usize> = match &self.columns {
514            Some(cols) => cols
515                .iter()
516                .map(|name| {
517                    schema
518                        .column_with_name(name)
519                        .map(|(idx, _)| idx)
520                        .ok_or_else(|| Error::column_not_found(name))
521                })
522                .collect::<Result<Vec<_>>>()?,
523            None => (0..schema.fields().len()).collect(),
524        };
525
526        // Build a hash of each row's key columns (u64 hash keys to avoid String
527        // allocation storm)
528        let mut seen: HashMap<u64, usize> = HashMap::new();
529        let mut keep_indices: Vec<usize> = Vec::new();
530
531        let row_iter: Box<dyn Iterator<Item = usize>> = if self.keep_last {
532            Box::new((0..num_rows).rev())
533        } else {
534            Box::new(0..num_rows)
535        };
536
537        for row_idx in row_iter {
538            let key = Self::row_key(&batch, row_idx, &key_indices);
539
540            if let std::collections::hash_map::Entry::Vacant(e) = seen.entry(key) {
541                e.insert(row_idx);
542                keep_indices.push(row_idx);
543            }
544        }
545
546        if self.keep_last {
547            keep_indices.reverse();
548        }
549
550        if keep_indices.len() == num_rows {
551            return Ok(batch);
552        }
553
554        // Build new batch with only unique rows
555        let indices_array =
556            arrow::array::UInt64Array::from_iter_values(keep_indices.iter().map(|&i| i as u64));
557
558        let new_columns: Vec<Arc<dyn Array>> = (0..batch.num_columns())
559            .map(|col_idx| {
560                let col = batch.column(col_idx);
561                arrow::compute::take(col.as_ref(), &indices_array, None)
562                    .map_err(Error::Arrow)
563                    .map(Arc::from)
564            })
565            .collect::<Result<Vec<_>>>()?;
566
567        RecordBatch::try_new(schema, new_columns).map_err(Error::Arrow)
568    }
569}
570
571#[cfg(test)]
572#[allow(
573    clippy::float_cmp,
574    clippy::cast_precision_loss,
575    clippy::redundant_closure
576)]
577mod tests {
578    use arrow::{
579        array::{Int32Array, StringArray},
580        datatypes::{DataType, Field, Schema},
581    };
582
583    use super::*;
584
585    fn create_test_batch() -> RecordBatch {
586        let schema = Arc::new(Schema::new(vec![
587            Field::new("id", DataType::Int32, false),
588            Field::new("name", DataType::Utf8, false),
589            Field::new("value", DataType::Int32, false),
590        ]));
591
592        let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
593        let name_array = StringArray::from(vec!["a", "b", "c", "d", "e"]);
594        let value_array = Int32Array::from(vec![10, 20, 30, 40, 50]);
595
596        RecordBatch::try_new(
597            schema,
598            vec![
599                Arc::new(id_array),
600                Arc::new(name_array),
601                Arc::new(value_array),
602            ],
603        )
604        .ok()
605        .unwrap_or_else(|| panic!("Should create batch"))
606    }
607
608    #[cfg(feature = "shuffle")]
609    #[test]
610    fn test_shuffle_transform_deterministic() {
611        let batch = create_test_batch();
612        let transform = Shuffle::with_seed(42);
613
614        let result1 = transform.apply(batch.clone());
615        let result2 = transform.apply(batch);
616
617        assert!(result1.is_ok());
618        assert!(result2.is_ok());
619
620        let result1 = result1.ok().unwrap_or_else(|| panic!("Should succeed"));
621        let result2 = result2.ok().unwrap_or_else(|| panic!("Should succeed"));
622
623        // Same seed should produce same shuffle
624        let col1 = result1
625            .column(0)
626            .as_any()
627            .downcast_ref::<Int32Array>()
628            .unwrap_or_else(|| panic!("Should be Int32Array"));
629        let col2 = result2
630            .column(0)
631            .as_any()
632            .downcast_ref::<Int32Array>()
633            .unwrap_or_else(|| panic!("Should be Int32Array"));
634
635        for i in 0..col1.len() {
636            assert_eq!(col1.value(i), col2.value(i));
637        }
638    }
639
640    #[cfg(feature = "shuffle")]
641    #[test]
642    fn test_shuffle_preserves_row_integrity() {
643        let batch = create_test_batch();
644        let transform = Shuffle::with_seed(42);
645
646        let result = transform.apply(batch);
647        assert!(result.is_ok());
648        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
649
650        // Check that id-name-value relationships are preserved
651        let ids = result
652            .column(0)
653            .as_any()
654            .downcast_ref::<Int32Array>()
655            .unwrap_or_else(|| panic!("Should be Int32Array"));
656        let values = result
657            .column(2)
658            .as_any()
659            .downcast_ref::<Int32Array>()
660            .unwrap_or_else(|| panic!("Should be Int32Array"));
661
662        // Original: id=1 -> value=10, id=2 -> value=20, etc.
663        for i in 0..ids.len() {
664            let id = ids.value(i);
665            let value = values.value(i);
666            assert_eq!(value, id * 10);
667        }
668    }
669
670    #[cfg(feature = "shuffle")]
671    #[test]
672    fn test_shuffle_single_row() {
673        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
674        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1]))])
675            .ok()
676            .unwrap_or_else(|| panic!("Should create batch"));
677
678        let transform = Shuffle::new();
679        let result = transform.apply(batch);
680        assert!(result.is_ok());
681    }
682
683    #[cfg(feature = "shuffle")]
684    #[test]
685    fn test_shuffle_default() {
686        let shuffle = Shuffle::default();
687        let batch = create_test_batch();
688        let result = shuffle.apply(batch);
689        assert!(result.is_ok());
690    }
691
692    #[cfg(feature = "shuffle")]
693    #[test]
694    fn test_shuffle_debug() {
695        let shuffle = Shuffle::new();
696        let debug_str = format!("{:?}", shuffle);
697        assert!(debug_str.contains("Shuffle"));
698    }
699
700    #[cfg(feature = "shuffle")]
701    #[test]
702    fn test_shuffle_with_seed() {
703        let batch = create_test_batch();
704        let shuffle = Shuffle::with_seed(12345);
705
706        let result1 = shuffle.apply(batch.clone());
707        let result2 = shuffle.apply(batch);
708
709        assert!(result1.is_ok());
710        assert!(result2.is_ok());
711
712        let result1 = result1.ok().unwrap_or_else(|| panic!("Should succeed"));
713        let result2 = result2.ok().unwrap_or_else(|| panic!("Should succeed"));
714
715        let ids1 = result1
716            .column(0)
717            .as_any()
718            .downcast_ref::<Int32Array>()
719            .unwrap_or_else(|| panic!("Should be Int32Array"));
720        let ids2 = result2
721            .column(0)
722            .as_any()
723            .downcast_ref::<Int32Array>()
724            .unwrap_or_else(|| panic!("Should be Int32Array"));
725
726        for i in 0..ids1.len() {
727            assert_eq!(ids1.value(i), ids2.value(i));
728        }
729    }
730
731    // Sample transform tests
732
733    #[cfg(feature = "shuffle")]
734    #[test]
735    fn test_sample_by_count() {
736        let batch = create_test_batch();
737        let transform = Sample::new(3).with_seed(42);
738
739        let result = transform.apply(batch);
740        assert!(result.is_ok());
741        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
742        assert_eq!(result.num_rows(), 3);
743    }
744
745    #[cfg(feature = "shuffle")]
746    #[test]
747    fn test_sample_by_fraction() {
748        let batch = create_test_batch();
749        let transform = Sample::fraction(0.4).with_seed(42);
750
751        let result = transform.apply(batch);
752        assert!(result.is_ok());
753        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
754        assert_eq!(result.num_rows(), 2); // 5 * 0.4 = 2
755    }
756
757    #[cfg(feature = "shuffle")]
758    #[test]
759    fn test_sample_deterministic() {
760        let batch = create_test_batch();
761        let transform = Sample::new(3).with_seed(42);
762
763        let result1 = transform.apply(batch.clone());
764        let result2 = transform.apply(batch);
765
766        assert!(result1.is_ok());
767        assert!(result2.is_ok());
768
769        let result1 = result1.ok().unwrap_or_else(|| panic!("Should succeed"));
770        let result2 = result2.ok().unwrap_or_else(|| panic!("Should succeed"));
771
772        let col1 = result1
773            .column(0)
774            .as_any()
775            .downcast_ref::<Int32Array>()
776            .unwrap_or_else(|| panic!("Should be Int32Array"));
777        let col2 = result2
778            .column(0)
779            .as_any()
780            .downcast_ref::<Int32Array>()
781            .unwrap_or_else(|| panic!("Should be Int32Array"));
782
783        for i in 0..col1.len() {
784            assert_eq!(col1.value(i), col2.value(i));
785        }
786    }
787
788    #[cfg(feature = "shuffle")]
789    #[test]
790    fn test_sample_preserves_row_integrity() {
791        let batch = create_test_batch();
792        let transform = Sample::new(3).with_seed(42);
793
794        let result = transform.apply(batch);
795        assert!(result.is_ok());
796        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
797
798        let ids = result
799            .column(0)
800            .as_any()
801            .downcast_ref::<Int32Array>()
802            .unwrap_or_else(|| panic!("Should be Int32Array"));
803        let values = result
804            .column(2)
805            .as_any()
806            .downcast_ref::<Int32Array>()
807            .unwrap_or_else(|| panic!("Should be Int32Array"));
808
809        // Check that id-value relationships are preserved
810        for i in 0..ids.len() {
811            let id = ids.value(i);
812            let value = values.value(i);
813            assert_eq!(value, id * 10);
814        }
815    }
816
817    #[cfg(feature = "shuffle")]
818    #[test]
819    fn test_sample_count_larger_than_batch() {
820        let batch = create_test_batch();
821        let transform = Sample::new(100);
822
823        let result = transform.apply(batch.clone());
824        assert!(result.is_ok());
825        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
826        assert_eq!(result.num_rows(), batch.num_rows());
827    }
828
829    #[cfg(feature = "shuffle")]
830    #[test]
831    fn test_sample_getters() {
832        let sample = Sample::new(10).with_seed(42);
833        assert_eq!(sample.count(), Some(10));
834        assert!(sample.sample_fraction().is_none());
835
836        let sample2 = Sample::fraction(0.5);
837        assert!(sample2.count().is_none());
838        assert_eq!(sample2.sample_fraction(), Some(0.5));
839    }
840
841    #[cfg(feature = "shuffle")]
842    #[test]
843    fn test_sample_debug() {
844        let sample = Sample::new(10);
845        let debug_str = format!("{:?}", sample);
846        assert!(debug_str.contains("Sample"));
847    }
848
849    #[cfg(feature = "shuffle")]
850    #[test]
851    fn test_sample_with_seed() {
852        let batch = create_test_batch();
853        let sample = Sample::new(3).with_seed(42);
854
855        let result1 = sample.apply(batch.clone());
856        let result2 = sample.apply(batch);
857
858        assert!(result1.is_ok());
859        assert!(result2.is_ok());
860
861        let result1 = result1.ok().unwrap_or_else(|| panic!("Should succeed"));
862        let result2 = result2.ok().unwrap_or_else(|| panic!("Should succeed"));
863
864        assert_eq!(result1.num_rows(), result2.num_rows());
865    }
866
867    #[cfg(feature = "shuffle")]
868    #[test]
869    fn test_sample_zero_count() {
870        let batch = create_test_batch();
871        let sample = Sample::new(0);
872        let result = sample.apply(batch);
873        assert!(result.is_ok());
874        let result = result.ok().unwrap();
875        assert_eq!(result.num_rows(), 0);
876    }
877
878    #[cfg(feature = "shuffle")]
879    #[test]
880    fn test_sample_fraction_zero() {
881        let batch = create_test_batch();
882        let sample = Sample::fraction(0.0);
883        let result = sample.apply(batch);
884        assert!(result.is_ok());
885        let result = result.ok().unwrap();
886        assert_eq!(result.num_rows(), 0);
887    }
888
889    #[cfg(feature = "shuffle")]
890    #[test]
891    fn test_sample_fraction_full() {
892        let batch = create_test_batch();
893        let sample = Sample::fraction(1.0);
894        let result = sample.apply(batch.clone());
895        assert!(result.is_ok());
896        let result = result.ok().unwrap();
897        assert_eq!(result.num_rows(), batch.num_rows());
898    }
899
900    #[cfg(feature = "shuffle")]
901    #[test]
902    fn test_sample_fraction_negative_clamped() {
903        let sample = Sample::fraction(-0.5);
904        // Negative fraction should be clamped to 0.0
905        assert_eq!(sample.sample_fraction(), Some(0.0));
906
907        let batch = create_test_batch();
908        let result = sample.apply(batch);
909        assert!(result.is_ok());
910        let result = result.ok().unwrap();
911        assert_eq!(result.num_rows(), 0);
912    }
913
914    #[cfg(feature = "shuffle")]
915    #[test]
916    fn test_sample_fraction_over_one_clamped() {
917        let sample = Sample::fraction(1.5);
918        // Fraction > 1.0 should be clamped to 1.0
919        assert_eq!(sample.sample_fraction(), Some(1.0));
920
921        let batch = create_test_batch();
922        let result = sample.apply(batch.clone());
923        assert!(result.is_ok());
924        let result = result.ok().unwrap();
925        assert_eq!(result.num_rows(), batch.num_rows());
926    }
927
928    // Take transform tests
929
930    #[test]
931    fn test_take_transform() {
932        let batch = create_test_batch();
933        let transform = Take::new(3);
934
935        let result = transform.apply(batch);
936        assert!(result.is_ok());
937        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
938        assert_eq!(result.num_rows(), 3);
939
940        let ids = result
941            .column(0)
942            .as_any()
943            .downcast_ref::<Int32Array>()
944            .unwrap_or_else(|| panic!("Should be Int32Array"));
945        assert_eq!(ids.value(0), 1);
946        assert_eq!(ids.value(1), 2);
947        assert_eq!(ids.value(2), 3);
948    }
949
950    #[test]
951    fn test_take_more_than_available() {
952        let batch = create_test_batch();
953        let transform = Take::new(100);
954
955        let result = transform.apply(batch.clone());
956        assert!(result.is_ok());
957        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
958        assert_eq!(result.num_rows(), batch.num_rows());
959    }
960
961    #[test]
962    fn test_take_count_getter() {
963        let take = Take::new(42);
964        assert_eq!(take.count(), 42);
965    }
966
967    #[test]
968    fn test_take_debug() {
969        let take = Take::new(10);
970        let debug_str = format!("{:?}", take);
971        assert!(debug_str.contains("Take"));
972    }
973
974    #[test]
975    fn test_take_zero_rows() {
976        let batch = create_test_batch();
977        let take = Take::new(0);
978        let result = take.apply(batch);
979        assert!(result.is_ok());
980        let result = result.ok().unwrap();
981        assert_eq!(result.num_rows(), 0);
982    }
983
984    #[test]
985    fn test_take_beyond_bounds() {
986        let batch = create_test_batch(); // 5 rows
987        let take = Take::new(100); // Request more than available
988        let result = take.apply(batch);
989        assert!(result.is_ok());
990        let result = result.ok().unwrap();
991        assert_eq!(result.num_rows(), 5); // Should return all rows
992    }
993
994    // Skip transform tests
995
996    #[test]
997    fn test_skip_transform() {
998        let batch = create_test_batch();
999        let transform = Skip::new(2);
1000
1001        let result = transform.apply(batch);
1002        assert!(result.is_ok());
1003        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1004        assert_eq!(result.num_rows(), 3);
1005
1006        let ids = result
1007            .column(0)
1008            .as_any()
1009            .downcast_ref::<Int32Array>()
1010            .unwrap_or_else(|| panic!("Should be Int32Array"));
1011        assert_eq!(ids.value(0), 3);
1012        assert_eq!(ids.value(1), 4);
1013        assert_eq!(ids.value(2), 5);
1014    }
1015
1016    #[test]
1017    fn test_skip_all_rows() {
1018        let batch = create_test_batch();
1019        let transform = Skip::new(10);
1020
1021        let result = transform.apply(batch);
1022        assert!(result.is_ok());
1023        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1024        assert_eq!(result.num_rows(), 0);
1025    }
1026
1027    #[test]
1028    fn test_skip_count_getter() {
1029        let skip = Skip::new(5);
1030        assert_eq!(skip.count(), 5);
1031    }
1032
1033    #[test]
1034    fn test_skip_debug() {
1035        let skip = Skip::new(5);
1036        let debug_str = format!("{:?}", skip);
1037        assert!(debug_str.contains("Skip"));
1038    }
1039
1040    #[test]
1041    fn test_skip_more_than_batch_size() {
1042        let batch = create_test_batch();
1043        let skip = Skip::new(100);
1044        let result = skip.apply(batch);
1045        assert!(result.is_ok());
1046        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1047        assert_eq!(result.num_rows(), 0);
1048    }
1049
1050    #[test]
1051    fn test_skip_beyond_bounds() {
1052        let batch = create_test_batch(); // 5 rows
1053        let skip = Skip::new(100); // Skip more than available
1054        let result = skip.apply(batch);
1055        assert!(result.is_ok());
1056        let result = result.ok().unwrap();
1057        assert_eq!(result.num_rows(), 0); // Should return empty
1058    }
1059
1060    #[test]
1061    fn test_skip_zero_rows() {
1062        let batch = create_test_batch();
1063        let original_rows = batch.num_rows();
1064        let skip = Skip::new(0);
1065        let result = skip.apply(batch);
1066        assert!(result.is_ok());
1067        let result = result.ok().unwrap();
1068        // Skipping 0 should return all rows
1069        assert_eq!(result.num_rows(), original_rows);
1070    }
1071
1072    // Sort transform tests
1073
1074    #[test]
1075    fn test_sort_ascending() {
1076        let schema = Arc::new(Schema::new(vec![Field::new(
1077            "value",
1078            DataType::Int32,
1079            false,
1080        )]));
1081        let values = Int32Array::from(vec![3, 1, 4, 1, 5]);
1082        let batch = RecordBatch::try_new(schema, vec![Arc::new(values)])
1083            .ok()
1084            .unwrap_or_else(|| panic!("Should create batch"));
1085
1086        let transform = Sort::by("value");
1087        let result = transform.apply(batch);
1088        assert!(result.is_ok());
1089        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1090
1091        let col = result
1092            .column(0)
1093            .as_any()
1094            .downcast_ref::<Int32Array>()
1095            .unwrap_or_else(|| panic!("Should be Int32Array"));
1096        assert_eq!(col.value(0), 1);
1097        assert_eq!(col.value(1), 1);
1098        assert_eq!(col.value(2), 3);
1099        assert_eq!(col.value(3), 4);
1100        assert_eq!(col.value(4), 5);
1101    }
1102
1103    #[test]
1104    fn test_sort_descending() {
1105        let schema = Arc::new(Schema::new(vec![Field::new(
1106            "value",
1107            DataType::Int32,
1108            false,
1109        )]));
1110        let values = Int32Array::from(vec![3, 1, 4, 1, 5]);
1111        let batch = RecordBatch::try_new(schema, vec![Arc::new(values)])
1112            .ok()
1113            .unwrap_or_else(|| panic!("Should create batch"));
1114
1115        let transform = Sort::by("value").order(SortOrder::Descending);
1116        let result = transform.apply(batch);
1117        assert!(result.is_ok());
1118        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1119
1120        let col = result
1121            .column(0)
1122            .as_any()
1123            .downcast_ref::<Int32Array>()
1124            .unwrap_or_else(|| panic!("Should be Int32Array"));
1125        assert_eq!(col.value(0), 5);
1126        assert_eq!(col.value(1), 4);
1127        assert_eq!(col.value(2), 3);
1128        assert_eq!(col.value(3), 1);
1129        assert_eq!(col.value(4), 1);
1130    }
1131
1132    #[test]
1133    fn test_sort_preserves_row_integrity() {
1134        let batch = create_test_batch();
1135        let transform = Sort::by("id").order(SortOrder::Descending);
1136
1137        let result = transform.apply(batch);
1138        assert!(result.is_ok());
1139        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1140
1141        let ids = result
1142            .column(0)
1143            .as_any()
1144            .downcast_ref::<Int32Array>()
1145            .unwrap_or_else(|| panic!("Should be Int32Array"));
1146        let values = result
1147            .column(2)
1148            .as_any()
1149            .downcast_ref::<Int32Array>()
1150            .unwrap_or_else(|| panic!("Should be Int32Array"));
1151
1152        // Verify rows are still correlated
1153        for i in 0..ids.len() {
1154            let id = ids.value(i);
1155            let value = values.value(i);
1156            assert_eq!(value, id * 10);
1157        }
1158
1159        // Verify descending order
1160        assert_eq!(ids.value(0), 5);
1161        assert_eq!(ids.value(4), 1);
1162    }
1163
1164    #[test]
1165    fn test_sort_multiple_columns() {
1166        let schema = Arc::new(Schema::new(vec![
1167            Field::new("group", DataType::Int32, false),
1168            Field::new("value", DataType::Int32, false),
1169        ]));
1170        let groups = Int32Array::from(vec![1, 2, 1, 2, 1]);
1171        let values = Int32Array::from(vec![30, 10, 10, 20, 20]);
1172        let batch = RecordBatch::try_new(schema, vec![Arc::new(groups), Arc::new(values)])
1173            .ok()
1174            .unwrap_or_else(|| panic!("Should create batch"));
1175
1176        let transform = Sort::by_columns(vec![
1177            ("group", SortOrder::Ascending),
1178            ("value", SortOrder::Ascending),
1179        ]);
1180        let result = transform.apply(batch);
1181        assert!(result.is_ok());
1182        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1183
1184        let groups = result
1185            .column(0)
1186            .as_any()
1187            .downcast_ref::<Int32Array>()
1188            .unwrap_or_else(|| panic!("Should be Int32Array"));
1189        let values = result
1190            .column(1)
1191            .as_any()
1192            .downcast_ref::<Int32Array>()
1193            .unwrap_or_else(|| panic!("Should be Int32Array"));
1194
1195        // Group 1 first, sorted by value
1196        assert_eq!(groups.value(0), 1);
1197        assert_eq!(values.value(0), 10);
1198        assert_eq!(groups.value(1), 1);
1199        assert_eq!(values.value(1), 20);
1200        assert_eq!(groups.value(2), 1);
1201        assert_eq!(values.value(2), 30);
1202        // Then group 2
1203        assert_eq!(groups.value(3), 2);
1204        assert_eq!(values.value(3), 10);
1205        assert_eq!(groups.value(4), 2);
1206        assert_eq!(values.value(4), 20);
1207    }
1208
1209    #[test]
1210    fn test_sort_column_not_found() {
1211        let batch = create_test_batch();
1212        let transform = Sort::by("nonexistent");
1213
1214        let result = transform.apply(batch);
1215        assert!(result.is_err());
1216    }
1217
1218    #[test]
1219    fn test_sort_columns_getter() {
1220        let sort = Sort::by("value").order(SortOrder::Descending);
1221        let cols = sort.columns();
1222        assert_eq!(cols.len(), 1);
1223        assert_eq!(cols[0].0, "value");
1224        assert_eq!(cols[0].1, SortOrder::Descending);
1225    }
1226
1227    #[test]
1228    fn test_sort_order_default() {
1229        let order = SortOrder::default();
1230        assert_eq!(order, SortOrder::Ascending);
1231    }
1232
1233    #[test]
1234    fn test_sort_single_row() {
1235        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1236        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1]))])
1237            .ok()
1238            .unwrap_or_else(|| panic!("Should create batch"));
1239
1240        let transform = Sort::by("id");
1241        let result = transform.apply(batch);
1242        assert!(result.is_ok());
1243    }
1244
1245    #[test]
1246    fn test_sort_debug() {
1247        let sort = Sort::by("col");
1248        let debug_str = format!("{:?}", sort);
1249        assert!(debug_str.contains("Sort"));
1250    }
1251
1252    #[test]
1253    fn test_sort_empty_batch() {
1254        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1255        let empty_batch = RecordBatch::new_empty(schema);
1256
1257        let sort = Sort::by("id");
1258        let result = sort.apply(empty_batch);
1259        assert!(result.is_ok());
1260    }
1261
1262    #[test]
1263    fn test_sort_empty_columns_vector() {
1264        let batch = create_test_batch();
1265        let sort = Sort::by_columns::<String>(vec![]);
1266        let result = sort.apply(batch.clone());
1267        assert!(result.is_ok());
1268        let result = result.ok().unwrap();
1269        // Empty sort columns should return unchanged batch
1270        assert_eq!(result.num_rows(), batch.num_rows());
1271    }
1272
1273    #[test]
1274    fn test_sort_multi_column_one_missing() {
1275        let batch = create_test_batch();
1276        let sort = Sort::by_columns(vec![
1277            ("value".to_string(), SortOrder::Ascending),
1278            ("nonexistent".to_string(), SortOrder::Ascending),
1279        ]);
1280        let result = sort.apply(batch);
1281        // Should error because second column doesn't exist
1282        assert!(result.is_err());
1283    }
1284
1285    #[test]
1286    fn test_sort_single_row_unchanged() {
1287        // Create single-row batch
1288        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1289        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![42]))])
1290            .ok()
1291            .unwrap();
1292
1293        let sort = Sort::by_columns(vec![("id".to_string(), SortOrder::Ascending)]);
1294        let result = sort.apply(batch.clone());
1295        assert!(result.is_ok());
1296        let result = result.ok().unwrap();
1297        // Single row should be returned unchanged
1298        assert_eq!(result.num_rows(), 1);
1299    }
1300
1301    // Unique tests
1302
1303    #[test]
1304    fn test_unique_all_columns() {
1305        let schema = Arc::new(Schema::new(vec![
1306            Field::new("id", DataType::Int32, false),
1307            Field::new("value", DataType::Int32, false),
1308        ]));
1309        let ids = Int32Array::from(vec![1, 2, 1, 2, 1]); // Duplicates
1310        let values = Int32Array::from(vec![10, 20, 10, 20, 30]); // Row 0 == Row 2, Row 1 == Row 3
1311        let batch = RecordBatch::try_new(schema, vec![Arc::new(ids), Arc::new(values)])
1312            .ok()
1313            .unwrap_or_else(|| panic!("Should create batch"));
1314
1315        let transform = Unique::all();
1316        let result = transform.apply(batch);
1317        assert!(result.is_ok());
1318        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1319
1320        assert_eq!(result.num_rows(), 3); // Only 3 unique rows
1321    }
1322
1323    #[test]
1324    fn test_unique_by_column() {
1325        let schema = Arc::new(Schema::new(vec![
1326            Field::new("id", DataType::Int32, false),
1327            Field::new("value", DataType::Int32, false),
1328        ]));
1329        let ids = Int32Array::from(vec![1, 2, 1, 2, 3]);
1330        let values = Int32Array::from(vec![10, 20, 30, 40, 50]);
1331        let batch = RecordBatch::try_new(schema, vec![Arc::new(ids), Arc::new(values)])
1332            .ok()
1333            .unwrap_or_else(|| panic!("Should create batch"));
1334
1335        let transform = Unique::by(vec!["id"]);
1336        let result = transform.apply(batch);
1337        assert!(result.is_ok());
1338        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1339
1340        assert_eq!(result.num_rows(), 3); // ids 1, 2, 3
1341
1342        let ids = result
1343            .column(0)
1344            .as_any()
1345            .downcast_ref::<Int32Array>()
1346            .unwrap_or_else(|| panic!("Should be Int32Array"));
1347        let values = result
1348            .column(1)
1349            .as_any()
1350            .downcast_ref::<Int32Array>()
1351            .unwrap_or_else(|| panic!("Should be Int32Array"));
1352
1353        // Keep first occurrences by default
1354        assert_eq!(ids.value(0), 1);
1355        assert_eq!(values.value(0), 10); // First occurrence of id=1
1356        assert_eq!(ids.value(1), 2);
1357        assert_eq!(values.value(1), 20); // First occurrence of id=2
1358    }
1359
1360    #[test]
1361    fn test_unique_keep_last() {
1362        let schema = Arc::new(Schema::new(vec![
1363            Field::new("id", DataType::Int32, false),
1364            Field::new("value", DataType::Int32, false),
1365        ]));
1366        let ids = Int32Array::from(vec![1, 2, 1, 2, 3]);
1367        let values = Int32Array::from(vec![10, 20, 30, 40, 50]);
1368        let batch = RecordBatch::try_new(schema, vec![Arc::new(ids), Arc::new(values)])
1369            .ok()
1370            .unwrap_or_else(|| panic!("Should create batch"));
1371
1372        let transform = Unique::by(vec!["id"]).keep_last();
1373        let result = transform.apply(batch);
1374        assert!(result.is_ok());
1375        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1376
1377        assert_eq!(result.num_rows(), 3);
1378
1379        let ids = result
1380            .column(0)
1381            .as_any()
1382            .downcast_ref::<Int32Array>()
1383            .unwrap_or_else(|| panic!("Should be Int32Array"));
1384        let values = result
1385            .column(1)
1386            .as_any()
1387            .downcast_ref::<Int32Array>()
1388            .unwrap_or_else(|| panic!("Should be Int32Array"));
1389
1390        // Keep last occurrences
1391        assert_eq!(ids.value(0), 1);
1392        assert_eq!(values.value(0), 30); // Last occurrence of id=1
1393        assert_eq!(ids.value(1), 2);
1394        assert_eq!(values.value(1), 40); // Last occurrence of id=2
1395    }
1396
1397    #[test]
1398    fn test_unique_no_duplicates() {
1399        let batch = create_test_batch();
1400        let transform = Unique::all();
1401
1402        let result = transform.apply(batch.clone());
1403        assert!(result.is_ok());
1404        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1405
1406        assert_eq!(result.num_rows(), batch.num_rows()); // All unique
1407    }
1408
1409    #[test]
1410    fn test_unique_column_not_found() {
1411        let batch = create_test_batch();
1412        let transform = Unique::by(vec!["nonexistent"]);
1413
1414        let result = transform.apply(batch);
1415        assert!(result.is_err());
1416    }
1417
1418    #[test]
1419    fn test_unique_columns_getter() {
1420        let unique = Unique::by(vec!["a", "b"]);
1421        assert!(unique.columns().is_some());
1422        assert_eq!(
1423            unique
1424                .columns()
1425                .unwrap_or_else(|| panic!("Should have columns")),
1426            &["a", "b"]
1427        );
1428
1429        let unique2 = Unique::all();
1430        assert!(unique2.columns().is_none());
1431    }
1432
1433    #[test]
1434    fn test_unique_single_row() {
1435        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1436        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1]))])
1437            .ok()
1438            .unwrap_or_else(|| panic!("Should create batch"));
1439
1440        let transform = Unique::all();
1441        let result = transform.apply(batch);
1442        assert!(result.is_ok());
1443        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1444        assert_eq!(result.num_rows(), 1);
1445    }
1446
1447    #[test]
1448    fn test_unique_debug() {
1449        let unique = Unique::all();
1450        let debug_str = format!("{:?}", unique);
1451        assert!(debug_str.contains("Unique"));
1452    }
1453
1454    #[test]
1455    fn test_unique_with_int64_column() {
1456        use arrow::array::Int64Array;
1457
1458        let schema = Arc::new(Schema::new(vec![
1459            Field::new("id", DataType::Int64, false),
1460            Field::new("name", DataType::Utf8, false),
1461        ]));
1462
1463        let id_arr = Int64Array::from(vec![1i64, 1i64, 2i64, 2i64, 3i64]);
1464        let name_arr = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1465
1466        let batch = RecordBatch::try_new(schema, vec![Arc::new(id_arr), Arc::new(name_arr)])
1467            .ok()
1468            .unwrap_or_else(|| panic!("batch"));
1469
1470        let unique = Unique::by(vec!["id"]);
1471        let result = unique.apply(batch);
1472        assert!(result.is_ok());
1473        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1474        assert_eq!(result.num_rows(), 3);
1475    }
1476
1477    #[test]
1478    fn test_unique_with_float64_column() {
1479        use arrow::array::Float64Array;
1480
1481        let schema = Arc::new(Schema::new(vec![
1482            Field::new("val", DataType::Float64, false),
1483            Field::new("name", DataType::Utf8, false),
1484        ]));
1485
1486        let val_arr = Float64Array::from(vec![1.0f64, 1.0f64, 2.0f64, 2.0f64, 3.0f64]);
1487        let name_arr = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1488
1489        let batch = RecordBatch::try_new(schema, vec![Arc::new(val_arr), Arc::new(name_arr)])
1490            .ok()
1491            .unwrap_or_else(|| panic!("batch"));
1492
1493        let unique = Unique::by(vec!["val"]);
1494        let result = unique.apply(batch);
1495        assert!(result.is_ok());
1496        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1497        assert_eq!(result.num_rows(), 3);
1498    }
1499
1500    #[test]
1501    fn test_unique_with_float32_column() {
1502        use arrow::array::Float32Array;
1503
1504        let schema = Arc::new(Schema::new(vec![
1505            Field::new("val", DataType::Float32, false),
1506            Field::new("name", DataType::Utf8, false),
1507        ]));
1508
1509        let val_arr = Float32Array::from(vec![1.0f32, 1.0f32, 2.0f32, 2.0f32, 3.0f32]);
1510        let name_arr = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1511
1512        let batch = RecordBatch::try_new(schema, vec![Arc::new(val_arr), Arc::new(name_arr)])
1513            .ok()
1514            .unwrap_or_else(|| panic!("batch"));
1515
1516        let unique = Unique::by(vec!["val"]);
1517        let result = unique.apply(batch);
1518        assert!(result.is_ok());
1519        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1520        assert_eq!(result.num_rows(), 3);
1521    }
1522
1523    #[test]
1524    fn test_unique_with_bool_column() {
1525        use arrow::array::BooleanArray;
1526
1527        let schema = Arc::new(Schema::new(vec![
1528            Field::new("flag", DataType::Boolean, false),
1529            Field::new("name", DataType::Utf8, false),
1530        ]));
1531
1532        let flag_arr = BooleanArray::from(vec![true, true, false, false, true]);
1533        let name_arr = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1534
1535        let batch = RecordBatch::try_new(schema, vec![Arc::new(flag_arr), Arc::new(name_arr)])
1536            .ok()
1537            .unwrap_or_else(|| panic!("batch"));
1538
1539        let unique = Unique::by(vec!["flag"]);
1540        let result = unique.apply(batch);
1541        assert!(result.is_ok());
1542        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1543        assert_eq!(result.num_rows(), 2);
1544    }
1545
1546    #[test]
1547    fn test_unique_with_null_values() {
1548        let schema = Arc::new(Schema::new(vec![
1549            Field::new("id", DataType::Int32, true),
1550            Field::new("name", DataType::Utf8, false),
1551        ]));
1552
1553        let id_arr = Int32Array::from(vec![Some(1), None, Some(1), None, Some(2)]);
1554        let name_arr = StringArray::from(vec!["a", "b", "c", "d", "e"]);
1555
1556        let batch = RecordBatch::try_new(schema, vec![Arc::new(id_arr), Arc::new(name_arr)])
1557            .ok()
1558            .unwrap_or_else(|| panic!("batch"));
1559
1560        let unique = Unique::by(vec!["id"]);
1561        let result = unique.apply(batch);
1562        assert!(result.is_ok());
1563        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
1564        assert_eq!(result.num_rows(), 3); // 1, NULL, 2
1565    }
1566
1567    #[test]
1568    fn test_unique_empty_batch() {
1569        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1570        let arr = Int32Array::from(Vec::<i32>::new());
1571        let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)])
1572            .ok()
1573            .unwrap_or_else(|| panic!("batch"));
1574
1575        let unique = Unique::by(["id"]);
1576        let result = unique.apply(batch);
1577        assert!(result.is_ok());
1578        let result = result.ok().unwrap();
1579        assert_eq!(result.num_rows(), 0);
1580    }
1581}