use super::Engine;
use anyhow::Result;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::prelude::*;
use std::sync::Arc;
pub struct DataFusionEngine {
ctx: Arc<SessionContext>,
}
impl DataFusionEngine {
pub fn new(ctx: SessionContext) -> Self {
Self { ctx: Arc::new(ctx) }
}
pub fn new_with_arc(ctx: Arc<SessionContext>) -> Self {
Self { ctx }
}
pub fn session_context(&self) -> &SessionContext {
&self.ctx
}
pub fn session_context_arc(&self) -> Arc<SessionContext> {
self.ctx.clone()
}
}
#[async_trait]
impl Engine for DataFusionEngine {
async fn execute(&self, sql: &str) -> Result<RecordBatch> {
let dataframe = self
.ctx
.sql(sql)
.await
.map_err(|e| anyhow::anyhow!("Failed to execute SQL query: {}", e))?;
let schema = dataframe.schema().inner().clone();
tracing::debug!("Query schema: {:?}", schema);
let batches = dataframe.collect().await.map_err(|e| {
tracing::error!("Failed to collect query results. Schema: {:?}", schema);
anyhow::anyhow!("Failed to collect query results: {}", e)
})?;
tracing::debug!("Collected {} batches", batches.len());
for (i, batch) in batches.iter().enumerate() {
tracing::debug!(
"Batch {}: {} rows, schema: {:?}",
i,
batch.num_rows(),
batch.schema()
);
}
match batches.len() {
0 => {
let empty_batch = RecordBatch::new_empty(schema);
Ok(empty_batch)
}
1 => {
Ok(batches
.into_iter()
.next()
.expect("len == 1 guarantees first element"))
}
_ => {
use arrow::compute::concat_batches;
let batch_schema = batches[0].schema();
let concatenated = concat_batches(&batch_schema, &batches)
.map_err(|e| anyhow::anyhow!("Failed to concatenate result batches: {}", e))?;
Ok(concatenated)
}
}
}
}