1use crate::{DFResult, RemoteField, RemoteSchema};
2use datafusion::arrow::array::{
3 ArrayRef, BinaryArray, BooleanArray, Date32Array, Float16Array, Float32Array, Float64Array,
4 Int16Array, Int32Array, Int64Array, Int8Array, ListArray, RecordBatch, StringArray,
5 Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
6 TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array,
7 UInt8Array,
8};
9use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
10use datafusion::common::DataFusionError;
11use std::fmt::Debug;
12use std::sync::Arc;
13
14pub trait Transform: Debug + Send + Sync {
15 fn transform_boolean(
16 &self,
17 array: &BooleanArray,
18 remote_field: &RemoteField,
19 ) -> DFResult<(ArrayRef, Field)> {
20 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
21 }
22
23 fn transform_int8(
24 &self,
25 array: &Int8Array,
26 remote_field: &RemoteField,
27 ) -> DFResult<(ArrayRef, Field)> {
28 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
29 }
30
31 fn transform_int16(
32 &self,
33 array: &Int16Array,
34 remote_field: &RemoteField,
35 ) -> DFResult<(ArrayRef, Field)> {
36 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
37 }
38
39 fn transform_int32(
40 &self,
41 array: &Int32Array,
42 remote_field: &RemoteField,
43 ) -> DFResult<(ArrayRef, Field)> {
44 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
45 }
46
47 fn transform_int64(
48 &self,
49 array: &Int64Array,
50 remote_field: &RemoteField,
51 ) -> DFResult<(ArrayRef, Field)> {
52 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
53 }
54
55 fn transform_uint8(
56 &self,
57 array: &UInt8Array,
58 remote_field: &RemoteField,
59 ) -> DFResult<(ArrayRef, Field)> {
60 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
61 }
62
63 fn transform_uint16(
64 &self,
65 array: &UInt16Array,
66 remote_field: &RemoteField,
67 ) -> DFResult<(ArrayRef, Field)> {
68 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
69 }
70
71 fn transform_uint32(
72 &self,
73 array: &UInt32Array,
74 remote_field: &RemoteField,
75 ) -> DFResult<(ArrayRef, Field)> {
76 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
77 }
78
79 fn transform_uint64(
80 &self,
81 array: &UInt64Array,
82 remote_field: &RemoteField,
83 ) -> DFResult<(ArrayRef, Field)> {
84 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
85 }
86
87 fn transform_float16(
88 &self,
89 array: &Float16Array,
90 remote_field: &RemoteField,
91 ) -> DFResult<(ArrayRef, Field)> {
92 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
93 }
94
95 fn transform_float32(
96 &self,
97 array: &Float32Array,
98 remote_field: &RemoteField,
99 ) -> DFResult<(ArrayRef, Field)> {
100 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
101 }
102
103 fn transform_float64(
104 &self,
105 array: &Float64Array,
106 remote_field: &RemoteField,
107 ) -> DFResult<(ArrayRef, Field)> {
108 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
109 }
110
111 fn transform_utf8(
112 &self,
113 array: &StringArray,
114 remote_field: &RemoteField,
115 ) -> DFResult<(ArrayRef, Field)> {
116 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
117 }
118
119 fn transform_binary(
120 &self,
121 array: &BinaryArray,
122 remote_field: &RemoteField,
123 ) -> DFResult<(ArrayRef, Field)> {
124 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
125 }
126
127 fn transform_timestamp_second(
128 &self,
129 array: &TimestampSecondArray,
130 remote_field: &RemoteField,
131 ) -> DFResult<(ArrayRef, Field)> {
132 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
133 }
134
135 fn transform_timestamp_millisecond(
136 &self,
137 array: &TimestampMillisecondArray,
138 remote_field: &RemoteField,
139 ) -> DFResult<(ArrayRef, Field)> {
140 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
141 }
142
143 fn transform_timestamp_microsecond(
144 &self,
145 array: &TimestampMicrosecondArray,
146 remote_field: &RemoteField,
147 ) -> DFResult<(ArrayRef, Field)> {
148 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
149 }
150
151 fn transform_timestamp_nanosecond(
152 &self,
153 array: &TimestampNanosecondArray,
154 remote_field: &RemoteField,
155 ) -> DFResult<(ArrayRef, Field)> {
156 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
157 }
158
159 fn transform_time64_nanosecond(
160 &self,
161 array: &Time64NanosecondArray,
162 remote_field: &RemoteField,
163 ) -> DFResult<(ArrayRef, Field)> {
164 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
165 }
166
167 fn transform_date32(
168 &self,
169 array: &Date32Array,
170 remote_field: &RemoteField,
171 ) -> DFResult<(ArrayRef, Field)> {
172 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
173 }
174
175 fn transform_list(
176 &self,
177 array: &ListArray,
178 remote_field: &RemoteField,
179 ) -> DFResult<(ArrayRef, Field)> {
180 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
181 }
182}
183
184pub(crate) fn transform_batch(
185 batch: RecordBatch,
186 transform: &dyn Transform,
187 remote_schema: &RemoteSchema,
188) -> DFResult<RecordBatch> {
189 let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(remote_schema.fields.len());
190 let mut new_fields: Vec<Field> = Vec::with_capacity(remote_schema.fields.len());
191 for (idx, remote_field) in remote_schema.fields.iter().enumerate() {
192 let (new_array, new_field) = match &remote_field.remote_type.to_arrow_type() {
193 DataType::Boolean => {
194 let array = batch
195 .column(idx)
196 .as_any()
197 .downcast_ref::<BooleanArray>()
198 .expect("Failed to downcast to BooleanArray");
199 transform.transform_boolean(array, remote_field)?
200 }
201 DataType::Int8 => {
202 let array = batch
203 .column(idx)
204 .as_any()
205 .downcast_ref::<Int8Array>()
206 .expect("Failed to downcast to Int8Array");
207 transform.transform_int8(array, remote_field)?
208 }
209 DataType::Int16 => {
210 let array = batch
211 .column(idx)
212 .as_any()
213 .downcast_ref::<Int16Array>()
214 .expect("Failed to downcast to Int16Array");
215 transform.transform_int16(array, remote_field)?
216 }
217 DataType::Int32 => {
218 let array = batch
219 .column(idx)
220 .as_any()
221 .downcast_ref::<Int32Array>()
222 .expect("Failed to downcast to Int32Array");
223 transform.transform_int32(array, remote_field)?
224 }
225 DataType::Int64 => {
226 let array = batch
227 .column(idx)
228 .as_any()
229 .downcast_ref::<Int64Array>()
230 .expect("Failed to downcast to Int64Array");
231 transform.transform_int64(array, remote_field)?
232 }
233 DataType::UInt8 => {
234 let array = batch
235 .column(idx)
236 .as_any()
237 .downcast_ref::<UInt8Array>()
238 .expect("Failed to downcast to UInt8Array");
239 transform.transform_uint8(array, remote_field)?
240 }
241 DataType::UInt16 => {
242 let array = batch
243 .column(idx)
244 .as_any()
245 .downcast_ref::<UInt16Array>()
246 .expect("Failed to downcast to UInt16Array");
247 transform.transform_uint16(array, remote_field)?
248 }
249 DataType::UInt32 => {
250 let array = batch
251 .column(idx)
252 .as_any()
253 .downcast_ref::<UInt32Array>()
254 .expect("Failed to downcast to UInt32Array");
255 transform.transform_uint32(array, remote_field)?
256 }
257 DataType::UInt64 => {
258 let array = batch
259 .column(idx)
260 .as_any()
261 .downcast_ref::<UInt64Array>()
262 .expect("Failed to downcast to UInt64Array");
263 transform.transform_uint64(array, remote_field)?
264 }
265 DataType::Float16 => {
266 let array = batch
267 .column(idx)
268 .as_any()
269 .downcast_ref::<Float16Array>()
270 .expect("Failed to downcast to Float16Array");
271 transform.transform_float16(array, remote_field)?
272 }
273 DataType::Float32 => {
274 let array = batch
275 .column(idx)
276 .as_any()
277 .downcast_ref::<Float32Array>()
278 .expect("Failed to downcast to Float32Array");
279 transform.transform_float32(array, remote_field)?
280 }
281 DataType::Float64 => {
282 let array = batch
283 .column(idx)
284 .as_any()
285 .downcast_ref::<Float64Array>()
286 .expect("Failed to downcast to Float64Array");
287 transform.transform_float64(array, remote_field)?
288 }
289 DataType::Utf8 => {
290 let array = batch
291 .column(idx)
292 .as_any()
293 .downcast_ref::<StringArray>()
294 .expect("Failed to downcast to StringArray");
295 transform.transform_utf8(array, remote_field)?
296 }
297 DataType::Binary => {
298 let array = batch
299 .column(idx)
300 .as_any()
301 .downcast_ref::<BinaryArray>()
302 .expect("Failed to downcast to BinaryArray");
303 transform.transform_binary(array, remote_field)?
304 }
305 DataType::Date32 => {
306 let array = batch
307 .column(idx)
308 .as_any()
309 .downcast_ref::<Date32Array>()
310 .expect("Failed to downcast to Date32Array");
311 transform.transform_date32(array, remote_field)?
312 }
313 DataType::Timestamp(TimeUnit::Second, _) => {
314 let array = batch
315 .column(idx)
316 .as_any()
317 .downcast_ref::<TimestampSecondArray>()
318 .expect("Failed to downcast to TimestampSecondArray");
319 transform.transform_timestamp_second(array, remote_field)?
320 }
321 DataType::Timestamp(TimeUnit::Millisecond, _) => {
322 let array = batch
323 .column(idx)
324 .as_any()
325 .downcast_ref::<TimestampMillisecondArray>()
326 .expect("Failed to downcast to TimestampMillisecondArray");
327 transform.transform_timestamp_millisecond(array, remote_field)?
328 }
329 DataType::Timestamp(TimeUnit::Microsecond, _) => {
330 let array = batch
331 .column(idx)
332 .as_any()
333 .downcast_ref::<TimestampMicrosecondArray>()
334 .expect("Failed to downcast to TimestampMicrosecondArray");
335 transform.transform_timestamp_microsecond(array, remote_field)?
336 }
337 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
338 let array = batch
339 .column(idx)
340 .as_any()
341 .downcast_ref::<TimestampNanosecondArray>()
342 .expect("Failed to downcast to TimestampNanosecondArray");
343 transform.transform_timestamp_nanosecond(array, remote_field)?
344 }
345 DataType::Time64(TimeUnit::Nanosecond) => {
346 let array = batch
347 .column(idx)
348 .as_any()
349 .downcast_ref::<Time64NanosecondArray>()
350 .expect("Failed to downcast to Time64NanosecondArray");
351 transform.transform_time64_nanosecond(array, remote_field)?
352 }
353 DataType::List(_field) => {
354 let array = batch
355 .column(idx)
356 .as_any()
357 .downcast_ref::<ListArray>()
358 .expect("Failed to downcast to ListArray");
359 transform.transform_list(array, remote_field)?
360 }
361 data_type => {
362 return Err(DataFusionError::NotImplemented(format!(
363 "Unsupported arrow type {data_type:?}",
364 )))
365 }
366 };
367 new_arrays.push(new_array);
368 new_fields.push(new_field);
369 }
370 let new_schema = Arc::new(Schema::new(new_fields));
371 Ok(RecordBatch::try_new(new_schema, new_arrays)?)
372}