datafusion_remote_table/
transform.rs

1use crate::{DFResult, RemoteField, RemoteSchemaRef};
2use datafusion::arrow::array::{
3    Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Float16Array, Float32Array,
4    Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, ListArray, NullArray, RecordBatch,
5    StringArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
6    TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array,
7    UInt64Array,
8};
9use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
10use datafusion::common::{DataFusionError, project_schema};
11use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
12use futures::{Stream, StreamExt};
13use std::any::Any;
14use std::fmt::Debug;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18
19pub struct TransformArgs<'a> {
20    pub col_index: usize,
21    pub field: &'a Field,
22    pub remote_field: Option<&'a RemoteField>,
23}
24
25pub trait Transform: Debug + Send + Sync {
26    fn as_any(&self) -> &dyn Any;
27
28    fn transform_null(
29        &self,
30        array: &NullArray,
31        args: TransformArgs,
32    ) -> DFResult<(ArrayRef, Field)> {
33        Ok((Arc::new(array.clone()), args.field.clone()))
34    }
35
36    fn transform_boolean(
37        &self,
38        array: &BooleanArray,
39        args: TransformArgs,
40    ) -> DFResult<(ArrayRef, Field)> {
41        Ok((Arc::new(array.clone()), args.field.clone()))
42    }
43
44    fn transform_int8(
45        &self,
46        array: &Int8Array,
47        args: TransformArgs,
48    ) -> DFResult<(ArrayRef, Field)> {
49        Ok((Arc::new(array.clone()), args.field.clone()))
50    }
51
52    fn transform_int16(
53        &self,
54        array: &Int16Array,
55        args: TransformArgs,
56    ) -> DFResult<(ArrayRef, Field)> {
57        Ok((Arc::new(array.clone()), args.field.clone()))
58    }
59
60    fn transform_int32(
61        &self,
62        array: &Int32Array,
63        args: TransformArgs,
64    ) -> DFResult<(ArrayRef, Field)> {
65        Ok((Arc::new(array.clone()), args.field.clone()))
66    }
67
68    fn transform_int64(
69        &self,
70        array: &Int64Array,
71        args: TransformArgs,
72    ) -> DFResult<(ArrayRef, Field)> {
73        Ok((Arc::new(array.clone()), args.field.clone()))
74    }
75
76    fn transform_uint8(
77        &self,
78        array: &UInt8Array,
79        args: TransformArgs,
80    ) -> DFResult<(ArrayRef, Field)> {
81        Ok((Arc::new(array.clone()), args.field.clone()))
82    }
83
84    fn transform_uint16(
85        &self,
86        array: &UInt16Array,
87        args: TransformArgs,
88    ) -> DFResult<(ArrayRef, Field)> {
89        Ok((Arc::new(array.clone()), args.field.clone()))
90    }
91
92    fn transform_uint32(
93        &self,
94        array: &UInt32Array,
95        args: TransformArgs,
96    ) -> DFResult<(ArrayRef, Field)> {
97        Ok((Arc::new(array.clone()), args.field.clone()))
98    }
99
100    fn transform_uint64(
101        &self,
102        array: &UInt64Array,
103        args: TransformArgs,
104    ) -> DFResult<(ArrayRef, Field)> {
105        Ok((Arc::new(array.clone()), args.field.clone()))
106    }
107
108    fn transform_float16(
109        &self,
110        array: &Float16Array,
111        args: TransformArgs,
112    ) -> DFResult<(ArrayRef, Field)> {
113        Ok((Arc::new(array.clone()), args.field.clone()))
114    }
115
116    fn transform_float32(
117        &self,
118        array: &Float32Array,
119        args: TransformArgs,
120    ) -> DFResult<(ArrayRef, Field)> {
121        Ok((Arc::new(array.clone()), args.field.clone()))
122    }
123
124    fn transform_float64(
125        &self,
126        array: &Float64Array,
127        args: TransformArgs,
128    ) -> DFResult<(ArrayRef, Field)> {
129        Ok((Arc::new(array.clone()), args.field.clone()))
130    }
131
132    fn transform_utf8(
133        &self,
134        array: &StringArray,
135        args: TransformArgs,
136    ) -> DFResult<(ArrayRef, Field)> {
137        Ok((Arc::new(array.clone()), args.field.clone()))
138    }
139
140    fn transform_binary(
141        &self,
142        array: &BinaryArray,
143        args: TransformArgs,
144    ) -> DFResult<(ArrayRef, Field)> {
145        Ok((Arc::new(array.clone()), args.field.clone()))
146    }
147
148    fn transform_timestamp_second(
149        &self,
150        array: &TimestampSecondArray,
151        args: TransformArgs,
152    ) -> DFResult<(ArrayRef, Field)> {
153        Ok((Arc::new(array.clone()), args.field.clone()))
154    }
155
156    fn transform_timestamp_millisecond(
157        &self,
158        array: &TimestampMillisecondArray,
159        args: TransformArgs,
160    ) -> DFResult<(ArrayRef, Field)> {
161        Ok((Arc::new(array.clone()), args.field.clone()))
162    }
163
164    fn transform_timestamp_microsecond(
165        &self,
166        array: &TimestampMicrosecondArray,
167        args: TransformArgs,
168    ) -> DFResult<(ArrayRef, Field)> {
169        Ok((Arc::new(array.clone()), args.field.clone()))
170    }
171
172    fn transform_timestamp_nanosecond(
173        &self,
174        array: &TimestampNanosecondArray,
175        args: TransformArgs,
176    ) -> DFResult<(ArrayRef, Field)> {
177        Ok((Arc::new(array.clone()), args.field.clone()))
178    }
179
180    fn transform_time64_nanosecond(
181        &self,
182        array: &Time64NanosecondArray,
183        args: TransformArgs,
184    ) -> DFResult<(ArrayRef, Field)> {
185        Ok((Arc::new(array.clone()), args.field.clone()))
186    }
187
188    fn transform_date32(
189        &self,
190        array: &Date32Array,
191        args: TransformArgs,
192    ) -> DFResult<(ArrayRef, Field)> {
193        Ok((Arc::new(array.clone()), args.field.clone()))
194    }
195
196    fn transform_list(
197        &self,
198        array: &ListArray,
199        args: TransformArgs,
200    ) -> DFResult<(ArrayRef, Field)> {
201        Ok((Arc::new(array.clone()), args.field.clone()))
202    }
203}
204
205pub(crate) struct TransformStream {
206    input: SendableRecordBatchStream,
207    transform: Arc<dyn Transform>,
208    table_schema: SchemaRef,
209    projection: Option<Vec<usize>>,
210    projected_schema: SchemaRef,
211    remote_schema: Option<RemoteSchemaRef>,
212}
213
214impl TransformStream {
215    pub fn try_new(
216        input: SendableRecordBatchStream,
217        transform: Arc<dyn Transform>,
218        table_schema: SchemaRef,
219        projection: Option<Vec<usize>>,
220        remote_schema: Option<RemoteSchemaRef>,
221    ) -> DFResult<Self> {
222        let projected_schema = project_schema(&table_schema, projection.as_ref())?;
223        Ok(Self {
224            input,
225            transform,
226            table_schema,
227            projection,
228            projected_schema,
229            remote_schema,
230        })
231    }
232}
233
234impl Stream for TransformStream {
235    type Item = DFResult<RecordBatch>;
236    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237        match self.input.poll_next_unpin(cx) {
238            Poll::Ready(Some(Ok(batch))) => {
239                match transform_batch(
240                    batch,
241                    self.transform.as_ref(),
242                    &self.table_schema,
243                    self.projection.as_ref(),
244                    self.remote_schema.as_ref(),
245                ) {
246                    Ok(result) => Poll::Ready(Some(Ok(result))),
247                    Err(e) => Poll::Ready(Some(Err(e))),
248                }
249            }
250            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
251            Poll::Ready(None) => Poll::Ready(None),
252            Poll::Pending => Poll::Pending,
253        }
254    }
255}
256
257impl RecordBatchStream for TransformStream {
258    fn schema(&self) -> SchemaRef {
259        self.projected_schema.clone()
260    }
261}
262
263pub(crate) fn transform_batch(
264    batch: RecordBatch,
265    transform: &dyn Transform,
266    table_schema: &SchemaRef,
267    projection: Option<&Vec<usize>>,
268    remote_schema: Option<&RemoteSchemaRef>,
269) -> DFResult<RecordBatch> {
270    let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(batch.schema().fields.len());
271    let mut new_fields: Vec<Field> = Vec::with_capacity(batch.schema().fields.len());
272    let all_col_indexes = (0..table_schema.fields.len()).collect::<Vec<usize>>();
273    let projected_col_indexes = projection.unwrap_or(&all_col_indexes);
274
275    for (idx, col_index) in projected_col_indexes.iter().enumerate() {
276        let field = table_schema.field(*col_index);
277        let remote_field = remote_schema.map(|schema| &schema.fields[*col_index]);
278        let args = TransformArgs {
279            col_index: *col_index,
280            field,
281            remote_field,
282        };
283
284        let (new_array, new_field) = match &field.data_type() {
285            DataType::Null => {
286                let array = batch
287                    .column(idx)
288                    .as_any()
289                    .downcast_ref::<NullArray>()
290                    .expect("Failed to downcast to NullArray");
291                transform.transform_null(array, args)?
292            }
293            DataType::Boolean => {
294                let array = batch
295                    .column(idx)
296                    .as_any()
297                    .downcast_ref::<BooleanArray>()
298                    .expect("Failed to downcast to BooleanArray");
299                transform.transform_boolean(array, args)?
300            }
301            DataType::Int8 => {
302                let array = batch
303                    .column(idx)
304                    .as_any()
305                    .downcast_ref::<Int8Array>()
306                    .expect("Failed to downcast to Int8Array");
307                transform.transform_int8(array, args)?
308            }
309            DataType::Int16 => {
310                let array = batch
311                    .column(idx)
312                    .as_any()
313                    .downcast_ref::<Int16Array>()
314                    .expect("Failed to downcast to Int16Array");
315                transform.transform_int16(array, args)?
316            }
317            DataType::Int32 => {
318                let array = batch
319                    .column(idx)
320                    .as_any()
321                    .downcast_ref::<Int32Array>()
322                    .expect("Failed to downcast to Int32Array");
323                transform.transform_int32(array, args)?
324            }
325            DataType::Int64 => {
326                let array = batch
327                    .column(idx)
328                    .as_any()
329                    .downcast_ref::<Int64Array>()
330                    .expect("Failed to downcast to Int64Array");
331                transform.transform_int64(array, args)?
332            }
333            DataType::UInt8 => {
334                let array = batch
335                    .column(idx)
336                    .as_any()
337                    .downcast_ref::<UInt8Array>()
338                    .expect("Failed to downcast to UInt8Array");
339                transform.transform_uint8(array, args)?
340            }
341            DataType::UInt16 => {
342                let array = batch
343                    .column(idx)
344                    .as_any()
345                    .downcast_ref::<UInt16Array>()
346                    .expect("Failed to downcast to UInt16Array");
347                transform.transform_uint16(array, args)?
348            }
349            DataType::UInt32 => {
350                let array = batch
351                    .column(idx)
352                    .as_any()
353                    .downcast_ref::<UInt32Array>()
354                    .expect("Failed to downcast to UInt32Array");
355                transform.transform_uint32(array, args)?
356            }
357            DataType::UInt64 => {
358                let array = batch
359                    .column(idx)
360                    .as_any()
361                    .downcast_ref::<UInt64Array>()
362                    .expect("Failed to downcast to UInt64Array");
363                transform.transform_uint64(array, args)?
364            }
365            DataType::Float16 => {
366                let array = batch
367                    .column(idx)
368                    .as_any()
369                    .downcast_ref::<Float16Array>()
370                    .expect("Failed to downcast to Float16Array");
371                transform.transform_float16(array, args)?
372            }
373            DataType::Float32 => {
374                let array = batch
375                    .column(idx)
376                    .as_any()
377                    .downcast_ref::<Float32Array>()
378                    .expect("Failed to downcast to Float32Array");
379                transform.transform_float32(array, args)?
380            }
381            DataType::Float64 => {
382                let array = batch
383                    .column(idx)
384                    .as_any()
385                    .downcast_ref::<Float64Array>()
386                    .expect("Failed to downcast to Float64Array");
387                transform.transform_float64(array, args)?
388            }
389            DataType::Utf8 => {
390                let array = batch
391                    .column(idx)
392                    .as_any()
393                    .downcast_ref::<StringArray>()
394                    .expect("Failed to downcast to StringArray");
395                transform.transform_utf8(array, args)?
396            }
397            DataType::Binary => {
398                let array = batch
399                    .column(idx)
400                    .as_any()
401                    .downcast_ref::<BinaryArray>()
402                    .expect("Failed to downcast to BinaryArray");
403                transform.transform_binary(array, args)?
404            }
405            DataType::Date32 => {
406                let array = batch
407                    .column(idx)
408                    .as_any()
409                    .downcast_ref::<Date32Array>()
410                    .expect("Failed to downcast to Date32Array");
411                transform.transform_date32(array, args)?
412            }
413            DataType::Timestamp(TimeUnit::Second, _) => {
414                let array = batch
415                    .column(idx)
416                    .as_any()
417                    .downcast_ref::<TimestampSecondArray>()
418                    .expect("Failed to downcast to TimestampSecondArray");
419                transform.transform_timestamp_second(array, args)?
420            }
421            DataType::Timestamp(TimeUnit::Millisecond, _) => {
422                let array = batch
423                    .column(idx)
424                    .as_any()
425                    .downcast_ref::<TimestampMillisecondArray>()
426                    .expect("Failed to downcast to TimestampMillisecondArray");
427                transform.transform_timestamp_millisecond(array, args)?
428            }
429            DataType::Timestamp(TimeUnit::Microsecond, _) => {
430                let array = batch
431                    .column(idx)
432                    .as_any()
433                    .downcast_ref::<TimestampMicrosecondArray>()
434                    .expect("Failed to downcast to TimestampMicrosecondArray");
435                transform.transform_timestamp_microsecond(array, args)?
436            }
437            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
438                let array = batch
439                    .column(idx)
440                    .as_any()
441                    .downcast_ref::<TimestampNanosecondArray>()
442                    .expect("Failed to downcast to TimestampNanosecondArray");
443                transform.transform_timestamp_nanosecond(array, args)?
444            }
445            DataType::Time64(TimeUnit::Nanosecond) => {
446                let array = batch
447                    .column(idx)
448                    .as_any()
449                    .downcast_ref::<Time64NanosecondArray>()
450                    .expect("Failed to downcast to Time64NanosecondArray");
451                transform.transform_time64_nanosecond(array, args)?
452            }
453            DataType::List(_field) => {
454                let array = batch
455                    .column(idx)
456                    .as_any()
457                    .downcast_ref::<ListArray>()
458                    .expect("Failed to downcast to ListArray");
459                transform.transform_list(array, args)?
460            }
461            data_type => {
462                return Err(DataFusionError::NotImplemented(format!(
463                    "Unsupported arrow type {data_type:?}",
464                )));
465            }
466        };
467        new_arrays.push(new_array);
468        new_fields.push(new_field);
469    }
470    let new_schema = Arc::new(Schema::new(new_fields));
471    Ok(RecordBatch::try_new(new_schema, new_arrays)?)
472}
473
474pub fn transform_schema(
475    schema: SchemaRef,
476    transform: &dyn Transform,
477    remote_schema: Option<&RemoteSchemaRef>,
478) -> DFResult<SchemaRef> {
479    let empty_record = RecordBatch::new_empty(schema.clone());
480    transform_batch(empty_record, transform, &schema, None, remote_schema)
481        .map(|batch| batch.schema())
482}