Skip to main content

datafusion_functions_table/
generate_series.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::timezone::Tz;
19use arrow::array::types::TimestampNanosecondType;
20use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray};
21use arrow::datatypes::{
22    DataType, Field, IntervalMonthDayNano, Schema, SchemaRef, TimeUnit,
23};
24use arrow::record_batch::RecordBatch;
25use async_trait::async_trait;
26use datafusion_catalog::Session;
27use datafusion_catalog::TableFunctionImpl;
28use datafusion_catalog::TableProvider;
29use datafusion_common::{Result, ScalarValue, plan_err};
30use datafusion_expr::{Expr, TableType};
31use datafusion_physical_plan::ExecutionPlan;
32use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
33use parking_lot::RwLock;
34use std::any::Any;
35use std::fmt;
36use std::str::FromStr;
37use std::sync::Arc;
38
39/// Empty generator that produces no rows - used when series arguments contain null values
40#[derive(Debug, Clone)]
41pub struct Empty {
42    name: &'static str,
43}
44
45impl Empty {
46    pub fn name(&self) -> &'static str {
47        self.name
48    }
49}
50
51impl LazyBatchGenerator for Empty {
52    fn as_any(&self) -> &dyn Any {
53        self
54    }
55
56    fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
57        Ok(None)
58    }
59
60    fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
61        Arc::new(RwLock::new(Empty { name: self.name }))
62    }
63}
64
65impl fmt::Display for Empty {
66    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67        write!(f, "{}: empty", self.name)
68    }
69}
70
71/// Trait for values that can be generated in a series
72pub trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static {
73    type StepType: fmt::Debug + Clone + Send + Sync;
74    type ValueType: fmt::Debug + Clone + Send + Sync;
75
76    /// Check if we've reached the end of the series
77    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool;
78
79    /// Advance to the next value in the series
80    fn advance(&mut self, step: &Self::StepType) -> Result<()>;
81
82    /// Create an Arrow array from a vector of values
83    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef>;
84
85    /// Convert self to ValueType for array creation
86    fn to_value_type(&self) -> Self::ValueType;
87
88    /// Display the value for debugging
89    fn display_value(&self) -> String;
90}
91
92impl SeriesValue for i64 {
93    type StepType = i64;
94    type ValueType = i64;
95
96    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
97        reach_end_int64(*self, end, *step, include_end)
98    }
99
100    fn advance(&mut self, step: &Self::StepType) -> Result<()> {
101        *self += step;
102        Ok(())
103    }
104
105    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
106        Ok(Arc::new(Int64Array::from(values)))
107    }
108
109    fn to_value_type(&self) -> Self::ValueType {
110        *self
111    }
112
113    fn display_value(&self) -> String {
114        self.to_string()
115    }
116}
117
118#[derive(Debug, Clone)]
119pub struct TimestampValue {
120    value: i64,
121    parsed_tz: Option<Tz>,
122    tz_str: Option<Arc<str>>,
123}
124
125impl TimestampValue {
126    pub fn value(&self) -> i64 {
127        self.value
128    }
129
130    pub fn tz_str(&self) -> Option<&Arc<str>> {
131        self.tz_str.as_ref()
132    }
133}
134
135impl SeriesValue for TimestampValue {
136    type StepType = IntervalMonthDayNano;
137    type ValueType = i64;
138
139    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
140        let step_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
141
142        if include_end {
143            if step_negative {
144                self.value < end.value
145            } else {
146                self.value > end.value
147            }
148        } else if step_negative {
149            self.value <= end.value
150        } else {
151            self.value >= end.value
152        }
153    }
154
155    fn advance(&mut self, step: &Self::StepType) -> Result<()> {
156        let tz = self
157            .parsed_tz
158            .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
159        let Some(next_ts) =
160            TimestampNanosecondType::add_month_day_nano(self.value, *step, tz)
161        else {
162            return plan_err!(
163                "Failed to add interval {:?} to timestamp {}",
164                step,
165                self.value
166            );
167        };
168        self.value = next_ts;
169        Ok(())
170    }
171
172    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
173        let array = TimestampNanosecondArray::from(values);
174
175        // Use timezone from self (now we have access to tz through &self)
176        let array = match self.tz_str.as_ref() {
177            Some(tz_str) => array.with_timezone(Arc::clone(tz_str)),
178            None => array,
179        };
180
181        Ok(Arc::new(array))
182    }
183
184    fn to_value_type(&self) -> Self::ValueType {
185        self.value
186    }
187
188    fn display_value(&self) -> String {
189        self.value.to_string()
190    }
191}
192
193/// Indicates the arguments used for generating a series.
194#[derive(Debug, Clone)]
195pub enum GenSeriesArgs {
196    /// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated.
197    ContainsNull { name: &'static str },
198    /// Int64Args holds the start, end, and step values for generating integer series when all arguments are not null.
199    Int64Args {
200        start: i64,
201        end: i64,
202        step: i64,
203        /// Indicates whether the end value should be included in the series.
204        include_end: bool,
205        name: &'static str,
206    },
207    /// TimestampArgs holds the start, end, and step values for generating timestamp series when all arguments are not null.
208    TimestampArgs {
209        start: i64,
210        end: i64,
211        step: IntervalMonthDayNano,
212        tz: Option<Arc<str>>,
213        /// Indicates whether the end value should be included in the series.
214        include_end: bool,
215        name: &'static str,
216    },
217    /// DateArgs holds the start, end, and step values for generating date series when all arguments are not null.
218    /// Internally, dates are converted to timestamps and use the timestamp logic.
219    DateArgs {
220        start: i64,
221        end: i64,
222        step: IntervalMonthDayNano,
223        /// Indicates whether the end value should be included in the series.
224        include_end: bool,
225        name: &'static str,
226    },
227}
228
229/// Table that generates a series of integers/timestamps from `start`(inclusive) to `end`, incrementing by step
230#[derive(Debug, Clone)]
231pub struct GenerateSeriesTable {
232    schema: SchemaRef,
233    args: GenSeriesArgs,
234}
235
236impl GenerateSeriesTable {
237    pub fn new(schema: SchemaRef, args: GenSeriesArgs) -> Self {
238        Self { schema, args }
239    }
240
241    pub fn as_generator(
242        &self,
243        batch_size: usize,
244    ) -> Result<Arc<RwLock<dyn LazyBatchGenerator>>> {
245        let generator: Arc<RwLock<dyn LazyBatchGenerator>> = match &self.args {
246            GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })),
247            GenSeriesArgs::Int64Args {
248                start,
249                end,
250                step,
251                include_end,
252                name,
253            } => Arc::new(RwLock::new(GenericSeriesState {
254                schema: self.schema(),
255                start: *start,
256                end: *end,
257                step: *step,
258                current: *start,
259                batch_size,
260                include_end: *include_end,
261                name,
262            })),
263            GenSeriesArgs::TimestampArgs {
264                start,
265                end,
266                step,
267                tz,
268                include_end,
269                name,
270            } => {
271                let parsed_tz = tz
272                    .as_ref()
273                    .map(|s| Tz::from_str(s.as_ref()))
274                    .transpose()
275                    .map_err(|e| {
276                        datafusion_common::internal_datafusion_err!(
277                            "Failed to parse timezone: {e}"
278                        )
279                    })?
280                    .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
281                Arc::new(RwLock::new(GenericSeriesState {
282                    schema: self.schema(),
283                    start: TimestampValue {
284                        value: *start,
285                        parsed_tz: Some(parsed_tz),
286                        tz_str: tz.clone(),
287                    },
288                    end: TimestampValue {
289                        value: *end,
290                        parsed_tz: Some(parsed_tz),
291                        tz_str: tz.clone(),
292                    },
293                    step: *step,
294                    current: TimestampValue {
295                        value: *start,
296                        parsed_tz: Some(parsed_tz),
297                        tz_str: tz.clone(),
298                    },
299                    batch_size,
300                    include_end: *include_end,
301                    name,
302                }))
303            }
304            GenSeriesArgs::DateArgs {
305                start,
306                end,
307                step,
308                include_end,
309                name,
310            } => Arc::new(RwLock::new(GenericSeriesState {
311                schema: self.schema(),
312                start: TimestampValue {
313                    value: *start,
314                    parsed_tz: None,
315                    tz_str: None,
316                },
317                end: TimestampValue {
318                    value: *end,
319                    parsed_tz: None,
320                    tz_str: None,
321                },
322                step: *step,
323                current: TimestampValue {
324                    value: *start,
325                    parsed_tz: None,
326                    tz_str: None,
327                },
328                batch_size,
329                include_end: *include_end,
330                name,
331            })),
332        };
333
334        Ok(generator)
335    }
336}
337
338#[derive(Debug, Clone)]
339pub struct GenericSeriesState<T: SeriesValue> {
340    schema: SchemaRef,
341    start: T,
342    end: T,
343    step: T::StepType,
344    batch_size: usize,
345    current: T,
346    include_end: bool,
347    name: &'static str,
348}
349
350impl<T: SeriesValue> GenericSeriesState<T> {
351    pub fn name(&self) -> &'static str {
352        self.name
353    }
354
355    pub fn batch_size(&self) -> usize {
356        self.batch_size
357    }
358
359    pub fn include_end(&self) -> bool {
360        self.include_end
361    }
362
363    pub fn start(&self) -> &T {
364        &self.start
365    }
366
367    pub fn end(&self) -> &T {
368        &self.end
369    }
370
371    pub fn step(&self) -> &T::StepType {
372        &self.step
373    }
374
375    pub fn current(&self) -> &T {
376        &self.current
377    }
378}
379
380impl<T: SeriesValue> LazyBatchGenerator for GenericSeriesState<T> {
381    fn as_any(&self) -> &dyn Any {
382        self
383    }
384
385    fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
386        let mut buf = Vec::with_capacity(self.batch_size);
387
388        while buf.len() < self.batch_size
389            && !self
390                .current
391                .should_stop(self.end.clone(), &self.step, self.include_end)
392        {
393            buf.push(self.current.to_value_type());
394            self.current.advance(&self.step)?;
395        }
396
397        if buf.is_empty() {
398            return Ok(None);
399        }
400
401        let array = self.current.create_array(buf)?;
402        let batch = RecordBatch::try_new(Arc::clone(&self.schema), vec![array])?;
403        Ok(Some(batch))
404    }
405
406    fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
407        let mut new = self.clone();
408        new.current = new.start.clone();
409        Arc::new(RwLock::new(new))
410    }
411}
412
413impl<T: SeriesValue> fmt::Display for GenericSeriesState<T> {
414    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
415        write!(
416            f,
417            "{}: start={}, end={}, batch_size={}",
418            self.name,
419            self.start.display_value(),
420            self.end.display_value(),
421            self.batch_size
422        )
423    }
424}
425
426fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool {
427    if step > 0 {
428        if include_end { val > end } else { val >= end }
429    } else if include_end {
430        val < end
431    } else {
432        val <= end
433    }
434}
435
436fn validate_interval_step(step: IntervalMonthDayNano) -> Result<()> {
437    if step.months == 0 && step.days == 0 && step.nanoseconds == 0 {
438        return plan_err!("Step interval cannot be zero");
439    }
440
441    Ok(())
442}
443
444#[async_trait]
445impl TableProvider for GenerateSeriesTable {
446    fn as_any(&self) -> &dyn Any {
447        self
448    }
449
450    fn schema(&self) -> SchemaRef {
451        Arc::clone(&self.schema)
452    }
453
454    fn table_type(&self) -> TableType {
455        TableType::Base
456    }
457
458    async fn scan(
459        &self,
460        state: &dyn Session,
461        projection: Option<&Vec<usize>>,
462        _filters: &[Expr],
463        _limit: Option<usize>,
464    ) -> Result<Arc<dyn ExecutionPlan>> {
465        let batch_size = state.config_options().execution.batch_size;
466        let generator = self.as_generator(batch_size)?;
467
468        Ok(Arc::new(
469            LazyMemoryExec::try_new(self.schema(), vec![generator])?
470                .with_projection(projection.cloned()),
471        ))
472    }
473}
474
475#[derive(Debug)]
476struct GenerateSeriesFuncImpl {
477    name: &'static str,
478    include_end: bool,
479}
480
481impl TableFunctionImpl for GenerateSeriesFuncImpl {
482    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
483        if exprs.is_empty() || exprs.len() > 3 {
484            return plan_err!("{} function requires 1 to 3 arguments", self.name);
485        }
486
487        // Determine the data type from the first argument
488        match &exprs[0] {
489            Expr::Literal(
490                // Default to int64 for null
491                ScalarValue::Null | ScalarValue::Int64(_),
492                _,
493            ) => self.call_int64(exprs),
494            Expr::Literal(s, _) if matches!(s.data_type(), DataType::Timestamp(_, _)) => {
495                self.call_timestamp(exprs)
496            }
497            Expr::Literal(s, _) if matches!(s.data_type(), DataType::Date32) => {
498                self.call_date(exprs)
499            }
500            Expr::Literal(scalar, _) => {
501                plan_err!(
502                    "Argument #1 must be an INTEGER, TIMESTAMP, DATE or NULL, got {:?}",
503                    scalar.data_type()
504                )
505            }
506            _ => plan_err!("Arguments must be literals"),
507        }
508    }
509}
510
511impl GenerateSeriesFuncImpl {
512    fn call_int64(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
513        let mut normalize_args = Vec::new();
514        for (expr_index, expr) in exprs.iter().enumerate() {
515            match expr {
516                Expr::Literal(ScalarValue::Null, _) => {}
517                Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n),
518                other => {
519                    return plan_err!(
520                        "Argument #{} must be an INTEGER or NULL, got {:?}",
521                        expr_index + 1,
522                        other
523                    );
524                }
525            };
526        }
527
528        let schema = Arc::new(Schema::new(vec![Field::new(
529            "value",
530            DataType::Int64,
531            false,
532        )]));
533
534        if normalize_args.len() != exprs.len() {
535            // contain null
536            return Ok(Arc::new(GenerateSeriesTable {
537                schema,
538                args: GenSeriesArgs::ContainsNull { name: self.name },
539            }));
540        }
541
542        let (start, end, step) = match &normalize_args[..] {
543            [end] => (0, *end, 1),
544            [start, end] => (*start, *end, 1),
545            [start, end, step] => (*start, *end, *step),
546            _ => {
547                return plan_err!("{} function requires 1 to 3 arguments", self.name);
548            }
549        };
550
551        if step == 0 {
552            return plan_err!("Step cannot be zero");
553        }
554
555        Ok(Arc::new(GenerateSeriesTable {
556            schema,
557            args: GenSeriesArgs::Int64Args {
558                start,
559                end,
560                step,
561                include_end: self.include_end,
562                name: self.name,
563            },
564        }))
565    }
566
567    fn call_timestamp(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
568        if exprs.len() != 3 {
569            return plan_err!(
570                "{} function with timestamps requires exactly 3 arguments",
571                self.name
572            );
573        }
574
575        // Parse start timestamp
576        let (start_ts, tz) = match &exprs[0] {
577            Expr::Literal(ScalarValue::TimestampNanosecond(ts, tz), _) => {
578                (*ts, tz.clone())
579            }
580            other => {
581                return plan_err!(
582                    "First argument must be a timestamp or NULL, got {:?}",
583                    other
584                );
585            }
586        };
587
588        // Parse end timestamp
589        let end_ts = match &exprs[1] {
590            Expr::Literal(ScalarValue::Null, _) => None,
591            Expr::Literal(ScalarValue::TimestampNanosecond(ts, _), _) => *ts,
592            other => {
593                return plan_err!(
594                    "Second argument must be a timestamp or NULL, got {:?}",
595                    other
596                );
597            }
598        };
599
600        // Parse step interval
601        let step_interval = match &exprs[2] {
602            Expr::Literal(ScalarValue::Null, _) => None,
603            Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), _) => *interval,
604            other => {
605                return plan_err!(
606                    "Third argument must be an interval or NULL, got {:?}",
607                    other
608                );
609            }
610        };
611
612        let schema = Arc::new(Schema::new(vec![Field::new(
613            "value",
614            DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
615            false,
616        )]));
617
618        // Check if any argument is null
619        let (Some(start), Some(end), Some(step)) = (start_ts, end_ts, step_interval)
620        else {
621            return Ok(Arc::new(GenerateSeriesTable {
622                schema,
623                args: GenSeriesArgs::ContainsNull { name: self.name },
624            }));
625        };
626
627        // Validate step interval
628        validate_interval_step(step)?;
629
630        Ok(Arc::new(GenerateSeriesTable {
631            schema,
632            args: GenSeriesArgs::TimestampArgs {
633                start,
634                end,
635                step,
636                tz,
637                include_end: self.include_end,
638                name: self.name,
639            },
640        }))
641    }
642
643    fn call_date(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
644        if exprs.len() != 3 {
645            return plan_err!(
646                "{} function with dates requires exactly 3 arguments",
647                self.name
648            );
649        }
650
651        let schema = Arc::new(Schema::new(vec![Field::new(
652            "value",
653            DataType::Timestamp(TimeUnit::Nanosecond, None),
654            false,
655        )]));
656
657        // Parse start date
658        let start_date = match &exprs[0] {
659            Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
660            Expr::Literal(ScalarValue::Date32(None), _)
661            | Expr::Literal(ScalarValue::Null, _) => {
662                return Ok(Arc::new(GenerateSeriesTable {
663                    schema,
664                    args: GenSeriesArgs::ContainsNull { name: self.name },
665                }));
666            }
667            other => {
668                return plan_err!(
669                    "First argument must be a date or NULL, got {:?}",
670                    other
671                );
672            }
673        };
674
675        // Parse end date
676        let end_date = match &exprs[1] {
677            Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
678            Expr::Literal(ScalarValue::Date32(None), _)
679            | Expr::Literal(ScalarValue::Null, _) => {
680                return Ok(Arc::new(GenerateSeriesTable {
681                    schema,
682                    args: GenSeriesArgs::ContainsNull { name: self.name },
683                }));
684            }
685            other => {
686                return plan_err!(
687                    "Second argument must be a date or NULL, got {:?}",
688                    other
689                );
690            }
691        };
692
693        // Parse step interval
694        let step_interval = match &exprs[2] {
695            Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval)), _) => {
696                *interval
697            }
698            Expr::Literal(ScalarValue::IntervalMonthDayNano(None), _)
699            | Expr::Literal(ScalarValue::Null, _) => {
700                return Ok(Arc::new(GenerateSeriesTable {
701                    schema,
702                    args: GenSeriesArgs::ContainsNull { name: self.name },
703                }));
704            }
705            other => {
706                return plan_err!(
707                    "Third argument must be an interval or NULL, got {:?}",
708                    other
709                );
710            }
711        };
712
713        // Convert Date32 (days since epoch) to timestamp nanoseconds (nanoseconds since epoch)
714        // Date32 is days since 1970-01-01, so multiply by nanoseconds per day
715        const NANOS_PER_DAY: i64 = 24 * 60 * 60 * 1_000_000_000;
716
717        let start_ts = start_date as i64 * NANOS_PER_DAY;
718        let end_ts = end_date as i64 * NANOS_PER_DAY;
719
720        // Validate step interval
721        validate_interval_step(step_interval)?;
722
723        Ok(Arc::new(GenerateSeriesTable {
724            schema,
725            args: GenSeriesArgs::DateArgs {
726                start: start_ts,
727                end: end_ts,
728                step: step_interval,
729                include_end: self.include_end,
730                name: self.name,
731            },
732        }))
733    }
734}
735
736#[derive(Debug)]
737pub struct GenerateSeriesFunc {}
738
739impl TableFunctionImpl for GenerateSeriesFunc {
740    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
741        let impl_func = GenerateSeriesFuncImpl {
742            name: "generate_series",
743            include_end: true,
744        };
745        impl_func.call(exprs)
746    }
747}
748
749#[derive(Debug)]
750pub struct RangeFunc {}
751
752impl TableFunctionImpl for RangeFunc {
753    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
754        let impl_func = GenerateSeriesFuncImpl {
755            name: "range",
756            include_end: false,
757        };
758        impl_func.call(exprs)
759    }
760}
761
762#[cfg(test)]
763mod generate_series_tests {
764    use std::sync::Arc;
765
766    use arrow::datatypes::{DataType, Field, Schema};
767    use datafusion_common::Result;
768    use datafusion_physical_plan::memory::LazyBatchGenerator;
769
770    use crate::generate_series::GenericSeriesState;
771
772    #[test]
773    fn test_generic_series_state_reset() -> Result<()> {
774        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
775        let mut state = GenericSeriesState::<i64> {
776            schema,
777            start: 1,
778            end: 5,
779            step: 1,
780            current: 1,
781            batch_size: 8192,
782            include_end: true,
783            name: "test",
784        };
785        let batch = state.generate_next_batch()?.expect("missing batch");
786
787        let state_reset = state.reset_state();
788        let reset_batch = state_reset
789            .write()
790            .generate_next_batch()?
791            .expect("missing reset batch");
792
793        assert_eq!(batch, reset_batch);
794
795        Ok(())
796    }
797}