datafusion_remote_table/
transform.rs

1use crate::{DFResult, RemoteField, RemoteSchemaRef};
2use datafusion::arrow::array::{
3    ArrayRef, BinaryArray, BooleanArray, Date32Array, Float16Array, Float32Array, Float64Array,
4    Int16Array, Int32Array, Int64Array, Int8Array, ListArray, RecordBatch, StringArray,
5    Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
6    TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array,
7    UInt8Array,
8};
9use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
10use datafusion::common::{project_schema, DataFusionError};
11use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
12use futures::{Stream, StreamExt};
13use std::fmt::Debug;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18pub struct TransformArgs<'a> {
19    pub col_index: usize,
20    pub field: &'a Field,
21    pub remote_field: Option<&'a RemoteField>,
22}
23
24pub trait Transform: Debug + Send + Sync {
25    fn transform_boolean(
26        &self,
27        array: &BooleanArray,
28        args: TransformArgs,
29    ) -> DFResult<(ArrayRef, Field)> {
30        Ok((Arc::new(array.clone()), args.field.clone()))
31    }
32
33    fn transform_int8(
34        &self,
35        array: &Int8Array,
36        args: TransformArgs,
37    ) -> DFResult<(ArrayRef, Field)> {
38        Ok((Arc::new(array.clone()), args.field.clone()))
39    }
40
41    fn transform_int16(
42        &self,
43        array: &Int16Array,
44        args: TransformArgs,
45    ) -> DFResult<(ArrayRef, Field)> {
46        Ok((Arc::new(array.clone()), args.field.clone()))
47    }
48
49    fn transform_int32(
50        &self,
51        array: &Int32Array,
52        args: TransformArgs,
53    ) -> DFResult<(ArrayRef, Field)> {
54        Ok((Arc::new(array.clone()), args.field.clone()))
55    }
56
57    fn transform_int64(
58        &self,
59        array: &Int64Array,
60        args: TransformArgs,
61    ) -> DFResult<(ArrayRef, Field)> {
62        Ok((Arc::new(array.clone()), args.field.clone()))
63    }
64
65    fn transform_uint8(
66        &self,
67        array: &UInt8Array,
68        args: TransformArgs,
69    ) -> DFResult<(ArrayRef, Field)> {
70        Ok((Arc::new(array.clone()), args.field.clone()))
71    }
72
73    fn transform_uint16(
74        &self,
75        array: &UInt16Array,
76        args: TransformArgs,
77    ) -> DFResult<(ArrayRef, Field)> {
78        Ok((Arc::new(array.clone()), args.field.clone()))
79    }
80
81    fn transform_uint32(
82        &self,
83        array: &UInt32Array,
84        args: TransformArgs,
85    ) -> DFResult<(ArrayRef, Field)> {
86        Ok((Arc::new(array.clone()), args.field.clone()))
87    }
88
89    fn transform_uint64(
90        &self,
91        array: &UInt64Array,
92        args: TransformArgs,
93    ) -> DFResult<(ArrayRef, Field)> {
94        Ok((Arc::new(array.clone()), args.field.clone()))
95    }
96
97    fn transform_float16(
98        &self,
99        array: &Float16Array,
100        args: TransformArgs,
101    ) -> DFResult<(ArrayRef, Field)> {
102        Ok((Arc::new(array.clone()), args.field.clone()))
103    }
104
105    fn transform_float32(
106        &self,
107        array: &Float32Array,
108        args: TransformArgs,
109    ) -> DFResult<(ArrayRef, Field)> {
110        Ok((Arc::new(array.clone()), args.field.clone()))
111    }
112
113    fn transform_float64(
114        &self,
115        array: &Float64Array,
116        args: TransformArgs,
117    ) -> DFResult<(ArrayRef, Field)> {
118        Ok((Arc::new(array.clone()), args.field.clone()))
119    }
120
121    fn transform_utf8(
122        &self,
123        array: &StringArray,
124        args: TransformArgs,
125    ) -> DFResult<(ArrayRef, Field)> {
126        Ok((Arc::new(array.clone()), args.field.clone()))
127    }
128
129    fn transform_binary(
130        &self,
131        array: &BinaryArray,
132        args: TransformArgs,
133    ) -> DFResult<(ArrayRef, Field)> {
134        Ok((Arc::new(array.clone()), args.field.clone()))
135    }
136
137    fn transform_timestamp_second(
138        &self,
139        array: &TimestampSecondArray,
140        args: TransformArgs,
141    ) -> DFResult<(ArrayRef, Field)> {
142        Ok((Arc::new(array.clone()), args.field.clone()))
143    }
144
145    fn transform_timestamp_millisecond(
146        &self,
147        array: &TimestampMillisecondArray,
148        args: TransformArgs,
149    ) -> DFResult<(ArrayRef, Field)> {
150        Ok((Arc::new(array.clone()), args.field.clone()))
151    }
152
153    fn transform_timestamp_microsecond(
154        &self,
155        array: &TimestampMicrosecondArray,
156        args: TransformArgs,
157    ) -> DFResult<(ArrayRef, Field)> {
158        Ok((Arc::new(array.clone()), args.field.clone()))
159    }
160
161    fn transform_timestamp_nanosecond(
162        &self,
163        array: &TimestampNanosecondArray,
164        args: TransformArgs,
165    ) -> DFResult<(ArrayRef, Field)> {
166        Ok((Arc::new(array.clone()), args.field.clone()))
167    }
168
169    fn transform_time64_nanosecond(
170        &self,
171        array: &Time64NanosecondArray,
172        args: TransformArgs,
173    ) -> DFResult<(ArrayRef, Field)> {
174        Ok((Arc::new(array.clone()), args.field.clone()))
175    }
176
177    fn transform_date32(
178        &self,
179        array: &Date32Array,
180        args: TransformArgs,
181    ) -> DFResult<(ArrayRef, Field)> {
182        Ok((Arc::new(array.clone()), args.field.clone()))
183    }
184
185    fn transform_list(
186        &self,
187        array: &ListArray,
188        args: TransformArgs,
189    ) -> DFResult<(ArrayRef, Field)> {
190        Ok((Arc::new(array.clone()), args.field.clone()))
191    }
192}
193
194pub(crate) struct TransformStream {
195    input: SendableRecordBatchStream,
196    transform: Arc<dyn Transform>,
197    table_schema: SchemaRef,
198    projection: Option<Vec<usize>>,
199    projected_schema: SchemaRef,
200    remote_schema: Option<RemoteSchemaRef>,
201}
202
203impl TransformStream {
204    pub fn try_new(
205        input: SendableRecordBatchStream,
206        transform: Arc<dyn Transform>,
207        table_schema: SchemaRef,
208        projection: Option<Vec<usize>>,
209        remote_schema: Option<RemoteSchemaRef>,
210    ) -> DFResult<Self> {
211        let projected_schema = project_schema(&table_schema, projection.as_ref())?;
212        Ok(Self {
213            input,
214            transform,
215            table_schema,
216            projection,
217            projected_schema,
218            remote_schema,
219        })
220    }
221}
222
223impl Stream for TransformStream {
224    type Item = DFResult<RecordBatch>;
225    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
226        match self.input.poll_next_unpin(cx) {
227            Poll::Ready(Some(Ok(batch))) => {
228                match transform_batch(
229                    batch,
230                    self.transform.as_ref(),
231                    &self.table_schema,
232                    self.projection.as_ref(),
233                    self.remote_schema.as_ref(),
234                ) {
235                    Ok(result) => Poll::Ready(Some(Ok(result))),
236                    Err(e) => Poll::Ready(Some(Err(e))),
237                }
238            }
239            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
240            Poll::Ready(None) => Poll::Ready(None),
241            Poll::Pending => Poll::Pending,
242        }
243    }
244}
245
246impl RecordBatchStream for TransformStream {
247    fn schema(&self) -> SchemaRef {
248        self.projected_schema.clone()
249    }
250}
251
252pub(crate) fn transform_batch(
253    batch: RecordBatch,
254    transform: &dyn Transform,
255    table_schema: &SchemaRef,
256    projection: Option<&Vec<usize>>,
257    remote_schema: Option<&RemoteSchemaRef>,
258) -> DFResult<RecordBatch> {
259    let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(batch.schema().fields.len());
260    let mut new_fields: Vec<Field> = Vec::with_capacity(batch.schema().fields.len());
261    let all_col_indexes = (0..table_schema.fields.len()).collect::<Vec<usize>>();
262    let projected_col_indexes = projection.unwrap_or(&all_col_indexes);
263
264    for (idx, col_index) in projected_col_indexes.iter().enumerate() {
265        let field = table_schema.field(*col_index);
266        let remote_field = remote_schema.map(|schema| &schema.fields[*col_index]);
267        let args = TransformArgs {
268            col_index: *col_index,
269            field,
270            remote_field,
271        };
272
273        let (new_array, new_field) = match &field.data_type() {
274            DataType::Boolean => {
275                let array = batch
276                    .column(idx)
277                    .as_any()
278                    .downcast_ref::<BooleanArray>()
279                    .expect("Failed to downcast to BooleanArray");
280                transform.transform_boolean(array, args)?
281            }
282            DataType::Int8 => {
283                let array = batch
284                    .column(idx)
285                    .as_any()
286                    .downcast_ref::<Int8Array>()
287                    .expect("Failed to downcast to Int8Array");
288                transform.transform_int8(array, args)?
289            }
290            DataType::Int16 => {
291                let array = batch
292                    .column(idx)
293                    .as_any()
294                    .downcast_ref::<Int16Array>()
295                    .expect("Failed to downcast to Int16Array");
296                transform.transform_int16(array, args)?
297            }
298            DataType::Int32 => {
299                let array = batch
300                    .column(idx)
301                    .as_any()
302                    .downcast_ref::<Int32Array>()
303                    .expect("Failed to downcast to Int32Array");
304                transform.transform_int32(array, args)?
305            }
306            DataType::Int64 => {
307                let array = batch
308                    .column(idx)
309                    .as_any()
310                    .downcast_ref::<Int64Array>()
311                    .expect("Failed to downcast to Int64Array");
312                transform.transform_int64(array, args)?
313            }
314            DataType::UInt8 => {
315                let array = batch
316                    .column(idx)
317                    .as_any()
318                    .downcast_ref::<UInt8Array>()
319                    .expect("Failed to downcast to UInt8Array");
320                transform.transform_uint8(array, args)?
321            }
322            DataType::UInt16 => {
323                let array = batch
324                    .column(idx)
325                    .as_any()
326                    .downcast_ref::<UInt16Array>()
327                    .expect("Failed to downcast to UInt16Array");
328                transform.transform_uint16(array, args)?
329            }
330            DataType::UInt32 => {
331                let array = batch
332                    .column(idx)
333                    .as_any()
334                    .downcast_ref::<UInt32Array>()
335                    .expect("Failed to downcast to UInt32Array");
336                transform.transform_uint32(array, args)?
337            }
338            DataType::UInt64 => {
339                let array = batch
340                    .column(idx)
341                    .as_any()
342                    .downcast_ref::<UInt64Array>()
343                    .expect("Failed to downcast to UInt64Array");
344                transform.transform_uint64(array, args)?
345            }
346            DataType::Float16 => {
347                let array = batch
348                    .column(idx)
349                    .as_any()
350                    .downcast_ref::<Float16Array>()
351                    .expect("Failed to downcast to Float16Array");
352                transform.transform_float16(array, args)?
353            }
354            DataType::Float32 => {
355                let array = batch
356                    .column(idx)
357                    .as_any()
358                    .downcast_ref::<Float32Array>()
359                    .expect("Failed to downcast to Float32Array");
360                transform.transform_float32(array, args)?
361            }
362            DataType::Float64 => {
363                let array = batch
364                    .column(idx)
365                    .as_any()
366                    .downcast_ref::<Float64Array>()
367                    .expect("Failed to downcast to Float64Array");
368                transform.transform_float64(array, args)?
369            }
370            DataType::Utf8 => {
371                let array = batch
372                    .column(idx)
373                    .as_any()
374                    .downcast_ref::<StringArray>()
375                    .expect("Failed to downcast to StringArray");
376                transform.transform_utf8(array, args)?
377            }
378            DataType::Binary => {
379                let array = batch
380                    .column(idx)
381                    .as_any()
382                    .downcast_ref::<BinaryArray>()
383                    .expect("Failed to downcast to BinaryArray");
384                transform.transform_binary(array, args)?
385            }
386            DataType::Date32 => {
387                let array = batch
388                    .column(idx)
389                    .as_any()
390                    .downcast_ref::<Date32Array>()
391                    .expect("Failed to downcast to Date32Array");
392                transform.transform_date32(array, args)?
393            }
394            DataType::Timestamp(TimeUnit::Second, _) => {
395                let array = batch
396                    .column(idx)
397                    .as_any()
398                    .downcast_ref::<TimestampSecondArray>()
399                    .expect("Failed to downcast to TimestampSecondArray");
400                transform.transform_timestamp_second(array, args)?
401            }
402            DataType::Timestamp(TimeUnit::Millisecond, _) => {
403                let array = batch
404                    .column(idx)
405                    .as_any()
406                    .downcast_ref::<TimestampMillisecondArray>()
407                    .expect("Failed to downcast to TimestampMillisecondArray");
408                transform.transform_timestamp_millisecond(array, args)?
409            }
410            DataType::Timestamp(TimeUnit::Microsecond, _) => {
411                let array = batch
412                    .column(idx)
413                    .as_any()
414                    .downcast_ref::<TimestampMicrosecondArray>()
415                    .expect("Failed to downcast to TimestampMicrosecondArray");
416                transform.transform_timestamp_microsecond(array, args)?
417            }
418            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
419                let array = batch
420                    .column(idx)
421                    .as_any()
422                    .downcast_ref::<TimestampNanosecondArray>()
423                    .expect("Failed to downcast to TimestampNanosecondArray");
424                transform.transform_timestamp_nanosecond(array, args)?
425            }
426            DataType::Time64(TimeUnit::Nanosecond) => {
427                let array = batch
428                    .column(idx)
429                    .as_any()
430                    .downcast_ref::<Time64NanosecondArray>()
431                    .expect("Failed to downcast to Time64NanosecondArray");
432                transform.transform_time64_nanosecond(array, args)?
433            }
434            DataType::List(_field) => {
435                let array = batch
436                    .column(idx)
437                    .as_any()
438                    .downcast_ref::<ListArray>()
439                    .expect("Failed to downcast to ListArray");
440                transform.transform_list(array, args)?
441            }
442            data_type => {
443                return Err(DataFusionError::NotImplemented(format!(
444                    "Unsupported arrow type {data_type:?}",
445                )))
446            }
447        };
448        new_arrays.push(new_array);
449        new_fields.push(new_field);
450    }
451    let new_schema = Arc::new(Schema::new(new_fields));
452    Ok(RecordBatch::try_new(new_schema, new_arrays)?)
453}