1use std::sync::Arc;
4
5use arrow::{
6 array::RecordBatch,
7 datatypes::{Field, Schema},
8};
9
10use super::Transform;
11use crate::error::{Error, Result};
12
13#[derive(Debug, Clone)]
23pub struct Select {
24 columns: Vec<String>,
25}
26
27impl Select {
28 pub fn new<S: Into<String>>(columns: impl IntoIterator<Item = S>) -> Self {
30 Self {
31 columns: columns.into_iter().map(Into::into).collect(),
32 }
33 }
34
35 pub fn columns(&self) -> &[String] {
37 &self.columns
38 }
39}
40
41impl Transform for Select {
42 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
43 let schema = batch.schema();
44 let mut fields = Vec::with_capacity(self.columns.len());
45 let mut arrays = Vec::with_capacity(self.columns.len());
46
47 for col_name in &self.columns {
48 let (idx, field) = schema
49 .column_with_name(col_name)
50 .ok_or_else(|| Error::column_not_found(col_name))?;
51
52 fields.push(field.clone());
53 arrays.push(Arc::clone(batch.column(idx)));
54 }
55
56 let new_schema = Arc::new(Schema::new(fields));
57 RecordBatch::try_new(new_schema, arrays).map_err(Error::Arrow)
58 }
59}
60
61#[derive(Debug, Clone)]
74pub struct Rename {
75 mapping: std::collections::HashMap<String, String>,
76}
77
78impl Rename {
79 pub fn new(mapping: std::collections::HashMap<String, String>) -> Self {
81 Self { mapping }
82 }
83
84 pub fn from_pairs<S: Into<String>>(pairs: impl IntoIterator<Item = (S, S)>) -> Self {
86 let mapping = pairs
87 .into_iter()
88 .map(|(old, new)| (old.into(), new.into()))
89 .collect();
90 Self { mapping }
91 }
92}
93
94impl Transform for Rename {
95 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
96 let schema = batch.schema();
97 let new_fields: Vec<Field> = schema
98 .fields()
99 .iter()
100 .map(|field| {
101 let name = field.name();
102 match self.mapping.get(name) {
103 Some(new_name) => {
104 Field::new(new_name, field.data_type().clone(), field.is_nullable())
105 }
106 None => field.as_ref().clone(),
107 }
108 })
109 .collect();
110
111 let new_schema = Arc::new(Schema::new(new_fields));
112 RecordBatch::try_new(new_schema, batch.columns().to_vec()).map_err(Error::Arrow)
113 }
114}
115
116#[derive(Debug, Clone)]
126pub struct Drop {
127 columns: Vec<String>,
128}
129
130impl Drop {
131 pub fn new<S: Into<String>>(columns: impl IntoIterator<Item = S>) -> Self {
133 Self {
134 columns: columns.into_iter().map(Into::into).collect(),
135 }
136 }
137
138 pub fn columns(&self) -> &[String] {
140 &self.columns
141 }
142}
143
144impl Transform for Drop {
145 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch> {
146 let schema = batch.schema();
147 let drop_set: std::collections::HashSet<&str> =
148 self.columns.iter().map(String::as_str).collect();
149
150 let mut fields = Vec::new();
151 let mut arrays = Vec::new();
152
153 for (idx, field) in schema.fields().iter().enumerate() {
154 if !drop_set.contains(field.name().as_str()) {
155 fields.push(field.as_ref().clone());
156 arrays.push(Arc::clone(batch.column(idx)));
157 }
158 }
159
160 if fields.is_empty() {
161 return Err(Error::transform("Cannot drop all columns from batch"));
162 }
163
164 let new_schema = Arc::new(Schema::new(fields));
165 RecordBatch::try_new(new_schema, arrays).map_err(Error::Arrow)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use arrow::{
172 array::{Int32Array, StringArray},
173 datatypes::DataType,
174 };
175
176 use super::*;
177
178 fn create_test_batch() -> RecordBatch {
179 let schema = Arc::new(Schema::new(vec![
180 Field::new("id", DataType::Int32, false),
181 Field::new("name", DataType::Utf8, false),
182 Field::new("value", DataType::Int32, false),
183 ]));
184
185 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
186 let name_array = StringArray::from(vec!["a", "b", "c", "d", "e"]);
187 let value_array = Int32Array::from(vec![10, 20, 30, 40, 50]);
188
189 RecordBatch::try_new(
190 schema,
191 vec![
192 Arc::new(id_array),
193 Arc::new(name_array),
194 Arc::new(value_array),
195 ],
196 )
197 .ok()
198 .unwrap_or_else(|| panic!("Should create batch"))
199 }
200
201 #[test]
202 fn test_select_transform() {
203 let batch = create_test_batch();
204 let transform = Select::new(vec!["id", "value"]);
205
206 let result = transform.apply(batch);
207 assert!(result.is_ok());
208 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
209 assert_eq!(result.num_columns(), 2);
210 assert_eq!(result.schema().field(0).name(), "id");
211 assert_eq!(result.schema().field(1).name(), "value");
212 }
213
214 #[test]
215 fn test_select_column_not_found() {
216 let batch = create_test_batch();
217 let transform = Select::new(vec!["nonexistent"]);
218
219 let result = transform.apply(batch);
220 assert!(result.is_err());
221 }
222
223 #[test]
224 fn test_select_columns_getter() {
225 let select = Select::new(vec!["a", "b"]);
226 assert_eq!(select.columns(), &["a", "b"]);
227 }
228
229 #[test]
230 fn test_select_preserves_column_order() {
231 let batch = create_test_batch();
232 let select = Select::new(vec!["value", "name", "id"]);
234 let result = select.apply(batch);
235 assert!(result.is_ok());
236 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
237 assert_eq!(result.schema().field(0).name(), "value");
238 assert_eq!(result.schema().field(1).name(), "name");
239 assert_eq!(result.schema().field(2).name(), "id");
240 }
241
242 #[test]
243 fn test_rename_transform() {
244 let batch = create_test_batch();
245 let transform = Rename::from_pairs([("id", "identifier"), ("name", "label")]);
246
247 let result = transform.apply(batch);
248 assert!(result.is_ok());
249 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
250
251 assert_eq!(result.schema().field(0).name(), "identifier");
252 assert_eq!(result.schema().field(1).name(), "label");
253 assert_eq!(result.schema().field(2).name(), "value"); }
255
256 #[test]
257 fn test_rename_multiple_columns() {
258 let batch = create_test_batch();
259 let transform = Rename::from_pairs([("id", "identifier"), ("name", "label")]);
260 let result = transform.apply(batch);
261 assert!(result.is_ok());
262 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
263
264 assert!(result.schema().field_with_name("identifier").is_ok());
265 assert!(result.schema().field_with_name("label").is_ok());
266 }
267
268 #[test]
269 fn test_rename_nonexistent_column_is_ok() {
270 let batch = create_test_batch();
271 let transform = Rename::from_pairs([("nonexistent", "new_name")]);
272 let result = transform.apply(batch.clone());
273 assert!(result.is_ok());
275 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
276 assert_eq!(result.num_rows(), batch.num_rows());
277 }
278
279 #[test]
280 fn test_rename_debug() {
281 let rename = Rename::from_pairs([("old", "new")]);
282 let debug_str = format!("{:?}", rename);
283 assert!(debug_str.contains("Rename"));
284 }
285
286 #[test]
287 fn test_rename_nonexistent_column() {
288 let batch = create_test_batch();
289 let rename = Rename::from_pairs([("nonexistent", "new_name")]);
290 let result = rename.apply(batch);
291 assert!(result.is_ok());
293 }
294
295 #[test]
296 fn test_drop_transform() {
297 let batch = create_test_batch();
298 let transform = Drop::new(vec!["name"]);
299
300 let result = transform.apply(batch);
301 assert!(result.is_ok());
302 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
303 assert_eq!(result.num_columns(), 2);
304 assert_eq!(result.schema().field(0).name(), "id");
305 assert_eq!(result.schema().field(1).name(), "value");
306 }
307
308 #[test]
309 fn test_drop_multiple_columns() {
310 let batch = create_test_batch();
311 let transform = Drop::new(vec!["id", "name"]);
312
313 let result = transform.apply(batch);
314 assert!(result.is_ok());
315 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
316 assert_eq!(result.num_columns(), 1);
317 assert_eq!(result.schema().field(0).name(), "value");
318 }
319
320 #[test]
321 fn test_drop_all_columns_error() {
322 let batch = create_test_batch();
323 let transform = Drop::new(vec!["id", "name", "value"]);
324
325 let result = transform.apply(batch);
326 assert!(result.is_err());
327 }
328
329 #[test]
330 fn test_drop_nonexistent_column_is_ok() {
331 let batch = create_test_batch();
332 let transform = Drop::new(vec!["nonexistent"]);
333
334 let result = transform.apply(batch);
335 assert!(result.is_ok());
336 let result = result.ok().unwrap_or_else(|| panic!("Should succeed"));
337 assert_eq!(result.num_columns(), 3); }
339
340 #[test]
341 fn test_drop_columns_getter() {
342 let transform = Drop::new(vec!["a", "b"]);
343 assert_eq!(transform.columns(), &["a", "b"]);
344 }
345
346 #[test]
347 fn test_select_debug() {
348 let select = Select::new(vec!["id", "name"]);
349 let debug_str = format!("{:?}", select);
350 assert!(debug_str.contains("Select"));
351 }
352
353 #[test]
354 fn test_drop_debug() {
355 let drop_t = Drop::new(vec!["col"]);
356 let debug_str = format!("{:?}", drop_t);
357 assert!(debug_str.contains("Drop"));
358 }
359
360 #[test]
361 fn test_drop_nonexistent_columns_unchanged() {
362 let batch = create_test_batch();
363 let original_cols = batch.num_columns();
364 let drop = Drop::new(["nonexistent_column", "also_nonexistent"]);
365 let result = drop.apply(batch);
366 assert!(result.is_ok());
367 let result = result.ok().unwrap();
368 assert_eq!(result.num_columns(), original_cols);
370 }
371}