datafusion_functions_table/
generate_series.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::timezone::Tz;
19use arrow::array::types::TimestampNanosecondType;
20use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray};
21use arrow::datatypes::{
22    DataType, Field, IntervalMonthDayNano, Schema, SchemaRef, TimeUnit,
23};
24use arrow::record_batch::RecordBatch;
25use async_trait::async_trait;
26use datafusion_catalog::Session;
27use datafusion_catalog::TableFunctionImpl;
28use datafusion_catalog::TableProvider;
29use datafusion_common::{plan_err, Result, ScalarValue};
30use datafusion_expr::{Expr, TableType};
31use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
32use datafusion_physical_plan::ExecutionPlan;
33use parking_lot::RwLock;
34use std::any::Any;
35use std::fmt;
36use std::str::FromStr;
37use std::sync::Arc;
38
39/// Empty generator that produces no rows - used when series arguments contain null values
40#[derive(Debug, Clone)]
41pub struct Empty {
42    name: &'static str,
43}
44
45impl Empty {
46    pub fn name(&self) -> &'static str {
47        self.name
48    }
49}
50
51impl LazyBatchGenerator for Empty {
52    fn as_any(&self) -> &dyn Any {
53        self
54    }
55
56    fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
57        Ok(None)
58    }
59}
60
61impl fmt::Display for Empty {
62    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
63        write!(f, "{}: empty", self.name)
64    }
65}
66
67/// Trait for values that can be generated in a series
68pub trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static {
69    type StepType: fmt::Debug + Clone + Send + Sync;
70    type ValueType: fmt::Debug + Clone + Send + Sync;
71
72    /// Check if we've reached the end of the series
73    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool;
74
75    /// Advance to the next value in the series
76    fn advance(&mut self, step: &Self::StepType) -> Result<()>;
77
78    /// Create an Arrow array from a vector of values
79    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef>;
80
81    /// Convert self to ValueType for array creation
82    fn to_value_type(&self) -> Self::ValueType;
83
84    /// Display the value for debugging
85    fn display_value(&self) -> String;
86}
87
88impl SeriesValue for i64 {
89    type StepType = i64;
90    type ValueType = i64;
91
92    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
93        reach_end_int64(*self, end, *step, include_end)
94    }
95
96    fn advance(&mut self, step: &Self::StepType) -> Result<()> {
97        *self += step;
98        Ok(())
99    }
100
101    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
102        Ok(Arc::new(Int64Array::from(values)))
103    }
104
105    fn to_value_type(&self) -> Self::ValueType {
106        *self
107    }
108
109    fn display_value(&self) -> String {
110        self.to_string()
111    }
112}
113
114#[derive(Debug, Clone)]
115pub struct TimestampValue {
116    value: i64,
117    parsed_tz: Option<Tz>,
118    tz_str: Option<Arc<str>>,
119}
120
121impl TimestampValue {
122    pub fn value(&self) -> i64 {
123        self.value
124    }
125
126    pub fn tz_str(&self) -> Option<&Arc<str>> {
127        self.tz_str.as_ref()
128    }
129}
130
131impl SeriesValue for TimestampValue {
132    type StepType = IntervalMonthDayNano;
133    type ValueType = i64;
134
135    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
136        let step_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
137
138        if include_end {
139            if step_negative {
140                self.value < end.value
141            } else {
142                self.value > end.value
143            }
144        } else if step_negative {
145            self.value <= end.value
146        } else {
147            self.value >= end.value
148        }
149    }
150
151    fn advance(&mut self, step: &Self::StepType) -> Result<()> {
152        let tz = self
153            .parsed_tz
154            .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
155        let Some(next_ts) =
156            TimestampNanosecondType::add_month_day_nano(self.value, *step, tz)
157        else {
158            return plan_err!(
159                "Failed to add interval {:?} to timestamp {}",
160                step,
161                self.value
162            );
163        };
164        self.value = next_ts;
165        Ok(())
166    }
167
168    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
169        let array = TimestampNanosecondArray::from(values);
170
171        // Use timezone from self (now we have access to tz through &self)
172        let array = match self.tz_str.as_ref() {
173            Some(tz_str) => array.with_timezone(Arc::clone(tz_str)),
174            None => array,
175        };
176
177        Ok(Arc::new(array))
178    }
179
180    fn to_value_type(&self) -> Self::ValueType {
181        self.value
182    }
183
184    fn display_value(&self) -> String {
185        self.value.to_string()
186    }
187}
188
189/// Indicates the arguments used for generating a series.
190#[derive(Debug, Clone)]
191pub enum GenSeriesArgs {
192    /// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated.
193    ContainsNull { name: &'static str },
194    /// Int64Args holds the start, end, and step values for generating integer series when all arguments are not null.
195    Int64Args {
196        start: i64,
197        end: i64,
198        step: i64,
199        /// Indicates whether the end value should be included in the series.
200        include_end: bool,
201        name: &'static str,
202    },
203    /// TimestampArgs holds the start, end, and step values for generating timestamp series when all arguments are not null.
204    TimestampArgs {
205        start: i64,
206        end: i64,
207        step: IntervalMonthDayNano,
208        tz: Option<Arc<str>>,
209        /// Indicates whether the end value should be included in the series.
210        include_end: bool,
211        name: &'static str,
212    },
213    /// DateArgs holds the start, end, and step values for generating date series when all arguments are not null.
214    /// Internally, dates are converted to timestamps and use the timestamp logic.
215    DateArgs {
216        start: i64,
217        end: i64,
218        step: IntervalMonthDayNano,
219        /// Indicates whether the end value should be included in the series.
220        include_end: bool,
221        name: &'static str,
222    },
223}
224
225/// Table that generates a series of integers/timestamps from `start`(inclusive) to `end`, incrementing by step
226#[derive(Debug, Clone)]
227pub struct GenerateSeriesTable {
228    schema: SchemaRef,
229    args: GenSeriesArgs,
230}
231
232impl GenerateSeriesTable {
233    pub fn new(schema: SchemaRef, args: GenSeriesArgs) -> Self {
234        Self { schema, args }
235    }
236
237    pub fn as_generator(
238        &self,
239        batch_size: usize,
240    ) -> Result<Arc<RwLock<dyn LazyBatchGenerator>>> {
241        let generator: Arc<RwLock<dyn LazyBatchGenerator>> = match &self.args {
242            GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })),
243            GenSeriesArgs::Int64Args {
244                start,
245                end,
246                step,
247                include_end,
248                name,
249            } => Arc::new(RwLock::new(GenericSeriesState {
250                schema: self.schema(),
251                start: *start,
252                end: *end,
253                step: *step,
254                current: *start,
255                batch_size,
256                include_end: *include_end,
257                name,
258            })),
259            GenSeriesArgs::TimestampArgs {
260                start,
261                end,
262                step,
263                tz,
264                include_end,
265                name,
266            } => {
267                let parsed_tz = tz
268                    .as_ref()
269                    .map(|s| Tz::from_str(s.as_ref()))
270                    .transpose()
271                    .map_err(|e| {
272                        datafusion_common::internal_datafusion_err!(
273                            "Failed to parse timezone: {e}"
274                        )
275                    })?
276                    .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
277                Arc::new(RwLock::new(GenericSeriesState {
278                    schema: self.schema(),
279                    start: TimestampValue {
280                        value: *start,
281                        parsed_tz: Some(parsed_tz),
282                        tz_str: tz.clone(),
283                    },
284                    end: TimestampValue {
285                        value: *end,
286                        parsed_tz: Some(parsed_tz),
287                        tz_str: tz.clone(),
288                    },
289                    step: *step,
290                    current: TimestampValue {
291                        value: *start,
292                        parsed_tz: Some(parsed_tz),
293                        tz_str: tz.clone(),
294                    },
295                    batch_size,
296                    include_end: *include_end,
297                    name,
298                }))
299            }
300            GenSeriesArgs::DateArgs {
301                start,
302                end,
303                step,
304                include_end,
305                name,
306            } => Arc::new(RwLock::new(GenericSeriesState {
307                schema: self.schema(),
308                start: TimestampValue {
309                    value: *start,
310                    parsed_tz: None,
311                    tz_str: None,
312                },
313                end: TimestampValue {
314                    value: *end,
315                    parsed_tz: None,
316                    tz_str: None,
317                },
318                step: *step,
319                current: TimestampValue {
320                    value: *start,
321                    parsed_tz: None,
322                    tz_str: None,
323                },
324                batch_size,
325                include_end: *include_end,
326                name,
327            })),
328        };
329
330        Ok(generator)
331    }
332}
333
334#[derive(Debug, Clone)]
335pub struct GenericSeriesState<T: SeriesValue> {
336    schema: SchemaRef,
337    start: T,
338    end: T,
339    step: T::StepType,
340    batch_size: usize,
341    current: T,
342    include_end: bool,
343    name: &'static str,
344}
345
346impl<T: SeriesValue> GenericSeriesState<T> {
347    pub fn name(&self) -> &'static str {
348        self.name
349    }
350
351    pub fn batch_size(&self) -> usize {
352        self.batch_size
353    }
354
355    pub fn include_end(&self) -> bool {
356        self.include_end
357    }
358
359    pub fn start(&self) -> &T {
360        &self.start
361    }
362
363    pub fn end(&self) -> &T {
364        &self.end
365    }
366
367    pub fn step(&self) -> &T::StepType {
368        &self.step
369    }
370
371    pub fn current(&self) -> &T {
372        &self.current
373    }
374}
375
376impl<T: SeriesValue> LazyBatchGenerator for GenericSeriesState<T> {
377    fn as_any(&self) -> &dyn Any {
378        self
379    }
380
381    fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
382        let mut buf = Vec::with_capacity(self.batch_size);
383
384        while buf.len() < self.batch_size
385            && !self
386                .current
387                .should_stop(self.end.clone(), &self.step, self.include_end)
388        {
389            buf.push(self.current.to_value_type());
390            self.current.advance(&self.step)?;
391        }
392
393        if buf.is_empty() {
394            return Ok(None);
395        }
396
397        let array = self.current.create_array(buf)?;
398        let batch = RecordBatch::try_new(Arc::clone(&self.schema), vec![array])?;
399        Ok(Some(batch))
400    }
401}
402
403impl<T: SeriesValue> fmt::Display for GenericSeriesState<T> {
404    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
405        write!(
406            f,
407            "{}: start={}, end={}, batch_size={}",
408            self.name,
409            self.start.display_value(),
410            self.end.display_value(),
411            self.batch_size
412        )
413    }
414}
415
416fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool {
417    if step > 0 {
418        if include_end {
419            val > end
420        } else {
421            val >= end
422        }
423    } else if include_end {
424        val < end
425    } else {
426        val <= end
427    }
428}
429
430fn validate_interval_step(
431    step: IntervalMonthDayNano,
432    start: i64,
433    end: i64,
434) -> Result<()> {
435    if step.months == 0 && step.days == 0 && step.nanoseconds == 0 {
436        return plan_err!("Step interval cannot be zero");
437    }
438
439    let step_is_positive = step.months > 0 || step.days > 0 || step.nanoseconds > 0;
440    let step_is_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
441
442    if start > end && step_is_positive {
443        return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series");
444    }
445
446    if start < end && step_is_negative {
447        return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series");
448    }
449
450    Ok(())
451}
452
453#[async_trait]
454impl TableProvider for GenerateSeriesTable {
455    fn as_any(&self) -> &dyn Any {
456        self
457    }
458
459    fn schema(&self) -> SchemaRef {
460        Arc::clone(&self.schema)
461    }
462
463    fn table_type(&self) -> TableType {
464        TableType::Base
465    }
466
467    async fn scan(
468        &self,
469        state: &dyn Session,
470        projection: Option<&Vec<usize>>,
471        _filters: &[Expr],
472        _limit: Option<usize>,
473    ) -> Result<Arc<dyn ExecutionPlan>> {
474        let batch_size = state.config_options().execution.batch_size;
475        let generator = self.as_generator(batch_size)?;
476
477        Ok(Arc::new(
478            LazyMemoryExec::try_new(self.schema(), vec![generator])?
479                .with_projection(projection.cloned()),
480        ))
481    }
482}
483
484#[derive(Debug)]
485struct GenerateSeriesFuncImpl {
486    name: &'static str,
487    include_end: bool,
488}
489
490impl TableFunctionImpl for GenerateSeriesFuncImpl {
491    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
492        if exprs.is_empty() || exprs.len() > 3 {
493            return plan_err!("{} function requires 1 to 3 arguments", self.name);
494        }
495
496        // Determine the data type from the first argument
497        match &exprs[0] {
498            Expr::Literal(
499                // Default to int64 for null
500                ScalarValue::Null | ScalarValue::Int64(_),
501                _,
502            ) => self.call_int64(exprs),
503            Expr::Literal(s, _) if matches!(s.data_type(), DataType::Timestamp(_, _)) => {
504                self.call_timestamp(exprs)
505            }
506            Expr::Literal(s, _) if matches!(s.data_type(), DataType::Date32) => {
507                self.call_date(exprs)
508            }
509            Expr::Literal(scalar, _) => {
510                plan_err!(
511                    "Argument #1 must be an INTEGER, TIMESTAMP, DATE or NULL, got {:?}",
512                    scalar.data_type()
513                )
514            }
515            _ => plan_err!("Arguments must be literals"),
516        }
517    }
518}
519
520impl GenerateSeriesFuncImpl {
521    fn call_int64(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
522        let mut normalize_args = Vec::new();
523        for (expr_index, expr) in exprs.iter().enumerate() {
524            match expr {
525                Expr::Literal(ScalarValue::Null, _) => {}
526                Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n),
527                other => {
528                    return plan_err!(
529                        "Argument #{} must be an INTEGER or NULL, got {:?}",
530                        expr_index + 1,
531                        other
532                    )
533                }
534            };
535        }
536
537        let schema = Arc::new(Schema::new(vec![Field::new(
538            "value",
539            DataType::Int64,
540            false,
541        )]));
542
543        if normalize_args.len() != exprs.len() {
544            // contain null
545            return Ok(Arc::new(GenerateSeriesTable {
546                schema,
547                args: GenSeriesArgs::ContainsNull { name: self.name },
548            }));
549        }
550
551        let (start, end, step) = match &normalize_args[..] {
552            [end] => (0, *end, 1),
553            [start, end] => (*start, *end, 1),
554            [start, end, step] => (*start, *end, *step),
555            _ => {
556                return plan_err!("{} function requires 1 to 3 arguments", self.name);
557            }
558        };
559
560        if start > end && step > 0 {
561            return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series");
562        }
563
564        if start < end && step < 0 {
565            return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series");
566        }
567
568        if step == 0 {
569            return plan_err!("Step cannot be zero");
570        }
571
572        Ok(Arc::new(GenerateSeriesTable {
573            schema,
574            args: GenSeriesArgs::Int64Args {
575                start,
576                end,
577                step,
578                include_end: self.include_end,
579                name: self.name,
580            },
581        }))
582    }
583
584    fn call_timestamp(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
585        if exprs.len() != 3 {
586            return plan_err!(
587                "{} function with timestamps requires exactly 3 arguments",
588                self.name
589            );
590        }
591
592        // Parse start timestamp
593        let (start_ts, tz) = match &exprs[0] {
594            Expr::Literal(ScalarValue::TimestampNanosecond(ts, tz), _) => {
595                (*ts, tz.clone())
596            }
597            other => {
598                return plan_err!(
599                    "First argument must be a timestamp or NULL, got {:?}",
600                    other
601                )
602            }
603        };
604
605        // Parse end timestamp
606        let end_ts = match &exprs[1] {
607            Expr::Literal(ScalarValue::Null, _) => None,
608            Expr::Literal(ScalarValue::TimestampNanosecond(ts, _), _) => *ts,
609            other => {
610                return plan_err!(
611                    "Second argument must be a timestamp or NULL, got {:?}",
612                    other
613                )
614            }
615        };
616
617        // Parse step interval
618        let step_interval = match &exprs[2] {
619            Expr::Literal(ScalarValue::Null, _) => None,
620            Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), _) => *interval,
621            other => {
622                return plan_err!(
623                    "Third argument must be an interval or NULL, got {:?}",
624                    other
625                )
626            }
627        };
628
629        let schema = Arc::new(Schema::new(vec![Field::new(
630            "value",
631            DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
632            false,
633        )]));
634
635        // Check if any argument is null
636        let (Some(start), Some(end), Some(step)) = (start_ts, end_ts, step_interval)
637        else {
638            return Ok(Arc::new(GenerateSeriesTable {
639                schema,
640                args: GenSeriesArgs::ContainsNull { name: self.name },
641            }));
642        };
643
644        // Validate step interval
645        validate_interval_step(step, start, end)?;
646
647        Ok(Arc::new(GenerateSeriesTable {
648            schema,
649            args: GenSeriesArgs::TimestampArgs {
650                start,
651                end,
652                step,
653                tz,
654                include_end: self.include_end,
655                name: self.name,
656            },
657        }))
658    }
659
660    fn call_date(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
661        if exprs.len() != 3 {
662            return plan_err!(
663                "{} function with dates requires exactly 3 arguments",
664                self.name
665            );
666        }
667
668        let schema = Arc::new(Schema::new(vec![Field::new(
669            "value",
670            DataType::Timestamp(TimeUnit::Nanosecond, None),
671            false,
672        )]));
673
674        // Parse start date
675        let start_date = match &exprs[0] {
676            Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
677            Expr::Literal(ScalarValue::Date32(None), _)
678            | Expr::Literal(ScalarValue::Null, _) => {
679                return Ok(Arc::new(GenerateSeriesTable {
680                    schema,
681                    args: GenSeriesArgs::ContainsNull { name: self.name },
682                }));
683            }
684            other => {
685                return plan_err!(
686                    "First argument must be a date or NULL, got {:?}",
687                    other
688                )
689            }
690        };
691
692        // Parse end date
693        let end_date = match &exprs[1] {
694            Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
695            Expr::Literal(ScalarValue::Date32(None), _)
696            | Expr::Literal(ScalarValue::Null, _) => {
697                return Ok(Arc::new(GenerateSeriesTable {
698                    schema,
699                    args: GenSeriesArgs::ContainsNull { name: self.name },
700                }));
701            }
702            other => {
703                return plan_err!(
704                    "Second argument must be a date or NULL, got {:?}",
705                    other
706                )
707            }
708        };
709
710        // Parse step interval
711        let step_interval = match &exprs[2] {
712            Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval)), _) => {
713                *interval
714            }
715            Expr::Literal(ScalarValue::IntervalMonthDayNano(None), _)
716            | Expr::Literal(ScalarValue::Null, _) => {
717                return Ok(Arc::new(GenerateSeriesTable {
718                    schema,
719                    args: GenSeriesArgs::ContainsNull { name: self.name },
720                }));
721            }
722            other => {
723                return plan_err!(
724                    "Third argument must be an interval or NULL, got {:?}",
725                    other
726                )
727            }
728        };
729
730        // Convert Date32 (days since epoch) to timestamp nanoseconds (nanoseconds since epoch)
731        // Date32 is days since 1970-01-01, so multiply by nanoseconds per day
732        const NANOS_PER_DAY: i64 = 24 * 60 * 60 * 1_000_000_000;
733
734        let start_ts = start_date as i64 * NANOS_PER_DAY;
735        let end_ts = end_date as i64 * NANOS_PER_DAY;
736
737        // Validate step interval
738        validate_interval_step(step_interval, start_ts, end_ts)?;
739
740        Ok(Arc::new(GenerateSeriesTable {
741            schema,
742            args: GenSeriesArgs::DateArgs {
743                start: start_ts,
744                end: end_ts,
745                step: step_interval,
746                include_end: self.include_end,
747                name: self.name,
748            },
749        }))
750    }
751}
752
753#[derive(Debug)]
754pub struct GenerateSeriesFunc {}
755
756impl TableFunctionImpl for GenerateSeriesFunc {
757    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
758        let impl_func = GenerateSeriesFuncImpl {
759            name: "generate_series",
760            include_end: true,
761        };
762        impl_func.call(exprs)
763    }
764}
765
766#[derive(Debug)]
767pub struct RangeFunc {}
768
769impl TableFunctionImpl for RangeFunc {
770    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
771        let impl_func = GenerateSeriesFuncImpl {
772            name: "range",
773            include_end: false,
774        };
775        impl_func.call(exprs)
776    }
777}