datafusion_remote_table/
transform.rs

1use crate::{DFResult, RemoteField, RemoteSchemaRef};
2use datafusion::arrow::array::{
3    Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Float16Array,
4    Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, LargeBinaryArray,
5    LargeStringArray, ListArray, NullArray, RecordBatch, StringArray, Time32MillisecondArray,
6    Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
7    TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt8Array,
8    UInt16Array, UInt32Array, UInt64Array,
9};
10use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
13use futures::{Stream, StreamExt};
14use std::any::Any;
15use std::fmt::Debug;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20pub struct TransformArgs<'a> {
21    pub col_index: usize,
22    pub field: &'a Field,
23    pub remote_field: Option<&'a RemoteField>,
24}
25
26pub trait Transform: Debug + Send + Sync {
27    fn as_any(&self) -> &dyn Any;
28
29    fn transform_null(
30        &self,
31        array: &NullArray,
32        args: TransformArgs,
33    ) -> DFResult<(ArrayRef, Field)> {
34        Ok((Arc::new(array.clone()), args.field.clone()))
35    }
36
37    fn transform_boolean(
38        &self,
39        array: &BooleanArray,
40        args: TransformArgs,
41    ) -> DFResult<(ArrayRef, Field)> {
42        Ok((Arc::new(array.clone()), args.field.clone()))
43    }
44
45    fn transform_int8(
46        &self,
47        array: &Int8Array,
48        args: TransformArgs,
49    ) -> DFResult<(ArrayRef, Field)> {
50        Ok((Arc::new(array.clone()), args.field.clone()))
51    }
52
53    fn transform_int16(
54        &self,
55        array: &Int16Array,
56        args: TransformArgs,
57    ) -> DFResult<(ArrayRef, Field)> {
58        Ok((Arc::new(array.clone()), args.field.clone()))
59    }
60
61    fn transform_int32(
62        &self,
63        array: &Int32Array,
64        args: TransformArgs,
65    ) -> DFResult<(ArrayRef, Field)> {
66        Ok((Arc::new(array.clone()), args.field.clone()))
67    }
68
69    fn transform_int64(
70        &self,
71        array: &Int64Array,
72        args: TransformArgs,
73    ) -> DFResult<(ArrayRef, Field)> {
74        Ok((Arc::new(array.clone()), args.field.clone()))
75    }
76
77    fn transform_uint8(
78        &self,
79        array: &UInt8Array,
80        args: TransformArgs,
81    ) -> DFResult<(ArrayRef, Field)> {
82        Ok((Arc::new(array.clone()), args.field.clone()))
83    }
84
85    fn transform_uint16(
86        &self,
87        array: &UInt16Array,
88        args: TransformArgs,
89    ) -> DFResult<(ArrayRef, Field)> {
90        Ok((Arc::new(array.clone()), args.field.clone()))
91    }
92
93    fn transform_uint32(
94        &self,
95        array: &UInt32Array,
96        args: TransformArgs,
97    ) -> DFResult<(ArrayRef, Field)> {
98        Ok((Arc::new(array.clone()), args.field.clone()))
99    }
100
101    fn transform_uint64(
102        &self,
103        array: &UInt64Array,
104        args: TransformArgs,
105    ) -> DFResult<(ArrayRef, Field)> {
106        Ok((Arc::new(array.clone()), args.field.clone()))
107    }
108
109    fn transform_float16(
110        &self,
111        array: &Float16Array,
112        args: TransformArgs,
113    ) -> DFResult<(ArrayRef, Field)> {
114        Ok((Arc::new(array.clone()), args.field.clone()))
115    }
116
117    fn transform_float32(
118        &self,
119        array: &Float32Array,
120        args: TransformArgs,
121    ) -> DFResult<(ArrayRef, Field)> {
122        Ok((Arc::new(array.clone()), args.field.clone()))
123    }
124
125    fn transform_float64(
126        &self,
127        array: &Float64Array,
128        args: TransformArgs,
129    ) -> DFResult<(ArrayRef, Field)> {
130        Ok((Arc::new(array.clone()), args.field.clone()))
131    }
132
133    fn transform_utf8(
134        &self,
135        array: &StringArray,
136        args: TransformArgs,
137    ) -> DFResult<(ArrayRef, Field)> {
138        Ok((Arc::new(array.clone()), args.field.clone()))
139    }
140
141    fn transform_large_utf8(
142        &self,
143        array: &LargeStringArray,
144        args: TransformArgs,
145    ) -> DFResult<(ArrayRef, Field)> {
146        Ok((Arc::new(array.clone()), args.field.clone()))
147    }
148
149    fn transform_binary(
150        &self,
151        array: &BinaryArray,
152        args: TransformArgs,
153    ) -> DFResult<(ArrayRef, Field)> {
154        Ok((Arc::new(array.clone()), args.field.clone()))
155    }
156
157    fn transform_large_binary(
158        &self,
159        array: &LargeBinaryArray,
160        args: TransformArgs,
161    ) -> DFResult<(ArrayRef, Field)> {
162        Ok((Arc::new(array.clone()), args.field.clone()))
163    }
164
165    fn transform_timestamp_second(
166        &self,
167        array: &TimestampSecondArray,
168        args: TransformArgs,
169    ) -> DFResult<(ArrayRef, Field)> {
170        Ok((Arc::new(array.clone()), args.field.clone()))
171    }
172
173    fn transform_timestamp_millisecond(
174        &self,
175        array: &TimestampMillisecondArray,
176        args: TransformArgs,
177    ) -> DFResult<(ArrayRef, Field)> {
178        Ok((Arc::new(array.clone()), args.field.clone()))
179    }
180
181    fn transform_timestamp_microsecond(
182        &self,
183        array: &TimestampMicrosecondArray,
184        args: TransformArgs,
185    ) -> DFResult<(ArrayRef, Field)> {
186        Ok((Arc::new(array.clone()), args.field.clone()))
187    }
188
189    fn transform_timestamp_nanosecond(
190        &self,
191        array: &TimestampNanosecondArray,
192        args: TransformArgs,
193    ) -> DFResult<(ArrayRef, Field)> {
194        Ok((Arc::new(array.clone()), args.field.clone()))
195    }
196
197    fn transform_date32(
198        &self,
199        array: &Date32Array,
200        args: TransformArgs,
201    ) -> DFResult<(ArrayRef, Field)> {
202        Ok((Arc::new(array.clone()), args.field.clone()))
203    }
204
205    fn transform_date64(
206        &self,
207        array: &Date64Array,
208        args: TransformArgs,
209    ) -> DFResult<(ArrayRef, Field)> {
210        Ok((Arc::new(array.clone()), args.field.clone()))
211    }
212
213    fn transform_time32_second(
214        &self,
215        array: &Time32SecondArray,
216        args: TransformArgs,
217    ) -> DFResult<(ArrayRef, Field)> {
218        Ok((Arc::new(array.clone()), args.field.clone()))
219    }
220
221    fn transform_time32_millisecond(
222        &self,
223        array: &Time32MillisecondArray,
224        args: TransformArgs,
225    ) -> DFResult<(ArrayRef, Field)> {
226        Ok((Arc::new(array.clone()), args.field.clone()))
227    }
228
229    fn transform_time64_microsecond(
230        &self,
231        array: &Time64MicrosecondArray,
232        args: TransformArgs,
233    ) -> DFResult<(ArrayRef, Field)> {
234        Ok((Arc::new(array.clone()), args.field.clone()))
235    }
236
237    fn transform_time64_nanosecond(
238        &self,
239        array: &Time64NanosecondArray,
240        args: TransformArgs,
241    ) -> DFResult<(ArrayRef, Field)> {
242        Ok((Arc::new(array.clone()), args.field.clone()))
243    }
244
245    fn transform_list(
246        &self,
247        array: &ListArray,
248        args: TransformArgs,
249    ) -> DFResult<(ArrayRef, Field)> {
250        Ok((Arc::new(array.clone()), args.field.clone()))
251    }
252}
253
254pub(crate) struct TransformStream {
255    input: SendableRecordBatchStream,
256    transform: Arc<dyn Transform>,
257    table_schema: SchemaRef,
258    projection: Option<Vec<usize>>,
259    projected_schema: SchemaRef,
260    remote_schema: Option<RemoteSchemaRef>,
261}
262
263impl TransformStream {
264    pub fn try_new(
265        input: SendableRecordBatchStream,
266        transform: Arc<dyn Transform>,
267        table_schema: SchemaRef,
268        projection: Option<Vec<usize>>,
269        remote_schema: Option<RemoteSchemaRef>,
270    ) -> DFResult<Self> {
271        let projected_schema = project_schema(&table_schema, projection.as_ref())?;
272        Ok(Self {
273            input,
274            transform,
275            table_schema,
276            projection,
277            projected_schema,
278            remote_schema,
279        })
280    }
281}
282
283impl Stream for TransformStream {
284    type Item = DFResult<RecordBatch>;
285    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
286        match self.input.poll_next_unpin(cx) {
287            Poll::Ready(Some(Ok(batch))) => {
288                match transform_batch(
289                    batch,
290                    self.transform.as_ref(),
291                    &self.table_schema,
292                    self.projection.as_ref(),
293                    self.remote_schema.as_ref(),
294                ) {
295                    Ok(result) => Poll::Ready(Some(Ok(result))),
296                    Err(e) => Poll::Ready(Some(Err(e))),
297                }
298            }
299            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
300            Poll::Ready(None) => Poll::Ready(None),
301            Poll::Pending => Poll::Pending,
302        }
303    }
304}
305
306impl RecordBatchStream for TransformStream {
307    fn schema(&self) -> SchemaRef {
308        self.projected_schema.clone()
309    }
310}
311
312macro_rules! handle_transform {
313    ($batch:expr, $idx:expr, $array_ty:ty, $transform:expr, $transform_method:ident, $transform_args:expr) => {{
314        let array = $batch
315            .column($idx)
316            .as_any()
317            .downcast_ref::<$array_ty>()
318            .expect(concat!("Failed to downcast to ", stringify!($array_ty)));
319        $transform.$transform_method(array, $transform_args)?
320    }};
321}
322
323pub(crate) fn transform_batch(
324    batch: RecordBatch,
325    transform: &dyn Transform,
326    table_schema: &SchemaRef,
327    projection: Option<&Vec<usize>>,
328    remote_schema: Option<&RemoteSchemaRef>,
329) -> DFResult<RecordBatch> {
330    let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(batch.schema().fields.len());
331    let mut new_fields: Vec<Field> = Vec::with_capacity(batch.schema().fields.len());
332    let all_col_indexes = (0..table_schema.fields.len()).collect::<Vec<usize>>();
333    let projected_col_indexes = projection.unwrap_or(&all_col_indexes);
334
335    for (idx, col_index) in projected_col_indexes.iter().enumerate() {
336        let field = table_schema.field(*col_index);
337        let remote_field = remote_schema.map(|schema| &schema.fields[*col_index]);
338        let args = TransformArgs {
339            col_index: *col_index,
340            field,
341            remote_field,
342        };
343
344        let (new_array, new_field) = match &field.data_type() {
345            DataType::Null => {
346                handle_transform!(batch, idx, NullArray, transform, transform_null, args)
347            }
348            DataType::Boolean => {
349                handle_transform!(batch, idx, BooleanArray, transform, transform_boolean, args)
350            }
351            DataType::Int8 => {
352                handle_transform!(batch, idx, Int8Array, transform, transform_int8, args)
353            }
354            DataType::Int16 => {
355                handle_transform!(batch, idx, Int16Array, transform, transform_int16, args)
356            }
357            DataType::Int32 => {
358                handle_transform!(batch, idx, Int32Array, transform, transform_int32, args)
359            }
360            DataType::Int64 => {
361                handle_transform!(batch, idx, Int64Array, transform, transform_int64, args)
362            }
363            DataType::UInt8 => {
364                handle_transform!(batch, idx, UInt8Array, transform, transform_uint8, args)
365            }
366            DataType::UInt16 => {
367                handle_transform!(batch, idx, UInt16Array, transform, transform_uint16, args)
368            }
369            DataType::UInt32 => {
370                handle_transform!(batch, idx, UInt32Array, transform, transform_uint32, args)
371            }
372            DataType::UInt64 => {
373                handle_transform!(batch, idx, UInt64Array, transform, transform_uint64, args)
374            }
375            DataType::Float16 => {
376                handle_transform!(batch, idx, Float16Array, transform, transform_float16, args)
377            }
378            DataType::Float32 => {
379                handle_transform!(batch, idx, Float32Array, transform, transform_float32, args)
380            }
381            DataType::Float64 => {
382                handle_transform!(batch, idx, Float64Array, transform, transform_float64, args)
383            }
384            DataType::Timestamp(TimeUnit::Second, _) => {
385                handle_transform!(
386                    batch,
387                    idx,
388                    TimestampSecondArray,
389                    transform,
390                    transform_timestamp_second,
391                    args
392                )
393            }
394            DataType::Timestamp(TimeUnit::Millisecond, _) => {
395                handle_transform!(
396                    batch,
397                    idx,
398                    TimestampMillisecondArray,
399                    transform,
400                    transform_timestamp_millisecond,
401                    args
402                )
403            }
404            DataType::Timestamp(TimeUnit::Microsecond, _) => {
405                handle_transform!(
406                    batch,
407                    idx,
408                    TimestampMicrosecondArray,
409                    transform,
410                    transform_timestamp_microsecond,
411                    args
412                )
413            }
414            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
415                handle_transform!(
416                    batch,
417                    idx,
418                    TimestampNanosecondArray,
419                    transform,
420                    transform_timestamp_nanosecond,
421                    args
422                )
423            }
424            DataType::Date32 => {
425                handle_transform!(batch, idx, Date32Array, transform, transform_date32, args)
426            }
427            DataType::Date64 => {
428                handle_transform!(batch, idx, Date64Array, transform, transform_date64, args)
429            }
430            DataType::Time32(TimeUnit::Second) => {
431                handle_transform!(
432                    batch,
433                    idx,
434                    Time32SecondArray,
435                    transform,
436                    transform_time32_second,
437                    args
438                )
439            }
440            DataType::Time32(TimeUnit::Millisecond) => {
441                handle_transform!(
442                    batch,
443                    idx,
444                    Time32MillisecondArray,
445                    transform,
446                    transform_time32_millisecond,
447                    args
448                )
449            }
450            DataType::Time32(TimeUnit::Microsecond) => unreachable!(),
451            DataType::Time32(TimeUnit::Nanosecond) => unreachable!(),
452            DataType::Time64(TimeUnit::Second) => unreachable!(),
453            DataType::Time64(TimeUnit::Millisecond) => unreachable!(),
454            DataType::Time64(TimeUnit::Microsecond) => {
455                handle_transform!(
456                    batch,
457                    idx,
458                    Time64MicrosecondArray,
459                    transform,
460                    transform_time64_microsecond,
461                    args
462                )
463            }
464            DataType::Time64(TimeUnit::Nanosecond) => {
465                handle_transform!(
466                    batch,
467                    idx,
468                    Time64NanosecondArray,
469                    transform,
470                    transform_time64_nanosecond,
471                    args
472                )
473            }
474            DataType::Utf8 => {
475                handle_transform!(batch, idx, StringArray, transform, transform_utf8, args)
476            }
477            DataType::LargeUtf8 => {
478                handle_transform!(
479                    batch,
480                    idx,
481                    LargeStringArray,
482                    transform,
483                    transform_large_utf8,
484                    args
485                )
486            }
487            DataType::Binary => {
488                handle_transform!(batch, idx, BinaryArray, transform, transform_binary, args)
489            }
490            DataType::LargeBinary => {
491                handle_transform!(
492                    batch,
493                    idx,
494                    LargeBinaryArray,
495                    transform,
496                    transform_large_binary,
497                    args
498                )
499            }
500            DataType::List(_field) => {
501                handle_transform!(batch, idx, ListArray, transform, transform_list, args)
502            }
503            data_type => {
504                return Err(DataFusionError::NotImplemented(format!(
505                    "Unsupported arrow type {data_type:?}",
506                )));
507            }
508        };
509        new_arrays.push(new_array);
510        new_fields.push(new_field);
511    }
512    let new_schema = Arc::new(Schema::new(new_fields));
513    Ok(RecordBatch::try_new(new_schema, new_arrays)?)
514}
515
516pub(crate) fn transform_schema(
517    schema: SchemaRef,
518    transform: Option<&Arc<dyn Transform>>,
519    remote_schema: Option<&RemoteSchemaRef>,
520) -> DFResult<SchemaRef> {
521    if let Some(transform) = transform {
522        let empty_record = RecordBatch::new_empty(schema.clone());
523        transform_batch(
524            empty_record,
525            transform.as_ref(),
526            &schema,
527            None,
528            remote_schema,
529        )
530        .map(|batch| batch.schema())
531    } else {
532        Ok(schema)
533    }
534}