1use arrow::array::timezone::Tz;
19use arrow::array::types::TimestampNanosecondType;
20use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray};
21use arrow::datatypes::{
22 DataType, Field, IntervalMonthDayNano, Schema, SchemaRef, TimeUnit,
23};
24use arrow::record_batch::RecordBatch;
25use async_trait::async_trait;
26use datafusion_catalog::Session;
27use datafusion_catalog::TableFunctionImpl;
28use datafusion_catalog::TableProvider;
29use datafusion_common::{plan_err, Result, ScalarValue};
30use datafusion_expr::{Expr, TableType};
31use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
32use datafusion_physical_plan::ExecutionPlan;
33use parking_lot::RwLock;
34use std::fmt;
35use std::str::FromStr;
36use std::sync::Arc;
37
38#[derive(Debug, Clone)]
40struct Empty {
41 name: &'static str,
42}
43
44impl LazyBatchGenerator for Empty {
45 fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
46 Ok(None)
47 }
48}
49
50impl fmt::Display for Empty {
51 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52 write!(f, "{}: empty", self.name)
53 }
54}
55
56trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static {
58 type StepType: fmt::Debug + Clone + Send + Sync;
59 type ValueType: fmt::Debug + Clone + Send + Sync;
60
61 fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool;
63
64 fn advance(&mut self, step: &Self::StepType) -> Result<()>;
66
67 fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef>;
69
70 fn to_value_type(&self) -> Self::ValueType;
72
73 fn display_value(&self) -> String;
75}
76
77impl SeriesValue for i64 {
78 type StepType = i64;
79 type ValueType = i64;
80
81 fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
82 reach_end_int64(*self, end, *step, include_end)
83 }
84
85 fn advance(&mut self, step: &Self::StepType) -> Result<()> {
86 *self += step;
87 Ok(())
88 }
89
90 fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
91 Ok(Arc::new(Int64Array::from(values)))
92 }
93
94 fn to_value_type(&self) -> Self::ValueType {
95 *self
96 }
97
98 fn display_value(&self) -> String {
99 self.to_string()
100 }
101}
102
103#[derive(Debug, Clone)]
104struct TimestampValue {
105 value: i64,
106 parsed_tz: Option<Tz>,
107 tz_str: Option<Arc<str>>,
108}
109
110impl SeriesValue for TimestampValue {
111 type StepType = IntervalMonthDayNano;
112 type ValueType = i64;
113
114 fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
115 let step_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
116
117 if include_end {
118 if step_negative {
119 self.value < end.value
120 } else {
121 self.value > end.value
122 }
123 } else if step_negative {
124 self.value <= end.value
125 } else {
126 self.value >= end.value
127 }
128 }
129
130 fn advance(&mut self, step: &Self::StepType) -> Result<()> {
131 let tz = self
132 .parsed_tz
133 .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
134 let Some(next_ts) =
135 TimestampNanosecondType::add_month_day_nano(self.value, *step, tz)
136 else {
137 return plan_err!(
138 "Failed to add interval {:?} to timestamp {}",
139 step,
140 self.value
141 );
142 };
143 self.value = next_ts;
144 Ok(())
145 }
146
147 fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
148 let array = TimestampNanosecondArray::from(values);
149
150 let array = match self.tz_str.as_ref() {
152 Some(tz_str) => array.with_timezone(Arc::clone(tz_str)),
153 None => array,
154 };
155
156 Ok(Arc::new(array))
157 }
158
159 fn to_value_type(&self) -> Self::ValueType {
160 self.value
161 }
162
163 fn display_value(&self) -> String {
164 self.value.to_string()
165 }
166}
167
168#[derive(Debug, Clone)]
170enum GenSeriesArgs {
171 ContainsNull { name: &'static str },
173 Int64Args {
175 start: i64,
176 end: i64,
177 step: i64,
178 include_end: bool,
180 name: &'static str,
181 },
182 TimestampArgs {
184 start: i64,
185 end: i64,
186 step: IntervalMonthDayNano,
187 tz: Option<Arc<str>>,
188 include_end: bool,
190 name: &'static str,
191 },
192 DateArgs {
195 start: i64,
196 end: i64,
197 step: IntervalMonthDayNano,
198 include_end: bool,
200 name: &'static str,
201 },
202}
203
204#[derive(Debug, Clone)]
206struct GenerateSeriesTable {
207 schema: SchemaRef,
208 args: GenSeriesArgs,
209}
210
211#[derive(Debug, Clone)]
212struct GenericSeriesState<T: SeriesValue> {
213 schema: SchemaRef,
214 start: T,
215 end: T,
216 step: T::StepType,
217 batch_size: usize,
218 current: T,
219 include_end: bool,
220 name: &'static str,
221}
222
223impl<T: SeriesValue> LazyBatchGenerator for GenericSeriesState<T> {
224 fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
225 let mut buf = Vec::with_capacity(self.batch_size);
226
227 while buf.len() < self.batch_size
228 && !self
229 .current
230 .should_stop(self.end.clone(), &self.step, self.include_end)
231 {
232 buf.push(self.current.to_value_type());
233 self.current.advance(&self.step)?;
234 }
235
236 if buf.is_empty() {
237 return Ok(None);
238 }
239
240 let array = self.current.create_array(buf)?;
241 let batch = RecordBatch::try_new(Arc::clone(&self.schema), vec![array])?;
242 Ok(Some(batch))
243 }
244}
245
246impl<T: SeriesValue> fmt::Display for GenericSeriesState<T> {
247 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
248 write!(
249 f,
250 "{}: start={}, end={}, batch_size={}",
251 self.name,
252 self.start.display_value(),
253 self.end.display_value(),
254 self.batch_size
255 )
256 }
257}
258
259fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool {
260 if step > 0 {
261 if include_end {
262 val > end
263 } else {
264 val >= end
265 }
266 } else if include_end {
267 val < end
268 } else {
269 val <= end
270 }
271}
272
273fn validate_interval_step(
274 step: IntervalMonthDayNano,
275 start: i64,
276 end: i64,
277) -> Result<()> {
278 if step.months == 0 && step.days == 0 && step.nanoseconds == 0 {
279 return plan_err!("Step interval cannot be zero");
280 }
281
282 let step_is_positive = step.months > 0 || step.days > 0 || step.nanoseconds > 0;
283 let step_is_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
284
285 if start > end && step_is_positive {
286 return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series");
287 }
288
289 if start < end && step_is_negative {
290 return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series");
291 }
292
293 Ok(())
294}
295
296#[async_trait]
297impl TableProvider for GenerateSeriesTable {
298 fn as_any(&self) -> &dyn std::any::Any {
299 self
300 }
301
302 fn schema(&self) -> SchemaRef {
303 Arc::clone(&self.schema)
304 }
305
306 fn table_type(&self) -> TableType {
307 TableType::Base
308 }
309
310 async fn scan(
311 &self,
312 state: &dyn Session,
313 projection: Option<&Vec<usize>>,
314 _filters: &[Expr],
315 _limit: Option<usize>,
316 ) -> Result<Arc<dyn ExecutionPlan>> {
317 let batch_size = state.config_options().execution.batch_size;
318 let schema = match projection {
319 Some(projection) => Arc::new(self.schema.project(projection)?),
320 None => self.schema(),
321 };
322 let generator: Arc<RwLock<dyn LazyBatchGenerator>> = match &self.args {
323 GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })),
324 GenSeriesArgs::Int64Args {
325 start,
326 end,
327 step,
328 include_end,
329 name,
330 } => Arc::new(RwLock::new(GenericSeriesState {
331 schema: self.schema(),
332 start: *start,
333 end: *end,
334 step: *step,
335 current: *start,
336 batch_size,
337 include_end: *include_end,
338 name,
339 })),
340 GenSeriesArgs::TimestampArgs {
341 start,
342 end,
343 step,
344 tz,
345 include_end,
346 name,
347 } => {
348 let parsed_tz = tz
349 .as_ref()
350 .map(|s| Tz::from_str(s.as_ref()))
351 .transpose()
352 .map_err(|e| {
353 datafusion_common::DataFusionError::Internal(format!(
354 "Failed to parse timezone: {e}"
355 ))
356 })?
357 .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
358 Arc::new(RwLock::new(GenericSeriesState {
359 schema: self.schema(),
360 start: TimestampValue {
361 value: *start,
362 parsed_tz: Some(parsed_tz),
363 tz_str: tz.clone(),
364 },
365 end: TimestampValue {
366 value: *end,
367 parsed_tz: Some(parsed_tz),
368 tz_str: tz.clone(),
369 },
370 step: *step,
371 current: TimestampValue {
372 value: *start,
373 parsed_tz: Some(parsed_tz),
374 tz_str: tz.clone(),
375 },
376 batch_size,
377 include_end: *include_end,
378 name,
379 }))
380 }
381 GenSeriesArgs::DateArgs {
382 start,
383 end,
384 step,
385 include_end,
386 name,
387 } => Arc::new(RwLock::new(GenericSeriesState {
388 schema: self.schema(),
389 start: TimestampValue {
390 value: *start,
391 parsed_tz: None,
392 tz_str: None,
393 },
394 end: TimestampValue {
395 value: *end,
396 parsed_tz: None,
397 tz_str: None,
398 },
399 step: *step,
400 current: TimestampValue {
401 value: *start,
402 parsed_tz: None,
403 tz_str: None,
404 },
405 batch_size,
406 include_end: *include_end,
407 name,
408 })),
409 };
410
411 Ok(Arc::new(LazyMemoryExec::try_new(schema, vec![generator])?))
412 }
413}
414
415#[derive(Debug)]
416struct GenerateSeriesFuncImpl {
417 name: &'static str,
418 include_end: bool,
419}
420
421impl TableFunctionImpl for GenerateSeriesFuncImpl {
422 fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
423 if exprs.is_empty() || exprs.len() > 3 {
424 return plan_err!("{} function requires 1 to 3 arguments", self.name);
425 }
426
427 match &exprs[0] {
429 Expr::Literal(
430 ScalarValue::Null | ScalarValue::Int64(_),
432 _,
433 ) => self.call_int64(exprs),
434 Expr::Literal(s, _) if matches!(s.data_type(), DataType::Timestamp(_, _)) => {
435 self.call_timestamp(exprs)
436 }
437 Expr::Literal(s, _) if matches!(s.data_type(), DataType::Date32) => {
438 self.call_date(exprs)
439 }
440 Expr::Literal(scalar, _) => {
441 plan_err!(
442 "Argument #1 must be an INTEGER, TIMESTAMP, DATE or NULL, got {:?}",
443 scalar.data_type()
444 )
445 }
446 _ => plan_err!("Arguments must be literals"),
447 }
448 }
449}
450
451impl GenerateSeriesFuncImpl {
452 fn call_int64(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
453 let mut normalize_args = Vec::new();
454 for (expr_index, expr) in exprs.iter().enumerate() {
455 match expr {
456 Expr::Literal(ScalarValue::Null, _) => {}
457 Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n),
458 other => {
459 return plan_err!(
460 "Argument #{} must be an INTEGER or NULL, got {:?}",
461 expr_index + 1,
462 other
463 )
464 }
465 };
466 }
467
468 let schema = Arc::new(Schema::new(vec![Field::new(
469 "value",
470 DataType::Int64,
471 false,
472 )]));
473
474 if normalize_args.len() != exprs.len() {
475 return Ok(Arc::new(GenerateSeriesTable {
477 schema,
478 args: GenSeriesArgs::ContainsNull { name: self.name },
479 }));
480 }
481
482 let (start, end, step) = match &normalize_args[..] {
483 [end] => (0, *end, 1),
484 [start, end] => (*start, *end, 1),
485 [start, end, step] => (*start, *end, *step),
486 _ => {
487 return plan_err!("{} function requires 1 to 3 arguments", self.name);
488 }
489 };
490
491 if start > end && step > 0 {
492 return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series");
493 }
494
495 if start < end && step < 0 {
496 return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series");
497 }
498
499 if step == 0 {
500 return plan_err!("Step cannot be zero");
501 }
502
503 Ok(Arc::new(GenerateSeriesTable {
504 schema,
505 args: GenSeriesArgs::Int64Args {
506 start,
507 end,
508 step,
509 include_end: self.include_end,
510 name: self.name,
511 },
512 }))
513 }
514
515 fn call_timestamp(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
516 if exprs.len() != 3 {
517 return plan_err!(
518 "{} function with timestamps requires exactly 3 arguments",
519 self.name
520 );
521 }
522
523 let (start_ts, tz) = match &exprs[0] {
525 Expr::Literal(ScalarValue::TimestampNanosecond(ts, tz), _) => {
526 (*ts, tz.clone())
527 }
528 other => {
529 return plan_err!(
530 "First argument must be a timestamp or NULL, got {:?}",
531 other
532 )
533 }
534 };
535
536 let end_ts = match &exprs[1] {
538 Expr::Literal(ScalarValue::Null, _) => None,
539 Expr::Literal(ScalarValue::TimestampNanosecond(ts, _), _) => *ts,
540 other => {
541 return plan_err!(
542 "Second argument must be a timestamp or NULL, got {:?}",
543 other
544 )
545 }
546 };
547
548 let step_interval = match &exprs[2] {
550 Expr::Literal(ScalarValue::Null, _) => None,
551 Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), _) => *interval,
552 other => {
553 return plan_err!(
554 "Third argument must be an interval or NULL, got {:?}",
555 other
556 )
557 }
558 };
559
560 let schema = Arc::new(Schema::new(vec![Field::new(
561 "value",
562 DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
563 false,
564 )]));
565
566 let (Some(start), Some(end), Some(step)) = (start_ts, end_ts, step_interval)
568 else {
569 return Ok(Arc::new(GenerateSeriesTable {
570 schema,
571 args: GenSeriesArgs::ContainsNull { name: self.name },
572 }));
573 };
574
575 validate_interval_step(step, start, end)?;
577
578 Ok(Arc::new(GenerateSeriesTable {
579 schema,
580 args: GenSeriesArgs::TimestampArgs {
581 start,
582 end,
583 step,
584 tz,
585 include_end: self.include_end,
586 name: self.name,
587 },
588 }))
589 }
590
591 fn call_date(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
592 if exprs.len() != 3 {
593 return plan_err!(
594 "{} function with dates requires exactly 3 arguments",
595 self.name
596 );
597 }
598
599 let schema = Arc::new(Schema::new(vec![Field::new(
600 "value",
601 DataType::Timestamp(TimeUnit::Nanosecond, None),
602 false,
603 )]));
604
605 let start_date = match &exprs[0] {
607 Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
608 Expr::Literal(ScalarValue::Date32(None), _)
609 | Expr::Literal(ScalarValue::Null, _) => {
610 return Ok(Arc::new(GenerateSeriesTable {
611 schema,
612 args: GenSeriesArgs::ContainsNull { name: self.name },
613 }));
614 }
615 other => {
616 return plan_err!(
617 "First argument must be a date or NULL, got {:?}",
618 other
619 )
620 }
621 };
622
623 let end_date = match &exprs[1] {
625 Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
626 Expr::Literal(ScalarValue::Date32(None), _)
627 | Expr::Literal(ScalarValue::Null, _) => {
628 return Ok(Arc::new(GenerateSeriesTable {
629 schema,
630 args: GenSeriesArgs::ContainsNull { name: self.name },
631 }));
632 }
633 other => {
634 return plan_err!(
635 "Second argument must be a date or NULL, got {:?}",
636 other
637 )
638 }
639 };
640
641 let step_interval = match &exprs[2] {
643 Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval)), _) => {
644 *interval
645 }
646 Expr::Literal(ScalarValue::IntervalMonthDayNano(None), _)
647 | Expr::Literal(ScalarValue::Null, _) => {
648 return Ok(Arc::new(GenerateSeriesTable {
649 schema,
650 args: GenSeriesArgs::ContainsNull { name: self.name },
651 }));
652 }
653 other => {
654 return plan_err!(
655 "Third argument must be an interval or NULL, got {:?}",
656 other
657 )
658 }
659 };
660
661 const NANOS_PER_DAY: i64 = 24 * 60 * 60 * 1_000_000_000;
664
665 let start_ts = start_date as i64 * NANOS_PER_DAY;
666 let end_ts = end_date as i64 * NANOS_PER_DAY;
667
668 validate_interval_step(step_interval, start_ts, end_ts)?;
670
671 Ok(Arc::new(GenerateSeriesTable {
672 schema,
673 args: GenSeriesArgs::DateArgs {
674 start: start_ts,
675 end: end_ts,
676 step: step_interval,
677 include_end: self.include_end,
678 name: self.name,
679 },
680 }))
681 }
682}
683
684#[derive(Debug)]
685pub struct GenerateSeriesFunc {}
686
687impl TableFunctionImpl for GenerateSeriesFunc {
688 fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
689 let impl_func = GenerateSeriesFuncImpl {
690 name: "generate_series",
691 include_end: true,
692 };
693 impl_func.call(exprs)
694 }
695}
696
697#[derive(Debug)]
698pub struct RangeFunc {}
699
700impl TableFunctionImpl for RangeFunc {
701 fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
702 let impl_func = GenerateSeriesFuncImpl {
703 name: "range",
704 include_end: false,
705 };
706 impl_func.call(exprs)
707 }
708}