duckquill 0.2.1

Parquet-backed text2sql engine and CLI for schema-first querying workflows
Documentation
use std::sync::Arc;

use datafusion::{
    arrow::{
        array::Array,
        datatypes::DataType,
        record_batch::RecordBatch,
    },
    datasource::file_format::parquet::ParquetFormat,
    execution::{
        SessionStateBuilder,
        context::SessionContext,
        object_store::ObjectStoreUrl,
    },
    prelude::ParquetReadOptions,
    scalar::ScalarValue,
};
use datafusion::datasource::listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl};
use object_store::aws::AmazonS3Builder;
use serde_json::{Map, Number, Value};
use url::Url;

use crate::{
    config::{DatasetConfig, FieldSummary, QueryRequest, QueryResponse, StorageConfig},
    error::AppError,
};

#[derive(Debug, Clone, Default)]
pub struct QueryEngine;

impl QueryEngine {
    pub fn new() -> Self {
        Self
    }

    pub async fn execute(&self, request: QueryRequest) -> Result<QueryResponse, AppError> {
        if request.sql.trim().is_empty() {
            return Err(AppError::Validation("sql must not be empty".to_string()));
        }

        if request.datasets.is_empty() {
            return Err(AppError::Validation(
                "at least one dataset must be provided".to_string(),
            ));
        }

        let ctx = new_context();
        for dataset in &request.datasets {
            register_dataset(&ctx, dataset).await?;
        }

        let dataframe = ctx.sql(&request.sql).await?;
        let schema = dataframe.schema();
        let batches = dataframe.collect().await?;
        let rows = batches_to_rows(&batches)?;
        let fields = schema
            .fields()
            .iter()
            .map(|field| FieldSummary {
                name: field.name().to_string(),
                data_type: field.data_type().to_string(),
                nullable: field.is_nullable(),
            })
            .collect();

        Ok(QueryResponse {
            row_count: rows.len(),
            rows,
            fields,
        })
    }
}

fn new_context() -> SessionContext {
    let state = SessionStateBuilder::new().with_default_features().build();
    SessionContext::new_with_state(state)
}

async fn register_dataset(ctx: &SessionContext, dataset: &DatasetConfig) -> Result<(), AppError> {
    match &dataset.storage {
        StorageConfig::Local => {
            ctx.register_parquet(
                dataset.table_name.as_str(),
                dataset.uri.as_str(),
                ParquetReadOptions::default(),
            )
            .await?;
        }
        StorageConfig::S3(config) => {
            let store = build_s3_store(config)?;
            let store_url = format!("s3://{}", config.bucket);
            ctx.register_object_store(&ObjectStoreUrl::parse(store_url.as_str())?, Arc::new(store));
            register_listing_table(ctx, dataset).await?;
        }
        StorageConfig::Minio(config) => {
            let store = build_minio_store(config)?;
            let store_url = format!("s3://{}", config.bucket);
            ctx.register_object_store(&ObjectStoreUrl::parse(store_url.as_str())?, Arc::new(store));
            register_listing_table(ctx, dataset).await?;
        }
    }

    Ok(())
}

async fn register_listing_table(ctx: &SessionContext, dataset: &DatasetConfig) -> Result<(), AppError> {
    let table_url = ListingTableUrl::parse(dataset.uri.as_str())?;
    let file_format = Arc::new(ParquetFormat::default());
    let listing_options = ListingOptions::new(file_format)
        .with_file_extension(".parquet");
    let resolved_schema = listing_options
        .infer_schema(ctx.state(), &table_url)
        .await?;
    let config = ListingTableConfig::new(table_url)
        .with_listing_options(listing_options)
        .with_schema(resolved_schema);
    let table = ListingTable::try_new(config)?;
    ctx.register_table(dataset.table_name.as_str(), Arc::new(table))?;
    Ok(())
}

