alimentar/transform/
mod.rs1use std::sync::Arc;
7
8use arrow::{
9 array::{BooleanArray, RecordBatch},
10 compute::filter_record_batch,
11};
12
13use crate::error::{Error, Result};
14
15#[cfg(feature = "shuffle")]
16mod fim;
17mod numeric;
18mod row_ops;
19mod selection;
20
21#[cfg(feature = "shuffle")]
22pub use fim::{Fim, FimFormat, FimTokens};
23pub use numeric::{Cast, FillNull, FillStrategy, NormMethod, Normalize};
24#[cfg(feature = "shuffle")]
25pub use row_ops::{Sample, Shuffle};
26pub use row_ops::{Skip, Sort, SortOrder, Take, Unique};
27pub use selection::{Drop, Rename, Select};
28
29pub trait Transform: Send + Sync {
40 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch>;
46}
47
48pub struct Map<F>
61where
62 F: Fn(RecordBatch) -> Result<RecordBatch> + Send + Sync,
63{
64 func: F,
65}
66
67impl<F> Map<F>
68where
69 F: Fn(RecordBatch) -> Result<RecordBatch> + Send + Sync,
70{
71 pub fn new(func: F) -> Self {
73 Self { func }
74 }
75}
76
77impl<F> Transform for Map<F>
78where
79 F: Fn(RecordBatch) -> Result<RecordBatch> + Send + Sync,
80{
81 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
82 (self.func)(batch)
83 }
84}
85
86pub struct Filter<F>
104where
105 F: Fn(&RecordBatch) -> Result<BooleanArray> + Send + Sync,
106{
107 predicate: F,
108}
109
110impl<F> Filter<F>
111where
112 F: Fn(&RecordBatch) -> Result<BooleanArray> + Send + Sync,
113{
114 pub fn new(predicate: F) -> Self {
116 Self { predicate }
117 }
118}
119
120impl<F> Transform for Filter<F>
121where
122 F: Fn(&RecordBatch) -> Result<BooleanArray> + Send + Sync,
123{
124 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
125 let mask = (self.predicate)(&batch)?;
126 filter_record_batch(&batch, &mask).map_err(Error::Arrow)
127 }
128}
129
130pub struct Chain {
142 transforms: Vec<Box<dyn Transform>>,
143}
144
145impl Chain {
146 pub fn new() -> Self {
148 Self {
149 transforms: Vec::new(),
150 }
151 }
152
153 #[must_use]
155 pub fn then<T: Transform + 'static>(mut self, transform: T) -> Self {
156 self.transforms.push(Box::new(transform));
157 self
158 }
159
160 pub fn len(&self) -> usize {
162 self.transforms.len()
163 }
164
165 pub fn is_empty(&self) -> bool {
167 self.transforms.is_empty()
168 }
169}
170
171impl Default for Chain {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177impl Transform for Chain {
178 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
179 let mut result = batch;
180 for transform in &self.transforms {
181 result = transform.apply(result)?;
182 }
183 Ok(result)
184 }
185}
186
187impl Transform for Box<dyn Transform> {
189 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
190 (**self).apply(batch)
191 }
192}
193
194impl Transform for Arc<dyn Transform> {
196 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
197 (**self).apply(batch)
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use arrow::{
204 array::{Int32Array, StringArray},
205 datatypes::{DataType, Field, Schema},
206 };
207
208 use super::*;
209
210 fn create_test_batch() -> RecordBatch {
211 let schema = Arc::new(Schema::new(vec![
212 Field::new("id", DataType::Int32, false),
213 Field::new("name", DataType::Utf8, false),
214 Field::new("value", DataType::Int32, false),
215 ]));
216
217 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
218 let name_array = StringArray::from(vec!["a", "b", "c", "d", "e"]);
219 let value_array = Int32Array::from(vec![10, 20, 30, 40, 50]);
220
221 RecordBatch::try_new(
222 schema,
223 vec![
224 Arc::new(id_array),
225 Arc::new(name_array),
226 Arc::new(value_array),
227 ],
228 )
229 .ok()
230 .unwrap_or_else(|| panic!("Should create batch"))
231 }
232
233 #[test]
234 fn test_map_transform() {
235 let batch = create_test_batch();
236 let transform = Map::new(Ok); let result = transform.apply(batch.clone());
239 assert!(result.is_ok());
240 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
241 assert_eq!(result.num_rows(), batch.num_rows());
242 }
243
244 #[test]
245 fn test_filter_transform() {
246 let batch = create_test_batch();
247 let transform = Filter::new(|b| {
248 let col = b
249 .column(0)
250 .as_any()
251 .downcast_ref::<Int32Array>()
252 .ok_or_else(|| Error::transform("Expected Int32Array"))?;
253 let mask: Vec<bool> = (0..col.len()).map(|i| col.value(i) > 2).collect();
254 Ok(BooleanArray::from(mask))
255 });
256
257 let result = transform.apply(batch);
258 assert!(result.is_ok());
259 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
260 assert_eq!(result.num_rows(), 3); }
262
263 #[test]
264 fn test_chain_transform() {
265 let batch = create_test_batch();
266 let chain = Chain::new()
267 .then(Select::new(vec!["id", "value"]))
268 .then(Take::new(3));
269
270 assert_eq!(chain.len(), 2);
271 assert!(!chain.is_empty());
272
273 let result = chain.apply(batch);
274 assert!(result.is_ok());
275 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
276 assert_eq!(result.num_columns(), 2);
277 assert_eq!(result.num_rows(), 3);
278 }
279
280 #[test]
281 fn test_empty_chain() {
282 let batch = create_test_batch();
283 let chain = Chain::new();
284
285 assert!(chain.is_empty());
286
287 let result = chain.apply(batch.clone());
288 assert!(result.is_ok());
289 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
290 assert_eq!(result.num_rows(), batch.num_rows());
291 }
292
293 #[test]
294 fn test_filter_empty_result() {
295 let batch = create_test_batch();
296 let filter = Filter::new(|batch| Ok(BooleanArray::from(vec![false; batch.num_rows()])));
297
298 let result = filter.apply(batch);
299 assert!(result.is_ok());
300 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
301 assert_eq!(result.num_rows(), 0);
302 }
303
304 #[test]
305 fn test_map_with_error() {
306 let batch = create_test_batch();
307 let map = Map::new(|_batch| Err(crate::Error::transform("intentional error")));
308 let result = map.apply(batch);
309 assert!(result.is_err());
310 }
311
312 #[test]
313 fn test_filter_closure() {
314 let batch = create_test_batch();
315 let filter = Filter::new(|batch: &RecordBatch| {
317 let id_col = batch.column(0).as_any().downcast_ref::<Int32Array>();
318 if let Some(arr) = id_col {
319 let mask: Vec<bool> = (0..arr.len()).map(|i| arr.value(i) > 2).collect();
320 Ok(arrow::array::BooleanArray::from(mask))
321 } else {
322 Ok(arrow::array::BooleanArray::from(vec![
323 false;
324 batch.num_rows()
325 ]))
326 }
327 });
328 let result = filter.apply(batch);
329 assert!(result.is_ok());
330 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
331 assert_eq!(result.num_rows(), 3); }
333
334 #[test]
335 fn test_filter_all_rows_filtered() {
336 let batch = create_test_batch();
337 let filter = Filter::new(|_batch: &RecordBatch| {
339 Ok(arrow::array::BooleanArray::from(vec![false; 5]))
340 });
341 let result = filter.apply(batch);
342 assert!(result.is_ok());
343 let result = result.ok().unwrap();
344 assert_eq!(result.num_rows(), 0);
345 }
346
347 #[test]
348 fn test_map_error_propagation() {
349 let batch = create_test_batch();
350 let map = Map::new(|_batch: RecordBatch| Err(crate::Error::transform("Test error")));
352 let result = map.apply(batch);
353 assert!(result.is_err());
354 }
355
356 #[test]
357 fn test_chain_empty_transforms() {
358 let batch = create_test_batch();
359 let chain: Chain = Chain::new();
360 let result = chain.apply(batch.clone());
361 assert!(result.is_ok());
362 let result = result.ok().unwrap();
363 assert_eq!(result.num_rows(), batch.num_rows());
364 }
365
366 #[test]
367 fn test_boxed_transform_delegation() {
368 let batch = create_test_batch();
369 let take = Take::new(2);
370 let boxed: Box<dyn Transform> = Box::new(take);
371 let result = boxed.apply(batch);
372 assert!(result.is_ok());
373 let result = result.ok().unwrap();
374 assert_eq!(result.num_rows(), 2);
375 }
376
377 #[test]
378 fn test_arc_transform_delegation() {
379 use std::sync::Arc as StdArc;
380 let batch = create_test_batch();
381 let take = Take::new(3);
382 let arced: StdArc<dyn Transform> = StdArc::new(take);
383 let result = arced.apply(batch);
384 assert!(result.is_ok());
385 let result = result.ok().unwrap();
386 assert_eq!(result.num_rows(), 3);
387 }
388
389 #[test]
390 fn test_chain_single_transform() {
391 let batch = create_test_batch();
392 let chain = Chain::new().then(Take::new(2));
393 let result = chain.apply(batch);
394 assert!(result.is_ok());
395 let result = result.ok().unwrap();
396 assert_eq!(result.num_rows(), 2);
397 }
398
399 #[test]
400 fn test_chain_with_multiple_transforms() {
401 let batch = create_test_batch();
402
403 let chain = Chain::new()
404 .then(Select::new(vec!["id", "name"]))
405 .then(Rename::from_pairs([("id", "identifier")]));
406
407 let result = chain.apply(batch);
408 assert!(result.is_ok());
409 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
410 assert!(result.schema().field_with_name("identifier").is_ok());
411 }
412}