Skip to main content

lance_datagen/
generator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, iter, marker::PhantomData, sync::Arc, sync::LazyLock};
5
6use arrow::{
7    array::{ArrayData, AsArray, Float32Builder, GenericBinaryBuilder, GenericStringBuilder},
8    buffer::{BooleanBuffer, Buffer, OffsetBuffer, ScalarBuffer},
9    datatypes::{
10        ArrowPrimitiveType, Float32Type, Int32Type, Int64Type, IntervalDayTime,
11        IntervalMonthDayNano, UInt32Type,
12    },
13};
14use arrow_array::{
15    make_array,
16    types::{ArrowDictionaryKeyType, BinaryType, ByteArrayType, Utf8Type},
17    Array, BinaryArray, FixedSizeBinaryArray, FixedSizeListArray, Float32Array, LargeListArray,
18    LargeStringArray, ListArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, RecordBatch,
19    RecordBatchOptions, RecordBatchReader, StringArray, StructArray,
20};
21use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema, SchemaRef};
22use futures::{stream::BoxStream, StreamExt};
23use rand::{distr::Uniform, Rng, RngCore, SeedableRng};
24use rand_distr::Zipf;
25use random_word;
26
27use self::array::rand_with_distribution;
28
29#[derive(Copy, Clone, Debug, Default)]
30pub struct RowCount(u64);
31#[derive(Copy, Clone, Debug, Default)]
32pub struct BatchCount(u32);
33#[derive(Copy, Clone, Debug, Default)]
34pub struct ByteCount(u64);
35#[derive(Copy, Clone, Debug, Default)]
36pub struct Dimension(u32);
37
38impl From<u32> for BatchCount {
39    fn from(n: u32) -> Self {
40        Self(n)
41    }
42}
43
44impl From<u64> for RowCount {
45    fn from(n: u64) -> Self {
46        Self(n)
47    }
48}
49
50impl From<u64> for ByteCount {
51    fn from(n: u64) -> Self {
52        Self(n)
53    }
54}
55
56impl From<u32> for Dimension {
57    fn from(n: u32) -> Self {
58        Self(n)
59    }
60}
61
62/// A trait for anything that can generate arrays of data
63pub trait ArrayGenerator: Send + Sync + std::fmt::Debug {
64    /// Generate an array of the given length
65    ///
66    /// # Arguments
67    ///
68    /// * `length` - The number of elements to generate
69    /// * `rng` - The random number generator to use
70    ///
71    /// # Returns
72    ///
73    /// An array of the given length
74    ///
75    /// Note: Not every generator needs an rng.  However, it is passed here because many do and this
76    /// lets us manage RNGs at the batch level instead of the array level.
77    fn generate(
78        &mut self,
79        length: RowCount,
80        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
81    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError>;
82
83    /// Generate an array of the given length using a new RNG with the default seed
84    ///
85    /// # Arguments
86    ///
87    /// * `length` - The number of elements to generate
88    ///
89    /// # Returns
90    ///
91    /// An array of the given length
92    fn generate_default(
93        &mut self,
94        length: RowCount,
95    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
96        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
97        Self::generate(self, length, &mut rng)
98    }
99    /// Get the data type of the array that this generator produces
100    ///
101    /// # Returns
102    ///
103    /// The data type of the array that this generator produces
104    fn data_type(&self) -> &DataType;
105    /// Gets metadata that should be associated with the field generated by this generator
106    fn metadata(&self) -> Option<HashMap<String, String>> {
107        None
108    }
109    /// Get the size of each element in bytes
110    ///
111    /// # Returns
112    ///
113    /// The size of each element in bytes.  Will be None if the size varies by element.
114    fn element_size_bytes(&self) -> Option<ByteCount>;
115}
116
117#[derive(Debug)]
118pub struct CycleNullGenerator {
119    generator: Box<dyn ArrayGenerator>,
120    validity: Vec<bool>,
121    idx: usize,
122}
123#[derive(Debug)]
124pub struct CycleNanGenerator {
125    generator: Box<dyn ArrayGenerator>,
126    nan_pattern: Vec<bool>,
127    idx: usize,
128}
129
130impl ArrayGenerator for CycleNanGenerator {
131    fn generate(
132        &mut self,
133        length: RowCount,
134        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
135    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
136        let array = self.generator.generate(length, rng)?;
137
138        // Only apply NaN pattern to float types
139        match array.data_type() {
140            DataType::Float16 => {
141                let float_array = array
142                    .as_any()
143                    .downcast_ref::<arrow_array::Float16Array>()
144                    .unwrap();
145                let mut values: Vec<half::f16> = float_array.values().to_vec();
146
147                for (i, &should_be_nan) in self
148                    .nan_pattern
149                    .iter()
150                    .cycle()
151                    .skip(self.idx)
152                    .take(length.0 as usize)
153                    .enumerate()
154                {
155                    if should_be_nan {
156                        values[i] = half::f16::NAN;
157                    }
158                }
159
160                self.idx = (self.idx + (length.0 as usize)) % self.nan_pattern.len();
161                Ok(Arc::new(arrow_array::Float16Array::from(values)))
162            }
163            DataType::Float32 => {
164                let float_array = array
165                    .as_any()
166                    .downcast_ref::<arrow_array::Float32Array>()
167                    .unwrap();
168                let mut values: Vec<f32> = float_array.values().to_vec();
169
170                for (i, &should_be_nan) in self
171                    .nan_pattern
172                    .iter()
173                    .cycle()
174                    .skip(self.idx)
175                    .take(length.0 as usize)
176                    .enumerate()
177                {
178                    if should_be_nan {
179                        values[i] = f32::NAN;
180                    }
181                }
182
183                self.idx = (self.idx + (length.0 as usize)) % self.nan_pattern.len();
184                Ok(Arc::new(arrow_array::Float32Array::from(values)))
185            }
186            DataType::Float64 => {
187                let float_array = array
188                    .as_any()
189                    .downcast_ref::<arrow_array::Float64Array>()
190                    .unwrap();
191                let mut values: Vec<f64> = float_array.values().to_vec();
192
193                for (i, &should_be_nan) in self
194                    .nan_pattern
195                    .iter()
196                    .cycle()
197                    .skip(self.idx)
198                    .take(length.0 as usize)
199                    .enumerate()
200                {
201                    if should_be_nan {
202                        values[i] = f64::NAN;
203                    }
204                }
205
206                self.idx = (self.idx + (length.0 as usize)) % self.nan_pattern.len();
207                Ok(Arc::new(arrow_array::Float64Array::from(values)))
208            }
209            _ => {
210                // For non-float types, just return the original array unchanged
211                Ok(array)
212            }
213        }
214    }
215
216    fn data_type(&self) -> &DataType {
217        self.generator.data_type()
218    }
219
220    fn element_size_bytes(&self) -> Option<ByteCount> {
221        self.generator.element_size_bytes()
222    }
223}
224
225impl ArrayGenerator for CycleNullGenerator {
226    fn generate(
227        &mut self,
228        length: RowCount,
229        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
230    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
231        let array = self.generator.generate(length, rng)?;
232        let data = array.to_data();
233        let validity_itr = self
234            .validity
235            .iter()
236            .cycle()
237            .skip(self.idx)
238            .take(length.0 as usize)
239            .copied();
240        let validity_bitmap = BooleanBuffer::from_iter(validity_itr);
241
242        self.idx = (self.idx + (length.0 as usize)) % self.validity.len();
243        unsafe {
244            let new_data = ArrayData::new_unchecked(
245                data.data_type().clone(),
246                data.len(),
247                None,
248                Some(validity_bitmap.into_inner()),
249                data.offset(),
250                data.buffers().to_vec(),
251                data.child_data().into(),
252            );
253            Ok(make_array(new_data))
254        }
255    }
256
257    fn data_type(&self) -> &DataType {
258        self.generator.data_type()
259    }
260
261    fn element_size_bytes(&self) -> Option<ByteCount> {
262        self.generator.element_size_bytes()
263    }
264}
265
266#[derive(Debug)]
267pub struct MetadataGenerator {
268    generator: Box<dyn ArrayGenerator>,
269    metadata: HashMap<String, String>,
270}
271
272impl ArrayGenerator for MetadataGenerator {
273    fn generate(
274        &mut self,
275        length: RowCount,
276        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
277    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
278        self.generator.generate(length, rng)
279    }
280
281    fn metadata(&self) -> Option<HashMap<String, String>> {
282        Some(self.metadata.clone())
283    }
284
285    fn data_type(&self) -> &DataType {
286        self.generator.data_type()
287    }
288
289    fn element_size_bytes(&self) -> Option<ByteCount> {
290        self.generator.element_size_bytes()
291    }
292}
293
294#[derive(Debug)]
295pub struct NullGenerator {
296    generator: Box<dyn ArrayGenerator>,
297    null_probability: f64,
298}
299
300impl ArrayGenerator for NullGenerator {
301    fn generate(
302        &mut self,
303        length: RowCount,
304        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
305    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
306        let array = self.generator.generate(length, rng)?;
307        let data = array.to_data();
308
309        if self.null_probability < 0.0 || self.null_probability > 1.0 {
310            return Err(ArrowError::InvalidArgumentError(format!(
311                "null_probability must be between 0 and 1, got {}",
312                self.null_probability
313            )));
314        }
315
316        let (null_count, new_validity) = if self.null_probability == 0.0 {
317            if data.null_count() == 0 {
318                return Ok(array);
319            } else {
320                (0_usize, None)
321            }
322        } else if self.null_probability == 1.0 {
323            if data.null_count() == data.len() {
324                return Ok(array);
325            } else {
326                let all_nulls = BooleanBuffer::new_unset(array.len());
327                (array.len(), Some(all_nulls.into_inner()))
328            }
329        } else {
330            let array_len = array.len();
331            let num_validity_bytes = array_len.div_ceil(8);
332            let mut null_count = 0;
333            // Sampling the RNG once per bit is kind of slow so we do this to sample once
334            // per byte.  We only get 8 bits of RNG resolution but that should be good enough.
335            let threshold = (self.null_probability * u8::MAX as f64) as u8;
336            let bytes = (0..num_validity_bytes)
337                .map(|byte_idx| {
338                    let mut sample = rng.random::<u64>();
339                    let mut byte: u8 = 0;
340                    for bit_idx in 0..8 {
341                        // We could probably overshoot and fill in extra bits with random data but
342                        // this is cleaner and that would mess up the null count
343                        byte <<= 1;
344                        let pos = byte_idx * 8 + (7 - bit_idx);
345                        if pos < array_len {
346                            let sample_piece = sample & 0xFF;
347                            let is_null = (sample_piece as u8) < threshold;
348                            byte |= (!is_null) as u8;
349                            null_count += is_null as usize;
350                        }
351                        sample >>= 8;
352                    }
353                    byte
354                })
355                .collect::<Vec<_>>();
356            let new_validity = Buffer::from_iter(bytes);
357            (null_count, Some(new_validity))
358        };
359
360        unsafe {
361            let new_data = ArrayData::new_unchecked(
362                data.data_type().clone(),
363                data.len(),
364                Some(null_count),
365                new_validity,
366                data.offset(),
367                data.buffers().to_vec(),
368                data.child_data().into(),
369            );
370            Ok(make_array(new_data))
371        }
372    }
373
374    fn metadata(&self) -> Option<HashMap<String, String>> {
375        self.generator.metadata()
376    }
377
378    fn data_type(&self) -> &DataType {
379        self.generator.data_type()
380    }
381
382    fn element_size_bytes(&self) -> Option<ByteCount> {
383        self.generator.element_size_bytes()
384    }
385}
386
387pub trait ArrayGeneratorExt {
388    /// Replaces the validity bitmap of generated arrays, inserting nulls with a given probability
389    fn with_random_nulls(self, null_probability: f64) -> Box<dyn ArrayGenerator>;
390    /// Replaces the validity bitmap of generated arrays with the inverse of `nulls`, cycling if needed
391    fn with_nulls(self, nulls: &[bool]) -> Box<dyn ArrayGenerator>;
392    /// Replaces the values of generated arrays with NaN values, cycling if needed
393    ///
394    /// Will have no effect if the data type is not a floating point data type
395    fn with_nans(self, nans: &[bool]) -> Box<dyn ArrayGenerator>;
396    /// Replaces the validity bitmap of generated arrays with `validity`, cycling if needed
397    fn with_validity(self, nulls: &[bool]) -> Box<dyn ArrayGenerator>;
398    fn with_metadata(self, metadata: HashMap<String, String>) -> Box<dyn ArrayGenerator>;
399}
400
401impl ArrayGeneratorExt for Box<dyn ArrayGenerator> {
402    fn with_random_nulls(self, null_probability: f64) -> Box<dyn ArrayGenerator> {
403        Box::new(NullGenerator {
404            generator: self,
405            null_probability,
406        })
407    }
408
409    fn with_nulls(self, nulls: &[bool]) -> Box<dyn ArrayGenerator> {
410        Box::new(CycleNullGenerator {
411            generator: self,
412            validity: nulls.iter().map(|v| !*v).collect(),
413            idx: 0,
414        })
415    }
416
417    fn with_nans(self, nans: &[bool]) -> Box<dyn ArrayGenerator> {
418        Box::new(CycleNanGenerator {
419            generator: self,
420            nan_pattern: nans.to_vec(),
421            idx: 0,
422        })
423    }
424
425    fn with_validity(self, validity: &[bool]) -> Box<dyn ArrayGenerator> {
426        Box::new(CycleNullGenerator {
427            generator: self,
428            validity: validity.to_vec(),
429            idx: 0,
430        })
431    }
432
433    fn with_metadata(self, metadata: HashMap<String, String>) -> Box<dyn ArrayGenerator> {
434        Box::new(MetadataGenerator {
435            generator: self,
436            metadata,
437        })
438    }
439}
440
441pub struct NTimesIter<I: Iterator>
442where
443    I::Item: Copy,
444{
445    iter: I,
446    n: u32,
447    cur: I::Item,
448    count: u32,
449}
450
451// Note: if this is used then there is a performance hit as the
452// inner loop cannot experience vectorization
453//
454// TODO: maybe faster to build the vec and then repeat it into
455// the destination array?
456impl<I: Iterator> Iterator for NTimesIter<I>
457where
458    I::Item: Copy,
459{
460    type Item = I::Item;
461
462    fn next(&mut self) -> Option<Self::Item> {
463        if self.count == 0 {
464            self.count = self.n - 1;
465            self.cur = self.iter.next()?;
466        } else {
467            self.count -= 1;
468        }
469        Some(self.cur)
470    }
471
472    fn size_hint(&self) -> (usize, Option<usize>) {
473        let (lower, upper) = self.iter.size_hint();
474        let lower = lower * self.n as usize;
475        let upper = upper.map(|u| u * self.n as usize);
476        (lower, upper)
477    }
478}
479
480pub struct FnGen<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T>
481where
482    T: Copy + Default,
483    ArrayType: arrow_array::Array + From<Vec<T>>,
484{
485    data_type: DataType,
486    generator: F,
487    array_type: PhantomData<ArrayType>,
488    repeat: u32,
489    leftover: T,
490    leftover_count: u32,
491    element_size_bytes: Option<ByteCount>,
492}
493
494impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> std::fmt::Debug
495    for FnGen<T, ArrayType, F>
496where
497    T: Copy + Default,
498    ArrayType: arrow_array::Array + From<Vec<T>>,
499{
500    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501        f.debug_struct("FnGen")
502            .field("data_type", &self.data_type)
503            .field("array_type", &self.array_type)
504            .field("repeat", &self.repeat)
505            .field("leftover_count", &self.leftover_count)
506            .field("element_size_bytes", &self.element_size_bytes)
507            .finish()
508    }
509}
510
511impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> FnGen<T, ArrayType, F>
512where
513    T: Copy + Default,
514    ArrayType: arrow_array::Array + From<Vec<T>>,
515{
516    fn new_known_size(
517        data_type: DataType,
518        generator: F,
519        repeat: u32,
520        element_size_bytes: ByteCount,
521    ) -> Self {
522        Self {
523            data_type,
524            generator,
525            array_type: PhantomData,
526            repeat,
527            leftover: T::default(),
528            leftover_count: 0,
529            element_size_bytes: Some(element_size_bytes),
530        }
531    }
532
533    fn new_unknown_size(data_type: DataType, generator: F, repeat: u32) -> Self {
534        Self {
535            data_type,
536            generator,
537            array_type: PhantomData,
538            repeat,
539            leftover: T::default(),
540            leftover_count: 0,
541            element_size_bytes: None,
542        }
543    }
544}
545
546impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> ArrayGenerator
547    for FnGen<T, ArrayType, F>
548where
549    T: Copy + Default + Send + Sync,
550    ArrayType: arrow_array::Array + From<Vec<T>> + 'static,
551    F: Send + Sync,
552{
553    fn generate(
554        &mut self,
555        length: RowCount,
556        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
557    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
558        let iter = (0..length.0).map(|_| (self.generator)(rng));
559        let values = if self.repeat > 1 {
560            Vec::from_iter(
561                NTimesIter {
562                    iter,
563                    n: self.repeat,
564                    cur: self.leftover,
565                    count: self.leftover_count,
566                }
567                .take(length.0 as usize),
568            )
569        } else {
570            Vec::from_iter(iter)
571        };
572        self.leftover_count = ((self.leftover_count as u64 + length.0) % self.repeat as u64) as u32;
573        self.leftover = values.last().copied().unwrap_or(T::default());
574        Ok(Arc::new(ArrayType::from(values)))
575    }
576
577    fn data_type(&self) -> &DataType {
578        &self.data_type
579    }
580
581    fn element_size_bytes(&self) -> Option<ByteCount> {
582        self.element_size_bytes
583    }
584}
585
586#[derive(Copy, Clone, Debug)]
587pub struct Seed(pub u64);
588pub const DEFAULT_SEED: Seed = Seed(42);
589
590impl From<u64> for Seed {
591    fn from(n: u64) -> Self {
592        Self(n)
593    }
594}
595
596#[derive(Debug)]
597pub struct CycleVectorGenerator {
598    underlying_gen: Box<dyn ArrayGenerator>,
599    dimension: Dimension,
600    data_type: DataType,
601}
602
603impl CycleVectorGenerator {
604    pub fn new(underlying_gen: Box<dyn ArrayGenerator>, dimension: Dimension) -> Self {
605        let data_type = DataType::FixedSizeList(
606            Arc::new(Field::new("item", underlying_gen.data_type().clone(), true)),
607            dimension.0 as i32,
608        );
609        Self {
610            underlying_gen,
611            dimension,
612            data_type,
613        }
614    }
615}
616
617impl ArrayGenerator for CycleVectorGenerator {
618    fn generate(
619        &mut self,
620        length: RowCount,
621        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
622    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
623        let values = self
624            .underlying_gen
625            .generate(RowCount::from(length.0 * self.dimension.0 as u64), rng)?;
626        let field = Arc::new(Field::new("item", values.data_type().clone(), true));
627        let values = Arc::new(values);
628
629        let array = FixedSizeListArray::try_new(field, self.dimension.0 as i32, values, None)?;
630
631        Ok(Arc::new(array))
632    }
633
634    fn data_type(&self) -> &DataType {
635        &self.data_type
636    }
637
638    fn element_size_bytes(&self) -> Option<ByteCount> {
639        self.underlying_gen
640            .element_size_bytes()
641            .map(|byte_count| ByteCount::from(byte_count.0 * self.dimension.0 as u64))
642    }
643}
644
645#[derive(Debug)]
646pub struct CycleListGenerator {
647    underlying_gen: Box<dyn ArrayGenerator>,
648    lengths_gen: Box<dyn ArrayGenerator>,
649    data_type: DataType,
650}
651
652impl CycleListGenerator {
653    pub fn new(
654        underlying_gen: Box<dyn ArrayGenerator>,
655        min_list_size: Dimension,
656        max_list_size: Dimension,
657    ) -> Self {
658        let data_type = DataType::List(Arc::new(Field::new(
659            "item",
660            underlying_gen.data_type().clone(),
661            true,
662        )));
663        let lengths_dist = Uniform::new(min_list_size.0, max_list_size.0).unwrap();
664        let lengths_gen = rand_with_distribution::<UInt32Type, Uniform<u32>>(lengths_dist);
665        Self {
666            underlying_gen,
667            lengths_gen,
668            data_type,
669        }
670    }
671}
672
673impl ArrayGenerator for CycleListGenerator {
674    fn generate(
675        &mut self,
676        length: RowCount,
677        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
678    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
679        let lengths = self.lengths_gen.generate(length, rng)?;
680        let lengths = lengths.as_primitive::<UInt32Type>();
681        let total_length = lengths.values().iter().map(|i| *i as u64).sum::<u64>();
682        let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
683        let values = self
684            .underlying_gen
685            .generate(RowCount::from(total_length), rng)?;
686        let field = Arc::new(Field::new("item", values.data_type().clone(), true));
687        let values = Arc::new(values);
688
689        let array = ListArray::try_new(field, offsets, values, None)?;
690
691        Ok(Arc::new(array))
692    }
693
694    fn data_type(&self) -> &DataType {
695        &self.data_type
696    }
697
698    fn element_size_bytes(&self) -> Option<ByteCount> {
699        None
700    }
701}
702
703#[derive(Debug, Default)]
704pub struct PseudoUuidGenerator {}
705
706impl ArrayGenerator for PseudoUuidGenerator {
707    fn generate(
708        &mut self,
709        length: RowCount,
710        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
711    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
712        Ok(Arc::new(FixedSizeBinaryArray::try_from_iter(
713            (0..length.0).map(|_| {
714                let mut data = vec![0; 16];
715                rng.fill_bytes(&mut data);
716                data
717            }),
718        )?))
719    }
720
721    fn data_type(&self) -> &DataType {
722        &DataType::FixedSizeBinary(16)
723    }
724
725    fn element_size_bytes(&self) -> Option<ByteCount> {
726        Some(ByteCount::from(16))
727    }
728}
729
730#[derive(Debug, Default)]
731pub struct PseudoUuidHexGenerator {}
732
733impl ArrayGenerator for PseudoUuidHexGenerator {
734    fn generate(
735        &mut self,
736        length: RowCount,
737        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
738    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
739        let mut data = vec![0; 16 * length.0 as usize];
740        rng.fill_bytes(&mut data);
741        let data_hex = hex::encode(data);
742
743        Ok(Arc::new(StringArray::from_iter_values(
744            (0..length.0 as usize).map(|i| data_hex.get(i * 32..(i + 1) * 32).unwrap()),
745        )))
746    }
747
748    fn data_type(&self) -> &DataType {
749        &DataType::Utf8
750    }
751
752    fn element_size_bytes(&self) -> Option<ByteCount> {
753        Some(ByteCount::from(16))
754    }
755}
756
757#[derive(Debug, Default)]
758pub struct RandomBooleanGenerator {}
759
760impl ArrayGenerator for RandomBooleanGenerator {
761    fn generate(
762        &mut self,
763        length: RowCount,
764        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
765    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
766        let num_bytes = length.0.div_ceil(8);
767        let mut bytes = vec![0; num_bytes as usize];
768        rng.fill_bytes(&mut bytes);
769        let bytes = BooleanBuffer::new(Buffer::from(bytes), 0, length.0 as usize);
770        Ok(Arc::new(arrow_array::BooleanArray::new(bytes, None)))
771    }
772
773    fn data_type(&self) -> &DataType {
774        &DataType::Boolean
775    }
776
777    fn element_size_bytes(&self) -> Option<ByteCount> {
778        // We can't say 1/8th of a byte and 1 byte would be a pretty extreme over-count so let's leave
779        // it at None until someone needs this.  Then we can probably special case this (e.g. make a ByteCount::ONE_BIT)
780        None
781    }
782}
783
784// Instead of using the "standard distribution" and generating values there are some cases (e.g. f16 / decimal)
785// where we just generate random bytes because there is no rand support
786pub struct RandomBytesGenerator<T: ArrowPrimitiveType + Send + Sync> {
787    phantom: PhantomData<T>,
788    data_type: DataType,
789}
790
791impl<T: ArrowPrimitiveType + Send + Sync> std::fmt::Debug for RandomBytesGenerator<T> {
792    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
793        f.debug_struct("RandomBytesGenerator")
794            .field("data_type", &self.data_type)
795            .finish()
796    }
797}
798
799impl<T: ArrowPrimitiveType + Send + Sync> RandomBytesGenerator<T> {
800    fn new(data_type: DataType) -> Self {
801        Self {
802            phantom: Default::default(),
803            data_type,
804        }
805    }
806
807    fn byte_width() -> Result<u64, ArrowError> {
808        T::DATA_TYPE.primitive_width().ok_or_else(|| ArrowError::InvalidArgumentError(format!("Cannot generate the data type {} with the RandomBytesGenerator because it is not a fixed-width bytes type", T::DATA_TYPE))).map(|val| val as u64)
809    }
810}
811
812impl<T: ArrowPrimitiveType + Send + Sync> ArrayGenerator for RandomBytesGenerator<T> {
813    fn generate(
814        &mut self,
815        length: RowCount,
816        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
817    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
818        let num_bytes = length.0 * Self::byte_width()?;
819        let mut bytes = vec![0; num_bytes as usize];
820        rng.fill_bytes(&mut bytes);
821        let bytes = ScalarBuffer::new(Buffer::from(bytes), 0, length.0 as usize);
822        Ok(Arc::new(
823            PrimitiveArray::<T>::new(bytes, None).with_data_type(self.data_type.clone()),
824        ))
825    }
826
827    fn data_type(&self) -> &DataType {
828        &self.data_type
829    }
830
831    fn element_size_bytes(&self) -> Option<ByteCount> {
832        Self::byte_width().map(ByteCount::from).ok()
833    }
834}
835
836// This is pretty much the same thing as RandomBinaryGenerator but we can't use that
837// because there is no ArrowPrimitiveType for FixedSizeBinary
838#[derive(Debug)]
839pub struct RandomFixedSizeBinaryGenerator {
840    data_type: DataType,
841    size: i32,
842}
843
844impl RandomFixedSizeBinaryGenerator {
845    fn new(size: i32) -> Self {
846        Self {
847            size,
848            data_type: DataType::FixedSizeBinary(size),
849        }
850    }
851}
852
853impl ArrayGenerator for RandomFixedSizeBinaryGenerator {
854    fn generate(
855        &mut self,
856        length: RowCount,
857        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
858    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
859        let num_bytes = length.0 * self.size as u64;
860        let mut bytes = vec![0; num_bytes as usize];
861        rng.fill_bytes(&mut bytes);
862        Ok(Arc::new(FixedSizeBinaryArray::new(
863            self.size,
864            Buffer::from(bytes),
865            None,
866        )))
867    }
868
869    fn data_type(&self) -> &DataType {
870        &self.data_type
871    }
872
873    fn element_size_bytes(&self) -> Option<ByteCount> {
874        Some(ByteCount::from(self.size as u64))
875    }
876}
877
878#[derive(Debug)]
879pub struct RandomIntervalGenerator {
880    unit: IntervalUnit,
881    data_type: DataType,
882}
883
884impl RandomIntervalGenerator {
885    pub fn new(unit: IntervalUnit) -> Self {
886        Self {
887            unit,
888            data_type: DataType::Interval(unit),
889        }
890    }
891}
892
893impl ArrayGenerator for RandomIntervalGenerator {
894    fn generate(
895        &mut self,
896        length: RowCount,
897        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
898    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
899        match self.unit {
900            IntervalUnit::YearMonth => {
901                let months = (0..length.0)
902                    .map(|_| rng.random::<i32>())
903                    .collect::<Vec<_>>();
904                Ok(Arc::new(arrow_array::IntervalYearMonthArray::from(months)))
905            }
906            IntervalUnit::MonthDayNano => {
907                let day_time_array = (0..length.0)
908                    .map(|_| IntervalMonthDayNano::new(rng.random(), rng.random(), rng.random()))
909                    .collect::<Vec<_>>();
910                Ok(Arc::new(arrow_array::IntervalMonthDayNanoArray::from(
911                    day_time_array,
912                )))
913            }
914            IntervalUnit::DayTime => {
915                let day_time_array = (0..length.0)
916                    .map(|_| IntervalDayTime::new(rng.random(), rng.random()))
917                    .collect::<Vec<_>>();
918                Ok(Arc::new(arrow_array::IntervalDayTimeArray::from(
919                    day_time_array,
920                )))
921            }
922        }
923    }
924
925    fn data_type(&self) -> &DataType {
926        &self.data_type
927    }
928
929    fn element_size_bytes(&self) -> Option<ByteCount> {
930        Some(ByteCount::from(12))
931    }
932}
933#[derive(Debug)]
934pub struct RandomBinaryGenerator {
935    bytes_per_element: ByteCount,
936    scale_to_utf8: bool,
937    is_large: bool,
938    data_type: DataType,
939}
940
941impl RandomBinaryGenerator {
942    pub fn new(bytes_per_element: ByteCount, scale_to_utf8: bool, is_large: bool) -> Self {
943        Self {
944            bytes_per_element,
945            scale_to_utf8,
946            is_large,
947            data_type: match (scale_to_utf8, is_large) {
948                (false, false) => DataType::Binary,
949                (false, true) => DataType::LargeBinary,
950                (true, false) => DataType::Utf8,
951                (true, true) => DataType::LargeUtf8,
952            },
953        }
954    }
955}
956
957impl ArrayGenerator for RandomBinaryGenerator {
958    fn generate(
959        &mut self,
960        length: RowCount,
961        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
962    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
963        let mut bytes = vec![0; (self.bytes_per_element.0 * length.0) as usize];
964        rng.fill_bytes(&mut bytes);
965        if self.scale_to_utf8 {
966            // This doesn't give us the full UTF-8 range and it isn't statistically correct but
967            // it's fast and probably good enough for most cases
968            bytes = bytes.into_iter().map(|val| (val % 95) + 32).collect();
969        }
970        let bytes = Buffer::from(bytes);
971        if self.is_large {
972            let offsets = OffsetBuffer::from_lengths(iter::repeat_n(
973                self.bytes_per_element.0 as usize,
974                length.0 as usize,
975            ));
976            if self.scale_to_utf8 {
977                // This is safe because we are only using printable characters
978                unsafe {
979                    Ok(Arc::new(arrow_array::LargeStringArray::new_unchecked(
980                        offsets, bytes, None,
981                    )))
982                }
983            } else {
984                unsafe {
985                    Ok(Arc::new(arrow_array::LargeBinaryArray::new_unchecked(
986                        offsets, bytes, None,
987                    )))
988                }
989            }
990        } else {
991            let offsets = OffsetBuffer::from_lengths(iter::repeat_n(
992                self.bytes_per_element.0 as usize,
993                length.0 as usize,
994            ));
995            if self.scale_to_utf8 {
996                // This is safe because we are only using printable characters
997                unsafe {
998                    Ok(Arc::new(arrow_array::StringArray::new_unchecked(
999                        offsets, bytes, None,
1000                    )))
1001                }
1002            } else {
1003                unsafe {
1004                    Ok(Arc::new(arrow_array::BinaryArray::new_unchecked(
1005                        offsets, bytes, None,
1006                    )))
1007                }
1008            }
1009        }
1010    }
1011
1012    fn data_type(&self) -> &DataType {
1013        &self.data_type
1014    }
1015
1016    fn element_size_bytes(&self) -> Option<ByteCount> {
1017        // Not exactly correct since there are N + 1 4-byte offsets and this only counts N
1018        Some(ByteCount::from(
1019            self.bytes_per_element.0 + std::mem::size_of::<i32>() as u64,
1020        ))
1021    }
1022}
1023
1024/// Generate a sequence of strings with a prefix and a counter
1025///
1026/// For example, if the prefix is "user_" the strings will be "user_0", "user_1", ...
1027#[derive(Debug)]
1028pub struct PrefixPlusCounterGenerator {
1029    prefix: String,
1030    is_large: bool,
1031    data_type: DataType,
1032    current_counter: u64,
1033}
1034
1035impl PrefixPlusCounterGenerator {
1036    pub fn new(prefix: String, is_large: bool) -> Self {
1037        Self {
1038            prefix,
1039            is_large,
1040            data_type: if is_large {
1041                DataType::LargeUtf8
1042            } else {
1043                DataType::Utf8
1044            },
1045            current_counter: 0,
1046        }
1047    }
1048
1049    fn generate_values<T: OffsetSizeTrait>(
1050        &self,
1051        start: u64,
1052        num_values: u64,
1053    ) -> Result<Arc<dyn Array>, ArrowError> {
1054        let max_counter = start + num_values;
1055        let max_digits_per_counter = (max_counter as f64).log10().ceil() as u64;
1056        let max_bytes_per_str = max_digits_per_counter + self.prefix.len() as u64;
1057        let max_bytes = max_bytes_per_str * num_values;
1058        let mut builder =
1059            GenericStringBuilder::<T>::with_capacity(num_values as usize, max_bytes as usize);
1060        let mut word = String::with_capacity(max_bytes_per_str as usize);
1061        word.push_str(&self.prefix);
1062        for i in 0..num_values {
1063            let counter = start + i;
1064            word.truncate(self.prefix.len());
1065            word.push_str(&counter.to_string());
1066            builder.append_value(&word);
1067        }
1068        Ok(Arc::new(builder.finish()))
1069    }
1070}
1071
1072impl ArrayGenerator for PrefixPlusCounterGenerator {
1073    fn generate(
1074        &mut self,
1075        length: RowCount,
1076        _rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1077    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1078        let start = self.current_counter;
1079        self.current_counter += length.0;
1080        if self.is_large {
1081            self.generate_values::<i64>(start, length.0)
1082        } else {
1083            self.generate_values::<i32>(start, length.0)
1084        }
1085    }
1086
1087    fn data_type(&self) -> &DataType {
1088        &self.data_type
1089    }
1090
1091    fn element_size_bytes(&self) -> Option<ByteCount> {
1092        // It's not consistent
1093        None
1094    }
1095}
1096
1097/// Generate a sequence of binary strings with a prefix and a counter
1098///
1099/// The counter will be encoded (little-endian) as a u8, u16, u32, or u64 and added to the prefix
1100/// As long as more than 256 values are generated then the resulting array will have
1101/// variable width
1102#[derive(Debug)]
1103pub struct BinaryPrefixPlusCounterGenerator {
1104    prefix: Arc<[u8]>,
1105    is_large: bool,
1106    data_type: DataType,
1107    current_counter: u64,
1108}
1109
1110impl BinaryPrefixPlusCounterGenerator {
1111    pub fn new(prefix: Arc<[u8]>, is_large: bool) -> Self {
1112        Self {
1113            prefix,
1114            is_large,
1115            data_type: if is_large {
1116                DataType::LargeBinary
1117            } else {
1118                DataType::Binary
1119            },
1120            current_counter: 0,
1121        }
1122    }
1123
1124    fn generate_values<T: OffsetSizeTrait>(
1125        &self,
1126        start: u64,
1127        num_values: u64,
1128    ) -> Result<Arc<dyn Array>, ArrowError> {
1129        let max_bytes = (self.prefix.len() + std::mem::size_of::<u64>()) * num_values as usize;
1130        let mut builder = GenericBinaryBuilder::<T>::with_capacity(num_values as usize, max_bytes);
1131        let mut word = Vec::with_capacity(self.prefix.len() + std::mem::size_of::<u64>());
1132        word.extend_from_slice(&self.prefix);
1133        for i in 0..num_values {
1134            let counter = start + i;
1135            word.truncate(self.prefix.len());
1136            if counter < u8::MAX as u64 {
1137                word.push(counter as u8);
1138            } else if counter < u16::MAX as u64 {
1139                word.extend_from_slice(&(counter as u16).to_le_bytes());
1140            } else if counter < u32::MAX as u64 {
1141                word.extend_from_slice(&(counter as u32).to_le_bytes());
1142            } else {
1143                word.extend_from_slice(&counter.to_le_bytes());
1144            }
1145            builder.append_value(&word);
1146        }
1147        Ok(Arc::new(builder.finish()))
1148    }
1149}
1150
1151impl ArrayGenerator for BinaryPrefixPlusCounterGenerator {
1152    fn generate(
1153        &mut self,
1154        length: RowCount,
1155        _rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1156    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1157        let start = self.current_counter;
1158        self.current_counter += length.0;
1159        if self.is_large {
1160            self.generate_values::<i64>(start, length.0)
1161        } else {
1162            self.generate_values::<i32>(start, length.0)
1163        }
1164    }
1165
1166    fn data_type(&self) -> &DataType {
1167        &self.data_type
1168    }
1169
1170    fn element_size_bytes(&self) -> Option<ByteCount> {
1171        // It's not consistent
1172        None
1173    }
1174}
1175
1176// Common English stop words placed at the front to be sampled more frequently
1177const STOP_WORDS: &[&str] = &[
1178    "a", "an", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "is", "it",
1179    "no", "not", "of", "on", "or", "such", "that", "the", "their", "then", "there", "these",
1180    "they", "this", "to", "was", "will", "with",
1181];
1182
1183/// Word list with stop words at the front for Zipf sampling, computed once.
1184static SENTENCE_WORDS: LazyLock<Vec<&'static str>> = LazyLock::new(|| {
1185    let all_words = random_word::all(random_word::Lang::En);
1186    let mut words = Vec::with_capacity(STOP_WORDS.len() + all_words.len());
1187    words.extend(STOP_WORDS.iter().copied());
1188    words.extend(
1189        all_words
1190            .iter()
1191            .filter(|w| !STOP_WORDS.contains(w))
1192            .copied(),
1193    );
1194    words
1195});
1196
1197struct RandomSentenceGenerator {
1198    min_words: usize,
1199    max_words: usize,
1200    /// Zipf distribution for word selection (favors lower indices)
1201    zipf: Zipf<f64>,
1202    is_large: bool,
1203}
1204
1205impl std::fmt::Debug for RandomSentenceGenerator {
1206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1207        f.debug_struct("RandomSentenceGenerator")
1208            .field("min_words", &self.min_words)
1209            .field("max_words", &self.max_words)
1210            .field("num_words", &SENTENCE_WORDS.len())
1211            .field("is_large", &self.is_large)
1212            .finish()
1213    }
1214}
1215
1216impl RandomSentenceGenerator {
1217    pub fn new(min_words: usize, max_words: usize, is_large: bool) -> Self {
1218        // Zipf distribution with exponent ~1.0 approximates natural language
1219        let zipf = Zipf::new(SENTENCE_WORDS.len() as f64, 1.0).unwrap();
1220
1221        Self {
1222            min_words,
1223            max_words,
1224            zipf,
1225            is_large,
1226        }
1227    }
1228}
1229
1230impl ArrayGenerator for RandomSentenceGenerator {
1231    fn generate(
1232        &mut self,
1233        length: RowCount,
1234        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1235    ) -> Result<Arc<dyn Array>, ArrowError> {
1236        let mut values = Vec::with_capacity(length.0 as usize);
1237
1238        for _ in 0..length.0 {
1239            let num_words = rng.random_range(self.min_words..=self.max_words);
1240            let sentence: String = (0..num_words)
1241                .map(|_| {
1242                    // Zipf returns 1-indexed values, subtract 1 for 0-indexed array
1243                    let idx = rng.sample(self.zipf) as usize - 1;
1244                    SENTENCE_WORDS[idx]
1245                })
1246                .collect::<Vec<_>>()
1247                .join(" ");
1248            values.push(sentence);
1249        }
1250
1251        if self.is_large {
1252            Ok(Arc::new(LargeStringArray::from(values)))
1253        } else {
1254            Ok(Arc::new(StringArray::from(values)))
1255        }
1256    }
1257
1258    fn data_type(&self) -> &DataType {
1259        if self.is_large {
1260            &DataType::LargeUtf8
1261        } else {
1262            &DataType::Utf8
1263        }
1264    }
1265
1266    fn element_size_bytes(&self) -> Option<ByteCount> {
1267        // Estimate average word length as 5, plus space
1268        // See https://arxiv.org/pdf/1208.6109
1269        let avg_word_length = 6;
1270        let avg_words = (self.min_words + self.max_words) / 2;
1271        Some(ByteCount::from((avg_word_length * avg_words) as u64))
1272    }
1273}
1274
1275#[derive(Debug)]
1276struct RandomWordGenerator {
1277    words: &'static [&'static str],
1278    is_large: bool,
1279}
1280
1281impl RandomWordGenerator {
1282    pub fn new(is_large: bool) -> Self {
1283        let words = random_word::all(random_word::Lang::En);
1284        Self { words, is_large }
1285    }
1286}
1287
1288impl ArrayGenerator for RandomWordGenerator {
1289    fn generate(
1290        &mut self,
1291        length: RowCount,
1292        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1293    ) -> Result<Arc<dyn Array>, ArrowError> {
1294        let mut values = Vec::with_capacity(length.0 as usize);
1295
1296        for _ in 0..length.0 {
1297            let word = self.words[rng.random_range(0..self.words.len())];
1298            values.push(word.to_string());
1299        }
1300
1301        if self.is_large {
1302            Ok(Arc::new(LargeStringArray::from(values)))
1303        } else {
1304            Ok(Arc::new(StringArray::from(values)))
1305        }
1306    }
1307
1308    fn data_type(&self) -> &DataType {
1309        if self.is_large {
1310            &DataType::LargeUtf8
1311        } else {
1312            &DataType::Utf8
1313        }
1314    }
1315
1316    fn element_size_bytes(&self) -> Option<ByteCount> {
1317        // Average English word length is ~5 characters
1318        Some(ByteCount::from(5))
1319    }
1320}
1321
1322#[derive(Debug)]
1323pub struct VariableRandomBinaryGenerator {
1324    lengths_gen: Box<dyn ArrayGenerator>,
1325    data_type: DataType,
1326}
1327
1328impl VariableRandomBinaryGenerator {
1329    pub fn new(min_bytes_per_element: ByteCount, max_bytes_per_element: ByteCount) -> Self {
1330        let lengths_dist = Uniform::new_inclusive(
1331            min_bytes_per_element.0 as i32,
1332            max_bytes_per_element.0 as i32,
1333        )
1334        .unwrap();
1335        let lengths_gen = rand_with_distribution::<Int32Type, Uniform<i32>>(lengths_dist);
1336
1337        Self {
1338            lengths_gen,
1339            data_type: DataType::Binary,
1340        }
1341    }
1342}
1343
1344impl ArrayGenerator for VariableRandomBinaryGenerator {
1345    fn generate(
1346        &mut self,
1347        length: RowCount,
1348        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1349    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1350        let lengths = self.lengths_gen.generate(length, rng)?;
1351        let lengths = lengths.as_primitive::<Int32Type>();
1352        let total_length = lengths.values().iter().map(|i| *i as usize).sum::<usize>();
1353        let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1354        let mut bytes = vec![0; total_length];
1355        rng.fill_bytes(&mut bytes);
1356        let bytes = Buffer::from(bytes);
1357        Ok(Arc::new(BinaryArray::try_new(offsets, bytes, None)?))
1358    }
1359
1360    fn data_type(&self) -> &DataType {
1361        &self.data_type
1362    }
1363
1364    fn element_size_bytes(&self) -> Option<ByteCount> {
1365        None
1366    }
1367}
1368
1369pub struct CycleBinaryGenerator<T: ByteArrayType> {
1370    values: Vec<u8>,
1371    lengths: Vec<usize>,
1372    data_type: DataType,
1373    array_type: PhantomData<T>,
1374    width: Option<ByteCount>,
1375    idx: usize,
1376}
1377
1378impl<T: ByteArrayType> std::fmt::Debug for CycleBinaryGenerator<T> {
1379    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1380        f.debug_struct("CycleBinaryGenerator")
1381            .field("values", &self.values)
1382            .field("lengths", &self.lengths)
1383            .field("data_type", &self.data_type)
1384            .field("width", &self.width)
1385            .field("idx", &self.idx)
1386            .finish()
1387    }
1388}
1389
1390impl<T: ByteArrayType> CycleBinaryGenerator<T> {
1391    pub fn from_strings(values: &[&str]) -> Self {
1392        if values.is_empty() {
1393            panic!("Attempt to create a cycle generator with no values");
1394        }
1395        let lengths = values.iter().map(|s| s.len()).collect::<Vec<_>>();
1396        let typical_length = lengths[0];
1397        let width = if lengths.iter().all(|item| *item == typical_length) {
1398            Some(ByteCount::from(
1399                typical_length as u64 + std::mem::size_of::<i32>() as u64,
1400            ))
1401        } else {
1402            None
1403        };
1404        let values = values
1405            .iter()
1406            .flat_map(|s| s.as_bytes().iter().copied())
1407            .collect::<Vec<_>>();
1408        Self {
1409            values,
1410            lengths,
1411            data_type: T::DATA_TYPE,
1412            array_type: PhantomData,
1413            width,
1414            idx: 0,
1415        }
1416    }
1417}
1418
1419impl<T: ByteArrayType> ArrayGenerator for CycleBinaryGenerator<T> {
1420    fn generate(
1421        &mut self,
1422        length: RowCount,
1423        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
1424    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1425        let lengths = self
1426            .lengths
1427            .iter()
1428            .copied()
1429            .cycle()
1430            .skip(self.idx)
1431            .take(length.0 as usize);
1432        let num_bytes = lengths.clone().sum();
1433        let byte_offset = self.lengths[0..self.idx].iter().sum();
1434        let bytes = self
1435            .values
1436            .iter()
1437            .cycle()
1438            .skip(byte_offset)
1439            .copied()
1440            .take(num_bytes)
1441            .collect::<Vec<_>>();
1442        let bytes = Buffer::from(bytes);
1443        let offsets = OffsetBuffer::from_lengths(lengths);
1444        self.idx = (self.idx + length.0 as usize) % self.lengths.len();
1445        Ok(Arc::new(arrow_array::GenericByteArray::<T>::new(
1446            offsets, bytes, None,
1447        )))
1448    }
1449
1450    fn data_type(&self) -> &DataType {
1451        &self.data_type
1452    }
1453
1454    fn element_size_bytes(&self) -> Option<ByteCount> {
1455        self.width
1456    }
1457}
1458
1459pub struct FixedBinaryGenerator<T: ByteArrayType> {
1460    value: Vec<u8>,
1461    data_type: DataType,
1462    array_type: PhantomData<T>,
1463}
1464
1465impl<T: ByteArrayType> std::fmt::Debug for FixedBinaryGenerator<T> {
1466    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1467        f.debug_struct("FixedBinaryGenerator")
1468            .field("value", &self.value)
1469            .field("data_type", &self.data_type)
1470            .finish()
1471    }
1472}
1473
1474impl<T: ByteArrayType> FixedBinaryGenerator<T> {
1475    pub fn new(value: Vec<u8>) -> Self {
1476        Self {
1477            value,
1478            data_type: T::DATA_TYPE,
1479            array_type: PhantomData,
1480        }
1481    }
1482}
1483
1484impl<T: ByteArrayType> ArrayGenerator for FixedBinaryGenerator<T> {
1485    fn generate(
1486        &mut self,
1487        length: RowCount,
1488        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
1489    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1490        let bytes = Buffer::from(Vec::from_iter(
1491            self.value
1492                .iter()
1493                .cycle()
1494                .take((length.0 * self.value.len() as u64) as usize)
1495                .copied(),
1496        ));
1497        let offsets =
1498            OffsetBuffer::from_lengths(iter::repeat_n(self.value.len(), length.0 as usize));
1499        Ok(Arc::new(arrow_array::GenericByteArray::<T>::new(
1500            offsets, bytes, None,
1501        )))
1502    }
1503
1504    fn data_type(&self) -> &DataType {
1505        &self.data_type
1506    }
1507
1508    fn element_size_bytes(&self) -> Option<ByteCount> {
1509        // Not exactly correct since there are N + 1 4-byte offsets and this only counts N
1510        Some(ByteCount::from(
1511            self.value.len() as u64 + std::mem::size_of::<i32>() as u64,
1512        ))
1513    }
1514}
1515
1516pub struct DictionaryGenerator<K: ArrowDictionaryKeyType> {
1517    generator: Box<dyn ArrayGenerator>,
1518    data_type: DataType,
1519    key_type: PhantomData<K>,
1520    key_width: u64,
1521}
1522
1523impl<K: ArrowDictionaryKeyType> std::fmt::Debug for DictionaryGenerator<K> {
1524    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1525        f.debug_struct("DictionaryGenerator")
1526            .field("generator", &self.generator)
1527            .field("data_type", &self.data_type)
1528            .field("key_width", &self.key_width)
1529            .finish()
1530    }
1531}
1532
1533impl<K: ArrowDictionaryKeyType> DictionaryGenerator<K> {
1534    fn new(generator: Box<dyn ArrayGenerator>) -> Self {
1535        let key_type = Box::new(K::DATA_TYPE);
1536        let key_width = key_type
1537            .primitive_width()
1538            .expect("dictionary key types should have a known width")
1539            as u64;
1540        let val_type = Box::new(generator.data_type().clone());
1541        let dict_type = DataType::Dictionary(key_type, val_type);
1542        Self {
1543            generator,
1544            data_type: dict_type,
1545            key_type: PhantomData,
1546            key_width,
1547        }
1548    }
1549}
1550
1551impl<K: ArrowDictionaryKeyType + Send + Sync> ArrayGenerator for DictionaryGenerator<K> {
1552    fn generate(
1553        &mut self,
1554        length: RowCount,
1555        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1556    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1557        let underlying = self.generator.generate(length, rng)?;
1558        arrow_cast::cast::cast(&underlying, &self.data_type)
1559    }
1560
1561    fn data_type(&self) -> &DataType {
1562        &self.data_type
1563    }
1564
1565    fn element_size_bytes(&self) -> Option<ByteCount> {
1566        self.generator
1567            .element_size_bytes()
1568            .map(|size_bytes| ByteCount::from(size_bytes.0 + self.key_width))
1569    }
1570}
1571
1572/// Generator that produces low-cardinality data by generating a fixed set of
1573/// unique values and then randomly selecting from them.
1574struct LowCardinalityGenerator {
1575    inner: Box<dyn ArrayGenerator>,
1576    cardinality: usize,
1577    /// Cached unique values, generated on first call
1578    unique_values: Option<Arc<dyn Array>>,
1579}
1580
1581impl std::fmt::Debug for LowCardinalityGenerator {
1582    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1583        f.debug_struct("LowCardinalityGenerator")
1584            .field("inner", &self.inner)
1585            .field("cardinality", &self.cardinality)
1586            .field("initialized", &self.unique_values.is_some())
1587            .finish()
1588    }
1589}
1590
1591impl LowCardinalityGenerator {
1592    fn new(inner: Box<dyn ArrayGenerator>, cardinality: usize) -> Self {
1593        Self {
1594            inner,
1595            cardinality,
1596            unique_values: None,
1597        }
1598    }
1599}
1600
1601impl ArrayGenerator for LowCardinalityGenerator {
1602    fn generate(
1603        &mut self,
1604        length: RowCount,
1605        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1606    ) -> Result<Arc<dyn Array>, ArrowError> {
1607        // Generate unique values on first call
1608        if self.unique_values.is_none() {
1609            self.unique_values = Some(
1610                self.inner
1611                    .generate(RowCount::from(self.cardinality as u64), rng)?,
1612            );
1613        }
1614
1615        let unique_values = self.unique_values.as_ref().unwrap();
1616
1617        // Generate random indices into the unique values
1618        let indices: Vec<usize> = (0..length.0)
1619            .map(|_| rng.random_range(0..self.cardinality))
1620            .collect();
1621
1622        // Use arrow's take to select values
1623        let indices_array =
1624            arrow_array::UInt32Array::from(indices.iter().map(|&i| i as u32).collect::<Vec<_>>());
1625        arrow::compute::take(unique_values.as_ref(), &indices_array, None)
1626            .map(|arr| arr as Arc<dyn Array>)
1627    }
1628
1629    fn data_type(&self) -> &DataType {
1630        self.inner.data_type()
1631    }
1632
1633    fn element_size_bytes(&self) -> Option<ByteCount> {
1634        self.inner.element_size_bytes()
1635    }
1636}
1637
1638#[derive(Debug)]
1639struct RandomListGenerator {
1640    field: Arc<Field>,
1641    child_field: Arc<Field>,
1642    items_gen: Box<dyn ArrayGenerator>,
1643    lengths_gen: Box<dyn ArrayGenerator>,
1644    is_large: bool,
1645}
1646
1647impl RandomListGenerator {
1648    // Creates a list generator that generates random lists with lengths between 0 and 10 (inclusive)
1649    fn new(items_gen: Box<dyn ArrayGenerator>, is_large: bool) -> Self {
1650        let child_field = Arc::new(Field::new("item", items_gen.data_type().clone(), true));
1651        let list_type = if is_large {
1652            DataType::LargeList(child_field.clone())
1653        } else {
1654            DataType::List(child_field.clone())
1655        };
1656        let field = Field::new("", list_type, true);
1657        let lengths_gen = if is_large {
1658            let lengths_dist = Uniform::new_inclusive(0, 10).unwrap();
1659            rand_with_distribution::<Int64Type, Uniform<i64>>(lengths_dist)
1660        } else {
1661            let lengths_dist = Uniform::new_inclusive(0, 10).unwrap();
1662            rand_with_distribution::<Int32Type, Uniform<i32>>(lengths_dist)
1663        };
1664        Self {
1665            field: Arc::new(field),
1666            child_field,
1667            items_gen,
1668            lengths_gen,
1669            is_large,
1670        }
1671    }
1672}
1673
1674impl ArrayGenerator for RandomListGenerator {
1675    fn generate(
1676        &mut self,
1677        length: RowCount,
1678        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1679    ) -> Result<Arc<dyn Array>, ArrowError> {
1680        let lengths = self.lengths_gen.generate(length, rng)?;
1681        if self.is_large {
1682            let lengths = lengths.as_primitive::<Int64Type>();
1683            let total_length = lengths.values().iter().sum::<i64>() as u64;
1684            let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1685            let items = self.items_gen.generate(RowCount::from(total_length), rng)?;
1686            Ok(Arc::new(LargeListArray::try_new(
1687                self.child_field.clone(),
1688                offsets,
1689                items,
1690                None,
1691            )?))
1692        } else {
1693            let lengths = lengths.as_primitive::<Int32Type>();
1694            let total_length = lengths.values().iter().sum::<i32>() as u64;
1695            let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1696            let items = self.items_gen.generate(RowCount::from(total_length), rng)?;
1697            Ok(Arc::new(ListArray::try_new(
1698                self.child_field.clone(),
1699                offsets,
1700                items,
1701                None,
1702            )?))
1703        }
1704    }
1705
1706    fn data_type(&self) -> &DataType {
1707        self.field.data_type()
1708    }
1709
1710    fn element_size_bytes(&self) -> Option<ByteCount> {
1711        None
1712    }
1713}
1714
1715/// Generates random map arrays where each map has 0-4 entries.
1716#[derive(Debug)]
1717struct RandomMapGenerator {
1718    field: Arc<Field>,
1719    entries_field: Arc<Field>,
1720    keys_gen: Box<dyn ArrayGenerator>,
1721    values_gen: Box<dyn ArrayGenerator>,
1722    lengths_gen: Box<dyn ArrayGenerator>,
1723}
1724
1725impl RandomMapGenerator {
1726    fn new(keys_gen: Box<dyn ArrayGenerator>, values_gen: Box<dyn ArrayGenerator>) -> Self {
1727        let entries_fields = Fields::from(vec![
1728            Field::new("keys", keys_gen.data_type().clone(), false),
1729            Field::new("values", values_gen.data_type().clone(), true),
1730        ]);
1731        let entries_field = Arc::new(Field::new(
1732            "entries",
1733            DataType::Struct(entries_fields),
1734            false,
1735        ));
1736        let map_type = DataType::Map(entries_field.clone(), false);
1737        let field = Arc::new(Field::new("", map_type, true));
1738        let lengths_dist = Uniform::new_inclusive(0_i32, 4).unwrap();
1739        let lengths_gen = rand_with_distribution::<Int32Type, Uniform<i32>>(lengths_dist);
1740
1741        Self {
1742            field,
1743            entries_field,
1744            keys_gen,
1745            values_gen,
1746            lengths_gen,
1747        }
1748    }
1749}
1750
1751impl ArrayGenerator for RandomMapGenerator {
1752    fn generate(
1753        &mut self,
1754        length: RowCount,
1755        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1756    ) -> Result<Arc<dyn Array>, ArrowError> {
1757        let lengths = self.lengths_gen.generate(length, rng)?;
1758        let lengths = lengths.as_primitive::<Int32Type>();
1759        let total_entries = lengths.values().iter().sum::<i32>() as u64;
1760        let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1761
1762        let keys = self.keys_gen.generate(RowCount::from(total_entries), rng)?;
1763        let values = self
1764            .values_gen
1765            .generate(RowCount::from(total_entries), rng)?;
1766
1767        let entries = StructArray::new(
1768            Fields::from(vec![
1769                Field::new("keys", keys.data_type().clone(), false),
1770                Field::new("values", values.data_type().clone(), true),
1771            ]),
1772            vec![keys, values],
1773            None,
1774        );
1775
1776        Ok(Arc::new(MapArray::try_new(
1777            self.entries_field.clone(),
1778            offsets,
1779            entries,
1780            None,
1781            false,
1782        )?))
1783    }
1784
1785    fn data_type(&self) -> &DataType {
1786        self.field.data_type()
1787    }
1788
1789    fn element_size_bytes(&self) -> Option<ByteCount> {
1790        None
1791    }
1792}
1793
1794#[derive(Debug)]
1795struct NullArrayGenerator {}
1796
1797impl ArrayGenerator for NullArrayGenerator {
1798    fn generate(
1799        &mut self,
1800        length: RowCount,
1801        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
1802    ) -> Result<Arc<dyn Array>, ArrowError> {
1803        Ok(Arc::new(NullArray::new(length.0 as usize)))
1804    }
1805
1806    fn data_type(&self) -> &DataType {
1807        &DataType::Null
1808    }
1809
1810    fn element_size_bytes(&self) -> Option<ByteCount> {
1811        None
1812    }
1813}
1814
1815/// Generates 2 dimensional vectors along the unit circle, with a configurable number of steps per circle.
1816#[derive(Debug)]
1817struct RadialStepGenerator {
1818    num_steps_per_circle: u32,
1819    data_field: Arc<Field>,
1820    data_type: DataType,
1821    current_step: u32,
1822}
1823
1824impl RadialStepGenerator {
1825    fn new(num_steps_per_circle: u32) -> Self {
1826        let data_field = Arc::new(Field::new("item", DataType::Float32, false));
1827        let data_type = DataType::FixedSizeList(data_field.clone(), 2);
1828        Self {
1829            num_steps_per_circle,
1830            data_field,
1831            data_type,
1832            current_step: 0,
1833        }
1834    }
1835}
1836
1837impl ArrayGenerator for RadialStepGenerator {
1838    fn generate(
1839        &mut self,
1840        length: RowCount,
1841        _rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1842    ) -> Result<Arc<dyn Array>, ArrowError> {
1843        let mut values_builder = Float32Builder::with_capacity(length.0 as usize * 2);
1844        for _ in 0..length.0 {
1845            let angle = (self.current_step as f32) / (self.num_steps_per_circle as f32)
1846                * 2.0
1847                * std::f32::consts::PI;
1848            values_builder.append_value(angle.cos());
1849            values_builder.append_value(angle.sin());
1850            self.current_step = (self.current_step + 1) % self.num_steps_per_circle;
1851        }
1852        let values = values_builder.finish();
1853        let vectors =
1854            FixedSizeListArray::try_new(self.data_field.clone(), 2, Arc::new(values), None)?;
1855        Ok(Arc::new(vectors))
1856    }
1857
1858    fn data_type(&self) -> &DataType {
1859        &self.data_type
1860    }
1861
1862    fn element_size_bytes(&self) -> Option<ByteCount> {
1863        Some(ByteCount::from(8))
1864    }
1865}
1866
1867/// Cycles through a set of centroids, adding noise to each point
1868#[derive(Debug)]
1869struct JitterCentroidsGenerator {
1870    centroids: Float32Array,
1871    dimension: u32,
1872    noise_level: f32,
1873    data_type: DataType,
1874    data_field: Arc<Field>,
1875
1876    offset: usize,
1877}
1878
1879impl JitterCentroidsGenerator {
1880    fn try_new(centroids: Arc<dyn Array>, noise_level: f32) -> Result<Self, ArrowError> {
1881        let DataType::FixedSizeList(values_field, dimension) = centroids.data_type() else {
1882            return Err(ArrowError::InvalidArgumentError(
1883                "Centroids must be a FixedSizeList".to_string(),
1884            ));
1885        };
1886        if values_field.data_type() != &DataType::Float32 {
1887            return Err(ArrowError::InvalidArgumentError(
1888                "Centroids values must be a Float32".to_string(),
1889            ));
1890        }
1891        let data_type = DataType::FixedSizeList(values_field.clone(), *dimension);
1892        Ok(Self {
1893            centroids: centroids
1894                .as_fixed_size_list()
1895                .values()
1896                .as_primitive::<Float32Type>()
1897                .clone(),
1898            dimension: *dimension as u32,
1899            noise_level,
1900            data_type,
1901            data_field: values_field.clone(),
1902            offset: 0,
1903        })
1904    }
1905}
1906
1907impl ArrayGenerator for JitterCentroidsGenerator {
1908    fn generate(
1909        &mut self,
1910        length: RowCount,
1911        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1912    ) -> Result<Arc<dyn Array>, ArrowError> {
1913        let mut values_builder =
1914            Float32Builder::with_capacity(length.0 as usize * self.dimension as usize);
1915        for _ in 0..length.0 {
1916            // Generate random N dimensional point
1917            let mut noise = (0..self.dimension as usize)
1918                .map(|_| rng.random::<f32>())
1919                .collect::<Vec<_>>();
1920            // Scale point to noise_level length
1921            let scale = self.noise_level / noise.iter().map(|v| v * v).sum::<f32>().sqrt();
1922            noise.iter_mut().for_each(|v| *v *= scale);
1923
1924            // Add noise to centroid and store in values
1925            for (i, noise) in noise.into_iter().enumerate() {
1926                let centroid_val = self.centroids.value(self.offset + i);
1927                let jittered_val = centroid_val + noise;
1928                values_builder.append_value(jittered_val);
1929            }
1930            // Advance to next centroid
1931            self.offset = (self.offset + self.dimension as usize) % self.centroids.len();
1932        }
1933        let values = values_builder.finish();
1934        let vectors = FixedSizeListArray::try_new(
1935            self.data_field.clone(),
1936            self.dimension as i32,
1937            Arc::new(values),
1938            None,
1939        )?;
1940        Ok(Arc::new(vectors))
1941    }
1942
1943    fn data_type(&self) -> &DataType {
1944        &self.data_type
1945    }
1946
1947    fn element_size_bytes(&self) -> Option<ByteCount> {
1948        Some(ByteCount::from(self.dimension as u64 * 4))
1949    }
1950}
1951#[derive(Debug)]
1952struct RandomStructGenerator {
1953    fields: Fields,
1954    data_type: DataType,
1955    child_gens: Vec<Box<dyn ArrayGenerator>>,
1956}
1957
1958impl RandomStructGenerator {
1959    fn new(fields: Fields, child_gens: Vec<Box<dyn ArrayGenerator>>) -> Self {
1960        let data_type = DataType::Struct(fields.clone());
1961        Self {
1962            fields,
1963            data_type,
1964            child_gens,
1965        }
1966    }
1967}
1968
1969impl ArrayGenerator for RandomStructGenerator {
1970    fn generate(
1971        &mut self,
1972        length: RowCount,
1973        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1974    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1975        if self.child_gens.is_empty() {
1976            // Have to create empty struct arrays specially to ensure they have the correct
1977            // row count
1978            let struct_arr = StructArray::new_empty_fields(length.0 as usize, None);
1979            return Ok(Arc::new(struct_arr));
1980        }
1981        let child_arrays = self
1982            .child_gens
1983            .iter_mut()
1984            .map(|genn| genn.generate(length, rng))
1985            .collect::<Result<Vec<_>, ArrowError>>()?;
1986        let struct_arr = StructArray::new(self.fields.clone(), child_arrays, None);
1987        Ok(Arc::new(struct_arr))
1988    }
1989
1990    fn data_type(&self) -> &DataType {
1991        &self.data_type
1992    }
1993
1994    fn element_size_bytes(&self) -> Option<ByteCount> {
1995        let mut sum = 0;
1996        for child_gen in &self.child_gens {
1997            sum += child_gen.element_size_bytes()?.0;
1998        }
1999        Some(ByteCount::from(sum))
2000    }
2001}
2002
2003/// A RecordBatchReader that generates batches of the given size from the given array generators
2004pub struct FixedSizeBatchGenerator {
2005    rng: rand_xoshiro::Xoshiro256PlusPlus,
2006    generators: Vec<Box<dyn ArrayGenerator>>,
2007    batch_size: RowCount,
2008    num_batches: BatchCount,
2009    schema: SchemaRef,
2010}
2011
2012impl FixedSizeBatchGenerator {
2013    fn new(
2014        generators: Vec<(Option<String>, Box<dyn ArrayGenerator>)>,
2015        batch_size: RowCount,
2016        num_batches: BatchCount,
2017        seed: Option<Seed>,
2018        default_null_probability: Option<f64>,
2019    ) -> Self {
2020        let mut fields = Vec::with_capacity(generators.len());
2021        for (field_index, field_gen) in generators.iter().enumerate() {
2022            let (name, genn) = field_gen;
2023            let default_name = format!("field_{}", field_index);
2024            let name = name.clone().unwrap_or(default_name);
2025            let mut field = Field::new(name, genn.data_type().clone(), true);
2026            if let Some(metadata) = genn.metadata() {
2027                field = field.with_metadata(metadata);
2028            }
2029            fields.push(field);
2030        }
2031        let mut generators = generators
2032            .into_iter()
2033            .map(|(_, genn)| genn)
2034            .collect::<Vec<_>>();
2035        if let Some(null_probability) = default_null_probability {
2036            generators = generators
2037                .into_iter()
2038                .map(|genn| genn.with_random_nulls(null_probability))
2039                .collect();
2040        }
2041        let schema = Arc::new(Schema::new(fields));
2042        Self {
2043            rng: rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(
2044                seed.map(|s| s.0).unwrap_or(DEFAULT_SEED.0),
2045            ),
2046            generators,
2047            batch_size,
2048            num_batches,
2049            schema,
2050        }
2051    }
2052
2053    fn gen_next(&mut self) -> Result<RecordBatch, ArrowError> {
2054        let mut arrays = Vec::with_capacity(self.generators.len());
2055        for genn in self.generators.iter_mut() {
2056            let arr = genn.generate(self.batch_size, &mut self.rng)?;
2057            arrays.push(arr);
2058        }
2059        self.num_batches.0 -= 1;
2060        Ok(RecordBatch::try_new_with_options(
2061            self.schema.clone(),
2062            arrays,
2063            &RecordBatchOptions::new().with_row_count(Some(self.batch_size.0 as usize)),
2064        )
2065        .unwrap())
2066    }
2067}
2068
2069impl Iterator for FixedSizeBatchGenerator {
2070    type Item = Result<RecordBatch, ArrowError>;
2071
2072    fn next(&mut self) -> Option<Self::Item> {
2073        if self.num_batches.0 == 0 {
2074            return None;
2075        }
2076        Some(self.gen_next())
2077    }
2078}
2079
2080impl RecordBatchReader for FixedSizeBatchGenerator {
2081    fn schema(&self) -> SchemaRef {
2082        self.schema.clone()
2083    }
2084}
2085
2086/// A builder to create a record batch reader with generated data
2087///
2088/// This type is meant to be used in a fluent builder style to define the schema and generators
2089/// for a record batch reader.
2090#[derive(Default)]
2091pub struct BatchGeneratorBuilder {
2092    generators: Vec<(Option<String>, Box<dyn ArrayGenerator>)>,
2093    default_null_probability: Option<f64>,
2094    seed: Option<Seed>,
2095}
2096
2097pub enum RoundingBehavior {
2098    ExactOrErr,
2099    RoundUp,
2100    RoundDown,
2101}
2102
2103impl BatchGeneratorBuilder {
2104    /// Create a new BatchGeneratorBuilder with a default random seed
2105    pub fn new() -> Self {
2106        Default::default()
2107    }
2108
2109    /// Create a new BatchGeneratorBuilder with the given seed
2110    pub fn new_with_seed(seed: Seed) -> Self {
2111        Self {
2112            seed: Some(seed),
2113            ..Default::default()
2114        }
2115    }
2116
2117    /// Adds a new column to the generator
2118    ///
2119    /// See [`crate::generator::array`] for methods to create generators
2120    pub fn col(mut self, name: impl Into<String>, genn: Box<dyn ArrayGenerator>) -> Self {
2121        self.generators.push((Some(name.into()), genn));
2122        self
2123    }
2124
2125    /// Adds a new column to the generator with a generated unique name
2126    ///
2127    /// See [`crate::generator::array`] for methods to create generators
2128    pub fn anon_col(mut self, genn: Box<dyn ArrayGenerator>) -> Self {
2129        self.generators.push((None, genn));
2130        self
2131    }
2132
2133    pub fn into_batch_rows(self, batch_size: RowCount) -> Result<RecordBatch, ArrowError> {
2134        let mut reader = self.into_reader_rows(batch_size, BatchCount::from(1));
2135        reader
2136            .next()
2137            .expect("Asked for 1 batch but reader was empty")
2138    }
2139
2140    pub fn into_batch_bytes(
2141        self,
2142        batch_size: ByteCount,
2143        rounding: RoundingBehavior,
2144    ) -> Result<RecordBatch, ArrowError> {
2145        let mut reader = self.into_reader_bytes(batch_size, BatchCount::from(1), rounding)?;
2146        reader
2147            .next()
2148            .expect("Asked for 1 batch but reader was empty")
2149    }
2150
2151    /// Create a RecordBatchReader that generates batches of the given size (in rows)
2152    pub fn into_reader_rows(
2153        self,
2154        batch_size: RowCount,
2155        num_batches: BatchCount,
2156    ) -> impl RecordBatchReader {
2157        FixedSizeBatchGenerator::new(
2158            self.generators,
2159            batch_size,
2160            num_batches,
2161            self.seed,
2162            self.default_null_probability,
2163        )
2164    }
2165
2166    pub fn into_reader_stream(
2167        self,
2168        batch_size: RowCount,
2169        num_batches: BatchCount,
2170    ) -> (
2171        BoxStream<'static, Result<RecordBatch, ArrowError>>,
2172        Arc<Schema>,
2173    ) {
2174        // TODO: this is pretty lazy and could be optimized
2175        let reader = self.into_reader_rows(batch_size, num_batches);
2176        let schema = reader.schema();
2177        let batches = reader.collect::<Vec<_>>();
2178        (futures::stream::iter(batches).boxed(), schema)
2179    }
2180
2181    /// Create a RecordBatchReader that generates batches of the given size (in bytes)
2182    pub fn into_reader_bytes(
2183        self,
2184        batch_size_bytes: ByteCount,
2185        num_batches: BatchCount,
2186        rounding: RoundingBehavior,
2187    ) -> Result<impl RecordBatchReader, ArrowError> {
2188        let bytes_per_row = self
2189            .generators
2190            .iter()
2191            .map(|genn| genn.1.element_size_bytes().map(|byte_count| byte_count.0).ok_or(
2192                        ArrowError::NotYetImplemented("The function into_reader_bytes currently requires each array generator to have a fixed element size".to_string())
2193                )
2194            )
2195            .sum::<Result<u64, ArrowError>>()?;
2196        let mut num_rows = RowCount::from(batch_size_bytes.0 / bytes_per_row);
2197        if !batch_size_bytes.0.is_multiple_of(bytes_per_row) {
2198            match rounding {
2199                RoundingBehavior::ExactOrErr => {
2200                    return Err(ArrowError::NotYetImplemented(
2201                        format!("Exact rounding requested but not possible.  Batch size requested {}, row size: {}", batch_size_bytes.0, bytes_per_row))
2202                    );
2203                }
2204                RoundingBehavior::RoundUp => {
2205                    num_rows = RowCount::from(num_rows.0 + 1);
2206                }
2207                RoundingBehavior::RoundDown => (),
2208            }
2209        }
2210        Ok(self.into_reader_rows(num_rows, num_batches))
2211    }
2212
2213    /// Set the seed for the generator
2214    pub fn with_seed(mut self, seed: Seed) -> Self {
2215        self.seed = Some(seed);
2216        self
2217    }
2218
2219    /// Adds nulls (with the given probability) to all columns
2220    pub fn with_random_nulls(&mut self, default_null_probability: f64) {
2221        self.default_null_probability = Some(default_null_probability);
2222    }
2223}
2224
2225/// Factory for creating a single random array
2226pub struct ArrayGeneratorBuilder {
2227    generator: Box<dyn ArrayGenerator>,
2228    seed: Option<Seed>,
2229}
2230
2231impl ArrayGeneratorBuilder {
2232    fn new(generator: Box<dyn ArrayGenerator>) -> Self {
2233        Self {
2234            generator,
2235            seed: None,
2236        }
2237    }
2238
2239    /// Use the given seed for the generator
2240    pub fn with_seed(mut self, seed: Seed) -> Self {
2241        self.seed = Some(seed);
2242        self
2243    }
2244
2245    /// Generate a single array with the given length
2246    pub fn into_array_rows(
2247        mut self,
2248        length: RowCount,
2249    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
2250        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(
2251            self.seed.map(|s| s.0).unwrap_or(DEFAULT_SEED.0),
2252        );
2253        self.generator.generate(length, &mut rng)
2254    }
2255}
2256
2257const MS_PER_DAY: i64 = 86400000;
2258
2259pub mod array {
2260
2261    use arrow::datatypes::{Int16Type, Int64Type, Int8Type};
2262    use arrow_array::types::{
2263        Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType,
2264        DurationNanosecondType, DurationSecondType, Float16Type, Float32Type, Float64Type,
2265        UInt16Type, UInt32Type, UInt64Type, UInt8Type,
2266    };
2267    use arrow_array::{
2268        ArrowNativeTypeOp, BooleanArray, Date32Array, Date64Array, Time32MillisecondArray,
2269        Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
2270        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
2271        TimestampSecondArray,
2272    };
2273    use arrow_schema::{IntervalUnit, TimeUnit};
2274    use chrono::Utc;
2275    use rand::prelude::Distribution;
2276
2277    use super::*;
2278
2279    /// Create a generator of vectors by continuously calling the given generator
2280    ///
2281    /// For example, given a step generator and a dimension of 3 this will generate vectors like
2282    /// [0, 1, 2], [3, 4, 5], [6, 7, 8], ...
2283    pub fn cycle_vec(
2284        generator: Box<dyn ArrayGenerator>,
2285        dimension: Dimension,
2286    ) -> Box<dyn ArrayGenerator> {
2287        Box::new(CycleVectorGenerator::new(generator, dimension))
2288    }
2289
2290    /// Create a generator of list vectors by continuously calling the given generator
2291    ///
2292    /// The lists will have lengths uniformly distributed between `min_list_size` (inclusive) and
2293    /// `max_list_size` (exclusive).
2294    pub fn cycle_vec_var(
2295        generator: Box<dyn ArrayGenerator>,
2296        min_list_size: Dimension,
2297        max_list_size: Dimension,
2298    ) -> Box<dyn ArrayGenerator> {
2299        Box::new(CycleListGenerator::new(
2300            generator,
2301            min_list_size,
2302            max_list_size,
2303        ))
2304    }
2305
2306    /// Create a generator of vectors around unit circle
2307    ///
2308    /// Vectors will be equally spaced around the unit circle so that there are num_steps
2309    /// vectors per circle.
2310    pub fn cycle_unit_circle(num_steps: u32) -> Box<dyn ArrayGenerator> {
2311        Box::new(RadialStepGenerator::new(num_steps))
2312    }
2313
2314    /// Create a generator of vectors by cycling through a given set of vectors
2315    ///
2316    /// Each value will be spaced in slightly away from the previous value on a ball of radius jitter
2317    pub fn jitter_centroids(centroids: Arc<dyn Array>, jitter: f32) -> Box<dyn ArrayGenerator> {
2318        Box::new(JitterCentroidsGenerator::try_new(centroids, jitter).unwrap())
2319    }
2320
2321    /// Create a generator from a vector of values
2322    ///
2323    /// If more rows are requested than the length of values then it will restart
2324    /// from the beginning of the vector.
2325    pub fn cycle<DataType>(values: Vec<DataType::Native>) -> Box<dyn ArrayGenerator>
2326    where
2327        DataType::Native: Copy + 'static,
2328        DataType: ArrowPrimitiveType,
2329        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2330    {
2331        let mut values_idx = 0;
2332        Box::new(
2333            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2334                DataType::DATA_TYPE,
2335                move |_| {
2336                    let y = values[values_idx];
2337                    values_idx = (values_idx + 1) % values.len();
2338                    y
2339                },
2340                1,
2341                DataType::DATA_TYPE
2342                    .primitive_width()
2343                    .map(|width| ByteCount::from(width as u64))
2344                    .expect("Primitive types should have a fixed width"),
2345            ),
2346        )
2347    }
2348
2349    /// Create a generator from a vector of booleans
2350    ///
2351    /// If more rows are requested than the length of values then it will restart from
2352    /// the beginning of the vector
2353    pub fn cycle_bool(values: Vec<bool>) -> Box<dyn ArrayGenerator> {
2354        let mut values_idx = 0;
2355        Box::new(FnGen::<bool, BooleanArray, _>::new_unknown_size(
2356            DataType::Boolean,
2357            move |_| {
2358                let val = values[values_idx];
2359                values_idx = (values_idx + 1) % values.len();
2360                val
2361            },
2362            1,
2363        ))
2364    }
2365
2366    /// Create a generator that starts at 0 and increments by 1 for each element
2367    pub fn step<DataType>() -> Box<dyn ArrayGenerator>
2368    where
2369        DataType::Native: Copy + Default + std::ops::AddAssign<DataType::Native> + 'static,
2370        DataType: ArrowPrimitiveType,
2371        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2372    {
2373        let mut x = DataType::Native::default();
2374        Box::new(
2375            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2376                DataType::DATA_TYPE,
2377                move |_| {
2378                    let y = x;
2379                    x += DataType::Native::ONE;
2380                    y
2381                },
2382                1,
2383                DataType::DATA_TYPE
2384                    .primitive_width()
2385                    .map(|width| ByteCount::from(width as u64))
2386                    .expect("Primitive types should have a fixed width"),
2387            ),
2388        )
2389    }
2390
2391    pub fn blob() -> Box<dyn ArrayGenerator> {
2392        let mut blob_meta = HashMap::new();
2393        blob_meta.insert("lance-encoding:blob".to_string(), "true".to_string());
2394        rand_fixedbin(ByteCount::from(4 * 1024 * 1024), true).with_metadata(blob_meta)
2395    }
2396
2397    /// Create a generator that starts at a given value and increments by a given step for each element
2398    pub fn step_custom<DataType>(
2399        start: DataType::Native,
2400        step: DataType::Native,
2401    ) -> Box<dyn ArrayGenerator>
2402    where
2403        DataType::Native: Copy + Default + std::ops::AddAssign<DataType::Native> + 'static,
2404        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2405        DataType: ArrowPrimitiveType,
2406    {
2407        let mut x = start;
2408        Box::new(
2409            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2410                DataType::DATA_TYPE,
2411                move |_| {
2412                    let y = x;
2413                    x += step;
2414                    y
2415                },
2416                1,
2417                DataType::DATA_TYPE
2418                    .primitive_width()
2419                    .map(|width| ByteCount::from(width as u64))
2420                    .expect("Primitive types should have a fixed width"),
2421            ),
2422        )
2423    }
2424
2425    /// Create a generator that fills each element with the given primitive value
2426    pub fn fill<DataType>(value: DataType::Native) -> Box<dyn ArrayGenerator>
2427    where
2428        DataType::Native: Copy + 'static,
2429        DataType: ArrowPrimitiveType,
2430        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2431    {
2432        Box::new(
2433            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2434                DataType::DATA_TYPE,
2435                move |_| value,
2436                1,
2437                DataType::DATA_TYPE
2438                    .primitive_width()
2439                    .map(|width| ByteCount::from(width as u64))
2440                    .expect("Primitive types should have a fixed width"),
2441            ),
2442        )
2443    }
2444
2445    /// Create a generator that fills each element with the given binary value
2446    pub fn fill_varbin(value: Vec<u8>) -> Box<dyn ArrayGenerator> {
2447        Box::new(FixedBinaryGenerator::<BinaryType>::new(value))
2448    }
2449
2450    /// Create a generator that fills each element with the given string value
2451    pub fn fill_utf8(value: String) -> Box<dyn ArrayGenerator> {
2452        Box::new(FixedBinaryGenerator::<Utf8Type>::new(value.into_bytes()))
2453    }
2454
2455    pub fn cycle_utf8_literals(values: &[&'static str]) -> Box<dyn ArrayGenerator> {
2456        Box::new(CycleBinaryGenerator::<Utf8Type>::from_strings(values))
2457    }
2458
2459    /// Create a generator of primitive values that are randomly sampled from the entire range available for the value
2460    pub fn rand<DataType>() -> Box<dyn ArrayGenerator>
2461    where
2462        DataType::Native: Copy + 'static,
2463        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2464        DataType: ArrowPrimitiveType,
2465        rand::distr::StandardUniform: rand::distr::Distribution<DataType::Native>,
2466    {
2467        Box::new(
2468            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2469                DataType::DATA_TYPE,
2470                move |rng| rng.random(),
2471                1,
2472                DataType::DATA_TYPE
2473                    .primitive_width()
2474                    .map(|width| ByteCount::from(width as u64))
2475                    .expect("Primitive types should have a fixed width"),
2476            ),
2477        )
2478    }
2479
2480    /// Create a generator of primitive values that are randomly sampled from the entire range available for the value
2481    pub fn rand_with_distribution<
2482        DataType,
2483        Dist: rand::distr::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
2484    >(
2485        dist: Dist,
2486    ) -> Box<dyn ArrayGenerator>
2487    where
2488        DataType::Native: Copy + 'static,
2489        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2490        DataType: ArrowPrimitiveType,
2491    {
2492        Box::new(
2493            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2494                DataType::DATA_TYPE,
2495                move |rng| rng.sample(dist.clone()),
2496                1,
2497                DataType::DATA_TYPE
2498                    .primitive_width()
2499                    .map(|width| ByteCount::from(width as u64))
2500                    .expect("Primitive types should have a fixed width"),
2501            ),
2502        )
2503    }
2504
2505    /// Create a generator of 1d vectors (of a primitive type) consisting of randomly sampled primitive values
2506    pub fn rand_vec<DataType>(dimension: Dimension) -> Box<dyn ArrayGenerator>
2507    where
2508        DataType::Native: Copy + 'static,
2509        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2510        DataType: ArrowPrimitiveType,
2511        rand::distr::StandardUniform: rand::distr::Distribution<DataType::Native>,
2512    {
2513        let underlying = rand::<DataType>();
2514        cycle_vec(underlying, dimension)
2515    }
2516
2517    /// Create a generator of 1d vectors (of a primitive type) consisting of randomly sampled nullable values
2518    pub fn rand_vec_nullable<DataType>(
2519        dimension: Dimension,
2520        null_probability: f64,
2521    ) -> Box<dyn ArrayGenerator>
2522    where
2523        DataType::Native: Copy + 'static,
2524        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2525        DataType: ArrowPrimitiveType,
2526        rand::distr::StandardUniform: rand::distr::Distribution<DataType::Native>,
2527    {
2528        let underlying = rand::<DataType>().with_random_nulls(null_probability);
2529        cycle_vec(underlying, dimension)
2530    }
2531
2532    /// Create a generator of randomly sampled time32 values covering the entire
2533    /// range of 1 day
2534    pub fn rand_time32(resolution: &TimeUnit) -> Box<dyn ArrayGenerator> {
2535        let start = 0;
2536        let end = match resolution {
2537            TimeUnit::Second => 86_400,
2538            TimeUnit::Millisecond => 86_400_000,
2539            _ => panic!(),
2540        };
2541
2542        let data_type = DataType::Time32(*resolution);
2543        let size = ByteCount::from(data_type.primitive_width().unwrap() as u64);
2544        let dist = Uniform::new(start, end).unwrap();
2545        let sample_fn = move |rng: &mut _| dist.sample(rng);
2546
2547        match resolution {
2548            TimeUnit::Second => Box::new(FnGen::<i32, Time32SecondArray, _>::new_known_size(
2549                data_type, sample_fn, 1, size,
2550            )),
2551            TimeUnit::Millisecond => {
2552                Box::new(FnGen::<i32, Time32MillisecondArray, _>::new_known_size(
2553                    data_type, sample_fn, 1, size,
2554                ))
2555            }
2556            _ => panic!(),
2557        }
2558    }
2559
2560    /// Create a generator of randomly sampled time64 values covering the entire
2561    /// range of 1 day
2562    pub fn rand_time64(resolution: &TimeUnit) -> Box<dyn ArrayGenerator> {
2563        let start = 0_i64;
2564        let end: i64 = match resolution {
2565            TimeUnit::Microsecond => 86_400_000,
2566            TimeUnit::Nanosecond => 86_400_000_000,
2567            _ => panic!(),
2568        };
2569
2570        let data_type = DataType::Time64(*resolution);
2571        let size = ByteCount::from(data_type.primitive_width().unwrap() as u64);
2572        let dist = Uniform::new(start, end).unwrap();
2573        let sample_fn = move |rng: &mut _| dist.sample(rng);
2574
2575        match resolution {
2576            TimeUnit::Microsecond => {
2577                Box::new(FnGen::<i64, Time64MicrosecondArray, _>::new_known_size(
2578                    data_type, sample_fn, 1, size,
2579                ))
2580            }
2581            TimeUnit::Nanosecond => {
2582                Box::new(FnGen::<i64, Time64NanosecondArray, _>::new_known_size(
2583                    data_type, sample_fn, 1, size,
2584                ))
2585            }
2586            _ => panic!(),
2587        }
2588    }
2589
2590    /// Create a generator of random UUIDs, stored as fixed size binary values
2591    ///
2592    /// Note, these are "pseudo UUIDs".  They are 16-byte randomish values but they
2593    /// are not guaranteed to be unique.  We use a simplistic RNG that trades uniqueness
2594    /// for speed.
2595    pub fn rand_pseudo_uuid() -> Box<dyn ArrayGenerator> {
2596        Box::<PseudoUuidGenerator>::default()
2597    }
2598
2599    /// Create a generator of random UUIDs, stored as 32-character strings (hex encoding
2600    /// of the 16-byte binary value)
2601    ///
2602    /// Note, these are "pseudo UUIDs".  They are 16-byte randomish values but they
2603    /// are not guaranteed to be unique.  We use a simplistic RNG that trades uniqueness
2604    /// for speed.
2605    pub fn rand_pseudo_uuid_hex() -> Box<dyn ArrayGenerator> {
2606        Box::<PseudoUuidHexGenerator>::default()
2607    }
2608
2609    pub fn rand_primitive<T: ArrowPrimitiveType + Send + Sync>(
2610        data_type: DataType,
2611    ) -> Box<dyn ArrayGenerator> {
2612        Box::new(RandomBytesGenerator::<T>::new(data_type))
2613    }
2614
2615    pub fn rand_fsb(size: i32) -> Box<dyn ArrayGenerator> {
2616        Box::new(RandomFixedSizeBinaryGenerator::new(size))
2617    }
2618
2619    pub fn rand_interval(unit: IntervalUnit) -> Box<dyn ArrayGenerator> {
2620        Box::new(RandomIntervalGenerator::new(unit))
2621    }
2622
2623    /// Create a generator of randomly sampled date32 values
2624    ///
2625    /// Instead of sampling the entire range, all values will be drawn from the last year as this
2626    /// is a more common use pattern
2627    pub fn rand_date32() -> Box<dyn ArrayGenerator> {
2628        let now = chrono::Utc::now();
2629        let one_year_ago = now - chrono::TimeDelta::try_days(365).expect("TimeDelta try days");
2630        rand_date32_in_range(one_year_ago, now)
2631    }
2632
2633    /// Create a generator of randomly sampled date32 values in the given range
2634    pub fn rand_date32_in_range(
2635        start: chrono::DateTime<Utc>,
2636        end: chrono::DateTime<Utc>,
2637    ) -> Box<dyn ArrayGenerator> {
2638        let data_type = DataType::Date32;
2639        let end_ms = end.timestamp_millis();
2640        let end_days = (end_ms / MS_PER_DAY) as i32;
2641        let start_ms = start.timestamp_millis();
2642        let start_days = (start_ms / MS_PER_DAY) as i32;
2643        let dist = Uniform::new(start_days, end_days).unwrap();
2644
2645        Box::new(FnGen::<i32, Date32Array, _>::new_known_size(
2646            data_type,
2647            move |rng| dist.sample(rng),
2648            1,
2649            DataType::Date32
2650                .primitive_width()
2651                .map(|width| ByteCount::from(width as u64))
2652                .expect("Date32 should have a fixed width"),
2653        ))
2654    }
2655
2656    /// Create a generator of randomly sampled date64 values
2657    ///
2658    /// Instead of sampling the entire range, all values will be drawn from the last year as this
2659    /// is a more common use pattern
2660    pub fn rand_date64() -> Box<dyn ArrayGenerator> {
2661        let now = chrono::Utc::now();
2662        let one_year_ago = now - chrono::TimeDelta::try_days(365).expect("TimeDelta try_days");
2663        rand_date64_in_range(one_year_ago, now)
2664    }
2665
2666    /// Create a generator of randomly sampled timestamp values in the given range
2667    ///
2668    /// Currently just samples the entire range of u64 values and casts to timestamp
2669    pub fn rand_timestamp_in_range(
2670        start: chrono::DateTime<Utc>,
2671        end: chrono::DateTime<Utc>,
2672        data_type: &DataType,
2673    ) -> Box<dyn ArrayGenerator> {
2674        let end_ms = end.timestamp_millis();
2675        let start_ms = start.timestamp_millis();
2676        let (start_ticks, end_ticks) = match data_type {
2677            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
2678                (start_ms * 1000 * 1000, end_ms * 1000 * 1000)
2679            }
2680            DataType::Timestamp(TimeUnit::Microsecond, _) => (start_ms * 1000, end_ms * 1000),
2681            DataType::Timestamp(TimeUnit::Millisecond, _) => (start_ms, end_ms),
2682            DataType::Timestamp(TimeUnit::Second, _) => (start.timestamp(), end.timestamp()),
2683            _ => panic!(),
2684        };
2685        let dist = Uniform::new(start_ticks, end_ticks).unwrap();
2686
2687        let data_type = data_type.clone();
2688        let sample_fn = move |rng: &mut _| dist.sample(rng);
2689        let width = data_type
2690            .primitive_width()
2691            .map(|width| ByteCount::from(width as u64))
2692            .unwrap();
2693
2694        match data_type {
2695            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
2696                Box::new(FnGen::<i64, TimestampNanosecondArray, _>::new_known_size(
2697                    data_type, sample_fn, 1, width,
2698                ))
2699            }
2700            DataType::Timestamp(TimeUnit::Microsecond, _) => {
2701                Box::new(FnGen::<i64, TimestampMicrosecondArray, _>::new_known_size(
2702                    data_type, sample_fn, 1, width,
2703                ))
2704            }
2705            DataType::Timestamp(TimeUnit::Millisecond, _) => {
2706                Box::new(FnGen::<i64, TimestampMillisecondArray, _>::new_known_size(
2707                    data_type, sample_fn, 1, width,
2708                ))
2709            }
2710            DataType::Timestamp(TimeUnit::Second, _) => {
2711                Box::new(FnGen::<i64, TimestampSecondArray, _>::new_known_size(
2712                    data_type, sample_fn, 1, width,
2713                ))
2714            }
2715            _ => panic!(),
2716        }
2717    }
2718
2719    pub fn rand_timestamp(data_type: &DataType) -> Box<dyn ArrayGenerator> {
2720        let now = chrono::Utc::now();
2721        let one_year_ago = now - chrono::Duration::try_days(365).unwrap();
2722        rand_timestamp_in_range(one_year_ago, now, data_type)
2723    }
2724
2725    /// Create a generator of randomly sampled date64 values
2726    ///
2727    /// Instead of sampling the entire range, all values will be drawn from the last year as this
2728    /// is a more common use pattern
2729    pub fn rand_date64_in_range(
2730        start: chrono::DateTime<Utc>,
2731        end: chrono::DateTime<Utc>,
2732    ) -> Box<dyn ArrayGenerator> {
2733        let data_type = DataType::Date64;
2734        let end_ms = end.timestamp_millis();
2735        let end_days = end_ms / MS_PER_DAY;
2736        let start_ms = start.timestamp_millis();
2737        let start_days = start_ms / MS_PER_DAY;
2738        let dist = Uniform::new(start_days, end_days).unwrap();
2739
2740        Box::new(FnGen::<i64, Date64Array, _>::new_known_size(
2741            data_type,
2742            move |rng| (dist.sample(rng)) * MS_PER_DAY,
2743            1,
2744            DataType::Date64
2745                .primitive_width()
2746                .map(|width| ByteCount::from(width as u64))
2747                .expect("Date64 should have a fixed width"),
2748        ))
2749    }
2750
2751    /// Create a generator of random binary values where each value has a fixed number of bytes
2752    pub fn rand_fixedbin(bytes_per_element: ByteCount, is_large: bool) -> Box<dyn ArrayGenerator> {
2753        Box::new(RandomBinaryGenerator::new(
2754            bytes_per_element,
2755            false,
2756            is_large,
2757        ))
2758    }
2759
2760    /// Create a generator of random binary values where each value has a variable number of bytes
2761    ///
2762    /// The number of bytes per element will be randomly sampled from the given (inclusive) range
2763    pub fn rand_varbin(
2764        min_bytes_per_element: ByteCount,
2765        max_bytes_per_element: ByteCount,
2766    ) -> Box<dyn ArrayGenerator> {
2767        Box::new(VariableRandomBinaryGenerator::new(
2768            min_bytes_per_element,
2769            max_bytes_per_element,
2770        ))
2771    }
2772
2773    /// Create a generator of random strings
2774    ///
2775    /// All strings will consist entirely of printable ASCII characters
2776    pub fn rand_utf8(bytes_per_element: ByteCount, is_large: bool) -> Box<dyn ArrayGenerator> {
2777        Box::new(RandomBinaryGenerator::new(
2778            bytes_per_element,
2779            true,
2780            is_large,
2781        ))
2782    }
2783
2784    /// Creates a generator of strings with a prefix and a counter
2785    ///
2786    /// For example, if the prefix is "user_" the strings will be "user_0", "user_1", ...
2787    pub fn utf8_prefix_plus_counter(
2788        prefix: impl Into<String>,
2789        is_large: bool,
2790    ) -> Box<dyn ArrayGenerator> {
2791        Box::new(PrefixPlusCounterGenerator::new(prefix.into(), is_large))
2792    }
2793
2794    pub fn binary_prefix_plus_counter(
2795        prefix: Arc<[u8]>,
2796        is_large: bool,
2797    ) -> Box<dyn ArrayGenerator> {
2798        Box::new(BinaryPrefixPlusCounterGenerator::new(prefix, is_large))
2799    }
2800
2801    /// Create a random generator of boolean values
2802    pub fn rand_boolean() -> Box<dyn ArrayGenerator> {
2803        Box::<RandomBooleanGenerator>::default()
2804    }
2805
2806    /// Create a generator of random sentences
2807    ///
2808    /// Generates strings containing between min_words and max_words random English words joined by spaces
2809    pub fn random_sentence(
2810        min_words: usize,
2811        max_words: usize,
2812        is_large: bool,
2813    ) -> Box<dyn ArrayGenerator> {
2814        Box::new(RandomSentenceGenerator::new(min_words, max_words, is_large))
2815    }
2816
2817    /// Create a generator of random words (one word per row)
2818    ///
2819    /// Generates strings containing a single random English word per row
2820    pub fn random_word(is_large: bool) -> Box<dyn ArrayGenerator> {
2821        Box::new(RandomWordGenerator::new(is_large))
2822    }
2823
2824    pub fn rand_list(item_type: &DataType, is_large: bool) -> Box<dyn ArrayGenerator> {
2825        let child_gen = rand_type(item_type);
2826        Box::new(RandomListGenerator::new(child_gen, is_large))
2827    }
2828
2829    pub fn rand_list_any(
2830        item_gen: Box<dyn ArrayGenerator>,
2831        is_large: bool,
2832    ) -> Box<dyn ArrayGenerator> {
2833        Box::new(RandomListGenerator::new(item_gen, is_large))
2834    }
2835
2836    /// Generates random map arrays where each map has 0-4 entries.
2837    pub fn rand_map(key_type: &DataType, value_type: &DataType) -> Box<dyn ArrayGenerator> {
2838        let keys_gen = rand_type(key_type);
2839        let values_gen = rand_type(value_type);
2840        Box::new(RandomMapGenerator::new(keys_gen, values_gen))
2841    }
2842
2843    pub fn rand_struct(fields: Fields) -> Box<dyn ArrayGenerator> {
2844        let child_gens = fields
2845            .iter()
2846            .map(|f| rand_type(f.data_type()))
2847            .collect::<Vec<_>>();
2848        Box::new(RandomStructGenerator::new(fields, child_gens))
2849    }
2850
2851    pub fn null_type() -> Box<dyn ArrayGenerator> {
2852        Box::new(NullArrayGenerator {})
2853    }
2854
2855    /// Create a generator of random values
2856    pub fn rand_type(data_type: &DataType) -> Box<dyn ArrayGenerator> {
2857        match data_type {
2858            DataType::Boolean => rand_boolean(),
2859            DataType::Int8 => rand::<Int8Type>(),
2860            DataType::Int16 => rand::<Int16Type>(),
2861            DataType::Int32 => rand::<Int32Type>(),
2862            DataType::Int64 => rand::<Int64Type>(),
2863            DataType::UInt8 => rand::<UInt8Type>(),
2864            DataType::UInt16 => rand::<UInt16Type>(),
2865            DataType::UInt32 => rand::<UInt32Type>(),
2866            DataType::UInt64 => rand::<UInt64Type>(),
2867            DataType::Float16 => rand_primitive::<Float16Type>(data_type.clone()),
2868            DataType::Float32 => rand::<Float32Type>(),
2869            DataType::Float64 => rand::<Float64Type>(),
2870            DataType::Decimal128(_, _) => rand_primitive::<Decimal128Type>(data_type.clone()),
2871            DataType::Decimal256(_, _) => rand_primitive::<Decimal256Type>(data_type.clone()),
2872            DataType::Utf8 => rand_utf8(ByteCount::from(12), false),
2873            DataType::LargeUtf8 => rand_utf8(ByteCount::from(12), true),
2874            DataType::Binary => rand_fixedbin(ByteCount::from(12), false),
2875            DataType::LargeBinary => rand_fixedbin(ByteCount::from(12), true),
2876            DataType::Dictionary(key_type, value_type) => {
2877                dict_type(rand_type(value_type), key_type)
2878            }
2879            DataType::FixedSizeList(child, dimension) => cycle_vec(
2880                rand_type(child.data_type()),
2881                Dimension::from(*dimension as u32),
2882            ),
2883            DataType::FixedSizeBinary(size) => rand_fsb(*size),
2884            DataType::List(child) => rand_list(child.data_type(), false),
2885            DataType::LargeList(child) => rand_list(child.data_type(), true),
2886            DataType::Map(entries_field, _) => {
2887                let DataType::Struct(fields) = entries_field.data_type() else {
2888                    panic!("Map entries field must be a struct");
2889                };
2890                let key_type = fields[0].data_type();
2891                let value_type = fields[1].data_type();
2892                rand_map(key_type, value_type)
2893            }
2894            DataType::Duration(unit) => match unit {
2895                TimeUnit::Second => rand::<DurationSecondType>(),
2896                TimeUnit::Millisecond => rand::<DurationMillisecondType>(),
2897                TimeUnit::Microsecond => rand::<DurationMicrosecondType>(),
2898                TimeUnit::Nanosecond => rand::<DurationNanosecondType>(),
2899            },
2900            DataType::Interval(unit) => rand_interval(*unit),
2901            DataType::Date32 => rand_date32(),
2902            DataType::Date64 => rand_date64(),
2903            DataType::Time32(resolution) => rand_time32(resolution),
2904            DataType::Time64(resolution) => rand_time64(resolution),
2905            DataType::Timestamp(_, _) => rand_timestamp(data_type),
2906            DataType::Struct(fields) => rand_struct(fields.clone()),
2907            DataType::Null => null_type(),
2908            _ => unimplemented!("random generation of {}", data_type),
2909        }
2910    }
2911
2912    /// Encodes arrays generated by the underlying generator as dictionaries with the given key type
2913    ///
2914    /// Note that this may not be very realistic if the underlying generator is something like a random
2915    /// generator since most of the underlying values will be unique and the common case for dictionary
2916    /// encoding is when there is a small set of possible values.
2917    pub fn dict<K: ArrowDictionaryKeyType + Send + Sync>(
2918        generator: Box<dyn ArrayGenerator>,
2919    ) -> Box<dyn ArrayGenerator> {
2920        Box::new(DictionaryGenerator::<K>::new(generator))
2921    }
2922
2923    /// Encodes arrays generated by the underlying generator as dictionaries with the given key type
2924    pub fn dict_type(
2925        generator: Box<dyn ArrayGenerator>,
2926        key_type: &DataType,
2927    ) -> Box<dyn ArrayGenerator> {
2928        match key_type {
2929            DataType::Int8 => dict::<Int8Type>(generator),
2930            DataType::Int16 => dict::<Int16Type>(generator),
2931            DataType::Int32 => dict::<Int32Type>(generator),
2932            DataType::Int64 => dict::<Int64Type>(generator),
2933            DataType::UInt8 => dict::<UInt8Type>(generator),
2934            DataType::UInt16 => dict::<UInt16Type>(generator),
2935            DataType::UInt32 => dict::<UInt32Type>(generator),
2936            DataType::UInt64 => dict::<UInt64Type>(generator),
2937            _ => unimplemented!(),
2938        }
2939    }
2940
2941    /// Wraps a generator to produce low-cardinality data.
2942    ///
2943    /// Generates `cardinality` unique values on first call, then randomly
2944    /// selects from them for all subsequent rows.
2945    pub fn low_cardinality(
2946        generator: Box<dyn ArrayGenerator>,
2947        cardinality: usize,
2948    ) -> Box<dyn ArrayGenerator> {
2949        Box::new(LowCardinalityGenerator::new(generator, cardinality))
2950    }
2951}
2952
2953/// Create a BatchGeneratorBuilder to start generating batch data
2954pub fn gen_batch() -> BatchGeneratorBuilder {
2955    BatchGeneratorBuilder::default()
2956}
2957
2958/// Create an ArrayGeneratorBuilder to start generating array data
2959pub fn gen_array(genn: Box<dyn ArrayGenerator>) -> ArrayGeneratorBuilder {
2960    ArrayGeneratorBuilder::new(genn)
2961}
2962
2963/// Metadata key to specify content type for string generation.
2964/// Set to "sentence" to use the sentence generator with Zipf distribution.
2965pub const CONTENT_TYPE_KEY: &str = "lance-datagen:content-type";
2966
2967/// Metadata key to specify cardinality for low-cardinality data generation.
2968/// Set to a numeric string (e.g., "100") to limit unique values.
2969pub const CARDINALITY_KEY: &str = "lance-datagen:cardinality";
2970
2971/// Create a generator for a field, checking metadata for content type hints.
2972///
2973/// Supported metadata keys:
2974/// - `lance-datagen:content-type`: Set to "sentence" for Utf8/LargeUtf8 fields
2975///   to use the sentence generator with Zipf distribution.
2976/// - `lance-datagen:cardinality`: Set to a number to limit unique values.
2977///   The generator will produce only that many unique values and randomly
2978///   select from them.
2979pub fn rand_field(field: &Field) -> Box<dyn ArrayGenerator> {
2980    let mut generator = if let Some(content_type) = field.metadata().get(CONTENT_TYPE_KEY) {
2981        match (content_type.as_str(), field.data_type()) {
2982            ("sentence", DataType::Utf8) => array::random_sentence(1, 10, false),
2983            ("sentence", DataType::LargeUtf8) => array::random_sentence(1, 10, true),
2984            _ => array::rand_type(field.data_type()),
2985        }
2986    } else {
2987        array::rand_type(field.data_type())
2988    };
2989
2990    if let Some(cardinality_str) = field.metadata().get(CARDINALITY_KEY) {
2991        if let Ok(cardinality) = cardinality_str.parse::<usize>() {
2992            if cardinality > 0 {
2993                generator = array::low_cardinality(generator, cardinality);
2994            }
2995        }
2996    }
2997
2998    generator
2999}
3000
3001/// Create a BatchGeneratorBuilder with the given schema
3002///
3003/// You can add more columns or convert this into a reader immediately.
3004///
3005/// Supported field metadata:
3006/// - `lance-datagen:content-type` = `"sentence"`: Use sentence generator with
3007///   Zipf distribution for more realistic text (Utf8/LargeUtf8 only).
3008/// - `lance-datagen:cardinality` = `"<number>"`: Limit to N unique values.
3009pub fn rand(schema: &Schema) -> BatchGeneratorBuilder {
3010    let mut builder = BatchGeneratorBuilder::default();
3011    for field in schema.fields() {
3012        builder = builder.col(field.name(), rand_field(field));
3013    }
3014    builder
3015}
3016
3017#[cfg(test)]
3018mod tests {
3019
3020    use arrow::datatypes::{Float32Type, Int16Type, Int8Type, UInt32Type};
3021    use arrow_array::{BooleanArray, Float32Array, Int16Array, Int32Array, Int8Array, UInt32Array};
3022
3023    use super::*;
3024
3025    #[test]
3026    fn test_step() {
3027        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3028        let mut genn = array::step::<Int32Type>();
3029        assert_eq!(
3030            *genn.generate(RowCount::from(5), &mut rng).unwrap(),
3031            Int32Array::from_iter([0, 1, 2, 3, 4])
3032        );
3033        assert_eq!(
3034            *genn.generate(RowCount::from(5), &mut rng).unwrap(),
3035            Int32Array::from_iter([5, 6, 7, 8, 9])
3036        );
3037
3038        let mut genn = array::step::<Int8Type>();
3039        assert_eq!(
3040            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3041            Int8Array::from_iter([0, 1, 2])
3042        );
3043
3044        let mut genn = array::step::<Float32Type>();
3045        assert_eq!(
3046            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3047            Float32Array::from_iter([0.0, 1.0, 2.0])
3048        );
3049
3050        let mut genn = array::step_custom::<Int16Type>(4, 8);
3051        assert_eq!(
3052            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3053            Int16Array::from_iter([4, 12, 20])
3054        );
3055        assert_eq!(
3056            *genn.generate(RowCount::from(2), &mut rng).unwrap(),
3057            Int16Array::from_iter([28, 36])
3058        );
3059    }
3060
3061    #[test]
3062    fn test_cycle() {
3063        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3064        let mut genn = array::cycle::<Int32Type>(vec![1, 2, 3]);
3065        assert_eq!(
3066            *genn.generate(RowCount::from(5), &mut rng).unwrap(),
3067            Int32Array::from_iter([1, 2, 3, 1, 2])
3068        );
3069
3070        let mut genn = array::cycle_utf8_literals(&["abc", "def", "xyz"]);
3071        assert_eq!(
3072            *genn.generate(RowCount::from(5), &mut rng).unwrap(),
3073            StringArray::from_iter_values(["abc", "def", "xyz", "abc", "def"])
3074        );
3075        assert_eq!(
3076            *genn.generate(RowCount::from(1), &mut rng).unwrap(),
3077            StringArray::from_iter_values(["xyz"])
3078        );
3079
3080        let mut genn = array::cycle_bool(vec![false, false, true]);
3081        assert_eq!(
3082            *genn.generate(RowCount::from(5), &mut rng).unwrap(),
3083            BooleanArray::from_iter(vec![false, false, true, false, false].into_iter().map(Some))
3084        );
3085        assert_eq!(
3086            *genn.generate(RowCount::from(1), &mut rng).unwrap(),
3087            BooleanArray::from_iter(vec![Some(true)])
3088        )
3089    }
3090
3091    #[test]
3092    fn test_fill() {
3093        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3094        let mut genn = array::fill::<Int32Type>(42);
3095        assert_eq!(
3096            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3097            Int32Array::from_iter([42, 42, 42])
3098        );
3099        assert_eq!(
3100            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3101            Int32Array::from_iter([42, 42, 42])
3102        );
3103
3104        let mut genn = array::fill_varbin(vec![0, 1, 2]);
3105        assert_eq!(
3106            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3107            arrow_array::BinaryArray::from_iter_values([
3108                "\x00\x01\x02",
3109                "\x00\x01\x02",
3110                "\x00\x01\x02"
3111            ])
3112        );
3113
3114        let mut genn = array::fill_utf8("xyz".to_string());
3115        assert_eq!(
3116            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3117            arrow_array::StringArray::from_iter_values(["xyz", "xyz", "xyz"])
3118        );
3119    }
3120
3121    #[test]
3122    fn test_utf8_prefix_plus_counter() {
3123        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3124        let mut genn = array::utf8_prefix_plus_counter("user_", false);
3125        assert_eq!(
3126            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3127            arrow_array::StringArray::from_iter_values(["user_0", "user_1", "user_2"])
3128        );
3129
3130        let mut genn = array::utf8_prefix_plus_counter("user_", true);
3131        assert_eq!(
3132            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3133            arrow_array::LargeStringArray::from_iter_values(["user_0", "user_1", "user_2"])
3134        );
3135    }
3136
3137    #[test]
3138    fn test_rng() {
3139        // Note: these tests are heavily dependent on the default seed.
3140        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3141        let mut genn = array::rand::<Int32Type>();
3142        assert_eq!(
3143            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3144            Int32Array::from_iter([-797553329, 1369325940, -69174021])
3145        );
3146
3147        let mut genn = array::rand_fixedbin(ByteCount::from(3), false);
3148        assert_eq!(
3149            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3150            arrow_array::BinaryArray::from_iter_values([
3151                [184, 53, 216],
3152                [12, 96, 159],
3153                [125, 179, 56]
3154            ])
3155        );
3156
3157        let mut genn = array::rand_utf8(ByteCount::from(3), false);
3158        assert_eq!(
3159            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3160            arrow_array::StringArray::from_iter_values([">@p", "n `", "NWa"])
3161        );
3162
3163        let mut genn = array::random_sentence(1, 5, false);
3164        let words = genn.generate(RowCount::from(10), &mut rng).unwrap();
3165        assert_eq!(words.data_type(), &DataType::Utf8);
3166        let words_array = words.as_any().downcast_ref::<StringArray>().unwrap();
3167        // Verify each string contains 1-5 words
3168        for i in 0..10 {
3169            let sentence = words_array.value(i);
3170            let word_count = sentence.split_whitespace().count();
3171            assert!((1..=5).contains(&word_count));
3172        }
3173
3174        let mut genn = array::rand_date32();
3175        let days_32 = genn.generate(RowCount::from(3), &mut rng).unwrap();
3176        assert_eq!(days_32.data_type(), &DataType::Date32);
3177
3178        let mut genn = array::rand_date64();
3179        let days_64 = genn.generate(RowCount::from(3), &mut rng).unwrap();
3180        assert_eq!(days_64.data_type(), &DataType::Date64);
3181
3182        let mut genn = array::rand_boolean();
3183        let bools = genn.generate(RowCount::from(1024), &mut rng).unwrap();
3184        assert_eq!(bools.data_type(), &DataType::Boolean);
3185        let bools = bools.as_any().downcast_ref::<BooleanArray>().unwrap();
3186        // Sanity check to ensure we're getting at least some rng
3187        assert!(bools.false_count() > 100);
3188        assert!(bools.true_count() > 100);
3189
3190        let mut genn = array::rand_varbin(ByteCount::from(2), ByteCount::from(4));
3191        assert_eq!(
3192            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3193            arrow_array::BinaryArray::from_iter_values([
3194                vec![174, 178],
3195                vec![64, 122, 207, 248],
3196                vec![124, 3, 58]
3197            ])
3198        );
3199    }
3200
3201    #[test]
3202    fn test_rng_list() {
3203        // Note: these tests are heavily dependent on the default seed.
3204        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3205        let mut genn = array::rand_list(&DataType::Int32, false);
3206        let arr = genn.generate(RowCount::from(100), &mut rng).unwrap();
3207        // Make sure we can generate empty lists (note, test is dependent on seed)
3208        let arr = arr.as_list::<i32>();
3209        assert!(arr.iter().any(|l| l.unwrap().is_empty()));
3210        // Shouldn't generate any giant lists (don't kill performance in normal datagen)
3211        assert!(arr.iter().any(|l| l.unwrap().len() < 11));
3212    }
3213
3214    #[test]
3215    fn test_rng_distribution() {
3216        // Sanity test to make sure we our RNG is giving us well distributed values
3217        // We generates some 4-byte integers, histogram them into 8 buckets, and make
3218        // sure each bucket has a good # of values
3219        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3220        let mut genn = array::rand::<UInt32Type>();
3221        for _ in 0..10 {
3222            let arr = genn.generate(RowCount::from(10000), &mut rng).unwrap();
3223            let int_arr = arr.as_any().downcast_ref::<UInt32Array>().unwrap();
3224            let mut buckets = vec![0_u32; 256];
3225            for val in int_arr.values() {
3226                buckets[(*val >> 24) as usize] += 1;
3227            }
3228            for bucket in buckets {
3229                // Perfectly even distribution would have 10000 / 256 values (~40) per bucket
3230                // We test for 15 which should be "good enough" and statistically unlikely to fail
3231                assert!(bucket > 15);
3232            }
3233        }
3234    }
3235
3236    #[test]
3237    fn test_nulls() {
3238        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3239        let mut genn = array::rand::<Int32Type>().with_random_nulls(0.3);
3240
3241        let arr = genn.generate(RowCount::from(1000), &mut rng).unwrap();
3242
3243        // This assert depends on the default seed
3244        assert_eq!(arr.null_count(), 297);
3245
3246        for len in 0..100 {
3247            let arr = genn.generate(RowCount::from(len), &mut rng).unwrap();
3248            // Make sure the null count we came up with matches the actual # of unset bits
3249            assert_eq!(
3250                arr.null_count(),
3251                arr.nulls()
3252                    .map(|nulls| (len as usize)
3253                        - nulls.buffer().count_set_bits_offset(0, len as usize))
3254                    .unwrap_or(0)
3255            );
3256        }
3257
3258        let mut genn = array::rand::<Int32Type>().with_random_nulls(0.0);
3259        let arr = genn.generate(RowCount::from(10), &mut rng).unwrap();
3260
3261        assert_eq!(arr.null_count(), 0);
3262
3263        let mut genn = array::rand::<Int32Type>().with_random_nulls(1.0);
3264        let arr = genn.generate(RowCount::from(10), &mut rng).unwrap();
3265
3266        assert_eq!(arr.null_count(), 10);
3267        assert!((0..10).all(|idx| arr.is_null(idx)));
3268
3269        let mut genn = array::rand::<Int32Type>().with_nulls(&[false, false, true]);
3270        let arr = genn.generate(RowCount::from(7), &mut rng).unwrap();
3271        assert!((0..2).all(|idx| arr.is_valid(idx)));
3272        assert!(arr.is_null(2));
3273        assert!((3..5).all(|idx| arr.is_valid(idx)));
3274        assert!(arr.is_null(5));
3275        assert!(arr.is_valid(6));
3276    }
3277
3278    #[test]
3279    fn test_unit_circle() {
3280        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3281        let mut genn = array::cycle_unit_circle(4);
3282        let arr = genn.generate(RowCount::from(6), &mut rng).unwrap();
3283
3284        let arr_values = arr
3285            .as_fixed_size_list()
3286            .values()
3287            .as_primitive::<Float32Type>()
3288            .values()
3289            .to_vec();
3290        assert_eq!(arr_values.len(), 12);
3291        let expected_values = [1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, -1.0, 1.0, 0.0, 0.0, 1.0];
3292        for (actual, expected) in arr_values.iter().zip(expected_values.iter()) {
3293            assert!((actual - expected).abs() < 0.0001);
3294        }
3295    }
3296
3297    #[test]
3298    fn test_jitter_centroids() {
3299        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3300        let mut centroids_gen = array::cycle_unit_circle(4);
3301        let centroids = centroids_gen.generate(RowCount::from(4), &mut rng).unwrap();
3302
3303        let centroid_values = centroids
3304            .as_fixed_size_list()
3305            .values()
3306            .as_primitive::<Float32Type>()
3307            .values()
3308            .to_vec();
3309
3310        let mut jitter_jen = array::jitter_centroids(centroids, 0.001);
3311        let jittered = jitter_jen.generate(RowCount::from(100), &mut rng).unwrap();
3312
3313        let values = jittered
3314            .as_fixed_size_list()
3315            .values()
3316            .as_primitive::<Float32Type>()
3317            .values()
3318            .to_vec();
3319
3320        for i in 0..100 {
3321            let centroid = i % 4;
3322            let centroid_x = centroid_values[centroid * 2];
3323            let centroid_y = centroid_values[centroid * 2 + 1];
3324            let value_x = values[i * 2];
3325            let value_y = values[i * 2 + 1];
3326
3327            let l2_dist = ((value_x - centroid_x).powi(2) + (value_y - centroid_y).powi(2)).sqrt();
3328            assert!(l2_dist < 0.001001);
3329            assert!(l2_dist > 0.000999);
3330        }
3331    }
3332
3333    #[test]
3334    fn test_rand_schema() {
3335        let schema = Schema::new(vec![
3336            Field::new("a", DataType::Int32, true),
3337            Field::new("b", DataType::Utf8, true),
3338            Field::new("c", DataType::Float32, true),
3339            Field::new("d", DataType::Int32, true),
3340            Field::new("e", DataType::Int32, true),
3341        ]);
3342        let rbr = rand(&schema)
3343            .into_reader_bytes(
3344                ByteCount::from(1024 * 1024),
3345                BatchCount::from(8),
3346                RoundingBehavior::ExactOrErr,
3347            )
3348            .unwrap();
3349        assert_eq!(*rbr.schema(), schema);
3350
3351        let batches = rbr.map(|val| val.unwrap()).collect::<Vec<_>>();
3352        assert_eq!(batches.len(), 8);
3353
3354        for batch in batches {
3355            assert_eq!(batch.num_rows(), 1024 * 1024 / 32);
3356            assert_eq!(batch.num_columns(), 5);
3357        }
3358    }
3359}