datafusion-extra 0.0.2

[WIP] DataClod
Documentation
use std::any::Any;
use std::sync::Arc;

use async_trait::async_trait;
use datafusion::arrow::array::Int64Array;
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::{exec_err, DataFusionError, ScalarValue};
use datafusion::datasource::function::TableFunctionImpl;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::context::{ExecutionProps, SessionState};
use datafusion::logical_expr::{Expr, TableType};
use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;

struct GenerateSeriesTable<T: Send + Sync> {
    start: T,
    stop: T,
    step: T,
}

#[async_trait]
impl TableProvider for GenerateSeriesTable<i64> {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn schema(&self) -> SchemaRef {
        Arc::new(Schema::new(vec![Field::new(
            "generate_series",
            DataType::Int64,
            true,
        )]))
    }

    fn table_type(&self) -> TableType {
        TableType::Base
    }

    async fn scan(
        &self, _state: &SessionState, projection: Option<&Vec<usize>>, _filters: &[Expr],
        _limit: Option<usize>,
    ) -> Result<Arc<dyn ExecutionPlan>> {
        let schema = self.schema();
        let array: Int64Array = (self.start..=self.stop)
            .step_by(self.step as usize)
            .collect();
        let batches = vec![RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()];
        Ok(Arc::new(MemoryExec::try_new(
            &[batches],
            self.schema(),
            projection.cloned(),
        )?))
    }
}

pub struct GenerateSeriesUDTF {}

impl TableFunctionImpl for GenerateSeriesUDTF {
    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
        let execution_props = ExecutionProps::new();
        let info = SimplifyContext::new(&execution_props);
        let start = ExprSimplifier::new(info).simplify(exprs[0].clone())?;
        let execution_props = ExecutionProps::new();
        let info = SimplifyContext::new(&execution_props);
        let stop = ExprSimplifier::new(info).simplify(exprs[1].clone())?;
        let execution_props = ExecutionProps::new();
        let step = if exprs.len() == 3 {
            let info = SimplifyContext::new(&execution_props);
            ExprSimplifier::new(info).simplify(exprs[2].clone())?
        } else {
            Expr::Literal(ScalarValue::Int64(Some(1)))
        };

        match (start, stop, step) {
            (
                Expr::Literal(ScalarValue::Int64(Some(start))),
                Expr::Literal(ScalarValue::Int64(Some(stop))),
                Expr::Literal(ScalarValue::Int64(Some(step))),
            ) => Ok(Arc::new(GenerateSeriesTable { start, stop, step })),
            _ => exec_err!("Limit must be an integer"),
        }
    }
}

#[cfg(test)]
mod tests {
    use datafusion::execution::context::SessionContext;

    use super::*;

    #[tokio::test]
    async fn test_generate_series() -> Result<()> {
        let ctx = SessionContext::new();

        ctx.register_udtf("generate_series", Arc::new(GenerateSeriesUDTF {}));

        let df = ctx.sql("SELECT * FROM generate_series(1, 1);").await?;
        df.show().await?;

        Ok(())
    }
}