Skip to main content

alopex_embedded/
dataframe_api.rs

1use std::sync::Arc;
2
3use arrow::array::{
4    ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array,
5    NullArray, StringArray, TimestampMicrosecondArray,
6};
7use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
8use arrow::record_batch::RecordBatch;
9
10use alopex_dataframe::{DataFrame, DataFrameError};
11use alopex_sql::{ColumnInfo, ExecutionResult, QueryResult, ResolvedType, SqlValue};
12
13use crate::{Database, Result, SqlResult, Transaction};
14
15type DfResult<T> = std::result::Result<T, DataFrameError>;
16
17impl Database {
18    /// Execute SQL and return a DataFrame for query results.
19    pub fn query_df(&self, sql: &str) -> Result<DataFrame> {
20        let result = self.execute_sql(sql)?;
21        let df = sql_result_to_dataframe(result)?;
22        Ok(df)
23    }
24}
25
26impl<'a> Transaction<'a> {
27    /// Execute SQL within the transaction and return a DataFrame for query results.
28    pub fn query_df(&mut self, sql: &str) -> Result<DataFrame> {
29        let result = self.execute_sql(sql)?;
30        let df = sql_result_to_dataframe(result)?;
31        Ok(df)
32    }
33}
34
35fn sql_result_to_dataframe(result: SqlResult) -> DfResult<DataFrame> {
36    match result {
37        ExecutionResult::Query(query) => query_result_to_dataframe(query),
38        ExecutionResult::Success | ExecutionResult::RowsAffected(_) => Err(
39            DataFrameError::invalid_operation("query_df requires a SELECT query that returns rows"),
40        ),
41    }
42}
43
44fn query_result_to_dataframe(query: QueryResult) -> DfResult<DataFrame> {
45    let row_count = query.rows.len();
46    let mut fields = Vec::with_capacity(query.columns.len());
47    let mut builders = Vec::with_capacity(query.columns.len());
48
49    for ColumnInfo { name, data_type } in query.columns {
50        let arrow_type = arrow_type_for(&data_type)?;
51        fields.push(Field::new(&name, arrow_type, true));
52        builders.push(ColumnBuilder::new(name, data_type, row_count)?);
53    }
54
55    for row in query.rows {
56        if row.len() != builders.len() {
57            return Err(DataFrameError::schema_mismatch(format!(
58                "row has {} columns, expected {}",
59                row.len(),
60                builders.len()
61            )));
62        }
63
64        for (value, builder) in row.into_iter().zip(builders.iter_mut()) {
65            builder.push(value)?;
66        }
67    }
68
69    let schema = Arc::new(Schema::new(fields));
70    let arrays = builders
71        .into_iter()
72        .map(ColumnBuilder::finish)
73        .collect::<DfResult<Vec<_>>>()?;
74    let batch = RecordBatch::try_new(schema, arrays).map_err(|e| {
75        DataFrameError::schema_mismatch(format!("failed to build RecordBatch: {e}"))
76    })?;
77
78    DataFrame::from_batches(vec![batch])
79}
80
81fn arrow_type_for(ty: &ResolvedType) -> DfResult<DataType> {
82    match ty {
83        ResolvedType::Integer => Ok(DataType::Int32),
84        ResolvedType::BigInt => Ok(DataType::Int64),
85        ResolvedType::Float => Ok(DataType::Float32),
86        ResolvedType::Double => Ok(DataType::Float64),
87        ResolvedType::Text => Ok(DataType::Utf8),
88        ResolvedType::Blob => Ok(DataType::Binary),
89        ResolvedType::Boolean => Ok(DataType::Boolean),
90        ResolvedType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Microsecond, None)),
91        ResolvedType::Null => Ok(DataType::Null),
92        ResolvedType::Vector { .. } => Err(DataFrameError::invalid_operation(
93            "vector columns are not supported for DataFrame conversion",
94        )),
95    }
96}
97
98struct ColumnBuilder {
99    name: String,
100    expected: ResolvedType,
101    kind: ColumnBuilderKind,
102}
103
104enum ColumnBuilderKind {
105    Int32(Vec<Option<i32>>),
106    Int64(Vec<Option<i64>>),
107    Float32(Vec<Option<f32>>),
108    Float64(Vec<Option<f64>>),
109    Utf8(Vec<Option<String>>),
110    Binary(Vec<Option<Vec<u8>>>),
111    Boolean(Vec<Option<bool>>),
112    Timestamp(Vec<Option<i64>>),
113    Null(usize),
114}
115
116impl ColumnBuilder {
117    fn new(name: String, expected: ResolvedType, row_count: usize) -> DfResult<Self> {
118        let kind = match expected {
119            ResolvedType::Integer => ColumnBuilderKind::Int32(Vec::with_capacity(row_count)),
120            ResolvedType::BigInt => ColumnBuilderKind::Int64(Vec::with_capacity(row_count)),
121            ResolvedType::Float => ColumnBuilderKind::Float32(Vec::with_capacity(row_count)),
122            ResolvedType::Double => ColumnBuilderKind::Float64(Vec::with_capacity(row_count)),
123            ResolvedType::Text => ColumnBuilderKind::Utf8(Vec::with_capacity(row_count)),
124            ResolvedType::Blob => ColumnBuilderKind::Binary(Vec::with_capacity(row_count)),
125            ResolvedType::Boolean => ColumnBuilderKind::Boolean(Vec::with_capacity(row_count)),
126            ResolvedType::Timestamp => ColumnBuilderKind::Timestamp(Vec::with_capacity(row_count)),
127            ResolvedType::Null => ColumnBuilderKind::Null(0),
128            ResolvedType::Vector { .. } => {
129                return Err(DataFrameError::invalid_operation(
130                    "vector columns are not supported for DataFrame conversion",
131                ))
132            }
133        };
134
135        Ok(Self {
136            name,
137            expected,
138            kind,
139        })
140    }
141
142    fn push(&mut self, value: SqlValue) -> DfResult<()> {
143        match (&mut self.kind, value) {
144            (ColumnBuilderKind::Int32(values), SqlValue::Integer(v)) => {
145                values.push(Some(v));
146                Ok(())
147            }
148            (ColumnBuilderKind::Int64(values), SqlValue::BigInt(v)) => {
149                values.push(Some(v));
150                Ok(())
151            }
152            (ColumnBuilderKind::Float32(values), SqlValue::Float(v)) => {
153                values.push(Some(v));
154                Ok(())
155            }
156            (ColumnBuilderKind::Float64(values), SqlValue::Double(v)) => {
157                values.push(Some(v));
158                Ok(())
159            }
160            (ColumnBuilderKind::Utf8(values), SqlValue::Text(v)) => {
161                values.push(Some(v));
162                Ok(())
163            }
164            (ColumnBuilderKind::Binary(values), SqlValue::Blob(v)) => {
165                values.push(Some(v));
166                Ok(())
167            }
168            (ColumnBuilderKind::Boolean(values), SqlValue::Boolean(v)) => {
169                values.push(Some(v));
170                Ok(())
171            }
172            (ColumnBuilderKind::Timestamp(values), SqlValue::Timestamp(v)) => {
173                values.push(Some(v));
174                Ok(())
175            }
176            (ColumnBuilderKind::Int32(values), SqlValue::Null) => {
177                values.push(None);
178                Ok(())
179            }
180            (ColumnBuilderKind::Int64(values), SqlValue::Null) => {
181                values.push(None);
182                Ok(())
183            }
184            (ColumnBuilderKind::Float32(values), SqlValue::Null) => {
185                values.push(None);
186                Ok(())
187            }
188            (ColumnBuilderKind::Float64(values), SqlValue::Null) => {
189                values.push(None);
190                Ok(())
191            }
192            (ColumnBuilderKind::Utf8(values), SqlValue::Null) => {
193                values.push(None);
194                Ok(())
195            }
196            (ColumnBuilderKind::Binary(values), SqlValue::Null) => {
197                values.push(None);
198                Ok(())
199            }
200            (ColumnBuilderKind::Boolean(values), SqlValue::Null) => {
201                values.push(None);
202                Ok(())
203            }
204            (ColumnBuilderKind::Timestamp(values), SqlValue::Null) => {
205                values.push(None);
206                Ok(())
207            }
208            (ColumnBuilderKind::Null(count), SqlValue::Null) => {
209                *count += 1;
210                Ok(())
211            }
212            (_, other) => Err(DataFrameError::type_mismatch(
213                Some(self.name.clone()),
214                self.expected.to_string(),
215                other.type_name().to_string(),
216            )),
217        }
218    }
219
220    fn finish(self) -> DfResult<ArrayRef> {
221        let array: ArrayRef = match self.kind {
222            ColumnBuilderKind::Int32(values) => Arc::new(Int32Array::from(values)),
223            ColumnBuilderKind::Int64(values) => Arc::new(Int64Array::from(values)),
224            ColumnBuilderKind::Float32(values) => Arc::new(Float32Array::from(values)),
225            ColumnBuilderKind::Float64(values) => Arc::new(Float64Array::from(values)),
226            ColumnBuilderKind::Utf8(values) => Arc::new(StringArray::from(values)),
227            ColumnBuilderKind::Binary(values) => {
228                let slices: Vec<Option<&[u8]>> = values.iter().map(|v| v.as_deref()).collect();
229                Arc::new(BinaryArray::from(slices))
230            }
231            ColumnBuilderKind::Boolean(values) => Arc::new(BooleanArray::from(values)),
232            ColumnBuilderKind::Timestamp(values) => {
233                Arc::new(TimestampMicrosecondArray::from(values))
234            }
235            ColumnBuilderKind::Null(len) => Arc::new(NullArray::new(len)),
236        };
237
238        Ok(array)
239    }
240}