datafusion_datasource_avro/avro_to_arrow/
reader.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::arrow_array_reader::AvroArrowArrayReader;
19use arrow::datatypes::{Fields, SchemaRef};
20use arrow::error::Result as ArrowResult;
21use arrow::record_batch::RecordBatch;
22use datafusion_common::Result;
23use std::io::{Read, Seek};
24use std::sync::Arc;
25
26/// Avro file reader builder
27#[derive(Debug)]
28pub struct ReaderBuilder {
29    /// Optional schema for the Avro file
30    ///
31    /// If the schema is not supplied, the reader will try to read the schema.
32    schema: Option<SchemaRef>,
33    /// Batch size (number of records to load each time)
34    ///
35    /// The default batch size when using the `ReaderBuilder` is 1024 records
36    batch_size: usize,
37    /// Optional projection for which columns to load (zero-based column indices)
38    projection: Option<Vec<String>>,
39}
40
41impl Default for ReaderBuilder {
42    fn default() -> Self {
43        Self {
44            schema: None,
45            batch_size: 1024,
46            projection: None,
47        }
48    }
49}
50
51impl ReaderBuilder {
52    /// Create a new builder for configuring Avro parsing options.
53    ///
54    /// To convert a builder into a reader, call `Reader::from_builder`
55    ///
56    /// # Example
57    ///
58    /// ```
59    /// use std::fs::File;
60    ///
61    /// use datafusion_datasource_avro::avro_to_arrow::{Reader, ReaderBuilder};
62    ///
63    /// fn example() -> Reader<'static, File> {
64    ///     let file = File::open("test/data/basic.avro").unwrap();
65    ///
66    ///     // create a builder, inferring the schema with the first 100 records
67    ///     let builder = ReaderBuilder::new()
68    ///       .read_schema()
69    ///       .with_batch_size(100);
70    ///
71    ///     let reader = builder
72    ///       .build::<File>(file)
73    ///       .unwrap();
74    ///
75    ///     reader
76    /// }
77    /// ```
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Set the Avro file's schema
83    pub fn with_schema(mut self, schema: SchemaRef) -> Self {
84        self.schema = Some(schema);
85        self
86    }
87
88    /// Set the Avro reader to infer the schema of the file
89    pub fn read_schema(mut self) -> Self {
90        // remove any schema that is set
91        self.schema = None;
92        self
93    }
94
95    /// Set the batch size (number of records to load at one time)
96    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
97        self.batch_size = batch_size;
98        self
99    }
100
101    /// Set the reader's column projection
102    pub fn with_projection(mut self, projection: Vec<String>) -> Self {
103        self.projection = Some(projection);
104        self
105    }
106
107    /// Create a new `Reader` from the `ReaderBuilder`
108    pub fn build<'a, R>(self, source: R) -> Result<Reader<'a, R>>
109    where
110        R: Read + Seek,
111    {
112        let mut source = source;
113
114        // check if schema should be inferred
115        let schema = match self.schema {
116            Some(schema) => schema,
117            None => Arc::new(super::read_avro_schema_from_reader(&mut source)?),
118        };
119        source.rewind()?;
120        Reader::try_new(source, schema, self.batch_size, self.projection)
121    }
122}
123
124/// Avro file record  reader
125pub struct Reader<'a, R: Read> {
126    array_reader: AvroArrowArrayReader<'a, R>,
127    schema: SchemaRef,
128    batch_size: usize,
129}
130
131impl<R: Read> Reader<'_, R> {
132    /// Create a new Avro Reader from any value that implements the `Read` trait.
133    ///
134    /// If reading a `File`, you can customise the Reader, such as to enable schema
135    /// inference, use `ReaderBuilder`.
136    ///
137    /// If projection is provided, it uses a schema with only the fields in the projection, respecting their order.
138    /// Only the first level of projection is handled. No further projection currently occurs, but would be
139    /// useful if plucking values from a struct, e.g. getting `a.b.c.e` from `a.b.c.{d, e}`.
140    pub fn try_new(
141        reader: R,
142        schema: SchemaRef,
143        batch_size: usize,
144        projection: Option<Vec<String>>,
145    ) -> Result<Self> {
146        let projected_schema = projection.as_ref().filter(|p| !p.is_empty()).map_or_else(
147            || Arc::clone(&schema),
148            |proj| {
149                Arc::new(arrow::datatypes::Schema::new(
150                    proj.iter()
151                        .filter_map(|name| {
152                            schema.column_with_name(name).map(|(_, f)| f.clone())
153                        })
154                        .collect::<Fields>(),
155                ))
156            },
157        );
158
159        Ok(Self {
160            array_reader: AvroArrowArrayReader::try_new(
161                reader,
162                Arc::clone(&projected_schema),
163            )?,
164            schema: projected_schema,
165            batch_size,
166        })
167    }
168
169    /// Returns the schema of the reader, useful for getting the schema without reading
170    /// record batches
171    pub fn schema(&self) -> SchemaRef {
172        Arc::clone(&self.schema)
173    }
174}
175
176impl<R: Read> Iterator for Reader<'_, R> {
177    type Item = ArrowResult<RecordBatch>;
178
179    /// Returns the next batch of results (defined by `self.batch_size`), or `None` if there
180    /// are no more results.
181    fn next(&mut self) -> Option<Self::Item> {
182        self.array_reader.next_batch(self.batch_size)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use arrow::array::*;
190    use arrow::array::{
191        BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array,
192        TimestampMicrosecondArray,
193    };
194    use arrow::datatypes::TimeUnit;
195    use arrow::datatypes::{DataType, Field};
196    use std::fs::File;
197
198    fn build_reader(name: &str, projection: Option<Vec<String>>) -> Reader<File> {
199        let testdata = datafusion_common::test_util::arrow_test_data();
200        let filename = format!("{testdata}/avro/{name}");
201        let mut builder = ReaderBuilder::new().read_schema().with_batch_size(64);
202        if let Some(projection) = projection {
203            builder = builder.with_projection(projection);
204        }
205        builder.build(File::open(filename).unwrap()).unwrap()
206    }
207
208    fn get_col<'a, T: 'static>(
209        batch: &'a RecordBatch,
210        col: (usize, &Field),
211    ) -> Option<&'a T> {
212        batch.column(col.0).as_any().downcast_ref::<T>()
213    }
214
215    #[test]
216    fn test_avro_basic() {
217        let mut reader = build_reader("alltypes_dictionary.avro", None);
218        let batch = reader.next().unwrap().unwrap();
219
220        assert_eq!(11, batch.num_columns());
221        assert_eq!(2, batch.num_rows());
222
223        let schema = reader.schema();
224        let batch_schema = batch.schema();
225        assert_eq!(schema, batch_schema);
226
227        let id = schema.column_with_name("id").unwrap();
228        assert_eq!(0, id.0);
229        assert_eq!(&DataType::Int32, id.1.data_type());
230        let col = get_col::<Int32Array>(&batch, id).unwrap();
231        assert_eq!(0, col.value(0));
232        assert_eq!(1, col.value(1));
233        let bool_col = schema.column_with_name("bool_col").unwrap();
234        assert_eq!(1, bool_col.0);
235        assert_eq!(&DataType::Boolean, bool_col.1.data_type());
236        let col = get_col::<BooleanArray>(&batch, bool_col).unwrap();
237        assert!(col.value(0));
238        assert!(!col.value(1));
239        let tinyint_col = schema.column_with_name("tinyint_col").unwrap();
240        assert_eq!(2, tinyint_col.0);
241        assert_eq!(&DataType::Int32, tinyint_col.1.data_type());
242        let col = get_col::<Int32Array>(&batch, tinyint_col).unwrap();
243        assert_eq!(0, col.value(0));
244        assert_eq!(1, col.value(1));
245        let smallint_col = schema.column_with_name("smallint_col").unwrap();
246        assert_eq!(3, smallint_col.0);
247        assert_eq!(&DataType::Int32, smallint_col.1.data_type());
248        let col = get_col::<Int32Array>(&batch, smallint_col).unwrap();
249        assert_eq!(0, col.value(0));
250        assert_eq!(1, col.value(1));
251        let int_col = schema.column_with_name("int_col").unwrap();
252        assert_eq!(4, int_col.0);
253        let col = get_col::<Int32Array>(&batch, int_col).unwrap();
254        assert_eq!(0, col.value(0));
255        assert_eq!(1, col.value(1));
256        assert_eq!(&DataType::Int32, int_col.1.data_type());
257        let col = get_col::<Int32Array>(&batch, int_col).unwrap();
258        assert_eq!(0, col.value(0));
259        assert_eq!(1, col.value(1));
260        let bigint_col = schema.column_with_name("bigint_col").unwrap();
261        assert_eq!(5, bigint_col.0);
262        let col = get_col::<Int64Array>(&batch, bigint_col).unwrap();
263        assert_eq!(0, col.value(0));
264        assert_eq!(10, col.value(1));
265        assert_eq!(&DataType::Int64, bigint_col.1.data_type());
266        let float_col = schema.column_with_name("float_col").unwrap();
267        assert_eq!(6, float_col.0);
268        let col = get_col::<Float32Array>(&batch, float_col).unwrap();
269        assert_eq!(0.0, col.value(0));
270        assert_eq!(1.1, col.value(1));
271        assert_eq!(&DataType::Float32, float_col.1.data_type());
272        let col = get_col::<Float32Array>(&batch, float_col).unwrap();
273        assert_eq!(0.0, col.value(0));
274        assert_eq!(1.1, col.value(1));
275        let double_col = schema.column_with_name("double_col").unwrap();
276        assert_eq!(7, double_col.0);
277        assert_eq!(&DataType::Float64, double_col.1.data_type());
278        let col = get_col::<Float64Array>(&batch, double_col).unwrap();
279        assert_eq!(0.0, col.value(0));
280        assert_eq!(10.1, col.value(1));
281        let date_string_col = schema.column_with_name("date_string_col").unwrap();
282        assert_eq!(8, date_string_col.0);
283        assert_eq!(&DataType::Binary, date_string_col.1.data_type());
284        let col = get_col::<BinaryArray>(&batch, date_string_col).unwrap();
285        assert_eq!("01/01/09".as_bytes(), col.value(0));
286        assert_eq!("01/01/09".as_bytes(), col.value(1));
287        let string_col = schema.column_with_name("string_col").unwrap();
288        assert_eq!(9, string_col.0);
289        assert_eq!(&DataType::Binary, string_col.1.data_type());
290        let col = get_col::<BinaryArray>(&batch, string_col).unwrap();
291        assert_eq!("0".as_bytes(), col.value(0));
292        assert_eq!("1".as_bytes(), col.value(1));
293        let timestamp_col = schema.column_with_name("timestamp_col").unwrap();
294        assert_eq!(10, timestamp_col.0);
295        assert_eq!(
296            &DataType::Timestamp(TimeUnit::Microsecond, None),
297            timestamp_col.1.data_type()
298        );
299        let col = get_col::<TimestampMicrosecondArray>(&batch, timestamp_col).unwrap();
300        assert_eq!(1230768000000000, col.value(0));
301        assert_eq!(1230768060000000, col.value(1));
302    }
303
304    #[test]
305    fn test_avro_with_projection() {
306        // Test projection to filter and reorder columns
307        let projection = Some(vec![
308            "string_col".to_string(),
309            "double_col".to_string(),
310            "bool_col".to_string(),
311        ]);
312        let mut reader = build_reader("alltypes_dictionary.avro", projection);
313        let batch = reader.next().unwrap().unwrap();
314
315        // Only 3 columns should be present (not all 11)
316        assert_eq!(3, batch.num_columns());
317        assert_eq!(2, batch.num_rows());
318
319        let schema = reader.schema();
320        let batch_schema = batch.schema();
321        assert_eq!(schema, batch_schema);
322
323        // Verify columns are in the order specified in projection
324        // First column should be string_col (was at index 9 in original)
325        assert_eq!("string_col", schema.field(0).name());
326        assert_eq!(&DataType::Binary, schema.field(0).data_type());
327        let col = batch
328            .column(0)
329            .as_any()
330            .downcast_ref::<BinaryArray>()
331            .unwrap();
332        assert_eq!("0".as_bytes(), col.value(0));
333        assert_eq!("1".as_bytes(), col.value(1));
334
335        // Second column should be double_col (was at index 7 in original)
336        assert_eq!("double_col", schema.field(1).name());
337        assert_eq!(&DataType::Float64, schema.field(1).data_type());
338        let col = batch
339            .column(1)
340            .as_any()
341            .downcast_ref::<Float64Array>()
342            .unwrap();
343        assert_eq!(0.0, col.value(0));
344        assert_eq!(10.1, col.value(1));
345
346        // Third column should be bool_col (was at index 1 in original)
347        assert_eq!("bool_col", schema.field(2).name());
348        assert_eq!(&DataType::Boolean, schema.field(2).data_type());
349        let col = batch
350            .column(2)
351            .as_any()
352            .downcast_ref::<BooleanArray>()
353            .unwrap();
354        assert!(col.value(0));
355        assert!(!col.value(1));
356    }
357}