datafusion_functions_table/
generate_series.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::timezone::Tz;
19use arrow::array::types::TimestampNanosecondType;
20use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray};
21use arrow::datatypes::{
22    DataType, Field, IntervalMonthDayNano, Schema, SchemaRef, TimeUnit,
23};
24use arrow::record_batch::RecordBatch;
25use async_trait::async_trait;
26use datafusion_catalog::Session;
27use datafusion_catalog::TableFunctionImpl;
28use datafusion_catalog::TableProvider;
29use datafusion_common::{plan_err, Result, ScalarValue};
30use datafusion_expr::{Expr, TableType};
31use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
32use datafusion_physical_plan::ExecutionPlan;
33use parking_lot::RwLock;
34use std::any::Any;
35use std::fmt;
36use std::str::FromStr;
37use std::sync::Arc;
38
39/// Empty generator that produces no rows - used when series arguments contain null values
40#[derive(Debug, Clone)]
41pub struct Empty {
42    name: &'static str,
43}
44
45impl Empty {
46    pub fn name(&self) -> &'static str {
47        self.name
48    }
49}
50
51impl LazyBatchGenerator for Empty {
52    fn as_any(&self) -> &dyn Any {
53        self
54    }
55
56    fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
57        Ok(None)
58    }
59}
60
61impl fmt::Display for Empty {
62    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
63        write!(f, "{}: empty", self.name)
64    }
65}
66
67/// Trait for values that can be generated in a series
68pub trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static {
69    type StepType: fmt::Debug + Clone + Send + Sync;
70    type ValueType: fmt::Debug + Clone + Send + Sync;
71
72    /// Check if we've reached the end of the series
73    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool;
74
75    /// Advance to the next value in the series
76    fn advance(&mut self, step: &Self::StepType) -> Result<()>;
77
78    /// Create an Arrow array from a vector of values
79    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef>;
80
81    /// Convert self to ValueType for array creation
82    fn to_value_type(&self) -> Self::ValueType;
83
84    /// Display the value for debugging
85    fn display_value(&self) -> String;
86}
87
88impl SeriesValue for i64 {
89    type StepType = i64;
90    type ValueType = i64;
91
92    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
93        reach_end_int64(*self, end, *step, include_end)
94    }
95
96    fn advance(&mut self, step: &Self::StepType) -> Result<()> {
97        *self += step;
98        Ok(())
99    }
100
101    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
102        Ok(Arc::new(Int64Array::from(values)))
103    }
104
105    fn to_value_type(&self) -> Self::ValueType {
106        *self
107    }
108
109    fn display_value(&self) -> String {
110        self.to_string()
111    }
112}
113
114#[derive(Debug, Clone)]
115pub struct TimestampValue {
116    value: i64,
117    parsed_tz: Option<Tz>,
118    tz_str: Option<Arc<str>>,
119}
120
121impl TimestampValue {
122    pub fn value(&self) -> i64 {
123        self.value
124    }
125
126    pub fn tz_str(&self) -> Option<&Arc<str>> {
127        self.tz_str.as_ref()
128    }
129}
130
131impl SeriesValue for TimestampValue {
132    type StepType = IntervalMonthDayNano;
133    type ValueType = i64;
134
135    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
136        let step_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
137
138        if include_end {
139            if step_negative {
140                self.value < end.value
141            } else {
142                self.value > end.value
143            }
144        } else if step_negative {
145            self.value <= end.value
146        } else {
147            self.value >= end.value
148        }
149    }
150
151    fn advance(&mut self, step: &Self::StepType) -> Result<()> {
152        let tz = self
153            .parsed_tz
154            .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
155        let Some(next_ts) =
156            TimestampNanosecondType::add_month_day_nano(self.value, *step, tz)
157        else {
158            return plan_err!(
159                "Failed to add interval {:?} to timestamp {}",
160                step,
161                self.value
162            );
163        };
164        self.value = next_ts;
165        Ok(())
166    }
167
168    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
169        let array = TimestampNanosecondArray::from(values);
170
171        // Use timezone from self (now we have access to tz through &self)
172        let array = match self.tz_str.as_ref() {
173            Some(tz_str) => array.with_timezone(Arc::clone(tz_str)),
174            None => array,
175        };
176
177        Ok(Arc::new(array))
178    }
179
180    fn to_value_type(&self) -> Self::ValueType {
181        self.value
182    }
183
184    fn display_value(&self) -> String {
185        self.value.to_string()
186    }
187}
188
189/// Indicates the arguments used for generating a series.
190#[derive(Debug, Clone)]
191pub enum GenSeriesArgs {
192    /// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated.
193    ContainsNull { name: &'static str },
194    /// Int64Args holds the start, end, and step values for generating integer series when all arguments are not null.
195    Int64Args {
196        start: i64,
197        end: i64,
198        step: i64,
199        /// Indicates whether the end value should be included in the series.
200        include_end: bool,
201        name: &'static str,
202    },
203    /// TimestampArgs holds the start, end, and step values for generating timestamp series when all arguments are not null.
204    TimestampArgs {
205        start: i64,
206        end: i64,
207        step: IntervalMonthDayNano,
208        tz: Option<Arc<str>>,
209        /// Indicates whether the end value should be included in the series.
210        include_end: bool,
211        name: &'static str,
212    },
213    /// DateArgs holds the start, end, and step values for generating date series when all arguments are not null.
214    /// Internally, dates are converted to timestamps and use the timestamp logic.
215    DateArgs {
216        start: i64,
217        end: i64,
218        step: IntervalMonthDayNano,
219        /// Indicates whether the end value should be included in the series.
220        include_end: bool,
221        name: &'static str,
222    },
223}
224
225/// Table that generates a series of integers/timestamps from `start`(inclusive) to `end`, incrementing by step
226#[derive(Debug, Clone)]
227pub struct GenerateSeriesTable {
228    schema: SchemaRef,
229    args: GenSeriesArgs,
230}
231
232impl GenerateSeriesTable {
233    pub fn new(schema: SchemaRef, args: GenSeriesArgs) -> Self {
234        Self { schema, args }
235    }
236
237    pub fn as_generator(
238        &self,
239        batch_size: usize,
240    ) -> Result<Arc<RwLock<dyn LazyBatchGenerator>>> {
241        let generator: Arc<RwLock<dyn LazyBatchGenerator>> = match &self.args {
242            GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })),
243            GenSeriesArgs::Int64Args {
244                start,
245                end,
246                step,
247                include_end,
248                name,
249            } => Arc::new(RwLock::new(GenericSeriesState {
250                schema: self.schema(),
251                start: *start,
252                end: *end,
253                step: *step,
254                current: *start,
255                batch_size,
256                include_end: *include_end,
257                name,
258            })),
259            GenSeriesArgs::TimestampArgs {
260                start,
261                end,
262                step,
263                tz,
264                include_end,
265                name,
266            } => {
267                let parsed_tz = tz
268                    .as_ref()
269                    .map(|s| Tz::from_str(s.as_ref()))
270                    .transpose()
271                    .map_err(|e| {
272                        datafusion_common::DataFusionError::Internal(format!(
273                            "Failed to parse timezone: {e}"
274                        ))
275                    })?
276                    .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
277                Arc::new(RwLock::new(GenericSeriesState {
278                    schema: self.schema(),
279                    start: TimestampValue {
280                        value: *start,
281                        parsed_tz: Some(parsed_tz),
282                        tz_str: tz.clone(),
283                    },
284                    end: TimestampValue {
285                        value: *end,
286                        parsed_tz: Some(parsed_tz),
287                        tz_str: tz.clone(),
288                    },
289                    step: *step,
290                    current: TimestampValue {
291                        value: *start,
292                        parsed_tz: Some(parsed_tz),
293                        tz_str: tz.clone(),
294                    },
295                    batch_size,
296                    include_end: *include_end,
297                    name,
298                }))
299            }
300            GenSeriesArgs::DateArgs {
301                start,
302                end,
303                step,
304                include_end,
305                name,
306            } => Arc::new(RwLock::new(GenericSeriesState {
307                schema: self.schema(),
308                start: TimestampValue {
309                    value: *start,
310                    parsed_tz: None,
311                    tz_str: None,
312                },
313                end: TimestampValue {
314                    value: *end,
315                    parsed_tz: None,
316                    tz_str: None,
317                },
318                step: *step,
319                current: TimestampValue {
320                    value: *start,
321                    parsed_tz: None,
322                    tz_str: None,
323                },
324                batch_size,
325                include_end: *include_end,
326                name,
327            })),
328        };
329
330        Ok(generator)
331    }
332}
333
334#[derive(Debug, Clone)]
335pub struct GenericSeriesState<T: SeriesValue> {
336    schema: SchemaRef,
337    start: T,
338    end: T,
339    step: T::StepType,
340    batch_size: usize,
341    current: T,
342    include_end: bool,
343    name: &'static str,
344}
345
346impl<T: SeriesValue> GenericSeriesState<T> {
347    pub fn name(&self) -> &'static str {
348        self.name
349    }
350
351    pub fn batch_size(&self) -> usize {
352        self.batch_size
353    }
354
355    pub fn include_end(&self) -> bool {
356        self.include_end
357    }
358
359    pub fn start(&self) -> &T {
360        &self.start
361    }
362
363    pub fn end(&self) -> &T {
364        &self.end
365    }
366
367    pub fn step(&self) -> &T::StepType {
368        &self.step
369    }
370
371    pub fn current(&self) -> &T {
372        &self.current
373    }
374}
375
376impl<T: SeriesValue> LazyBatchGenerator for GenericSeriesState<T> {
377    fn as_any(&self) -> &dyn Any {
378        self
379    }
380
381    fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
382        let mut buf = Vec::with_capacity(self.batch_size);
383
384        while buf.len() < self.batch_size
385            && !self
386                .current
387                .should_stop(self.end.clone(), &self.step, self.include_end)
388        {
389            buf.push(self.current.to_value_type());
390            self.current.advance(&self.step)?;
391        }
392
393        if buf.is_empty() {
394            return Ok(None);
395        }
396
397        let array = self.current.create_array(buf)?;
398        let batch = RecordBatch::try_new(Arc::clone(&self.schema), vec![array])?;
399        Ok(Some(batch))
400    }
401}
402
403impl<T: SeriesValue> fmt::Display for GenericSeriesState<T> {
404    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
405        write!(
406            f,
407            "{}: start={}, end={}, batch_size={}",
408            self.name,
409            self.start.display_value(),
410            self.end.display_value(),
411            self.batch_size
412        )
413    }
414}
415
416fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool {
417    if step > 0 {
418        if include_end {
419            val > end
420        } else {
421            val >= end
422        }
423    } else if include_end {
424        val < end
425    } else {
426        val <= end
427    }
428}
429
430fn validate_interval_step(
431    step: IntervalMonthDayNano,
432    start: i64,
433    end: i64,
434) -> Result<()> {
435    if step.months == 0 && step.days == 0 && step.nanoseconds == 0 {
436        return plan_err!("Step interval cannot be zero");
437    }
438
439    let step_is_positive = step.months > 0 || step.days > 0 || step.nanoseconds > 0;
440    let step_is_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
441
442    if start > end && step_is_positive {
443        return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series");
444    }
445
446    if start < end && step_is_negative {
447        return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series");
448    }
449
450    Ok(())
451}
452
453#[async_trait]
454impl TableProvider for GenerateSeriesTable {
455    fn as_any(&self) -> &dyn Any {
456        self
457    }
458
459    fn schema(&self) -> SchemaRef {
460        Arc::clone(&self.schema)
461    }
462
463    fn table_type(&self) -> TableType {
464        TableType::Base
465    }
466
467    async fn scan(
468        &self,
469        state: &dyn Session,
470        projection: Option<&Vec<usize>>,
471        _filters: &[Expr],
472        _limit: Option<usize>,
473    ) -> Result<Arc<dyn ExecutionPlan>> {
474        let batch_size = state.config_options().execution.batch_size;
475        let schema = match projection {
476            Some(projection) => Arc::new(self.schema.project(projection)?),
477            None => self.schema(),
478        };
479
480        let generator = self.as_generator(batch_size)?;
481
482        Ok(Arc::new(LazyMemoryExec::try_new(schema, vec![generator])?))
483    }
484}
485
486#[derive(Debug)]
487struct GenerateSeriesFuncImpl {
488    name: &'static str,
489    include_end: bool,
490}
491
492impl TableFunctionImpl for GenerateSeriesFuncImpl {
493    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
494        if exprs.is_empty() || exprs.len() > 3 {
495            return plan_err!("{} function requires 1 to 3 arguments", self.name);
496        }
497
498        // Determine the data type from the first argument
499        match &exprs[0] {
500            Expr::Literal(
501                // Default to int64 for null
502                ScalarValue::Null | ScalarValue::Int64(_),
503                _,
504            ) => self.call_int64(exprs),
505            Expr::Literal(s, _) if matches!(s.data_type(), DataType::Timestamp(_, _)) => {
506                self.call_timestamp(exprs)
507            }
508            Expr::Literal(s, _) if matches!(s.data_type(), DataType::Date32) => {
509                self.call_date(exprs)
510            }
511            Expr::Literal(scalar, _) => {
512                plan_err!(
513                    "Argument #1 must be an INTEGER, TIMESTAMP, DATE or NULL, got {:?}",
514                    scalar.data_type()
515                )
516            }
517            _ => plan_err!("Arguments must be literals"),
518        }
519    }
520}
521
522impl GenerateSeriesFuncImpl {
523    fn call_int64(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
524        let mut normalize_args = Vec::new();
525        for (expr_index, expr) in exprs.iter().enumerate() {
526            match expr {
527                Expr::Literal(ScalarValue::Null, _) => {}
528                Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n),
529                other => {
530                    return plan_err!(
531                        "Argument #{} must be an INTEGER or NULL, got {:?}",
532                        expr_index + 1,
533                        other
534                    )
535                }
536            };
537        }
538
539        let schema = Arc::new(Schema::new(vec![Field::new(
540            "value",
541            DataType::Int64,
542            false,
543        )]));
544
545        if normalize_args.len() != exprs.len() {
546            // contain null
547            return Ok(Arc::new(GenerateSeriesTable {
548                schema,
549                args: GenSeriesArgs::ContainsNull { name: self.name },
550            }));
551        }
552
553        let (start, end, step) = match &normalize_args[..] {
554            [end] => (0, *end, 1),
555            [start, end] => (*start, *end, 1),
556            [start, end, step] => (*start, *end, *step),
557            _ => {
558                return plan_err!("{} function requires 1 to 3 arguments", self.name);
559            }
560        };
561
562        if start > end && step > 0 {
563            return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series");
564        }
565
566        if start < end && step < 0 {
567            return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series");
568        }
569
570        if step == 0 {
571            return plan_err!("Step cannot be zero");
572        }
573
574        Ok(Arc::new(GenerateSeriesTable {
575            schema,
576            args: GenSeriesArgs::Int64Args {
577                start,
578                end,
579                step,
580                include_end: self.include_end,
581                name: self.name,
582            },
583        }))
584    }
585
586    fn call_timestamp(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
587        if exprs.len() != 3 {
588            return plan_err!(
589                "{} function with timestamps requires exactly 3 arguments",
590                self.name
591            );
592        }
593
594        // Parse start timestamp
595        let (start_ts, tz) = match &exprs[0] {
596            Expr::Literal(ScalarValue::TimestampNanosecond(ts, tz), _) => {
597                (*ts, tz.clone())
598            }
599            other => {
600                return plan_err!(
601                    "First argument must be a timestamp or NULL, got {:?}",
602                    other
603                )
604            }
605        };
606
607        // Parse end timestamp
608        let end_ts = match &exprs[1] {
609            Expr::Literal(ScalarValue::Null, _) => None,
610            Expr::Literal(ScalarValue::TimestampNanosecond(ts, _), _) => *ts,
611            other => {
612                return plan_err!(
613                    "Second argument must be a timestamp or NULL, got {:?}",
614                    other
615                )
616            }
617        };
618
619        // Parse step interval
620        let step_interval = match &exprs[2] {
621            Expr::Literal(ScalarValue::Null, _) => None,
622            Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), _) => *interval,
623            other => {
624                return plan_err!(
625                    "Third argument must be an interval or NULL, got {:?}",
626                    other
627                )
628            }
629        };
630
631        let schema = Arc::new(Schema::new(vec![Field::new(
632            "value",
633            DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
634            false,
635        )]));
636
637        // Check if any argument is null
638        let (Some(start), Some(end), Some(step)) = (start_ts, end_ts, step_interval)
639        else {
640            return Ok(Arc::new(GenerateSeriesTable {
641                schema,
642                args: GenSeriesArgs::ContainsNull { name: self.name },
643            }));
644        };
645
646        // Validate step interval
647        validate_interval_step(step, start, end)?;
648
649        Ok(Arc::new(GenerateSeriesTable {
650            schema,
651            args: GenSeriesArgs::TimestampArgs {
652                start,
653                end,
654                step,
655                tz,
656                include_end: self.include_end,
657                name: self.name,
658            },
659        }))
660    }
661
662    fn call_date(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
663        if exprs.len() != 3 {
664            return plan_err!(
665                "{} function with dates requires exactly 3 arguments",
666                self.name
667            );
668        }
669
670        let schema = Arc::new(Schema::new(vec![Field::new(
671            "value",
672            DataType::Timestamp(TimeUnit::Nanosecond, None),
673            false,
674        )]));
675
676        // Parse start date
677        let start_date = match &exprs[0] {
678            Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
679            Expr::Literal(ScalarValue::Date32(None), _)
680            | Expr::Literal(ScalarValue::Null, _) => {
681                return Ok(Arc::new(GenerateSeriesTable {
682                    schema,
683                    args: GenSeriesArgs::ContainsNull { name: self.name },
684                }));
685            }
686            other => {
687                return plan_err!(
688                    "First argument must be a date or NULL, got {:?}",
689                    other
690                )
691            }
692        };
693
694        // Parse end date
695        let end_date = match &exprs[1] {
696            Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
697            Expr::Literal(ScalarValue::Date32(None), _)
698            | Expr::Literal(ScalarValue::Null, _) => {
699                return Ok(Arc::new(GenerateSeriesTable {
700                    schema,
701                    args: GenSeriesArgs::ContainsNull { name: self.name },
702                }));
703            }
704            other => {
705                return plan_err!(
706                    "Second argument must be a date or NULL, got {:?}",
707                    other
708                )
709            }
710        };
711
712        // Parse step interval
713        let step_interval = match &exprs[2] {
714            Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval)), _) => {
715                *interval
716            }
717            Expr::Literal(ScalarValue::IntervalMonthDayNano(None), _)
718            | Expr::Literal(ScalarValue::Null, _) => {
719                return Ok(Arc::new(GenerateSeriesTable {
720                    schema,
721                    args: GenSeriesArgs::ContainsNull { name: self.name },
722                }));
723            }
724            other => {
725                return plan_err!(
726                    "Third argument must be an interval or NULL, got {:?}",
727                    other
728                )
729            }
730        };
731
732        // Convert Date32 (days since epoch) to timestamp nanoseconds (nanoseconds since epoch)
733        // Date32 is days since 1970-01-01, so multiply by nanoseconds per day
734        const NANOS_PER_DAY: i64 = 24 * 60 * 60 * 1_000_000_000;
735
736        let start_ts = start_date as i64 * NANOS_PER_DAY;
737        let end_ts = end_date as i64 * NANOS_PER_DAY;
738
739        // Validate step interval
740        validate_interval_step(step_interval, start_ts, end_ts)?;
741
742        Ok(Arc::new(GenerateSeriesTable {
743            schema,
744            args: GenSeriesArgs::DateArgs {
745                start: start_ts,
746                end: end_ts,
747                step: step_interval,
748                include_end: self.include_end,
749                name: self.name,
750            },
751        }))
752    }
753}
754
755#[derive(Debug)]
756pub struct GenerateSeriesFunc {}
757
758impl TableFunctionImpl for GenerateSeriesFunc {
759    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
760        let impl_func = GenerateSeriesFuncImpl {
761            name: "generate_series",
762            include_end: true,
763        };
764        impl_func.call(exprs)
765    }
766}
767
768#[derive(Debug)]
769pub struct RangeFunc {}
770
771impl TableFunctionImpl for RangeFunc {
772    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
773        let impl_func = GenerateSeriesFuncImpl {
774            name: "range",
775            include_end: false,
776        };
777        impl_func.call(exprs)
778    }
779}