Skip to main content

alimentar/transform/
mod.rs

1//! Data transforms for alimentar.
2//!
3//! Transforms apply operations to RecordBatches, enabling data preprocessing
4//! pipelines. All transforms are composable and can be chained together.
5
6use std::sync::Arc;
7
8use arrow::{
9    array::{BooleanArray, RecordBatch},
10    compute::filter_record_batch,
11};
12
13use crate::error::{Error, Result};
14
15#[cfg(feature = "shuffle")]
16mod fim;
17mod numeric;
18mod row_ops;
19mod selection;
20
21#[cfg(feature = "shuffle")]
22pub use fim::{Fim, FimFormat, FimTokens};
23pub use numeric::{Cast, FillNull, FillStrategy, NormMethod, Normalize};
24#[cfg(feature = "shuffle")]
25pub use row_ops::{Sample, Shuffle};
26pub use row_ops::{Skip, Sort, SortOrder, Take, Unique};
27pub use selection::{Drop, Rename, Select};
28
29/// A transform that can be applied to RecordBatches.
30///
31/// Transforms are the building blocks for data preprocessing pipelines.
32/// They take a RecordBatch and produce a new RecordBatch with the
33/// transformation applied.
34///
35/// # Thread Safety
36///
37/// All transforms must be thread-safe (Send + Sync) to support parallel
38/// data loading in future versions.
39pub trait Transform: Send + Sync {
40    /// Applies the transform to a RecordBatch.
41    ///
42    /// # Errors
43    ///
44    /// Returns an error if the transform cannot be applied to the batch.
45    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch>;
46}
47
48/// A transform that applies a function to each RecordBatch.
49///
50/// # Example
51///
52/// ```ignore
53/// use alimentar::Map;
54///
55/// let transform = Map::new(|batch| {
56///     // Process batch
57///     Ok(batch)
58/// });
59/// ```
60pub struct Map<F>
61where
62    F: Fn(RecordBatch) -> Result<RecordBatch> + Send + Sync,
63{
64    func: F,
65}
66
67impl<F> Map<F>
68where
69    F: Fn(RecordBatch) -> Result<RecordBatch> + Send + Sync,
70{
71    /// Creates a new Map transform with the given function.
72    pub fn new(func: F) -> Self {
73        Self { func }
74    }
75}
76
77impl<F> Transform for Map<F>
78where
79    F: Fn(RecordBatch) -> Result<RecordBatch> + Send + Sync,
80{
81    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
82        (self.func)(batch)
83    }
84}
85
86/// A transform that filters rows based on a predicate.
87///
88/// The predicate function receives a RecordBatch and must return a BooleanArray
89/// with the same number of rows, where `true` indicates the row should be kept.
90///
91/// # Example
92///
93/// ```ignore
94/// use alimentar::Filter;
95/// use arrow::array::{Int32Array, BooleanArray};
96///
97/// let filter = Filter::new(|batch| {
98///     let col = batch.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
99///     let mask: Vec<bool> = (0..col.len()).map(|i| col.value(i) > 5).collect();
100///     Ok(BooleanArray::from(mask))
101/// });
102/// ```
103pub struct Filter<F>
104where
105    F: Fn(&RecordBatch) -> Result<BooleanArray> + Send + Sync,
106{
107    predicate: F,
108}
109
110impl<F> Filter<F>
111where
112    F: Fn(&RecordBatch) -> Result<BooleanArray> + Send + Sync,
113{
114    /// Creates a new Filter transform with the given predicate.
115    pub fn new(predicate: F) -> Self {
116        Self { predicate }
117    }
118}
119
120impl<F> Transform for Filter<F>
121where
122    F: Fn(&RecordBatch) -> Result<BooleanArray> + Send + Sync,
123{
124    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
125        let mask = (self.predicate)(&batch)?;
126        filter_record_batch(&batch, &mask).map_err(Error::Arrow)
127    }
128}
129
130/// A chain of transforms applied in sequence.
131///
132/// # Example
133///
134/// ```ignore
135/// use alimentar::{Chain, Select, Shuffle};
136///
137/// let chain = Chain::new()
138///     .then(Select::new(vec!["id", "value"]))
139///     .then(Shuffle::with_seed(42));
140/// ```
141pub struct Chain {
142    transforms: Vec<Box<dyn Transform>>,
143}
144
145impl Chain {
146    /// Creates a new empty transform chain.
147    pub fn new() -> Self {
148        Self {
149            transforms: Vec::new(),
150        }
151    }
152
153    /// Adds a transform to the chain.
154    #[must_use]
155    pub fn then<T: Transform + 'static>(mut self, transform: T) -> Self {
156        self.transforms.push(Box::new(transform));
157        self
158    }
159
160    /// Returns the number of transforms in the chain.
161    pub fn len(&self) -> usize {
162        self.transforms.len()
163    }
164
165    /// Returns true if the chain has no transforms.
166    pub fn is_empty(&self) -> bool {
167        self.transforms.is_empty()
168    }
169}
170
171impl Default for Chain {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177impl Transform for Chain {
178    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
179        let mut result = batch;
180        for transform in &self.transforms {
181            result = transform.apply(result)?;
182        }
183        Ok(result)
184    }
185}
186
187// Implement Transform for boxed transforms
188impl Transform for Box<dyn Transform> {
189    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
190        (**self).apply(batch)
191    }
192}
193
194// Implement Transform for Arc<dyn Transform>
195impl Transform for Arc<dyn Transform> {
196    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
197        (**self).apply(batch)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use arrow::{
204        array::{Int32Array, StringArray},
205        datatypes::{DataType, Field, Schema},
206    };
207
208    use super::*;
209
210    fn create_test_batch() -> RecordBatch {
211        let schema = Arc::new(Schema::new(vec![
212            Field::new("id", DataType::Int32, false),
213            Field::new("name", DataType::Utf8, false),
214            Field::new("value", DataType::Int32, false),
215        ]));
216
217        let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
218        let name_array = StringArray::from(vec!["a", "b", "c", "d", "e"]);
219        let value_array = Int32Array::from(vec![10, 20, 30, 40, 50]);
220
221        RecordBatch::try_new(
222            schema,
223            vec![
224                Arc::new(id_array),
225                Arc::new(name_array),
226                Arc::new(value_array),
227            ],
228        )
229        .ok()
230        .unwrap_or_else(|| panic!("Should create batch"))
231    }
232
233    #[test]
234    fn test_map_transform() {
235        let batch = create_test_batch();
236        let transform = Map::new(Ok); // Identity transform
237
238        let result = transform.apply(batch.clone());
239        assert!(result.is_ok());
240        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
241        assert_eq!(result.num_rows(), batch.num_rows());
242    }
243
244    #[test]
245    fn test_filter_transform() {
246        let batch = create_test_batch();
247        let transform = Filter::new(|b| {
248            let col = b
249                .column(0)
250                .as_any()
251                .downcast_ref::<Int32Array>()
252                .ok_or_else(|| Error::transform("Expected Int32Array"))?;
253            let mask: Vec<bool> = (0..col.len()).map(|i| col.value(i) > 2).collect();
254            Ok(BooleanArray::from(mask))
255        });
256
257        let result = transform.apply(batch);
258        assert!(result.is_ok());
259        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
260        assert_eq!(result.num_rows(), 3); // Only id > 2: 3, 4, 5
261    }
262
263    #[test]
264    fn test_chain_transform() {
265        let batch = create_test_batch();
266        let chain = Chain::new()
267            .then(Select::new(vec!["id", "value"]))
268            .then(Take::new(3));
269
270        assert_eq!(chain.len(), 2);
271        assert!(!chain.is_empty());
272
273        let result = chain.apply(batch);
274        assert!(result.is_ok());
275        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
276        assert_eq!(result.num_columns(), 2);
277        assert_eq!(result.num_rows(), 3);
278    }
279
280    #[test]
281    fn test_empty_chain() {
282        let batch = create_test_batch();
283        let chain = Chain::new();
284
285        assert!(chain.is_empty());
286
287        let result = chain.apply(batch.clone());
288        assert!(result.is_ok());
289        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
290        assert_eq!(result.num_rows(), batch.num_rows());
291    }
292
293    #[test]
294    fn test_filter_empty_result() {
295        let batch = create_test_batch();
296        let filter = Filter::new(|batch| Ok(BooleanArray::from(vec![false; batch.num_rows()])));
297
298        let result = filter.apply(batch);
299        assert!(result.is_ok());
300        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
301        assert_eq!(result.num_rows(), 0);
302    }
303
304    #[test]
305    fn test_map_with_error() {
306        let batch = create_test_batch();
307        let map = Map::new(|_batch| Err(crate::Error::transform("intentional error")));
308        let result = map.apply(batch);
309        assert!(result.is_err());
310    }
311
312    #[test]
313    fn test_filter_closure() {
314        let batch = create_test_batch();
315        // Test with a closure that filters to only rows where id > 2
316        let filter = Filter::new(|batch: &RecordBatch| {
317            let id_col = batch.column(0).as_any().downcast_ref::<Int32Array>();
318            if let Some(arr) = id_col {
319                let mask: Vec<bool> = (0..arr.len()).map(|i| arr.value(i) > 2).collect();
320                Ok(arrow::array::BooleanArray::from(mask))
321            } else {
322                Ok(arrow::array::BooleanArray::from(vec![
323                    false;
324                    batch.num_rows()
325                ]))
326            }
327        });
328        let result = filter.apply(batch);
329        assert!(result.is_ok());
330        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
331        assert_eq!(result.num_rows(), 3); // rows with id 3, 4, 5
332    }
333
334    #[test]
335    fn test_filter_all_rows_filtered() {
336        let batch = create_test_batch();
337        // Filter that removes all rows (5 rows in test batch)
338        let filter = Filter::new(|_batch: &RecordBatch| {
339            Ok(arrow::array::BooleanArray::from(vec![false; 5]))
340        });
341        let result = filter.apply(batch);
342        assert!(result.is_ok());
343        let result = result.ok().unwrap();
344        assert_eq!(result.num_rows(), 0);
345    }
346
347    #[test]
348    fn test_map_error_propagation() {
349        let batch = create_test_batch();
350        // Map that returns error
351        let map = Map::new(|_batch: RecordBatch| Err(crate::Error::transform("Test error")));
352        let result = map.apply(batch);
353        assert!(result.is_err());
354    }
355
356    #[test]
357    fn test_chain_empty_transforms() {
358        let batch = create_test_batch();
359        let chain: Chain = Chain::new();
360        let result = chain.apply(batch.clone());
361        assert!(result.is_ok());
362        let result = result.ok().unwrap();
363        assert_eq!(result.num_rows(), batch.num_rows());
364    }
365
366    #[test]
367    fn test_boxed_transform_delegation() {
368        let batch = create_test_batch();
369        let take = Take::new(2);
370        let boxed: Box<dyn Transform> = Box::new(take);
371        let result = boxed.apply(batch);
372        assert!(result.is_ok());
373        let result = result.ok().unwrap();
374        assert_eq!(result.num_rows(), 2);
375    }
376
377    #[test]
378    fn test_arc_transform_delegation() {
379        use std::sync::Arc as StdArc;
380        let batch = create_test_batch();
381        let take = Take::new(3);
382        let arced: StdArc<dyn Transform> = StdArc::new(take);
383        let result = arced.apply(batch);
384        assert!(result.is_ok());
385        let result = result.ok().unwrap();
386        assert_eq!(result.num_rows(), 3);
387    }
388
389    #[test]
390    fn test_chain_single_transform() {
391        let batch = create_test_batch();
392        let chain = Chain::new().then(Take::new(2));
393        let result = chain.apply(batch);
394        assert!(result.is_ok());
395        let result = result.ok().unwrap();
396        assert_eq!(result.num_rows(), 2);
397    }
398
399    #[test]
400    fn test_chain_with_multiple_transforms() {
401        let batch = create_test_batch();
402
403        let chain = Chain::new()
404            .then(Select::new(vec!["id", "name"]))
405            .then(Rename::from_pairs([("id", "identifier")]));
406
407        let result = chain.apply(batch);
408        assert!(result.is_ok());
409        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
410        assert!(result.schema().field_with_name("identifier").is_ok());
411    }
412}