datafusion_federation/schema_cast/
record_convert.rs1use datafusion::arrow::{
2 array::{Array, RecordBatch},
3 compute::cast,
4 datatypes::{DataType, IntervalUnit, SchemaRef},
5};
6use std::sync::Arc;
7
8use super::{
9 intervals_cast::{
10 cast_interval_monthdaynano_to_daytime, cast_interval_monthdaynano_to_yearmonth,
11 },
12 lists_cast::{cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list},
13 struct_cast::cast_string_to_struct,
14};
15
16pub type Result<T, E = Error> = std::result::Result<T, E>;
17
18#[derive(Debug)]
19pub enum Error {
20 UnableToConvertRecordBatch {
21 source: datafusion::arrow::error::ArrowError,
22 },
23
24 UnexpectedNumberOfColumns {
25 expected: usize,
26 found: usize,
27 },
28}
29
30impl std::error::Error for Error {}
31
32impl std::fmt::Display for Error {
33 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34 match self {
35 Error::UnableToConvertRecordBatch { source } => {
36 write!(f, "Unable to convert record batch: {}", source)
37 }
38 Error::UnexpectedNumberOfColumns { expected, found } => {
39 write!(
40 f,
41 "Unexpected number of columns. Expected: {}, Found: {}",
42 expected, found
43 )
44 }
45 }
46 }
47}
48
49#[allow(clippy::needless_pass_by_value)]
52pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Result<RecordBatch> {
53 let actual_schema = record_batch.schema();
54
55 if actual_schema.fields().len() != expected_schema.fields().len() {
56 return Err(Error::UnexpectedNumberOfColumns {
57 expected: expected_schema.fields().len(),
58 found: actual_schema.fields().len(),
59 });
60 }
61
62 let cols = expected_schema
63 .fields()
64 .iter()
65 .enumerate()
66 .map(|(i, expected_field)| {
67 let record_batch_col = record_batch.column(i);
68
69 match (record_batch_col.data_type(), expected_field.data_type()) {
70 (DataType::Utf8, DataType::List(item_type)) => {
71 cast_string_to_list::<i32>(&Arc::clone(record_batch_col), item_type)
72 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
73 }
74 (DataType::Utf8, DataType::LargeList(item_type)) => {
75 cast_string_to_large_list::<i32>(&Arc::clone(record_batch_col), item_type)
76 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
77 }
78 (DataType::Utf8, DataType::FixedSizeList(item_type, value_length)) => {
79 cast_string_to_fixed_size_list::<i32>(
80 &Arc::clone(record_batch_col),
81 item_type,
82 *value_length,
83 )
84 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
85 }
86 (DataType::Utf8, DataType::Struct(_)) => cast_string_to_struct::<i32>(
87 &Arc::clone(record_batch_col),
88 expected_field.clone(),
89 )
90 .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
91 (DataType::LargeUtf8, DataType::List(item_type)) => {
92 cast_string_to_list::<i64>(&Arc::clone(record_batch_col), item_type)
93 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
94 }
95 (DataType::LargeUtf8, DataType::LargeList(item_type)) => {
96 cast_string_to_large_list::<i64>(&Arc::clone(record_batch_col), item_type)
97 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
98 }
99 (DataType::LargeUtf8, DataType::FixedSizeList(item_type, value_length)) => {
100 cast_string_to_fixed_size_list::<i64>(
101 &Arc::clone(record_batch_col),
102 item_type,
103 *value_length,
104 )
105 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
106 }
107 (DataType::LargeUtf8, DataType::Struct(_)) => cast_string_to_struct::<i64>(
108 &Arc::clone(record_batch_col),
109 expected_field.clone(),
110 )
111 .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
112 (
113 DataType::Interval(IntervalUnit::MonthDayNano),
114 DataType::Interval(IntervalUnit::YearMonth),
115 ) => cast_interval_monthdaynano_to_yearmonth(&Arc::clone(record_batch_col))
116 .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
117 (
118 DataType::Interval(IntervalUnit::MonthDayNano),
119 DataType::Interval(IntervalUnit::DayTime),
120 ) => cast_interval_monthdaynano_to_daytime(&Arc::clone(record_batch_col))
121 .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
122 _ => cast(&Arc::clone(record_batch_col), expected_field.data_type())
123 .map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
124 }
125 })
126 .collect::<Result<Vec<Arc<dyn Array>>>>()?;
127
128 RecordBatch::try_new(expected_schema, cols)
129 .map_err(|e| Error::UnableToConvertRecordBatch { source: e })
130}
131
132#[cfg(test)]
133mod test {
134 use super::*;
135 use datafusion::arrow::array::LargeStringArray;
136 use datafusion::arrow::{
137 array::{Int32Array, StringArray},
138 datatypes::{DataType, Field, Schema, TimeUnit},
139 };
140 use datafusion::assert_batches_eq;
141
142 fn schema() -> SchemaRef {
143 Arc::new(Schema::new(vec![
144 Field::new("a", DataType::Int32, false),
145 Field::new("b", DataType::Utf8, false),
146 Field::new("c", DataType::Utf8, false),
147 ]))
148 }
149
150 fn to_schema() -> SchemaRef {
151 Arc::new(Schema::new(vec![
152 Field::new("a", DataType::Int64, false),
153 Field::new("b", DataType::LargeUtf8, false),
154 Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false),
155 ]))
156 }
157
158 fn batch_input() -> RecordBatch {
159 RecordBatch::try_new(
160 schema(),
161 vec![
162 Arc::new(Int32Array::from(vec![1, 2, 3])),
163 Arc::new(StringArray::from(vec!["foo", "bar", "baz"])),
164 Arc::new(StringArray::from(vec![
165 "2024-01-13 03:18:09.000000",
166 "2024-01-13 03:18:09",
167 "2024-01-13 03:18:09.000",
168 ])),
169 ],
170 )
171 .expect("record batch should not panic")
172 }
173
174 #[test]
175 fn test_string_to_timestamp_conversion() {
176 let result = try_cast_to(batch_input(), to_schema()).expect("converted");
177 let expected = vec![
178 "+---+-----+---------------------+",
179 "| a | b | c |",
180 "+---+-----+---------------------+",
181 "| 1 | foo | 2024-01-13T03:18:09 |",
182 "| 2 | bar | 2024-01-13T03:18:09 |",
183 "| 3 | baz | 2024-01-13T03:18:09 |",
184 "+---+-----+---------------------+",
185 ];
186
187 assert_batches_eq!(expected, &[result]);
188 }
189
190 fn large_string_from_schema() -> SchemaRef {
191 Arc::new(Schema::new(vec![
192 Field::new("a", DataType::Int32, false),
193 Field::new("b", DataType::LargeUtf8, false),
194 Field::new("c", DataType::LargeUtf8, false),
195 ]))
196 }
197
198 fn large_string_to_schema() -> SchemaRef {
199 Arc::new(Schema::new(vec![
200 Field::new("a", DataType::Int64, false),
201 Field::new("b", DataType::LargeUtf8, false),
202 Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false),
203 ]))
204 }
205
206 fn large_string_batch_input() -> RecordBatch {
207 RecordBatch::try_new(
208 large_string_from_schema(),
209 vec![
210 Arc::new(Int32Array::from(vec![1, 2, 3])),
211 Arc::new(LargeStringArray::from(vec!["foo", "bar", "baz"])),
212 Arc::new(LargeStringArray::from(vec![
213 "2024-01-13 03:18:09.000000",
214 "2024-01-13 03:18:09",
215 "2024-01-13 03:18:09.000",
216 ])),
217 ],
218 )
219 .expect("record batch should not panic")
220 }
221
222 #[test]
223 fn test_large_string_to_timestamp_conversion() {
224 let result =
225 try_cast_to(large_string_batch_input(), large_string_to_schema()).expect("converted");
226 let expected = vec![
227 "+---+-----+---------------------+",
228 "| a | b | c |",
229 "+---+-----+---------------------+",
230 "| 1 | foo | 2024-01-13T03:18:09 |",
231 "| 2 | bar | 2024-01-13T03:18:09 |",
232 "| 3 | baz | 2024-01-13T03:18:09 |",
233 "+---+-----+---------------------+",
234 ];
235 assert_batches_eq!(expected, &[result]);
236 }
237}