datafusion_federation/schema_cast/
record_convert.rs

1use datafusion::arrow::{
2    array::{Array, RecordBatch, RecordBatchOptions},
3    compute::cast,
4    datatypes::{DataType, IntervalUnit, SchemaRef},
5};
6use std::sync::Arc;
7
8use super::{
9    intervals_cast::{
10        cast_interval_monthdaynano_to_daytime, cast_interval_monthdaynano_to_yearmonth,
11    },
12    lists_cast::{cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list},
13    struct_cast::cast_string_to_struct,
14};
15
16pub type Result<T, E = Error> = std::result::Result<T, E>;
17
18#[derive(Debug)]
19pub enum Error {
20    UnableToConvertRecordBatch {
21        source: datafusion::arrow::error::ArrowError,
22    },
23
24    UnexpectedNumberOfColumns {
25        expected: usize,
26        found: usize,
27    },
28}
29
30impl std::error::Error for Error {}
31
32impl std::fmt::Display for Error {
33    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34        match self {
35            Error::UnableToConvertRecordBatch { source } => {
36                write!(f, "Unable to convert record batch: {source}")
37            }
38            Error::UnexpectedNumberOfColumns { expected, found } => {
39                write!(
40                    f,
41                    "Unexpected number of columns. Expected: {expected}, Found: {found}",
42                )
43            }
44        }
45    }
46}
47
48/// Cast a given record batch into a new record batch with the given schema.
49/// It assumes the record batch columns are correctly ordered.
50#[allow(clippy::needless_pass_by_value)]
51pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Result<RecordBatch> {
52    let actual_schema = record_batch.schema();
53
54    if actual_schema.fields().len() != expected_schema.fields().len() {
55        return Err(Error::UnexpectedNumberOfColumns {
56            expected: expected_schema.fields().len(),
57            found: actual_schema.fields().len(),
58        });
59    }
60
61    let cols = expected_schema
62        .fields()
63        .iter()
64        .enumerate()
65        .map(|(i, expected_field)| {
66            let record_batch_col = record_batch.column(i);
67
68            match (record_batch_col.data_type(), expected_field.data_type()) {
69                (DataType::Utf8, DataType::List(item_type)) => {
70                    cast_string_to_list::<i32>(&Arc::clone(record_batch_col), item_type)
71                        .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
72                }
73                (DataType::Utf8, DataType::LargeList(item_type)) => {
74                    cast_string_to_large_list::<i32>(&Arc::clone(record_batch_col), item_type)
75                        .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
76                }
77                (DataType::Utf8, DataType::FixedSizeList(item_type, value_length)) => {
78                    cast_string_to_fixed_size_list::<i32>(
79                        &Arc::clone(record_batch_col),
80                        item_type,
81                        *value_length,
82                    )
83                    .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
84                }
85                (DataType::Utf8, DataType::Struct(_)) => cast_string_to_struct::<i32>(
86                    &Arc::clone(record_batch_col),
87                    expected_field.clone(),
88                )
89                .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
90                (DataType::LargeUtf8, DataType::List(item_type)) => {
91                    cast_string_to_list::<i64>(&Arc::clone(record_batch_col), item_type)
92                        .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
93                }
94                (DataType::LargeUtf8, DataType::LargeList(item_type)) => {
95                    cast_string_to_large_list::<i64>(&Arc::clone(record_batch_col), item_type)
96                        .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
97                }
98                (DataType::LargeUtf8, DataType::FixedSizeList(item_type, value_length)) => {
99                    cast_string_to_fixed_size_list::<i64>(
100                        &Arc::clone(record_batch_col),
101                        item_type,
102                        *value_length,
103                    )
104                    .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
105                }
106                (DataType::LargeUtf8, DataType::Struct(_)) => cast_string_to_struct::<i64>(
107                    &Arc::clone(record_batch_col),
108                    expected_field.clone(),
109                )
110                .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
111                (
112                    DataType::Interval(IntervalUnit::MonthDayNano),
113                    DataType::Interval(IntervalUnit::YearMonth),
114                ) => cast_interval_monthdaynano_to_yearmonth(&Arc::clone(record_batch_col))
115                    .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
116                (
117                    DataType::Interval(IntervalUnit::MonthDayNano),
118                    DataType::Interval(IntervalUnit::DayTime),
119                ) => cast_interval_monthdaynano_to_daytime(&Arc::clone(record_batch_col))
120                    .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
121                _ => cast(&Arc::clone(record_batch_col), expected_field.data_type())
122                    .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
123            }
124        })
125        .collect::<Result<Vec<Arc<dyn Array>>>>()?;
126
127    let options = RecordBatchOptions::new().with_row_count(Some(record_batch.num_rows()));
128    RecordBatch::try_new_with_options(expected_schema, cols, &options)
129        .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
130}
131
132#[cfg(test)]
133mod test {
134    use super::*;
135    use datafusion::arrow::array::{LargeStringArray, RecordBatchOptions};
136    use datafusion::arrow::{
137        array::{Int32Array, StringArray},
138        datatypes::{DataType, Field, Schema, TimeUnit},
139    };
140    use datafusion::assert_batches_eq;
141
142    fn schema() -> SchemaRef {
143        Arc::new(Schema::new(vec![
144            Field::new("a", DataType::Int32, false),
145            Field::new("b", DataType::Utf8, false),
146            Field::new("c", DataType::Utf8, false),
147        ]))
148    }
149
150    fn to_schema() -> SchemaRef {
151        Arc::new(Schema::new(vec![
152            Field::new("a", DataType::Int64, false),
153            Field::new("b", DataType::LargeUtf8, false),
154            Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false),
155        ]))
156    }
157
158    fn batch_input() -> RecordBatch {
159        RecordBatch::try_new(
160            schema(),
161            vec![
162                Arc::new(Int32Array::from(vec![1, 2, 3])),
163                Arc::new(StringArray::from(vec!["foo", "bar", "baz"])),
164                Arc::new(StringArray::from(vec![
165                    "2024-01-13 03:18:09.000000",
166                    "2024-01-13 03:18:09",
167                    "2024-01-13 03:18:09.000",
168                ])),
169            ],
170        )
171        .expect("record batch should not panic")
172    }
173
174    #[test]
175    fn test_string_to_timestamp_conversion() {
176        let result = try_cast_to(batch_input(), to_schema()).expect("converted");
177        let expected = [
178            "+---+-----+---------------------+",
179            "| a | b   | c                   |",
180            "+---+-----+---------------------+",
181            "| 1 | foo | 2024-01-13T03:18:09 |",
182            "| 2 | bar | 2024-01-13T03:18:09 |",
183            "| 3 | baz | 2024-01-13T03:18:09 |",
184            "+---+-----+---------------------+",
185        ];
186
187        assert_batches_eq!(expected, &[result]);
188    }
189
190    fn large_string_from_schema() -> SchemaRef {
191        Arc::new(Schema::new(vec![
192            Field::new("a", DataType::Int32, false),
193            Field::new("b", DataType::LargeUtf8, false),
194            Field::new("c", DataType::LargeUtf8, false),
195        ]))
196    }
197
198    fn large_string_to_schema() -> SchemaRef {
199        Arc::new(Schema::new(vec![
200            Field::new("a", DataType::Int64, false),
201            Field::new("b", DataType::LargeUtf8, false),
202            Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false),
203        ]))
204    }
205
206    fn large_string_batch_input() -> RecordBatch {
207        RecordBatch::try_new(
208            large_string_from_schema(),
209            vec![
210                Arc::new(Int32Array::from(vec![1, 2, 3])),
211                Arc::new(LargeStringArray::from(vec!["foo", "bar", "baz"])),
212                Arc::new(LargeStringArray::from(vec![
213                    "2024-01-13 03:18:09.000000",
214                    "2024-01-13 03:18:09",
215                    "2024-01-13 03:18:09.000",
216                ])),
217            ],
218        )
219        .expect("record batch should not panic")
220    }
221
222    #[test]
223    fn test_large_string_to_timestamp_conversion() {
224        let result =
225            try_cast_to(large_string_batch_input(), large_string_to_schema()).expect("converted");
226        let expected = [
227            "+---+-----+---------------------+",
228            "| a | b   | c                   |",
229            "+---+-----+---------------------+",
230            "| 1 | foo | 2024-01-13T03:18:09 |",
231            "| 2 | bar | 2024-01-13T03:18:09 |",
232            "| 3 | baz | 2024-01-13T03:18:09 |",
233            "+---+-----+---------------------+",
234        ];
235        assert_batches_eq!(expected, &[result]);
236    }
237
238    #[test]
239    fn test_convert_empty_batch() {
240        let schema = SchemaRef::new(Schema::empty());
241        let options = RecordBatchOptions::new().with_row_count(Some(10));
242        let batch = RecordBatch::try_new_with_options(schema.clone(), vec![], &options)
243            .expect("failed to create empty batch");
244        let result = try_cast_to(batch, schema).expect("converted");
245        let expected = ["++", "++", "++"];
246        assert_batches_eq!(expected, &[result]);
247    }
248}