use std::ops::{Add, AddAssign};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use anyhow::{bail, Result};
use async_trait::async_trait;
use datafusion::arrow::array::{ArrayRef, Int64Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::Result as DFResult;
use datafusion::datasource::streaming::StreamingTable;
use datafusion::datasource::TableProvider;
use datafusion::execution::TaskContext;
use datafusion::logical_expr::{Signature, TypeSignature, Volatility};
use datafusion::physical_plan::streaming::PartitionStream;
use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
use futures::Stream;
use num_traits::Zero;
use crate::udtf::{FromFuncParamValue, FuncParamValue, TableFuncContextProvider, TableUDF};
#[derive(Debug, Clone, Copy)]
pub struct GenerateSeries;
#[async_trait]
impl TableUDF for GenerateSeries {
fn name(&self) -> &str {
"generate_series"
}
async fn create_provider(
&self, _: &dyn TableFuncContextProvider, args: Vec<FuncParamValue>,
) -> Result<Arc<dyn TableProvider>> {
match args.len() {
2 => {
let mut args = args.into_iter();
let start = args.next().unwrap();
let stop = args.next().unwrap();
if i64::is_param_valid(&start) && i64::is_param_valid(&stop) {
create_straming_table::<GenerateSeriesTypeInt>(
GenerateSeriesTypeInt,
start.param_into()?,
stop.param_into()?,
1,
)
} else {
bail!("'start' and 'stop' must be integers")
}
}
3 => {
let mut args = args.into_iter();
let start = args.next().unwrap();
let stop = args.next().unwrap();
let step = args.next().unwrap();
if i64::is_param_valid(&start)
&& i64::is_param_valid(&stop)
&& i64::is_param_valid(&step)
{
create_straming_table::<GenerateSeriesTypeInt>(
GenerateSeriesTypeInt,
start.param_into()?,
stop.param_into()?,
step.param_into()?,
)
} else {
bail!("'start', 'stop' and 'step' must be integers")
}
}
_ => bail!("'generate_series' must have 2 or 3 arguments"),
}
}
fn signature(&self) -> Option<Signature> {
Some(Signature::new(
TypeSignature::OneOf(vec![
TypeSignature::Uniform(2, vec![DataType::Int64]),
TypeSignature::Uniform(3, vec![DataType::Int64]),
]),
Volatility::Immutable,
))
}
}
fn create_straming_table<T: GenerateSeriesType>(
gen_series_type: T, start: T::PrimType, stop: T::PrimType, step: T::PrimType,
) -> Result<Arc<dyn TableProvider>> {
if step.is_zero() {
bail!("'step' may not be zero")
}
let partition: GenerateSeriesPartition<T> =
GenerateSeriesPartition::new(gen_series_type, start, stop, step);
let table = StreamingTable::try_new(partition.schema().clone(), vec![Arc::new(partition)])?;
Ok(Arc::new(table))
}
trait GenerateSeriesType: Send + Sync + 'static {
type PrimType: Send + Sync + PartialOrd + AddAssign + Add + Zero + Copy + Unpin;
fn arrow_type(&self) -> DataType;
fn collect_array(&self, series: Vec<Self::PrimType>) -> ArrayRef;
}
struct GenerateSeriesTypeInt;
impl GenerateSeriesType for GenerateSeriesTypeInt {
type PrimType = i64;
fn arrow_type(&self) -> DataType {
DataType::Int64
}
fn collect_array(&self, series: Vec<Self::PrimType>) -> ArrayRef {
let arr = Int64Array::from_iter_values(series);
Arc::new(arr)
}
}
struct GenerateSeriesPartition<T: GenerateSeriesType> {
gen_series_type: Arc<T>,
schema: SchemaRef,
start: T::PrimType,
stop: T::PrimType,
step: T::PrimType,
}
impl<T: GenerateSeriesType> GenerateSeriesPartition<T> {
fn new(gen_series_type: T, start: T::PrimType, stop: T::PrimType, step: T::PrimType) -> Self {
GenerateSeriesPartition {
schema: Arc::new(Schema::new([Arc::new(Field::new(
"generate_series",
gen_series_type.arrow_type(),
false,
))])),
start,
stop,
step,
gen_series_type: Arc::new(gen_series_type),
}
}
}
impl<T: GenerateSeriesType> PartitionStream for GenerateSeriesPartition<T> {
fn schema(&self) -> &Arc<Schema> {
&self.schema
}
fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
Box::pin(GenerateSeriesStream::<T> {
gen_series_type: Arc::clone(&self.gen_series_type),
schema: self.schema.clone(),
exhausted: false,
curr: self.start,
stop: self.stop,
step: self.step,
})
}
}
struct GenerateSeriesStream<T: GenerateSeriesType> {
gen_series_type: Arc<T>,
schema: Arc<Schema>,
exhausted: bool,
curr: T::PrimType,
stop: T::PrimType,
step: T::PrimType,
}
impl<T: GenerateSeriesType> GenerateSeriesStream<T> {
fn generate_next(&mut self) -> Option<RecordBatch> {
if self.exhausted {
return None;
}
const BATCH_SIZE: usize = 1000;
let mut series: Vec<_> = Vec::new();
if self.curr < self.stop && self.step > T::PrimType::zero() {
let mut count = 0;
while self.curr <= self.stop && count < BATCH_SIZE {
series.push(self.curr);
self.curr += self.step;
count += 1;
}
} else if self.curr > self.stop && self.step < T::PrimType::zero() {
let mut count = 0;
while self.curr >= self.stop && count < BATCH_SIZE {
series.push(self.curr);
self.curr += self.step;
count += 1;
}
}
if series.len() < BATCH_SIZE {
self.exhausted = true
}
if let Some(last) = series.last() {
self.curr = *last + self.step;
}
let arrow_dt = self.gen_series_type.arrow_type();
let arr = self.gen_series_type.collect_array(series);
assert_eq!(arr.data_type(), &arrow_dt);
let batch = RecordBatch::try_new(self.schema.clone(), vec![arr]).unwrap();
Some(batch)
}
}
impl<T: GenerateSeriesType> Stream for GenerateSeriesStream<T> {
type Item = DFResult<RecordBatch>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(self.get_mut().generate_next().map(Ok))
}
}
impl<T: GenerateSeriesType> RecordBatchStream for GenerateSeriesStream<T> {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}