datafusion_remote_table/
transform.rs

1use crate::{DFResult, RemoteField, RemoteSchema};
2use datafusion::arrow::array::{
3    ArrayRef, BinaryArray, BooleanArray, Date32Array, Float16Array, Float32Array, Float64Array,
4    Int16Array, Int32Array, Int64Array, Int8Array, ListArray, RecordBatch, StringArray,
5    Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
6    TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array,
7    UInt8Array,
8};
9use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
10use datafusion::common::DataFusionError;
11use std::fmt::Debug;
12use std::sync::Arc;
13
14pub trait Transform: Debug + Send + Sync {
15    fn transform_boolean(
16        &self,
17        array: &BooleanArray,
18        remote_field: &RemoteField,
19    ) -> DFResult<(ArrayRef, Field)> {
20        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
21    }
22
23    fn transform_int8(
24        &self,
25        array: &Int8Array,
26        remote_field: &RemoteField,
27    ) -> DFResult<(ArrayRef, Field)> {
28        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
29    }
30
31    fn transform_int16(
32        &self,
33        array: &Int16Array,
34        remote_field: &RemoteField,
35    ) -> DFResult<(ArrayRef, Field)> {
36        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
37    }
38
39    fn transform_int32(
40        &self,
41        array: &Int32Array,
42        remote_field: &RemoteField,
43    ) -> DFResult<(ArrayRef, Field)> {
44        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
45    }
46
47    fn transform_int64(
48        &self,
49        array: &Int64Array,
50        remote_field: &RemoteField,
51    ) -> DFResult<(ArrayRef, Field)> {
52        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
53    }
54
55    fn transform_uint8(
56        &self,
57        array: &UInt8Array,
58        remote_field: &RemoteField,
59    ) -> DFResult<(ArrayRef, Field)> {
60        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
61    }
62
63    fn transform_uint16(
64        &self,
65        array: &UInt16Array,
66        remote_field: &RemoteField,
67    ) -> DFResult<(ArrayRef, Field)> {
68        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
69    }
70
71    fn transform_uint32(
72        &self,
73        array: &UInt32Array,
74        remote_field: &RemoteField,
75    ) -> DFResult<(ArrayRef, Field)> {
76        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
77    }
78
79    fn transform_uint64(
80        &self,
81        array: &UInt64Array,
82        remote_field: &RemoteField,
83    ) -> DFResult<(ArrayRef, Field)> {
84        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
85    }
86
87    fn transform_float16(
88        &self,
89        array: &Float16Array,
90        remote_field: &RemoteField,
91    ) -> DFResult<(ArrayRef, Field)> {
92        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
93    }
94
95    fn transform_float32(
96        &self,
97        array: &Float32Array,
98        remote_field: &RemoteField,
99    ) -> DFResult<(ArrayRef, Field)> {
100        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
101    }
102
103    fn transform_float64(
104        &self,
105        array: &Float64Array,
106        remote_field: &RemoteField,
107    ) -> DFResult<(ArrayRef, Field)> {
108        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
109    }
110
111    fn transform_utf8(
112        &self,
113        array: &StringArray,
114        remote_field: &RemoteField,
115    ) -> DFResult<(ArrayRef, Field)> {
116        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
117    }
118
119    fn transform_binary(
120        &self,
121        array: &BinaryArray,
122        remote_field: &RemoteField,
123    ) -> DFResult<(ArrayRef, Field)> {
124        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
125    }
126
127    fn transform_timestamp_second(
128        &self,
129        array: &TimestampSecondArray,
130        remote_field: &RemoteField,
131    ) -> DFResult<(ArrayRef, Field)> {
132        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
133    }
134
135    fn transform_timestamp_millisecond(
136        &self,
137        array: &TimestampMillisecondArray,
138        remote_field: &RemoteField,
139    ) -> DFResult<(ArrayRef, Field)> {
140        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
141    }
142
143    fn transform_timestamp_microsecond(
144        &self,
145        array: &TimestampMicrosecondArray,
146        remote_field: &RemoteField,
147    ) -> DFResult<(ArrayRef, Field)> {
148        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
149    }
150
151    fn transform_timestamp_nanosecond(
152        &self,
153        array: &TimestampNanosecondArray,
154        remote_field: &RemoteField,
155    ) -> DFResult<(ArrayRef, Field)> {
156        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
157    }
158
159    fn transform_time64_nanosecond(
160        &self,
161        array: &Time64NanosecondArray,
162        remote_field: &RemoteField,
163    ) -> DFResult<(ArrayRef, Field)> {
164        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
165    }
166
167    fn transform_date32(
168        &self,
169        array: &Date32Array,
170        remote_field: &RemoteField,
171    ) -> DFResult<(ArrayRef, Field)> {
172        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
173    }
174
175    fn transform_list(
176        &self,
177        array: &ListArray,
178        remote_field: &RemoteField,
179    ) -> DFResult<(ArrayRef, Field)> {
180        Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
181    }
182}
183
184pub(crate) fn transform_batch(
185    batch: RecordBatch,
186    transform: &dyn Transform,
187    remote_schema: &RemoteSchema,
188) -> DFResult<RecordBatch> {
189    let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(remote_schema.fields.len());
190    let mut new_fields: Vec<Field> = Vec::with_capacity(remote_schema.fields.len());
191    for (idx, remote_field) in remote_schema.fields.iter().enumerate() {
192        let (new_array, new_field) = match &remote_field.remote_type.to_arrow_type() {
193            DataType::Boolean => {
194                let array = batch
195                    .column(idx)
196                    .as_any()
197                    .downcast_ref::<BooleanArray>()
198                    .expect("Failed to downcast to BooleanArray");
199                transform.transform_boolean(array, remote_field)?
200            }
201            DataType::Int8 => {
202                let array = batch
203                    .column(idx)
204                    .as_any()
205                    .downcast_ref::<Int8Array>()
206                    .expect("Failed to downcast to Int8Array");
207                transform.transform_int8(array, remote_field)?
208            }
209            DataType::Int16 => {
210                let array = batch
211                    .column(idx)
212                    .as_any()
213                    .downcast_ref::<Int16Array>()
214                    .expect("Failed to downcast to Int16Array");
215                transform.transform_int16(array, remote_field)?
216            }
217            DataType::Int32 => {
218                let array = batch
219                    .column(idx)
220                    .as_any()
221                    .downcast_ref::<Int32Array>()
222                    .expect("Failed to downcast to Int32Array");
223                transform.transform_int32(array, remote_field)?
224            }
225            DataType::Int64 => {
226                let array = batch
227                    .column(idx)
228                    .as_any()
229                    .downcast_ref::<Int64Array>()
230                    .expect("Failed to downcast to Int64Array");
231                transform.transform_int64(array, remote_field)?
232            }
233            DataType::UInt8 => {
234                let array = batch
235                    .column(idx)
236                    .as_any()
237                    .downcast_ref::<UInt8Array>()
238                    .expect("Failed to downcast to UInt8Array");
239                transform.transform_uint8(array, remote_field)?
240            }
241            DataType::UInt16 => {
242                let array = batch
243                    .column(idx)
244                    .as_any()
245                    .downcast_ref::<UInt16Array>()
246                    .expect("Failed to downcast to UInt16Array");
247                transform.transform_uint16(array, remote_field)?
248            }
249            DataType::UInt32 => {
250                let array = batch
251                    .column(idx)
252                    .as_any()
253                    .downcast_ref::<UInt32Array>()
254                    .expect("Failed to downcast to UInt32Array");
255                transform.transform_uint32(array, remote_field)?
256            }
257            DataType::UInt64 => {
258                let array = batch
259                    .column(idx)
260                    .as_any()
261                    .downcast_ref::<UInt64Array>()
262                    .expect("Failed to downcast to UInt64Array");
263                transform.transform_uint64(array, remote_field)?
264            }
265            DataType::Float16 => {
266                let array = batch
267                    .column(idx)
268                    .as_any()
269                    .downcast_ref::<Float16Array>()
270                    .expect("Failed to downcast to Float16Array");
271                transform.transform_float16(array, remote_field)?
272            }
273            DataType::Float32 => {
274                let array = batch
275                    .column(idx)
276                    .as_any()
277                    .downcast_ref::<Float32Array>()
278                    .expect("Failed to downcast to Float32Array");
279                transform.transform_float32(array, remote_field)?
280            }
281            DataType::Float64 => {
282                let array = batch
283                    .column(idx)
284                    .as_any()
285                    .downcast_ref::<Float64Array>()
286                    .expect("Failed to downcast to Float64Array");
287                transform.transform_float64(array, remote_field)?
288            }
289            DataType::Utf8 => {
290                let array = batch
291                    .column(idx)
292                    .as_any()
293                    .downcast_ref::<StringArray>()
294                    .expect("Failed to downcast to StringArray");
295                transform.transform_utf8(array, remote_field)?
296            }
297            DataType::Binary => {
298                let array = batch
299                    .column(idx)
300                    .as_any()
301                    .downcast_ref::<BinaryArray>()
302                    .expect("Failed to downcast to BinaryArray");
303                transform.transform_binary(array, remote_field)?
304            }
305            DataType::Date32 => {
306                let array = batch
307                    .column(idx)
308                    .as_any()
309                    .downcast_ref::<Date32Array>()
310                    .expect("Failed to downcast to Date32Array");
311                transform.transform_date32(array, remote_field)?
312            }
313            DataType::Timestamp(TimeUnit::Second, _) => {
314                let array = batch
315                    .column(idx)
316                    .as_any()
317                    .downcast_ref::<TimestampSecondArray>()
318                    .expect("Failed to downcast to TimestampSecondArray");
319                transform.transform_timestamp_second(array, remote_field)?
320            }
321            DataType::Timestamp(TimeUnit::Millisecond, _) => {
322                let array = batch
323                    .column(idx)
324                    .as_any()
325                    .downcast_ref::<TimestampMillisecondArray>()
326                    .expect("Failed to downcast to TimestampMillisecondArray");
327                transform.transform_timestamp_millisecond(array, remote_field)?
328            }
329            DataType::Timestamp(TimeUnit::Microsecond, _) => {
330                let array = batch
331                    .column(idx)
332                    .as_any()
333                    .downcast_ref::<TimestampMicrosecondArray>()
334                    .expect("Failed to downcast to TimestampMicrosecondArray");
335                transform.transform_timestamp_microsecond(array, remote_field)?
336            }
337            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
338                let array = batch
339                    .column(idx)
340                    .as_any()
341                    .downcast_ref::<TimestampNanosecondArray>()
342                    .expect("Failed to downcast to TimestampNanosecondArray");
343                transform.transform_timestamp_nanosecond(array, remote_field)?
344            }
345            DataType::Time64(TimeUnit::Nanosecond) => {
346                let array = batch
347                    .column(idx)
348                    .as_any()
349                    .downcast_ref::<Time64NanosecondArray>()
350                    .expect("Failed to downcast to Time64NanosecondArray");
351                transform.transform_time64_nanosecond(array, remote_field)?
352            }
353            DataType::List(_field) => {
354                let array = batch
355                    .column(idx)
356                    .as_any()
357                    .downcast_ref::<ListArray>()
358                    .expect("Failed to downcast to ListArray");
359                transform.transform_list(array, remote_field)?
360            }
361            data_type => {
362                return Err(DataFusionError::NotImplemented(format!(
363                    "Unsupported arrow type {data_type:?}",
364                )))
365            }
366        };
367        new_arrays.push(new_array);
368        new_fields.push(new_field);
369    }
370    let new_schema = Arc::new(Schema::new(new_fields));
371    Ok(RecordBatch::try_new(new_schema, new_arrays)?)
372}