fn build_s3_store(config: &crate::config::S3StorageConfig) -> Result<object_store::aws::AmazonS3, AppError> {
    let mut builder = AmazonS3Builder::new()
        .with_bucket_name(config.bucket.as_str())
        .with_region(config.region.as_str())
        .with_allow_http(config.allow_http)
        .with_virtual_hosted_style_request(!config.force_path_style);

    if let Some(endpoint) = &config.endpoint {
        builder = builder.with_endpoint(endpoint.as_str());
    }
    if let Some(access_key_id) = &config.access_key_id {
        builder = builder.with_access_key_id(access_key_id.as_str());
    }
    if let Some(secret_access_key) = &config.secret_access_key {
        builder = builder.with_secret_access_key(secret_access_key.as_str());
    }
    if let Some(session_token) = &config.session_token {
        builder = builder.with_token(session_token.as_str());
    }

    builder
        .build()
        .map_err(|error| AppError::Validation(error.to_string()))
}

fn build_minio_store(config: &crate::config::MinioStorageConfig) -> Result<object_store::aws::AmazonS3, AppError> {
    AmazonS3Builder::new()
        .with_bucket_name(config.bucket.as_str())
        .with_region(config.region.as_str())
        .with_endpoint(config.endpoint.as_str())
        .with_access_key_id(config.access_key_id.as_str())
        .with_secret_access_key(config.secret_access_key.as_str())
        .with_allow_http(config.allow_http)
        .with_virtual_hosted_style_request(false)
        .build()
        .map_err(|error| AppError::Validation(error.to_string()))
}

fn batches_to_rows(batches: &[RecordBatch]) -> Result<Vec<Value>, AppError> {
    let mut rows = Vec::new();
    for batch in batches {
        let schema = batch.schema();
        for row_index in 0..batch.num_rows() {
            let mut row = Map::with_capacity(batch.num_columns());
            for (column_index, field) in schema.fields().iter().enumerate() {
                let array = batch.column(column_index);
                row.insert(
                    field.name().to_string(),
                    scalar_value_to_json(array.as_ref(), row_index)?,
                );
            }
            rows.push(Value::Object(row));
        }
    }
    Ok(rows)
}

