datafusion_remote_table/
transform.rs

1use crate::{DFResult, RemoteField, RemoteSchemaRef};
2use datafusion::arrow::array::{
3    Array, ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
4    Decimal128Array, Decimal256Array, FixedSizeBinaryArray, FixedSizeListArray, Float16Array,
5    Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array,
6    IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
7    LargeListArray, LargeListViewArray, LargeStringArray, ListArray, ListViewArray, NullArray,
8    RecordBatch, StringArray, StringViewArray, Time32MillisecondArray, Time32SecondArray,
9    Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
10    TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt8Array,
11    UInt16Array, UInt32Array, UInt64Array,
12};
13use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit};
14use datafusion::common::{DataFusionError, project_schema};
15use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
16use futures::{Stream, StreamExt};
17use std::any::Any;
18use std::fmt::Debug;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23pub struct TransformArgs<'a> {
24    pub col_index: usize,
25    pub field: &'a Field,
26    pub remote_field: Option<&'a RemoteField>,
27}
28
29pub trait Transform: Debug + Send + Sync {
30    fn as_any(&self) -> &dyn Any;
31
32    fn transform_null(
33        &self,
34        array: &NullArray,
35        args: TransformArgs,
36    ) -> DFResult<(ArrayRef, Field)> {
37        Ok((Arc::new(array.clone()), args.field.clone()))
38    }
39
40    fn transform_boolean(
41        &self,
42        array: &BooleanArray,
43        args: TransformArgs,
44    ) -> DFResult<(ArrayRef, Field)> {
45        Ok((Arc::new(array.clone()), args.field.clone()))
46    }
47
48    fn transform_int8(
49        &self,
50        array: &Int8Array,
51        args: TransformArgs,
52    ) -> DFResult<(ArrayRef, Field)> {
53        Ok((Arc::new(array.clone()), args.field.clone()))
54    }
55
56    fn transform_int16(
57        &self,
58        array: &Int16Array,
59        args: TransformArgs,
60    ) -> DFResult<(ArrayRef, Field)> {
61        Ok((Arc::new(array.clone()), args.field.clone()))
62    }
63
64    fn transform_int32(
65        &self,
66        array: &Int32Array,
67        args: TransformArgs,
68    ) -> DFResult<(ArrayRef, Field)> {
69        Ok((Arc::new(array.clone()), args.field.clone()))
70    }
71
72    fn transform_int64(
73        &self,
74        array: &Int64Array,
75        args: TransformArgs,
76    ) -> DFResult<(ArrayRef, Field)> {
77        Ok((Arc::new(array.clone()), args.field.clone()))
78    }
79
80    fn transform_uint8(
81        &self,
82        array: &UInt8Array,
83        args: TransformArgs,
84    ) -> DFResult<(ArrayRef, Field)> {
85        Ok((Arc::new(array.clone()), args.field.clone()))
86    }
87
88    fn transform_uint16(
89        &self,
90        array: &UInt16Array,
91        args: TransformArgs,
92    ) -> DFResult<(ArrayRef, Field)> {
93        Ok((Arc::new(array.clone()), args.field.clone()))
94    }
95
96    fn transform_uint32(
97        &self,
98        array: &UInt32Array,
99        args: TransformArgs,
100    ) -> DFResult<(ArrayRef, Field)> {
101        Ok((Arc::new(array.clone()), args.field.clone()))
102    }
103
104    fn transform_uint64(
105        &self,
106        array: &UInt64Array,
107        args: TransformArgs,
108    ) -> DFResult<(ArrayRef, Field)> {
109        Ok((Arc::new(array.clone()), args.field.clone()))
110    }
111
112    fn transform_float16(
113        &self,
114        array: &Float16Array,
115        args: TransformArgs,
116    ) -> DFResult<(ArrayRef, Field)> {
117        Ok((Arc::new(array.clone()), args.field.clone()))
118    }
119
120    fn transform_float32(
121        &self,
122        array: &Float32Array,
123        args: TransformArgs,
124    ) -> DFResult<(ArrayRef, Field)> {
125        Ok((Arc::new(array.clone()), args.field.clone()))
126    }
127
128    fn transform_float64(
129        &self,
130        array: &Float64Array,
131        args: TransformArgs,
132    ) -> DFResult<(ArrayRef, Field)> {
133        Ok((Arc::new(array.clone()), args.field.clone()))
134    }
135
136    fn transform_binary(
137        &self,
138        array: &BinaryArray,
139        args: TransformArgs,
140    ) -> DFResult<(ArrayRef, Field)> {
141        Ok((Arc::new(array.clone()), args.field.clone()))
142    }
143
144    fn transform_fixed_size_binary(
145        &self,
146        array: &FixedSizeBinaryArray,
147        args: TransformArgs,
148    ) -> DFResult<(ArrayRef, Field)> {
149        Ok((Arc::new(array.clone()), args.field.clone()))
150    }
151
152    fn transform_large_binary(
153        &self,
154        array: &LargeBinaryArray,
155        args: TransformArgs,
156    ) -> DFResult<(ArrayRef, Field)> {
157        Ok((Arc::new(array.clone()), args.field.clone()))
158    }
159
160    fn transform_binary_view(
161        &self,
162        array: &BinaryViewArray,
163        args: TransformArgs,
164    ) -> DFResult<(ArrayRef, Field)> {
165        Ok((Arc::new(array.clone()), args.field.clone()))
166    }
167
168    fn transform_utf8(
169        &self,
170        array: &StringArray,
171        args: TransformArgs,
172    ) -> DFResult<(ArrayRef, Field)> {
173        Ok((Arc::new(array.clone()), args.field.clone()))
174    }
175
176    fn transform_large_utf8(
177        &self,
178        array: &LargeStringArray,
179        args: TransformArgs,
180    ) -> DFResult<(ArrayRef, Field)> {
181        Ok((Arc::new(array.clone()), args.field.clone()))
182    }
183
184    fn transform_utf8_view(
185        &self,
186        array: &StringViewArray,
187        args: TransformArgs,
188    ) -> DFResult<(ArrayRef, Field)> {
189        Ok((Arc::new(array.clone()), args.field.clone()))
190    }
191
192    fn transform_timestamp_second(
193        &self,
194        array: &TimestampSecondArray,
195        args: TransformArgs,
196    ) -> DFResult<(ArrayRef, Field)> {
197        Ok((Arc::new(array.clone()), args.field.clone()))
198    }
199
200    fn transform_timestamp_millisecond(
201        &self,
202        array: &TimestampMillisecondArray,
203        args: TransformArgs,
204    ) -> DFResult<(ArrayRef, Field)> {
205        Ok((Arc::new(array.clone()), args.field.clone()))
206    }
207
208    fn transform_timestamp_microsecond(
209        &self,
210        array: &TimestampMicrosecondArray,
211        args: TransformArgs,
212    ) -> DFResult<(ArrayRef, Field)> {
213        Ok((Arc::new(array.clone()), args.field.clone()))
214    }
215
216    fn transform_timestamp_nanosecond(
217        &self,
218        array: &TimestampNanosecondArray,
219        args: TransformArgs,
220    ) -> DFResult<(ArrayRef, Field)> {
221        Ok((Arc::new(array.clone()), args.field.clone()))
222    }
223
224    fn transform_date32(
225        &self,
226        array: &Date32Array,
227        args: TransformArgs,
228    ) -> DFResult<(ArrayRef, Field)> {
229        Ok((Arc::new(array.clone()), args.field.clone()))
230    }
231
232    fn transform_date64(
233        &self,
234        array: &Date64Array,
235        args: TransformArgs,
236    ) -> DFResult<(ArrayRef, Field)> {
237        Ok((Arc::new(array.clone()), args.field.clone()))
238    }
239
240    fn transform_time32_second(
241        &self,
242        array: &Time32SecondArray,
243        args: TransformArgs,
244    ) -> DFResult<(ArrayRef, Field)> {
245        Ok((Arc::new(array.clone()), args.field.clone()))
246    }
247
248    fn transform_time32_millisecond(
249        &self,
250        array: &Time32MillisecondArray,
251        args: TransformArgs,
252    ) -> DFResult<(ArrayRef, Field)> {
253        Ok((Arc::new(array.clone()), args.field.clone()))
254    }
255
256    fn transform_time64_microsecond(
257        &self,
258        array: &Time64MicrosecondArray,
259        args: TransformArgs,
260    ) -> DFResult<(ArrayRef, Field)> {
261        Ok((Arc::new(array.clone()), args.field.clone()))
262    }
263
264    fn transform_time64_nanosecond(
265        &self,
266        array: &Time64NanosecondArray,
267        args: TransformArgs,
268    ) -> DFResult<(ArrayRef, Field)> {
269        Ok((Arc::new(array.clone()), args.field.clone()))
270    }
271
272    fn transform_interval_year_month(
273        &self,
274        array: &IntervalYearMonthArray,
275        args: TransformArgs,
276    ) -> DFResult<(ArrayRef, Field)> {
277        Ok((Arc::new(array.clone()), args.field.clone()))
278    }
279
280    fn transform_interval_day_time(
281        &self,
282        array: &IntervalDayTimeArray,
283        args: TransformArgs,
284    ) -> DFResult<(ArrayRef, Field)> {
285        Ok((Arc::new(array.clone()), args.field.clone()))
286    }
287
288    fn transform_interval_month_day_nano(
289        &self,
290        array: &IntervalMonthDayNanoArray,
291        args: TransformArgs,
292    ) -> DFResult<(ArrayRef, Field)> {
293        Ok((Arc::new(array.clone()), args.field.clone()))
294    }
295
296    fn transform_list(
297        &self,
298        array: &ListArray,
299        args: TransformArgs,
300    ) -> DFResult<(ArrayRef, Field)> {
301        Ok((Arc::new(array.clone()), args.field.clone()))
302    }
303
304    fn transform_list_view(
305        &self,
306        array: &ListViewArray,
307        args: TransformArgs,
308    ) -> DFResult<(ArrayRef, Field)> {
309        Ok((Arc::new(array.clone()), args.field.clone()))
310    }
311
312    fn transform_fixed_size_list(
313        &self,
314        array: &FixedSizeListArray,
315        args: TransformArgs,
316    ) -> DFResult<(ArrayRef, Field)> {
317        Ok((Arc::new(array.clone()), args.field.clone()))
318    }
319
320    fn transform_large_list(
321        &self,
322        array: &LargeListArray,
323        args: TransformArgs,
324    ) -> DFResult<(ArrayRef, Field)> {
325        Ok((Arc::new(array.clone()), args.field.clone()))
326    }
327
328    fn transform_large_list_view(
329        &self,
330        array: &LargeListViewArray,
331        args: TransformArgs,
332    ) -> DFResult<(ArrayRef, Field)> {
333        Ok((Arc::new(array.clone()), args.field.clone()))
334    }
335
336    fn transform_decimal128(
337        &self,
338        array: &Decimal128Array,
339        args: TransformArgs,
340    ) -> DFResult<(ArrayRef, Field)> {
341        Ok((Arc::new(array.clone()), args.field.clone()))
342    }
343
344    fn transform_decimal256(
345        &self,
346        array: &Decimal256Array,
347        args: TransformArgs,
348    ) -> DFResult<(ArrayRef, Field)> {
349        Ok((Arc::new(array.clone()), args.field.clone()))
350    }
351}
352
353#[derive(Debug)]
354pub struct DefaultTransform {}
355
356impl Transform for DefaultTransform {
357    fn as_any(&self) -> &dyn Any {
358        self
359    }
360}
361
362pub(crate) struct TransformStream {
363    input: SendableRecordBatchStream,
364    transform: Arc<dyn Transform>,
365    table_schema: SchemaRef,
366    projection: Option<Vec<usize>>,
367    projected_schema: SchemaRef,
368    remote_schema: Option<RemoteSchemaRef>,
369}
370
371impl TransformStream {
372    pub fn try_new(
373        input: SendableRecordBatchStream,
374        transform: Arc<dyn Transform>,
375        table_schema: SchemaRef,
376        projection: Option<Vec<usize>>,
377        remote_schema: Option<RemoteSchemaRef>,
378    ) -> DFResult<Self> {
379        let projected_schema = project_schema(&table_schema, projection.as_ref())?;
380        Ok(Self {
381            input,
382            transform,
383            table_schema,
384            projection,
385            projected_schema,
386            remote_schema,
387        })
388    }
389}
390
391impl Stream for TransformStream {
392    type Item = DFResult<RecordBatch>;
393    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
394        match self.input.poll_next_unpin(cx) {
395            Poll::Ready(Some(Ok(batch))) => {
396                match transform_batch(
397                    batch,
398                    self.transform.as_ref(),
399                    &self.table_schema,
400                    self.projection.as_ref(),
401                    self.remote_schema.as_ref(),
402                ) {
403                    Ok(result) => Poll::Ready(Some(Ok(result))),
404                    Err(e) => Poll::Ready(Some(Err(e))),
405                }
406            }
407            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
408            Poll::Ready(None) => Poll::Ready(None),
409            Poll::Pending => Poll::Pending,
410        }
411    }
412}
413
414impl RecordBatchStream for TransformStream {
415    fn schema(&self) -> SchemaRef {
416        self.projected_schema.clone()
417    }
418}
419
420macro_rules! handle_transform {
421    ($batch:expr, $idx:expr, $array_ty:ty, $transform:expr, $transform_method:ident, $transform_args:expr) => {{
422        let array = $batch
423            .column($idx)
424            .as_any()
425            .downcast_ref::<$array_ty>()
426            .expect(concat!("Failed to downcast to ", stringify!($array_ty)));
427        $transform.$transform_method(array, $transform_args)?
428    }};
429}
430
431pub(crate) fn transform_batch(
432    batch: RecordBatch,
433    transform: &dyn Transform,
434    table_schema: &SchemaRef,
435    projection: Option<&Vec<usize>>,
436    remote_schema: Option<&RemoteSchemaRef>,
437) -> DFResult<RecordBatch> {
438    let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(batch.schema().fields.len());
439    let mut new_fields: Vec<Field> = Vec::with_capacity(batch.schema().fields.len());
440    let all_col_indexes = (0..table_schema.fields.len()).collect::<Vec<usize>>();
441    let projected_col_indexes = projection.unwrap_or(&all_col_indexes);
442
443    for (idx, col_index) in projected_col_indexes.iter().enumerate() {
444        let field = table_schema.field(*col_index);
445        let remote_field = remote_schema.map(|schema| &schema.fields[*col_index]);
446        let args = TransformArgs {
447            col_index: *col_index,
448            field,
449            remote_field,
450        };
451
452        let (new_array, new_field) = match &field.data_type() {
453            DataType::Null => {
454                handle_transform!(batch, idx, NullArray, transform, transform_null, args)
455            }
456            DataType::Boolean => {
457                handle_transform!(batch, idx, BooleanArray, transform, transform_boolean, args)
458            }
459            DataType::Int8 => {
460                handle_transform!(batch, idx, Int8Array, transform, transform_int8, args)
461            }
462            DataType::Int16 => {
463                handle_transform!(batch, idx, Int16Array, transform, transform_int16, args)
464            }
465            DataType::Int32 => {
466                handle_transform!(batch, idx, Int32Array, transform, transform_int32, args)
467            }
468            DataType::Int64 => {
469                handle_transform!(batch, idx, Int64Array, transform, transform_int64, args)
470            }
471            DataType::UInt8 => {
472                handle_transform!(batch, idx, UInt8Array, transform, transform_uint8, args)
473            }
474            DataType::UInt16 => {
475                handle_transform!(batch, idx, UInt16Array, transform, transform_uint16, args)
476            }
477            DataType::UInt32 => {
478                handle_transform!(batch, idx, UInt32Array, transform, transform_uint32, args)
479            }
480            DataType::UInt64 => {
481                handle_transform!(batch, idx, UInt64Array, transform, transform_uint64, args)
482            }
483            DataType::Float16 => {
484                handle_transform!(batch, idx, Float16Array, transform, transform_float16, args)
485            }
486            DataType::Float32 => {
487                handle_transform!(batch, idx, Float32Array, transform, transform_float32, args)
488            }
489            DataType::Float64 => {
490                handle_transform!(batch, idx, Float64Array, transform, transform_float64, args)
491            }
492            DataType::Timestamp(TimeUnit::Second, _) => {
493                handle_transform!(
494                    batch,
495                    idx,
496                    TimestampSecondArray,
497                    transform,
498                    transform_timestamp_second,
499                    args
500                )
501            }
502            DataType::Timestamp(TimeUnit::Millisecond, _) => {
503                handle_transform!(
504                    batch,
505                    idx,
506                    TimestampMillisecondArray,
507                    transform,
508                    transform_timestamp_millisecond,
509                    args
510                )
511            }
512            DataType::Timestamp(TimeUnit::Microsecond, _) => {
513                handle_transform!(
514                    batch,
515                    idx,
516                    TimestampMicrosecondArray,
517                    transform,
518                    transform_timestamp_microsecond,
519                    args
520                )
521            }
522            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
523                handle_transform!(
524                    batch,
525                    idx,
526                    TimestampNanosecondArray,
527                    transform,
528                    transform_timestamp_nanosecond,
529                    args
530                )
531            }
532            DataType::Date32 => {
533                handle_transform!(batch, idx, Date32Array, transform, transform_date32, args)
534            }
535            DataType::Date64 => {
536                handle_transform!(batch, idx, Date64Array, transform, transform_date64, args)
537            }
538            DataType::Time32(TimeUnit::Second) => {
539                handle_transform!(
540                    batch,
541                    idx,
542                    Time32SecondArray,
543                    transform,
544                    transform_time32_second,
545                    args
546                )
547            }
548            DataType::Time32(TimeUnit::Millisecond) => {
549                handle_transform!(
550                    batch,
551                    idx,
552                    Time32MillisecondArray,
553                    transform,
554                    transform_time32_millisecond,
555                    args
556                )
557            }
558            DataType::Time32(TimeUnit::Microsecond) => unreachable!(),
559            DataType::Time32(TimeUnit::Nanosecond) => unreachable!(),
560            DataType::Time64(TimeUnit::Second) => unreachable!(),
561            DataType::Time64(TimeUnit::Millisecond) => unreachable!(),
562            DataType::Time64(TimeUnit::Microsecond) => {
563                handle_transform!(
564                    batch,
565                    idx,
566                    Time64MicrosecondArray,
567                    transform,
568                    transform_time64_microsecond,
569                    args
570                )
571            }
572            DataType::Time64(TimeUnit::Nanosecond) => {
573                handle_transform!(
574                    batch,
575                    idx,
576                    Time64NanosecondArray,
577                    transform,
578                    transform_time64_nanosecond,
579                    args
580                )
581            }
582            DataType::Interval(IntervalUnit::YearMonth) => {
583                handle_transform!(
584                    batch,
585                    idx,
586                    IntervalYearMonthArray,
587                    transform,
588                    transform_interval_year_month,
589                    args
590                )
591            }
592            DataType::Interval(IntervalUnit::DayTime) => {
593                handle_transform!(
594                    batch,
595                    idx,
596                    IntervalDayTimeArray,
597                    transform,
598                    transform_interval_day_time,
599                    args
600                )
601            }
602            DataType::Interval(IntervalUnit::MonthDayNano) => {
603                handle_transform!(
604                    batch,
605                    idx,
606                    IntervalMonthDayNanoArray,
607                    transform,
608                    transform_interval_month_day_nano,
609                    args
610                )
611            }
612            DataType::Binary => {
613                handle_transform!(batch, idx, BinaryArray, transform, transform_binary, args)
614            }
615            DataType::FixedSizeBinary(_) => {
616                handle_transform!(
617                    batch,
618                    idx,
619                    FixedSizeBinaryArray,
620                    transform,
621                    transform_fixed_size_binary,
622                    args
623                )
624            }
625            DataType::LargeBinary => {
626                handle_transform!(
627                    batch,
628                    idx,
629                    LargeBinaryArray,
630                    transform,
631                    transform_large_binary,
632                    args
633                )
634            }
635            DataType::BinaryView => {
636                handle_transform!(
637                    batch,
638                    idx,
639                    BinaryViewArray,
640                    transform,
641                    transform_binary_view,
642                    args
643                )
644            }
645            DataType::Utf8 => {
646                handle_transform!(batch, idx, StringArray, transform, transform_utf8, args)
647            }
648            DataType::LargeUtf8 => {
649                handle_transform!(
650                    batch,
651                    idx,
652                    LargeStringArray,
653                    transform,
654                    transform_large_utf8,
655                    args
656                )
657            }
658            DataType::Utf8View => {
659                handle_transform!(
660                    batch,
661                    idx,
662                    StringViewArray,
663                    transform,
664                    transform_utf8_view,
665                    args
666                )
667            }
668            DataType::List(_field) => {
669                handle_transform!(batch, idx, ListArray, transform, transform_list, args)
670            }
671            DataType::ListView(_field) => {
672                handle_transform!(
673                    batch,
674                    idx,
675                    ListViewArray,
676                    transform,
677                    transform_list_view,
678                    args
679                )
680            }
681            DataType::FixedSizeList(_, _) => {
682                handle_transform!(
683                    batch,
684                    idx,
685                    FixedSizeListArray,
686                    transform,
687                    transform_fixed_size_list,
688                    args
689                )
690            }
691            DataType::LargeList(_field) => {
692                handle_transform!(
693                    batch,
694                    idx,
695                    LargeListArray,
696                    transform,
697                    transform_large_list,
698                    args
699                )
700            }
701            DataType::LargeListView(_field) => {
702                handle_transform!(
703                    batch,
704                    idx,
705                    LargeListViewArray,
706                    transform,
707                    transform_large_list_view,
708                    args
709                )
710            }
711            DataType::Decimal128(_, _) => {
712                handle_transform!(
713                    batch,
714                    idx,
715                    Decimal128Array,
716                    transform,
717                    transform_decimal128,
718                    args
719                )
720            }
721            DataType::Decimal256(_, _) => {
722                handle_transform!(
723                    batch,
724                    idx,
725                    Decimal256Array,
726                    transform,
727                    transform_decimal256,
728                    args
729                )
730            }
731            data_type => {
732                return Err(DataFusionError::NotImplemented(format!(
733                    "Unsupported transform arrow type {data_type:?}",
734                )));
735            }
736        };
737        new_arrays.push(new_array);
738        new_fields.push(new_field);
739    }
740    let new_schema = Arc::new(Schema::new(new_fields));
741    Ok(RecordBatch::try_new(new_schema, new_arrays)?)
742}
743
744pub(crate) fn transform_schema(
745    schema: SchemaRef,
746    transform: &dyn Transform,
747    remote_schema: Option<&RemoteSchemaRef>,
748) -> DFResult<SchemaRef> {
749    let empty_record = RecordBatch::new_empty(schema.clone());
750    transform_batch(empty_record, transform, &schema, None, remote_schema)
751        .map(|batch| batch.schema())
752}