Skip to main content

kyu_copy/
parquet_reader.rs

1//! Parquet file reader.
2
3use std::fs::File;
4
5use arrow::array::{
6    Array, AsArray, BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array,
7    Int64Array,
8};
9use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
10
11use kyu_common::{KyuError, KyuResult};
12use kyu_types::{LogicalType, TypedValue};
13
14use crate::DataReader;
15
16/// Reads rows from a Parquet file, converting Arrow arrays to TypedValue rows.
17pub struct ParquetReader {
18    schema: Vec<LogicalType>,
19    /// Pre-loaded rows from all row groups.
20    rows: std::vec::IntoIter<Vec<TypedValue>>,
21}
22
23impl ParquetReader {
24    /// Open a Parquet file and read all row groups into memory.
25    pub fn open(path: &str, schema: &[LogicalType]) -> KyuResult<Self> {
26        let file =
27            File::open(path).map_err(|e| KyuError::Copy(format!("cannot open '{path}': {e}")))?;
28
29        let builder = ParquetRecordBatchReaderBuilder::try_new(file)
30            .map_err(|e| KyuError::Copy(format!("invalid Parquet file '{path}': {e}")))?;
31
32        let reader = builder
33            .build()
34            .map_err(|e| KyuError::Copy(format!("cannot read Parquet '{path}': {e}")))?;
35
36        let mut all_rows = Vec::new();
37
38        for batch_result in reader {
39            let batch =
40                batch_result.map_err(|e| KyuError::Copy(format!("Parquet batch error: {e}")))?;
41
42            let num_rows = batch.num_rows();
43            let num_cols = schema.len().min(batch.num_columns());
44
45            for row_idx in 0..num_rows {
46                let mut row = Vec::with_capacity(schema.len());
47                for (col_idx, col_type) in schema.iter().enumerate().take(num_cols) {
48                    let col = batch.column(col_idx);
49                    let value = extract_value(col.as_ref(), row_idx, col_type)?;
50                    row.push(value);
51                }
52                // Pad with Null if Parquet has fewer columns than schema.
53                for _ in num_cols..schema.len() {
54                    row.push(TypedValue::Null);
55                }
56                all_rows.push(row);
57            }
58        }
59
60        Ok(Self {
61            schema: schema.to_vec(),
62            rows: all_rows.into_iter(),
63        })
64    }
65}
66
67impl DataReader for ParquetReader {
68    fn schema(&self) -> &[LogicalType] {
69        &self.schema
70    }
71}
72
73impl Iterator for ParquetReader {
74    type Item = KyuResult<Vec<TypedValue>>;
75
76    fn next(&mut self) -> Option<Self::Item> {
77        self.rows.next().map(Ok)
78    }
79}
80
81/// Extract a TypedValue from an Arrow array at the given row index.
82fn extract_value(
83    array: &dyn Array,
84    row: usize,
85    target_type: &LogicalType,
86) -> KyuResult<TypedValue> {
87    if array.is_null(row) {
88        return Ok(TypedValue::Null);
89    }
90
91    match target_type {
92        LogicalType::Int8 => {
93            let arr = array
94                .as_any()
95                .downcast_ref::<Int8Array>()
96                .ok_or_else(|| KyuError::Copy("expected Int8 column in Parquet".into()))?;
97            Ok(TypedValue::Int8(arr.value(row)))
98        }
99        LogicalType::Int16 => {
100            let arr = array
101                .as_any()
102                .downcast_ref::<Int16Array>()
103                .ok_or_else(|| KyuError::Copy("expected Int16 column in Parquet".into()))?;
104            Ok(TypedValue::Int16(arr.value(row)))
105        }
106        LogicalType::Int32 => {
107            let arr = array
108                .as_any()
109                .downcast_ref::<Int32Array>()
110                .ok_or_else(|| KyuError::Copy("expected Int32 column in Parquet".into()))?;
111            Ok(TypedValue::Int32(arr.value(row)))
112        }
113        LogicalType::Int64 | LogicalType::Serial => {
114            let arr = array
115                .as_any()
116                .downcast_ref::<Int64Array>()
117                .ok_or_else(|| KyuError::Copy("expected Int64 column in Parquet".into()))?;
118            Ok(TypedValue::Int64(arr.value(row)))
119        }
120        LogicalType::Float => {
121            let arr = array
122                .as_any()
123                .downcast_ref::<Float32Array>()
124                .ok_or_else(|| KyuError::Copy("expected Float32 column in Parquet".into()))?;
125            Ok(TypedValue::Float(arr.value(row)))
126        }
127        LogicalType::Double => {
128            let arr = array
129                .as_any()
130                .downcast_ref::<Float64Array>()
131                .ok_or_else(|| KyuError::Copy("expected Float64 column in Parquet".into()))?;
132            Ok(TypedValue::Double(arr.value(row)))
133        }
134        LogicalType::Bool => {
135            let arr = array
136                .as_any()
137                .downcast_ref::<BooleanArray>()
138                .ok_or_else(|| KyuError::Copy("expected Boolean column in Parquet".into()))?;
139            Ok(TypedValue::Bool(arr.value(row)))
140        }
141        LogicalType::String => {
142            let arr = array.as_string::<i32>();
143            Ok(TypedValue::String(smol_str::SmolStr::new(arr.value(row))))
144        }
145        _ => Err(KyuError::Copy(format!(
146            "unsupported type {} for Parquet import",
147            target_type.type_name()
148        ))),
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use arrow::array::{Int64Array, StringArray};
156    use arrow::datatypes::{DataType, Field, Schema};
157    use arrow::record_batch::RecordBatch;
158    use parquet::arrow::ArrowWriter;
159    use std::sync::Arc;
160
161    fn write_test_parquet(dir: &std::path::Path, name: &str) -> String {
162        let path = dir.join(name);
163        let schema = Arc::new(Schema::new(vec![
164            Field::new("id", DataType::Int64, false),
165            Field::new("name", DataType::Utf8, false),
166        ]));
167
168        let ids = Int64Array::from(vec![1, 2, 3]);
169        let names = StringArray::from(vec!["Alice", "Bob", "Charlie"]);
170        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(ids), Arc::new(names)])
171            .unwrap();
172
173        let file = File::create(&path).unwrap();
174        let mut writer = ArrowWriter::try_new(file, Arc::clone(&schema), None).unwrap();
175        writer.write(&batch).unwrap();
176        writer.close().unwrap();
177
178        path.to_str().unwrap().to_string()
179    }
180
181    #[test]
182    fn read_parquet_basic() {
183        let dir = std::env::temp_dir().join("kyu_parquet_reader_test");
184        let _ = std::fs::create_dir_all(&dir);
185        let path = write_test_parquet(&dir, "test.parquet");
186
187        let schema = vec![LogicalType::Int64, LogicalType::String];
188        let reader = ParquetReader::open(&path, &schema).unwrap();
189        let rows: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
190
191        assert_eq!(rows.len(), 3);
192        assert_eq!(rows[0][0], TypedValue::Int64(1));
193        assert_eq!(
194            rows[0][1],
195            TypedValue::String(smol_str::SmolStr::new("Alice"))
196        );
197        assert_eq!(rows[2][0], TypedValue::Int64(3));
198        assert_eq!(
199            rows[2][1],
200            TypedValue::String(smol_str::SmolStr::new("Charlie"))
201        );
202
203        let _ = std::fs::remove_dir_all(&dir);
204    }
205}