1use arrow::array::timezone::Tz;
19use arrow::array::types::TimestampNanosecondType;
20use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray};
21use arrow::datatypes::{
22 DataType, Field, IntervalMonthDayNano, Schema, SchemaRef, TimeUnit,
23};
24use arrow::record_batch::RecordBatch;
25use async_trait::async_trait;
26use datafusion_catalog::Session;
27use datafusion_catalog::TableFunctionImpl;
28use datafusion_catalog::TableProvider;
29use datafusion_common::{Result, ScalarValue, plan_err};
30use datafusion_expr::{Expr, TableType};
31use datafusion_physical_plan::ExecutionPlan;
32use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
33use parking_lot::RwLock;
34use std::any::Any;
35use std::fmt;
36use std::str::FromStr;
37use std::sync::Arc;
38
39#[derive(Debug, Clone)]
41pub struct Empty {
42 name: &'static str,
43}
44
45impl Empty {
46 pub fn name(&self) -> &'static str {
47 self.name
48 }
49}
50
51impl LazyBatchGenerator for Empty {
52 fn as_any(&self) -> &dyn Any {
53 self
54 }
55
56 fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
57 Ok(None)
58 }
59
60 fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
61 Arc::new(RwLock::new(Empty { name: self.name }))
62 }
63}
64
65impl fmt::Display for Empty {
66 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67 write!(f, "{}: empty", self.name)
68 }
69}
70
71pub trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static {
73 type StepType: fmt::Debug + Clone + Send + Sync;
74 type ValueType: fmt::Debug + Clone + Send + Sync;
75
76 fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool;
78
79 fn advance(&mut self, step: &Self::StepType) -> Result<()>;
81
82 fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef>;
84
85 fn to_value_type(&self) -> Self::ValueType;
87
88 fn display_value(&self) -> String;
90}
91
92impl SeriesValue for i64 {
93 type StepType = i64;
94 type ValueType = i64;
95
96 fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
97 reach_end_int64(*self, end, *step, include_end)
98 }
99
100 fn advance(&mut self, step: &Self::StepType) -> Result<()> {
101 *self += step;
102 Ok(())
103 }
104
105 fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
106 Ok(Arc::new(Int64Array::from(values)))
107 }
108
109 fn to_value_type(&self) -> Self::ValueType {
110 *self
111 }
112
113 fn display_value(&self) -> String {
114 self.to_string()
115 }
116}
117
118#[derive(Debug, Clone)]
119pub struct TimestampValue {
120 value: i64,
121 parsed_tz: Option<Tz>,
122 tz_str: Option<Arc<str>>,
123}
124
125impl TimestampValue {
126 pub fn value(&self) -> i64 {
127 self.value
128 }
129
130 pub fn tz_str(&self) -> Option<&Arc<str>> {
131 self.tz_str.as_ref()
132 }
133}
134
135impl SeriesValue for TimestampValue {
136 type StepType = IntervalMonthDayNano;
137 type ValueType = i64;
138
139 fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
140 let step_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
141
142 if include_end {
143 if step_negative {
144 self.value < end.value
145 } else {
146 self.value > end.value
147 }
148 } else if step_negative {
149 self.value <= end.value
150 } else {
151 self.value >= end.value
152 }
153 }
154
155 fn advance(&mut self, step: &Self::StepType) -> Result<()> {
156 let tz = self
157 .parsed_tz
158 .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
159 let Some(next_ts) =
160 TimestampNanosecondType::add_month_day_nano(self.value, *step, tz)
161 else {
162 return plan_err!(
163 "Failed to add interval {:?} to timestamp {}",
164 step,
165 self.value
166 );
167 };
168 self.value = next_ts;
169 Ok(())
170 }
171
172 fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
173 let array = TimestampNanosecondArray::from(values);
174
175 let array = match self.tz_str.as_ref() {
177 Some(tz_str) => array.with_timezone(Arc::clone(tz_str)),
178 None => array,
179 };
180
181 Ok(Arc::new(array))
182 }
183
184 fn to_value_type(&self) -> Self::ValueType {
185 self.value
186 }
187
188 fn display_value(&self) -> String {
189 self.value.to_string()
190 }
191}
192
193#[derive(Debug, Clone)]
195pub enum GenSeriesArgs {
196 ContainsNull { name: &'static str },
198 Int64Args {
200 start: i64,
201 end: i64,
202 step: i64,
203 include_end: bool,
205 name: &'static str,
206 },
207 TimestampArgs {
209 start: i64,
210 end: i64,
211 step: IntervalMonthDayNano,
212 tz: Option<Arc<str>>,
213 include_end: bool,
215 name: &'static str,
216 },
217 DateArgs {
220 start: i64,
221 end: i64,
222 step: IntervalMonthDayNano,
223 include_end: bool,
225 name: &'static str,
226 },
227}
228
229#[derive(Debug, Clone)]
231pub struct GenerateSeriesTable {
232 schema: SchemaRef,
233 args: GenSeriesArgs,
234}
235
236impl GenerateSeriesTable {
237 pub fn new(schema: SchemaRef, args: GenSeriesArgs) -> Self {
238 Self { schema, args }
239 }
240
241 pub fn as_generator(
242 &self,
243 batch_size: usize,
244 ) -> Result<Arc<RwLock<dyn LazyBatchGenerator>>> {
245 let generator: Arc<RwLock<dyn LazyBatchGenerator>> = match &self.args {
246 GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })),
247 GenSeriesArgs::Int64Args {
248 start,
249 end,
250 step,
251 include_end,
252 name,
253 } => Arc::new(RwLock::new(GenericSeriesState {
254 schema: self.schema(),
255 start: *start,
256 end: *end,
257 step: *step,
258 current: *start,
259 batch_size,
260 include_end: *include_end,
261 name,
262 })),
263 GenSeriesArgs::TimestampArgs {
264 start,
265 end,
266 step,
267 tz,
268 include_end,
269 name,
270 } => {
271 let parsed_tz = tz
272 .as_ref()
273 .map(|s| Tz::from_str(s.as_ref()))
274 .transpose()
275 .map_err(|e| {
276 datafusion_common::internal_datafusion_err!(
277 "Failed to parse timezone: {e}"
278 )
279 })?
280 .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
281 Arc::new(RwLock::new(GenericSeriesState {
282 schema: self.schema(),
283 start: TimestampValue {
284 value: *start,
285 parsed_tz: Some(parsed_tz),
286 tz_str: tz.clone(),
287 },
288 end: TimestampValue {
289 value: *end,
290 parsed_tz: Some(parsed_tz),
291 tz_str: tz.clone(),
292 },
293 step: *step,
294 current: TimestampValue {
295 value: *start,
296 parsed_tz: Some(parsed_tz),
297 tz_str: tz.clone(),
298 },
299 batch_size,
300 include_end: *include_end,
301 name,
302 }))
303 }
304 GenSeriesArgs::DateArgs {
305 start,
306 end,
307 step,
308 include_end,
309 name,
310 } => Arc::new(RwLock::new(GenericSeriesState {
311 schema: self.schema(),
312 start: TimestampValue {
313 value: *start,
314 parsed_tz: None,
315 tz_str: None,
316 },
317 end: TimestampValue {
318 value: *end,
319 parsed_tz: None,
320 tz_str: None,
321 },
322 step: *step,
323 current: TimestampValue {
324 value: *start,
325 parsed_tz: None,
326 tz_str: None,
327 },
328 batch_size,
329 include_end: *include_end,
330 name,
331 })),
332 };
333
334 Ok(generator)
335 }
336}
337
338#[derive(Debug, Clone)]
339pub struct GenericSeriesState<T: SeriesValue> {
340 schema: SchemaRef,
341 start: T,
342 end: T,
343 step: T::StepType,
344 batch_size: usize,
345 current: T,
346 include_end: bool,
347 name: &'static str,
348}
349
350impl<T: SeriesValue> GenericSeriesState<T> {
351 pub fn name(&self) -> &'static str {
352 self.name
353 }
354
355 pub fn batch_size(&self) -> usize {
356 self.batch_size
357 }
358
359 pub fn include_end(&self) -> bool {
360 self.include_end
361 }
362
363 pub fn start(&self) -> &T {
364 &self.start
365 }
366
367 pub fn end(&self) -> &T {
368 &self.end
369 }
370
371 pub fn step(&self) -> &T::StepType {
372 &self.step
373 }
374
375 pub fn current(&self) -> &T {
376 &self.current
377 }
378}
379
380impl<T: SeriesValue> LazyBatchGenerator for GenericSeriesState<T> {
381 fn as_any(&self) -> &dyn Any {
382 self
383 }
384
385 fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
386 let mut buf = Vec::with_capacity(self.batch_size);
387
388 while buf.len() < self.batch_size
389 && !self
390 .current
391 .should_stop(self.end.clone(), &self.step, self.include_end)
392 {
393 buf.push(self.current.to_value_type());
394 self.current.advance(&self.step)?;
395 }
396
397 if buf.is_empty() {
398 return Ok(None);
399 }
400
401 let array = self.current.create_array(buf)?;
402 let batch = RecordBatch::try_new(Arc::clone(&self.schema), vec![array])?;
403 Ok(Some(batch))
404 }
405
406 fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
407 let mut new = self.clone();
408 new.current = new.start.clone();
409 Arc::new(RwLock::new(new))
410 }
411}
412
413impl<T: SeriesValue> fmt::Display for GenericSeriesState<T> {
414 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
415 write!(
416 f,
417 "{}: start={}, end={}, batch_size={}",
418 self.name,
419 self.start.display_value(),
420 self.end.display_value(),
421 self.batch_size
422 )
423 }
424}
425
426fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool {
427 if step > 0 {
428 if include_end { val > end } else { val >= end }
429 } else if include_end {
430 val < end
431 } else {
432 val <= end
433 }
434}
435
436fn validate_interval_step(
437 step: IntervalMonthDayNano,
438 start: i64,
439 end: i64,
440) -> Result<()> {
441 if step.months == 0 && step.days == 0 && step.nanoseconds == 0 {
442 return plan_err!("Step interval cannot be zero");
443 }
444
445 let step_is_positive = step.months > 0 || step.days > 0 || step.nanoseconds > 0;
446 let step_is_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
447
448 if start > end && step_is_positive {
449 return plan_err!(
450 "Start is bigger than end, but increment is positive: Cannot generate infinite series"
451 );
452 }
453
454 if start < end && step_is_negative {
455 return plan_err!(
456 "Start is smaller than end, but increment is negative: Cannot generate infinite series"
457 );
458 }
459
460 Ok(())
461}
462
463#[async_trait]
464impl TableProvider for GenerateSeriesTable {
465 fn as_any(&self) -> &dyn Any {
466 self
467 }
468
469 fn schema(&self) -> SchemaRef {
470 Arc::clone(&self.schema)
471 }
472
473 fn table_type(&self) -> TableType {
474 TableType::Base
475 }
476
477 async fn scan(
478 &self,
479 state: &dyn Session,
480 projection: Option<&Vec<usize>>,
481 _filters: &[Expr],
482 _limit: Option<usize>,
483 ) -> Result<Arc<dyn ExecutionPlan>> {
484 let batch_size = state.config_options().execution.batch_size;
485 let generator = self.as_generator(batch_size)?;
486
487 Ok(Arc::new(
488 LazyMemoryExec::try_new(self.schema(), vec![generator])?
489 .with_projection(projection.cloned()),
490 ))
491 }
492}
493
494#[derive(Debug)]
495struct GenerateSeriesFuncImpl {
496 name: &'static str,
497 include_end: bool,
498}
499
500impl TableFunctionImpl for GenerateSeriesFuncImpl {
501 fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
502 if exprs.is_empty() || exprs.len() > 3 {
503 return plan_err!("{} function requires 1 to 3 arguments", self.name);
504 }
505
506 match &exprs[0] {
508 Expr::Literal(
509 ScalarValue::Null | ScalarValue::Int64(_),
511 _,
512 ) => self.call_int64(exprs),
513 Expr::Literal(s, _) if matches!(s.data_type(), DataType::Timestamp(_, _)) => {
514 self.call_timestamp(exprs)
515 }
516 Expr::Literal(s, _) if matches!(s.data_type(), DataType::Date32) => {
517 self.call_date(exprs)
518 }
519 Expr::Literal(scalar, _) => {
520 plan_err!(
521 "Argument #1 must be an INTEGER, TIMESTAMP, DATE or NULL, got {:?}",
522 scalar.data_type()
523 )
524 }
525 _ => plan_err!("Arguments must be literals"),
526 }
527 }
528}
529
530impl GenerateSeriesFuncImpl {
531 fn call_int64(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
532 let mut normalize_args = Vec::new();
533 for (expr_index, expr) in exprs.iter().enumerate() {
534 match expr {
535 Expr::Literal(ScalarValue::Null, _) => {}
536 Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n),
537 other => {
538 return plan_err!(
539 "Argument #{} must be an INTEGER or NULL, got {:?}",
540 expr_index + 1,
541 other
542 );
543 }
544 };
545 }
546
547 let schema = Arc::new(Schema::new(vec![Field::new(
548 "value",
549 DataType::Int64,
550 false,
551 )]));
552
553 if normalize_args.len() != exprs.len() {
554 return Ok(Arc::new(GenerateSeriesTable {
556 schema,
557 args: GenSeriesArgs::ContainsNull { name: self.name },
558 }));
559 }
560
561 let (start, end, step) = match &normalize_args[..] {
562 [end] => (0, *end, 1),
563 [start, end] => (*start, *end, 1),
564 [start, end, step] => (*start, *end, *step),
565 _ => {
566 return plan_err!("{} function requires 1 to 3 arguments", self.name);
567 }
568 };
569
570 if start > end && step > 0 {
571 return plan_err!(
572 "Start is bigger than end, but increment is positive: Cannot generate infinite series"
573 );
574 }
575
576 if start < end && step < 0 {
577 return plan_err!(
578 "Start is smaller than end, but increment is negative: Cannot generate infinite series"
579 );
580 }
581
582 if step == 0 {
583 return plan_err!("Step cannot be zero");
584 }
585
586 Ok(Arc::new(GenerateSeriesTable {
587 schema,
588 args: GenSeriesArgs::Int64Args {
589 start,
590 end,
591 step,
592 include_end: self.include_end,
593 name: self.name,
594 },
595 }))
596 }
597
598 fn call_timestamp(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
599 if exprs.len() != 3 {
600 return plan_err!(
601 "{} function with timestamps requires exactly 3 arguments",
602 self.name
603 );
604 }
605
606 let (start_ts, tz) = match &exprs[0] {
608 Expr::Literal(ScalarValue::TimestampNanosecond(ts, tz), _) => {
609 (*ts, tz.clone())
610 }
611 other => {
612 return plan_err!(
613 "First argument must be a timestamp or NULL, got {:?}",
614 other
615 );
616 }
617 };
618
619 let end_ts = match &exprs[1] {
621 Expr::Literal(ScalarValue::Null, _) => None,
622 Expr::Literal(ScalarValue::TimestampNanosecond(ts, _), _) => *ts,
623 other => {
624 return plan_err!(
625 "Second argument must be a timestamp or NULL, got {:?}",
626 other
627 );
628 }
629 };
630
631 let step_interval = match &exprs[2] {
633 Expr::Literal(ScalarValue::Null, _) => None,
634 Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), _) => *interval,
635 other => {
636 return plan_err!(
637 "Third argument must be an interval or NULL, got {:?}",
638 other
639 );
640 }
641 };
642
643 let schema = Arc::new(Schema::new(vec![Field::new(
644 "value",
645 DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
646 false,
647 )]));
648
649 let (Some(start), Some(end), Some(step)) = (start_ts, end_ts, step_interval)
651 else {
652 return Ok(Arc::new(GenerateSeriesTable {
653 schema,
654 args: GenSeriesArgs::ContainsNull { name: self.name },
655 }));
656 };
657
658 validate_interval_step(step, start, end)?;
660
661 Ok(Arc::new(GenerateSeriesTable {
662 schema,
663 args: GenSeriesArgs::TimestampArgs {
664 start,
665 end,
666 step,
667 tz,
668 include_end: self.include_end,
669 name: self.name,
670 },
671 }))
672 }
673
674 fn call_date(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
675 if exprs.len() != 3 {
676 return plan_err!(
677 "{} function with dates requires exactly 3 arguments",
678 self.name
679 );
680 }
681
682 let schema = Arc::new(Schema::new(vec![Field::new(
683 "value",
684 DataType::Timestamp(TimeUnit::Nanosecond, None),
685 false,
686 )]));
687
688 let start_date = match &exprs[0] {
690 Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
691 Expr::Literal(ScalarValue::Date32(None), _)
692 | Expr::Literal(ScalarValue::Null, _) => {
693 return Ok(Arc::new(GenerateSeriesTable {
694 schema,
695 args: GenSeriesArgs::ContainsNull { name: self.name },
696 }));
697 }
698 other => {
699 return plan_err!(
700 "First argument must be a date or NULL, got {:?}",
701 other
702 );
703 }
704 };
705
706 let end_date = match &exprs[1] {
708 Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
709 Expr::Literal(ScalarValue::Date32(None), _)
710 | Expr::Literal(ScalarValue::Null, _) => {
711 return Ok(Arc::new(GenerateSeriesTable {
712 schema,
713 args: GenSeriesArgs::ContainsNull { name: self.name },
714 }));
715 }
716 other => {
717 return plan_err!(
718 "Second argument must be a date or NULL, got {:?}",
719 other
720 );
721 }
722 };
723
724 let step_interval = match &exprs[2] {
726 Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval)), _) => {
727 *interval
728 }
729 Expr::Literal(ScalarValue::IntervalMonthDayNano(None), _)
730 | Expr::Literal(ScalarValue::Null, _) => {
731 return Ok(Arc::new(GenerateSeriesTable {
732 schema,
733 args: GenSeriesArgs::ContainsNull { name: self.name },
734 }));
735 }
736 other => {
737 return plan_err!(
738 "Third argument must be an interval or NULL, got {:?}",
739 other
740 );
741 }
742 };
743
744 const NANOS_PER_DAY: i64 = 24 * 60 * 60 * 1_000_000_000;
747
748 let start_ts = start_date as i64 * NANOS_PER_DAY;
749 let end_ts = end_date as i64 * NANOS_PER_DAY;
750
751 validate_interval_step(step_interval, start_ts, end_ts)?;
753
754 Ok(Arc::new(GenerateSeriesTable {
755 schema,
756 args: GenSeriesArgs::DateArgs {
757 start: start_ts,
758 end: end_ts,
759 step: step_interval,
760 include_end: self.include_end,
761 name: self.name,
762 },
763 }))
764 }
765}
766
767#[derive(Debug)]
768pub struct GenerateSeriesFunc {}
769
770impl TableFunctionImpl for GenerateSeriesFunc {
771 fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
772 let impl_func = GenerateSeriesFuncImpl {
773 name: "generate_series",
774 include_end: true,
775 };
776 impl_func.call(exprs)
777 }
778}
779
780#[derive(Debug)]
781pub struct RangeFunc {}
782
783impl TableFunctionImpl for RangeFunc {
784 fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
785 let impl_func = GenerateSeriesFuncImpl {
786 name: "range",
787 include_end: false,
788 };
789 impl_func.call(exprs)
790 }
791}
792
793#[cfg(test)]
794mod generate_series_tests {
795 use std::sync::Arc;
796
797 use arrow::datatypes::{DataType, Field, Schema};
798 use datafusion_common::Result;
799 use datafusion_physical_plan::memory::LazyBatchGenerator;
800
801 use crate::generate_series::GenericSeriesState;
802
803 #[test]
804 fn test_generic_series_state_reset() -> Result<()> {
805 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
806 let mut state = GenericSeriesState::<i64> {
807 schema,
808 start: 1,
809 end: 5,
810 step: 1,
811 current: 1,
812 batch_size: 8192,
813 include_end: true,
814 name: "test",
815 };
816 let batch = state.generate_next_batch()?.expect("missing batch");
817
818 let state_reset = state.reset_state();
819 let reset_batch = state_reset
820 .write()
821 .generate_next_batch()?
822 .expect("missing reset batch");
823
824 assert_eq!(batch, reset_batch);
825
826 Ok(())
827 }
828}