datafusion_federation/schema_cast/
record_convert.rs1use datafusion::arrow::{
2 array::{Array, RecordBatch},
3 compute::cast,
4 datatypes::{DataType, IntervalUnit, SchemaRef},
5};
6use std::sync::Arc;
7
8use super::{
9 intervals_cast::{
10 cast_interval_monthdaynano_to_daytime, cast_interval_monthdaynano_to_yearmonth,
11 },
12 lists_cast::{cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list},
13 struct_cast::cast_string_to_struct,
14};
15
16pub type Result<T, E = Error> = std::result::Result<T, E>;
17
18#[derive(Debug)]
19pub enum Error {
20 UnableToConvertRecordBatch {
21 source: datafusion::arrow::error::ArrowError,
22 },
23
24 UnexpectedNumberOfColumns {
25 expected: usize,
26 found: usize,
27 },
28}
29
30impl std::error::Error for Error {}
31
32impl std::fmt::Display for Error {
33 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34 match self {
35 Error::UnableToConvertRecordBatch { source } => {
36 write!(f, "Unable to convert record batch: {source}")
37 }
38 Error::UnexpectedNumberOfColumns { expected, found } => {
39 write!(
40 f,
41 "Unexpected number of columns. Expected: {expected}, Found: {found}",
42 )
43 }
44 }
45 }
46}
47
48#[allow(clippy::needless_pass_by_value)]
51pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Result<RecordBatch> {
52 let actual_schema = record_batch.schema();
53
54 if actual_schema.fields().len() != expected_schema.fields().len() {
55 return Err(Error::UnexpectedNumberOfColumns {
56 expected: expected_schema.fields().len(),
57 found: actual_schema.fields().len(),
58 });
59 }
60
61 let cols = expected_schema
62 .fields()
63 .iter()
64 .enumerate()
65 .map(|(i, expected_field)| {
66 let record_batch_col = record_batch.column(i);
67
68 match (record_batch_col.data_type(), expected_field.data_type()) {
69 (DataType::Utf8, DataType::List(item_type)) => {
70 cast_string_to_list::<i32>(&Arc::clone(record_batch_col), item_type)
71 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
72 }
73 (DataType::Utf8, DataType::LargeList(item_type)) => {
74 cast_string_to_large_list::<i32>(&Arc::clone(record_batch_col), item_type)
75 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
76 }
77 (DataType::Utf8, DataType::FixedSizeList(item_type, value_length)) => {
78 cast_string_to_fixed_size_list::<i32>(
79 &Arc::clone(record_batch_col),
80 item_type,
81 *value_length,
82 )
83 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
84 }
85 (DataType::Utf8, DataType::Struct(_)) => cast_string_to_struct::<i32>(
86 &Arc::clone(record_batch_col),
87 expected_field.clone(),
88 )
89 .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
90 (DataType::LargeUtf8, DataType::List(item_type)) => {
91 cast_string_to_list::<i64>(&Arc::clone(record_batch_col), item_type)
92 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
93 }
94 (DataType::LargeUtf8, DataType::LargeList(item_type)) => {
95 cast_string_to_large_list::<i64>(&Arc::clone(record_batch_col), item_type)
96 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
97 }
98 (DataType::LargeUtf8, DataType::FixedSizeList(item_type, value_length)) => {
99 cast_string_to_fixed_size_list::<i64>(
100 &Arc::clone(record_batch_col),
101 item_type,
102 *value_length,
103 )
104 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
105 }
106 (DataType::LargeUtf8, DataType::Struct(_)) => cast_string_to_struct::<i64>(
107 &Arc::clone(record_batch_col),
108 expected_field.clone(),
109 )
110 .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
111 (
112 DataType::Interval(IntervalUnit::MonthDayNano),
113 DataType::Interval(IntervalUnit::YearMonth),
114 ) => cast_interval_monthdaynano_to_yearmonth(&Arc::clone(record_batch_col))
115 .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
116 (
117 DataType::Interval(IntervalUnit::MonthDayNano),
118 DataType::Interval(IntervalUnit::DayTime),
119 ) => cast_interval_monthdaynano_to_daytime(&Arc::clone(record_batch_col))
120 .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
121 _ => cast(&Arc::clone(record_batch_col), expected_field.data_type())
122 .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
123 }
124 })
125 .collect::<Result<Vec<Arc<dyn Array>>>>()?;
126
127 RecordBatch::try_new(expected_schema, cols)
128 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
129}
130
131#[cfg(test)]
132mod test {
133 use super::*;
134 use datafusion::arrow::array::LargeStringArray;
135 use datafusion::arrow::{
136 array::{Int32Array, StringArray},
137 datatypes::{DataType, Field, Schema, TimeUnit},
138 };
139 use datafusion::assert_batches_eq;
140
141 fn schema() -> SchemaRef {
142 Arc::new(Schema::new(vec![
143 Field::new("a", DataType::Int32, false),
144 Field::new("b", DataType::Utf8, false),
145 Field::new("c", DataType::Utf8, false),
146 ]))
147 }
148
149 fn to_schema() -> SchemaRef {
150 Arc::new(Schema::new(vec![
151 Field::new("a", DataType::Int64, false),
152 Field::new("b", DataType::LargeUtf8, false),
153 Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false),
154 ]))
155 }
156
157 fn batch_input() -> RecordBatch {
158 RecordBatch::try_new(
159 schema(),
160 vec![
161 Arc::new(Int32Array::from(vec![1, 2, 3])),
162 Arc::new(StringArray::from(vec!["foo", "bar", "baz"])),
163 Arc::new(StringArray::from(vec![
164 "2024-01-13 03:18:09.000000",
165 "2024-01-13 03:18:09",
166 "2024-01-13 03:18:09.000",
167 ])),
168 ],
169 )
170 .expect("record batch should not panic")
171 }
172
173 #[test]
174 fn test_string_to_timestamp_conversion() {
175 let result = try_cast_to(batch_input(), to_schema()).expect("converted");
176 let expected = vec![
177 "+---+-----+---------------------+",
178 "| a | b | c |",
179 "+---+-----+---------------------+",
180 "| 1 | foo | 2024-01-13T03:18:09 |",
181 "| 2 | bar | 2024-01-13T03:18:09 |",
182 "| 3 | baz | 2024-01-13T03:18:09 |",
183 "+---+-----+---------------------+",
184 ];
185
186 assert_batches_eq!(expected, &[result]);
187 }
188
189 fn large_string_from_schema() -> SchemaRef {
190 Arc::new(Schema::new(vec![
191 Field::new("a", DataType::Int32, false),
192 Field::new("b", DataType::LargeUtf8, false),
193 Field::new("c", DataType::LargeUtf8, false),
194 ]))
195 }
196
197 fn large_string_to_schema() -> SchemaRef {
198 Arc::new(Schema::new(vec![
199 Field::new("a", DataType::Int64, false),
200 Field::new("b", DataType::LargeUtf8, false),
201 Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false),
202 ]))
203 }
204
205 fn large_string_batch_input() -> RecordBatch {
206 RecordBatch::try_new(
207 large_string_from_schema(),
208 vec![
209 Arc::new(Int32Array::from(vec![1, 2, 3])),
210 Arc::new(LargeStringArray::from(vec!["foo", "bar", "baz"])),
211 Arc::new(LargeStringArray::from(vec![
212 "2024-01-13 03:18:09.000000",
213 "2024-01-13 03:18:09",
214 "2024-01-13 03:18:09.000",
215 ])),
216 ],
217 )
218 .expect("record batch should not panic")
219 }
220
221 #[test]
222 fn test_large_string_to_timestamp_conversion() {
223 let result =
224 try_cast_to(large_string_batch_input(), large_string_to_schema()).expect("converted");
225 let expected = vec![
226 "+---+-----+---------------------+",
227 "| a | b | c |",
228 "+---+-----+---------------------+",
229 "| 1 | foo | 2024-01-13T03:18:09 |",
230 "| 2 | bar | 2024-01-13T03:18:09 |",
231 "| 3 | baz | 2024-01-13T03:18:09 |",
232 "+---+-----+---------------------+",
233 ];
234 assert_batches_eq!(expected, &[result]);
235 }
236}