1use crate::{DFResult, RemoteField, RemoteSchemaRef};
2use datafusion::arrow::array::{
3 ArrayRef, BinaryArray, BooleanArray, Date32Array, Float16Array, Float32Array, Float64Array,
4 Int16Array, Int32Array, Int64Array, Int8Array, ListArray, RecordBatch, StringArray,
5 Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
6 TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array,
7 UInt8Array,
8};
9use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
10use datafusion::common::{project_schema, DataFusionError};
11use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
12use futures::{Stream, StreamExt};
13use std::fmt::Debug;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18pub struct TransformArgs<'a> {
19 pub col_index: usize,
20 pub field: &'a Field,
21 pub remote_field: Option<&'a RemoteField>,
22}
23
24pub trait Transform: Debug + Send + Sync {
25 fn transform_boolean(
26 &self,
27 array: &BooleanArray,
28 args: TransformArgs,
29 ) -> DFResult<(ArrayRef, Field)> {
30 Ok((Arc::new(array.clone()), args.field.clone()))
31 }
32
33 fn transform_int8(
34 &self,
35 array: &Int8Array,
36 args: TransformArgs,
37 ) -> DFResult<(ArrayRef, Field)> {
38 Ok((Arc::new(array.clone()), args.field.clone()))
39 }
40
41 fn transform_int16(
42 &self,
43 array: &Int16Array,
44 args: TransformArgs,
45 ) -> DFResult<(ArrayRef, Field)> {
46 Ok((Arc::new(array.clone()), args.field.clone()))
47 }
48
49 fn transform_int32(
50 &self,
51 array: &Int32Array,
52 args: TransformArgs,
53 ) -> DFResult<(ArrayRef, Field)> {
54 Ok((Arc::new(array.clone()), args.field.clone()))
55 }
56
57 fn transform_int64(
58 &self,
59 array: &Int64Array,
60 args: TransformArgs,
61 ) -> DFResult<(ArrayRef, Field)> {
62 Ok((Arc::new(array.clone()), args.field.clone()))
63 }
64
65 fn transform_uint8(
66 &self,
67 array: &UInt8Array,
68 args: TransformArgs,
69 ) -> DFResult<(ArrayRef, Field)> {
70 Ok((Arc::new(array.clone()), args.field.clone()))
71 }
72
73 fn transform_uint16(
74 &self,
75 array: &UInt16Array,
76 args: TransformArgs,
77 ) -> DFResult<(ArrayRef, Field)> {
78 Ok((Arc::new(array.clone()), args.field.clone()))
79 }
80
81 fn transform_uint32(
82 &self,
83 array: &UInt32Array,
84 args: TransformArgs,
85 ) -> DFResult<(ArrayRef, Field)> {
86 Ok((Arc::new(array.clone()), args.field.clone()))
87 }
88
89 fn transform_uint64(
90 &self,
91 array: &UInt64Array,
92 args: TransformArgs,
93 ) -> DFResult<(ArrayRef, Field)> {
94 Ok((Arc::new(array.clone()), args.field.clone()))
95 }
96
97 fn transform_float16(
98 &self,
99 array: &Float16Array,
100 args: TransformArgs,
101 ) -> DFResult<(ArrayRef, Field)> {
102 Ok((Arc::new(array.clone()), args.field.clone()))
103 }
104
105 fn transform_float32(
106 &self,
107 array: &Float32Array,
108 args: TransformArgs,
109 ) -> DFResult<(ArrayRef, Field)> {
110 Ok((Arc::new(array.clone()), args.field.clone()))
111 }
112
113 fn transform_float64(
114 &self,
115 array: &Float64Array,
116 args: TransformArgs,
117 ) -> DFResult<(ArrayRef, Field)> {
118 Ok((Arc::new(array.clone()), args.field.clone()))
119 }
120
121 fn transform_utf8(
122 &self,
123 array: &StringArray,
124 args: TransformArgs,
125 ) -> DFResult<(ArrayRef, Field)> {
126 Ok((Arc::new(array.clone()), args.field.clone()))
127 }
128
129 fn transform_binary(
130 &self,
131 array: &BinaryArray,
132 args: TransformArgs,
133 ) -> DFResult<(ArrayRef, Field)> {
134 Ok((Arc::new(array.clone()), args.field.clone()))
135 }
136
137 fn transform_timestamp_second(
138 &self,
139 array: &TimestampSecondArray,
140 args: TransformArgs,
141 ) -> DFResult<(ArrayRef, Field)> {
142 Ok((Arc::new(array.clone()), args.field.clone()))
143 }
144
145 fn transform_timestamp_millisecond(
146 &self,
147 array: &TimestampMillisecondArray,
148 args: TransformArgs,
149 ) -> DFResult<(ArrayRef, Field)> {
150 Ok((Arc::new(array.clone()), args.field.clone()))
151 }
152
153 fn transform_timestamp_microsecond(
154 &self,
155 array: &TimestampMicrosecondArray,
156 args: TransformArgs,
157 ) -> DFResult<(ArrayRef, Field)> {
158 Ok((Arc::new(array.clone()), args.field.clone()))
159 }
160
161 fn transform_timestamp_nanosecond(
162 &self,
163 array: &TimestampNanosecondArray,
164 args: TransformArgs,
165 ) -> DFResult<(ArrayRef, Field)> {
166 Ok((Arc::new(array.clone()), args.field.clone()))
167 }
168
169 fn transform_time64_nanosecond(
170 &self,
171 array: &Time64NanosecondArray,
172 args: TransformArgs,
173 ) -> DFResult<(ArrayRef, Field)> {
174 Ok((Arc::new(array.clone()), args.field.clone()))
175 }
176
177 fn transform_date32(
178 &self,
179 array: &Date32Array,
180 args: TransformArgs,
181 ) -> DFResult<(ArrayRef, Field)> {
182 Ok((Arc::new(array.clone()), args.field.clone()))
183 }
184
185 fn transform_list(
186 &self,
187 array: &ListArray,
188 args: TransformArgs,
189 ) -> DFResult<(ArrayRef, Field)> {
190 Ok((Arc::new(array.clone()), args.field.clone()))
191 }
192}
193
194pub(crate) struct TransformStream {
195 input: SendableRecordBatchStream,
196 transform: Arc<dyn Transform>,
197 table_schema: SchemaRef,
198 projection: Option<Vec<usize>>,
199 projected_schema: SchemaRef,
200 remote_schema: Option<RemoteSchemaRef>,
201}
202
203impl TransformStream {
204 pub fn try_new(
205 input: SendableRecordBatchStream,
206 transform: Arc<dyn Transform>,
207 table_schema: SchemaRef,
208 projection: Option<Vec<usize>>,
209 remote_schema: Option<RemoteSchemaRef>,
210 ) -> DFResult<Self> {
211 let projected_schema = project_schema(&table_schema, projection.as_ref())?;
212 Ok(Self {
213 input,
214 transform,
215 table_schema,
216 projection,
217 projected_schema,
218 remote_schema,
219 })
220 }
221}
222
223impl Stream for TransformStream {
224 type Item = DFResult<RecordBatch>;
225 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
226 match self.input.poll_next_unpin(cx) {
227 Poll::Ready(Some(Ok(batch))) => {
228 match transform_batch(
229 batch,
230 self.transform.as_ref(),
231 &self.table_schema,
232 self.projection.as_ref(),
233 self.remote_schema.as_ref(),
234 ) {
235 Ok(result) => Poll::Ready(Some(Ok(result))),
236 Err(e) => Poll::Ready(Some(Err(e))),
237 }
238 }
239 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
240 Poll::Ready(None) => Poll::Ready(None),
241 Poll::Pending => Poll::Pending,
242 }
243 }
244}
245
246impl RecordBatchStream for TransformStream {
247 fn schema(&self) -> SchemaRef {
248 self.projected_schema.clone()
249 }
250}
251
252pub(crate) fn transform_batch(
253 batch: RecordBatch,
254 transform: &dyn Transform,
255 table_schema: &SchemaRef,
256 projection: Option<&Vec<usize>>,
257 remote_schema: Option<&RemoteSchemaRef>,
258) -> DFResult<RecordBatch> {
259 let mut new_arrays: Vec<ArrayRef> = Vec::with_capacity(batch.schema().fields.len());
260 let mut new_fields: Vec<Field> = Vec::with_capacity(batch.schema().fields.len());
261 let all_col_indexes = (0..table_schema.fields.len()).collect::<Vec<usize>>();
262 let projected_col_indexes = projection.unwrap_or(&all_col_indexes);
263
264 for (idx, col_index) in projected_col_indexes.iter().enumerate() {
265 let field = table_schema.field(*col_index);
266 let remote_field = remote_schema.map(|schema| &schema.fields[*col_index]);
267 let args = TransformArgs {
268 col_index: *col_index,
269 field,
270 remote_field,
271 };
272
273 let (new_array, new_field) = match &field.data_type() {
274 DataType::Boolean => {
275 let array = batch
276 .column(idx)
277 .as_any()
278 .downcast_ref::<BooleanArray>()
279 .expect("Failed to downcast to BooleanArray");
280 transform.transform_boolean(array, args)?
281 }
282 DataType::Int8 => {
283 let array = batch
284 .column(idx)
285 .as_any()
286 .downcast_ref::<Int8Array>()
287 .expect("Failed to downcast to Int8Array");
288 transform.transform_int8(array, args)?
289 }
290 DataType::Int16 => {
291 let array = batch
292 .column(idx)
293 .as_any()
294 .downcast_ref::<Int16Array>()
295 .expect("Failed to downcast to Int16Array");
296 transform.transform_int16(array, args)?
297 }
298 DataType::Int32 => {
299 let array = batch
300 .column(idx)
301 .as_any()
302 .downcast_ref::<Int32Array>()
303 .expect("Failed to downcast to Int32Array");
304 transform.transform_int32(array, args)?
305 }
306 DataType::Int64 => {
307 let array = batch
308 .column(idx)
309 .as_any()
310 .downcast_ref::<Int64Array>()
311 .expect("Failed to downcast to Int64Array");
312 transform.transform_int64(array, args)?
313 }
314 DataType::UInt8 => {
315 let array = batch
316 .column(idx)
317 .as_any()
318 .downcast_ref::<UInt8Array>()
319 .expect("Failed to downcast to UInt8Array");
320 transform.transform_uint8(array, args)?
321 }
322 DataType::UInt16 => {
323 let array = batch
324 .column(idx)
325 .as_any()
326 .downcast_ref::<UInt16Array>()
327 .expect("Failed to downcast to UInt16Array");
328 transform.transform_uint16(array, args)?
329 }
330 DataType::UInt32 => {
331 let array = batch
332 .column(idx)
333 .as_any()
334 .downcast_ref::<UInt32Array>()
335 .expect("Failed to downcast to UInt32Array");
336 transform.transform_uint32(array, args)?
337 }
338 DataType::UInt64 => {
339 let array = batch
340 .column(idx)
341 .as_any()
342 .downcast_ref::<UInt64Array>()
343 .expect("Failed to downcast to UInt64Array");
344 transform.transform_uint64(array, args)?
345 }
346 DataType::Float16 => {
347 let array = batch
348 .column(idx)
349 .as_any()
350 .downcast_ref::<Float16Array>()
351 .expect("Failed to downcast to Float16Array");
352 transform.transform_float16(array, args)?
353 }
354 DataType::Float32 => {
355 let array = batch
356 .column(idx)
357 .as_any()
358 .downcast_ref::<Float32Array>()
359 .expect("Failed to downcast to Float32Array");
360 transform.transform_float32(array, args)?
361 }
362 DataType::Float64 => {
363 let array = batch
364 .column(idx)
365 .as_any()
366 .downcast_ref::<Float64Array>()
367 .expect("Failed to downcast to Float64Array");
368 transform.transform_float64(array, args)?
369 }
370 DataType::Utf8 => {
371 let array = batch
372 .column(idx)
373 .as_any()
374 .downcast_ref::<StringArray>()
375 .expect("Failed to downcast to StringArray");
376 transform.transform_utf8(array, args)?
377 }
378 DataType::Binary => {
379 let array = batch
380 .column(idx)
381 .as_any()
382 .downcast_ref::<BinaryArray>()
383 .expect("Failed to downcast to BinaryArray");
384 transform.transform_binary(array, args)?
385 }
386 DataType::Date32 => {
387 let array = batch
388 .column(idx)
389 .as_any()
390 .downcast_ref::<Date32Array>()
391 .expect("Failed to downcast to Date32Array");
392 transform.transform_date32(array, args)?
393 }
394 DataType::Timestamp(TimeUnit::Second, _) => {
395 let array = batch
396 .column(idx)
397 .as_any()
398 .downcast_ref::<TimestampSecondArray>()
399 .expect("Failed to downcast to TimestampSecondArray");
400 transform.transform_timestamp_second(array, args)?
401 }
402 DataType::Timestamp(TimeUnit::Millisecond, _) => {
403 let array = batch
404 .column(idx)
405 .as_any()
406 .downcast_ref::<TimestampMillisecondArray>()
407 .expect("Failed to downcast to TimestampMillisecondArray");
408 transform.transform_timestamp_millisecond(array, args)?
409 }
410 DataType::Timestamp(TimeUnit::Microsecond, _) => {
411 let array = batch
412 .column(idx)
413 .as_any()
414 .downcast_ref::<TimestampMicrosecondArray>()
415 .expect("Failed to downcast to TimestampMicrosecondArray");
416 transform.transform_timestamp_microsecond(array, args)?
417 }
418 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
419 let array = batch
420 .column(idx)
421 .as_any()
422 .downcast_ref::<TimestampNanosecondArray>()
423 .expect("Failed to downcast to TimestampNanosecondArray");
424 transform.transform_timestamp_nanosecond(array, args)?
425 }
426 DataType::Time64(TimeUnit::Nanosecond) => {
427 let array = batch
428 .column(idx)
429 .as_any()
430 .downcast_ref::<Time64NanosecondArray>()
431 .expect("Failed to downcast to Time64NanosecondArray");
432 transform.transform_time64_nanosecond(array, args)?
433 }
434 DataType::List(_field) => {
435 let array = batch
436 .column(idx)
437 .as_any()
438 .downcast_ref::<ListArray>()
439 .expect("Failed to downcast to ListArray");
440 transform.transform_list(array, args)?
441 }
442 data_type => {
443 return Err(DataFusionError::NotImplemented(format!(
444 "Unsupported arrow type {data_type:?}",
445 )))
446 }
447 };
448 new_arrays.push(new_array);
449 new_fields.push(new_field);
450 }
451 let new_schema = Arc::new(Schema::new(new_fields));
452 Ok(RecordBatch::try_new(new_schema, new_arrays)?)
453}