datafusion_functions_table/
generate_series.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::timezone::Tz;
19use arrow::array::types::TimestampNanosecondType;
20use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray};
21use arrow::datatypes::{
22    DataType, Field, IntervalMonthDayNano, Schema, SchemaRef, TimeUnit,
23};
24use arrow::record_batch::RecordBatch;
25use async_trait::async_trait;
26use datafusion_catalog::Session;
27use datafusion_catalog::TableFunctionImpl;
28use datafusion_catalog::TableProvider;
29use datafusion_common::{Result, ScalarValue, plan_err};
30use datafusion_expr::{Expr, TableType};
31use datafusion_physical_plan::ExecutionPlan;
32use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
33use parking_lot::RwLock;
34use std::any::Any;
35use std::fmt;
36use std::str::FromStr;
37use std::sync::Arc;
38
39/// Empty generator that produces no rows - used when series arguments contain null values
40#[derive(Debug, Clone)]
41pub struct Empty {
42    name: &'static str,
43}
44
45impl Empty {
46    pub fn name(&self) -> &'static str {
47        self.name
48    }
49}
50
51impl LazyBatchGenerator for Empty {
52    fn as_any(&self) -> &dyn Any {
53        self
54    }
55
56    fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
57        Ok(None)
58    }
59
60    fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
61        Arc::new(RwLock::new(Empty { name: self.name }))
62    }
63}
64
65impl fmt::Display for Empty {
66    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67        write!(f, "{}: empty", self.name)
68    }
69}
70
71/// Trait for values that can be generated in a series
72pub trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static {
73    type StepType: fmt::Debug + Clone + Send + Sync;
74    type ValueType: fmt::Debug + Clone + Send + Sync;
75
76    /// Check if we've reached the end of the series
77    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool;
78
79    /// Advance to the next value in the series
80    fn advance(&mut self, step: &Self::StepType) -> Result<()>;
81
82    /// Create an Arrow array from a vector of values
83    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef>;
84
85    /// Convert self to ValueType for array creation
86    fn to_value_type(&self) -> Self::ValueType;
87
88    /// Display the value for debugging
89    fn display_value(&self) -> String;
90}
91
92impl SeriesValue for i64 {
93    type StepType = i64;
94    type ValueType = i64;
95
96    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
97        reach_end_int64(*self, end, *step, include_end)
98    }
99
100    fn advance(&mut self, step: &Self::StepType) -> Result<()> {
101        *self += step;
102        Ok(())
103    }
104
105    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
106        Ok(Arc::new(Int64Array::from(values)))
107    }
108
109    fn to_value_type(&self) -> Self::ValueType {
110        *self
111    }
112
113    fn display_value(&self) -> String {
114        self.to_string()
115    }
116}
117
118#[derive(Debug, Clone)]
119pub struct TimestampValue {
120    value: i64,
121    parsed_tz: Option<Tz>,
122    tz_str: Option<Arc<str>>,
123}
124
125impl TimestampValue {
126    pub fn value(&self) -> i64 {
127        self.value
128    }
129
130    pub fn tz_str(&self) -> Option<&Arc<str>> {
131        self.tz_str.as_ref()
132    }
133}
134
135impl SeriesValue for TimestampValue {
136    type StepType = IntervalMonthDayNano;
137    type ValueType = i64;
138
139    fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool {
140        let step_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
141
142        if include_end {
143            if step_negative {
144                self.value < end.value
145            } else {
146                self.value > end.value
147            }
148        } else if step_negative {
149            self.value <= end.value
150        } else {
151            self.value >= end.value
152        }
153    }
154
155    fn advance(&mut self, step: &Self::StepType) -> Result<()> {
156        let tz = self
157            .parsed_tz
158            .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
159        let Some(next_ts) =
160            TimestampNanosecondType::add_month_day_nano(self.value, *step, tz)
161        else {
162            return plan_err!(
163                "Failed to add interval {:?} to timestamp {}",
164                step,
165                self.value
166            );
167        };
168        self.value = next_ts;
169        Ok(())
170    }
171
172    fn create_array(&self, values: Vec<Self::ValueType>) -> Result<ArrayRef> {
173        let array = TimestampNanosecondArray::from(values);
174
175        // Use timezone from self (now we have access to tz through &self)
176        let array = match self.tz_str.as_ref() {
177            Some(tz_str) => array.with_timezone(Arc::clone(tz_str)),
178            None => array,
179        };
180
181        Ok(Arc::new(array))
182    }
183
184    fn to_value_type(&self) -> Self::ValueType {
185        self.value
186    }
187
188    fn display_value(&self) -> String {
189        self.value.to_string()
190    }
191}
192
193/// Indicates the arguments used for generating a series.
194#[derive(Debug, Clone)]
195pub enum GenSeriesArgs {
196    /// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated.
197    ContainsNull { name: &'static str },
198    /// Int64Args holds the start, end, and step values for generating integer series when all arguments are not null.
199    Int64Args {
200        start: i64,
201        end: i64,
202        step: i64,
203        /// Indicates whether the end value should be included in the series.
204        include_end: bool,
205        name: &'static str,
206    },
207    /// TimestampArgs holds the start, end, and step values for generating timestamp series when all arguments are not null.
208    TimestampArgs {
209        start: i64,
210        end: i64,
211        step: IntervalMonthDayNano,
212        tz: Option<Arc<str>>,
213        /// Indicates whether the end value should be included in the series.
214        include_end: bool,
215        name: &'static str,
216    },
217    /// DateArgs holds the start, end, and step values for generating date series when all arguments are not null.
218    /// Internally, dates are converted to timestamps and use the timestamp logic.
219    DateArgs {
220        start: i64,
221        end: i64,
222        step: IntervalMonthDayNano,
223        /// Indicates whether the end value should be included in the series.
224        include_end: bool,
225        name: &'static str,
226    },
227}
228
229/// Table that generates a series of integers/timestamps from `start`(inclusive) to `end`, incrementing by step
230#[derive(Debug, Clone)]
231pub struct GenerateSeriesTable {
232    schema: SchemaRef,
233    args: GenSeriesArgs,
234}
235
236impl GenerateSeriesTable {
237    pub fn new(schema: SchemaRef, args: GenSeriesArgs) -> Self {
238        Self { schema, args }
239    }
240
241    pub fn as_generator(
242        &self,
243        batch_size: usize,
244    ) -> Result<Arc<RwLock<dyn LazyBatchGenerator>>> {
245        let generator: Arc<RwLock<dyn LazyBatchGenerator>> = match &self.args {
246            GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })),
247            GenSeriesArgs::Int64Args {
248                start,
249                end,
250                step,
251                include_end,
252                name,
253            } => Arc::new(RwLock::new(GenericSeriesState {
254                schema: self.schema(),
255                start: *start,
256                end: *end,
257                step: *step,
258                current: *start,
259                batch_size,
260                include_end: *include_end,
261                name,
262            })),
263            GenSeriesArgs::TimestampArgs {
264                start,
265                end,
266                step,
267                tz,
268                include_end,
269                name,
270            } => {
271                let parsed_tz = tz
272                    .as_ref()
273                    .map(|s| Tz::from_str(s.as_ref()))
274                    .transpose()
275                    .map_err(|e| {
276                        datafusion_common::internal_datafusion_err!(
277                            "Failed to parse timezone: {e}"
278                        )
279                    })?
280                    .unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
281                Arc::new(RwLock::new(GenericSeriesState {
282                    schema: self.schema(),
283                    start: TimestampValue {
284                        value: *start,
285                        parsed_tz: Some(parsed_tz),
286                        tz_str: tz.clone(),
287                    },
288                    end: TimestampValue {
289                        value: *end,
290                        parsed_tz: Some(parsed_tz),
291                        tz_str: tz.clone(),
292                    },
293                    step: *step,
294                    current: TimestampValue {
295                        value: *start,
296                        parsed_tz: Some(parsed_tz),
297                        tz_str: tz.clone(),
298                    },
299                    batch_size,
300                    include_end: *include_end,
301                    name,
302                }))
303            }
304            GenSeriesArgs::DateArgs {
305                start,
306                end,
307                step,
308                include_end,
309                name,
310            } => Arc::new(RwLock::new(GenericSeriesState {
311                schema: self.schema(),
312                start: TimestampValue {
313                    value: *start,
314                    parsed_tz: None,
315                    tz_str: None,
316                },
317                end: TimestampValue {
318                    value: *end,
319                    parsed_tz: None,
320                    tz_str: None,
321                },
322                step: *step,
323                current: TimestampValue {
324                    value: *start,
325                    parsed_tz: None,
326                    tz_str: None,
327                },
328                batch_size,
329                include_end: *include_end,
330                name,
331            })),
332        };
333
334        Ok(generator)
335    }
336}
337
338#[derive(Debug, Clone)]
339pub struct GenericSeriesState<T: SeriesValue> {
340    schema: SchemaRef,
341    start: T,
342    end: T,
343    step: T::StepType,
344    batch_size: usize,
345    current: T,
346    include_end: bool,
347    name: &'static str,
348}
349
350impl<T: SeriesValue> GenericSeriesState<T> {
351    pub fn name(&self) -> &'static str {
352        self.name
353    }
354
355    pub fn batch_size(&self) -> usize {
356        self.batch_size
357    }
358
359    pub fn include_end(&self) -> bool {
360        self.include_end
361    }
362
363    pub fn start(&self) -> &T {
364        &self.start
365    }
366
367    pub fn end(&self) -> &T {
368        &self.end
369    }
370
371    pub fn step(&self) -> &T::StepType {
372        &self.step
373    }
374
375    pub fn current(&self) -> &T {
376        &self.current
377    }
378}
379
380impl<T: SeriesValue> LazyBatchGenerator for GenericSeriesState<T> {
381    fn as_any(&self) -> &dyn Any {
382        self
383    }
384
385    fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
386        let mut buf = Vec::with_capacity(self.batch_size);
387
388        while buf.len() < self.batch_size
389            && !self
390                .current
391                .should_stop(self.end.clone(), &self.step, self.include_end)
392        {
393            buf.push(self.current.to_value_type());
394            self.current.advance(&self.step)?;
395        }
396
397        if buf.is_empty() {
398            return Ok(None);
399        }
400
401        let array = self.current.create_array(buf)?;
402        let batch = RecordBatch::try_new(Arc::clone(&self.schema), vec![array])?;
403        Ok(Some(batch))
404    }
405
406    fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
407        let mut new = self.clone();
408        new.current = new.start.clone();
409        Arc::new(RwLock::new(new))
410    }
411}
412
413impl<T: SeriesValue> fmt::Display for GenericSeriesState<T> {
414    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
415        write!(
416            f,
417            "{}: start={}, end={}, batch_size={}",
418            self.name,
419            self.start.display_value(),
420            self.end.display_value(),
421            self.batch_size
422        )
423    }
424}
425
426fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool {
427    if step > 0 {
428        if include_end { val > end } else { val >= end }
429    } else if include_end {
430        val < end
431    } else {
432        val <= end
433    }
434}
435
436fn validate_interval_step(
437    step: IntervalMonthDayNano,
438    start: i64,
439    end: i64,
440) -> Result<()> {
441    if step.months == 0 && step.days == 0 && step.nanoseconds == 0 {
442        return plan_err!("Step interval cannot be zero");
443    }
444
445    let step_is_positive = step.months > 0 || step.days > 0 || step.nanoseconds > 0;
446    let step_is_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0;
447
448    if start > end && step_is_positive {
449        return plan_err!(
450            "Start is bigger than end, but increment is positive: Cannot generate infinite series"
451        );
452    }
453
454    if start < end && step_is_negative {
455        return plan_err!(
456            "Start is smaller than end, but increment is negative: Cannot generate infinite series"
457        );
458    }
459
460    Ok(())
461}
462
463#[async_trait]
464impl TableProvider for GenerateSeriesTable {
465    fn as_any(&self) -> &dyn Any {
466        self
467    }
468
469    fn schema(&self) -> SchemaRef {
470        Arc::clone(&self.schema)
471    }
472
473    fn table_type(&self) -> TableType {
474        TableType::Base
475    }
476
477    async fn scan(
478        &self,
479        state: &dyn Session,
480        projection: Option<&Vec<usize>>,
481        _filters: &[Expr],
482        _limit: Option<usize>,
483    ) -> Result<Arc<dyn ExecutionPlan>> {
484        let batch_size = state.config_options().execution.batch_size;
485        let generator = self.as_generator(batch_size)?;
486
487        Ok(Arc::new(
488            LazyMemoryExec::try_new(self.schema(), vec![generator])?
489                .with_projection(projection.cloned()),
490        ))
491    }
492}
493
494#[derive(Debug)]
495struct GenerateSeriesFuncImpl {
496    name: &'static str,
497    include_end: bool,
498}
499
500impl TableFunctionImpl for GenerateSeriesFuncImpl {
501    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
502        if exprs.is_empty() || exprs.len() > 3 {
503            return plan_err!("{} function requires 1 to 3 arguments", self.name);
504        }
505
506        // Determine the data type from the first argument
507        match &exprs[0] {
508            Expr::Literal(
509                // Default to int64 for null
510                ScalarValue::Null | ScalarValue::Int64(_),
511                _,
512            ) => self.call_int64(exprs),
513            Expr::Literal(s, _) if matches!(s.data_type(), DataType::Timestamp(_, _)) => {
514                self.call_timestamp(exprs)
515            }
516            Expr::Literal(s, _) if matches!(s.data_type(), DataType::Date32) => {
517                self.call_date(exprs)
518            }
519            Expr::Literal(scalar, _) => {
520                plan_err!(
521                    "Argument #1 must be an INTEGER, TIMESTAMP, DATE or NULL, got {:?}",
522                    scalar.data_type()
523                )
524            }
525            _ => plan_err!("Arguments must be literals"),
526        }
527    }
528}
529
530impl GenerateSeriesFuncImpl {
531    fn call_int64(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
532        let mut normalize_args = Vec::new();
533        for (expr_index, expr) in exprs.iter().enumerate() {
534            match expr {
535                Expr::Literal(ScalarValue::Null, _) => {}
536                Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n),
537                other => {
538                    return plan_err!(
539                        "Argument #{} must be an INTEGER or NULL, got {:?}",
540                        expr_index + 1,
541                        other
542                    );
543                }
544            };
545        }
546
547        let schema = Arc::new(Schema::new(vec![Field::new(
548            "value",
549            DataType::Int64,
550            false,
551        )]));
552
553        if normalize_args.len() != exprs.len() {
554            // contain null
555            return Ok(Arc::new(GenerateSeriesTable {
556                schema,
557                args: GenSeriesArgs::ContainsNull { name: self.name },
558            }));
559        }
560
561        let (start, end, step) = match &normalize_args[..] {
562            [end] => (0, *end, 1),
563            [start, end] => (*start, *end, 1),
564            [start, end, step] => (*start, *end, *step),
565            _ => {
566                return plan_err!("{} function requires 1 to 3 arguments", self.name);
567            }
568        };
569
570        if start > end && step > 0 {
571            return plan_err!(
572                "Start is bigger than end, but increment is positive: Cannot generate infinite series"
573            );
574        }
575
576        if start < end && step < 0 {
577            return plan_err!(
578                "Start is smaller than end, but increment is negative: Cannot generate infinite series"
579            );
580        }
581
582        if step == 0 {
583            return plan_err!("Step cannot be zero");
584        }
585
586        Ok(Arc::new(GenerateSeriesTable {
587            schema,
588            args: GenSeriesArgs::Int64Args {
589                start,
590                end,
591                step,
592                include_end: self.include_end,
593                name: self.name,
594            },
595        }))
596    }
597
598    fn call_timestamp(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
599        if exprs.len() != 3 {
600            return plan_err!(
601                "{} function with timestamps requires exactly 3 arguments",
602                self.name
603            );
604        }
605
606        // Parse start timestamp
607        let (start_ts, tz) = match &exprs[0] {
608            Expr::Literal(ScalarValue::TimestampNanosecond(ts, tz), _) => {
609                (*ts, tz.clone())
610            }
611            other => {
612                return plan_err!(
613                    "First argument must be a timestamp or NULL, got {:?}",
614                    other
615                );
616            }
617        };
618
619        // Parse end timestamp
620        let end_ts = match &exprs[1] {
621            Expr::Literal(ScalarValue::Null, _) => None,
622            Expr::Literal(ScalarValue::TimestampNanosecond(ts, _), _) => *ts,
623            other => {
624                return plan_err!(
625                    "Second argument must be a timestamp or NULL, got {:?}",
626                    other
627                );
628            }
629        };
630
631        // Parse step interval
632        let step_interval = match &exprs[2] {
633            Expr::Literal(ScalarValue::Null, _) => None,
634            Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), _) => *interval,
635            other => {
636                return plan_err!(
637                    "Third argument must be an interval or NULL, got {:?}",
638                    other
639                );
640            }
641        };
642
643        let schema = Arc::new(Schema::new(vec![Field::new(
644            "value",
645            DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
646            false,
647        )]));
648
649        // Check if any argument is null
650        let (Some(start), Some(end), Some(step)) = (start_ts, end_ts, step_interval)
651        else {
652            return Ok(Arc::new(GenerateSeriesTable {
653                schema,
654                args: GenSeriesArgs::ContainsNull { name: self.name },
655            }));
656        };
657
658        // Validate step interval
659        validate_interval_step(step, start, end)?;
660
661        Ok(Arc::new(GenerateSeriesTable {
662            schema,
663            args: GenSeriesArgs::TimestampArgs {
664                start,
665                end,
666                step,
667                tz,
668                include_end: self.include_end,
669                name: self.name,
670            },
671        }))
672    }
673
674    fn call_date(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
675        if exprs.len() != 3 {
676            return plan_err!(
677                "{} function with dates requires exactly 3 arguments",
678                self.name
679            );
680        }
681
682        let schema = Arc::new(Schema::new(vec![Field::new(
683            "value",
684            DataType::Timestamp(TimeUnit::Nanosecond, None),
685            false,
686        )]));
687
688        // Parse start date
689        let start_date = match &exprs[0] {
690            Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
691            Expr::Literal(ScalarValue::Date32(None), _)
692            | Expr::Literal(ScalarValue::Null, _) => {
693                return Ok(Arc::new(GenerateSeriesTable {
694                    schema,
695                    args: GenSeriesArgs::ContainsNull { name: self.name },
696                }));
697            }
698            other => {
699                return plan_err!(
700                    "First argument must be a date or NULL, got {:?}",
701                    other
702                );
703            }
704        };
705
706        // Parse end date
707        let end_date = match &exprs[1] {
708            Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date,
709            Expr::Literal(ScalarValue::Date32(None), _)
710            | Expr::Literal(ScalarValue::Null, _) => {
711                return Ok(Arc::new(GenerateSeriesTable {
712                    schema,
713                    args: GenSeriesArgs::ContainsNull { name: self.name },
714                }));
715            }
716            other => {
717                return plan_err!(
718                    "Second argument must be a date or NULL, got {:?}",
719                    other
720                );
721            }
722        };
723
724        // Parse step interval
725        let step_interval = match &exprs[2] {
726            Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval)), _) => {
727                *interval
728            }
729            Expr::Literal(ScalarValue::IntervalMonthDayNano(None), _)
730            | Expr::Literal(ScalarValue::Null, _) => {
731                return Ok(Arc::new(GenerateSeriesTable {
732                    schema,
733                    args: GenSeriesArgs::ContainsNull { name: self.name },
734                }));
735            }
736            other => {
737                return plan_err!(
738                    "Third argument must be an interval or NULL, got {:?}",
739                    other
740                );
741            }
742        };
743
744        // Convert Date32 (days since epoch) to timestamp nanoseconds (nanoseconds since epoch)
745        // Date32 is days since 1970-01-01, so multiply by nanoseconds per day
746        const NANOS_PER_DAY: i64 = 24 * 60 * 60 * 1_000_000_000;
747
748        let start_ts = start_date as i64 * NANOS_PER_DAY;
749        let end_ts = end_date as i64 * NANOS_PER_DAY;
750
751        // Validate step interval
752        validate_interval_step(step_interval, start_ts, end_ts)?;
753
754        Ok(Arc::new(GenerateSeriesTable {
755            schema,
756            args: GenSeriesArgs::DateArgs {
757                start: start_ts,
758                end: end_ts,
759                step: step_interval,
760                include_end: self.include_end,
761                name: self.name,
762            },
763        }))
764    }
765}
766
767#[derive(Debug)]
768pub struct GenerateSeriesFunc {}
769
770impl TableFunctionImpl for GenerateSeriesFunc {
771    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
772        let impl_func = GenerateSeriesFuncImpl {
773            name: "generate_series",
774            include_end: true,
775        };
776        impl_func.call(exprs)
777    }
778}
779
780#[derive(Debug)]
781pub struct RangeFunc {}
782
783impl TableFunctionImpl for RangeFunc {
784    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
785        let impl_func = GenerateSeriesFuncImpl {
786            name: "range",
787            include_end: false,
788        };
789        impl_func.call(exprs)
790    }
791}
792
793#[cfg(test)]
794mod generate_series_tests {
795    use std::sync::Arc;
796
797    use arrow::datatypes::{DataType, Field, Schema};
798    use datafusion_common::Result;
799    use datafusion_physical_plan::memory::LazyBatchGenerator;
800
801    use crate::generate_series::GenericSeriesState;
802
803    #[test]
804    fn test_generic_series_state_reset() -> Result<()> {
805        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
806        let mut state = GenericSeriesState::<i64> {
807            schema,
808            start: 1,
809            end: 5,
810            step: 1,
811            current: 1,
812            batch_size: 8192,
813            include_end: true,
814            name: "test",
815        };
816        let batch = state.generate_next_batch()?.expect("missing batch");
817
818        let state_reset = state.reset_state();
819        let reset_batch = state_reset
820            .write()
821            .generate_next_batch()?
822            .expect("missing reset batch");
823
824        assert_eq!(batch, reset_batch);
825
826        Ok(())
827    }
828}