datafusion_remote_table/
transform.rs

1use crate::{DFResult, RemoteField, RemoteSchema};
2use datafusion::arrow::array::{
3    ArrayRef, BinaryArray, BooleanArray, Date32Array, Float16Array, Float32Array, Float64Array,
4    Int16Array, Int32Array, Int64Array, Int8Array, ListArray, RecordBatch, StringArray,
5    Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
6    TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array,
7    UInt8Array,
8};
9use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
10use datafusion::common::DataFusionError;
11use std::fmt::Debug;
12use std::sync::Arc;
13
14pub trait Transform: Debug + Send + Sync {
15    fn transform_boolean(
16        &self,
17        array: &BooleanArray,
18        remote_field: &RemoteField,
19    ) -> DFResult<(ArrayRef, Field)> {
20        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
21    }
22
23    fn transform_int8(
24        &self,
25        array: &Int8Array,
26        remote_field: &RemoteField,
27    ) -> DFResult<(ArrayRef, Field)> {
28        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
29    }
30
31    fn transform_int16(
32        &self,
33        array: &Int16Array,
34        remote_field: &RemoteField,
35    ) -> DFResult<(ArrayRef, Field)> {
36        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
37    }
38
39    fn transform_int32(
40        &self,
41        array: &Int32Array,
42        remote_field: &RemoteField,
43    ) -> DFResult<(ArrayRef, Field)> {
44        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
45    }
46
47    fn transform_int64(
48        &self,
49        array: &Int64Array,
50        remote_field: &RemoteField,
51    ) -> DFResult<(ArrayRef, Field)> {
52        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
53    }
54
55    fn transform_uint8(
56        &self,
57        array: &UInt8Array,
58        remote_field: &RemoteField,
59    ) -> DFResult<(ArrayRef, Field)> {
60        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
61    }
62
63    fn transform_uint16(
64        &self,
65        array: &UInt16Array,
66        remote_field: &RemoteField,
67    ) -> DFResult<(ArrayRef, Field)> {
68        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
69    }
70
71    fn transform_uint32(
72        &self,
73        array: &UInt32Array,
74        remote_field: &RemoteField,
75    ) -> DFResult<(ArrayRef, Field)> {
76        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
77    }
78
79    fn transform_uint64(
80        &self,
81        array: &UInt64Array,
82        remote_field: &RemoteField,
83    ) -> DFResult<(ArrayRef, Field)> {
84        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
85    }
86
87    fn transform_float16(
88        &self,
89        array: &Float16Array,
90        remote_field: &RemoteField,
91    ) -> DFResult<(ArrayRef, Field)> {
92        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
93    }
94
95    fn transform_float32(
96        &self,
97        array: &Float32Array,
98        remote_field: &RemoteField,
99    ) -> DFResult<(ArrayRef, Field)> {
100        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
101    }
102
103    fn transform_float64(
104        &self,
105        array: &Float64Array,
106        remote_field: &RemoteField,
107    ) -> DFResult<(ArrayRef, Field)> {
108        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
109    }
110
111    fn transform_utf8(
112        &self,
113        array: &StringArray,
114        remote_field: &RemoteField,
115    ) -> DFResult<(ArrayRef, Field)> {
116        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
117    }
118
119    fn transform_binary(
120        &self,
121        array: &BinaryArray,
122        remote_field: &RemoteField,
123    ) -> DFResult<(ArrayRef, Field)> {
124        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
125    }
126
127    fn transform_timestamp_second(
128        &self,
129        array: &TimestampSecondArray,
130        remote_field: &RemoteField,
131    ) -> DFResult<(ArrayRef, Field)> {
132        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
133    }
134
135    fn transform_timestamp_millisecond(
136        &self,
137        array: &TimestampMillisecondArray,
138        remote_field: &RemoteField,
139    ) -> DFResult<(ArrayRef, Field)> {
140        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
141    }
142
143    fn transform_timestamp_microsecond(
144        &self,
145        array: &TimestampMicrosecondArray,
146        remote_field: &RemoteField,
147    ) -> DFResult<(ArrayRef, Field)> {
148        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
149    }
150
151    fn transform_timestamp_nanosecond(
152        &self,
153        array: &TimestampNanosecondArray,
154        remote_field: &RemoteField,
155    ) -> DFResult<(ArrayRef, Field)> {
156        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
157    }
158
159    fn transform_time64_nanosecond(
160        &self,
161        array: &Time64NanosecondArray,
162        remote_field: &RemoteField,
163    ) -> DFResult<(ArrayRef, Field)> {
164        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
165    }
166
167    fn transform_date32(
168        &self,
169        array: &Date32Array,
170        remote_field: &RemoteField,
171    ) -> DFResult<(ArrayRef, Field)> {
172        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
173    }
174
175    fn transform_list(
176        &self,
177        array: &ListArray,
178        remote_field: &RemoteField,
179    ) -> DFResult<(ArrayRef, Field)> {
180        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
181    }
182}
183
184pub(crate) fn transform_batch(
185    batch: RecordBatch,
186    transform: &dyn Transform,
187    remote_schema: &RemoteSchema,
188) -> DFResult<RecordBatch> {
189    let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(remote_schema.fields.len());
190    let mut new_fields: Vec<Field> = Vec::with_capacity(remote_schema.fields.len());
191    for (idx, remote_field) in remote_schema.fields.iter().enumerate() {
192        let (new_array, new_field) = match &remote_field.remote_type.to_arrow_type() {
193            // TODO use a macro to reduce boilerplate
194            DataType::Boolean => {
195                let array = batch
196                    .column(idx)
197                    .as_any()
198                    .downcast_ref::<BooleanArray>()
199                    .expect("Failed to downcast to BooleanArray");
200                transform.transform_boolean(array, remote_field)?
201            }
202            DataType::Int8 => {
203                let array = batch
204                    .column(idx)
205                    .as_any()
206                    .downcast_ref::<Int8Array>()
207                    .expect("Failed to downcast to Int8Array");
208                transform.transform_int8(array, remote_field)?
209            }
210            DataType::Int16 => {
211                let array = batch
212                    .column(idx)
213                    .as_any()
214                    .downcast_ref::<Int16Array>()
215                    .expect("Failed to downcast to Int16Array");
216                transform.transform_int16(array, remote_field)?
217            }
218            DataType::Int32 => {
219                let array = batch
220                    .column(idx)
221                    .as_any()
222                    .downcast_ref::<Int32Array>()
223                    .expect("Failed to downcast to Int32Array");
224                transform.transform_int32(array, remote_field)?
225            }
226            DataType::Int64 => {
227                let array = batch
228                    .column(idx)
229                    .as_any()
230                    .downcast_ref::<Int64Array>()
231                    .expect("Failed to downcast to Int64Array");
232                transform.transform_int64(array, remote_field)?
233            }
234            DataType::UInt8 => {
235                let array = batch
236                    .column(idx)
237                    .as_any()
238                    .downcast_ref::<UInt8Array>()
239                    .expect("Failed to downcast to UInt8Array");
240                transform.transform_uint8(array, remote_field)?
241            }
242            DataType::UInt16 => {
243                let array = batch
244                    .column(idx)
245                    .as_any()
246                    .downcast_ref::<UInt16Array>()
247                    .expect("Failed to downcast to UInt16Array");
248                transform.transform_uint16(array, remote_field)?
249            }
250            DataType::UInt32 => {
251                let array = batch
252                    .column(idx)
253                    .as_any()
254                    .downcast_ref::<UInt32Array>()
255                    .expect("Failed to downcast to UInt32Array");
256                transform.transform_uint32(array, remote_field)?
257            }
258            DataType::UInt64 => {
259                let array = batch
260                    .column(idx)
261                    .as_any()
262                    .downcast_ref::<UInt64Array>()
263                    .expect("Failed to downcast to UInt64Array");
264                transform.transform_uint64(array, remote_field)?
265            }
266            DataType::Float16 => {
267                let array = batch
268                    .column(idx)
269                    .as_any()
270                    .downcast_ref::<Float16Array>()
271                    .expect("Failed to downcast to Float16Array");
272                transform.transform_float16(array, remote_field)?
273            }
274            DataType::Float32 => {
275                let array = batch
276                    .column(idx)
277                    .as_any()
278                    .downcast_ref::<Float32Array>()
279                    .expect("Failed to downcast to Float32Array");
280                transform.transform_float32(array, remote_field)?
281            }
282            DataType::Float64 => {
283                let array = batch
284                    .column(idx)
285                    .as_any()
286                    .downcast_ref::<Float64Array>()
287                    .expect("Failed to downcast to Float64Array");
288                transform.transform_float64(array, remote_field)?
289            }
290            DataType::Utf8 => {
291                let array = batch
292                    .column(idx)
293                    .as_any()
294                    .downcast_ref::<StringArray>()
295                    .expect("Failed to downcast to StringArray");
296                transform.transform_utf8(array, remote_field)?
297            }
298            DataType::Binary => {
299                let array = batch
300                    .column(idx)
301                    .as_any()
302                    .downcast_ref::<BinaryArray>()
303                    .expect("Failed to downcast to BinaryArray");
304                transform.transform_binary(array, remote_field)?
305            }
306            DataType::Date32 => {
307                let array = batch
308                    .column(idx)
309                    .as_any()
310                    .downcast_ref::<Date32Array>()
311                    .expect("Failed to downcast to Date32Array");
312                transform.transform_date32(array, remote_field)?
313            }
314            DataType::Timestamp(TimeUnit::Second, _) => {
315                let array = batch
316                    .column(idx)
317                    .as_any()
318                    .downcast_ref::<TimestampSecondArray>()
319                    .expect("Failed to downcast to TimestampSecondArray");
320                transform.transform_timestamp_second(array, remote_field)?
321            }
322            DataType::Timestamp(TimeUnit::Millisecond, _) => {
323                let array = batch
324                    .column(idx)
325                    .as_any()
326                    .downcast_ref::<TimestampMillisecondArray>()
327                    .expect("Failed to downcast to TimestampMillisecondArray");
328                transform.transform_timestamp_millisecond(array, remote_field)?
329            }
330            DataType::Timestamp(TimeUnit::Microsecond, _) => {
331                let array = batch
332                    .column(idx)
333                    .as_any()
334                    .downcast_ref::<TimestampMicrosecondArray>()
335                    .expect("Failed to downcast to TimestampMicrosecondArray");
336                transform.transform_timestamp_microsecond(array, remote_field)?
337            }
338            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
339                let array = batch
340                    .column(idx)
341                    .as_any()
342                    .downcast_ref::<TimestampNanosecondArray>()
343                    .expect("Failed to downcast to TimestampNanosecondArray");
344                transform.transform_timestamp_nanosecond(array, remote_field)?
345            }
346            DataType::Time64(TimeUnit::Nanosecond) => {
347                let array = batch
348                    .column(idx)
349                    .as_any()
350                    .downcast_ref::<Time64NanosecondArray>()
351                    .expect("Failed to downcast to Time64NanosecondArray");
352                transform.transform_time64_nanosecond(array, remote_field)?
353            }
354            DataType::List(_field) => {
355                let array = batch
356                    .column(idx)
357                    .as_any()
358                    .downcast_ref::<ListArray>()
359                    .expect("Failed to downcast to ListArray");
360                transform.transform_list(array, remote_field)?
361            }
362            data_type => {
363                return Err(DataFusionError::NotImplemented(format!(
364                    "Unsupported arrow type {data_type:?}",
365                )))
366            }
367        };
368        new_arrays.push(new_array);
369        new_fields.push(new_field);
370    }
371    let new_schema = Arc::new(Schema::new(new_fields));
372    Ok(RecordBatch::try_new(new_schema, new_arrays)?)
373}