alopex_sql/executor/bulk/
parquet.rs

1use std::fs::File;
2
3use arrow_array::{
4    Array, BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array,
5    LargeBinaryArray, StringArray, TimestampMicrosecondArray,
6};
7use arrow_schema::DataType as ArrowDataType;
8use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder};
9
10use crate::catalog::TableMetadata;
11use crate::executor::{ExecutorError, Result};
12use crate::planner::types::ResolvedType;
13use crate::storage::SqlValue;
14
15use super::{BulkReader, CopyField, CopySchema};
16
17/// Parquet リーダー(Arrow 経由でスキーマ抽出とデータ読み込み)。
18pub struct ParquetReader {
19    schema: CopySchema,
20    target_types: Vec<ResolvedType>,
21    reader: ParquetRecordBatchReader,
22    buffer: Option<Vec<Vec<SqlValue>>>,
23}
24
25impl ParquetReader {
26    pub fn open(path: &str, table_meta: &TableMetadata, _header: bool) -> Result<Self> {
27        let file = File::open(path)
28            .map_err(|e| ExecutorError::BulkLoad(format!("failed to open parquet: {e}")))?;
29
30        let builder = ParquetRecordBatchReaderBuilder::try_new(file).map_err(|e| {
31            ExecutorError::BulkLoad(format!("failed to read parquet metadata: {e}"))
32        })?;
33
34        let arrow_schema = builder.schema();
35        let mut fields = Vec::with_capacity(arrow_schema.fields().len());
36        for f in arrow_schema.fields() {
37            let ty = map_arrow_type(f.data_type())?;
38            fields.push(CopyField {
39                name: Some(f.name().clone()),
40                data_type: Some(ty),
41            });
42        }
43
44        let reader = builder
45            .with_batch_size(1024)
46            .build()
47            .map_err(|e| ExecutorError::BulkLoad(format!("failed to build parquet reader: {e}")))?;
48        // TODO: バッチサイズを open 引数で受け取れるようにし、呼び出し側で柔軟に制御できるようにする。
49
50        let target_types: Vec<ResolvedType> = table_meta
51            .columns
52            .iter()
53            .map(|c| c.data_type.clone())
54            .collect();
55
56        Ok(Self {
57            schema: CopySchema { fields },
58            target_types,
59            reader,
60            buffer: None,
61        })
62    }
63}
64
65impl BulkReader for ParquetReader {
66    fn schema(&self) -> &CopySchema {
67        &self.schema
68    }
69
70    fn next_batch(&mut self, max_rows: usize) -> Result<Option<Vec<Vec<SqlValue>>>> {
71        let max_rows = max_rows.max(1);
72
73        if let Some(mut buffered) = self.buffer.take() {
74            if buffered.len() > max_rows {
75                let rest = buffered.split_off(max_rows);
76                self.buffer = Some(rest);
77            }
78            return Ok(Some(buffered));
79        }
80
81        let maybe_batch = self.reader.next();
82        let batch = match maybe_batch {
83            Some(b) => b.map_err(|e| {
84                ExecutorError::BulkLoad(format!("failed to read parquet batch: {e}"))
85            })?,
86            None => return Ok(None),
87        };
88
89        let mut rows: Vec<Vec<SqlValue>> = Vec::with_capacity(batch.num_rows());
90        for row_idx in 0..batch.num_rows() {
91            let mut row = Vec::with_capacity(self.schema.fields.len());
92            for col_idx in 0..self.schema.fields.len() {
93                let value = arrow_value_to_sql(
94                    batch.column(col_idx).as_ref(),
95                    batch.schema().field(col_idx).data_type(),
96                    self.target_types
97                        .get(col_idx)
98                        .ok_or_else(|| ExecutorError::BulkLoad("missing target type".into()))?,
99                    row_idx,
100                )?;
101                row.push(value);
102            }
103            rows.push(row);
104        }
105
106        if rows.len() > max_rows {
107            let rest = rows.split_off(max_rows);
108            self.buffer = Some(rest);
109        }
110
111        Ok(Some(rows))
112    }
113}
114
115fn map_arrow_type(dt: &ArrowDataType) -> Result<ResolvedType> {
116    match dt {
117        ArrowDataType::Int32 => Ok(ResolvedType::Integer),
118        ArrowDataType::Int64 => Ok(ResolvedType::BigInt),
119        ArrowDataType::Float32 => Ok(ResolvedType::Float),
120        ArrowDataType::Float64 => Ok(ResolvedType::Double),
121        ArrowDataType::Boolean => Ok(ResolvedType::Boolean),
122        ArrowDataType::Utf8 => Ok(ResolvedType::Text),
123        ArrowDataType::Binary | ArrowDataType::LargeBinary => Ok(ResolvedType::Blob),
124        ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, _) => {
125            Ok(ResolvedType::Timestamp)
126        }
127        other => Err(ExecutorError::BulkLoad(format!(
128            "unsupported parquet/arrow type: {other:?}"
129        ))),
130    }
131}
132
133fn arrow_value_to_sql(
134    array: &dyn Array,
135    dt: &ArrowDataType,
136    expected: &ResolvedType,
137    row_idx: usize,
138) -> Result<SqlValue> {
139    if array.is_null(row_idx) {
140        return Ok(SqlValue::Null);
141    }
142
143    match (dt, expected) {
144        (ArrowDataType::Int32, ResolvedType::Integer) => {
145            let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
146            Ok(SqlValue::Integer(arr.value(row_idx)))
147        }
148        (ArrowDataType::Int32, ResolvedType::BigInt) => {
149            let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
150            Ok(SqlValue::BigInt(arr.value(row_idx) as i64))
151        }
152        (ArrowDataType::Int32, ResolvedType::Float) => {
153            let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
154            Ok(SqlValue::Float(arr.value(row_idx) as f32))
155        }
156        (ArrowDataType::Int32, ResolvedType::Double) => {
157            let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
158            Ok(SqlValue::Double(arr.value(row_idx) as f64))
159        }
160        (ArrowDataType::Int64, ResolvedType::BigInt) => {
161            let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
162            Ok(SqlValue::BigInt(arr.value(row_idx)))
163        }
164        (ArrowDataType::Int64, ResolvedType::Double) => {
165            let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
166            Ok(SqlValue::Double(arr.value(row_idx) as f64))
167        }
168        (ArrowDataType::Float32, ResolvedType::Float) => {
169            let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
170            Ok(SqlValue::Float(arr.value(row_idx)))
171        }
172        (ArrowDataType::Float32, ResolvedType::Double) => {
173            let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
174            Ok(SqlValue::Double(arr.value(row_idx) as f64))
175        }
176        (ArrowDataType::Float64, ResolvedType::Double) => {
177            let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
178            Ok(SqlValue::Double(arr.value(row_idx)))
179        }
180        (ArrowDataType::Boolean, ResolvedType::Boolean) => {
181            let arr = array.as_any().downcast_ref::<BooleanArray>().unwrap();
182            Ok(SqlValue::Boolean(arr.value(row_idx)))
183        }
184        (ArrowDataType::Utf8, ResolvedType::Text) => {
185            let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
186            Ok(SqlValue::Text(arr.value(row_idx).to_string()))
187        }
188        (ArrowDataType::Binary, ResolvedType::Blob) => {
189            let arr = array.as_any().downcast_ref::<BinaryArray>().unwrap();
190            Ok(SqlValue::Blob(arr.value(row_idx).to_vec()))
191        }
192        (ArrowDataType::LargeBinary, ResolvedType::Blob) => {
193            let arr = array.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
194            Ok(SqlValue::Blob(arr.value(row_idx).to_vec()))
195        }
196        (
197            ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, _),
198            ResolvedType::Timestamp,
199        ) => {
200            let arr = array
201                .as_any()
202                .downcast_ref::<TimestampMicrosecondArray>()
203                .unwrap();
204            Ok(SqlValue::Timestamp(arr.value(row_idx)))
205        }
206        _ => Err(ExecutorError::BulkLoad(format!(
207            "parquet field type {:?} does not match expected {:?}",
208            dt, expected
209        ))),
210    }
211}