1use crate::{DFResult, RemoteField, RemoteSchemaRef};
2use datafusion::arrow::array::{
3 Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Float16Array,
4 Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, LargeBinaryArray,
5 LargeStringArray, ListArray, NullArray, RecordBatch, StringArray, Time32MillisecondArray,
6 Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
7 TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt8Array,
8 UInt16Array, UInt32Array, UInt64Array,
9};
10use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
13use futures::{Stream, StreamExt};
14use std::any::Any;
15use std::fmt::Debug;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll};
19
20pub struct TransformArgs<'a> {
21 pub col_index: usize,
22 pub field: &'a Field,
23 pub remote_field: Option<&'a RemoteField>,
24}
25
26pub trait Transform: Debug + Send + Sync {
27 fn as_any(&self) -> &dyn Any;
28
29 fn transform_null(
30 &self,
31 array: &NullArray,
32 args: TransformArgs,
33 ) -> DFResult<(ArrayRef, Field)> {
34 Ok((Arc::new(array.clone()), args.field.clone()))
35 }
36
37 fn transform_boolean(
38 &self,
39 array: &BooleanArray,
40 args: TransformArgs,
41 ) -> DFResult<(ArrayRef, Field)> {
42 Ok((Arc::new(array.clone()), args.field.clone()))
43 }
44
45 fn transform_int8(
46 &self,
47 array: &Int8Array,
48 args: TransformArgs,
49 ) -> DFResult<(ArrayRef, Field)> {
50 Ok((Arc::new(array.clone()), args.field.clone()))
51 }
52
53 fn transform_int16(
54 &self,
55 array: &Int16Array,
56 args: TransformArgs,
57 ) -> DFResult<(ArrayRef, Field)> {
58 Ok((Arc::new(array.clone()), args.field.clone()))
59 }
60
61 fn transform_int32(
62 &self,
63 array: &Int32Array,
64 args: TransformArgs,
65 ) -> DFResult<(ArrayRef, Field)> {
66 Ok((Arc::new(array.clone()), args.field.clone()))
67 }
68
69 fn transform_int64(
70 &self,
71 array: &Int64Array,
72 args: TransformArgs,
73 ) -> DFResult<(ArrayRef, Field)> {
74 Ok((Arc::new(array.clone()), args.field.clone()))
75 }
76
77 fn transform_uint8(
78 &self,
79 array: &UInt8Array,
80 args: TransformArgs,
81 ) -> DFResult<(ArrayRef, Field)> {
82 Ok((Arc::new(array.clone()), args.field.clone()))
83 }
84
85 fn transform_uint16(
86 &self,
87 array: &UInt16Array,
88 args: TransformArgs,
89 ) -> DFResult<(ArrayRef, Field)> {
90 Ok((Arc::new(array.clone()), args.field.clone()))
91 }
92
93 fn transform_uint32(
94 &self,
95 array: &UInt32Array,
96 args: TransformArgs,
97 ) -> DFResult<(ArrayRef, Field)> {
98 Ok((Arc::new(array.clone()), args.field.clone()))
99 }
100
101 fn transform_uint64(
102 &self,
103 array: &UInt64Array,
104 args: TransformArgs,
105 ) -> DFResult<(ArrayRef, Field)> {
106 Ok((Arc::new(array.clone()), args.field.clone()))
107 }
108
109 fn transform_float16(
110 &self,
111 array: &Float16Array,
112 args: TransformArgs,
113 ) -> DFResult<(ArrayRef, Field)> {
114 Ok((Arc::new(array.clone()), args.field.clone()))
115 }
116
117 fn transform_float32(
118 &self,
119 array: &Float32Array,
120 args: TransformArgs,
121 ) -> DFResult<(ArrayRef, Field)> {
122 Ok((Arc::new(array.clone()), args.field.clone()))
123 }
124
125 fn transform_float64(
126 &self,
127 array: &Float64Array,
128 args: TransformArgs,
129 ) -> DFResult<(ArrayRef, Field)> {
130 Ok((Arc::new(array.clone()), args.field.clone()))
131 }
132
133 fn transform_utf8(
134 &self,
135 array: &StringArray,
136 args: TransformArgs,
137 ) -> DFResult<(ArrayRef, Field)> {
138 Ok((Arc::new(array.clone()), args.field.clone()))
139 }
140
141 fn transform_large_utf8(
142 &self,
143 array: &LargeStringArray,
144 args: TransformArgs,
145 ) -> DFResult<(ArrayRef, Field)> {
146 Ok((Arc::new(array.clone()), args.field.clone()))
147 }
148
149 fn transform_binary(
150 &self,
151 array: &BinaryArray,
152 args: TransformArgs,
153 ) -> DFResult<(ArrayRef, Field)> {
154 Ok((Arc::new(array.clone()), args.field.clone()))
155 }
156
157 fn transform_large_binary(
158 &self,
159 array: &LargeBinaryArray,
160 args: TransformArgs,
161 ) -> DFResult<(ArrayRef, Field)> {
162 Ok((Arc::new(array.clone()), args.field.clone()))
163 }
164
165 fn transform_timestamp_second(
166 &self,
167 array: &TimestampSecondArray,
168 args: TransformArgs,
169 ) -> DFResult<(ArrayRef, Field)> {
170 Ok((Arc::new(array.clone()), args.field.clone()))
171 }
172
173 fn transform_timestamp_millisecond(
174 &self,
175 array: &TimestampMillisecondArray,
176 args: TransformArgs,
177 ) -> DFResult<(ArrayRef, Field)> {
178 Ok((Arc::new(array.clone()), args.field.clone()))
179 }
180
181 fn transform_timestamp_microsecond(
182 &self,
183 array: &TimestampMicrosecondArray,
184 args: TransformArgs,
185 ) -> DFResult<(ArrayRef, Field)> {
186 Ok((Arc::new(array.clone()), args.field.clone()))
187 }
188
189 fn transform_timestamp_nanosecond(
190 &self,
191 array: &TimestampNanosecondArray,
192 args: TransformArgs,
193 ) -> DFResult<(ArrayRef, Field)> {
194 Ok((Arc::new(array.clone()), args.field.clone()))
195 }
196
197 fn transform_date32(
198 &self,
199 array: &Date32Array,
200 args: TransformArgs,
201 ) -> DFResult<(ArrayRef, Field)> {
202 Ok((Arc::new(array.clone()), args.field.clone()))
203 }
204
205 fn transform_date64(
206 &self,
207 array: &Date64Array,
208 args: TransformArgs,
209 ) -> DFResult<(ArrayRef, Field)> {
210 Ok((Arc::new(array.clone()), args.field.clone()))
211 }
212
213 fn transform_time32_second(
214 &self,
215 array: &Time32SecondArray,
216 args: TransformArgs,
217 ) -> DFResult<(ArrayRef, Field)> {
218 Ok((Arc::new(array.clone()), args.field.clone()))
219 }
220
221 fn transform_time32_millisecond(
222 &self,
223 array: &Time32MillisecondArray,
224 args: TransformArgs,
225 ) -> DFResult<(ArrayRef, Field)> {
226 Ok((Arc::new(array.clone()), args.field.clone()))
227 }
228
229 fn transform_time64_microsecond(
230 &self,
231 array: &Time64MicrosecondArray,
232 args: TransformArgs,
233 ) -> DFResult<(ArrayRef, Field)> {
234 Ok((Arc::new(array.clone()), args.field.clone()))
235 }
236
237 fn transform_time64_nanosecond(
238 &self,
239 array: &Time64NanosecondArray,
240 args: TransformArgs,
241 ) -> DFResult<(ArrayRef, Field)> {
242 Ok((Arc::new(array.clone()), args.field.clone()))
243 }
244
245 fn transform_list(
246 &self,
247 array: &ListArray,
248 args: TransformArgs,
249 ) -> DFResult<(ArrayRef, Field)> {
250 Ok((Arc::new(array.clone()), args.field.clone()))
251 }
252}
253
254pub(crate) struct TransformStream {
255 input: SendableRecordBatchStream,
256 transform: Arc<dyn Transform>,
257 table_schema: SchemaRef,
258 projection: Option<Vec<usize>>,
259 projected_schema: SchemaRef,
260 remote_schema: Option<RemoteSchemaRef>,
261}
262
263impl TransformStream {
264 pub fn try_new(
265 input: SendableRecordBatchStream,
266 transform: Arc<dyn Transform>,
267 table_schema: SchemaRef,
268 projection: Option<Vec<usize>>,
269 remote_schema: Option<RemoteSchemaRef>,
270 ) -> DFResult<Self> {
271 let projected_schema = project_schema(&table_schema, projection.as_ref())?;
272 Ok(Self {
273 input,
274 transform,
275 table_schema,
276 projection,
277 projected_schema,
278 remote_schema,
279 })
280 }
281}
282
283impl Stream for TransformStream {
284 type Item = DFResult<RecordBatch>;
285 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
286 match self.input.poll_next_unpin(cx) {
287 Poll::Ready(Some(Ok(batch))) => {
288 match transform_batch(
289 batch,
290 self.transform.as_ref(),
291 &self.table_schema,
292 self.projection.as_ref(),
293 self.remote_schema.as_ref(),
294 ) {
295 Ok(result) => Poll::Ready(Some(Ok(result))),
296 Err(e) => Poll::Ready(Some(Err(e))),
297 }
298 }
299 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
300 Poll::Ready(None) => Poll::Ready(None),
301 Poll::Pending => Poll::Pending,
302 }
303 }
304}
305
306impl RecordBatchStream for TransformStream {
307 fn schema(&self) -> SchemaRef {
308 self.projected_schema.clone()
309 }
310}
311
312macro_rules! handle_transform {
313 ($batch:expr, $idx:expr, $array_ty:ty, $transform:expr, $transform_method:ident, $transform_args:expr) => {{
314 let array = $batch
315 .column($idx)
316 .as_any()
317 .downcast_ref::<$array_ty>()
318 .expect(concat!("Failed to downcast to ", stringify!($array_ty)));
319 $transform.$transform_method(array, $transform_args)?
320 }};
321}
322
323pub(crate) fn transform_batch(
324 batch: RecordBatch,
325 transform: &dyn Transform,
326 table_schema: &SchemaRef,
327 projection: Option<&Vec<usize>>,
328 remote_schema: Option<&RemoteSchemaRef>,
329) -> DFResult<RecordBatch> {
330 let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(batch.schema().fields.len());
331 let mut new_fields: Vec<Field> = Vec::with_capacity(batch.schema().fields.len());
332 let all_col_indexes = (0..table_schema.fields.len()).collect::<Vec<usize>>();
333 let projected_col_indexes = projection.unwrap_or(&all_col_indexes);
334
335 for (idx, col_index) in projected_col_indexes.iter().enumerate() {
336 let field = table_schema.field(*col_index);
337 let remote_field = remote_schema.map(|schema| &schema.fields[*col_index]);
338 let args = TransformArgs {
339 col_index: *col_index,
340 field,
341 remote_field,
342 };
343
344 let (new_array, new_field) = match &field.data_type() {
345 DataType::Null => {
346 handle_transform!(batch, idx, NullArray, transform, transform_null, args)
347 }
348 DataType::Boolean => {
349 handle_transform!(batch, idx, BooleanArray, transform, transform_boolean, args)
350 }
351 DataType::Int8 => {
352 handle_transform!(batch, idx, Int8Array, transform, transform_int8, args)
353 }
354 DataType::Int16 => {
355 handle_transform!(batch, idx, Int16Array, transform, transform_int16, args)
356 }
357 DataType::Int32 => {
358 handle_transform!(batch, idx, Int32Array, transform, transform_int32, args)
359 }
360 DataType::Int64 => {
361 handle_transform!(batch, idx, Int64Array, transform, transform_int64, args)
362 }
363 DataType::UInt8 => {
364 handle_transform!(batch, idx, UInt8Array, transform, transform_uint8, args)
365 }
366 DataType::UInt16 => {
367 handle_transform!(batch, idx, UInt16Array, transform, transform_uint16, args)
368 }
369 DataType::UInt32 => {
370 handle_transform!(batch, idx, UInt32Array, transform, transform_uint32, args)
371 }
372 DataType::UInt64 => {
373 handle_transform!(batch, idx, UInt64Array, transform, transform_uint64, args)
374 }
375 DataType::Float16 => {
376 handle_transform!(batch, idx, Float16Array, transform, transform_float16, args)
377 }
378 DataType::Float32 => {
379 handle_transform!(batch, idx, Float32Array, transform, transform_float32, args)
380 }
381 DataType::Float64 => {
382 handle_transform!(batch, idx, Float64Array, transform, transform_float64, args)
383 }
384 DataType::Timestamp(TimeUnit::Second, _) => {
385 handle_transform!(
386 batch,
387 idx,
388 TimestampSecondArray,
389 transform,
390 transform_timestamp_second,
391 args
392 )
393 }
394 DataType::Timestamp(TimeUnit::Millisecond, _) => {
395 handle_transform!(
396 batch,
397 idx,
398 TimestampMillisecondArray,
399 transform,
400 transform_timestamp_millisecond,
401 args
402 )
403 }
404 DataType::Timestamp(TimeUnit::Microsecond, _) => {
405 handle_transform!(
406 batch,
407 idx,
408 TimestampMicrosecondArray,
409 transform,
410 transform_timestamp_microsecond,
411 args
412 )
413 }
414 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
415 handle_transform!(
416 batch,
417 idx,
418 TimestampNanosecondArray,
419 transform,
420 transform_timestamp_nanosecond,
421 args
422 )
423 }
424 DataType::Date32 => {
425 handle_transform!(batch, idx, Date32Array, transform, transform_date32, args)
426 }
427 DataType::Date64 => {
428 handle_transform!(batch, idx, Date64Array, transform, transform_date64, args)
429 }
430 DataType::Time32(TimeUnit::Second) => {
431 handle_transform!(
432 batch,
433 idx,
434 Time32SecondArray,
435 transform,
436 transform_time32_second,
437 args
438 )
439 }
440 DataType::Time32(TimeUnit::Millisecond) => {
441 handle_transform!(
442 batch,
443 idx,
444 Time32MillisecondArray,
445 transform,
446 transform_time32_millisecond,
447 args
448 )
449 }
450 DataType::Time32(TimeUnit::Microsecond) => unreachable!(),
451 DataType::Time32(TimeUnit::Nanosecond) => unreachable!(),
452 DataType::Time64(TimeUnit::Second) => unreachable!(),
453 DataType::Time64(TimeUnit::Millisecond) => unreachable!(),
454 DataType::Time64(TimeUnit::Microsecond) => {
455 handle_transform!(
456 batch,
457 idx,
458 Time64MicrosecondArray,
459 transform,
460 transform_time64_microsecond,
461 args
462 )
463 }
464 DataType::Time64(TimeUnit::Nanosecond) => {
465 handle_transform!(
466 batch,
467 idx,
468 Time64NanosecondArray,
469 transform,
470 transform_time64_nanosecond,
471 args
472 )
473 }
474 DataType::Utf8 => {
475 handle_transform!(batch, idx, StringArray, transform, transform_utf8, args)
476 }
477 DataType::LargeUtf8 => {
478 handle_transform!(
479 batch,
480 idx,
481 LargeStringArray,
482 transform,
483 transform_large_utf8,
484 args
485 )
486 }
487 DataType::Binary => {
488 handle_transform!(batch, idx, BinaryArray, transform, transform_binary, args)
489 }
490 DataType::LargeBinary => {
491 handle_transform!(
492 batch,
493 idx,
494 LargeBinaryArray,
495 transform,
496 transform_large_binary,
497 args
498 )
499 }
500 DataType::List(_field) => {
501 handle_transform!(batch, idx, ListArray, transform, transform_list, args)
502 }
503 data_type => {
504 return Err(DataFusionError::NotImplemented(format!(
505 "Unsupported arrow type {data_type:?}",
506 )));
507 }
508 };
509 new_arrays.push(new_array);
510 new_fields.push(new_field);
511 }
512 let new_schema = Arc::new(Schema::new(new_fields));
513 Ok(RecordBatch::try_new(new_schema, new_arrays)?)
514}
515
516pub(crate) fn transform_schema(
517 schema: SchemaRef,
518 transform: Option<&Arc<dyn Transform>>,
519 remote_schema: Option<&RemoteSchemaRef>,
520) -> DFResult<SchemaRef> {
521 if let Some(transform) = transform {
522 let empty_record = RecordBatch::new_empty(schema.clone());
523 transform_batch(
524 empty_record,
525 transform.as_ref(),
526 &schema,
527 None,
528 remote_schema,
529 )
530 .map(|batch| batch.schema())
531 } else {
532 Ok(schema)
533 }
534}