use std::sync::Arc;
use datafusion::{
arrow::{
array::Array,
datatypes::DataType,
record_batch::RecordBatch,
},
datasource::file_format::parquet::ParquetFormat,
execution::{
SessionStateBuilder,
context::SessionContext,
object_store::ObjectStoreUrl,
},
prelude::ParquetReadOptions,
scalar::ScalarValue,
};
use datafusion::datasource::listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl};
use object_store::aws::AmazonS3Builder;
use serde_json::{Map, Number, Value};
use url::Url;
use crate::{
config::{DatasetConfig, FieldSummary, QueryRequest, QueryResponse, StorageConfig},
error::AppError,
};
#[derive(Debug, Clone, Default)]
pub struct QueryEngine;
impl QueryEngine {
pub fn new() -> Self {
Self
}
pub async fn execute(&self, request: QueryRequest) -> Result<QueryResponse, AppError> {
if request.sql.trim().is_empty() {
return Err(AppError::Validation("sql must not be empty".to_string()));
}
if request.datasets.is_empty() {
return Err(AppError::Validation(
"at least one dataset must be provided".to_string(),
));
}
let ctx = new_context();
for dataset in &request.datasets {
register_dataset(&ctx, dataset).await?;
}
let dataframe = ctx.sql(&request.sql).await?;
let schema = dataframe.schema();
let batches = dataframe.collect().await?;
let rows = batches_to_rows(&batches)?;
let fields = schema
.fields()
.iter()
.map(|field| FieldSummary {
name: field.name().to_string(),
data_type: field.data_type().to_string(),
nullable: field.is_nullable(),
})
.collect();
Ok(QueryResponse {
row_count: rows.len(),
rows,
fields,
})
}
}
fn new_context() -> SessionContext {
let state = SessionStateBuilder::new().with_default_features().build();
SessionContext::new_with_state(state)
}
async fn register_dataset(ctx: &SessionContext, dataset: &DatasetConfig) -> Result<(), AppError> {
match &dataset.storage {
StorageConfig::Local => {
ctx.register_parquet(
dataset.table_name.as_str(),
dataset.uri.as_str(),
ParquetReadOptions::default(),
)
.await?;
}
StorageConfig::S3(config) => {
let store = build_s3_store(config)?;
let store_url = format!("s3://{}", config.bucket);
ctx.register_object_store(&ObjectStoreUrl::parse(store_url.as_str())?, Arc::new(store));
register_listing_table(ctx, dataset).await?;
}
StorageConfig::Minio(config) => {
let store = build_minio_store(config)?;
let store_url = format!("s3://{}", config.bucket);
ctx.register_object_store(&ObjectStoreUrl::parse(store_url.as_str())?, Arc::new(store));
register_listing_table(ctx, dataset).await?;
}
}
Ok(())
}
async fn register_listing_table(ctx: &SessionContext, dataset: &DatasetConfig) -> Result<(), AppError> {
let table_url = ListingTableUrl::parse(dataset.uri.as_str())?;
let file_format = Arc::new(ParquetFormat::default());
let listing_options = ListingOptions::new(file_format)
.with_file_extension(".parquet");
let resolved_schema = listing_options
.infer_schema(ctx.state(), &table_url)
.await?;
let config = ListingTableConfig::new(table_url)
.with_listing_options(listing_options)
.with_schema(resolved_schema);
let table = ListingTable::try_new(config)?;
ctx.register_table(dataset.table_name.as_str(), Arc::new(table))?;
Ok(())
}
fn build_s3_store(config: &crate::config::S3StorageConfig) -> Result<object_store::aws::AmazonS3, AppError> {
let mut builder = AmazonS3Builder::new()
.with_bucket_name(config.bucket.as_str())
.with_region(config.region.as_str())
.with_allow_http(config.allow_http)
.with_virtual_hosted_style_request(!config.force_path_style);
if let Some(endpoint) = &config.endpoint {
builder = builder.with_endpoint(endpoint.as_str());
}
if let Some(access_key_id) = &config.access_key_id {
builder = builder.with_access_key_id(access_key_id.as_str());
}
if let Some(secret_access_key) = &config.secret_access_key {
builder = builder.with_secret_access_key(secret_access_key.as_str());
}
if let Some(session_token) = &config.session_token {
builder = builder.with_token(session_token.as_str());
}
builder
.build()
.map_err(|error| AppError::Validation(error.to_string()))
}
fn build_minio_store(config: &crate::config::MinioStorageConfig) -> Result<object_store::aws::AmazonS3, AppError> {
AmazonS3Builder::new()
.with_bucket_name(config.bucket.as_str())
.with_region(config.region.as_str())
.with_endpoint(config.endpoint.as_str())
.with_access_key_id(config.access_key_id.as_str())
.with_secret_access_key(config.secret_access_key.as_str())
.with_allow_http(config.allow_http)
.with_virtual_hosted_style_request(false)
.build()
.map_err(|error| AppError::Validation(error.to_string()))
}
fn batches_to_rows(batches: &[RecordBatch]) -> Result<Vec<Value>, AppError> {
let mut rows = Vec::new();
for batch in batches {
let schema = batch.schema();
for row_index in 0..batch.num_rows() {
let mut row = Map::with_capacity(batch.num_columns());
for (column_index, field) in schema.fields().iter().enumerate() {
let array = batch.column(column_index);
row.insert(
field.name().to_string(),
scalar_value_to_json(array.as_ref(), row_index)?,
);
}
rows.push(Value::Object(row));
}
}
Ok(rows)
}
fn scalar_value_to_json(array: &dyn Array, row_index: usize) -> Result<Value, AppError> {
let scalar = ScalarValue::try_from_array(array, row_index)
.map_err(|error| AppError::Serialization(error.to_string()))?;
let value = match scalar {
ScalarValue::Null => Value::Null,
ScalarValue::Boolean(value) => value.map(Value::Bool).unwrap_or(Value::Null),
ScalarValue::Float32(value) => value
.and_then(Number::from_f64)
.map(Value::Number)
.unwrap_or(Value::Null),
ScalarValue::Float64(value) => value
.and_then(Number::from_f64)
.map(Value::Number)
.unwrap_or(Value::Null),
ScalarValue::Int8(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
ScalarValue::Int16(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
ScalarValue::Int32(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
ScalarValue::Int64(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
ScalarValue::UInt8(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
ScalarValue::UInt16(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
ScalarValue::UInt32(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
ScalarValue::UInt64(value) => value
.map(Number::from)
.map(Value::Number)
.unwrap_or(Value::Null),
ScalarValue::Utf8(value) | ScalarValue::LargeUtf8(value) => {
value.map(Value::String).unwrap_or(Value::Null)
}
ScalarValue::Binary(value) | ScalarValue::LargeBinary(value) => value
.map(|bytes| Value::String(format!("0x{}", hex::encode(bytes))))
.unwrap_or(Value::Null),
ScalarValue::Date32(value)
| ScalarValue::Date64(value)
| ScalarValue::TimestampSecond(value, _)
| ScalarValue::TimestampMillisecond(value, _)
| ScalarValue::TimestampMicrosecond(value, _)
| ScalarValue::TimestampNanosecond(value, _)
| ScalarValue::Time32Second(value)
| ScalarValue::Time32Millisecond(value)
| ScalarValue::Time64Microsecond(value)
| ScalarValue::Time64Nanosecond(value) => value
.map(|v| Value::String(v.to_string()))
.unwrap_or(Value::Null),
ScalarValue::Decimal128(value, precision, scale) => value
.map(|v| Value::String(format_decimal(v, scale as u32, precision as usize)))
.unwrap_or(Value::Null),
other => Value::String(other.to_string()),
};
Ok(value)
}
fn format_decimal(value: i128, scale: u32, precision: usize) -> String {
let sign = if value < 0 { "-" } else { "" };
let digits = value.abs().to_string();
if scale == 0 {
return format!("{sign}{digits}");
}
let width = precision.max(scale as usize + 1);
let padded = format!("{digits:0>width$}");
let split = padded.len() - scale as usize;
format!("{sign}{}.{}", &padded[..split], &padded[split..])
}
impl TryFrom<&str> for ObjectStoreUrl {
type Error = AppError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let url = Url::parse(value).map_err(|error| AppError::Validation(error.to_string()))?;
ObjectStoreUrl::parse(url.as_str()).map_err(Into::into)
}
}
impl From<object_store::Error> for AppError {
fn from(value: object_store::Error) -> Self {
AppError::Query(value.to_string())
}
}
impl From<datafusion::execution::object_store::ObjectStoreUrlError> for AppError {
fn from(value: datafusion::execution::object_store::ObjectStoreUrlError) -> Self {
AppError::Validation(value.to_string())
}
}
impl From<url::ParseError> for AppError {
fn from(value: url::ParseError) -> Self {
AppError::Validation(value.to_string())
}
}
#[cfg(test)]
mod tests {
use std::{fs::File, sync::Arc};
use bytes::Bytes;
use datafusion::arrow::{
array::{Int64Array, StringArray},
datatypes::{Field, Schema},
record_batch::RecordBatch,
};
use parquet::arrow::ArrowWriter;
use tempfile::tempdir;
use super::*;
use crate::config::{DatasetConfig, QueryRequest, StorageConfig};
#[tokio::test]
async fn queries_local_parquet_without_full_download_bootstrap() {
let dir = tempdir().unwrap();
let path = dir.path().join("sample.parquet");
write_parquet(&path).unwrap();
let request = QueryRequest {
sql: "SELECT city, trips FROM trips WHERE trips > 15 ORDER BY trips DESC".to_string(),
datasets: vec![DatasetConfig {
table_name: "trips".to_string(),
uri: path.to_string_lossy().to_string(),
storage: StorageConfig::Local,
}],
};
let response = QueryEngine::new().execute(request).await.unwrap();
assert_eq!(response.row_count, 2);
assert_eq!(response.rows[0]["city"], Value::String("seattle".to_string()));
assert_eq!(response.rows[1]["trips"], Value::Number(18.into()));
}
fn write_parquet(path: &std::path::Path) -> Result<(), Box<dyn std::error::Error>> {
let schema = Arc::new(Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("trips", DataType::Int64, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["vancouver", "seattle", "portland"])) as Arc<dyn Array>,
Arc::new(Int64Array::from(vec![12, 22, 18])) as Arc<dyn Array>,
],
)?;
let file = File::create(path)?;
let mut writer = ArrowWriter::try_new(file, schema, None)?;
writer.write(&batch)?;
writer.close()?;
let _ = Bytes::from_static(b"parquet");
Ok(())
}
}