lance_datagen/
generator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, iter, marker::PhantomData, sync::Arc};
5
6use arrow::{
7    array::{ArrayData, AsArray},
8    buffer::{BooleanBuffer, Buffer, OffsetBuffer, ScalarBuffer},
9    datatypes::{ArrowPrimitiveType, Int32Type, Int64Type, IntervalDayTime, IntervalMonthDayNano},
10};
11use arrow_array::{
12    make_array,
13    types::{ArrowDictionaryKeyType, BinaryType, ByteArrayType, Utf8Type},
14    Array, BinaryArray, FixedSizeBinaryArray, FixedSizeListArray, LargeListArray, ListArray,
15    NullArray, PrimitiveArray, RecordBatch, RecordBatchOptions, RecordBatchReader, StringArray,
16    StructArray,
17};
18use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema, SchemaRef};
19use futures::{stream::BoxStream, StreamExt};
20use rand::{distributions::Uniform, Rng, RngCore, SeedableRng};
21
22use self::array::rand_with_distribution;
23
24#[derive(Copy, Clone, Debug, Default)]
25pub struct RowCount(u64);
26#[derive(Copy, Clone, Debug, Default)]
27pub struct BatchCount(u32);
28#[derive(Copy, Clone, Debug, Default)]
29pub struct ByteCount(u64);
30#[derive(Copy, Clone, Debug, Default)]
31pub struct Dimension(u32);
32
33impl From<u32> for BatchCount {
34    fn from(n: u32) -> Self {
35        Self(n)
36    }
37}
38
39impl From<u64> for RowCount {
40    fn from(n: u64) -> Self {
41        Self(n)
42    }
43}
44
45impl From<u64> for ByteCount {
46    fn from(n: u64) -> Self {
47        Self(n)
48    }
49}
50
51impl From<u32> for Dimension {
52    fn from(n: u32) -> Self {
53        Self(n)
54    }
55}
56
57/// A trait for anything that can generate arrays of data
58pub trait ArrayGenerator: Send + Sync + std::fmt::Debug {
59    /// Generate an array of the given length
60    ///
61    /// # Arguments
62    ///
63    /// * `length` - The number of elements to generate
64    /// * `rng` - The random number generator to use
65    ///
66    /// # Returns
67    ///
68    /// An array of the given length
69    ///
70    /// Note: Not every generator needs an rng.  However, it is passed here because many do and this
71    /// lets us manage RNGs at the batch level instead of the array level.
72    fn generate(
73        &mut self,
74        length: RowCount,
75        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
76    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError>;
77    /// Get the data type of the array that this generator produces
78    ///
79    /// # Returns
80    ///
81    /// The data type of the array that this generator produces
82    fn data_type(&self) -> &DataType;
83    /// Gets metadata that should be associated with the field generated by this generator
84    fn metadata(&self) -> Option<HashMap<String, String>> {
85        None
86    }
87    /// Get the size of each element in bytes
88    ///
89    /// # Returns
90    ///
91    /// The size of each element in bytes.  Will be None if the size varies by element.
92    fn element_size_bytes(&self) -> Option<ByteCount>;
93}
94
95#[derive(Debug)]
96pub struct CycleNullGenerator {
97    generator: Box<dyn ArrayGenerator>,
98    validity: Vec<bool>,
99    idx: usize,
100}
101
102impl ArrayGenerator for CycleNullGenerator {
103    fn generate(
104        &mut self,
105        length: RowCount,
106        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
107    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
108        let array = self.generator.generate(length, rng)?;
109        let data = array.to_data();
110        let validity_itr = self
111            .validity
112            .iter()
113            .cycle()
114            .skip(self.idx)
115            .take(length.0 as usize)
116            .copied();
117        let validity_bitmap = BooleanBuffer::from_iter(validity_itr);
118
119        self.idx = (self.idx + (length.0 as usize)) % self.validity.len();
120        unsafe {
121            let new_data = ArrayData::new_unchecked(
122                data.data_type().clone(),
123                data.len(),
124                None,
125                Some(validity_bitmap.into_inner()),
126                data.offset(),
127                data.buffers().to_vec(),
128                data.child_data().into(),
129            );
130            Ok(make_array(new_data))
131        }
132    }
133
134    fn data_type(&self) -> &DataType {
135        self.generator.data_type()
136    }
137
138    fn element_size_bytes(&self) -> Option<ByteCount> {
139        self.generator.element_size_bytes()
140    }
141}
142
143#[derive(Debug)]
144pub struct MetadataGenerator {
145    generator: Box<dyn ArrayGenerator>,
146    metadata: HashMap<String, String>,
147}
148
149impl ArrayGenerator for MetadataGenerator {
150    fn generate(
151        &mut self,
152        length: RowCount,
153        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
154    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
155        self.generator.generate(length, rng)
156    }
157
158    fn metadata(&self) -> Option<HashMap<String, String>> {
159        Some(self.metadata.clone())
160    }
161
162    fn data_type(&self) -> &DataType {
163        self.generator.data_type()
164    }
165
166    fn element_size_bytes(&self) -> Option<ByteCount> {
167        self.generator.element_size_bytes()
168    }
169}
170
171#[derive(Debug)]
172pub struct NullGenerator {
173    generator: Box<dyn ArrayGenerator>,
174    null_probability: f64,
175}
176
177impl ArrayGenerator for NullGenerator {
178    fn generate(
179        &mut self,
180        length: RowCount,
181        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
182    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
183        let array = self.generator.generate(length, rng)?;
184        let data = array.to_data();
185
186        if self.null_probability < 0.0 || self.null_probability > 1.0 {
187            return Err(ArrowError::InvalidArgumentError(format!(
188                "null_probability must be between 0 and 1, got {}",
189                self.null_probability
190            )));
191        }
192
193        let (null_count, new_validity) = if self.null_probability == 0.0 {
194            if data.null_count() == 0 {
195                return Ok(array);
196            } else {
197                (0_usize, None)
198            }
199        } else if self.null_probability == 1.0 {
200            if data.null_count() == data.len() {
201                return Ok(array);
202            } else {
203                let all_nulls = BooleanBuffer::new_unset(array.len());
204                (array.len(), Some(all_nulls.into_inner()))
205            }
206        } else {
207            let array_len = array.len();
208            let num_validity_bytes = (array_len + 7) / 8;
209            let mut null_count = 0;
210            // Sampling the RNG once per bit is kind of slow so we do this to sample once
211            // per byte.  We only get 8 bits of RNG resolution but that should be good enough.
212            let threshold = (self.null_probability * u8::MAX as f64) as u8;
213            let bytes = (0..num_validity_bytes)
214                .map(|byte_idx| {
215                    let mut sample = rng.gen::<u64>();
216                    let mut byte: u8 = 0;
217                    for bit_idx in 0..8 {
218                        // We could probably overshoot and fill in extra bits with random data but
219                        // this is cleaner and that would mess up the null count
220                        byte <<= 1;
221                        let pos = byte_idx * 8 + (7 - bit_idx);
222                        if pos < array_len {
223                            let sample_piece = sample & 0xFF;
224                            let is_null = (sample_piece as u8) < threshold;
225                            byte |= (!is_null) as u8;
226                            null_count += is_null as usize;
227                        }
228                        sample >>= 8;
229                    }
230                    byte
231                })
232                .collect::<Vec<_>>();
233            let new_validity = Buffer::from_iter(bytes);
234            (null_count, Some(new_validity))
235        };
236
237        unsafe {
238            let new_data = ArrayData::new_unchecked(
239                data.data_type().clone(),
240                data.len(),
241                Some(null_count),
242                new_validity,
243                data.offset(),
244                data.buffers().to_vec(),
245                data.child_data().into(),
246            );
247            Ok(make_array(new_data))
248        }
249    }
250
251    fn metadata(&self) -> Option<HashMap<String, String>> {
252        self.generator.metadata()
253    }
254
255    fn data_type(&self) -> &DataType {
256        self.generator.data_type()
257    }
258
259    fn element_size_bytes(&self) -> Option<ByteCount> {
260        self.generator.element_size_bytes()
261    }
262}
263
264pub trait ArrayGeneratorExt {
265    /// Replaces the validity bitmap of generated arrays, inserting nulls with a given probability
266    fn with_random_nulls(self, null_probability: f64) -> Box<dyn ArrayGenerator>;
267    /// Replaces the validity bitmap of generated arrays with the inverse of `nulls`, cycling if needed
268    fn with_nulls(self, nulls: &[bool]) -> Box<dyn ArrayGenerator>;
269    /// Replaces the validity bitmap of generated arrays with `validity`, cycling if needed
270    fn with_validity(self, nulls: &[bool]) -> Box<dyn ArrayGenerator>;
271    fn with_metadata(self, metadata: HashMap<String, String>) -> Box<dyn ArrayGenerator>;
272}
273
274impl ArrayGeneratorExt for Box<dyn ArrayGenerator> {
275    fn with_random_nulls(self, null_probability: f64) -> Box<dyn ArrayGenerator> {
276        Box::new(NullGenerator {
277            generator: self,
278            null_probability,
279        })
280    }
281
282    fn with_nulls(self, nulls: &[bool]) -> Box<dyn ArrayGenerator> {
283        Box::new(CycleNullGenerator {
284            generator: self,
285            validity: nulls.iter().map(|v| !*v).collect(),
286            idx: 0,
287        })
288    }
289
290    fn with_validity(self, validity: &[bool]) -> Box<dyn ArrayGenerator> {
291        Box::new(CycleNullGenerator {
292            generator: self,
293            validity: validity.to_vec(),
294            idx: 0,
295        })
296    }
297
298    fn with_metadata(self, metadata: HashMap<String, String>) -> Box<dyn ArrayGenerator> {
299        Box::new(MetadataGenerator {
300            generator: self,
301            metadata,
302        })
303    }
304}
305
306pub struct NTimesIter<I: Iterator>
307where
308    I::Item: Copy,
309{
310    iter: I,
311    n: u32,
312    cur: I::Item,
313    count: u32,
314}
315
316// Note: if this is used then there is a performance hit as the
317// inner loop cannot experience vectorization
318//
319// TODO: maybe faster to build the vec and then repeat it into
320// the destination array?
321impl<I: Iterator> Iterator for NTimesIter<I>
322where
323    I::Item: Copy,
324{
325    type Item = I::Item;
326
327    fn next(&mut self) -> Option<Self::Item> {
328        if self.count == 0 {
329            self.count = self.n - 1;
330            self.cur = self.iter.next()?;
331        } else {
332            self.count -= 1;
333        }
334        Some(self.cur)
335    }
336
337    fn size_hint(&self) -> (usize, Option<usize>) {
338        let (lower, upper) = self.iter.size_hint();
339        let lower = lower * self.n as usize;
340        let upper = upper.map(|u| u * self.n as usize);
341        (lower, upper)
342    }
343}
344
345pub struct FnGen<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T>
346where
347    T: Copy + Default,
348    ArrayType: arrow_array::Array + From<Vec<T>>,
349{
350    data_type: DataType,
351    generator: F,
352    array_type: PhantomData<ArrayType>,
353    repeat: u32,
354    leftover: T,
355    leftover_count: u32,
356    element_size_bytes: Option<ByteCount>,
357}
358
359impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> std::fmt::Debug
360    for FnGen<T, ArrayType, F>
361where
362    T: Copy + Default,
363    ArrayType: arrow_array::Array + From<Vec<T>>,
364{
365    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366        f.debug_struct("FnGen")
367            .field("data_type", &self.data_type)
368            .field("array_type", &self.array_type)
369            .field("repeat", &self.repeat)
370            .field("leftover_count", &self.leftover_count)
371            .field("element_size_bytes", &self.element_size_bytes)
372            .finish()
373    }
374}
375
376impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> FnGen<T, ArrayType, F>
377where
378    T: Copy + Default,
379    ArrayType: arrow_array::Array + From<Vec<T>>,
380{
381    fn new_known_size(
382        data_type: DataType,
383        generator: F,
384        repeat: u32,
385        element_size_bytes: ByteCount,
386    ) -> Self {
387        Self {
388            data_type,
389            generator,
390            array_type: PhantomData,
391            repeat,
392            leftover: T::default(),
393            leftover_count: 0,
394            element_size_bytes: Some(element_size_bytes),
395        }
396    }
397}
398
399impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> ArrayGenerator
400    for FnGen<T, ArrayType, F>
401where
402    T: Copy + Default + Send + Sync,
403    ArrayType: arrow_array::Array + From<Vec<T>> + 'static,
404    F: Send + Sync,
405{
406    fn generate(
407        &mut self,
408        length: RowCount,
409        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
410    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
411        let iter = (0..length.0).map(|_| (self.generator)(rng));
412        let values = if self.repeat > 1 {
413            Vec::from_iter(
414                NTimesIter {
415                    iter,
416                    n: self.repeat,
417                    cur: self.leftover,
418                    count: self.leftover_count,
419                }
420                .take(length.0 as usize),
421            )
422        } else {
423            Vec::from_iter(iter)
424        };
425        self.leftover_count = ((self.leftover_count as u64 + length.0) % self.repeat as u64) as u32;
426        self.leftover = values.last().copied().unwrap_or(T::default());
427        Ok(Arc::new(ArrayType::from(values)))
428    }
429
430    fn data_type(&self) -> &DataType {
431        &self.data_type
432    }
433
434    fn element_size_bytes(&self) -> Option<ByteCount> {
435        self.element_size_bytes
436    }
437}
438
439#[derive(Copy, Clone, Debug)]
440pub struct Seed(pub u64);
441pub const DEFAULT_SEED: Seed = Seed(42);
442
443impl From<u64> for Seed {
444    fn from(n: u64) -> Self {
445        Self(n)
446    }
447}
448
449#[derive(Debug)]
450pub struct CycleVectorGenerator {
451    underlying_gen: Box<dyn ArrayGenerator>,
452    dimension: Dimension,
453    data_type: DataType,
454}
455
456impl CycleVectorGenerator {
457    pub fn new(underlying_gen: Box<dyn ArrayGenerator>, dimension: Dimension) -> Self {
458        let data_type = DataType::FixedSizeList(
459            Arc::new(Field::new("item", underlying_gen.data_type().clone(), true)),
460            dimension.0 as i32,
461        );
462        Self {
463            underlying_gen,
464            dimension,
465            data_type,
466        }
467    }
468}
469
470impl ArrayGenerator for CycleVectorGenerator {
471    fn generate(
472        &mut self,
473        length: RowCount,
474        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
475    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
476        let values = self
477            .underlying_gen
478            .generate(RowCount::from(length.0 * self.dimension.0 as u64), rng)?;
479        let field = Arc::new(Field::new("item", values.data_type().clone(), true));
480        let values = Arc::new(values);
481
482        let array = FixedSizeListArray::try_new(field, self.dimension.0 as i32, values, None)?;
483
484        Ok(Arc::new(array))
485    }
486
487    fn data_type(&self) -> &DataType {
488        &self.data_type
489    }
490
491    fn element_size_bytes(&self) -> Option<ByteCount> {
492        self.underlying_gen
493            .element_size_bytes()
494            .map(|byte_count| ByteCount::from(byte_count.0 * self.dimension.0 as u64))
495    }
496}
497
498#[derive(Debug, Default)]
499pub struct PseudoUuidGenerator {}
500
501impl ArrayGenerator for PseudoUuidGenerator {
502    fn generate(
503        &mut self,
504        length: RowCount,
505        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
506    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
507        Ok(Arc::new(FixedSizeBinaryArray::try_from_iter(
508            (0..length.0).map(|_| {
509                let mut data = vec![0; 16];
510                rng.fill_bytes(&mut data);
511                data
512            }),
513        )?))
514    }
515
516    fn data_type(&self) -> &DataType {
517        &DataType::FixedSizeBinary(16)
518    }
519
520    fn element_size_bytes(&self) -> Option<ByteCount> {
521        Some(ByteCount::from(16))
522    }
523}
524
525#[derive(Debug, Default)]
526pub struct PseudoUuidHexGenerator {}
527
528impl ArrayGenerator for PseudoUuidHexGenerator {
529    fn generate(
530        &mut self,
531        length: RowCount,
532        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
533    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
534        let mut data = vec![0; 16 * length.0 as usize];
535        rng.fill_bytes(&mut data);
536        let data_hex = hex::encode(data);
537
538        Ok(Arc::new(StringArray::from_iter_values(
539            (0..length.0 as usize).map(|i| data_hex.get(i * 32..(i + 1) * 32).unwrap()),
540        )))
541    }
542
543    fn data_type(&self) -> &DataType {
544        &DataType::Utf8
545    }
546
547    fn element_size_bytes(&self) -> Option<ByteCount> {
548        Some(ByteCount::from(16))
549    }
550}
551
552#[derive(Debug, Default)]
553pub struct RandomBooleanGenerator {}
554
555impl ArrayGenerator for RandomBooleanGenerator {
556    fn generate(
557        &mut self,
558        length: RowCount,
559        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
560    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
561        let num_bytes = (length.0 + 7) / 8;
562        let mut bytes = vec![0; num_bytes as usize];
563        rng.fill_bytes(&mut bytes);
564        let bytes = BooleanBuffer::new(Buffer::from(bytes), 0, length.0 as usize);
565        Ok(Arc::new(arrow_array::BooleanArray::new(bytes, None)))
566    }
567
568    fn data_type(&self) -> &DataType {
569        &DataType::Boolean
570    }
571
572    fn element_size_bytes(&self) -> Option<ByteCount> {
573        // We can't say 1/8th of a byte and 1 byte would be a pretty extreme over-count so let's leave
574        // it at None until someone needs this.  Then we can probably special case this (e.g. make a ByteCount::ONE_BIT)
575        None
576    }
577}
578
579// Instead of using the "standard distribution" and generating values there are some cases (e.g. f16 / decimal)
580// where we just generate random bytes because there is no rand support
581pub struct RandomBytesGenerator<T: ArrowPrimitiveType + Send + Sync> {
582    phantom: PhantomData<T>,
583    data_type: DataType,
584}
585
586impl<T: ArrowPrimitiveType + Send + Sync> std::fmt::Debug for RandomBytesGenerator<T> {
587    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
588        f.debug_struct("RandomBytesGenerator")
589            .field("data_type", &self.data_type)
590            .finish()
591    }
592}
593
594impl<T: ArrowPrimitiveType + Send + Sync> RandomBytesGenerator<T> {
595    fn new(data_type: DataType) -> Self {
596        Self {
597            phantom: Default::default(),
598            data_type,
599        }
600    }
601
602    fn byte_width() -> Result<u64, ArrowError> {
603        T::DATA_TYPE.primitive_width().ok_or_else(|| ArrowError::InvalidArgumentError(format!("Cannot generate the data type {} with the RandomBytesGenerator because it is not a fixed-width bytes type", T::DATA_TYPE))).map(|val| val as u64)
604    }
605}
606
607impl<T: ArrowPrimitiveType + Send + Sync> ArrayGenerator for RandomBytesGenerator<T> {
608    fn generate(
609        &mut self,
610        length: RowCount,
611        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
612    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
613        let num_bytes = length.0 * Self::byte_width()?;
614        let mut bytes = vec![0; num_bytes as usize];
615        rng.fill_bytes(&mut bytes);
616        let bytes = ScalarBuffer::new(Buffer::from(bytes), 0, length.0 as usize);
617        Ok(Arc::new(
618            PrimitiveArray::<T>::new(bytes, None).with_data_type(self.data_type.clone()),
619        ))
620    }
621
622    fn data_type(&self) -> &DataType {
623        &self.data_type
624    }
625
626    fn element_size_bytes(&self) -> Option<ByteCount> {
627        Self::byte_width().map(ByteCount::from).ok()
628    }
629}
630
631// This is pretty much the same thing as RandomBinaryGenerator but we can't use that
632// because there is no ArrowPrimitiveType for FixedSizeBinary
633#[derive(Debug)]
634pub struct RandomFixedSizeBinaryGenerator {
635    data_type: DataType,
636    size: i32,
637}
638
639impl RandomFixedSizeBinaryGenerator {
640    fn new(size: i32) -> Self {
641        Self {
642            size,
643            data_type: DataType::FixedSizeBinary(size),
644        }
645    }
646}
647
648impl ArrayGenerator for RandomFixedSizeBinaryGenerator {
649    fn generate(
650        &mut self,
651        length: RowCount,
652        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
653    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
654        let num_bytes = length.0 * self.size as u64;
655        let mut bytes = vec![0; num_bytes as usize];
656        rng.fill_bytes(&mut bytes);
657        Ok(Arc::new(FixedSizeBinaryArray::new(
658            self.size,
659            Buffer::from(bytes),
660            None,
661        )))
662    }
663
664    fn data_type(&self) -> &DataType {
665        &self.data_type
666    }
667
668    fn element_size_bytes(&self) -> Option<ByteCount> {
669        Some(ByteCount::from(self.size as u64))
670    }
671}
672
673#[derive(Debug)]
674pub struct RandomIntervalGenerator {
675    unit: IntervalUnit,
676    data_type: DataType,
677}
678
679impl RandomIntervalGenerator {
680    pub fn new(unit: IntervalUnit) -> Self {
681        Self {
682            unit,
683            data_type: DataType::Interval(unit),
684        }
685    }
686}
687
688impl ArrayGenerator for RandomIntervalGenerator {
689    fn generate(
690        &mut self,
691        length: RowCount,
692        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
693    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
694        match self.unit {
695            IntervalUnit::YearMonth => {
696                let months = (0..length.0).map(|_| rng.gen::<i32>()).collect::<Vec<_>>();
697                Ok(Arc::new(arrow_array::IntervalYearMonthArray::from(months)))
698            }
699            IntervalUnit::MonthDayNano => {
700                let day_time_array = (0..length.0)
701                    .map(|_| IntervalMonthDayNano::new(rng.gen(), rng.gen(), rng.gen()))
702                    .collect::<Vec<_>>();
703                Ok(Arc::new(arrow_array::IntervalMonthDayNanoArray::from(
704                    day_time_array,
705                )))
706            }
707            IntervalUnit::DayTime => {
708                let day_time_array = (0..length.0)
709                    .map(|_| IntervalDayTime::new(rng.gen(), rng.gen()))
710                    .collect::<Vec<_>>();
711                Ok(Arc::new(arrow_array::IntervalDayTimeArray::from(
712                    day_time_array,
713                )))
714            }
715        }
716    }
717
718    fn data_type(&self) -> &DataType {
719        &self.data_type
720    }
721
722    fn element_size_bytes(&self) -> Option<ByteCount> {
723        Some(ByteCount::from(12))
724    }
725}
726#[derive(Debug)]
727pub struct RandomBinaryGenerator {
728    bytes_per_element: ByteCount,
729    scale_to_utf8: bool,
730    is_large: bool,
731    data_type: DataType,
732}
733
734impl RandomBinaryGenerator {
735    pub fn new(bytes_per_element: ByteCount, scale_to_utf8: bool, is_large: bool) -> Self {
736        Self {
737            bytes_per_element,
738            scale_to_utf8,
739            is_large,
740            data_type: match (scale_to_utf8, is_large) {
741                (false, false) => DataType::Binary,
742                (false, true) => DataType::LargeBinary,
743                (true, false) => DataType::Utf8,
744                (true, true) => DataType::LargeUtf8,
745            },
746        }
747    }
748}
749
750impl ArrayGenerator for RandomBinaryGenerator {
751    fn generate(
752        &mut self,
753        length: RowCount,
754        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
755    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
756        let mut bytes = vec![0; (self.bytes_per_element.0 * length.0) as usize];
757        rng.fill_bytes(&mut bytes);
758        if self.scale_to_utf8 {
759            // This doesn't give us the full UTF-8 range and it isn't statistically correct but
760            // it's fast and probably good enough for most cases
761            bytes = bytes.into_iter().map(|val| (val % 95) + 32).collect();
762        }
763        let bytes = Buffer::from(bytes);
764        if self.is_large {
765            let offsets = OffsetBuffer::from_lengths(
766                iter::repeat(self.bytes_per_element.0 as usize).take(length.0 as usize),
767            );
768            if self.scale_to_utf8 {
769                // This is safe because we are only using printable characters
770                unsafe {
771                    Ok(Arc::new(arrow_array::LargeStringArray::new_unchecked(
772                        offsets, bytes, None,
773                    )))
774                }
775            } else {
776                unsafe {
777                    Ok(Arc::new(arrow_array::LargeBinaryArray::new_unchecked(
778                        offsets, bytes, None,
779                    )))
780                }
781            }
782        } else {
783            let offsets = OffsetBuffer::from_lengths(
784                iter::repeat(self.bytes_per_element.0 as usize).take(length.0 as usize),
785            );
786            if self.scale_to_utf8 {
787                // This is safe because we are only using printable characters
788                unsafe {
789                    Ok(Arc::new(arrow_array::StringArray::new_unchecked(
790                        offsets, bytes, None,
791                    )))
792                }
793            } else {
794                unsafe {
795                    Ok(Arc::new(arrow_array::BinaryArray::new_unchecked(
796                        offsets, bytes, None,
797                    )))
798                }
799            }
800        }
801    }
802
803    fn data_type(&self) -> &DataType {
804        &self.data_type
805    }
806
807    fn element_size_bytes(&self) -> Option<ByteCount> {
808        // Not exactly correct since there are N + 1 4-byte offsets and this only counts N
809        Some(ByteCount::from(
810            self.bytes_per_element.0 + std::mem::size_of::<i32>() as u64,
811        ))
812    }
813}
814
815#[derive(Debug)]
816pub struct VariableRandomBinaryGenerator {
817    lengths_gen: Box<dyn ArrayGenerator>,
818    data_type: DataType,
819}
820
821impl VariableRandomBinaryGenerator {
822    pub fn new(min_bytes_per_element: ByteCount, max_bytes_per_element: ByteCount) -> Self {
823        let lengths_dist = Uniform::new_inclusive(
824            min_bytes_per_element.0 as i32,
825            max_bytes_per_element.0 as i32,
826        );
827        let lengths_gen = rand_with_distribution::<Int32Type, Uniform<i32>>(lengths_dist);
828
829        Self {
830            lengths_gen,
831            data_type: DataType::Binary,
832        }
833    }
834}
835
836impl ArrayGenerator for VariableRandomBinaryGenerator {
837    fn generate(
838        &mut self,
839        length: RowCount,
840        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
841    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
842        let lengths = self.lengths_gen.generate(length, rng)?;
843        let lengths = lengths.as_primitive::<Int32Type>();
844        let total_length = lengths.values().iter().map(|i| *i as usize).sum::<usize>();
845        let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
846        let mut bytes = vec![0; total_length];
847        rng.fill_bytes(&mut bytes);
848        let bytes = Buffer::from(bytes);
849        Ok(Arc::new(BinaryArray::try_new(offsets, bytes, None)?))
850    }
851
852    fn data_type(&self) -> &DataType {
853        &self.data_type
854    }
855
856    fn element_size_bytes(&self) -> Option<ByteCount> {
857        None
858    }
859}
860
861pub struct CycleBinaryGenerator<T: ByteArrayType> {
862    values: Vec<u8>,
863    lengths: Vec<usize>,
864    data_type: DataType,
865    array_type: PhantomData<T>,
866    width: Option<ByteCount>,
867    idx: usize,
868}
869
870impl<T: ByteArrayType> std::fmt::Debug for CycleBinaryGenerator<T> {
871    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
872        f.debug_struct("CycleBinaryGenerator")
873            .field("values", &self.values)
874            .field("lengths", &self.lengths)
875            .field("data_type", &self.data_type)
876            .field("width", &self.width)
877            .field("idx", &self.idx)
878            .finish()
879    }
880}
881
882impl<T: ByteArrayType> CycleBinaryGenerator<T> {
883    pub fn from_strings(values: &[&str]) -> Self {
884        if values.is_empty() {
885            panic!("Attempt to create a cycle generator with no values");
886        }
887        let lengths = values.iter().map(|s| s.len()).collect::<Vec<_>>();
888        let typical_length = lengths[0];
889        let width = if lengths.iter().all(|item| *item == typical_length) {
890            Some(ByteCount::from(
891                typical_length as u64 + std::mem::size_of::<i32>() as u64,
892            ))
893        } else {
894            None
895        };
896        let values = values
897            .iter()
898            .flat_map(|s| s.as_bytes().iter().copied())
899            .collect::<Vec<_>>();
900        Self {
901            values,
902            lengths,
903            data_type: T::DATA_TYPE,
904            array_type: PhantomData,
905            width,
906            idx: 0,
907        }
908    }
909}
910
911impl<T: ByteArrayType> ArrayGenerator for CycleBinaryGenerator<T> {
912    fn generate(
913        &mut self,
914        length: RowCount,
915        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
916    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
917        let lengths = self
918            .lengths
919            .iter()
920            .copied()
921            .cycle()
922            .skip(self.idx)
923            .take(length.0 as usize);
924        let num_bytes = lengths.clone().sum();
925        let byte_offset = self.lengths[0..self.idx].iter().sum();
926        let bytes = self
927            .values
928            .iter()
929            .cycle()
930            .skip(byte_offset)
931            .copied()
932            .take(num_bytes)
933            .collect::<Vec<_>>();
934        let bytes = Buffer::from(bytes);
935        let offsets = OffsetBuffer::from_lengths(lengths);
936        self.idx = (self.idx + length.0 as usize) % self.lengths.len();
937        Ok(Arc::new(arrow_array::GenericByteArray::<T>::new(
938            offsets, bytes, None,
939        )))
940    }
941
942    fn data_type(&self) -> &DataType {
943        &self.data_type
944    }
945
946    fn element_size_bytes(&self) -> Option<ByteCount> {
947        self.width
948    }
949}
950
951pub struct FixedBinaryGenerator<T: ByteArrayType> {
952    value: Vec<u8>,
953    data_type: DataType,
954    array_type: PhantomData<T>,
955}
956
957impl<T: ByteArrayType> std::fmt::Debug for FixedBinaryGenerator<T> {
958    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
959        f.debug_struct("FixedBinaryGenerator")
960            .field("value", &self.value)
961            .field("data_type", &self.data_type)
962            .finish()
963    }
964}
965
966impl<T: ByteArrayType> FixedBinaryGenerator<T> {
967    pub fn new(value: Vec<u8>) -> Self {
968        Self {
969            value,
970            data_type: T::DATA_TYPE,
971            array_type: PhantomData,
972        }
973    }
974}
975
976impl<T: ByteArrayType> ArrayGenerator for FixedBinaryGenerator<T> {
977    fn generate(
978        &mut self,
979        length: RowCount,
980        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
981    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
982        let bytes = Buffer::from(Vec::from_iter(
983            self.value
984                .iter()
985                .cycle()
986                .take((length.0 * self.value.len() as u64) as usize)
987                .copied(),
988        ));
989        let offsets =
990            OffsetBuffer::from_lengths(iter::repeat(self.value.len()).take(length.0 as usize));
991        Ok(Arc::new(arrow_array::GenericByteArray::<T>::new(
992            offsets, bytes, None,
993        )))
994    }
995
996    fn data_type(&self) -> &DataType {
997        &self.data_type
998    }
999
1000    fn element_size_bytes(&self) -> Option<ByteCount> {
1001        // Not exactly correct since there are N + 1 4-byte offsets and this only counts N
1002        Some(ByteCount::from(
1003            self.value.len() as u64 + std::mem::size_of::<i32>() as u64,
1004        ))
1005    }
1006}
1007
1008pub struct DictionaryGenerator<K: ArrowDictionaryKeyType> {
1009    generator: Box<dyn ArrayGenerator>,
1010    data_type: DataType,
1011    key_type: PhantomData<K>,
1012    key_width: u64,
1013}
1014
1015impl<K: ArrowDictionaryKeyType> std::fmt::Debug for DictionaryGenerator<K> {
1016    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1017        f.debug_struct("DictionaryGenerator")
1018            .field("generator", &self.generator)
1019            .field("data_type", &self.data_type)
1020            .field("key_width", &self.key_width)
1021            .finish()
1022    }
1023}
1024
1025impl<K: ArrowDictionaryKeyType> DictionaryGenerator<K> {
1026    fn new(generator: Box<dyn ArrayGenerator>) -> Self {
1027        let key_type = Box::new(K::DATA_TYPE);
1028        let key_width = key_type
1029            .primitive_width()
1030            .expect("dictionary key types should have a known width")
1031            as u64;
1032        let val_type = Box::new(generator.data_type().clone());
1033        let dict_type = DataType::Dictionary(key_type, val_type);
1034        Self {
1035            generator,
1036            data_type: dict_type,
1037            key_type: PhantomData,
1038            key_width,
1039        }
1040    }
1041}
1042
1043impl<K: ArrowDictionaryKeyType + Send + Sync> ArrayGenerator for DictionaryGenerator<K> {
1044    fn generate(
1045        &mut self,
1046        length: RowCount,
1047        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1048    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1049        let underlying = self.generator.generate(length, rng)?;
1050        arrow_cast::cast::cast(&underlying, &self.data_type)
1051    }
1052
1053    fn data_type(&self) -> &DataType {
1054        &self.data_type
1055    }
1056
1057    fn element_size_bytes(&self) -> Option<ByteCount> {
1058        self.generator
1059            .element_size_bytes()
1060            .map(|size_bytes| ByteCount::from(size_bytes.0 + self.key_width))
1061    }
1062}
1063
1064#[derive(Debug)]
1065struct RandomListGenerator {
1066    field: Arc<Field>,
1067    child_field: Arc<Field>,
1068    items_gen: Box<dyn ArrayGenerator>,
1069    lengths_gen: Box<dyn ArrayGenerator>,
1070    is_large: bool,
1071}
1072
1073impl RandomListGenerator {
1074    // Creates a list generator that generates random lists with lengths between 0 and 10 (inclusive)
1075    fn new(items_gen: Box<dyn ArrayGenerator>, is_large: bool) -> Self {
1076        let child_field = Arc::new(Field::new("item", items_gen.data_type().clone(), true));
1077        let list_type = if is_large {
1078            DataType::LargeList(child_field.clone())
1079        } else {
1080            DataType::List(child_field.clone())
1081        };
1082        let field = Field::new("", list_type, true);
1083        let lengths_gen = if is_large {
1084            let lengths_dist = Uniform::new_inclusive(0, 10);
1085            rand_with_distribution::<Int64Type, Uniform<i64>>(lengths_dist)
1086        } else {
1087            let lengths_dist = Uniform::new_inclusive(0, 10);
1088            rand_with_distribution::<Int32Type, Uniform<i32>>(lengths_dist)
1089        };
1090        Self {
1091            field: Arc::new(field),
1092            child_field,
1093            items_gen,
1094            lengths_gen,
1095            is_large,
1096        }
1097    }
1098}
1099
1100impl ArrayGenerator for RandomListGenerator {
1101    fn generate(
1102        &mut self,
1103        length: RowCount,
1104        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1105    ) -> Result<Arc<dyn Array>, ArrowError> {
1106        let lengths = self.lengths_gen.generate(length, rng)?;
1107        if self.is_large {
1108            let lengths = lengths.as_primitive::<Int64Type>();
1109            let total_length = lengths.values().iter().sum::<i64>() as u64;
1110            let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1111            let items = self.items_gen.generate(RowCount::from(total_length), rng)?;
1112            Ok(Arc::new(LargeListArray::try_new(
1113                self.child_field.clone(),
1114                offsets,
1115                items,
1116                None,
1117            )?))
1118        } else {
1119            let lengths = lengths.as_primitive::<Int32Type>();
1120            let total_length = lengths.values().iter().sum::<i32>() as u64;
1121            let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1122            let items = self.items_gen.generate(RowCount::from(total_length), rng)?;
1123            Ok(Arc::new(ListArray::try_new(
1124                self.child_field.clone(),
1125                offsets,
1126                items,
1127                None,
1128            )?))
1129        }
1130    }
1131
1132    fn data_type(&self) -> &DataType {
1133        self.field.data_type()
1134    }
1135
1136    fn element_size_bytes(&self) -> Option<ByteCount> {
1137        None
1138    }
1139}
1140
1141#[derive(Debug)]
1142struct NullArrayGenerator {}
1143
1144impl ArrayGenerator for NullArrayGenerator {
1145    fn generate(
1146        &mut self,
1147        length: RowCount,
1148        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
1149    ) -> Result<Arc<dyn Array>, ArrowError> {
1150        Ok(Arc::new(NullArray::new(length.0 as usize)))
1151    }
1152
1153    fn data_type(&self) -> &DataType {
1154        &DataType::Null
1155    }
1156
1157    fn element_size_bytes(&self) -> Option<ByteCount> {
1158        None
1159    }
1160}
1161
1162#[derive(Debug)]
1163struct RandomStructGenerator {
1164    fields: Fields,
1165    data_type: DataType,
1166    child_gens: Vec<Box<dyn ArrayGenerator>>,
1167}
1168
1169impl RandomStructGenerator {
1170    fn new(fields: Fields, child_gens: Vec<Box<dyn ArrayGenerator>>) -> Self {
1171        let data_type = DataType::Struct(fields.clone());
1172        Self {
1173            fields,
1174            data_type,
1175            child_gens,
1176        }
1177    }
1178}
1179
1180impl ArrayGenerator for RandomStructGenerator {
1181    fn generate(
1182        &mut self,
1183        length: RowCount,
1184        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1185    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1186        if self.child_gens.is_empty() {
1187            // Have to create empty struct arrays specially to ensure they have the correct
1188            // row count
1189            let struct_arr = StructArray::new_empty_fields(length.0 as usize, None);
1190            return Ok(Arc::new(struct_arr));
1191        }
1192        let child_arrays = self
1193            .child_gens
1194            .iter_mut()
1195            .map(|gen| gen.generate(length, rng))
1196            .collect::<Result<Vec<_>, ArrowError>>()?;
1197        let struct_arr = StructArray::new(self.fields.clone(), child_arrays, None);
1198        Ok(Arc::new(struct_arr))
1199    }
1200
1201    fn data_type(&self) -> &DataType {
1202        &self.data_type
1203    }
1204
1205    fn element_size_bytes(&self) -> Option<ByteCount> {
1206        let mut sum = 0;
1207        for child_gen in &self.child_gens {
1208            sum += child_gen.element_size_bytes()?.0;
1209        }
1210        Some(ByteCount::from(sum))
1211    }
1212}
1213
1214/// A RecordBatchReader that generates batches of the given size from the given array generators
1215pub struct FixedSizeBatchGenerator {
1216    rng: rand_xoshiro::Xoshiro256PlusPlus,
1217    generators: Vec<Box<dyn ArrayGenerator>>,
1218    batch_size: RowCount,
1219    num_batches: BatchCount,
1220    schema: SchemaRef,
1221}
1222
1223impl FixedSizeBatchGenerator {
1224    fn new(
1225        generators: Vec<(Option<String>, Box<dyn ArrayGenerator>)>,
1226        batch_size: RowCount,
1227        num_batches: BatchCount,
1228        seed: Option<Seed>,
1229        default_null_probability: Option<f64>,
1230    ) -> Self {
1231        let mut fields = Vec::with_capacity(generators.len());
1232        for (field_index, field_gen) in generators.iter().enumerate() {
1233            let (name, gen) = field_gen;
1234            let default_name = format!("field_{}", field_index);
1235            let name = name.clone().unwrap_or(default_name);
1236            let mut field = Field::new(name, gen.data_type().clone(), true);
1237            if let Some(metadata) = gen.metadata() {
1238                field = field.with_metadata(metadata);
1239            }
1240            fields.push(field);
1241        }
1242        let mut generators = generators
1243            .into_iter()
1244            .map(|(_, gen)| gen)
1245            .collect::<Vec<_>>();
1246        if let Some(null_probability) = default_null_probability {
1247            generators = generators
1248                .into_iter()
1249                .map(|gen| gen.with_random_nulls(null_probability))
1250                .collect();
1251        }
1252        let schema = Arc::new(Schema::new(fields));
1253        Self {
1254            rng: rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(
1255                seed.map(|s| s.0).unwrap_or(DEFAULT_SEED.0),
1256            ),
1257            generators,
1258            batch_size,
1259            num_batches,
1260            schema,
1261        }
1262    }
1263
1264    fn gen_next(&mut self) -> Result<RecordBatch, ArrowError> {
1265        let mut arrays = Vec::with_capacity(self.generators.len());
1266        for gen in self.generators.iter_mut() {
1267            let arr = gen.generate(self.batch_size, &mut self.rng)?;
1268            arrays.push(arr);
1269        }
1270        self.num_batches.0 -= 1;
1271        Ok(RecordBatch::try_new_with_options(
1272            self.schema.clone(),
1273            arrays,
1274            &RecordBatchOptions::new().with_row_count(Some(self.batch_size.0 as usize)),
1275        )
1276        .unwrap())
1277    }
1278}
1279
1280impl Iterator for FixedSizeBatchGenerator {
1281    type Item = Result<RecordBatch, ArrowError>;
1282
1283    fn next(&mut self) -> Option<Self::Item> {
1284        if self.num_batches.0 == 0 {
1285            return None;
1286        }
1287        Some(self.gen_next())
1288    }
1289}
1290
1291impl RecordBatchReader for FixedSizeBatchGenerator {
1292    fn schema(&self) -> SchemaRef {
1293        self.schema.clone()
1294    }
1295}
1296
1297/// A builder to create a record batch reader with generated data
1298///
1299/// This type is meant to be used in a fluent builder style to define the schema and generators
1300/// for a record batch reader.
1301#[derive(Default)]
1302pub struct BatchGeneratorBuilder {
1303    generators: Vec<(Option<String>, Box<dyn ArrayGenerator>)>,
1304    default_null_probability: Option<f64>,
1305    seed: Option<Seed>,
1306}
1307
1308pub enum RoundingBehavior {
1309    ExactOrErr,
1310    RoundUp,
1311    RoundDown,
1312}
1313
1314impl BatchGeneratorBuilder {
1315    /// Create a new BatchGeneratorBuilder with a default random seed
1316    pub fn new() -> Self {
1317        Default::default()
1318    }
1319
1320    /// Create a new BatchGeneratorBuilder with the given seed
1321    pub fn new_with_seed(seed: Seed) -> Self {
1322        Self {
1323            seed: Some(seed),
1324            ..Default::default()
1325        }
1326    }
1327
1328    /// Adds a new column to the generator
1329    ///
1330    /// See [`crate::generator::array`] for methods to create generators
1331    pub fn col(mut self, name: impl Into<String>, gen: Box<dyn ArrayGenerator>) -> Self {
1332        self.generators.push((Some(name.into()), gen));
1333        self
1334    }
1335
1336    /// Adds a new column to the generator with a generated unique name
1337    ///
1338    /// See [`crate::generator::array`] for methods to create generators
1339    pub fn anon_col(mut self, gen: Box<dyn ArrayGenerator>) -> Self {
1340        self.generators.push((None, gen));
1341        self
1342    }
1343
1344    pub fn into_batch_rows(self, batch_size: RowCount) -> Result<RecordBatch, ArrowError> {
1345        let mut reader = self.into_reader_rows(batch_size, BatchCount::from(1));
1346        reader
1347            .next()
1348            .expect("Asked for 1 batch but reader was empty")
1349    }
1350
1351    pub fn into_batch_bytes(
1352        self,
1353        batch_size: ByteCount,
1354        rounding: RoundingBehavior,
1355    ) -> Result<RecordBatch, ArrowError> {
1356        let mut reader = self.into_reader_bytes(batch_size, BatchCount::from(1), rounding)?;
1357        reader
1358            .next()
1359            .expect("Asked for 1 batch but reader was empty")
1360    }
1361
1362    /// Create a RecordBatchReader that generates batches of the given size (in rows)
1363    pub fn into_reader_rows(
1364        self,
1365        batch_size: RowCount,
1366        num_batches: BatchCount,
1367    ) -> impl RecordBatchReader {
1368        FixedSizeBatchGenerator::new(
1369            self.generators,
1370            batch_size,
1371            num_batches,
1372            self.seed,
1373            self.default_null_probability,
1374        )
1375    }
1376
1377    pub fn into_reader_stream(
1378        self,
1379        batch_size: RowCount,
1380        num_batches: BatchCount,
1381    ) -> (
1382        BoxStream<'static, Result<RecordBatch, ArrowError>>,
1383        Arc<Schema>,
1384    ) {
1385        // TODO: this is pretty lazy and could be optimized
1386        let reader = self.into_reader_rows(batch_size, num_batches);
1387        let schema = reader.schema();
1388        let batches = reader.collect::<Vec<_>>();
1389        (futures::stream::iter(batches).boxed(), schema)
1390    }
1391
1392    /// Create a RecordBatchReader that generates batches of the given size (in bytes)
1393    pub fn into_reader_bytes(
1394        self,
1395        batch_size_bytes: ByteCount,
1396        num_batches: BatchCount,
1397        rounding: RoundingBehavior,
1398    ) -> Result<impl RecordBatchReader, ArrowError> {
1399        let bytes_per_row = self
1400            .generators
1401            .iter()
1402            .map(|gen| gen.1.element_size_bytes().map(|byte_count| byte_count.0).ok_or(
1403                        ArrowError::NotYetImplemented("The function into_reader_bytes currently requires each array generator to have a fixed element size".to_string())
1404                )
1405            )
1406            .sum::<Result<u64, ArrowError>>()?;
1407        let mut num_rows = RowCount::from(batch_size_bytes.0 / bytes_per_row);
1408        if batch_size_bytes.0 % bytes_per_row != 0 {
1409            match rounding {
1410                RoundingBehavior::ExactOrErr => {
1411                    return Err(ArrowError::NotYetImplemented(
1412                        format!("Exact rounding requested but not possible.  Batch size requested {}, row size: {}", batch_size_bytes.0, bytes_per_row))
1413                    );
1414                }
1415                RoundingBehavior::RoundUp => {
1416                    num_rows = RowCount::from(num_rows.0 + 1);
1417                }
1418                RoundingBehavior::RoundDown => (),
1419            }
1420        }
1421        Ok(self.into_reader_rows(num_rows, num_batches))
1422    }
1423
1424    /// Set the seed for the generator
1425    pub fn with_seed(mut self, seed: Seed) -> Self {
1426        self.seed = Some(seed);
1427        self
1428    }
1429
1430    /// Adds nulls (with the given probability) to all columns
1431    pub fn with_random_nulls(&mut self, default_null_probability: f64) {
1432        self.default_null_probability = Some(default_null_probability);
1433    }
1434}
1435
1436/// Factory for creating a single random array
1437pub struct ArrayGeneratorBuilder {
1438    generator: Box<dyn ArrayGenerator>,
1439    seed: Option<Seed>,
1440}
1441
1442impl ArrayGeneratorBuilder {
1443    fn new(generator: Box<dyn ArrayGenerator>) -> Self {
1444        Self {
1445            generator,
1446            seed: None,
1447        }
1448    }
1449
1450    /// Use the given seed for the generator
1451    pub fn with_seed(mut self, seed: Seed) -> Self {
1452        self.seed = Some(seed);
1453        self
1454    }
1455
1456    /// Generate a single array with the given length
1457    pub fn into_array_rows(
1458        mut self,
1459        length: RowCount,
1460    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1461        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(
1462            self.seed.map(|s| s.0).unwrap_or(DEFAULT_SEED.0),
1463        );
1464        self.generator.generate(length, &mut rng)
1465    }
1466}
1467
1468const MS_PER_DAY: i64 = 86400000;
1469
1470pub mod array {
1471
1472    use arrow::datatypes::{Int16Type, Int64Type, Int8Type};
1473    use arrow_array::types::{
1474        Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType,
1475        DurationNanosecondType, DurationSecondType, Float16Type, Float32Type, Float64Type,
1476        UInt16Type, UInt32Type, UInt64Type, UInt8Type,
1477    };
1478    use arrow_array::{
1479        ArrowNativeTypeOp, Date32Array, Date64Array, Time32MillisecondArray, Time32SecondArray,
1480        Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
1481        TimestampNanosecondArray, TimestampSecondArray,
1482    };
1483    use arrow_schema::{IntervalUnit, TimeUnit};
1484    use chrono::Utc;
1485    use rand::prelude::Distribution;
1486
1487    use super::*;
1488
1489    /// Create a generator of vectors by continuously calling the given generator
1490    ///
1491    /// For example, given a step generator and a dimension of 3 this will generate vectors like
1492    /// [0, 1, 2], [3, 4, 5], [6, 7, 8], ...
1493    pub fn cycle_vec(
1494        generator: Box<dyn ArrayGenerator>,
1495        dimension: Dimension,
1496    ) -> Box<dyn ArrayGenerator> {
1497        Box::new(CycleVectorGenerator::new(generator, dimension))
1498    }
1499
1500    /// Create a generator from a vector of values
1501    ///
1502    /// If more rows are requested than the length of values then it will restart
1503    /// from the beginning of the vector.
1504    pub fn cycle<DataType>(values: Vec<DataType::Native>) -> Box<dyn ArrayGenerator>
1505    where
1506        DataType::Native: Copy + 'static,
1507        DataType: ArrowPrimitiveType,
1508        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1509    {
1510        let mut values_idx = 0;
1511        Box::new(
1512            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1513                DataType::DATA_TYPE,
1514                move |_| {
1515                    let y = values[values_idx];
1516                    values_idx = (values_idx + 1) % values.len();
1517                    y
1518                },
1519                1,
1520                DataType::DATA_TYPE
1521                    .primitive_width()
1522                    .map(|width| ByteCount::from(width as u64))
1523                    .expect("Primitive types should have a fixed width"),
1524            ),
1525        )
1526    }
1527
1528    /// Create a generator that starts at 0 and increments by 1 for each element
1529    pub fn step<DataType>() -> Box<dyn ArrayGenerator>
1530    where
1531        DataType::Native: Copy + Default + std::ops::AddAssign<DataType::Native> + 'static,
1532        DataType: ArrowPrimitiveType,
1533        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1534    {
1535        let mut x = DataType::Native::default();
1536        Box::new(
1537            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1538                DataType::DATA_TYPE,
1539                move |_| {
1540                    let y = x;
1541                    x += DataType::Native::ONE;
1542                    y
1543                },
1544                1,
1545                DataType::DATA_TYPE
1546                    .primitive_width()
1547                    .map(|width| ByteCount::from(width as u64))
1548                    .expect("Primitive types should have a fixed width"),
1549            ),
1550        )
1551    }
1552
1553    pub fn blob() -> Box<dyn ArrayGenerator> {
1554        let mut blob_meta = HashMap::new();
1555        blob_meta.insert("lance-encoding:blob".to_string(), "true".to_string());
1556        rand_fixedbin(ByteCount::from(4 * 1024 * 1024), true).with_metadata(blob_meta)
1557    }
1558
1559    /// Create a generator that starts at a given value and increments by a given step for each element
1560    pub fn step_custom<DataType>(
1561        start: DataType::Native,
1562        step: DataType::Native,
1563    ) -> Box<dyn ArrayGenerator>
1564    where
1565        DataType::Native: Copy + Default + std::ops::AddAssign<DataType::Native> + 'static,
1566        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1567        DataType: ArrowPrimitiveType,
1568    {
1569        let mut x = start;
1570        Box::new(
1571            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1572                DataType::DATA_TYPE,
1573                move |_| {
1574                    let y = x;
1575                    x += step;
1576                    y
1577                },
1578                1,
1579                DataType::DATA_TYPE
1580                    .primitive_width()
1581                    .map(|width| ByteCount::from(width as u64))
1582                    .expect("Primitive types should have a fixed width"),
1583            ),
1584        )
1585    }
1586
1587    /// Create a generator that fills each element with the given primitive value
1588    pub fn fill<DataType>(value: DataType::Native) -> Box<dyn ArrayGenerator>
1589    where
1590        DataType::Native: Copy + 'static,
1591        DataType: ArrowPrimitiveType,
1592        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1593    {
1594        Box::new(
1595            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1596                DataType::DATA_TYPE,
1597                move |_| value,
1598                1,
1599                DataType::DATA_TYPE
1600                    .primitive_width()
1601                    .map(|width| ByteCount::from(width as u64))
1602                    .expect("Primitive types should have a fixed width"),
1603            ),
1604        )
1605    }
1606
1607    /// Create a generator that fills each element with the given binary value
1608    pub fn fill_varbin(value: Vec<u8>) -> Box<dyn ArrayGenerator> {
1609        Box::new(FixedBinaryGenerator::<BinaryType>::new(value))
1610    }
1611
1612    /// Create a generator that fills each element with the given string value
1613    pub fn fill_utf8(value: String) -> Box<dyn ArrayGenerator> {
1614        Box::new(FixedBinaryGenerator::<Utf8Type>::new(value.into_bytes()))
1615    }
1616
1617    pub fn cycle_utf8_literals(values: &[&'static str]) -> Box<dyn ArrayGenerator> {
1618        Box::new(CycleBinaryGenerator::<Utf8Type>::from_strings(values))
1619    }
1620
1621    /// Create a generator of primitive values that are randomly sampled from the entire range available for the value
1622    pub fn rand<DataType>() -> Box<dyn ArrayGenerator>
1623    where
1624        DataType::Native: Copy + 'static,
1625        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1626        DataType: ArrowPrimitiveType,
1627        rand::distributions::Standard: rand::distributions::Distribution<DataType::Native>,
1628    {
1629        Box::new(
1630            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1631                DataType::DATA_TYPE,
1632                move |rng| rng.gen(),
1633                1,
1634                DataType::DATA_TYPE
1635                    .primitive_width()
1636                    .map(|width| ByteCount::from(width as u64))
1637                    .expect("Primitive types should have a fixed width"),
1638            ),
1639        )
1640    }
1641
1642    /// Create a generator of primitive values that are randomly sampled from the entire range available for the value
1643    pub fn rand_with_distribution<
1644        DataType,
1645        Dist: rand::distributions::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
1646    >(
1647        dist: Dist,
1648    ) -> Box<dyn ArrayGenerator>
1649    where
1650        DataType::Native: Copy + 'static,
1651        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1652        DataType: ArrowPrimitiveType,
1653    {
1654        Box::new(
1655            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
1656                DataType::DATA_TYPE,
1657                move |rng| rng.sample(dist.clone()),
1658                1,
1659                DataType::DATA_TYPE
1660                    .primitive_width()
1661                    .map(|width| ByteCount::from(width as u64))
1662                    .expect("Primitive types should have a fixed width"),
1663            ),
1664        )
1665    }
1666
1667    /// Create a generator of 1d vectors (of a primitive type) consisting of randomly sampled primitive values
1668    pub fn rand_vec<DataType>(dimension: Dimension) -> Box<dyn ArrayGenerator>
1669    where
1670        DataType::Native: Copy + 'static,
1671        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
1672        DataType: ArrowPrimitiveType,
1673        rand::distributions::Standard: rand::distributions::Distribution<DataType::Native>,
1674    {
1675        let underlying = rand::<DataType>();
1676        cycle_vec(underlying, dimension)
1677    }
1678
1679    /// Create a generator of randomly sampled time32 values covering the entire
1680    /// range of 1 day
1681    pub fn rand_time32(resolution: &TimeUnit) -> Box<dyn ArrayGenerator> {
1682        let start = 0;
1683        let end = match resolution {
1684            TimeUnit::Second => 86_400,
1685            TimeUnit::Millisecond => 86_400_000,
1686            _ => panic!(),
1687        };
1688
1689        let data_type = DataType::Time32(*resolution);
1690        let size = ByteCount::from(data_type.primitive_width().unwrap() as u64);
1691        let dist = Uniform::new(start, end);
1692        let sample_fn = move |rng: &mut _| dist.sample(rng);
1693
1694        match resolution {
1695            TimeUnit::Second => Box::new(FnGen::<i32, Time32SecondArray, _>::new_known_size(
1696                data_type, sample_fn, 1, size,
1697            )),
1698            TimeUnit::Millisecond => {
1699                Box::new(FnGen::<i32, Time32MillisecondArray, _>::new_known_size(
1700                    data_type, sample_fn, 1, size,
1701                ))
1702            }
1703            _ => panic!(),
1704        }
1705    }
1706
1707    /// Create a generator of randomly sampled time64 values covering the entire
1708    /// range of 1 day
1709    pub fn rand_time64(resolution: &TimeUnit) -> Box<dyn ArrayGenerator> {
1710        let start = 0_i64;
1711        let end: i64 = match resolution {
1712            TimeUnit::Microsecond => 86_400_000,
1713            TimeUnit::Nanosecond => 86_400_000_000,
1714            _ => panic!(),
1715        };
1716
1717        let data_type = DataType::Time64(*resolution);
1718        let size = ByteCount::from(data_type.primitive_width().unwrap() as u64);
1719        let dist = Uniform::new(start, end);
1720        let sample_fn = move |rng: &mut _| dist.sample(rng);
1721
1722        match resolution {
1723            TimeUnit::Microsecond => {
1724                Box::new(FnGen::<i64, Time64MicrosecondArray, _>::new_known_size(
1725                    data_type, sample_fn, 1, size,
1726                ))
1727            }
1728            TimeUnit::Nanosecond => {
1729                Box::new(FnGen::<i64, Time64NanosecondArray, _>::new_known_size(
1730                    data_type, sample_fn, 1, size,
1731                ))
1732            }
1733            _ => panic!(),
1734        }
1735    }
1736
1737    /// Create a generator of random UUIDs, stored as fixed size binary values
1738    ///
1739    /// Note, these are "pseudo UUIDs".  They are 16-byte randomish values but they
1740    /// are not guaranteed to be unique.  We use a simplistic RNG that trades uniqueness
1741    /// for speed.
1742    pub fn rand_pseudo_uuid() -> Box<dyn ArrayGenerator> {
1743        Box::<PseudoUuidGenerator>::default()
1744    }
1745
1746    /// Create a generator of random UUIDs, stored as 32-character strings (hex encoding
1747    /// of the 16-byte binary value)
1748    ///
1749    /// Note, these are "pseudo UUIDs".  They are 16-byte randomish values but they
1750    /// are not guaranteed to be unique.  We use a simplistic RNG that trades uniqueness
1751    /// for speed.
1752    pub fn rand_pseudo_uuid_hex() -> Box<dyn ArrayGenerator> {
1753        Box::<PseudoUuidHexGenerator>::default()
1754    }
1755
1756    pub fn rand_primitive<T: ArrowPrimitiveType + Send + Sync>(
1757        data_type: DataType,
1758    ) -> Box<dyn ArrayGenerator> {
1759        Box::new(RandomBytesGenerator::<T>::new(data_type))
1760    }
1761
1762    pub fn rand_fsb(size: i32) -> Box<dyn ArrayGenerator> {
1763        Box::new(RandomFixedSizeBinaryGenerator::new(size))
1764    }
1765
1766    pub fn rand_interval(unit: IntervalUnit) -> Box<dyn ArrayGenerator> {
1767        Box::new(RandomIntervalGenerator::new(unit))
1768    }
1769
1770    /// Create a generator of randomly sampled date32 values
1771    ///
1772    /// Instead of sampling the entire range, all values will be drawn from the last year as this
1773    /// is a more common use pattern
1774    pub fn rand_date32() -> Box<dyn ArrayGenerator> {
1775        let now = chrono::Utc::now();
1776        let one_year_ago = now - chrono::TimeDelta::try_days(365).expect("TimeDelta try days");
1777        rand_date32_in_range(one_year_ago, now)
1778    }
1779
1780    /// Create a generator of randomly sampled date32 values in the given range
1781    pub fn rand_date32_in_range(
1782        start: chrono::DateTime<Utc>,
1783        end: chrono::DateTime<Utc>,
1784    ) -> Box<dyn ArrayGenerator> {
1785        let data_type = DataType::Date32;
1786        let end_ms = end.timestamp_millis();
1787        let end_days = (end_ms / MS_PER_DAY) as i32;
1788        let start_ms = start.timestamp_millis();
1789        let start_days = (start_ms / MS_PER_DAY) as i32;
1790        let dist = Uniform::new(start_days, end_days);
1791
1792        Box::new(FnGen::<i32, Date32Array, _>::new_known_size(
1793            data_type,
1794            move |rng| dist.sample(rng),
1795            1,
1796            DataType::Date32
1797                .primitive_width()
1798                .map(|width| ByteCount::from(width as u64))
1799                .expect("Date32 should have a fixed width"),
1800        ))
1801    }
1802
1803    /// Create a generator of randomly sampled date64 values
1804    ///
1805    /// Instead of sampling the entire range, all values will be drawn from the last year as this
1806    /// is a more common use pattern
1807    pub fn rand_date64() -> Box<dyn ArrayGenerator> {
1808        let now = chrono::Utc::now();
1809        let one_year_ago = now - chrono::TimeDelta::try_days(365).expect("TimeDelta try_days");
1810        rand_date64_in_range(one_year_ago, now)
1811    }
1812
1813    /// Create a generator of randomly sampled timestamp values in the given range
1814    ///
1815    /// Currently just samples the entire range of u64 values and casts to timestamp
1816    pub fn rand_timestamp_in_range(
1817        start: chrono::DateTime<Utc>,
1818        end: chrono::DateTime<Utc>,
1819        data_type: &DataType,
1820    ) -> Box<dyn ArrayGenerator> {
1821        let end_ms = end.timestamp_millis();
1822        let start_ms = start.timestamp_millis();
1823        let (start_ticks, end_ticks) = match data_type {
1824            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1825                (start_ms * 1000 * 1000, end_ms * 1000 * 1000)
1826            }
1827            DataType::Timestamp(TimeUnit::Microsecond, _) => (start_ms * 1000, end_ms * 1000),
1828            DataType::Timestamp(TimeUnit::Millisecond, _) => (start_ms, end_ms),
1829            DataType::Timestamp(TimeUnit::Second, _) => (start.timestamp(), end.timestamp()),
1830            _ => panic!(),
1831        };
1832        let dist = Uniform::new(start_ticks, end_ticks);
1833
1834        let data_type = data_type.clone();
1835        let sample_fn = move |rng: &mut _| (dist.sample(rng));
1836        let width = data_type
1837            .primitive_width()
1838            .map(|width| ByteCount::from(width as u64))
1839            .unwrap();
1840
1841        match data_type {
1842            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1843                Box::new(FnGen::<i64, TimestampNanosecondArray, _>::new_known_size(
1844                    data_type, sample_fn, 1, width,
1845                ))
1846            }
1847            DataType::Timestamp(TimeUnit::Microsecond, _) => {
1848                Box::new(FnGen::<i64, TimestampMicrosecondArray, _>::new_known_size(
1849                    data_type, sample_fn, 1, width,
1850                ))
1851            }
1852            DataType::Timestamp(TimeUnit::Millisecond, _) => {
1853                Box::new(FnGen::<i64, TimestampMicrosecondArray, _>::new_known_size(
1854                    data_type, sample_fn, 1, width,
1855                ))
1856            }
1857            DataType::Timestamp(TimeUnit::Second, _) => {
1858                Box::new(FnGen::<i64, TimestampSecondArray, _>::new_known_size(
1859                    data_type, sample_fn, 1, width,
1860                ))
1861            }
1862            _ => panic!(),
1863        }
1864    }
1865
1866    pub fn rand_timestamp(data_type: &DataType) -> Box<dyn ArrayGenerator> {
1867        let now = chrono::Utc::now();
1868        let one_year_ago = now - chrono::Duration::try_days(365).unwrap();
1869        rand_timestamp_in_range(one_year_ago, now, data_type)
1870    }
1871
1872    /// Create a generator of randomly sampled date64 values
1873    ///
1874    /// Instead of sampling the entire range, all values will be drawn from the last year as this
1875    /// is a more common use pattern
1876    pub fn rand_date64_in_range(
1877        start: chrono::DateTime<Utc>,
1878        end: chrono::DateTime<Utc>,
1879    ) -> Box<dyn ArrayGenerator> {
1880        let data_type = DataType::Date64;
1881        let end_ms = end.timestamp_millis();
1882        let end_days = end_ms / MS_PER_DAY;
1883        let start_ms = start.timestamp_millis();
1884        let start_days = start_ms / MS_PER_DAY;
1885        let dist = Uniform::new(start_days, end_days);
1886
1887        Box::new(FnGen::<i64, Date64Array, _>::new_known_size(
1888            data_type,
1889            move |rng| (dist.sample(rng)) * MS_PER_DAY,
1890            1,
1891            DataType::Date64
1892                .primitive_width()
1893                .map(|width| ByteCount::from(width as u64))
1894                .expect("Date64 should have a fixed width"),
1895        ))
1896    }
1897
1898    /// Create a generator of random binary values where each value has a fixed number of bytes
1899    pub fn rand_fixedbin(bytes_per_element: ByteCount, is_large: bool) -> Box<dyn ArrayGenerator> {
1900        Box::new(RandomBinaryGenerator::new(
1901            bytes_per_element,
1902            false,
1903            is_large,
1904        ))
1905    }
1906
1907    /// Create a generator of random binary values where each value has a variable number of bytes
1908    ///
1909    /// The number of bytes per element will be randomly sampled from the given (inclusive) range
1910    pub fn rand_varbin(
1911        min_bytes_per_element: ByteCount,
1912        max_bytes_per_element: ByteCount,
1913    ) -> Box<dyn ArrayGenerator> {
1914        Box::new(VariableRandomBinaryGenerator::new(
1915            min_bytes_per_element,
1916            max_bytes_per_element,
1917        ))
1918    }
1919
1920    /// Create a generator of random strings
1921    ///
1922    /// All strings will consist entirely of printable ASCII characters
1923    pub fn rand_utf8(bytes_per_element: ByteCount, is_large: bool) -> Box<dyn ArrayGenerator> {
1924        Box::new(RandomBinaryGenerator::new(
1925            bytes_per_element,
1926            true,
1927            is_large,
1928        ))
1929    }
1930
1931    /// Create a random generator of boolean values
1932    pub fn rand_boolean() -> Box<dyn ArrayGenerator> {
1933        Box::<RandomBooleanGenerator>::default()
1934    }
1935
1936    pub fn rand_list(item_type: &DataType, is_large: bool) -> Box<dyn ArrayGenerator> {
1937        let child_gen = rand_type(item_type);
1938        Box::new(RandomListGenerator::new(child_gen, is_large))
1939    }
1940
1941    pub fn rand_list_any(
1942        item_gen: Box<dyn ArrayGenerator>,
1943        is_large: bool,
1944    ) -> Box<dyn ArrayGenerator> {
1945        Box::new(RandomListGenerator::new(item_gen, is_large))
1946    }
1947
1948    pub fn rand_struct(fields: Fields) -> Box<dyn ArrayGenerator> {
1949        let child_gens = fields
1950            .iter()
1951            .map(|f| rand_type(f.data_type()))
1952            .collect::<Vec<_>>();
1953        Box::new(RandomStructGenerator::new(fields, child_gens))
1954    }
1955
1956    pub fn null_type() -> Box<dyn ArrayGenerator> {
1957        Box::new(NullArrayGenerator {})
1958    }
1959
1960    /// Create a generator of random values
1961    pub fn rand_type(data_type: &DataType) -> Box<dyn ArrayGenerator> {
1962        match data_type {
1963            DataType::Boolean => rand_boolean(),
1964            DataType::Int8 => rand::<Int8Type>(),
1965            DataType::Int16 => rand::<Int16Type>(),
1966            DataType::Int32 => rand::<Int32Type>(),
1967            DataType::Int64 => rand::<Int64Type>(),
1968            DataType::UInt8 => rand::<UInt8Type>(),
1969            DataType::UInt16 => rand::<UInt16Type>(),
1970            DataType::UInt32 => rand::<UInt32Type>(),
1971            DataType::UInt64 => rand::<UInt64Type>(),
1972            DataType::Float16 => rand_primitive::<Float16Type>(data_type.clone()),
1973            DataType::Float32 => rand::<Float32Type>(),
1974            DataType::Float64 => rand::<Float64Type>(),
1975            DataType::Decimal128(_, _) => rand_primitive::<Decimal128Type>(data_type.clone()),
1976            DataType::Decimal256(_, _) => rand_primitive::<Decimal256Type>(data_type.clone()),
1977            DataType::Utf8 => rand_utf8(ByteCount::from(12), false),
1978            DataType::LargeUtf8 => rand_utf8(ByteCount::from(12), true),
1979            DataType::Binary => rand_fixedbin(ByteCount::from(12), false),
1980            DataType::LargeBinary => rand_fixedbin(ByteCount::from(12), true),
1981            DataType::Dictionary(key_type, value_type) => {
1982                dict_type(rand_type(value_type), key_type)
1983            }
1984            DataType::FixedSizeList(child, dimension) => cycle_vec(
1985                rand_type(child.data_type()),
1986                Dimension::from(*dimension as u32),
1987            ),
1988            DataType::FixedSizeBinary(size) => rand_fsb(*size),
1989            DataType::List(child) => rand_list(child.data_type(), false),
1990            DataType::LargeList(child) => rand_list(child.data_type(), true),
1991            DataType::Duration(unit) => match unit {
1992                TimeUnit::Second => rand::<DurationSecondType>(),
1993                TimeUnit::Millisecond => rand::<DurationMillisecondType>(),
1994                TimeUnit::Microsecond => rand::<DurationMicrosecondType>(),
1995                TimeUnit::Nanosecond => rand::<DurationNanosecondType>(),
1996            },
1997            DataType::Interval(unit) => rand_interval(*unit),
1998            DataType::Date32 => rand_date32(),
1999            DataType::Date64 => rand_date64(),
2000            DataType::Time32(resolution) => rand_time32(resolution),
2001            DataType::Time64(resolution) => rand_time64(resolution),
2002            DataType::Timestamp(_, _) => rand_timestamp(data_type),
2003            DataType::Struct(fields) => rand_struct(fields.clone()),
2004            DataType::Null => null_type(),
2005            _ => unimplemented!("random generation of {}", data_type),
2006        }
2007    }
2008
2009    /// Encodes arrays generated by the underlying generator as dictionaries with the given key type
2010    ///
2011    /// Note that this may not be very realistic if the underlying generator is something like a random
2012    /// generator since most of the underlying values will be unique and the common case for dictionary
2013    /// encoding is when there is a small set of possible values.
2014    pub fn dict<K: ArrowDictionaryKeyType + Send + Sync>(
2015        generator: Box<dyn ArrayGenerator>,
2016    ) -> Box<dyn ArrayGenerator> {
2017        Box::new(DictionaryGenerator::<K>::new(generator))
2018    }
2019
2020    /// Encodes arrays generated by the underlying generator as dictionaries with the given key type
2021    pub fn dict_type(
2022        generator: Box<dyn ArrayGenerator>,
2023        key_type: &DataType,
2024    ) -> Box<dyn ArrayGenerator> {
2025        match key_type {
2026            DataType::Int8 => dict::<Int8Type>(generator),
2027            DataType::Int16 => dict::<Int16Type>(generator),
2028            DataType::Int32 => dict::<Int32Type>(generator),
2029            DataType::Int64 => dict::<Int64Type>(generator),
2030            DataType::UInt8 => dict::<UInt8Type>(generator),
2031            DataType::UInt16 => dict::<UInt16Type>(generator),
2032            DataType::UInt32 => dict::<UInt32Type>(generator),
2033            DataType::UInt64 => dict::<UInt64Type>(generator),
2034            _ => unimplemented!(),
2035        }
2036    }
2037}
2038
2039/// Create a BatchGeneratorBuilder to start generating batch data
2040pub fn gen() -> BatchGeneratorBuilder {
2041    BatchGeneratorBuilder::default()
2042}
2043
2044/// Create an ArrayGeneratorBuilder to start generating array data
2045pub fn gen_array(gen: Box<dyn ArrayGenerator>) -> ArrayGeneratorBuilder {
2046    ArrayGeneratorBuilder::new(gen)
2047}
2048
2049/// Create a BatchGeneratorBuilder with the given schema
2050///
2051/// You can add more columns or convert this into a reader immediately
2052pub fn rand(schema: &Schema) -> BatchGeneratorBuilder {
2053    let mut builder = BatchGeneratorBuilder::default();
2054    for field in schema.fields() {
2055        builder = builder.col(field.name(), array::rand_type(field.data_type()));
2056    }
2057    builder
2058}
2059
2060#[cfg(test)]
2061mod tests {
2062
2063    use arrow::datatypes::{Float32Type, Int16Type, Int8Type, UInt32Type};
2064    use arrow_array::{BooleanArray, Float32Array, Int16Array, Int32Array, Int8Array, UInt32Array};
2065
2066    use super::*;
2067
2068    #[test]
2069    fn test_step() {
2070        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2071        let mut gen = array::step::<Int32Type>();
2072        assert_eq!(
2073            *gen.generate(RowCount::from(5), &mut rng).unwrap(),
2074            Int32Array::from_iter([0, 1, 2, 3, 4])
2075        );
2076        assert_eq!(
2077            *gen.generate(RowCount::from(5), &mut rng).unwrap(),
2078            Int32Array::from_iter([5, 6, 7, 8, 9])
2079        );
2080
2081        let mut gen = array::step::<Int8Type>();
2082        assert_eq!(
2083            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2084            Int8Array::from_iter([0, 1, 2])
2085        );
2086
2087        let mut gen = array::step::<Float32Type>();
2088        assert_eq!(
2089            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2090            Float32Array::from_iter([0.0, 1.0, 2.0])
2091        );
2092
2093        let mut gen = array::step_custom::<Int16Type>(4, 8);
2094        assert_eq!(
2095            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2096            Int16Array::from_iter([4, 12, 20])
2097        );
2098        assert_eq!(
2099            *gen.generate(RowCount::from(2), &mut rng).unwrap(),
2100            Int16Array::from_iter([28, 36])
2101        );
2102    }
2103
2104    #[test]
2105    fn test_cycle() {
2106        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2107        let mut gen = array::cycle::<Int32Type>(vec![1, 2, 3]);
2108        assert_eq!(
2109            *gen.generate(RowCount::from(5), &mut rng).unwrap(),
2110            Int32Array::from_iter([1, 2, 3, 1, 2])
2111        );
2112
2113        let mut gen = array::cycle_utf8_literals(&["abc", "def", "xyz"]);
2114        assert_eq!(
2115            *gen.generate(RowCount::from(5), &mut rng).unwrap(),
2116            StringArray::from_iter_values(["abc", "def", "xyz", "abc", "def"])
2117        );
2118        assert_eq!(
2119            *gen.generate(RowCount::from(1), &mut rng).unwrap(),
2120            StringArray::from_iter_values(["xyz"])
2121        );
2122    }
2123
2124    #[test]
2125    fn test_fill() {
2126        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2127        let mut gen = array::fill::<Int32Type>(42);
2128        assert_eq!(
2129            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2130            Int32Array::from_iter([42, 42, 42])
2131        );
2132        assert_eq!(
2133            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2134            Int32Array::from_iter([42, 42, 42])
2135        );
2136
2137        let mut gen = array::fill_varbin(vec![0, 1, 2]);
2138        assert_eq!(
2139            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2140            arrow_array::BinaryArray::from_iter_values([
2141                "\x00\x01\x02",
2142                "\x00\x01\x02",
2143                "\x00\x01\x02"
2144            ])
2145        );
2146
2147        let mut gen = array::fill_utf8("xyz".to_string());
2148        assert_eq!(
2149            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2150            arrow_array::StringArray::from_iter_values(["xyz", "xyz", "xyz"])
2151        );
2152    }
2153
2154    #[test]
2155    fn test_rng() {
2156        // Note: these tests are heavily dependent on the default seed.
2157        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2158        let mut gen = array::rand::<Int32Type>();
2159        assert_eq!(
2160            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2161            Int32Array::from_iter([-797553329, 1369325940, -69174021])
2162        );
2163
2164        let mut gen = array::rand_fixedbin(ByteCount::from(3), false);
2165        assert_eq!(
2166            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2167            arrow_array::BinaryArray::from_iter_values([
2168                [184, 53, 216],
2169                [12, 96, 159],
2170                [125, 179, 56]
2171            ])
2172        );
2173
2174        let mut gen = array::rand_utf8(ByteCount::from(3), false);
2175        assert_eq!(
2176            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2177            arrow_array::StringArray::from_iter_values([">@p", "n `", "NWa"])
2178        );
2179
2180        let mut gen = array::rand_date32();
2181        let days_32 = gen.generate(RowCount::from(3), &mut rng).unwrap();
2182        assert_eq!(days_32.data_type(), &DataType::Date32);
2183
2184        let mut gen = array::rand_date64();
2185        let days_64 = gen.generate(RowCount::from(3), &mut rng).unwrap();
2186        assert_eq!(days_64.data_type(), &DataType::Date64);
2187
2188        let mut gen = array::rand_boolean();
2189        let bools = gen.generate(RowCount::from(1024), &mut rng).unwrap();
2190        assert_eq!(bools.data_type(), &DataType::Boolean);
2191        let bools = bools.as_any().downcast_ref::<BooleanArray>().unwrap();
2192        // Sanity check to ensure we're getting at least some rng
2193        assert!(bools.false_count() > 100);
2194        assert!(bools.true_count() > 100);
2195
2196        let mut gen = array::rand_varbin(ByteCount::from(2), ByteCount::from(4));
2197        assert_eq!(
2198            *gen.generate(RowCount::from(3), &mut rng).unwrap(),
2199            arrow_array::BinaryArray::from_iter_values([
2200                vec![56, 122, 157, 34],
2201                vec![58, 51],
2202                vec![41, 184, 125]
2203            ])
2204        );
2205    }
2206
2207    #[test]
2208    fn test_rng_list() {
2209        // Note: these tests are heavily dependent on the default seed.
2210        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2211        let mut gen = array::rand_list(&DataType::Int32, false);
2212        let arr = gen.generate(RowCount::from(100), &mut rng).unwrap();
2213        // Make sure we can generate empty lists (note, test is dependent on seed)
2214        let arr = arr.as_list::<i32>();
2215        assert!(arr.iter().any(|l| l.unwrap().is_empty()));
2216        // Shouldn't generate any giant lists (don't kill performance in normal datagen)
2217        assert!(arr.iter().any(|l| l.unwrap().len() < 11));
2218    }
2219
2220    #[test]
2221    fn test_rng_distribution() {
2222        // Sanity test to make sure we our RNG is giving us well distributed values
2223        // We generates some 4-byte integers, histogram them into 8 buckets, and make
2224        // sure each bucket has a good # of values
2225        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2226        let mut gen = array::rand::<UInt32Type>();
2227        for _ in 0..10 {
2228            let arr = gen.generate(RowCount::from(10000), &mut rng).unwrap();
2229            let int_arr = arr.as_any().downcast_ref::<UInt32Array>().unwrap();
2230            let mut buckets = vec![0_u32; 256];
2231            for val in int_arr.values() {
2232                buckets[(*val >> 24) as usize] += 1;
2233            }
2234            for bucket in buckets {
2235                // Perfectly even distribution would have 10000 / 256 values (~40) per bucket
2236                // We test for 15 which should be "good enough" and statistically unlikely to fail
2237                assert!(bucket > 15);
2238            }
2239        }
2240    }
2241
2242    #[test]
2243    fn test_nulls() {
2244        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
2245        let mut gen = array::rand::<Int32Type>().with_random_nulls(0.3);
2246
2247        let arr = gen.generate(RowCount::from(1000), &mut rng).unwrap();
2248
2249        // This assert depends on the default seed
2250        assert_eq!(arr.null_count(), 297);
2251
2252        for len in 0..100 {
2253            let arr = gen.generate(RowCount::from(len), &mut rng).unwrap();
2254            // Make sure the null count we came up with matches the actual # of unset bits
2255            assert_eq!(
2256                arr.null_count(),
2257                arr.nulls()
2258                    .map(|nulls| (len as usize)
2259                        - nulls.buffer().count_set_bits_offset(0, len as usize))
2260                    .unwrap_or(0)
2261            );
2262        }
2263
2264        let mut gen = array::rand::<Int32Type>().with_random_nulls(0.0);
2265        let arr = gen.generate(RowCount::from(10), &mut rng).unwrap();
2266
2267        assert_eq!(arr.null_count(), 0);
2268
2269        let mut gen = array::rand::<Int32Type>().with_random_nulls(1.0);
2270        let arr = gen.generate(RowCount::from(10), &mut rng).unwrap();
2271
2272        assert_eq!(arr.null_count(), 10);
2273        assert!((0..10).all(|idx| arr.is_null(idx)));
2274
2275        let mut gen = array::rand::<Int32Type>().with_nulls(&[false, false, true]);
2276        let arr = gen.generate(RowCount::from(7), &mut rng).unwrap();
2277        assert!((0..2).all(|idx| arr.is_valid(idx)));
2278        assert!(arr.is_null(2));
2279        assert!((3..5).all(|idx| arr.is_valid(idx)));
2280        assert!(arr.is_null(5));
2281        assert!(arr.is_valid(6));
2282    }
2283
2284    #[test]
2285    fn test_rand_schema() {
2286        let schema = Schema::new(vec![
2287            Field::new("a", DataType::Int32, true),
2288            Field::new("b", DataType::Utf8, true),
2289            Field::new("c", DataType::Float32, true),
2290            Field::new("d", DataType::Int32, true),
2291            Field::new("e", DataType::Int32, true),
2292        ]);
2293        let rbr = rand(&schema)
2294            .into_reader_bytes(
2295                ByteCount::from(1024 * 1024),
2296                BatchCount::from(8),
2297                RoundingBehavior::ExactOrErr,
2298            )
2299            .unwrap();
2300        assert_eq!(*rbr.schema(), schema);
2301
2302        let batches = rbr.map(|val| val.unwrap()).collect::<Vec<_>>();
2303        assert_eq!(batches.len(), 8);
2304
2305        for batch in batches {
2306            assert_eq!(batch.num_rows(), 1024 * 1024 / 32);
2307            assert_eq!(batch.num_columns(), 5);
2308        }
2309    }
2310}