fn scalar_value_to_json(array: &dyn Array, row_index: usize) -> Result<Value, AppError> {
    let scalar = ScalarValue::try_from_array(array, row_index)
        .map_err(|error| AppError::Serialization(error.to_string()))?;

    let value = match scalar {
        ScalarValue::Null => Value::Null,
        ScalarValue::Boolean(value) => value.map(Value::Bool).unwrap_or(Value::Null),
        ScalarValue::Float32(value) => value
            .and_then(Number::from_f64)
            .map(Value::Number)
            .unwrap_or(Value::Null),
        ScalarValue::Float64(value) => value
            .and_then(Number::from_f64)
            .map(Value::Number)
            .unwrap_or(Value::Null),
        ScalarValue::Int8(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
        ScalarValue::Int16(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
        ScalarValue::Int32(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
        ScalarValue::Int64(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
        ScalarValue::UInt8(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
        ScalarValue::UInt16(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
        ScalarValue::UInt32(value) => value.map(|v| Value::Number(v.into())).unwrap_or(Value::Null),
        ScalarValue::UInt64(value) => value
            .map(Number::from)
            .map(Value::Number)
            .unwrap_or(Value::Null),
        ScalarValue::Utf8(value) | ScalarValue::LargeUtf8(value) => {
            value.map(Value::String).unwrap_or(Value::Null)
        }
        ScalarValue::Binary(value) | ScalarValue::LargeBinary(value) => value
            .map(|bytes| Value::String(format!("0x{}", hex::encode(bytes))))
            .unwrap_or(Value::Null),
        ScalarValue::Date32(value)
        | ScalarValue::Date64(value)
        | ScalarValue::TimestampSecond(value, _)
        | ScalarValue::TimestampMillisecond(value, _)
        | ScalarValue::TimestampMicrosecond(value, _)
        | ScalarValue::TimestampNanosecond(value, _)
        | ScalarValue::Time32Second(value)
        | ScalarValue::Time32Millisecond(value)
        | ScalarValue::Time64Microsecond(value)
        | ScalarValue::Time64Nanosecond(value) => value
            .map(|v| Value::String(v.to_string()))
            .unwrap_or(Value::Null),
        ScalarValue::Decimal128(value, precision, scale) => value
            .map(|v| Value::String(format_decimal(v, scale as u32, precision as usize)))
            .unwrap_or(Value::Null),
        other => Value::String(other.to_string()),
    };

    Ok(value)
}

fn format_decimal(value: i128, scale: u32, precision: usize) -> String {
    let sign = if value < 0 { "-" } else { "" };
    let digits = value.abs().to_string();
    if scale == 0 {
        return format!("{sign}{digits}");
    }

    let width = precision.max(scale as usize + 1);
    let padded = format!("{digits:0>width$}");
    let split = padded.len() - scale as usize;
    format!("{sign}{}.{}", &padded[..split], &padded[split..])
}

impl TryFrom<&str> for ObjectStoreUrl {
    type Error = AppError;

    fn try_from(value: &str) -> Result<Self, Self::Error> {
        let url = Url::parse(value).map_err(|error| AppError::Validation(error.to_string()))?;
        ObjectStoreUrl::parse(url.as_str()).map_err(Into::into)
    }
}

impl From<object_store::Error> for AppError {
    fn from(value: object_store::Error) -> Self {
        AppError::Query(value.to_string())
    }
}

impl From<datafusion::execution::object_store::ObjectStoreUrlError> for AppError {
    fn from(value: datafusion::execution::object_store::ObjectStoreUrlError) -> Self {
        AppError::Validation(value.to_string())
    }
}

impl From<url::ParseError> for AppError {
    fn from(value: url::ParseError) -> Self {
        AppError::Validation(value.to_string())
    }
}

#[cfg(test)]
mod tests {
    use std::{fs::File, sync::Arc};

    use bytes::Bytes;
    use datafusion::arrow::{
        array::{Int64Array, StringArray},
        datatypes::{Field, Schema},
        record_batch::RecordBatch,
    };
    use parquet::arrow::ArrowWriter;
    use tempfile::tempdir;

    use super::*;
    use crate::config::{DatasetConfig, QueryRequest, StorageConfig};

    #[tokio::test]
    async fn queries_local_parquet_without_full_download_bootstrap() {
        let dir = tempdir().unwrap();
        let path = dir.path().join("sample.parquet");
        write_parquet(&path).unwrap();

        let request = QueryRequest {
            sql: "SELECT city, trips FROM trips WHERE trips > 15 ORDER BY trips DESC".to_string(),
            datasets: vec![DatasetConfig {
                table_name: "trips".to_string(),
                uri: path.to_string_lossy().to_string(),
                storage: StorageConfig::Local,
            }],
        };

        let response = QueryEngine::new().execute(request).await.unwrap();
        assert_eq!(response.row_count, 2);
        assert_eq!(response.rows[0]["city"], Value::String("seattle".to_string()));
        assert_eq!(response.rows[1]["trips"], Value::Number(18.into()));
    }

    fn write_parquet(path: &std::path::Path) -> Result<(), Box<dyn std::error::Error>> {
        let schema = Arc::new(Schema::new(vec![
            Field::new("city", DataType::Utf8, false),
            Field::new("trips", DataType::Int64, false),
        ]));
        let batch = RecordBatch::try_new(
            schema.clone(),
            vec![
                Arc::new(StringArray::from(vec!["vancouver", "seattle", "portland"])) as Arc<dyn Array>,
                Arc::new(Int64Array::from(vec![12, 22, 18])) as Arc<dyn Array>,
            ],
        )?;
        let file = File::create(path)?;
        let mut writer = ArrowWriter::try_new(file, schema, None)?;
        writer.write(&batch)?;
        writer.close()?;
        let _ = Bytes::from_static(b"parquet");
        Ok(())
    }
}