datafusion_functions_table/
generate_series.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::timezone::Tz;
19use arrow::array::types::TimestampNanosecondType;
20use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray};
21use arrow::datatypes::{
22    DataType, Field, IntervalMonthDayNano, Schema, SchemaRef, TimeUnit,
23};
24use arrow::record_batch::RecordBatch;
25use async_trait::async_trait;
26use datafusion_catalog::Session;
27use datafusion_catalog::TableFunctionImpl;
28use datafusion_catalog::TableProvider;
29use datafusion_common::{plan_err, Result, ScalarValue};
30use datafusion_expr::{Expr, TableType};
31use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
32use datafusion_physical_plan::ExecutionPlan;
33use parking_lot::RwLock;
34use std::fmt;
35use std::str::FromStr;
36use std::sync::Arc;
37
38/// Empty generator that produces no rows - used when series arguments contain null values
39#[derive(Debug, Clone)]
40struct Empty {
41    name: &'static str,
42}
43
44impl LazyBatchGenerator for Empty {
45    fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
46        Ok(None)
47    }
48}
49
50impl fmt::Display for Empty {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        write!(f, "{}: empty", self.name)
53    }
54}
55
56/// Trait for values that can be generated in a series
57trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static {
58    type StepType: fmt::Debug + Clone + Send + Sync;
59    type ValueType: fmt::Debug + Clone + Send + Sync;
60
61    /// Check if we've reached the end of the series
62    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool;
63
64    /// Advance to the next value in the series
65    fn advance(&mut self, step: &Self::StepType) -> Result<()>;
66
67    /// Create an Arrow array from a vector of values
68    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef>;
69
70    /// Convert self to ValueType for array creation
71    fn to_value_type(&self) -> Self::ValueType;
72
73    /// Display the value for debugging
74    fn display_value(&self) -> String;
75}
76
77impl SeriesValue for i64 {
78    type StepType = i64;
79    type ValueType = i64;
80
81    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
82        reach_end_int64(*self, end, *step, include_end)
83    }
84
85    fn advance(&mut self, step: &Self::StepType) -> Result<()> {
86        *self += step;
87        Ok(())
88    }
89
90    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
91        Ok(Arc::new(Int64Array::from(values)))
92    }
93
94    fn to_value_type(&self) -> Self::ValueType {
95        *self
96    }
97
98    fn display_value(&self) -> String {
99        self.to_string()
100    }
101}
102
103#[derive(Debug, Clone)]
104struct TimestampValue {
105    value: i64,
106    parsed_tz: Option<Tz>,
107    tz_str: Option<Arc<str>>,
108}
109
110impl SeriesValue for TimestampValue {
111    type StepType = IntervalMonthDayNano;
112    type ValueType = i64;
113
114    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
115        let step_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
116
117        if include_end {
118            if step_negative {
119                self.value < end.value
120            } else {
121                self.value > end.value
122            }
123        } else if step_negative {
124            self.value <= end.value
125        } else {
126            self.value >= end.value
127        }
128    }
129
130    fn advance(&mut self, step: &Self::StepType) -> Result<()> {
131        let tz = self
132            .parsed_tz
133            .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
134        let Some(next_ts) =
135            TimestampNanosecondType::add_month_day_nano(self.value, *step, tz)
136        else {
137            return plan_err!(
138                "Failed to add interval {:?} to timestamp {}",
139                step,
140                self.value
141            );
142        };
143        self.value = next_ts;
144        Ok(())
145    }
146
147    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
148        let array = TimestampNanosecondArray::from(values);
149
150        // Use timezone from self (now we have access to tz through &self)
151        let array = match self.tz_str.as_ref() {
152            Some(tz_str) => array.with_timezone(Arc::clone(tz_str)),
153            None => array,
154        };
155
156        Ok(Arc::new(array))
157    }
158
159    fn to_value_type(&self) -> Self::ValueType {
160        self.value
161    }
162
163    fn display_value(&self) -> String {
164        self.value.to_string()
165    }
166}
167
168/// Indicates the arguments used for generating a series.
169#[derive(Debug, Clone)]
170enum GenSeriesArgs {
171    /// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated.
172    ContainsNull { name: &'static str },
173    /// Int64Args holds the start, end, and step values for generating integer series when all arguments are not null.
174    Int64Args {
175        start: i64,
176        end: i64,
177        step: i64,
178        /// Indicates whether the end value should be included in the series.
179        include_end: bool,
180        name: &'static str,
181    },
182    /// TimestampArgs holds the start, end, and step values for generating timestamp series when all arguments are not null.
183    TimestampArgs {
184        start: i64,
185        end: i64,
186        step: IntervalMonthDayNano,
187        tz: Option<Arc<str>>,
188        /// Indicates whether the end value should be included in the series.
189        include_end: bool,
190        name: &'static str,
191    },
192    /// DateArgs holds the start, end, and step values for generating date series when all arguments are not null.
193    /// Internally, dates are converted to timestamps and use the timestamp logic.
194    DateArgs {
195        start: i64,
196        end: i64,
197        step: IntervalMonthDayNano,
198        /// Indicates whether the end value should be included in the series.
199        include_end: bool,
200        name: &'static str,
201    },
202}
203
204/// Table that generates a series of integers/timestamps from `start`(inclusive) to `end`, incrementing by step
205#[derive(Debug, Clone)]
206struct GenerateSeriesTable {
207    schema: SchemaRef,
208    args: GenSeriesArgs,
209}
210
211#[derive(Debug, Clone)]
212struct GenericSeriesState<T: SeriesValue> {
213    schema: SchemaRef,
214    start: T,
215    end: T,
216    step: T::StepType,
217    batch_size: usize,
218    current: T,
219    include_end: bool,
220    name: &'static str,
221}
222
223impl<T: SeriesValue> LazyBatchGenerator for GenericSeriesState<T> {
224    fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
225        let mut buf = Vec::with_capacity(self.batch_size);
226
227        while buf.len() < self.batch_size
228            && !self
229                .current
230                .should_stop(self.end.clone(), &self.step, self.include_end)
231        {
232            buf.push(self.current.to_value_type());
233            self.current.advance(&self.step)?;
234        }
235
236        if buf.is_empty() {
237            return Ok(None);
238        }
239
240        let array = self.current.create_array(buf)?;
241        let batch = RecordBatch::try_new(Arc::clone(&self.schema), vec![array])?;
242        Ok(Some(batch))
243    }
244}
245
246impl<T: SeriesValue> fmt::Display for GenericSeriesState<T> {
247    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
248        write!(
249            f,
250            "{}: start={}, end={}, batch_size={}",
251            self.name,
252            self.start.display_value(),
253            self.end.display_value(),
254            self.batch_size
255        )
256    }
257}
258
259fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool {
260    if step > 0 {
261        if include_end {
262            val > end
263        } else {
264            val >= end
265        }
266    } else if include_end {
267        val < end
268    } else {
269        val <= end
270    }
271}
272
273fn validate_interval_step(
274    step: IntervalMonthDayNano,
275    start: i64,
276    end: i64,
277) -> Result<()> {
278    if step.months == 0 && step.days == 0 && step.nanoseconds == 0 {
279        return plan_err!("Step interval cannot be zero");
280    }
281
282    let step_is_positive = step.months > 0 || step.days > 0 || step.nanoseconds > 0;
283    let step_is_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
284
285    if start > end && step_is_positive {
286        return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series");
287    }
288
289    if start < end && step_is_negative {
290        return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series");
291    }
292
293    Ok(())
294}
295
296#[async_trait]
297impl TableProvider for GenerateSeriesTable {
298    fn as_any(&self) -> &dyn std::any::Any {
299        self
300    }
301
302    fn schema(&self) -> SchemaRef {
303        Arc::clone(&self.schema)
304    }
305
306    fn table_type(&self) -> TableType {
307        TableType::Base
308    }
309
310    async fn scan(
311        &self,
312        state: &dyn Session,
313        projection: Option<&Vec<usize>>,
314        _filters: &[Expr],
315        _limit: Option<usize>,
316    ) -> Result<Arc<dyn ExecutionPlan>> {
317        let batch_size = state.config_options().execution.batch_size;
318        let schema = match projection {
319            Some(projection) => Arc::new(self.schema.project(projection)?),
320            None => self.schema(),
321        };
322        let generator: Arc<RwLock<dyn LazyBatchGenerator>> = match &self.args {
323            GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })),
324            GenSeriesArgs::Int64Args {
325                start,
326                end,
327                step,
328                include_end,
329                name,
330            } => Arc::new(RwLock::new(GenericSeriesState {
331                schema: self.schema(),
332                start: *start,
333                end: *end,
334                step: *step,
335                current: *start,
336                batch_size,
337                include_end: *include_end,
338                name,
339            })),
340            GenSeriesArgs::TimestampArgs {
341                start,
342                end,
343                step,
344                tz,
345                include_end,
346                name,
347            } => {
348                let parsed_tz = tz
349                    .as_ref()
350                    .map(|s| Tz::from_str(s.as_ref()))
351                    .transpose()
352                    .map_err(|e| {
353                        datafusion_common::DataFusionError::Internal(format!(
354                            "Failed to parse timezone: {e}"
355                        ))
356                    })?
357                    .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
358                Arc::new(RwLock::new(GenericSeriesState {
359                    schema: self.schema(),
360                    start: TimestampValue {
361                        value: *start,
362                        parsed_tz: Some(parsed_tz),
363                        tz_str: tz.clone(),
364                    },
365                    end: TimestampValue {
366                        value: *end,
367                        parsed_tz: Some(parsed_tz),
368                        tz_str: tz.clone(),
369                    },
370                    step: *step,
371                    current: TimestampValue {
372                        value: *start,
373                        parsed_tz: Some(parsed_tz),
374                        tz_str: tz.clone(),
375                    },
376                    batch_size,
377                    include_end: *include_end,
378                    name,
379                }))
380            }
381            GenSeriesArgs::DateArgs {
382                start,
383                end,
384                step,
385                include_end,
386                name,
387            } => Arc::new(RwLock::new(GenericSeriesState {
388                schema: self.schema(),
389                start: TimestampValue {
390                    value: *start,
391                    parsed_tz: None,
392                    tz_str: None,
393                },
394                end: TimestampValue {
395                    value: *end,
396                    parsed_tz: None,
397                    tz_str: None,
398                },
399                step: *step,
400                current: TimestampValue {
401                    value: *start,
402                    parsed_tz: None,
403                    tz_str: None,
404                },
405                batch_size,
406                include_end: *include_end,
407                name,
408            })),
409        };
410
411        Ok(Arc::new(LazyMemoryExec::try_new(schema, vec![generator])?))
412    }
413}
414
415#[derive(Debug)]
416struct GenerateSeriesFuncImpl {
417    name: &'static str,
418    include_end: bool,
419}
420
421impl TableFunctionImpl for GenerateSeriesFuncImpl {
422    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
423        if exprs.is_empty() || exprs.len() > 3 {
424            return plan_err!("{} function requires 1 to 3 arguments", self.name);
425        }
426
427        // Determine the data type from the first argument
428        match &exprs[0] {
429            Expr::Literal(
430                // Default to int64 for null
431                ScalarValue::Null | ScalarValue::Int64(_),
432                _,
433            ) => self.call_int64(exprs),
434            Expr::Literal(s, _) if matches!(s.data_type(), DataType::Timestamp(_, _)) => {
435                self.call_timestamp(exprs)
436            }
437            Expr::Literal(s, _) if matches!(s.data_type(), DataType::Date32) => {
438                self.call_date(exprs)
439            }
440            Expr::Literal(scalar, _) => {
441                plan_err!(
442                    "Argument #1 must be an INTEGER, TIMESTAMP, DATE or NULL, got {:?}",
443                    scalar.data_type()
444                )
445            }
446            _ => plan_err!("Arguments must be literals"),
447        }
448    }
449}
450
451impl GenerateSeriesFuncImpl {
452    fn call_int64(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
453        let mut normalize_args = Vec::new();
454        for (expr_index, expr) in exprs.iter().enumerate() {
455            match expr {
456                Expr::Literal(ScalarValue::Null, _) => {}
457                Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n),
458                other => {
459                    return plan_err!(
460                        "Argument #{} must be an INTEGER or NULL, got {:?}",
461                        expr_index + 1,
462                        other
463                    )
464                }
465            };
466        }
467
468        let schema = Arc::new(Schema::new(vec![Field::new(
469            "value",
470            DataType::Int64,
471            false,
472        )]));
473
474        if normalize_args.len() != exprs.len() {
475            // contain null
476            return Ok(Arc::new(GenerateSeriesTable {
477                schema,
478                args: GenSeriesArgs::ContainsNull { name: self.name },
479            }));
480        }
481
482        let (start, end, step) = match &normalize_args[..] {
483            [end] => (0, *end, 1),
484            [start, end] => (*start, *end, 1),
485            [start, end, step] => (*start, *end, *step),
486            _ => {
487                return plan_err!("{} function requires 1 to 3 arguments", self.name);
488            }
489        };
490
491        if start > end && step > 0 {
492            return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series");
493        }
494
495        if start < end && step < 0 {
496            return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series");
497        }
498
499        if step == 0 {
500            return plan_err!("Step cannot be zero");
501        }
502
503        Ok(Arc::new(GenerateSeriesTable {
504            schema,
505            args: GenSeriesArgs::Int64Args {
506                start,
507                end,
508                step,
509                include_end: self.include_end,
510                name: self.name,
511            },
512        }))
513    }
514
515    fn call_timestamp(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
516        if exprs.len() != 3 {
517            return plan_err!(
518                "{} function with timestamps requires exactly 3 arguments",
519                self.name
520            );
521        }
522
523        // Parse start timestamp
524        let (start_ts, tz) = match &exprs[0] {
525            Expr::Literal(ScalarValue::TimestampNanosecond(ts, tz), _) => {
526                (*ts, tz.clone())
527            }
528            other => {
529                return plan_err!(
530                    "First argument must be a timestamp or NULL, got {:?}",
531                    other
532                )
533            }
534        };
535
536        // Parse end timestamp
537        let end_ts = match &exprs[1] {
538            Expr::Literal(ScalarValue::Null, _) => None,
539            Expr::Literal(ScalarValue::TimestampNanosecond(ts, _), _) => *ts,
540            other => {
541                return plan_err!(
542                    "Second argument must be a timestamp or NULL, got {:?}",
543                    other
544                )
545            }
546        };
547
548        // Parse step interval
549        let step_interval = match &exprs[2] {
550            Expr::Literal(ScalarValue::Null, _) => None,
551            Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), _) => *interval,
552            other => {
553                return plan_err!(
554                    "Third argument must be an interval or NULL, got {:?}",
555                    other
556                )
557            }
558        };
559
560        let schema = Arc::new(Schema::new(vec![Field::new(
561            "value",
562            DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
563            false,
564        )]));
565
566        // Check if any argument is null
567        let (Some(start), Some(end), Some(step)) = (start_ts, end_ts, step_interval)
568        else {
569            return Ok(Arc::new(GenerateSeriesTable {
570                schema,
571                args: GenSeriesArgs::ContainsNull { name: self.name },
572            }));
573        };
574
575        // Validate step interval
576        validate_interval_step(step, start, end)?;
577
578        Ok(Arc::new(GenerateSeriesTable {
579            schema,
580            args: GenSeriesArgs::TimestampArgs {
581                start,
582                end,
583                step,
584                tz,
585                include_end: self.include_end,
586                name: self.name,
587            },
588        }))
589    }
590
591    fn call_date(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
592        if exprs.len() != 3 {
593            return plan_err!(
594                "{} function with dates requires exactly 3 arguments",
595                self.name
596            );
597        }
598
599        let schema = Arc::new(Schema::new(vec![Field::new(
600            "value",
601            DataType::Timestamp(TimeUnit::Nanosecond, None),
602            false,
603        )]));
604
605        // Parse start date
606        let start_date = match &exprs[0] {
607            Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
608            Expr::Literal(ScalarValue::Date32(None), _)
609            | Expr::Literal(ScalarValue::Null, _) => {
610                return Ok(Arc::new(GenerateSeriesTable {
611                    schema,
612                    args: GenSeriesArgs::ContainsNull { name: self.name },
613                }));
614            }
615            other => {
616                return plan_err!(
617                    "First argument must be a date or NULL, got {:?}",
618                    other
619                )
620            }
621        };
622
623        // Parse end date
624        let end_date = match &exprs[1] {
625            Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
626            Expr::Literal(ScalarValue::Date32(None), _)
627            | Expr::Literal(ScalarValue::Null, _) => {
628                return Ok(Arc::new(GenerateSeriesTable {
629                    schema,
630                    args: GenSeriesArgs::ContainsNull { name: self.name },
631                }));
632            }
633            other => {
634                return plan_err!(
635                    "Second argument must be a date or NULL, got {:?}",
636                    other
637                )
638            }
639        };
640
641        // Parse step interval
642        let step_interval = match &exprs[2] {
643            Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval)), _) => {
644                *interval
645            }
646            Expr::Literal(ScalarValue::IntervalMonthDayNano(None), _)
647            | Expr::Literal(ScalarValue::Null, _) => {
648                return Ok(Arc::new(GenerateSeriesTable {
649                    schema,
650                    args: GenSeriesArgs::ContainsNull { name: self.name },
651                }));
652            }
653            other => {
654                return plan_err!(
655                    "Third argument must be an interval or NULL, got {:?}",
656                    other
657                )
658            }
659        };
660
661        // Convert Date32 (days since epoch) to timestamp nanoseconds (nanoseconds since epoch)
662        // Date32 is days since 1970-01-01, so multiply by nanoseconds per day
663        const NANOS_PER_DAY: i64 = 24 * 60 * 60 * 1_000_000_000;
664
665        let start_ts = start_date as i64 * NANOS_PER_DAY;
666        let end_ts = end_date as i64 * NANOS_PER_DAY;
667
668        // Validate step interval
669        validate_interval_step(step_interval, start_ts, end_ts)?;
670
671        Ok(Arc::new(GenerateSeriesTable {
672            schema,
673            args: GenSeriesArgs::DateArgs {
674                start: start_ts,
675                end: end_ts,
676                step: step_interval,
677                include_end: self.include_end,
678                name: self.name,
679            },
680        }))
681    }
682}
683
684#[derive(Debug)]
685pub struct GenerateSeriesFunc {}
686
687impl TableFunctionImpl for GenerateSeriesFunc {
688    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
689        let impl_func = GenerateSeriesFuncImpl {
690            name: "generate_series",
691            include_end: true,
692        };
693        impl_func.call(exprs)
694    }
695}
696
697#[derive(Debug)]
698pub struct RangeFunc {}
699
700impl TableFunctionImpl for RangeFunc {
701    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
702        let impl_func = GenerateSeriesFuncImpl {
703            name: "range",
704            include_end: false,
705        };
706        impl_func.call(exprs)
707    }
708}