datafusion_functions_table/
generate_series.rs1use arrow::array::Int64Array;
19use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
20use arrow::record_batch::RecordBatch;
21use async_trait::async_trait;
22use datafusion_catalog::Session;
23use datafusion_catalog::TableFunctionImpl;
24use datafusion_catalog::TableProvider;
25use datafusion_common::{plan_err, Result, ScalarValue};
26use datafusion_expr::{Expr, TableType};
27use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
28use datafusion_physical_plan::ExecutionPlan;
29use parking_lot::RwLock;
30use std::fmt;
31use std::sync::Arc;
32
33#[derive(Debug, Clone)]
35enum GenSeriesArgs {
36 ContainsNull {
38 include_end: bool,
39 name: &'static str,
40 },
41 AllNotNullArgs {
43 start: i64,
44 end: i64,
45 step: i64,
46 include_end: bool,
48 name: &'static str,
49 },
50}
51
52#[derive(Debug, Clone)]
54struct GenerateSeriesTable {
55 schema: SchemaRef,
56 args: GenSeriesArgs,
57}
58
59#[derive(Debug, Clone)]
61struct GenerateSeriesState {
62 schema: SchemaRef,
63 start: i64, end: i64,
65 step: i64,
66 batch_size: usize,
67
68 current: i64,
70 include_end: bool,
72 name: &'static str,
73}
74
75impl GenerateSeriesState {
76 fn reach_end(&self, val: i64) -> bool {
77 if self.step > 0 {
78 if self.include_end {
79 return val > self.end;
80 } else {
81 return val >= self.end;
82 }
83 }
84
85 if self.include_end {
86 val < self.end
87 } else {
88 val <= self.end
89 }
90 }
91}
92
93impl fmt::Display for GenerateSeriesState {
95 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96 write!(
97 f,
98 "{}: start={}, end={}, batch_size={}",
99 self.name, self.start, self.end, self.batch_size
100 )
101 }
102}
103
104impl LazyBatchGenerator for GenerateSeriesState {
105 fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
106 let mut buf = Vec::with_capacity(self.batch_size);
107 while buf.len() < self.batch_size && !self.reach_end(self.current) {
108 buf.push(self.current);
109 self.current += self.step;
110 }
111 let array = Int64Array::from(buf);
112
113 if array.is_empty() {
114 return Ok(None);
115 }
116
117 let batch =
118 RecordBatch::try_new(Arc::clone(&self.schema), vec![Arc::new(array)])?;
119
120 Ok(Some(batch))
121 }
122}
123
124#[async_trait]
125impl TableProvider for GenerateSeriesTable {
126 fn as_any(&self) -> &dyn std::any::Any {
127 self
128 }
129
130 fn schema(&self) -> SchemaRef {
131 Arc::clone(&self.schema)
132 }
133
134 fn table_type(&self) -> TableType {
135 TableType::Base
136 }
137
138 async fn scan(
139 &self,
140 state: &dyn Session,
141 projection: Option<&Vec<usize>>,
142 _filters: &[Expr],
143 _limit: Option<usize>,
144 ) -> Result<Arc<dyn ExecutionPlan>> {
145 let batch_size = state.config_options().execution.batch_size;
146 let schema = match projection {
147 Some(projection) => Arc::new(self.schema.project(projection)?),
148 None => self.schema(),
149 };
150 let state = match self.args {
151 GenSeriesArgs::ContainsNull { include_end, name } => GenerateSeriesState {
153 schema: self.schema(),
154 start: 0,
155 end: 0,
156 step: 1,
157 current: 1,
158 batch_size,
159 include_end,
160 name,
161 },
162 GenSeriesArgs::AllNotNullArgs {
163 start,
164 end,
165 step,
166 include_end,
167 name,
168 } => GenerateSeriesState {
169 schema: self.schema(),
170 start,
171 end,
172 step,
173 current: start,
174 batch_size,
175 include_end,
176 name,
177 },
178 };
179
180 Ok(Arc::new(LazyMemoryExec::try_new(
181 schema,
182 vec![Arc::new(RwLock::new(state))],
183 )?))
184 }
185}
186
187#[derive(Debug)]
188struct GenerateSeriesFuncImpl {
189 name: &'static str,
190 include_end: bool,
191}
192
193impl TableFunctionImpl for GenerateSeriesFuncImpl {
194 fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
195 if exprs.is_empty() || exprs.len() > 3 {
196 return plan_err!("{} function requires 1 to 3 arguments", self.name);
197 }
198
199 let mut normalize_args = Vec::new();
200 for expr in exprs {
201 match expr {
202 Expr::Literal(ScalarValue::Null, _) => {}
203 Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n),
204 _ => return plan_err!("First argument must be an integer literal"),
205 };
206 }
207
208 let schema = Arc::new(Schema::new(vec![Field::new(
209 "value",
210 DataType::Int64,
211 false,
212 )]));
213
214 if normalize_args.len() != exprs.len() {
215 return Ok(Arc::new(GenerateSeriesTable {
217 schema,
218 args: GenSeriesArgs::ContainsNull {
219 include_end: self.include_end,
220 name: self.name,
221 },
222 }));
223 }
224
225 let (start, end, step) = match &normalize_args[..] {
226 [end] => (0, *end, 1),
227 [start, end] => (*start, *end, 1),
228 [start, end, step] => (*start, *end, *step),
229 _ => {
230 return plan_err!("{} function requires 1 to 3 arguments", self.name);
231 }
232 };
233
234 if start > end && step > 0 {
235 return plan_err!("start is bigger than end, but increment is positive: cannot generate infinite series");
236 }
237
238 if start < end && step < 0 {
239 return plan_err!("start is smaller than end, but increment is negative: cannot generate infinite series");
240 }
241
242 if step == 0 {
243 return plan_err!("step cannot be zero");
244 }
245
246 Ok(Arc::new(GenerateSeriesTable {
247 schema,
248 args: GenSeriesArgs::AllNotNullArgs {
249 start,
250 end,
251 step,
252 include_end: self.include_end,
253 name: self.name,
254 },
255 }))
256 }
257}
258
259#[derive(Debug)]
260pub struct GenerateSeriesFunc {}
261
262impl TableFunctionImpl for GenerateSeriesFunc {
263 fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
264 let impl_func = GenerateSeriesFuncImpl {
265 name: "generate_series",
266 include_end: true,
267 };
268 impl_func.call(exprs)
269 }
270}
271
272#[derive(Debug)]
273pub struct RangeFunc {}
274
275impl TableFunctionImpl for RangeFunc {
276 fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
277 let impl_func = GenerateSeriesFuncImpl {
278 name: "range",
279 include_end: false,
280 };
281 impl_func.call(exprs)
282 }
283}