Skip to main content

alimentar/transform/
selection.rs

1//! Column selection and manipulation transforms.
2
3use std::sync::Arc;
4
5use arrow::{
6    array::RecordBatch,
7    datatypes::{Field, Schema},
8};
9
10use super::Transform;
11use crate::error::{Error, Result};
12
13/// A transform that selects specific columns from a RecordBatch.
14///
15/// # Example
16///
17/// ```ignore
18/// use alimentar::Select;
19///
20/// let select = Select::new(vec!["id", "name"]);
21/// ```
22#[derive(Debug, Clone)]
23pub struct Select {
24    columns: Vec<String>,
25}
26
27impl Select {
28    /// Creates a new Select transform for the given column names.
29    pub fn new<S: Into<String>>(columns: impl IntoIterator<Item = S>) -> Self {
30        Self {
31            columns: columns.into_iter().map(Into::into).collect(),
32        }
33    }
34
35    /// Returns the columns to be selected.
36    pub fn columns(&self) -> &[String] {
37        &self.columns
38    }
39}
40
41impl Transform for Select {
42    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
43        let schema = batch.schema();
44        let mut fields = Vec::with_capacity(self.columns.len());
45        let mut arrays = Vec::with_capacity(self.columns.len());
46
47        for col_name in &self.columns {
48            let (idx, field) = schema
49                .column_with_name(col_name)
50                .ok_or_else(|| Error::column_not_found(col_name))?;
51
52            fields.push(field.clone());
53            arrays.push(Arc::clone(batch.column(idx)));
54        }
55
56        let new_schema = Arc::new(Schema::new(fields));
57        RecordBatch::try_new(new_schema, arrays).map_err(Error::Arrow)
58    }
59}
60
61/// A transform that renames columns in a RecordBatch.
62///
63/// # Example
64///
65/// ```ignore
66/// use alimentar::Rename;
67/// use std::collections::HashMap;
68///
69/// let mut mapping = HashMap::new();
70/// mapping.insert("old_name".to_string(), "new_name".to_string());
71/// let rename = Rename::new(mapping);
72/// ```
73#[derive(Debug, Clone)]
74pub struct Rename {
75    mapping: std::collections::HashMap<String, String>,
76}
77
78impl Rename {
79    /// Creates a new Rename transform with the given column mappings.
80    pub fn new(mapping: std::collections::HashMap<String, String>) -> Self {
81        Self { mapping }
82    }
83
84    /// Creates a Rename transform from pairs of (old_name, new_name).
85    pub fn from_pairs<S: Into<String>>(pairs: impl IntoIterator<Item = (S, S)>) -> Self {
86        let mapping = pairs
87            .into_iter()
88            .map(|(old, new)| (old.into(), new.into()))
89            .collect();
90        Self { mapping }
91    }
92}
93
94impl Transform for Rename {
95    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
96        let schema = batch.schema();
97        let new_fields: Vec<Field> = schema
98            .fields()
99            .iter()
100            .map(|field| {
101                let name = field.name();
102                match self.mapping.get(name) {
103                    Some(new_name) => {
104                        Field::new(new_name, field.data_type().clone(), field.is_nullable())
105                    }
106                    None => field.as_ref().clone(),
107                }
108            })
109            .collect();
110
111        let new_schema = Arc::new(Schema::new(new_fields));
112        RecordBatch::try_new(new_schema, batch.columns().to_vec()).map_err(Error::Arrow)
113    }
114}
115
116/// A transform that drops (removes) specified columns from a RecordBatch.
117///
118/// # Example
119///
120/// ```ignore
121/// use alimentar::Drop;
122///
123/// let drop = Drop::new(vec!["temp_column", "debug_info"]);
124/// ```
125#[derive(Debug, Clone)]
126pub struct Drop {
127    columns: Vec<String>,
128}
129
130impl Drop {
131    /// Creates a new Drop transform for the given column names.
132    pub fn new<S: Into<String>>(columns: impl IntoIterator<Item = S>) -> Self {
133        Self {
134            columns: columns.into_iter().map(Into::into).collect(),
135        }
136    }
137
138    /// Returns the columns to be dropped.
139    pub fn columns(&self) -> &[String] {
140        &self.columns
141    }
142}
143
144impl Transform for Drop {
145    fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
146        let schema = batch.schema();
147        let drop_set: std::collections::HashSet<&str> =
148            self.columns.iter().map(String::as_str).collect();
149
150        let mut fields = Vec::new();
151        let mut arrays = Vec::new();
152
153        for (idx, field) in schema.fields().iter().enumerate() {
154            if !drop_set.contains(field.name().as_str()) {
155                fields.push(field.as_ref().clone());
156                arrays.push(Arc::clone(batch.column(idx)));
157            }
158        }
159
160        if fields.is_empty() {
161            return Err(Error::transform("Cannot drop all columns from batch"));
162        }
163
164        let new_schema = Arc::new(Schema::new(fields));
165        RecordBatch::try_new(new_schema, arrays).map_err(Error::Arrow)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use arrow::{
172        array::{Int32Array, StringArray},
173        datatypes::DataType,
174    };
175
176    use super::*;
177
178    fn create_test_batch() -> RecordBatch {
179        let schema = Arc::new(Schema::new(vec![
180            Field::new("id", DataType::Int32, false),
181            Field::new("name", DataType::Utf8, false),
182            Field::new("value", DataType::Int32, false),
183        ]));
184
185        let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
186        let name_array = StringArray::from(vec!["a", "b", "c", "d", "e"]);
187        let value_array = Int32Array::from(vec![10, 20, 30, 40, 50]);
188
189        RecordBatch::try_new(
190            schema,
191            vec![
192                Arc::new(id_array),
193                Arc::new(name_array),
194                Arc::new(value_array),
195            ],
196        )
197        .ok()
198        .unwrap_or_else(|| panic!("Should create batch"))
199    }
200
201    #[test]
202    fn test_select_transform() {
203        let batch = create_test_batch();
204        let transform = Select::new(vec!["id", "value"]);
205
206        let result = transform.apply(batch);
207        assert!(result.is_ok());
208        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
209        assert_eq!(result.num_columns(), 2);
210        assert_eq!(result.schema().field(0).name(), "id");
211        assert_eq!(result.schema().field(1).name(), "value");
212    }
213
214    #[test]
215    fn test_select_column_not_found() {
216        let batch = create_test_batch();
217        let transform = Select::new(vec!["nonexistent"]);
218
219        let result = transform.apply(batch);
220        assert!(result.is_err());
221    }
222
223    #[test]
224    fn test_select_columns_getter() {
225        let select = Select::new(vec!["a", "b"]);
226        assert_eq!(select.columns(), &["a", "b"]);
227    }
228
229    #[test]
230    fn test_select_preserves_column_order() {
231        let batch = create_test_batch();
232        // Select in reverse order
233        let select = Select::new(vec!["value", "name", "id"]);
234        let result = select.apply(batch);
235        assert!(result.is_ok());
236        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
237        assert_eq!(result.schema().field(0).name(), "value");
238        assert_eq!(result.schema().field(1).name(), "name");
239        assert_eq!(result.schema().field(2).name(), "id");
240    }
241
242    #[test]
243    fn test_rename_transform() {
244        let batch = create_test_batch();
245        let transform = Rename::from_pairs([("id", "identifier"), ("name", "label")]);
246
247        let result = transform.apply(batch);
248        assert!(result.is_ok());
249        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
250
251        assert_eq!(result.schema().field(0).name(), "identifier");
252        assert_eq!(result.schema().field(1).name(), "label");
253        assert_eq!(result.schema().field(2).name(), "value"); // Unchanged
254    }
255
256    #[test]
257    fn test_rename_multiple_columns() {
258        let batch = create_test_batch();
259        let transform = Rename::from_pairs([("id", "identifier"), ("name", "label")]);
260        let result = transform.apply(batch);
261        assert!(result.is_ok());
262        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
263
264        assert!(result.schema().field_with_name("identifier").is_ok());
265        assert!(result.schema().field_with_name("label").is_ok());
266    }
267
268    #[test]
269    fn test_rename_nonexistent_column_is_ok() {
270        let batch = create_test_batch();
271        let transform = Rename::from_pairs([("nonexistent", "new_name")]);
272        let result = transform.apply(batch.clone());
273        // Renaming a nonexistent column should succeed (no-op)
274        assert!(result.is_ok());
275        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
276        assert_eq!(result.num_rows(), batch.num_rows());
277    }
278
279    #[test]
280    fn test_rename_debug() {
281        let rename = Rename::from_pairs([("old", "new")]);
282        let debug_str = format!("{:?}", rename);
283        assert!(debug_str.contains("Rename"));
284    }
285
286    #[test]
287    fn test_rename_nonexistent_column() {
288        let batch = create_test_batch();
289        let rename = Rename::from_pairs([("nonexistent", "new_name")]);
290        let result = rename.apply(batch);
291        // Renaming nonexistent column should succeed (no-op)
292        assert!(result.is_ok());
293    }
294
295    #[test]
296    fn test_drop_transform() {
297        let batch = create_test_batch();
298        let transform = Drop::new(vec!["name"]);
299
300        let result = transform.apply(batch);
301        assert!(result.is_ok());
302        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
303        assert_eq!(result.num_columns(), 2);
304        assert_eq!(result.schema().field(0).name(), "id");
305        assert_eq!(result.schema().field(1).name(), "value");
306    }
307
308    #[test]
309    fn test_drop_multiple_columns() {
310        let batch = create_test_batch();
311        let transform = Drop::new(vec!["id", "name"]);
312
313        let result = transform.apply(batch);
314        assert!(result.is_ok());
315        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
316        assert_eq!(result.num_columns(), 1);
317        assert_eq!(result.schema().field(0).name(), "value");
318    }
319
320    #[test]
321    fn test_drop_all_columns_error() {
322        let batch = create_test_batch();
323        let transform = Drop::new(vec!["id", "name", "value"]);
324
325        let result = transform.apply(batch);
326        assert!(result.is_err());
327    }
328
329    #[test]
330    fn test_drop_nonexistent_column_is_ok() {
331        let batch = create_test_batch();
332        let transform = Drop::new(vec!["nonexistent"]);
333
334        let result = transform.apply(batch);
335        assert!(result.is_ok());
336        let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
337        assert_eq!(result.num_columns(), 3); // All columns remain
338    }
339
340    #[test]
341    fn test_drop_columns_getter() {
342        let transform = Drop::new(vec!["a", "b"]);
343        assert_eq!(transform.columns(), &["a", "b"]);
344    }
345
346    #[test]
347    fn test_select_debug() {
348        let select = Select::new(vec!["id", "name"]);
349        let debug_str = format!("{:?}", select);
350        assert!(debug_str.contains("Select"));
351    }
352
353    #[test]
354    fn test_drop_debug() {
355        let drop_t = Drop::new(vec!["col"]);
356        let debug_str = format!("{:?}", drop_t);
357        assert!(debug_str.contains("Drop"));
358    }
359
360    #[test]
361    fn test_drop_nonexistent_columns_unchanged() {
362        let batch = create_test_batch();
363        let original_cols = batch.num_columns();
364        let drop = Drop::new(["nonexistent_column", "also_nonexistent"]);
365        let result = drop.apply(batch);
366        assert!(result.is_ok());
367        let result = result.ok().unwrap();
368        // Dropping nonexistent columns should return unchanged batch
369        assert_eq!(result.num_columns(), original_cols);
370    }
371}