datafusion_functions_table/
generate_series.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::Int64Array;
19use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
20use arrow::record_batch::RecordBatch;
21use async_trait::async_trait;
22use datafusion_catalog::Session;
23use datafusion_catalog::TableFunctionImpl;
24use datafusion_catalog::TableProvider;
25use datafusion_common::{plan_err, Result, ScalarValue};
26use datafusion_expr::{Expr, TableType};
27use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
28use datafusion_physical_plan::ExecutionPlan;
29use parking_lot::RwLock;
30use std::fmt;
31use std::sync::Arc;
32
33/// Indicates the arguments used for generating a series.
34#[derive(Debug, Clone)]
35enum GenSeriesArgs {
36    /// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated.
37    ContainsNull {
38        include_end: bool,
39        name: &'static str,
40    },
41    /// AllNotNullArgs holds the start, end, and step values for generating the series when all arguments are not null.
42    AllNotNullArgs {
43        start: i64,
44        end: i64,
45        step: i64,
46        /// Indicates whether the end value should be included in the series.
47        include_end: bool,
48        name: &'static str,
49    },
50}
51
52/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step
53#[derive(Debug, Clone)]
54struct GenerateSeriesTable {
55    schema: SchemaRef,
56    args: GenSeriesArgs,
57}
58
59/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step
60#[derive(Debug, Clone)]
61struct GenerateSeriesState {
62    schema: SchemaRef,
63    start: i64, // Kept for display
64    end: i64,
65    step: i64,
66    batch_size: usize,
67
68    /// Tracks current position when generating table
69    current: i64,
70    /// Indicates whether the end value should be included in the series.
71    include_end: bool,
72    name: &'static str,
73}
74
75impl GenerateSeriesState {
76    fn reach_end(&self, val: i64) -> bool {
77        if self.step > 0 {
78            if self.include_end {
79                return val > self.end;
80            } else {
81                return val >= self.end;
82            }
83        }
84
85        if self.include_end {
86            val < self.end
87        } else {
88            val <= self.end
89        }
90    }
91}
92
93/// Detail to display for 'Explain' plan
94impl fmt::Display for GenerateSeriesState {
95    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96        write!(
97            f,
98            "{}: start={}, end={}, batch_size={}",
99            self.name, self.start, self.end, self.batch_size
100        )
101    }
102}
103
104impl LazyBatchGenerator for GenerateSeriesState {
105    fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
106        let mut buf = Vec::with_capacity(self.batch_size);
107        while buf.len() < self.batch_size && !self.reach_end(self.current) {
108            buf.push(self.current);
109            self.current += self.step;
110        }
111        let array = Int64Array::from(buf);
112
113        if array.is_empty() {
114            return Ok(None);
115        }
116
117        let batch =
118            RecordBatch::try_new(Arc::clone(&self.schema), vec![Arc::new(array)])?;
119
120        Ok(Some(batch))
121    }
122}
123
124#[async_trait]
125impl TableProvider for GenerateSeriesTable {
126    fn as_any(&self) -> &dyn std::any::Any {
127        self
128    }
129
130    fn schema(&self) -> SchemaRef {
131        Arc::clone(&self.schema)
132    }
133
134    fn table_type(&self) -> TableType {
135        TableType::Base
136    }
137
138    async fn scan(
139        &self,
140        state: &dyn Session,
141        projection: Option<&Vec<usize>>,
142        _filters: &[Expr],
143        _limit: Option<usize>,
144    ) -> Result<Arc<dyn ExecutionPlan>> {
145        let batch_size = state.config_options().execution.batch_size;
146        let schema = match projection {
147            Some(projection) => Arc::new(self.schema.project(projection)?),
148            None => self.schema(),
149        };
150        let state = match self.args {
151            // if args have null, then return 0 row
152            GenSeriesArgs::ContainsNull { include_end, name } => GenerateSeriesState {
153                schema: self.schema(),
154                start: 0,
155                end: 0,
156                step: 1,
157                current: 1,
158                batch_size,
159                include_end,
160                name,
161            },
162            GenSeriesArgs::AllNotNullArgs {
163                start,
164                end,
165                step,
166                include_end,
167                name,
168            } => GenerateSeriesState {
169                schema: self.schema(),
170                start,
171                end,
172                step,
173                current: start,
174                batch_size,
175                include_end,
176                name,
177            },
178        };
179
180        Ok(Arc::new(LazyMemoryExec::try_new(
181            schema,
182            vec![Arc::new(RwLock::new(state))],
183        )?))
184    }
185}
186
187#[derive(Debug)]
188struct GenerateSeriesFuncImpl {
189    name: &'static str,
190    include_end: bool,
191}
192
193impl TableFunctionImpl for GenerateSeriesFuncImpl {
194    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
195        if exprs.is_empty() || exprs.len() > 3 {
196            return plan_err!("{} function requires 1 to 3 arguments", self.name);
197        }
198
199        let mut normalize_args = Vec::new();
200        for expr in exprs {
201            match expr {
202                Expr::Literal(ScalarValue::Null, _) => {}
203                Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n),
204                _ => return plan_err!("First argument must be an integer literal"),
205            };
206        }
207
208        let schema = Arc::new(Schema::new(vec![Field::new(
209            "value",
210            DataType::Int64,
211            false,
212        )]));
213
214        if normalize_args.len() != exprs.len() {
215            // contain null
216            return Ok(Arc::new(GenerateSeriesTable {
217                schema,
218                args: GenSeriesArgs::ContainsNull {
219                    include_end: self.include_end,
220                    name: self.name,
221                },
222            }));
223        }
224
225        let (start, end, step) = match &normalize_args[..] {
226            [end] => (0, *end, 1),
227            [start, end] => (*start, *end, 1),
228            [start, end, step] => (*start, *end, *step),
229            _ => {
230                return plan_err!("{} function requires 1 to 3 arguments", self.name);
231            }
232        };
233
234        if start > end && step > 0 {
235            return plan_err!("start is bigger than end, but increment is positive: cannot generate infinite series");
236        }
237
238        if start < end && step < 0 {
239            return plan_err!("start is smaller than end, but increment is negative: cannot generate infinite series");
240        }
241
242        if step == 0 {
243            return plan_err!("step cannot be zero");
244        }
245
246        Ok(Arc::new(GenerateSeriesTable {
247            schema,
248            args: GenSeriesArgs::AllNotNullArgs {
249                start,
250                end,
251                step,
252                include_end: self.include_end,
253                name: self.name,
254            },
255        }))
256    }
257}
258
259#[derive(Debug)]
260pub struct GenerateSeriesFunc {}
261
262impl TableFunctionImpl for GenerateSeriesFunc {
263    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
264        let impl_func = GenerateSeriesFuncImpl {
265            name: "generate_series",
266            include_end: true,
267        };
268        impl_func.call(exprs)
269    }
270}
271
272#[derive(Debug)]
273pub struct RangeFunc {}
274
275impl TableFunctionImpl for RangeFunc {
276    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
277        let impl_func = GenerateSeriesFuncImpl {
278            name: "range",
279            include_end: false,
280        };
281        impl_func.call(exprs)
282    }
283}