datafusion-util 0.0.2

[WIP] DataClod
Documentation
use std::ops::{Add, AddAssign};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use anyhow::{bail, Result};
use async_trait::async_trait;
use datafusion::arrow::array::{ArrayRef, Int64Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::Result as DFResult;
use datafusion::datasource::streaming::StreamingTable;
use datafusion::datasource::TableProvider;
use datafusion::execution::TaskContext;
use datafusion::logical_expr::{Signature, TypeSignature, Volatility};
use datafusion::physical_plan::streaming::PartitionStream;
use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
use futures::Stream;
use num_traits::Zero;

use crate::udtf::{FromFuncParamValue, FuncParamValue, TableFuncContextProvider, TableUDF};

#[derive(Debug, Clone, Copy)]
pub struct GenerateSeries;

#[async_trait]
impl TableUDF for GenerateSeries {
    fn name(&self) -> &str {
        "generate_series"
    }

    async fn create_provider(
        &self, _: &dyn TableFuncContextProvider, args: Vec<FuncParamValue>,
    ) -> Result<Arc<dyn TableProvider>> {
        match args.len() {
            2 => {
                let mut args = args.into_iter();
                let start = args.next().unwrap();
                let stop = args.next().unwrap();

                if i64::is_param_valid(&start) && i64::is_param_valid(&stop) {
                    create_straming_table::<GenerateSeriesTypeInt>(
                        GenerateSeriesTypeInt,
                        start.param_into()?,
                        stop.param_into()?,
                        1,
                    )
                } else {
                    bail!("'start' and 'stop' must be integers")
                }
            }
            3 => {
                let mut args = args.into_iter();
                let start = args.next().unwrap();
                let stop = args.next().unwrap();
                let step = args.next().unwrap();

                if i64::is_param_valid(&start)
                    && i64::is_param_valid(&stop)
                    && i64::is_param_valid(&step)
                {
                    create_straming_table::<GenerateSeriesTypeInt>(
                        GenerateSeriesTypeInt,
                        start.param_into()?,
                        stop.param_into()?,
                        step.param_into()?,
                    )
                } else {
                    bail!("'start', 'stop' and 'step' must be integers")
                }
            }
            _ => bail!("'generate_series' must have 2 or 3 arguments"),
        }
    }

    fn signature(&self) -> Option<Signature> {
        Some(Signature::new(
            TypeSignature::OneOf(vec![
                TypeSignature::Uniform(2, vec![DataType::Int64]),
                TypeSignature::Uniform(3, vec![DataType::Int64]),
            ]),
            Volatility::Immutable,
        ))
    }
}

fn create_straming_table<T: GenerateSeriesType>(
    gen_series_type: T, start: T::PrimType, stop: T::PrimType, step: T::PrimType,
) -> Result<Arc<dyn TableProvider>> {
    if step.is_zero() {
        bail!("'step' may not be zero")
    }

    let partition: GenerateSeriesPartition<T> =
        GenerateSeriesPartition::new(gen_series_type, start, stop, step);
    let table = StreamingTable::try_new(partition.schema().clone(), vec![Arc::new(partition)])?;

    Ok(Arc::new(table))
}

trait GenerateSeriesType: Send + Sync + 'static {
    type PrimType: Send + Sync + PartialOrd + AddAssign + Add + Zero + Copy + Unpin;

    fn arrow_type(&self) -> DataType;
    fn collect_array(&self, series: Vec<Self::PrimType>) -> ArrayRef;
}

struct GenerateSeriesTypeInt;

impl GenerateSeriesType for GenerateSeriesTypeInt {
    type PrimType = i64;

    fn arrow_type(&self) -> DataType {
        DataType::Int64
    }

    fn collect_array(&self, series: Vec<Self::PrimType>) -> ArrayRef {
        let arr = Int64Array::from_iter_values(series);
        Arc::new(arr)
    }
}

struct GenerateSeriesPartition<T: GenerateSeriesType> {
    gen_series_type: Arc<T>,
    schema: SchemaRef,
    start: T::PrimType,
    stop: T::PrimType,
    step: T::PrimType,
}

impl<T: GenerateSeriesType> GenerateSeriesPartition<T> {
    fn new(gen_series_type: T, start: T::PrimType, stop: T::PrimType, step: T::PrimType) -> Self {
        GenerateSeriesPartition {
            schema: Arc::new(Schema::new([Arc::new(Field::new(
                "generate_series",
                gen_series_type.arrow_type(),
                false,
            ))])),
            start,
            stop,
            step,
            gen_series_type: Arc::new(gen_series_type),
        }
    }
}

impl<T: GenerateSeriesType> PartitionStream for GenerateSeriesPartition<T> {
    fn schema(&self) -> &Arc<Schema> {
        &self.schema
    }

    fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
        Box::pin(GenerateSeriesStream::<T> {
            gen_series_type: Arc::clone(&self.gen_series_type),
            schema: self.schema.clone(),
            exhausted: false,
            curr: self.start,
            stop: self.stop,
            step: self.step,
        })
    }
}

struct GenerateSeriesStream<T: GenerateSeriesType> {
    gen_series_type: Arc<T>,
    schema: Arc<Schema>,
    exhausted: bool,
    curr: T::PrimType,
    stop: T::PrimType,
    step: T::PrimType,
}

impl<T: GenerateSeriesType> GenerateSeriesStream<T> {
    fn generate_next(&mut self) -> Option<RecordBatch> {
        if self.exhausted {
            return None;
        }

        const BATCH_SIZE: usize = 1000;

        let mut series: Vec<_> = Vec::new();
        if self.curr < self.stop && self.step > T::PrimType::zero() {
            // Going up.
            let mut count = 0;
            while self.curr <= self.stop && count < BATCH_SIZE {
                series.push(self.curr);
                self.curr += self.step;
                count += 1;
            }
        } else if self.curr > self.stop && self.step < T::PrimType::zero() {
            // Going down.
            let mut count = 0;
            while self.curr >= self.stop && count < BATCH_SIZE {
                series.push(self.curr);
                self.curr += self.step;
                count += 1;
            }
        }

        if series.len() < BATCH_SIZE {
            self.exhausted = true
        }

        // Calculate the start value for the next iteration.
        if let Some(last) = series.last() {
            self.curr = *last + self.step;
        }

        let arrow_dt = self.gen_series_type.arrow_type();
        let arr = self.gen_series_type.collect_array(series);
        assert_eq!(arr.data_type(), &arrow_dt);
        let batch = RecordBatch::try_new(self.schema.clone(), vec![arr]).unwrap();
        Some(batch)
    }
}

impl<T: GenerateSeriesType> Stream for GenerateSeriesStream<T> {
    type Item = DFResult<RecordBatch>;

    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        Poll::Ready(self.get_mut().generate_next().map(Ok))
    }
}

impl<T: GenerateSeriesType> RecordBatchStream for GenerateSeriesStream<T> {
    fn schema(&self) -> SchemaRef {
        self.schema.clone()
    }
}