1use std::sync::Arc;
2
3use arrow::array::{
4 ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array,
5 NullArray, StringArray, TimestampMicrosecondArray,
6};
7use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
8use arrow::record_batch::RecordBatch;
9
10use alopex_dataframe::{DataFrame, DataFrameError};
11use alopex_sql::{ColumnInfo, ExecutionResult, QueryResult, ResolvedType, SqlValue};
12
13use crate::{Database, Result, SqlResult, Transaction};
14
15type DfResult<T> = std::result::Result<T, DataFrameError>;
16
17impl Database {
18 pub fn query_df(&self, sql: &str) -> Result<DataFrame> {
20 let result = self.execute_sql(sql)?;
21 let df = sql_result_to_dataframe(result)?;
22 Ok(df)
23 }
24}
25
26impl<'a> Transaction<'a> {
27 pub fn query_df(&mut self, sql: &str) -> Result<DataFrame> {
29 let result = self.execute_sql(sql)?;
30 let df = sql_result_to_dataframe(result)?;
31 Ok(df)
32 }
33}
34
35fn sql_result_to_dataframe(result: SqlResult) -> DfResult<DataFrame> {
36 match result {
37 ExecutionResult::Query(query) => query_result_to_dataframe(query),
38 ExecutionResult::Success | ExecutionResult::RowsAffected(_) => Err(
39 DataFrameError::invalid_operation("query_df requires a SELECT query that returns rows"),
40 ),
41 }
42}
43
44fn query_result_to_dataframe(query: QueryResult) -> DfResult<DataFrame> {
45 let row_count = query.rows.len();
46 let mut fields = Vec::with_capacity(query.columns.len());
47 let mut builders = Vec::with_capacity(query.columns.len());
48
49 for ColumnInfo { name, data_type } in query.columns {
50 let arrow_type = arrow_type_for(&data_type)?;
51 fields.push(Field::new(&name, arrow_type, true));
52 builders.push(ColumnBuilder::new(name, data_type, row_count)?);
53 }
54
55 for row in query.rows {
56 if row.len() != builders.len() {
57 return Err(DataFrameError::schema_mismatch(format!(
58 "row has {} columns, expected {}",
59 row.len(),
60 builders.len()
61 )));
62 }
63
64 for (value, builder) in row.into_iter().zip(builders.iter_mut()) {
65 builder.push(value)?;
66 }
67 }
68
69 let schema = Arc::new(Schema::new(fields));
70 let arrays = builders
71 .into_iter()
72 .map(ColumnBuilder::finish)
73 .collect::<DfResult<Vec<_>>>()?;
74 let batch = RecordBatch::try_new(schema, arrays).map_err(|e| {
75 DataFrameError::schema_mismatch(format!("failed to build RecordBatch: {e}"))
76 })?;
77
78 DataFrame::from_batches(vec![batch])
79}
80
81fn arrow_type_for(ty: &ResolvedType) -> DfResult<DataType> {
82 match ty {
83 ResolvedType::Integer => Ok(DataType::Int32),
84 ResolvedType::BigInt => Ok(DataType::Int64),
85 ResolvedType::Float => Ok(DataType::Float32),
86 ResolvedType::Double => Ok(DataType::Float64),
87 ResolvedType::Text => Ok(DataType::Utf8),
88 ResolvedType::Blob => Ok(DataType::Binary),
89 ResolvedType::Boolean => Ok(DataType::Boolean),
90 ResolvedType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Microsecond, None)),
91 ResolvedType::Null => Ok(DataType::Null),
92 ResolvedType::Vector { .. } => Err(DataFrameError::invalid_operation(
93 "vector columns are not supported for DataFrame conversion",
94 )),
95 }
96}
97
98struct ColumnBuilder {
99 name: String,
100 expected: ResolvedType,
101 kind: ColumnBuilderKind,
102}
103
104enum ColumnBuilderKind {
105 Int32(Vec<Option<i32>>),
106 Int64(Vec<Option<i64>>),
107 Float32(Vec<Option<f32>>),
108 Float64(Vec<Option<f64>>),
109 Utf8(Vec<Option<String>>),
110 Binary(Vec<Option<Vec<u8>>>),
111 Boolean(Vec<Option<bool>>),
112 Timestamp(Vec<Option<i64>>),
113 Null(usize),
114}
115
116impl ColumnBuilder {
117 fn new(name: String, expected: ResolvedType, row_count: usize) -> DfResult<Self> {
118 let kind = match expected {
119 ResolvedType::Integer => ColumnBuilderKind::Int32(Vec::with_capacity(row_count)),
120 ResolvedType::BigInt => ColumnBuilderKind::Int64(Vec::with_capacity(row_count)),
121 ResolvedType::Float => ColumnBuilderKind::Float32(Vec::with_capacity(row_count)),
122 ResolvedType::Double => ColumnBuilderKind::Float64(Vec::with_capacity(row_count)),
123 ResolvedType::Text => ColumnBuilderKind::Utf8(Vec::with_capacity(row_count)),
124 ResolvedType::Blob => ColumnBuilderKind::Binary(Vec::with_capacity(row_count)),
125 ResolvedType::Boolean => ColumnBuilderKind::Boolean(Vec::with_capacity(row_count)),
126 ResolvedType::Timestamp => ColumnBuilderKind::Timestamp(Vec::with_capacity(row_count)),
127 ResolvedType::Null => ColumnBuilderKind::Null(0),
128 ResolvedType::Vector { .. } => {
129 return Err(DataFrameError::invalid_operation(
130 "vector columns are not supported for DataFrame conversion",
131 ))
132 }
133 };
134
135 Ok(Self {
136 name,
137 expected,
138 kind,
139 })
140 }
141
142 fn push(&mut self, value: SqlValue) -> DfResult<()> {
143 match (&mut self.kind, value) {
144 (ColumnBuilderKind::Int32(values), SqlValue::Integer(v)) => {
145 values.push(Some(v));
146 Ok(())
147 }
148 (ColumnBuilderKind::Int64(values), SqlValue::BigInt(v)) => {
149 values.push(Some(v));
150 Ok(())
151 }
152 (ColumnBuilderKind::Float32(values), SqlValue::Float(v)) => {
153 values.push(Some(v));
154 Ok(())
155 }
156 (ColumnBuilderKind::Float64(values), SqlValue::Double(v)) => {
157 values.push(Some(v));
158 Ok(())
159 }
160 (ColumnBuilderKind::Utf8(values), SqlValue::Text(v)) => {
161 values.push(Some(v));
162 Ok(())
163 }
164 (ColumnBuilderKind::Binary(values), SqlValue::Blob(v)) => {
165 values.push(Some(v));
166 Ok(())
167 }
168 (ColumnBuilderKind::Boolean(values), SqlValue::Boolean(v)) => {
169 values.push(Some(v));
170 Ok(())
171 }
172 (ColumnBuilderKind::Timestamp(values), SqlValue::Timestamp(v)) => {
173 values.push(Some(v));
174 Ok(())
175 }
176 (ColumnBuilderKind::Int32(values), SqlValue::Null) => {
177 values.push(None);
178 Ok(())
179 }
180 (ColumnBuilderKind::Int64(values), SqlValue::Null) => {
181 values.push(None);
182 Ok(())
183 }
184 (ColumnBuilderKind::Float32(values), SqlValue::Null) => {
185 values.push(None);
186 Ok(())
187 }
188 (ColumnBuilderKind::Float64(values), SqlValue::Null) => {
189 values.push(None);
190 Ok(())
191 }
192 (ColumnBuilderKind::Utf8(values), SqlValue::Null) => {
193 values.push(None);
194 Ok(())
195 }
196 (ColumnBuilderKind::Binary(values), SqlValue::Null) => {
197 values.push(None);
198 Ok(())
199 }
200 (ColumnBuilderKind::Boolean(values), SqlValue::Null) => {
201 values.push(None);
202 Ok(())
203 }
204 (ColumnBuilderKind::Timestamp(values), SqlValue::Null) => {
205 values.push(None);
206 Ok(())
207 }
208 (ColumnBuilderKind::Null(count), SqlValue::Null) => {
209 *count += 1;
210 Ok(())
211 }
212 (_, other) => Err(DataFrameError::type_mismatch(
213 Some(self.name.clone()),
214 self.expected.to_string(),
215 other.type_name().to_string(),
216 )),
217 }
218 }
219
220 fn finish(self) -> DfResult<ArrayRef> {
221 let array: ArrayRef = match self.kind {
222 ColumnBuilderKind::Int32(values) => Arc::new(Int32Array::from(values)),
223 ColumnBuilderKind::Int64(values) => Arc::new(Int64Array::from(values)),
224 ColumnBuilderKind::Float32(values) => Arc::new(Float32Array::from(values)),
225 ColumnBuilderKind::Float64(values) => Arc::new(Float64Array::from(values)),
226 ColumnBuilderKind::Utf8(values) => Arc::new(StringArray::from(values)),
227 ColumnBuilderKind::Binary(values) => {
228 let slices: Vec<Option<&[u8]>> = values.iter().map(|v| v.as_deref()).collect();
229 Arc::new(BinaryArray::from(slices))
230 }
231 ColumnBuilderKind::Boolean(values) => Arc::new(BooleanArray::from(values)),
232 ColumnBuilderKind::Timestamp(values) => {
233 Arc::new(TimestampMicrosecondArray::from(values))
234 }
235 ColumnBuilderKind::Null(len) => Arc::new(NullArray::new(len)),
236 };
237
238 Ok(array)
239 }
240}