1use std::sync::Arc;
18
19use crate::sql::arrow_sql_gen::arrow::map_data_type_to_array_builder;
20use arrow::{
21 array::{
22 ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Float32Builder, Float64Builder,
23 Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeStringBuilder, NullBuilder,
24 RecordBatch, RecordBatchOptions, StringBuilder, UInt16Builder, UInt32Builder,
25 UInt64Builder, UInt8Builder,
26 },
27 datatypes::{DataType, Field, Schema, SchemaRef},
28};
29use rusqlite::{types::Type, Row, Rows};
30use snafu::prelude::*;
31
32#[derive(Debug, Snafu)]
33pub enum Error {
34 #[snafu(display("Failed to build record batch: {source}"))]
35 FailedToBuildRecordBatch {
36 source: datafusion::arrow::error::ArrowError,
37 },
38
39 #[snafu(display("No builder found for index {index}"))]
40 NoBuilderForIndex { index: usize },
41
42 #[snafu(display("Failed to downcast builder for {sqlite_type}"))]
43 FailedToDowncastBuilder { sqlite_type: String },
44
45 #[snafu(display("Failed to extract row value: {source}"))]
46 FailedToExtractRowValue { source: rusqlite::Error },
47
48 #[snafu(display("Failed to extract column name: {source}"))]
49 FailedToExtractColumnName { source: rusqlite::Error },
50}
51
52pub type Result<T, E = Error> = std::result::Result<T, E>;
53
54pub fn rows_to_arrow(
61 mut rows: Rows,
62 num_cols: usize,
63 projected_schema: Option<SchemaRef>,
64) -> Result<RecordBatch> {
65 let mut arrow_fields: Vec<Field> = Vec::new();
66 let mut arrow_columns_builders: Vec<Box<dyn ArrayBuilder>> = Vec::new();
67 let mut arrow_types: Vec<DataType> = Vec::new();
68 let mut row_count = 0;
69
70 if let Ok(Some(row)) = rows.next() {
71 for i in 0..num_cols {
72 let mut column_type = row
73 .get_ref(i)
74 .context(FailedToExtractRowValueSnafu)?
75 .data_type();
76 let column_name = row
77 .as_ref()
78 .column_name(i)
79 .context(FailedToExtractColumnNameSnafu)?
80 .to_string();
81
82 if column_type == Type::Integer {
95 if let Some(projected_schema) = projected_schema.as_ref() {
96 match projected_schema.fields[i].data_type() {
97 DataType::Decimal128(..)
98 | DataType::Float16
99 | DataType::Float32
100 | DataType::Float64 => {
101 column_type = Type::Real;
102 }
103 _ => {}
104 }
105 }
106 }
107
108 let data_type = match &projected_schema {
109 Some(schema) => {
110 to_sqlite_decoding_type(schema.fields()[i].data_type(), &column_type)
111 }
112 None => map_column_type_to_data_type(column_type),
113 };
114
115 arrow_types.push(data_type.clone());
116 arrow_columns_builders.push(map_data_type_to_array_builder(&data_type));
117 arrow_fields.push(Field::new(column_name, data_type, true));
118 }
119
120 add_row_to_builders(row, &arrow_types, &mut arrow_columns_builders)?;
121 row_count += 1;
122 };
123
124 while let Ok(Some(row)) = rows.next() {
125 add_row_to_builders(row, &arrow_types, &mut arrow_columns_builders)?;
126 row_count += 1;
127 }
128
129 let columns = arrow_columns_builders
130 .into_iter()
131 .map(|mut b| b.finish())
132 .collect::<Vec<ArrayRef>>();
133
134 let options = &RecordBatchOptions::new().with_row_count(Some(row_count));
135 match RecordBatch::try_new_with_options(Arc::new(Schema::new(arrow_fields)), columns, options) {
136 Ok(record_batch) => Ok(record_batch),
137 Err(e) => Err(e).context(FailedToBuildRecordBatchSnafu),
138 }
139}
140
141fn to_sqlite_decoding_type(data_type: &DataType, sqlite_type: &Type) -> DataType {
142 if *sqlite_type == Type::Text {
143 return DataType::Utf8;
146 }
147 match data_type {
149 DataType::Null => DataType::Null,
150 DataType::Int8 => DataType::Int8,
151 DataType::Int16 => DataType::Int16,
152 DataType::Int32 => DataType::Int32,
153 DataType::Int64 => DataType::Int64,
154 DataType::UInt8 => DataType::UInt8,
155 DataType::UInt16 => DataType::UInt16,
156 DataType::UInt32 => DataType::UInt32,
157 DataType::UInt64 => DataType::UInt64,
158 DataType::Boolean => DataType::Boolean,
159 DataType::Float16 => DataType::Float16,
160 DataType::Float32 => DataType::Float32,
161 DataType::Float64 => DataType::Float64,
162 DataType::Utf8 => DataType::Utf8,
163 DataType::LargeUtf8 => DataType::LargeUtf8,
164 DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => DataType::Binary,
165 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => DataType::Float64,
166 DataType::Duration(_) => DataType::Int64,
167
168 _ => DataType::Utf8,
170 }
171}
172
173macro_rules! append_value {
174 ($builder:expr, $row:expr, $index:expr, $type:ty, $builder_type:ty, $sqlite_type:expr) => {{
175 let Some(builder) = $builder.as_any_mut().downcast_mut::<$builder_type>() else {
176 FailedToDowncastBuilderSnafu {
177 sqlite_type: format!("{}", $sqlite_type),
178 }
179 .fail()?
180 };
181 let value: Option<$type> = $row.get($index).context(FailedToExtractRowValueSnafu)?;
182 match value {
183 Some(value) => builder.append_value(value),
184 None => builder.append_null(),
185 }
186 }};
187}
188
189fn add_row_to_builders(
190 row: &Row,
191 arrow_types: &[DataType],
192 arrow_columns_builders: &mut [Box<dyn ArrayBuilder>],
193) -> Result<()> {
194 for (i, arrow_type) in arrow_types.iter().enumerate() {
195 let Some(builder) = arrow_columns_builders.get_mut(i) else {
196 return NoBuilderForIndexSnafu { index: i }.fail();
197 };
198
199 match *arrow_type {
200 DataType::Null => {
201 let Some(builder) = builder.as_any_mut().downcast_mut::<NullBuilder>() else {
202 return FailedToDowncastBuilderSnafu {
203 sqlite_type: format!("{}", Type::Null),
204 }
205 .fail();
206 };
207 builder.append_null();
208 }
209 DataType::Int8 => append_value!(builder, row, i, i8, Int8Builder, Type::Integer),
210 DataType::Int16 => append_value!(builder, row, i, i16, Int16Builder, Type::Integer),
211 DataType::Int32 => append_value!(builder, row, i, i32, Int32Builder, Type::Integer),
212 DataType::Int64 => append_value!(builder, row, i, i64, Int64Builder, Type::Integer),
213 DataType::UInt8 => append_value!(builder, row, i, u8, UInt8Builder, Type::Integer),
214 DataType::UInt16 => append_value!(builder, row, i, u16, UInt16Builder, Type::Integer),
215 DataType::UInt32 => append_value!(builder, row, i, u32, UInt32Builder, Type::Integer),
216 DataType::UInt64 => append_value!(builder, row, i, u64, UInt64Builder, Type::Integer),
217
218 DataType::Boolean => {
219 append_value!(builder, row, i, bool, BooleanBuilder, Type::Integer)
220 }
221
222 DataType::Float32 => append_value!(builder, row, i, f32, Float32Builder, Type::Real),
223 DataType::Float64 => append_value!(builder, row, i, f64, Float64Builder, Type::Real),
224
225 DataType::Utf8 => append_value!(builder, row, i, String, StringBuilder, Type::Text),
226 DataType::LargeUtf8 => {
227 append_value!(builder, row, i, String, LargeStringBuilder, Type::Text)
228 }
229
230 DataType::Binary => append_value!(builder, row, i, Vec<u8>, BinaryBuilder, Type::Blob),
231 _ => {
232 unimplemented!("Unsupported data type {arrow_type} for column index {i}")
233 }
234 }
235 }
236
237 Ok(())
238}
239
240fn map_column_type_to_data_type(column_type: Type) -> DataType {
241 match column_type {
242 Type::Null => DataType::Null,
243 Type::Integer => DataType::Int64,
244 Type::Real => DataType::Float64,
245 Type::Text => DataType::Utf8,
246 Type::Blob => DataType::Binary,
247 }
248}