1use crate::{DFResult, RemoteField, RemoteSchemaRef};
2use datafusion::arrow::array::{
3 Array, ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
4 Decimal128Array, Decimal256Array, FixedSizeBinaryArray, FixedSizeListArray, Float16Array,
5 Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array,
6 IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
7 LargeListArray, LargeListViewArray, LargeStringArray, ListArray, ListViewArray, NullArray,
8 RecordBatch, StringArray, StringViewArray, Time32MillisecondArray, Time32SecondArray,
9 Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
10 TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt8Array,
11 UInt16Array, UInt32Array, UInt64Array,
12};
13use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit};
14use datafusion::common::{DataFusionError, project_schema};
15use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
16use futures::{Stream, StreamExt};
17use std::any::Any;
18use std::fmt::Debug;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22
23pub struct TransformArgs<'a> {
24 pub col_index: usize,
25 pub field: &'a Field,
26 pub remote_field: Option<&'a RemoteField>,
27}
28
29pub trait Transform: Debug + Send + Sync {
30 fn as_any(&self) -> &dyn Any;
31
32 fn transform_null(
33 &self,
34 array: &NullArray,
35 args: TransformArgs,
36 ) -> DFResult<(ArrayRef, Field)> {
37 Ok((Arc::new(array.clone()), args.field.clone()))
38 }
39
40 fn transform_boolean(
41 &self,
42 array: &BooleanArray,
43 args: TransformArgs,
44 ) -> DFResult<(ArrayRef, Field)> {
45 Ok((Arc::new(array.clone()), args.field.clone()))
46 }
47
48 fn transform_int8(
49 &self,
50 array: &Int8Array,
51 args: TransformArgs,
52 ) -> DFResult<(ArrayRef, Field)> {
53 Ok((Arc::new(array.clone()), args.field.clone()))
54 }
55
56 fn transform_int16(
57 &self,
58 array: &Int16Array,
59 args: TransformArgs,
60 ) -> DFResult<(ArrayRef, Field)> {
61 Ok((Arc::new(array.clone()), args.field.clone()))
62 }
63
64 fn transform_int32(
65 &self,
66 array: &Int32Array,
67 args: TransformArgs,
68 ) -> DFResult<(ArrayRef, Field)> {
69 Ok((Arc::new(array.clone()), args.field.clone()))
70 }
71
72 fn transform_int64(
73 &self,
74 array: &Int64Array,
75 args: TransformArgs,
76 ) -> DFResult<(ArrayRef, Field)> {
77 Ok((Arc::new(array.clone()), args.field.clone()))
78 }
79
80 fn transform_uint8(
81 &self,
82 array: &UInt8Array,
83 args: TransformArgs,
84 ) -> DFResult<(ArrayRef, Field)> {
85 Ok((Arc::new(array.clone()), args.field.clone()))
86 }
87
88 fn transform_uint16(
89 &self,
90 array: &UInt16Array,
91 args: TransformArgs,
92 ) -> DFResult<(ArrayRef, Field)> {
93 Ok((Arc::new(array.clone()), args.field.clone()))
94 }
95
96 fn transform_uint32(
97 &self,
98 array: &UInt32Array,
99 args: TransformArgs,
100 ) -> DFResult<(ArrayRef, Field)> {
101 Ok((Arc::new(array.clone()), args.field.clone()))
102 }
103
104 fn transform_uint64(
105 &self,
106 array: &UInt64Array,
107 args: TransformArgs,
108 ) -> DFResult<(ArrayRef, Field)> {
109 Ok((Arc::new(array.clone()), args.field.clone()))
110 }
111
112 fn transform_float16(
113 &self,
114 array: &Float16Array,
115 args: TransformArgs,
116 ) -> DFResult<(ArrayRef, Field)> {
117 Ok((Arc::new(array.clone()), args.field.clone()))
118 }
119
120 fn transform_float32(
121 &self,
122 array: &Float32Array,
123 args: TransformArgs,
124 ) -> DFResult<(ArrayRef, Field)> {
125 Ok((Arc::new(array.clone()), args.field.clone()))
126 }
127
128 fn transform_float64(
129 &self,
130 array: &Float64Array,
131 args: TransformArgs,
132 ) -> DFResult<(ArrayRef, Field)> {
133 Ok((Arc::new(array.clone()), args.field.clone()))
134 }
135
136 fn transform_binary(
137 &self,
138 array: &BinaryArray,
139 args: TransformArgs,
140 ) -> DFResult<(ArrayRef, Field)> {
141 Ok((Arc::new(array.clone()), args.field.clone()))
142 }
143
144 fn transform_fixed_size_binary(
145 &self,
146 array: &FixedSizeBinaryArray,
147 args: TransformArgs,
148 ) -> DFResult<(ArrayRef, Field)> {
149 Ok((Arc::new(array.clone()), args.field.clone()))
150 }
151
152 fn transform_large_binary(
153 &self,
154 array: &LargeBinaryArray,
155 args: TransformArgs,
156 ) -> DFResult<(ArrayRef, Field)> {
157 Ok((Arc::new(array.clone()), args.field.clone()))
158 }
159
160 fn transform_binary_view(
161 &self,
162 array: &BinaryViewArray,
163 args: TransformArgs,
164 ) -> DFResult<(ArrayRef, Field)> {
165 Ok((Arc::new(array.clone()), args.field.clone()))
166 }
167
168 fn transform_utf8(
169 &self,
170 array: &StringArray,
171 args: TransformArgs,
172 ) -> DFResult<(ArrayRef, Field)> {
173 Ok((Arc::new(array.clone()), args.field.clone()))
174 }
175
176 fn transform_large_utf8(
177 &self,
178 array: &LargeStringArray,
179 args: TransformArgs,
180 ) -> DFResult<(ArrayRef, Field)> {
181 Ok((Arc::new(array.clone()), args.field.clone()))
182 }
183
184 fn transform_utf8_view(
185 &self,
186 array: &StringViewArray,
187 args: TransformArgs,
188 ) -> DFResult<(ArrayRef, Field)> {
189 Ok((Arc::new(array.clone()), args.field.clone()))
190 }
191
192 fn transform_timestamp_second(
193 &self,
194 array: &TimestampSecondArray,
195 args: TransformArgs,
196 ) -> DFResult<(ArrayRef, Field)> {
197 Ok((Arc::new(array.clone()), args.field.clone()))
198 }
199
200 fn transform_timestamp_millisecond(
201 &self,
202 array: &TimestampMillisecondArray,
203 args: TransformArgs,
204 ) -> DFResult<(ArrayRef, Field)> {
205 Ok((Arc::new(array.clone()), args.field.clone()))
206 }
207
208 fn transform_timestamp_microsecond(
209 &self,
210 array: &TimestampMicrosecondArray,
211 args: TransformArgs,
212 ) -> DFResult<(ArrayRef, Field)> {
213 Ok((Arc::new(array.clone()), args.field.clone()))
214 }
215
216 fn transform_timestamp_nanosecond(
217 &self,
218 array: &TimestampNanosecondArray,
219 args: TransformArgs,
220 ) -> DFResult<(ArrayRef, Field)> {
221 Ok((Arc::new(array.clone()), args.field.clone()))
222 }
223
224 fn transform_date32(
225 &self,
226 array: &Date32Array,
227 args: TransformArgs,
228 ) -> DFResult<(ArrayRef, Field)> {
229 Ok((Arc::new(array.clone()), args.field.clone()))
230 }
231
232 fn transform_date64(
233 &self,
234 array: &Date64Array,
235 args: TransformArgs,
236 ) -> DFResult<(ArrayRef, Field)> {
237 Ok((Arc::new(array.clone()), args.field.clone()))
238 }
239
240 fn transform_time32_second(
241 &self,
242 array: &Time32SecondArray,
243 args: TransformArgs,
244 ) -> DFResult<(ArrayRef, Field)> {
245 Ok((Arc::new(array.clone()), args.field.clone()))
246 }
247
248 fn transform_time32_millisecond(
249 &self,
250 array: &Time32MillisecondArray,
251 args: TransformArgs,
252 ) -> DFResult<(ArrayRef, Field)> {
253 Ok((Arc::new(array.clone()), args.field.clone()))
254 }
255
256 fn transform_time64_microsecond(
257 &self,
258 array: &Time64MicrosecondArray,
259 args: TransformArgs,
260 ) -> DFResult<(ArrayRef, Field)> {
261 Ok((Arc::new(array.clone()), args.field.clone()))
262 }
263
264 fn transform_time64_nanosecond(
265 &self,
266 array: &Time64NanosecondArray,
267 args: TransformArgs,
268 ) -> DFResult<(ArrayRef, Field)> {
269 Ok((Arc::new(array.clone()), args.field.clone()))
270 }
271
272 fn transform_interval_year_month(
273 &self,
274 array: &IntervalYearMonthArray,
275 args: TransformArgs,
276 ) -> DFResult<(ArrayRef, Field)> {
277 Ok((Arc::new(array.clone()), args.field.clone()))
278 }
279
280 fn transform_interval_day_time(
281 &self,
282 array: &IntervalDayTimeArray,
283 args: TransformArgs,
284 ) -> DFResult<(ArrayRef, Field)> {
285 Ok((Arc::new(array.clone()), args.field.clone()))
286 }
287
288 fn transform_interval_month_day_nano(
289 &self,
290 array: &IntervalMonthDayNanoArray,
291 args: TransformArgs,
292 ) -> DFResult<(ArrayRef, Field)> {
293 Ok((Arc::new(array.clone()), args.field.clone()))
294 }
295
296 fn transform_list(
297 &self,
298 array: &ListArray,
299 args: TransformArgs,
300 ) -> DFResult<(ArrayRef, Field)> {
301 Ok((Arc::new(array.clone()), args.field.clone()))
302 }
303
304 fn transform_list_view(
305 &self,
306 array: &ListViewArray,
307 args: TransformArgs,
308 ) -> DFResult<(ArrayRef, Field)> {
309 Ok((Arc::new(array.clone()), args.field.clone()))
310 }
311
312 fn transform_fixed_size_list(
313 &self,
314 array: &FixedSizeListArray,
315 args: TransformArgs,
316 ) -> DFResult<(ArrayRef, Field)> {
317 Ok((Arc::new(array.clone()), args.field.clone()))
318 }
319
320 fn transform_large_list(
321 &self,
322 array: &LargeListArray,
323 args: TransformArgs,
324 ) -> DFResult<(ArrayRef, Field)> {
325 Ok((Arc::new(array.clone()), args.field.clone()))
326 }
327
328 fn transform_large_list_view(
329 &self,
330 array: &LargeListViewArray,
331 args: TransformArgs,
332 ) -> DFResult<(ArrayRef, Field)> {
333 Ok((Arc::new(array.clone()), args.field.clone()))
334 }
335
336 fn transform_decimal128(
337 &self,
338 array: &Decimal128Array,
339 args: TransformArgs,
340 ) -> DFResult<(ArrayRef, Field)> {
341 Ok((Arc::new(array.clone()), args.field.clone()))
342 }
343
344 fn transform_decimal256(
345 &self,
346 array: &Decimal256Array,
347 args: TransformArgs,
348 ) -> DFResult<(ArrayRef, Field)> {
349 Ok((Arc::new(array.clone()), args.field.clone()))
350 }
351}
352
353#[derive(Debug)]
354pub struct DefaultTransform {}
355
356impl Transform for DefaultTransform {
357 fn as_any(&self) -> &dyn Any {
358 self
359 }
360}
361
362pub(crate) struct TransformStream {
363 input: SendableRecordBatchStream,
364 transform: Arc<dyn Transform>,
365 table_schema: SchemaRef,
366 projection: Option<Vec<usize>>,
367 projected_schema: SchemaRef,
368 remote_schema: Option<RemoteSchemaRef>,
369}
370
371impl TransformStream {
372 pub fn try_new(
373 input: SendableRecordBatchStream,
374 transform: Arc<dyn Transform>,
375 table_schema: SchemaRef,
376 projection: Option<Vec<usize>>,
377 remote_schema: Option<RemoteSchemaRef>,
378 ) -> DFResult<Self> {
379 let projected_schema = project_schema(&table_schema, projection.as_ref())?;
380 Ok(Self {
381 input,
382 transform,
383 table_schema,
384 projection,
385 projected_schema,
386 remote_schema,
387 })
388 }
389}
390
391impl Stream for TransformStream {
392 type Item = DFResult<RecordBatch>;
393 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
394 match self.input.poll_next_unpin(cx) {
395 Poll::Ready(Some(Ok(batch))) => {
396 match transform_batch(
397 batch,
398 self.transform.as_ref(),
399 &self.table_schema,
400 self.projection.as_ref(),
401 self.remote_schema.as_ref(),
402 ) {
403 Ok(result) => Poll::Ready(Some(Ok(result))),
404 Err(e) => Poll::Ready(Some(Err(e))),
405 }
406 }
407 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
408 Poll::Ready(None) => Poll::Ready(None),
409 Poll::Pending => Poll::Pending,
410 }
411 }
412}
413
414impl RecordBatchStream for TransformStream {
415 fn schema(&self) -> SchemaRef {
416 self.projected_schema.clone()
417 }
418}
419
420macro_rules! handle_transform {
421 ($batch:expr, $idx:expr, $array_ty:ty, $transform:expr, $transform_method:ident, $transform_args:expr) => {{
422 let array = $batch
423 .column($idx)
424 .as_any()
425 .downcast_ref::<$array_ty>()
426 .expect(concat!("Failed to downcast to ", stringify!($array_ty)));
427 $transform.$transform_method(array, $transform_args)?
428 }};
429}
430
431pub(crate) fn transform_batch(
432 batch: RecordBatch,
433 transform: &dyn Transform,
434 table_schema: &SchemaRef,
435 projection: Option<&Vec<usize>>,
436 remote_schema: Option<&RemoteSchemaRef>,
437) -> DFResult<RecordBatch> {
438 let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(batch.schema().fields.len());
439 let mut new_fields: Vec<Field> = Vec::with_capacity(batch.schema().fields.len());
440 let all_col_indexes = (0..table_schema.fields.len()).collect::<Vec<usize>>();
441 let projected_col_indexes = projection.unwrap_or(&all_col_indexes);
442
443 for (idx, col_index) in projected_col_indexes.iter().enumerate() {
444 let field = table_schema.field(*col_index);
445 let remote_field = remote_schema.map(|schema| &schema.fields[*col_index]);
446 let args = TransformArgs {
447 col_index: *col_index,
448 field,
449 remote_field,
450 };
451
452 let (new_array, new_field) = match &field.data_type() {
453 DataType::Null => {
454 handle_transform!(batch, idx, NullArray, transform, transform_null, args)
455 }
456 DataType::Boolean => {
457 handle_transform!(batch, idx, BooleanArray, transform, transform_boolean, args)
458 }
459 DataType::Int8 => {
460 handle_transform!(batch, idx, Int8Array, transform, transform_int8, args)
461 }
462 DataType::Int16 => {
463 handle_transform!(batch, idx, Int16Array, transform, transform_int16, args)
464 }
465 DataType::Int32 => {
466 handle_transform!(batch, idx, Int32Array, transform, transform_int32, args)
467 }
468 DataType::Int64 => {
469 handle_transform!(batch, idx, Int64Array, transform, transform_int64, args)
470 }
471 DataType::UInt8 => {
472 handle_transform!(batch, idx, UInt8Array, transform, transform_uint8, args)
473 }
474 DataType::UInt16 => {
475 handle_transform!(batch, idx, UInt16Array, transform, transform_uint16, args)
476 }
477 DataType::UInt32 => {
478 handle_transform!(batch, idx, UInt32Array, transform, transform_uint32, args)
479 }
480 DataType::UInt64 => {
481 handle_transform!(batch, idx, UInt64Array, transform, transform_uint64, args)
482 }
483 DataType::Float16 => {
484 handle_transform!(batch, idx, Float16Array, transform, transform_float16, args)
485 }
486 DataType::Float32 => {
487 handle_transform!(batch, idx, Float32Array, transform, transform_float32, args)
488 }
489 DataType::Float64 => {
490 handle_transform!(batch, idx, Float64Array, transform, transform_float64, args)
491 }
492 DataType::Timestamp(TimeUnit::Second, _) => {
493 handle_transform!(
494 batch,
495 idx,
496 TimestampSecondArray,
497 transform,
498 transform_timestamp_second,
499 args
500 )
501 }
502 DataType::Timestamp(TimeUnit::Millisecond, _) => {
503 handle_transform!(
504 batch,
505 idx,
506 TimestampMillisecondArray,
507 transform,
508 transform_timestamp_millisecond,
509 args
510 )
511 }
512 DataType::Timestamp(TimeUnit::Microsecond, _) => {
513 handle_transform!(
514 batch,
515 idx,
516 TimestampMicrosecondArray,
517 transform,
518 transform_timestamp_microsecond,
519 args
520 )
521 }
522 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
523 handle_transform!(
524 batch,
525 idx,
526 TimestampNanosecondArray,
527 transform,
528 transform_timestamp_nanosecond,
529 args
530 )
531 }
532 DataType::Date32 => {
533 handle_transform!(batch, idx, Date32Array, transform, transform_date32, args)
534 }
535 DataType::Date64 => {
536 handle_transform!(batch, idx, Date64Array, transform, transform_date64, args)
537 }
538 DataType::Time32(TimeUnit::Second) => {
539 handle_transform!(
540 batch,
541 idx,
542 Time32SecondArray,
543 transform,
544 transform_time32_second,
545 args
546 )
547 }
548 DataType::Time32(TimeUnit::Millisecond) => {
549 handle_transform!(
550 batch,
551 idx,
552 Time32MillisecondArray,
553 transform,
554 transform_time32_millisecond,
555 args
556 )
557 }
558 DataType::Time32(TimeUnit::Microsecond) => unreachable!(),
559 DataType::Time32(TimeUnit::Nanosecond) => unreachable!(),
560 DataType::Time64(TimeUnit::Second) => unreachable!(),
561 DataType::Time64(TimeUnit::Millisecond) => unreachable!(),
562 DataType::Time64(TimeUnit::Microsecond) => {
563 handle_transform!(
564 batch,
565 idx,
566 Time64MicrosecondArray,
567 transform,
568 transform_time64_microsecond,
569 args
570 )
571 }
572 DataType::Time64(TimeUnit::Nanosecond) => {
573 handle_transform!(
574 batch,
575 idx,
576 Time64NanosecondArray,
577 transform,
578 transform_time64_nanosecond,
579 args
580 )
581 }
582 DataType::Interval(IntervalUnit::YearMonth) => {
583 handle_transform!(
584 batch,
585 idx,
586 IntervalYearMonthArray,
587 transform,
588 transform_interval_year_month,
589 args
590 )
591 }
592 DataType::Interval(IntervalUnit::DayTime) => {
593 handle_transform!(
594 batch,
595 idx,
596 IntervalDayTimeArray,
597 transform,
598 transform_interval_day_time,
599 args
600 )
601 }
602 DataType::Interval(IntervalUnit::MonthDayNano) => {
603 handle_transform!(
604 batch,
605 idx,
606 IntervalMonthDayNanoArray,
607 transform,
608 transform_interval_month_day_nano,
609 args
610 )
611 }
612 DataType::Binary => {
613 handle_transform!(batch, idx, BinaryArray, transform, transform_binary, args)
614 }
615 DataType::FixedSizeBinary(_) => {
616 handle_transform!(
617 batch,
618 idx,
619 FixedSizeBinaryArray,
620 transform,
621 transform_fixed_size_binary,
622 args
623 )
624 }
625 DataType::LargeBinary => {
626 handle_transform!(
627 batch,
628 idx,
629 LargeBinaryArray,
630 transform,
631 transform_large_binary,
632 args
633 )
634 }
635 DataType::BinaryView => {
636 handle_transform!(
637 batch,
638 idx,
639 BinaryViewArray,
640 transform,
641 transform_binary_view,
642 args
643 )
644 }
645 DataType::Utf8 => {
646 handle_transform!(batch, idx, StringArray, transform, transform_utf8, args)
647 }
648 DataType::LargeUtf8 => {
649 handle_transform!(
650 batch,
651 idx,
652 LargeStringArray,
653 transform,
654 transform_large_utf8,
655 args
656 )
657 }
658 DataType::Utf8View => {
659 handle_transform!(
660 batch,
661 idx,
662 StringViewArray,
663 transform,
664 transform_utf8_view,
665 args
666 )
667 }
668 DataType::List(_field) => {
669 handle_transform!(batch, idx, ListArray, transform, transform_list, args)
670 }
671 DataType::ListView(_field) => {
672 handle_transform!(
673 batch,
674 idx,
675 ListViewArray,
676 transform,
677 transform_list_view,
678 args
679 )
680 }
681 DataType::FixedSizeList(_, _) => {
682 handle_transform!(
683 batch,
684 idx,
685 FixedSizeListArray,
686 transform,
687 transform_fixed_size_list,
688 args
689 )
690 }
691 DataType::LargeList(_field) => {
692 handle_transform!(
693 batch,
694 idx,
695 LargeListArray,
696 transform,
697 transform_large_list,
698 args
699 )
700 }
701 DataType::LargeListView(_field) => {
702 handle_transform!(
703 batch,
704 idx,
705 LargeListViewArray,
706 transform,
707 transform_large_list_view,
708 args
709 )
710 }
711 DataType::Decimal128(_, _) => {
712 handle_transform!(
713 batch,
714 idx,
715 Decimal128Array,
716 transform,
717 transform_decimal128,
718 args
719 )
720 }
721 DataType::Decimal256(_, _) => {
722 handle_transform!(
723 batch,
724 idx,
725 Decimal256Array,
726 transform,
727 transform_decimal256,
728 args
729 )
730 }
731 data_type => {
732 return Err(DataFusionError::NotImplemented(format!(
733 "Unsupported transform arrow type {data_type:?}",
734 )));
735 }
736 };
737 new_arrays.push(new_array);
738 new_fields.push(new_field);
739 }
740 let new_schema = Arc::new(Schema::new(new_fields));
741 Ok(RecordBatch::try_new(new_schema, new_arrays)?)
742}
743
744pub(crate) fn transform_schema(
745 schema: SchemaRef,
746 transform: &dyn Transform,
747 remote_schema: Option<&RemoteSchemaRef>,
748) -> DFResult<SchemaRef> {
749 let empty_record = RecordBatch::new_empty(schema.clone());
750 transform_batch(empty_record, transform, &schema, None, remote_schema)
751 .map(|batch| batch.schema())
752}