1use crate::{DFResult, RemoteField, RemoteSchemaRef};
2use datafusion::arrow::array::{
3 Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Float16Array, Float32Array,
4 Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, ListArray, NullArray, RecordBatch,
5 StringArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
6 TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array,
7 UInt64Array,
8};
9use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
10use datafusion::common::{DataFusionError, project_schema};
11use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
12use futures::{Stream, StreamExt};
13use std::any::Any;
14use std::fmt::Debug;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18
19pub struct TransformArgs<'a> {
20 pub col_index: usize,
21 pub field: &'a Field,
22 pub remote_field: Option<&'a RemoteField>,
23}
24
25pub trait Transform: Debug + Send + Sync {
26 fn as_any(&self) -> &dyn Any;
27
28 fn transform_null(
29 &self,
30 array: &NullArray,
31 args: TransformArgs,
32 ) -> DFResult<(ArrayRef, Field)> {
33 Ok((Arc::new(array.clone()), args.field.clone()))
34 }
35
36 fn transform_boolean(
37 &self,
38 array: &BooleanArray,
39 args: TransformArgs,
40 ) -> DFResult<(ArrayRef, Field)> {
41 Ok((Arc::new(array.clone()), args.field.clone()))
42 }
43
44 fn transform_int8(
45 &self,
46 array: &Int8Array,
47 args: TransformArgs,
48 ) -> DFResult<(ArrayRef, Field)> {
49 Ok((Arc::new(array.clone()), args.field.clone()))
50 }
51
52 fn transform_int16(
53 &self,
54 array: &Int16Array,
55 args: TransformArgs,
56 ) -> DFResult<(ArrayRef, Field)> {
57 Ok((Arc::new(array.clone()), args.field.clone()))
58 }
59
60 fn transform_int32(
61 &self,
62 array: &Int32Array,
63 args: TransformArgs,
64 ) -> DFResult<(ArrayRef, Field)> {
65 Ok((Arc::new(array.clone()), args.field.clone()))
66 }
67
68 fn transform_int64(
69 &self,
70 array: &Int64Array,
71 args: TransformArgs,
72 ) -> DFResult<(ArrayRef, Field)> {
73 Ok((Arc::new(array.clone()), args.field.clone()))
74 }
75
76 fn transform_uint8(
77 &self,
78 array: &UInt8Array,
79 args: TransformArgs,
80 ) -> DFResult<(ArrayRef, Field)> {
81 Ok((Arc::new(array.clone()), args.field.clone()))
82 }
83
84 fn transform_uint16(
85 &self,
86 array: &UInt16Array,
87 args: TransformArgs,
88 ) -> DFResult<(ArrayRef, Field)> {
89 Ok((Arc::new(array.clone()), args.field.clone()))
90 }
91
92 fn transform_uint32(
93 &self,
94 array: &UInt32Array,
95 args: TransformArgs,
96 ) -> DFResult<(ArrayRef, Field)> {
97 Ok((Arc::new(array.clone()), args.field.clone()))
98 }
99
100 fn transform_uint64(
101 &self,
102 array: &UInt64Array,
103 args: TransformArgs,
104 ) -> DFResult<(ArrayRef, Field)> {
105 Ok((Arc::new(array.clone()), args.field.clone()))
106 }
107
108 fn transform_float16(
109 &self,
110 array: &Float16Array,
111 args: TransformArgs,
112 ) -> DFResult<(ArrayRef, Field)> {
113 Ok((Arc::new(array.clone()), args.field.clone()))
114 }
115
116 fn transform_float32(
117 &self,
118 array: &Float32Array,
119 args: TransformArgs,
120 ) -> DFResult<(ArrayRef, Field)> {
121 Ok((Arc::new(array.clone()), args.field.clone()))
122 }
123
124 fn transform_float64(
125 &self,
126 array: &Float64Array,
127 args: TransformArgs,
128 ) -> DFResult<(ArrayRef, Field)> {
129 Ok((Arc::new(array.clone()), args.field.clone()))
130 }
131
132 fn transform_utf8(
133 &self,
134 array: &StringArray,
135 args: TransformArgs,
136 ) -> DFResult<(ArrayRef, Field)> {
137 Ok((Arc::new(array.clone()), args.field.clone()))
138 }
139
140 fn transform_binary(
141 &self,
142 array: &BinaryArray,
143 args: TransformArgs,
144 ) -> DFResult<(ArrayRef, Field)> {
145 Ok((Arc::new(array.clone()), args.field.clone()))
146 }
147
148 fn transform_timestamp_second(
149 &self,
150 array: &TimestampSecondArray,
151 args: TransformArgs,
152 ) -> DFResult<(ArrayRef, Field)> {
153 Ok((Arc::new(array.clone()), args.field.clone()))
154 }
155
156 fn transform_timestamp_millisecond(
157 &self,
158 array: &TimestampMillisecondArray,
159 args: TransformArgs,
160 ) -> DFResult<(ArrayRef, Field)> {
161 Ok((Arc::new(array.clone()), args.field.clone()))
162 }
163
164 fn transform_timestamp_microsecond(
165 &self,
166 array: &TimestampMicrosecondArray,
167 args: TransformArgs,
168 ) -> DFResult<(ArrayRef, Field)> {
169 Ok((Arc::new(array.clone()), args.field.clone()))
170 }
171
172 fn transform_timestamp_nanosecond(
173 &self,
174 array: &TimestampNanosecondArray,
175 args: TransformArgs,
176 ) -> DFResult<(ArrayRef, Field)> {
177 Ok((Arc::new(array.clone()), args.field.clone()))
178 }
179
180 fn transform_time64_nanosecond(
181 &self,
182 array: &Time64NanosecondArray,
183 args: TransformArgs,
184 ) -> DFResult<(ArrayRef, Field)> {
185 Ok((Arc::new(array.clone()), args.field.clone()))
186 }
187
188 fn transform_date32(
189 &self,
190 array: &Date32Array,
191 args: TransformArgs,
192 ) -> DFResult<(ArrayRef, Field)> {
193 Ok((Arc::new(array.clone()), args.field.clone()))
194 }
195
196 fn transform_list(
197 &self,
198 array: &ListArray,
199 args: TransformArgs,
200 ) -> DFResult<(ArrayRef, Field)> {
201 Ok((Arc::new(array.clone()), args.field.clone()))
202 }
203}
204
205pub(crate) struct TransformStream {
206 input: SendableRecordBatchStream,
207 transform: Arc<dyn Transform>,
208 table_schema: SchemaRef,
209 projection: Option<Vec<usize>>,
210 projected_schema: SchemaRef,
211 remote_schema: Option<RemoteSchemaRef>,
212}
213
214impl TransformStream {
215 pub fn try_new(
216 input: SendableRecordBatchStream,
217 transform: Arc<dyn Transform>,
218 table_schema: SchemaRef,
219 projection: Option<Vec<usize>>,
220 remote_schema: Option<RemoteSchemaRef>,
221 ) -> DFResult<Self> {
222 let projected_schema = project_schema(&table_schema, projection.as_ref())?;
223 Ok(Self {
224 input,
225 transform,
226 table_schema,
227 projection,
228 projected_schema,
229 remote_schema,
230 })
231 }
232}
233
234impl Stream for TransformStream {
235 type Item = DFResult<RecordBatch>;
236 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237 match self.input.poll_next_unpin(cx) {
238 Poll::Ready(Some(Ok(batch))) => {
239 match transform_batch(
240 batch,
241 self.transform.as_ref(),
242 &self.table_schema,
243 self.projection.as_ref(),
244 self.remote_schema.as_ref(),
245 ) {
246 Ok(result) => Poll::Ready(Some(Ok(result))),
247 Err(e) => Poll::Ready(Some(Err(e))),
248 }
249 }
250 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
251 Poll::Ready(None) => Poll::Ready(None),
252 Poll::Pending => Poll::Pending,
253 }
254 }
255}
256
257impl RecordBatchStream for TransformStream {
258 fn schema(&self) -> SchemaRef {
259 self.projected_schema.clone()
260 }
261}
262
263pub(crate) fn transform_batch(
264 batch: RecordBatch,
265 transform: &dyn Transform,
266 table_schema: &SchemaRef,
267 projection: Option<&Vec<usize>>,
268 remote_schema: Option<&RemoteSchemaRef>,
269) -> DFResult<RecordBatch> {
270 let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(batch.schema().fields.len());
271 let mut new_fields: Vec<Field> = Vec::with_capacity(batch.schema().fields.len());
272 let all_col_indexes = (0..table_schema.fields.len()).collect::<Vec<usize>>();
273 let projected_col_indexes = projection.unwrap_or(&all_col_indexes);
274
275 for (idx, col_index) in projected_col_indexes.iter().enumerate() {
276 let field = table_schema.field(*col_index);
277 let remote_field = remote_schema.map(|schema| &schema.fields[*col_index]);
278 let args = TransformArgs {
279 col_index: *col_index,
280 field,
281 remote_field,
282 };
283
284 let (new_array, new_field) = match &field.data_type() {
285 DataType::Null => {
286 let array = batch
287 .column(idx)
288 .as_any()
289 .downcast_ref::<NullArray>()
290 .expect("Failed to downcast to NullArray");
291 transform.transform_null(array, args)?
292 }
293 DataType::Boolean => {
294 let array = batch
295 .column(idx)
296 .as_any()
297 .downcast_ref::<BooleanArray>()
298 .expect("Failed to downcast to BooleanArray");
299 transform.transform_boolean(array, args)?
300 }
301 DataType::Int8 => {
302 let array = batch
303 .column(idx)
304 .as_any()
305 .downcast_ref::<Int8Array>()
306 .expect("Failed to downcast to Int8Array");
307 transform.transform_int8(array, args)?
308 }
309 DataType::Int16 => {
310 let array = batch
311 .column(idx)
312 .as_any()
313 .downcast_ref::<Int16Array>()
314 .expect("Failed to downcast to Int16Array");
315 transform.transform_int16(array, args)?
316 }
317 DataType::Int32 => {
318 let array = batch
319 .column(idx)
320 .as_any()
321 .downcast_ref::<Int32Array>()
322 .expect("Failed to downcast to Int32Array");
323 transform.transform_int32(array, args)?
324 }
325 DataType::Int64 => {
326 let array = batch
327 .column(idx)
328 .as_any()
329 .downcast_ref::<Int64Array>()
330 .expect("Failed to downcast to Int64Array");
331 transform.transform_int64(array, args)?
332 }
333 DataType::UInt8 => {
334 let array = batch
335 .column(idx)
336 .as_any()
337 .downcast_ref::<UInt8Array>()
338 .expect("Failed to downcast to UInt8Array");
339 transform.transform_uint8(array, args)?
340 }
341 DataType::UInt16 => {
342 let array = batch
343 .column(idx)
344 .as_any()
345 .downcast_ref::<UInt16Array>()
346 .expect("Failed to downcast to UInt16Array");
347 transform.transform_uint16(array, args)?
348 }
349 DataType::UInt32 => {
350 let array = batch
351 .column(idx)
352 .as_any()
353 .downcast_ref::<UInt32Array>()
354 .expect("Failed to downcast to UInt32Array");
355 transform.transform_uint32(array, args)?
356 }
357 DataType::UInt64 => {
358 let array = batch
359 .column(idx)
360 .as_any()
361 .downcast_ref::<UInt64Array>()
362 .expect("Failed to downcast to UInt64Array");
363 transform.transform_uint64(array, args)?
364 }
365 DataType::Float16 => {
366 let array = batch
367 .column(idx)
368 .as_any()
369 .downcast_ref::<Float16Array>()
370 .expect("Failed to downcast to Float16Array");
371 transform.transform_float16(array, args)?
372 }
373 DataType::Float32 => {
374 let array = batch
375 .column(idx)
376 .as_any()
377 .downcast_ref::<Float32Array>()
378 .expect("Failed to downcast to Float32Array");
379 transform.transform_float32(array, args)?
380 }
381 DataType::Float64 => {
382 let array = batch
383 .column(idx)
384 .as_any()
385 .downcast_ref::<Float64Array>()
386 .expect("Failed to downcast to Float64Array");
387 transform.transform_float64(array, args)?
388 }
389 DataType::Utf8 => {
390 let array = batch
391 .column(idx)
392 .as_any()
393 .downcast_ref::<StringArray>()
394 .expect("Failed to downcast to StringArray");
395 transform.transform_utf8(array, args)?
396 }
397 DataType::Binary => {
398 let array = batch
399 .column(idx)
400 .as_any()
401 .downcast_ref::<BinaryArray>()
402 .expect("Failed to downcast to BinaryArray");
403 transform.transform_binary(array, args)?
404 }
405 DataType::Date32 => {
406 let array = batch
407 .column(idx)
408 .as_any()
409 .downcast_ref::<Date32Array>()
410 .expect("Failed to downcast to Date32Array");
411 transform.transform_date32(array, args)?
412 }
413 DataType::Timestamp(TimeUnit::Second, _) => {
414 let array = batch
415 .column(idx)
416 .as_any()
417 .downcast_ref::<TimestampSecondArray>()
418 .expect("Failed to downcast to TimestampSecondArray");
419 transform.transform_timestamp_second(array, args)?
420 }
421 DataType::Timestamp(TimeUnit::Millisecond, _) => {
422 let array = batch
423 .column(idx)
424 .as_any()
425 .downcast_ref::<TimestampMillisecondArray>()
426 .expect("Failed to downcast to TimestampMillisecondArray");
427 transform.transform_timestamp_millisecond(array, args)?
428 }
429 DataType::Timestamp(TimeUnit::Microsecond, _) => {
430 let array = batch
431 .column(idx)
432 .as_any()
433 .downcast_ref::<TimestampMicrosecondArray>()
434 .expect("Failed to downcast to TimestampMicrosecondArray");
435 transform.transform_timestamp_microsecond(array, args)?
436 }
437 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
438 let array = batch
439 .column(idx)
440 .as_any()
441 .downcast_ref::<TimestampNanosecondArray>()
442 .expect("Failed to downcast to TimestampNanosecondArray");
443 transform.transform_timestamp_nanosecond(array, args)?
444 }
445 DataType::Time64(TimeUnit::Nanosecond) => {
446 let array = batch
447 .column(idx)
448 .as_any()
449 .downcast_ref::<Time64NanosecondArray>()
450 .expect("Failed to downcast to Time64NanosecondArray");
451 transform.transform_time64_nanosecond(array, args)?
452 }
453 DataType::List(_field) => {
454 let array = batch
455 .column(idx)
456 .as_any()
457 .downcast_ref::<ListArray>()
458 .expect("Failed to downcast to ListArray");
459 transform.transform_list(array, args)?
460 }
461 data_type => {
462 return Err(DataFusionError::NotImplemented(format!(
463 "Unsupported arrow type {data_type:?}",
464 )));
465 }
466 };
467 new_arrays.push(new_array);
468 new_fields.push(new_field);
469 }
470 let new_schema = Arc::new(Schema::new(new_fields));
471 Ok(RecordBatch::try_new(new_schema, new_arrays)?)
472}
473
474pub fn transform_schema(
475 schema: SchemaRef,
476 transform: &dyn Transform,
477 remote_schema: Option<&RemoteSchemaRef>,
478) -> DFResult<SchemaRef> {
479 let empty_record = RecordBatch::new_empty(schema.clone());
480 transform_batch(empty_record, transform, &schema, None, remote_schema)
481 .map(|batch| batch.schema())
482}