1use crate::{DFResult, RemoteField, RemoteSchema};
2use datafusion::arrow::array::{
3 ArrayRef, BinaryArray, BooleanArray, Date32Array, Float16Array, Float32Array, Float64Array,
4 Int16Array, Int32Array, Int64Array, Int8Array, ListArray, RecordBatch, StringArray,
5 Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
6 TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array,
7 UInt8Array,
8};
9use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
10use datafusion::common::DataFusionError;
11use std::fmt::Debug;
12use std::sync::Arc;
13
14pub trait Transform: Debug + Send + Sync {
15 fn transform_boolean(
16 &self,
17 array: &BooleanArray,
18 remote_field: &RemoteField,
19 ) -> DFResult<(ArrayRef, Field)> {
20 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
21 }
22
23 fn transform_int8(
24 &self,
25 array: &Int8Array,
26 remote_field: &RemoteField,
27 ) -> DFResult<(ArrayRef, Field)> {
28 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
29 }
30
31 fn transform_int16(
32 &self,
33 array: &Int16Array,
34 remote_field: &RemoteField,
35 ) -> DFResult<(ArrayRef, Field)> {
36 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
37 }
38
39 fn transform_int32(
40 &self,
41 array: &Int32Array,
42 remote_field: &RemoteField,
43 ) -> DFResult<(ArrayRef, Field)> {
44 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
45 }
46
47 fn transform_int64(
48 &self,
49 array: &Int64Array,
50 remote_field: &RemoteField,
51 ) -> DFResult<(ArrayRef, Field)> {
52 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
53 }
54
55 fn transform_uint8(
56 &self,
57 array: &UInt8Array,
58 remote_field: &RemoteField,
59 ) -> DFResult<(ArrayRef, Field)> {
60 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
61 }
62
63 fn transform_uint16(
64 &self,
65 array: &UInt16Array,
66 remote_field: &RemoteField,
67 ) -> DFResult<(ArrayRef, Field)> {
68 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
69 }
70
71 fn transform_uint32(
72 &self,
73 array: &UInt32Array,
74 remote_field: &RemoteField,
75 ) -> DFResult<(ArrayRef, Field)> {
76 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
77 }
78
79 fn transform_uint64(
80 &self,
81 array: &UInt64Array,
82 remote_field: &RemoteField,
83 ) -> DFResult<(ArrayRef, Field)> {
84 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
85 }
86
87 fn transform_float16(
88 &self,
89 array: &Float16Array,
90 remote_field: &RemoteField,
91 ) -> DFResult<(ArrayRef, Field)> {
92 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
93 }
94
95 fn transform_float32(
96 &self,
97 array: &Float32Array,
98 remote_field: &RemoteField,
99 ) -> DFResult<(ArrayRef, Field)> {
100 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
101 }
102
103 fn transform_float64(
104 &self,
105 array: &Float64Array,
106 remote_field: &RemoteField,
107 ) -> DFResult<(ArrayRef, Field)> {
108 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
109 }
110
111 fn transform_utf8(
112 &self,
113 array: &StringArray,
114 remote_field: &RemoteField,
115 ) -> DFResult<(ArrayRef, Field)> {
116 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
117 }
118
119 fn transform_binary(
120 &self,
121 array: &BinaryArray,
122 remote_field: &RemoteField,
123 ) -> DFResult<(ArrayRef, Field)> {
124 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
125 }
126
127 fn transform_timestamp_second(
128 &self,
129 array: &TimestampSecondArray,
130 remote_field: &RemoteField,
131 ) -> DFResult<(ArrayRef, Field)> {
132 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
133 }
134
135 fn transform_timestamp_millisecond(
136 &self,
137 array: &TimestampMillisecondArray,
138 remote_field: &RemoteField,
139 ) -> DFResult<(ArrayRef, Field)> {
140 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
141 }
142
143 fn transform_timestamp_microsecond(
144 &self,
145 array: &TimestampMicrosecondArray,
146 remote_field: &RemoteField,
147 ) -> DFResult<(ArrayRef, Field)> {
148 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
149 }
150
151 fn transform_timestamp_nanosecond(
152 &self,
153 array: &TimestampNanosecondArray,
154 remote_field: &RemoteField,
155 ) -> DFResult<(ArrayRef, Field)> {
156 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
157 }
158
159 fn transform_time64_nanosecond(
160 &self,
161 array: &Time64NanosecondArray,
162 remote_field: &RemoteField,
163 ) -> DFResult<(ArrayRef, Field)> {
164 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
165 }
166
167 fn transform_date32(
168 &self,
169 array: &Date32Array,
170 remote_field: &RemoteField,
171 ) -> DFResult<(ArrayRef, Field)> {
172 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
173 }
174
175 fn transform_list(
176 &self,
177 array: &ListArray,
178 remote_field: &RemoteField,
179 ) -> DFResult<(ArrayRef, Field)> {
180 Ok((Arc::new(array.clone()), remote_field.to_arrow_field()))
181 }
182}
183
184pub(crate) fn transform_batch(
185 batch: RecordBatch,
186 transform: &dyn Transform,
187 remote_schema: &RemoteSchema,
188) -> DFResult<RecordBatch> {
189 let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(remote_schema.fields.len());
190 let mut new_fields: Vec<Field> = Vec::with_capacity(remote_schema.fields.len());
191 for (idx, remote_field) in remote_schema.fields.iter().enumerate() {
192 let (new_array, new_field) = match &remote_field.remote_type.to_arrow_type() {
193 DataType::Boolean => {
195 let array = batch
196 .column(idx)
197 .as_any()
198 .downcast_ref::<BooleanArray>()
199 .expect("Failed to downcast to BooleanArray");
200 transform.transform_boolean(array, remote_field)?
201 }
202 DataType::Int8 => {
203 let array = batch
204 .column(idx)
205 .as_any()
206 .downcast_ref::<Int8Array>()
207 .expect("Failed to downcast to Int8Array");
208 transform.transform_int8(array, remote_field)?
209 }
210 DataType::Int16 => {
211 let array = batch
212 .column(idx)
213 .as_any()
214 .downcast_ref::<Int16Array>()
215 .expect("Failed to downcast to Int16Array");
216 transform.transform_int16(array, remote_field)?
217 }
218 DataType::Int32 => {
219 let array = batch
220 .column(idx)
221 .as_any()
222 .downcast_ref::<Int32Array>()
223 .expect("Failed to downcast to Int32Array");
224 transform.transform_int32(array, remote_field)?
225 }
226 DataType::Int64 => {
227 let array = batch
228 .column(idx)
229 .as_any()
230 .downcast_ref::<Int64Array>()
231 .expect("Failed to downcast to Int64Array");
232 transform.transform_int64(array, remote_field)?
233 }
234 DataType::UInt8 => {
235 let array = batch
236 .column(idx)
237 .as_any()
238 .downcast_ref::<UInt8Array>()
239 .expect("Failed to downcast to UInt8Array");
240 transform.transform_uint8(array, remote_field)?
241 }
242 DataType::UInt16 => {
243 let array = batch
244 .column(idx)
245 .as_any()
246 .downcast_ref::<UInt16Array>()
247 .expect("Failed to downcast to UInt16Array");
248 transform.transform_uint16(array, remote_field)?
249 }
250 DataType::UInt32 => {
251 let array = batch
252 .column(idx)
253 .as_any()
254 .downcast_ref::<UInt32Array>()
255 .expect("Failed to downcast to UInt32Array");
256 transform.transform_uint32(array, remote_field)?
257 }
258 DataType::UInt64 => {
259 let array = batch
260 .column(idx)
261 .as_any()
262 .downcast_ref::<UInt64Array>()
263 .expect("Failed to downcast to UInt64Array");
264 transform.transform_uint64(array, remote_field)?
265 }
266 DataType::Float16 => {
267 let array = batch
268 .column(idx)
269 .as_any()
270 .downcast_ref::<Float16Array>()
271 .expect("Failed to downcast to Float16Array");
272 transform.transform_float16(array, remote_field)?
273 }
274 DataType::Float32 => {
275 let array = batch
276 .column(idx)
277 .as_any()
278 .downcast_ref::<Float32Array>()
279 .expect("Failed to downcast to Float32Array");
280 transform.transform_float32(array, remote_field)?
281 }
282 DataType::Float64 => {
283 let array = batch
284 .column(idx)
285 .as_any()
286 .downcast_ref::<Float64Array>()
287 .expect("Failed to downcast to Float64Array");
288 transform.transform_float64(array, remote_field)?
289 }
290 DataType::Utf8 => {
291 let array = batch
292 .column(idx)
293 .as_any()
294 .downcast_ref::<StringArray>()
295 .expect("Failed to downcast to StringArray");
296 transform.transform_utf8(array, remote_field)?
297 }
298 DataType::Binary => {
299 let array = batch
300 .column(idx)
301 .as_any()
302 .downcast_ref::<BinaryArray>()
303 .expect("Failed to downcast to BinaryArray");
304 transform.transform_binary(array, remote_field)?
305 }
306 DataType::Date32 => {
307 let array = batch
308 .column(idx)
309 .as_any()
310 .downcast_ref::<Date32Array>()
311 .expect("Failed to downcast to Date32Array");
312 transform.transform_date32(array, remote_field)?
313 }
314 DataType::Timestamp(TimeUnit::Second, _) => {
315 let array = batch
316 .column(idx)
317 .as_any()
318 .downcast_ref::<TimestampSecondArray>()
319 .expect("Failed to downcast to TimestampSecondArray");
320 transform.transform_timestamp_second(array, remote_field)?
321 }
322 DataType::Timestamp(TimeUnit::Millisecond, _) => {
323 let array = batch
324 .column(idx)
325 .as_any()
326 .downcast_ref::<TimestampMillisecondArray>()
327 .expect("Failed to downcast to TimestampMillisecondArray");
328 transform.transform_timestamp_millisecond(array, remote_field)?
329 }
330 DataType::Timestamp(TimeUnit::Microsecond, _) => {
331 let array = batch
332 .column(idx)
333 .as_any()
334 .downcast_ref::<TimestampMicrosecondArray>()
335 .expect("Failed to downcast to TimestampMicrosecondArray");
336 transform.transform_timestamp_microsecond(array, remote_field)?
337 }
338 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
339 let array = batch
340 .column(idx)
341 .as_any()
342 .downcast_ref::<TimestampNanosecondArray>()
343 .expect("Failed to downcast to TimestampNanosecondArray");
344 transform.transform_timestamp_nanosecond(array, remote_field)?
345 }
346 DataType::Time64(TimeUnit::Nanosecond) => {
347 let array = batch
348 .column(idx)
349 .as_any()
350 .downcast_ref::<Time64NanosecondArray>()
351 .expect("Failed to downcast to Time64NanosecondArray");
352 transform.transform_time64_nanosecond(array, remote_field)?
353 }
354 DataType::List(_field) => {
355 let array = batch
356 .column(idx)
357 .as_any()
358 .downcast_ref::<ListArray>()
359 .expect("Failed to downcast to ListArray");
360 transform.transform_list(array, remote_field)?
361 }
362 data_type => {
363 return Err(DataFusionError::NotImplemented(format!(
364 "Unsupported arrow type {data_type:?}",
365 )))
366 }
367 };
368 new_arrays.push(new_array);
369 new_fields.push(new_field);
370 }
371 let new_schema = Arc::new(Schema::new(new_fields));
372 Ok(RecordBatch::try_new(new_schema, new_arrays)?)
373}