datafusion_datasource_avro/avro_to_arrow/
reader.rs1use super::arrow_array_reader::AvroArrowArrayReader;
19use arrow::datatypes::{Fields, SchemaRef};
20use arrow::error::Result as ArrowResult;
21use arrow::record_batch::RecordBatch;
22use datafusion_common::Result;
23use std::io::{Read, Seek};
24use std::sync::Arc;
25
26#[derive(Debug)]
28pub struct ReaderBuilder {
29 schema: Option<SchemaRef>,
33 batch_size: usize,
37 projection: Option<Vec<String>>,
39}
40
41impl Default for ReaderBuilder {
42 fn default() -> Self {
43 Self {
44 schema: None,
45 batch_size: 1024,
46 projection: None,
47 }
48 }
49}
50
51impl ReaderBuilder {
52 pub fn new() -> Self {
79 Self::default()
80 }
81
82 pub fn with_schema(mut self, schema: SchemaRef) -> Self {
84 self.schema = Some(schema);
85 self
86 }
87
88 pub fn read_schema(mut self) -> Self {
90 self.schema = None;
92 self
93 }
94
95 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
97 self.batch_size = batch_size;
98 self
99 }
100
101 pub fn with_projection(mut self, projection: Vec<String>) -> Self {
103 self.projection = Some(projection);
104 self
105 }
106
107 pub fn build<'a, R>(self, source: R) -> Result<Reader<'a, R>>
109 where
110 R: Read + Seek,
111 {
112 let mut source = source;
113
114 let schema = match self.schema {
116 Some(schema) => schema,
117 None => Arc::new(super::read_avro_schema_from_reader(&mut source)?),
118 };
119 source.rewind()?;
120 Reader::try_new(source, schema, self.batch_size, self.projection)
121 }
122}
123
124pub struct Reader<'a, R: Read> {
126 array_reader: AvroArrowArrayReader<'a, R>,
127 schema: SchemaRef,
128 batch_size: usize,
129}
130
131impl<R: Read> Reader<'_, R> {
132 pub fn try_new(
141 reader: R,
142 schema: SchemaRef,
143 batch_size: usize,
144 projection: Option<Vec<String>>,
145 ) -> Result<Self> {
146 let projected_schema = projection.as_ref().filter(|p| !p.is_empty()).map_or_else(
147 || Arc::clone(&schema),
148 |proj| {
149 Arc::new(arrow::datatypes::Schema::new(
150 proj.iter()
151 .filter_map(|name| {
152 schema.column_with_name(name).map(|(_, f)| f.clone())
153 })
154 .collect::<Fields>(),
155 ))
156 },
157 );
158
159 Ok(Self {
160 array_reader: AvroArrowArrayReader::try_new(
161 reader,
162 Arc::clone(&projected_schema),
163 )?,
164 schema: projected_schema,
165 batch_size,
166 })
167 }
168
169 pub fn schema(&self) -> SchemaRef {
172 Arc::clone(&self.schema)
173 }
174}
175
176impl<R: Read> Iterator for Reader<'_, R> {
177 type Item = ArrowResult<RecordBatch>;
178
179 fn next(&mut self) -> Option<Self::Item> {
182 self.array_reader.next_batch(self.batch_size)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use arrow::array::*;
190 use arrow::array::{
191 BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array,
192 TimestampMicrosecondArray,
193 };
194 use arrow::datatypes::TimeUnit;
195 use arrow::datatypes::{DataType, Field};
196 use std::fs::File;
197
198 fn build_reader(name: &str, projection: Option<Vec<String>>) -> Reader<File> {
199 let testdata = datafusion_common::test_util::arrow_test_data();
200 let filename = format!("{testdata}/avro/{name}");
201 let mut builder = ReaderBuilder::new().read_schema().with_batch_size(64);
202 if let Some(projection) = projection {
203 builder = builder.with_projection(projection);
204 }
205 builder.build(File::open(filename).unwrap()).unwrap()
206 }
207
208 fn get_col<'a, T: 'static>(
209 batch: &'a RecordBatch,
210 col: (usize, &Field),
211 ) -> Option<&'a T> {
212 batch.column(col.0).as_any().downcast_ref::<T>()
213 }
214
215 #[test]
216 fn test_avro_basic() {
217 let mut reader = build_reader("alltypes_dictionary.avro", None);
218 let batch = reader.next().unwrap().unwrap();
219
220 assert_eq!(11, batch.num_columns());
221 assert_eq!(2, batch.num_rows());
222
223 let schema = reader.schema();
224 let batch_schema = batch.schema();
225 assert_eq!(schema, batch_schema);
226
227 let id = schema.column_with_name("id").unwrap();
228 assert_eq!(0, id.0);
229 assert_eq!(&DataType::Int32, id.1.data_type());
230 let col = get_col::<Int32Array>(&batch, id).unwrap();
231 assert_eq!(0, col.value(0));
232 assert_eq!(1, col.value(1));
233 let bool_col = schema.column_with_name("bool_col").unwrap();
234 assert_eq!(1, bool_col.0);
235 assert_eq!(&DataType::Boolean, bool_col.1.data_type());
236 let col = get_col::<BooleanArray>(&batch, bool_col).unwrap();
237 assert!(col.value(0));
238 assert!(!col.value(1));
239 let tinyint_col = schema.column_with_name("tinyint_col").unwrap();
240 assert_eq!(2, tinyint_col.0);
241 assert_eq!(&DataType::Int32, tinyint_col.1.data_type());
242 let col = get_col::<Int32Array>(&batch, tinyint_col).unwrap();
243 assert_eq!(0, col.value(0));
244 assert_eq!(1, col.value(1));
245 let smallint_col = schema.column_with_name("smallint_col").unwrap();
246 assert_eq!(3, smallint_col.0);
247 assert_eq!(&DataType::Int32, smallint_col.1.data_type());
248 let col = get_col::<Int32Array>(&batch, smallint_col).unwrap();
249 assert_eq!(0, col.value(0));
250 assert_eq!(1, col.value(1));
251 let int_col = schema.column_with_name("int_col").unwrap();
252 assert_eq!(4, int_col.0);
253 let col = get_col::<Int32Array>(&batch, int_col).unwrap();
254 assert_eq!(0, col.value(0));
255 assert_eq!(1, col.value(1));
256 assert_eq!(&DataType::Int32, int_col.1.data_type());
257 let col = get_col::<Int32Array>(&batch, int_col).unwrap();
258 assert_eq!(0, col.value(0));
259 assert_eq!(1, col.value(1));
260 let bigint_col = schema.column_with_name("bigint_col").unwrap();
261 assert_eq!(5, bigint_col.0);
262 let col = get_col::<Int64Array>(&batch, bigint_col).unwrap();
263 assert_eq!(0, col.value(0));
264 assert_eq!(10, col.value(1));
265 assert_eq!(&DataType::Int64, bigint_col.1.data_type());
266 let float_col = schema.column_with_name("float_col").unwrap();
267 assert_eq!(6, float_col.0);
268 let col = get_col::<Float32Array>(&batch, float_col).unwrap();
269 assert_eq!(0.0, col.value(0));
270 assert_eq!(1.1, col.value(1));
271 assert_eq!(&DataType::Float32, float_col.1.data_type());
272 let col = get_col::<Float32Array>(&batch, float_col).unwrap();
273 assert_eq!(0.0, col.value(0));
274 assert_eq!(1.1, col.value(1));
275 let double_col = schema.column_with_name("double_col").unwrap();
276 assert_eq!(7, double_col.0);
277 assert_eq!(&DataType::Float64, double_col.1.data_type());
278 let col = get_col::<Float64Array>(&batch, double_col).unwrap();
279 assert_eq!(0.0, col.value(0));
280 assert_eq!(10.1, col.value(1));
281 let date_string_col = schema.column_with_name("date_string_col").unwrap();
282 assert_eq!(8, date_string_col.0);
283 assert_eq!(&DataType::Binary, date_string_col.1.data_type());
284 let col = get_col::<BinaryArray>(&batch, date_string_col).unwrap();
285 assert_eq!("01/01/09".as_bytes(), col.value(0));
286 assert_eq!("01/01/09".as_bytes(), col.value(1));
287 let string_col = schema.column_with_name("string_col").unwrap();
288 assert_eq!(9, string_col.0);
289 assert_eq!(&DataType::Binary, string_col.1.data_type());
290 let col = get_col::<BinaryArray>(&batch, string_col).unwrap();
291 assert_eq!("0".as_bytes(), col.value(0));
292 assert_eq!("1".as_bytes(), col.value(1));
293 let timestamp_col = schema.column_with_name("timestamp_col").unwrap();
294 assert_eq!(10, timestamp_col.0);
295 assert_eq!(
296 &DataType::Timestamp(TimeUnit::Microsecond, None),
297 timestamp_col.1.data_type()
298 );
299 let col = get_col::<TimestampMicrosecondArray>(&batch, timestamp_col).unwrap();
300 assert_eq!(1230768000000000, col.value(0));
301 assert_eq!(1230768060000000, col.value(1));
302 }
303
304 #[test]
305 fn test_avro_with_projection() {
306 let projection = Some(vec![
308 "string_col".to_string(),
309 "double_col".to_string(),
310 "bool_col".to_string(),
311 ]);
312 let mut reader = build_reader("alltypes_dictionary.avro", projection);
313 let batch = reader.next().unwrap().unwrap();
314
315 assert_eq!(3, batch.num_columns());
317 assert_eq!(2, batch.num_rows());
318
319 let schema = reader.schema();
320 let batch_schema = batch.schema();
321 assert_eq!(schema, batch_schema);
322
323 assert_eq!("string_col", schema.field(0).name());
326 assert_eq!(&DataType::Binary, schema.field(0).data_type());
327 let col = batch
328 .column(0)
329 .as_any()
330 .downcast_ref::<BinaryArray>()
331 .unwrap();
332 assert_eq!("0".as_bytes(), col.value(0));
333 assert_eq!("1".as_bytes(), col.value(1));
334
335 assert_eq!("double_col", schema.field(1).name());
337 assert_eq!(&DataType::Float64, schema.field(1).data_type());
338 let col = batch
339 .column(1)
340 .as_any()
341 .downcast_ref::<Float64Array>()
342 .unwrap();
343 assert_eq!(0.0, col.value(0));
344 assert_eq!(10.1, col.value(1));
345
346 assert_eq!("bool_col", schema.field(2).name());
348 assert_eq!(&DataType::Boolean, schema.field(2).data_type());
349 let col = batch
350 .column(2)
351 .as_any()
352 .downcast_ref::<BooleanArray>()
353 .unwrap();
354 assert!(col.value(0));
355 assert!(!col.value(1));
356 }
357}