Skip to main content

lance_datagen/
generator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, iter, marker::PhantomData, sync::Arc, sync::LazyLock};
5
6use arrow::{
7    array::{ArrayData, AsArray, Float32Builder, GenericBinaryBuilder, GenericStringBuilder},
8    buffer::{BooleanBuffer, Buffer, OffsetBuffer, ScalarBuffer},
9    datatypes::{
10        ArrowPrimitiveType, Float32Type, Int32Type, Int64Type, IntervalDayTime,
11        IntervalMonthDayNano, UInt32Type,
12    },
13};
14use arrow_array::{
15    Array, BinaryArray, FixedSizeBinaryArray, FixedSizeListArray, Float32Array, LargeListArray,
16    LargeStringArray, ListArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, RecordBatch,
17    RecordBatchOptions, RecordBatchReader, StringArray, StructArray, make_array,
18    types::{ArrowDictionaryKeyType, BinaryType, ByteArrayType, Utf8Type},
19};
20use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema, SchemaRef};
21use futures::{StreamExt, stream::BoxStream};
22use rand::{Rng, RngCore, SeedableRng, distr::Uniform};
23use rand_distr::Zipf;
24use random_word;
25
26use self::array::rand_with_distribution;
27
28#[derive(Copy, Clone, Debug, Default)]
29pub struct RowCount(u64);
30#[derive(Copy, Clone, Debug, Default)]
31pub struct BatchCount(u32);
32#[derive(Copy, Clone, Debug, Default)]
33pub struct ByteCount(u64);
34#[derive(Copy, Clone, Debug, Default)]
35pub struct Dimension(u32);
36
37impl From<u32> for BatchCount {
38    fn from(n: u32) -> Self {
39        Self(n)
40    }
41}
42
43impl From<u64> for RowCount {
44    fn from(n: u64) -> Self {
45        Self(n)
46    }
47}
48
49impl From<u64> for ByteCount {
50    fn from(n: u64) -> Self {
51        Self(n)
52    }
53}
54
55impl From<u32> for Dimension {
56    fn from(n: u32) -> Self {
57        Self(n)
58    }
59}
60
61/// A trait for anything that can generate arrays of data
62pub trait ArrayGenerator: Send + Sync + std::fmt::Debug {
63    /// Generate an array of the given length
64    ///
65    /// # Arguments
66    ///
67    /// * `length` - The number of elements to generate
68    /// * `rng` - The random number generator to use
69    ///
70    /// # Returns
71    ///
72    /// An array of the given length
73    ///
74    /// Note: Not every generator needs an rng.  However, it is passed here because many do and this
75    /// lets us manage RNGs at the batch level instead of the array level.
76    fn generate(
77        &mut self,
78        length: RowCount,
79        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
80    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError>;
81
82    /// Generate an array of the given length using a new RNG with the default seed
83    ///
84    /// # Arguments
85    ///
86    /// * `length` - The number of elements to generate
87    ///
88    /// # Returns
89    ///
90    /// An array of the given length
91    fn generate_default(
92        &mut self,
93        length: RowCount,
94    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
95        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
96        Self::generate(self, length, &mut rng)
97    }
98    /// Get the data type of the array that this generator produces
99    ///
100    /// # Returns
101    ///
102    /// The data type of the array that this generator produces
103    fn data_type(&self) -> &DataType;
104    /// Gets metadata that should be associated with the field generated by this generator
105    fn metadata(&self) -> Option<HashMap<String, String>> {
106        None
107    }
108    /// Get the size of each element in bytes
109    ///
110    /// # Returns
111    ///
112    /// The size of each element in bytes.  Will be None if the size varies by element.
113    fn element_size_bytes(&self) -> Option<ByteCount>;
114}
115
116#[derive(Debug)]
117pub struct CycleNullGenerator {
118    generator: Box<dyn ArrayGenerator>,
119    validity: Vec<bool>,
120    idx: usize,
121}
122#[derive(Debug)]
123pub struct CycleNanGenerator {
124    generator: Box<dyn ArrayGenerator>,
125    nan_pattern: Vec<bool>,
126    idx: usize,
127}
128
129impl ArrayGenerator for CycleNanGenerator {
130    fn generate(
131        &mut self,
132        length: RowCount,
133        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
134    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
135        let array = self.generator.generate(length, rng)?;
136
137        // Only apply NaN pattern to float types
138        match array.data_type() {
139            DataType::Float16 => {
140                let float_array = array
141                    .as_any()
142                    .downcast_ref::<arrow_array::Float16Array>()
143                    .unwrap();
144                let mut values: Vec<half::f16> = float_array.values().to_vec();
145
146                for (i, &should_be_nan) in self
147                    .nan_pattern
148                    .iter()
149                    .cycle()
150                    .skip(self.idx)
151                    .take(length.0 as usize)
152                    .enumerate()
153                {
154                    if should_be_nan {
155                        values[i] = half::f16::NAN;
156                    }
157                }
158
159                self.idx = (self.idx + (length.0 as usize)) % self.nan_pattern.len();
160                Ok(Arc::new(arrow_array::Float16Array::from(values)))
161            }
162            DataType::Float32 => {
163                let float_array = array
164                    .as_any()
165                    .downcast_ref::<arrow_array::Float32Array>()
166                    .unwrap();
167                let mut values: Vec<f32> = float_array.values().to_vec();
168
169                for (i, &should_be_nan) in self
170                    .nan_pattern
171                    .iter()
172                    .cycle()
173                    .skip(self.idx)
174                    .take(length.0 as usize)
175                    .enumerate()
176                {
177                    if should_be_nan {
178                        values[i] = f32::NAN;
179                    }
180                }
181
182                self.idx = (self.idx + (length.0 as usize)) % self.nan_pattern.len();
183                Ok(Arc::new(arrow_array::Float32Array::from(values)))
184            }
185            DataType::Float64 => {
186                let float_array = array
187                    .as_any()
188                    .downcast_ref::<arrow_array::Float64Array>()
189                    .unwrap();
190                let mut values: Vec<f64> = float_array.values().to_vec();
191
192                for (i, &should_be_nan) in self
193                    .nan_pattern
194                    .iter()
195                    .cycle()
196                    .skip(self.idx)
197                    .take(length.0 as usize)
198                    .enumerate()
199                {
200                    if should_be_nan {
201                        values[i] = f64::NAN;
202                    }
203                }
204
205                self.idx = (self.idx + (length.0 as usize)) % self.nan_pattern.len();
206                Ok(Arc::new(arrow_array::Float64Array::from(values)))
207            }
208            _ => {
209                // For non-float types, just return the original array unchanged
210                Ok(array)
211            }
212        }
213    }
214
215    fn data_type(&self) -> &DataType {
216        self.generator.data_type()
217    }
218
219    fn element_size_bytes(&self) -> Option<ByteCount> {
220        self.generator.element_size_bytes()
221    }
222}
223
224impl ArrayGenerator for CycleNullGenerator {
225    fn generate(
226        &mut self,
227        length: RowCount,
228        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
229    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
230        let array = self.generator.generate(length, rng)?;
231        let data = array.to_data();
232        let validity_itr = self
233            .validity
234            .iter()
235            .cycle()
236            .skip(self.idx)
237            .take(length.0 as usize)
238            .copied();
239        let validity_bitmap = BooleanBuffer::from_iter(validity_itr);
240
241        self.idx = (self.idx + (length.0 as usize)) % self.validity.len();
242        unsafe {
243            let new_data = ArrayData::new_unchecked(
244                data.data_type().clone(),
245                data.len(),
246                None,
247                Some(validity_bitmap.into_inner()),
248                data.offset(),
249                data.buffers().to_vec(),
250                data.child_data().into(),
251            );
252            Ok(make_array(new_data))
253        }
254    }
255
256    fn data_type(&self) -> &DataType {
257        self.generator.data_type()
258    }
259
260    fn element_size_bytes(&self) -> Option<ByteCount> {
261        self.generator.element_size_bytes()
262    }
263}
264
265#[derive(Debug)]
266pub struct MetadataGenerator {
267    generator: Box<dyn ArrayGenerator>,
268    metadata: HashMap<String, String>,
269}
270
271impl ArrayGenerator for MetadataGenerator {
272    fn generate(
273        &mut self,
274        length: RowCount,
275        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
276    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
277        self.generator.generate(length, rng)
278    }
279
280    fn metadata(&self) -> Option<HashMap<String, String>> {
281        Some(self.metadata.clone())
282    }
283
284    fn data_type(&self) -> &DataType {
285        self.generator.data_type()
286    }
287
288    fn element_size_bytes(&self) -> Option<ByteCount> {
289        self.generator.element_size_bytes()
290    }
291}
292
293#[derive(Debug)]
294pub struct NullGenerator {
295    generator: Box<dyn ArrayGenerator>,
296    null_probability: f64,
297}
298
299impl ArrayGenerator for NullGenerator {
300    fn generate(
301        &mut self,
302        length: RowCount,
303        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
304    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
305        let array = self.generator.generate(length, rng)?;
306        let data = array.to_data();
307
308        if self.null_probability < 0.0 || self.null_probability > 1.0 {
309            return Err(ArrowError::InvalidArgumentError(format!(
310                "null_probability must be between 0 and 1, got {}",
311                self.null_probability
312            )));
313        }
314
315        let (null_count, new_validity) = if self.null_probability == 0.0 {
316            if data.null_count() == 0 {
317                return Ok(array);
318            } else {
319                (0_usize, None)
320            }
321        } else if self.null_probability == 1.0 {
322            if data.null_count() == data.len() {
323                return Ok(array);
324            } else {
325                let all_nulls = BooleanBuffer::new_unset(array.len());
326                (array.len(), Some(all_nulls.into_inner()))
327            }
328        } else {
329            let array_len = array.len();
330            let num_validity_bytes = array_len.div_ceil(8);
331            let mut null_count = 0;
332            // Sampling the RNG once per bit is kind of slow so we do this to sample once
333            // per byte.  We only get 8 bits of RNG resolution but that should be good enough.
334            let threshold = (self.null_probability * u8::MAX as f64) as u8;
335            let bytes = (0..num_validity_bytes)
336                .map(|byte_idx| {
337                    let mut sample = rng.random::<u64>();
338                    let mut byte: u8 = 0;
339                    for bit_idx in 0..8 {
340                        // We could probably overshoot and fill in extra bits with random data but
341                        // this is cleaner and that would mess up the null count
342                        byte <<= 1;
343                        let pos = byte_idx * 8 + (7 - bit_idx);
344                        if pos < array_len {
345                            let sample_piece = sample & 0xFF;
346                            let is_null = (sample_piece as u8) < threshold;
347                            byte |= (!is_null) as u8;
348                            null_count += is_null as usize;
349                        }
350                        sample >>= 8;
351                    }
352                    byte
353                })
354                .collect::<Vec<_>>();
355            let new_validity = Buffer::from_iter(bytes);
356            (null_count, Some(new_validity))
357        };
358
359        unsafe {
360            let new_data = ArrayData::new_unchecked(
361                data.data_type().clone(),
362                data.len(),
363                Some(null_count),
364                new_validity,
365                data.offset(),
366                data.buffers().to_vec(),
367                data.child_data().into(),
368            );
369            Ok(make_array(new_data))
370        }
371    }
372
373    fn metadata(&self) -> Option<HashMap<String, String>> {
374        self.generator.metadata()
375    }
376
377    fn data_type(&self) -> &DataType {
378        self.generator.data_type()
379    }
380
381    fn element_size_bytes(&self) -> Option<ByteCount> {
382        self.generator.element_size_bytes()
383    }
384}
385
386pub trait ArrayGeneratorExt {
387    /// Replaces the validity bitmap of generated arrays, inserting nulls with a given probability
388    fn with_random_nulls(self, null_probability: f64) -> Box<dyn ArrayGenerator>;
389    /// Replaces the validity bitmap of generated arrays with the inverse of `nulls`, cycling if needed
390    fn with_nulls(self, nulls: &[bool]) -> Box<dyn ArrayGenerator>;
391    /// Replaces the values of generated arrays with NaN values, cycling if needed
392    ///
393    /// Will have no effect if the data type is not a floating point data type
394    fn with_nans(self, nans: &[bool]) -> Box<dyn ArrayGenerator>;
395    /// Replaces the validity bitmap of generated arrays with `validity`, cycling if needed
396    fn with_validity(self, nulls: &[bool]) -> Box<dyn ArrayGenerator>;
397    fn with_metadata(self, metadata: HashMap<String, String>) -> Box<dyn ArrayGenerator>;
398}
399
400impl ArrayGeneratorExt for Box<dyn ArrayGenerator> {
401    fn with_random_nulls(self, null_probability: f64) -> Box<dyn ArrayGenerator> {
402        Box::new(NullGenerator {
403            generator: self,
404            null_probability,
405        })
406    }
407
408    fn with_nulls(self, nulls: &[bool]) -> Box<dyn ArrayGenerator> {
409        Box::new(CycleNullGenerator {
410            generator: self,
411            validity: nulls.iter().map(|v| !*v).collect(),
412            idx: 0,
413        })
414    }
415
416    fn with_nans(self, nans: &[bool]) -> Box<dyn ArrayGenerator> {
417        Box::new(CycleNanGenerator {
418            generator: self,
419            nan_pattern: nans.to_vec(),
420            idx: 0,
421        })
422    }
423
424    fn with_validity(self, validity: &[bool]) -> Box<dyn ArrayGenerator> {
425        Box::new(CycleNullGenerator {
426            generator: self,
427            validity: validity.to_vec(),
428            idx: 0,
429        })
430    }
431
432    fn with_metadata(self, metadata: HashMap<String, String>) -> Box<dyn ArrayGenerator> {
433        Box::new(MetadataGenerator {
434            generator: self,
435            metadata,
436        })
437    }
438}
439
440pub struct NTimesIter<I: Iterator>
441where
442    I::Item: Copy,
443{
444    iter: I,
445    n: u32,
446    cur: I::Item,
447    count: u32,
448}
449
450// Note: if this is used then there is a performance hit as the
451// inner loop cannot experience vectorization
452//
453// TODO: maybe faster to build the vec and then repeat it into
454// the destination array?
455impl<I: Iterator> Iterator for NTimesIter<I>
456where
457    I::Item: Copy,
458{
459    type Item = I::Item;
460
461    fn next(&mut self) -> Option<Self::Item> {
462        if self.count == 0 {
463            self.count = self.n - 1;
464            self.cur = self.iter.next()?;
465        } else {
466            self.count -= 1;
467        }
468        Some(self.cur)
469    }
470
471    fn size_hint(&self) -> (usize, Option<usize>) {
472        let (lower, upper) = self.iter.size_hint();
473        let lower = lower * self.n as usize;
474        let upper = upper.map(|u| u * self.n as usize);
475        (lower, upper)
476    }
477}
478
479pub struct FnGen<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T>
480where
481    T: Copy + Default,
482    ArrayType: arrow_array::Array + From<Vec<T>>,
483{
484    data_type: DataType,
485    generator: F,
486    array_type: PhantomData<ArrayType>,
487    repeat: u32,
488    leftover: T,
489    leftover_count: u32,
490    element_size_bytes: Option<ByteCount>,
491}
492
493impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> std::fmt::Debug
494    for FnGen<T, ArrayType, F>
495where
496    T: Copy + Default,
497    ArrayType: arrow_array::Array + From<Vec<T>>,
498{
499    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500        f.debug_struct("FnGen")
501            .field("data_type", &self.data_type)
502            .field("array_type", &self.array_type)
503            .field("repeat", &self.repeat)
504            .field("leftover_count", &self.leftover_count)
505            .field("element_size_bytes", &self.element_size_bytes)
506            .finish()
507    }
508}
509
510impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> FnGen<T, ArrayType, F>
511where
512    T: Copy + Default,
513    ArrayType: arrow_array::Array + From<Vec<T>>,
514{
515    fn new_known_size(
516        data_type: DataType,
517        generator: F,
518        repeat: u32,
519        element_size_bytes: ByteCount,
520    ) -> Self {
521        Self {
522            data_type,
523            generator,
524            array_type: PhantomData,
525            repeat,
526            leftover: T::default(),
527            leftover_count: 0,
528            element_size_bytes: Some(element_size_bytes),
529        }
530    }
531
532    fn new_unknown_size(data_type: DataType, generator: F, repeat: u32) -> Self {
533        Self {
534            data_type,
535            generator,
536            array_type: PhantomData,
537            repeat,
538            leftover: T::default(),
539            leftover_count: 0,
540            element_size_bytes: None,
541        }
542    }
543}
544
545impl<T, ArrayType, F: FnMut(&mut rand_xoshiro::Xoshiro256PlusPlus) -> T> ArrayGenerator
546    for FnGen<T, ArrayType, F>
547where
548    T: Copy + Default + Send + Sync,
549    ArrayType: arrow_array::Array + From<Vec<T>> + 'static,
550    F: Send + Sync,
551{
552    fn generate(
553        &mut self,
554        length: RowCount,
555        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
556    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
557        let iter = (0..length.0).map(|_| (self.generator)(rng));
558        let values = if self.repeat > 1 {
559            Vec::from_iter(
560                NTimesIter {
561                    iter,
562                    n: self.repeat,
563                    cur: self.leftover,
564                    count: self.leftover_count,
565                }
566                .take(length.0 as usize),
567            )
568        } else {
569            Vec::from_iter(iter)
570        };
571        self.leftover_count = ((self.leftover_count as u64 + length.0) % self.repeat as u64) as u32;
572        self.leftover = values.last().copied().unwrap_or(T::default());
573        Ok(Arc::new(ArrayType::from(values)))
574    }
575
576    fn data_type(&self) -> &DataType {
577        &self.data_type
578    }
579
580    fn element_size_bytes(&self) -> Option<ByteCount> {
581        self.element_size_bytes
582    }
583}
584
585#[derive(Copy, Clone, Debug)]
586pub struct Seed(pub u64);
587pub const DEFAULT_SEED: Seed = Seed(42);
588
589impl From<u64> for Seed {
590    fn from(n: u64) -> Self {
591        Self(n)
592    }
593}
594
595#[derive(Debug)]
596pub struct CycleVectorGenerator {
597    underlying_gen: Box<dyn ArrayGenerator>,
598    dimension: Dimension,
599    data_type: DataType,
600}
601
602impl CycleVectorGenerator {
603    pub fn new(underlying_gen: Box<dyn ArrayGenerator>, dimension: Dimension) -> Self {
604        let data_type = DataType::FixedSizeList(
605            Arc::new(Field::new("item", underlying_gen.data_type().clone(), true)),
606            dimension.0 as i32,
607        );
608        Self {
609            underlying_gen,
610            dimension,
611            data_type,
612        }
613    }
614}
615
616impl ArrayGenerator for CycleVectorGenerator {
617    fn generate(
618        &mut self,
619        length: RowCount,
620        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
621    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
622        let values = self
623            .underlying_gen
624            .generate(RowCount::from(length.0 * self.dimension.0 as u64), rng)?;
625        let field = Arc::new(Field::new("item", values.data_type().clone(), true));
626        let values = Arc::new(values);
627
628        let array = FixedSizeListArray::try_new(field, self.dimension.0 as i32, values, None)?;
629
630        Ok(Arc::new(array))
631    }
632
633    fn data_type(&self) -> &DataType {
634        &self.data_type
635    }
636
637    fn element_size_bytes(&self) -> Option<ByteCount> {
638        self.underlying_gen
639            .element_size_bytes()
640            .map(|byte_count| ByteCount::from(byte_count.0 * self.dimension.0 as u64))
641    }
642}
643
644#[derive(Debug)]
645pub struct CycleListGenerator {
646    underlying_gen: Box<dyn ArrayGenerator>,
647    lengths_gen: Box<dyn ArrayGenerator>,
648    data_type: DataType,
649}
650
651impl CycleListGenerator {
652    pub fn new(
653        underlying_gen: Box<dyn ArrayGenerator>,
654        min_list_size: Dimension,
655        max_list_size: Dimension,
656    ) -> Self {
657        let data_type = DataType::List(Arc::new(Field::new(
658            "item",
659            underlying_gen.data_type().clone(),
660            true,
661        )));
662        let lengths_dist = Uniform::new(min_list_size.0, max_list_size.0).unwrap();
663        let lengths_gen = rand_with_distribution::<UInt32Type, Uniform<u32>>(lengths_dist);
664        Self {
665            underlying_gen,
666            lengths_gen,
667            data_type,
668        }
669    }
670}
671
672impl ArrayGenerator for CycleListGenerator {
673    fn generate(
674        &mut self,
675        length: RowCount,
676        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
677    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
678        let lengths = self.lengths_gen.generate(length, rng)?;
679        let lengths = lengths.as_primitive::<UInt32Type>();
680        let total_length = lengths.values().iter().map(|i| *i as u64).sum::<u64>();
681        let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
682        let values = self
683            .underlying_gen
684            .generate(RowCount::from(total_length), rng)?;
685        let field = Arc::new(Field::new("item", values.data_type().clone(), true));
686        let values = Arc::new(values);
687
688        let array = ListArray::try_new(field, offsets, values, None)?;
689
690        Ok(Arc::new(array))
691    }
692
693    fn data_type(&self) -> &DataType {
694        &self.data_type
695    }
696
697    fn element_size_bytes(&self) -> Option<ByteCount> {
698        None
699    }
700}
701
702#[derive(Debug, Default)]
703pub struct PseudoUuidGenerator {}
704
705impl ArrayGenerator for PseudoUuidGenerator {
706    fn generate(
707        &mut self,
708        length: RowCount,
709        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
710    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
711        Ok(Arc::new(FixedSizeBinaryArray::try_from_iter(
712            (0..length.0).map(|_| {
713                let mut data = vec![0; 16];
714                rng.fill_bytes(&mut data);
715                data
716            }),
717        )?))
718    }
719
720    fn data_type(&self) -> &DataType {
721        &DataType::FixedSizeBinary(16)
722    }
723
724    fn element_size_bytes(&self) -> Option<ByteCount> {
725        Some(ByteCount::from(16))
726    }
727}
728
729#[derive(Debug, Default)]
730pub struct PseudoUuidHexGenerator {}
731
732impl ArrayGenerator for PseudoUuidHexGenerator {
733    fn generate(
734        &mut self,
735        length: RowCount,
736        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
737    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
738        let mut data = vec![0; 16 * length.0 as usize];
739        rng.fill_bytes(&mut data);
740        let data_hex = hex::encode(data);
741
742        Ok(Arc::new(StringArray::from_iter_values(
743            (0..length.0 as usize).map(|i| data_hex.get(i * 32..(i + 1) * 32).unwrap()),
744        )))
745    }
746
747    fn data_type(&self) -> &DataType {
748        &DataType::Utf8
749    }
750
751    fn element_size_bytes(&self) -> Option<ByteCount> {
752        Some(ByteCount::from(16))
753    }
754}
755
756#[derive(Debug, Default)]
757pub struct RandomBooleanGenerator {}
758
759impl ArrayGenerator for RandomBooleanGenerator {
760    fn generate(
761        &mut self,
762        length: RowCount,
763        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
764    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
765        let num_bytes = length.0.div_ceil(8);
766        let mut bytes = vec![0; num_bytes as usize];
767        rng.fill_bytes(&mut bytes);
768        let bytes = BooleanBuffer::new(Buffer::from(bytes), 0, length.0 as usize);
769        Ok(Arc::new(arrow_array::BooleanArray::new(bytes, None)))
770    }
771
772    fn data_type(&self) -> &DataType {
773        &DataType::Boolean
774    }
775
776    fn element_size_bytes(&self) -> Option<ByteCount> {
777        // We can't say 1/8th of a byte and 1 byte would be a pretty extreme over-count so let's leave
778        // it at None until someone needs this.  Then we can probably special case this (e.g. make a ByteCount::ONE_BIT)
779        None
780    }
781}
782
783// Instead of using the "standard distribution" and generating values there are some cases (e.g. f16 / decimal)
784// where we just generate random bytes because there is no rand support
785pub struct RandomBytesGenerator<T: ArrowPrimitiveType + Send + Sync> {
786    phantom: PhantomData<T>,
787    data_type: DataType,
788}
789
790impl<T: ArrowPrimitiveType + Send + Sync> std::fmt::Debug for RandomBytesGenerator<T> {
791    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
792        f.debug_struct("RandomBytesGenerator")
793            .field("data_type", &self.data_type)
794            .finish()
795    }
796}
797
798impl<T: ArrowPrimitiveType + Send + Sync> RandomBytesGenerator<T> {
799    fn new(data_type: DataType) -> Self {
800        Self {
801            phantom: Default::default(),
802            data_type,
803        }
804    }
805
806    fn byte_width() -> Result<u64, ArrowError> {
807        T::DATA_TYPE.primitive_width().ok_or_else(|| ArrowError::InvalidArgumentError(format!("Cannot generate the data type {} with the RandomBytesGenerator because it is not a fixed-width bytes type", T::DATA_TYPE))).map(|val| val as u64)
808    }
809}
810
811impl<T: ArrowPrimitiveType + Send + Sync> ArrayGenerator for RandomBytesGenerator<T> {
812    fn generate(
813        &mut self,
814        length: RowCount,
815        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
816    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
817        let num_bytes = length.0 * Self::byte_width()?;
818        let mut bytes = vec![0; num_bytes as usize];
819        rng.fill_bytes(&mut bytes);
820        let bytes = ScalarBuffer::new(Buffer::from(bytes), 0, length.0 as usize);
821        Ok(Arc::new(
822            PrimitiveArray::<T>::new(bytes, None).with_data_type(self.data_type.clone()),
823        ))
824    }
825
826    fn data_type(&self) -> &DataType {
827        &self.data_type
828    }
829
830    fn element_size_bytes(&self) -> Option<ByteCount> {
831        Self::byte_width().map(ByteCount::from).ok()
832    }
833}
834
835// This is pretty much the same thing as RandomBinaryGenerator but we can't use that
836// because there is no ArrowPrimitiveType for FixedSizeBinary
837#[derive(Debug)]
838pub struct RandomFixedSizeBinaryGenerator {
839    data_type: DataType,
840    size: i32,
841}
842
843impl RandomFixedSizeBinaryGenerator {
844    fn new(size: i32) -> Self {
845        Self {
846            size,
847            data_type: DataType::FixedSizeBinary(size),
848        }
849    }
850}
851
852impl ArrayGenerator for RandomFixedSizeBinaryGenerator {
853    fn generate(
854        &mut self,
855        length: RowCount,
856        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
857    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
858        let num_bytes = length.0 * self.size as u64;
859        let mut bytes = vec![0; num_bytes as usize];
860        rng.fill_bytes(&mut bytes);
861        Ok(Arc::new(FixedSizeBinaryArray::new(
862            self.size,
863            Buffer::from(bytes),
864            None,
865        )))
866    }
867
868    fn data_type(&self) -> &DataType {
869        &self.data_type
870    }
871
872    fn element_size_bytes(&self) -> Option<ByteCount> {
873        Some(ByteCount::from(self.size as u64))
874    }
875}
876
877#[derive(Debug)]
878pub struct RandomIntervalGenerator {
879    unit: IntervalUnit,
880    data_type: DataType,
881}
882
883impl RandomIntervalGenerator {
884    pub fn new(unit: IntervalUnit) -> Self {
885        Self {
886            unit,
887            data_type: DataType::Interval(unit),
888        }
889    }
890}
891
892impl ArrayGenerator for RandomIntervalGenerator {
893    fn generate(
894        &mut self,
895        length: RowCount,
896        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
897    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
898        match self.unit {
899            IntervalUnit::YearMonth => {
900                let months = (0..length.0)
901                    .map(|_| rng.random::<i32>())
902                    .collect::<Vec<_>>();
903                Ok(Arc::new(arrow_array::IntervalYearMonthArray::from(months)))
904            }
905            IntervalUnit::MonthDayNano => {
906                let day_time_array = (0..length.0)
907                    .map(|_| IntervalMonthDayNano::new(rng.random(), rng.random(), rng.random()))
908                    .collect::<Vec<_>>();
909                Ok(Arc::new(arrow_array::IntervalMonthDayNanoArray::from(
910                    day_time_array,
911                )))
912            }
913            IntervalUnit::DayTime => {
914                let day_time_array = (0..length.0)
915                    .map(|_| IntervalDayTime::new(rng.random(), rng.random()))
916                    .collect::<Vec<_>>();
917                Ok(Arc::new(arrow_array::IntervalDayTimeArray::from(
918                    day_time_array,
919                )))
920            }
921        }
922    }
923
924    fn data_type(&self) -> &DataType {
925        &self.data_type
926    }
927
928    fn element_size_bytes(&self) -> Option<ByteCount> {
929        Some(ByteCount::from(12))
930    }
931}
932#[derive(Debug)]
933pub struct RandomBinaryGenerator {
934    bytes_per_element: ByteCount,
935    scale_to_utf8: bool,
936    is_large: bool,
937    data_type: DataType,
938}
939
940impl RandomBinaryGenerator {
941    pub fn new(bytes_per_element: ByteCount, scale_to_utf8: bool, is_large: bool) -> Self {
942        Self {
943            bytes_per_element,
944            scale_to_utf8,
945            is_large,
946            data_type: match (scale_to_utf8, is_large) {
947                (false, false) => DataType::Binary,
948                (false, true) => DataType::LargeBinary,
949                (true, false) => DataType::Utf8,
950                (true, true) => DataType::LargeUtf8,
951            },
952        }
953    }
954}
955
956impl ArrayGenerator for RandomBinaryGenerator {
957    fn generate(
958        &mut self,
959        length: RowCount,
960        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
961    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
962        let mut bytes = vec![0; (self.bytes_per_element.0 * length.0) as usize];
963        rng.fill_bytes(&mut bytes);
964        if self.scale_to_utf8 {
965            // This doesn't give us the full UTF-8 range and it isn't statistically correct but
966            // it's fast and probably good enough for most cases
967            bytes = bytes.into_iter().map(|val| (val % 95) + 32).collect();
968        }
969        let bytes = Buffer::from(bytes);
970        if self.is_large {
971            let offsets = OffsetBuffer::from_lengths(iter::repeat_n(
972                self.bytes_per_element.0 as usize,
973                length.0 as usize,
974            ));
975            if self.scale_to_utf8 {
976                // This is safe because we are only using printable characters
977                unsafe {
978                    Ok(Arc::new(arrow_array::LargeStringArray::new_unchecked(
979                        offsets, bytes, None,
980                    )))
981                }
982            } else {
983                unsafe {
984                    Ok(Arc::new(arrow_array::LargeBinaryArray::new_unchecked(
985                        offsets, bytes, None,
986                    )))
987                }
988            }
989        } else {
990            let offsets = OffsetBuffer::from_lengths(iter::repeat_n(
991                self.bytes_per_element.0 as usize,
992                length.0 as usize,
993            ));
994            if self.scale_to_utf8 {
995                // This is safe because we are only using printable characters
996                unsafe {
997                    Ok(Arc::new(arrow_array::StringArray::new_unchecked(
998                        offsets, bytes, None,
999                    )))
1000                }
1001            } else {
1002                unsafe {
1003                    Ok(Arc::new(arrow_array::BinaryArray::new_unchecked(
1004                        offsets, bytes, None,
1005                    )))
1006                }
1007            }
1008        }
1009    }
1010
1011    fn data_type(&self) -> &DataType {
1012        &self.data_type
1013    }
1014
1015    fn element_size_bytes(&self) -> Option<ByteCount> {
1016        // Not exactly correct since there are N + 1 4-byte offsets and this only counts N
1017        Some(ByteCount::from(
1018            self.bytes_per_element.0 + std::mem::size_of::<i32>() as u64,
1019        ))
1020    }
1021}
1022
1023/// Generate a sequence of strings with a prefix and a counter
1024///
1025/// For example, if the prefix is "user_" the strings will be "user_0", "user_1", ...
1026#[derive(Debug)]
1027pub struct PrefixPlusCounterGenerator {
1028    prefix: String,
1029    is_large: bool,
1030    data_type: DataType,
1031    current_counter: u64,
1032}
1033
1034impl PrefixPlusCounterGenerator {
1035    pub fn new(prefix: String, is_large: bool) -> Self {
1036        Self {
1037            prefix,
1038            is_large,
1039            data_type: if is_large {
1040                DataType::LargeUtf8
1041            } else {
1042                DataType::Utf8
1043            },
1044            current_counter: 0,
1045        }
1046    }
1047
1048    fn generate_values<T: OffsetSizeTrait>(
1049        &self,
1050        start: u64,
1051        num_values: u64,
1052    ) -> Result<Arc<dyn Array>, ArrowError> {
1053        let max_counter = start + num_values;
1054        let max_digits_per_counter = (max_counter as f64).log10().ceil() as u64;
1055        let max_bytes_per_str = max_digits_per_counter + self.prefix.len() as u64;
1056        let max_bytes = max_bytes_per_str * num_values;
1057        let mut builder =
1058            GenericStringBuilder::<T>::with_capacity(num_values as usize, max_bytes as usize);
1059        let mut word = String::with_capacity(max_bytes_per_str as usize);
1060        word.push_str(&self.prefix);
1061        for i in 0..num_values {
1062            let counter = start + i;
1063            word.truncate(self.prefix.len());
1064            word.push_str(&counter.to_string());
1065            builder.append_value(&word);
1066        }
1067        Ok(Arc::new(builder.finish()))
1068    }
1069}
1070
1071impl ArrayGenerator for PrefixPlusCounterGenerator {
1072    fn generate(
1073        &mut self,
1074        length: RowCount,
1075        _rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1076    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1077        let start = self.current_counter;
1078        self.current_counter += length.0;
1079        if self.is_large {
1080            self.generate_values::<i64>(start, length.0)
1081        } else {
1082            self.generate_values::<i32>(start, length.0)
1083        }
1084    }
1085
1086    fn data_type(&self) -> &DataType {
1087        &self.data_type
1088    }
1089
1090    fn element_size_bytes(&self) -> Option<ByteCount> {
1091        // It's not consistent
1092        None
1093    }
1094}
1095
1096/// Generate a sequence of binary strings with a prefix and a counter
1097///
1098/// The counter will be encoded (little-endian) as a u8, u16, u32, or u64 and added to the prefix
1099/// As long as more than 256 values are generated then the resulting array will have
1100/// variable width
1101#[derive(Debug)]
1102pub struct BinaryPrefixPlusCounterGenerator {
1103    prefix: Arc<[u8]>,
1104    is_large: bool,
1105    data_type: DataType,
1106    current_counter: u64,
1107}
1108
1109impl BinaryPrefixPlusCounterGenerator {
1110    pub fn new(prefix: Arc<[u8]>, is_large: bool) -> Self {
1111        Self {
1112            prefix,
1113            is_large,
1114            data_type: if is_large {
1115                DataType::LargeBinary
1116            } else {
1117                DataType::Binary
1118            },
1119            current_counter: 0,
1120        }
1121    }
1122
1123    fn generate_values<T: OffsetSizeTrait>(
1124        &self,
1125        start: u64,
1126        num_values: u64,
1127    ) -> Result<Arc<dyn Array>, ArrowError> {
1128        let max_bytes = (self.prefix.len() + std::mem::size_of::<u64>()) * num_values as usize;
1129        let mut builder = GenericBinaryBuilder::<T>::with_capacity(num_values as usize, max_bytes);
1130        let mut word = Vec::with_capacity(self.prefix.len() + std::mem::size_of::<u64>());
1131        word.extend_from_slice(&self.prefix);
1132        for i in 0..num_values {
1133            let counter = start + i;
1134            word.truncate(self.prefix.len());
1135            if counter < u8::MAX as u64 {
1136                word.push(counter as u8);
1137            } else if counter < u16::MAX as u64 {
1138                word.extend_from_slice(&(counter as u16).to_le_bytes());
1139            } else if counter < u32::MAX as u64 {
1140                word.extend_from_slice(&(counter as u32).to_le_bytes());
1141            } else {
1142                word.extend_from_slice(&counter.to_le_bytes());
1143            }
1144            builder.append_value(&word);
1145        }
1146        Ok(Arc::new(builder.finish()))
1147    }
1148}
1149
1150impl ArrayGenerator for BinaryPrefixPlusCounterGenerator {
1151    fn generate(
1152        &mut self,
1153        length: RowCount,
1154        _rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1155    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1156        let start = self.current_counter;
1157        self.current_counter += length.0;
1158        if self.is_large {
1159            self.generate_values::<i64>(start, length.0)
1160        } else {
1161            self.generate_values::<i32>(start, length.0)
1162        }
1163    }
1164
1165    fn data_type(&self) -> &DataType {
1166        &self.data_type
1167    }
1168
1169    fn element_size_bytes(&self) -> Option<ByteCount> {
1170        // It's not consistent
1171        None
1172    }
1173}
1174
1175// Common English stop words placed at the front to be sampled more frequently
1176const STOP_WORDS: &[&str] = &[
1177    "a", "an", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "is", "it",
1178    "no", "not", "of", "on", "or", "such", "that", "the", "their", "then", "there", "these",
1179    "they", "this", "to", "was", "will", "with",
1180];
1181
1182/// Word list with stop words at the front for Zipf sampling, computed once.
1183static SENTENCE_WORDS: LazyLock<Vec<&'static str>> = LazyLock::new(|| {
1184    let all_words = random_word::all(random_word::Lang::En);
1185    let mut words = Vec::with_capacity(STOP_WORDS.len() + all_words.len());
1186    words.extend(STOP_WORDS.iter().copied());
1187    words.extend(
1188        all_words
1189            .iter()
1190            .filter(|w| !STOP_WORDS.contains(w))
1191            .copied(),
1192    );
1193    words
1194});
1195
1196struct RandomSentenceGenerator {
1197    min_words: usize,
1198    max_words: usize,
1199    /// Zipf distribution for word selection (favors lower indices)
1200    zipf: Zipf<f64>,
1201    is_large: bool,
1202}
1203
1204impl std::fmt::Debug for RandomSentenceGenerator {
1205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1206        f.debug_struct("RandomSentenceGenerator")
1207            .field("min_words", &self.min_words)
1208            .field("max_words", &self.max_words)
1209            .field("num_words", &SENTENCE_WORDS.len())
1210            .field("is_large", &self.is_large)
1211            .finish()
1212    }
1213}
1214
1215impl RandomSentenceGenerator {
1216    pub fn new(min_words: usize, max_words: usize, is_large: bool) -> Self {
1217        // Zipf distribution with exponent ~1.0 approximates natural language
1218        let zipf = Zipf::new(SENTENCE_WORDS.len() as f64, 1.0).unwrap();
1219
1220        Self {
1221            min_words,
1222            max_words,
1223            zipf,
1224            is_large,
1225        }
1226    }
1227}
1228
1229impl ArrayGenerator for RandomSentenceGenerator {
1230    fn generate(
1231        &mut self,
1232        length: RowCount,
1233        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1234    ) -> Result<Arc<dyn Array>, ArrowError> {
1235        let mut values = Vec::with_capacity(length.0 as usize);
1236
1237        for _ in 0..length.0 {
1238            let num_words = rng.random_range(self.min_words..=self.max_words);
1239            let sentence: String = (0..num_words)
1240                .map(|_| {
1241                    // Zipf returns 1-indexed values, subtract 1 for 0-indexed array
1242                    let idx = rng.sample(self.zipf) as usize - 1;
1243                    SENTENCE_WORDS[idx]
1244                })
1245                .collect::<Vec<_>>()
1246                .join(" ");
1247            values.push(sentence);
1248        }
1249
1250        if self.is_large {
1251            Ok(Arc::new(LargeStringArray::from(values)))
1252        } else {
1253            Ok(Arc::new(StringArray::from(values)))
1254        }
1255    }
1256
1257    fn data_type(&self) -> &DataType {
1258        if self.is_large {
1259            &DataType::LargeUtf8
1260        } else {
1261            &DataType::Utf8
1262        }
1263    }
1264
1265    fn element_size_bytes(&self) -> Option<ByteCount> {
1266        // Estimate average word length as 5, plus space
1267        // See https://arxiv.org/pdf/1208.6109
1268        let avg_word_length = 6;
1269        let avg_words = (self.min_words + self.max_words) / 2;
1270        Some(ByteCount::from((avg_word_length * avg_words) as u64))
1271    }
1272}
1273
1274#[derive(Debug)]
1275struct RandomWordGenerator {
1276    words: &'static [&'static str],
1277    is_large: bool,
1278}
1279
1280impl RandomWordGenerator {
1281    pub fn new(is_large: bool) -> Self {
1282        let words = random_word::all(random_word::Lang::En);
1283        Self { words, is_large }
1284    }
1285}
1286
1287impl ArrayGenerator for RandomWordGenerator {
1288    fn generate(
1289        &mut self,
1290        length: RowCount,
1291        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1292    ) -> Result<Arc<dyn Array>, ArrowError> {
1293        let mut values = Vec::with_capacity(length.0 as usize);
1294
1295        for _ in 0..length.0 {
1296            let word = self.words[rng.random_range(0..self.words.len())];
1297            values.push(word.to_string());
1298        }
1299
1300        if self.is_large {
1301            Ok(Arc::new(LargeStringArray::from(values)))
1302        } else {
1303            Ok(Arc::new(StringArray::from(values)))
1304        }
1305    }
1306
1307    fn data_type(&self) -> &DataType {
1308        if self.is_large {
1309            &DataType::LargeUtf8
1310        } else {
1311            &DataType::Utf8
1312        }
1313    }
1314
1315    fn element_size_bytes(&self) -> Option<ByteCount> {
1316        // Average English word length is ~5 characters
1317        Some(ByteCount::from(5))
1318    }
1319}
1320
1321#[derive(Debug)]
1322pub struct VariableRandomBinaryGenerator {
1323    lengths_gen: Box<dyn ArrayGenerator>,
1324    data_type: DataType,
1325}
1326
1327impl VariableRandomBinaryGenerator {
1328    pub fn new(min_bytes_per_element: ByteCount, max_bytes_per_element: ByteCount) -> Self {
1329        let lengths_dist = Uniform::new_inclusive(
1330            min_bytes_per_element.0 as i32,
1331            max_bytes_per_element.0 as i32,
1332        )
1333        .unwrap();
1334        let lengths_gen = rand_with_distribution::<Int32Type, Uniform<i32>>(lengths_dist);
1335
1336        Self {
1337            lengths_gen,
1338            data_type: DataType::Binary,
1339        }
1340    }
1341}
1342
1343impl ArrayGenerator for VariableRandomBinaryGenerator {
1344    fn generate(
1345        &mut self,
1346        length: RowCount,
1347        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1348    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1349        let lengths = self.lengths_gen.generate(length, rng)?;
1350        let lengths = lengths.as_primitive::<Int32Type>();
1351        let total_length = lengths.values().iter().map(|i| *i as usize).sum::<usize>();
1352        let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1353        let mut bytes = vec![0; total_length];
1354        rng.fill_bytes(&mut bytes);
1355        let bytes = Buffer::from(bytes);
1356        Ok(Arc::new(BinaryArray::try_new(offsets, bytes, None)?))
1357    }
1358
1359    fn data_type(&self) -> &DataType {
1360        &self.data_type
1361    }
1362
1363    fn element_size_bytes(&self) -> Option<ByteCount> {
1364        None
1365    }
1366}
1367
1368pub struct CycleBinaryGenerator<T: ByteArrayType> {
1369    values: Vec<u8>,
1370    lengths: Vec<usize>,
1371    data_type: DataType,
1372    array_type: PhantomData<T>,
1373    width: Option<ByteCount>,
1374    idx: usize,
1375}
1376
1377impl<T: ByteArrayType> std::fmt::Debug for CycleBinaryGenerator<T> {
1378    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1379        f.debug_struct("CycleBinaryGenerator")
1380            .field("values", &self.values)
1381            .field("lengths", &self.lengths)
1382            .field("data_type", &self.data_type)
1383            .field("width", &self.width)
1384            .field("idx", &self.idx)
1385            .finish()
1386    }
1387}
1388
1389impl<T: ByteArrayType> CycleBinaryGenerator<T> {
1390    pub fn from_strings(values: &[&str]) -> Self {
1391        if values.is_empty() {
1392            panic!("Attempt to create a cycle generator with no values");
1393        }
1394        let lengths = values.iter().map(|s| s.len()).collect::<Vec<_>>();
1395        let typical_length = lengths[0];
1396        let width = if lengths.iter().all(|item| *item == typical_length) {
1397            Some(ByteCount::from(
1398                typical_length as u64 + std::mem::size_of::<i32>() as u64,
1399            ))
1400        } else {
1401            None
1402        };
1403        let values = values
1404            .iter()
1405            .flat_map(|s| s.as_bytes().iter().copied())
1406            .collect::<Vec<_>>();
1407        Self {
1408            values,
1409            lengths,
1410            data_type: T::DATA_TYPE,
1411            array_type: PhantomData,
1412            width,
1413            idx: 0,
1414        }
1415    }
1416}
1417
1418impl<T: ByteArrayType> ArrayGenerator for CycleBinaryGenerator<T> {
1419    fn generate(
1420        &mut self,
1421        length: RowCount,
1422        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
1423    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1424        let lengths = self
1425            .lengths
1426            .iter()
1427            .copied()
1428            .cycle()
1429            .skip(self.idx)
1430            .take(length.0 as usize);
1431        let num_bytes = lengths.clone().sum();
1432        let byte_offset = self.lengths[0..self.idx].iter().sum();
1433        let bytes = self
1434            .values
1435            .iter()
1436            .cycle()
1437            .skip(byte_offset)
1438            .copied()
1439            .take(num_bytes)
1440            .collect::<Vec<_>>();
1441        let bytes = Buffer::from(bytes);
1442        let offsets = OffsetBuffer::from_lengths(lengths);
1443        self.idx = (self.idx + length.0 as usize) % self.lengths.len();
1444        Ok(Arc::new(arrow_array::GenericByteArray::<T>::new(
1445            offsets, bytes, None,
1446        )))
1447    }
1448
1449    fn data_type(&self) -> &DataType {
1450        &self.data_type
1451    }
1452
1453    fn element_size_bytes(&self) -> Option<ByteCount> {
1454        self.width
1455    }
1456}
1457
1458pub struct FixedBinaryGenerator<T: ByteArrayType> {
1459    value: Vec<u8>,
1460    data_type: DataType,
1461    array_type: PhantomData<T>,
1462}
1463
1464impl<T: ByteArrayType> std::fmt::Debug for FixedBinaryGenerator<T> {
1465    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1466        f.debug_struct("FixedBinaryGenerator")
1467            .field("value", &self.value)
1468            .field("data_type", &self.data_type)
1469            .finish()
1470    }
1471}
1472
1473impl<T: ByteArrayType> FixedBinaryGenerator<T> {
1474    pub fn new(value: Vec<u8>) -> Self {
1475        Self {
1476            value,
1477            data_type: T::DATA_TYPE,
1478            array_type: PhantomData,
1479        }
1480    }
1481}
1482
1483impl<T: ByteArrayType> ArrayGenerator for FixedBinaryGenerator<T> {
1484    fn generate(
1485        &mut self,
1486        length: RowCount,
1487        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
1488    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1489        let bytes = Buffer::from(Vec::from_iter(
1490            self.value
1491                .iter()
1492                .cycle()
1493                .take((length.0 * self.value.len() as u64) as usize)
1494                .copied(),
1495        ));
1496        let offsets =
1497            OffsetBuffer::from_lengths(iter::repeat_n(self.value.len(), length.0 as usize));
1498        Ok(Arc::new(arrow_array::GenericByteArray::<T>::new(
1499            offsets, bytes, None,
1500        )))
1501    }
1502
1503    fn data_type(&self) -> &DataType {
1504        &self.data_type
1505    }
1506
1507    fn element_size_bytes(&self) -> Option<ByteCount> {
1508        // Not exactly correct since there are N + 1 4-byte offsets and this only counts N
1509        Some(ByteCount::from(
1510            self.value.len() as u64 + std::mem::size_of::<i32>() as u64,
1511        ))
1512    }
1513}
1514
1515pub struct DictionaryGenerator<K: ArrowDictionaryKeyType> {
1516    generator: Box<dyn ArrayGenerator>,
1517    data_type: DataType,
1518    key_type: PhantomData<K>,
1519    key_width: u64,
1520}
1521
1522impl<K: ArrowDictionaryKeyType> std::fmt::Debug for DictionaryGenerator<K> {
1523    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1524        f.debug_struct("DictionaryGenerator")
1525            .field("generator", &self.generator)
1526            .field("data_type", &self.data_type)
1527            .field("key_width", &self.key_width)
1528            .finish()
1529    }
1530}
1531
1532impl<K: ArrowDictionaryKeyType> DictionaryGenerator<K> {
1533    fn new(generator: Box<dyn ArrayGenerator>) -> Self {
1534        let key_type = Box::new(K::DATA_TYPE);
1535        let key_width = key_type
1536            .primitive_width()
1537            .expect("dictionary key types should have a known width")
1538            as u64;
1539        let val_type = Box::new(generator.data_type().clone());
1540        let dict_type = DataType::Dictionary(key_type, val_type);
1541        Self {
1542            generator,
1543            data_type: dict_type,
1544            key_type: PhantomData,
1545            key_width,
1546        }
1547    }
1548}
1549
1550impl<K: ArrowDictionaryKeyType + Send + Sync> ArrayGenerator for DictionaryGenerator<K> {
1551    fn generate(
1552        &mut self,
1553        length: RowCount,
1554        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1555    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1556        let underlying = self.generator.generate(length, rng)?;
1557        arrow_cast::cast::cast(&underlying, &self.data_type)
1558    }
1559
1560    fn data_type(&self) -> &DataType {
1561        &self.data_type
1562    }
1563
1564    fn element_size_bytes(&self) -> Option<ByteCount> {
1565        self.generator
1566            .element_size_bytes()
1567            .map(|size_bytes| ByteCount::from(size_bytes.0 + self.key_width))
1568    }
1569}
1570
1571/// Generator that produces low-cardinality data by generating a fixed set of
1572/// unique values and then randomly selecting from them.
1573struct LowCardinalityGenerator {
1574    inner: Box<dyn ArrayGenerator>,
1575    cardinality: usize,
1576    /// Cached unique values, generated on first call
1577    unique_values: Option<Arc<dyn Array>>,
1578}
1579
1580impl std::fmt::Debug for LowCardinalityGenerator {
1581    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1582        f.debug_struct("LowCardinalityGenerator")
1583            .field("inner", &self.inner)
1584            .field("cardinality", &self.cardinality)
1585            .field("initialized", &self.unique_values.is_some())
1586            .finish()
1587    }
1588}
1589
1590impl LowCardinalityGenerator {
1591    fn new(inner: Box<dyn ArrayGenerator>, cardinality: usize) -> Self {
1592        Self {
1593            inner,
1594            cardinality,
1595            unique_values: None,
1596        }
1597    }
1598}
1599
1600impl ArrayGenerator for LowCardinalityGenerator {
1601    fn generate(
1602        &mut self,
1603        length: RowCount,
1604        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1605    ) -> Result<Arc<dyn Array>, ArrowError> {
1606        // Generate unique values on first call
1607        if self.unique_values.is_none() {
1608            self.unique_values = Some(
1609                self.inner
1610                    .generate(RowCount::from(self.cardinality as u64), rng)?,
1611            );
1612        }
1613
1614        let unique_values = self.unique_values.as_ref().unwrap();
1615
1616        // Generate random indices into the unique values
1617        let indices: Vec<usize> = (0..length.0)
1618            .map(|_| rng.random_range(0..self.cardinality))
1619            .collect();
1620
1621        // Use arrow's take to select values
1622        let indices_array =
1623            arrow_array::UInt32Array::from(indices.iter().map(|&i| i as u32).collect::<Vec<_>>());
1624        arrow::compute::take(unique_values.as_ref(), &indices_array, None)
1625            .map(|arr| arr as Arc<dyn Array>)
1626    }
1627
1628    fn data_type(&self) -> &DataType {
1629        self.inner.data_type()
1630    }
1631
1632    fn element_size_bytes(&self) -> Option<ByteCount> {
1633        self.inner.element_size_bytes()
1634    }
1635}
1636
1637#[derive(Debug)]
1638struct RandomListGenerator {
1639    field: Arc<Field>,
1640    child_field: Arc<Field>,
1641    items_gen: Box<dyn ArrayGenerator>,
1642    lengths_gen: Box<dyn ArrayGenerator>,
1643    is_large: bool,
1644}
1645
1646impl RandomListGenerator {
1647    // Creates a list generator that generates random lists with lengths between 0 and 10 (inclusive)
1648    fn new(items_gen: Box<dyn ArrayGenerator>, is_large: bool) -> Self {
1649        let child_field = Arc::new(Field::new("item", items_gen.data_type().clone(), true));
1650        let list_type = if is_large {
1651            DataType::LargeList(child_field.clone())
1652        } else {
1653            DataType::List(child_field.clone())
1654        };
1655        let field = Field::new("", list_type, true);
1656        let lengths_gen = if is_large {
1657            let lengths_dist = Uniform::new_inclusive(0, 10).unwrap();
1658            rand_with_distribution::<Int64Type, Uniform<i64>>(lengths_dist)
1659        } else {
1660            let lengths_dist = Uniform::new_inclusive(0, 10).unwrap();
1661            rand_with_distribution::<Int32Type, Uniform<i32>>(lengths_dist)
1662        };
1663        Self {
1664            field: Arc::new(field),
1665            child_field,
1666            items_gen,
1667            lengths_gen,
1668            is_large,
1669        }
1670    }
1671}
1672
1673impl ArrayGenerator for RandomListGenerator {
1674    fn generate(
1675        &mut self,
1676        length: RowCount,
1677        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1678    ) -> Result<Arc<dyn Array>, ArrowError> {
1679        let lengths = self.lengths_gen.generate(length, rng)?;
1680        if self.is_large {
1681            let lengths = lengths.as_primitive::<Int64Type>();
1682            let total_length = lengths.values().iter().sum::<i64>() as u64;
1683            let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1684            let items = self.items_gen.generate(RowCount::from(total_length), rng)?;
1685            Ok(Arc::new(LargeListArray::try_new(
1686                self.child_field.clone(),
1687                offsets,
1688                items,
1689                None,
1690            )?))
1691        } else {
1692            let lengths = lengths.as_primitive::<Int32Type>();
1693            let total_length = lengths.values().iter().sum::<i32>() as u64;
1694            let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1695            let items = self.items_gen.generate(RowCount::from(total_length), rng)?;
1696            Ok(Arc::new(ListArray::try_new(
1697                self.child_field.clone(),
1698                offsets,
1699                items,
1700                None,
1701            )?))
1702        }
1703    }
1704
1705    fn data_type(&self) -> &DataType {
1706        self.field.data_type()
1707    }
1708
1709    fn element_size_bytes(&self) -> Option<ByteCount> {
1710        None
1711    }
1712}
1713
1714/// Generates random map arrays where each map has 0-4 entries.
1715#[derive(Debug)]
1716struct RandomMapGenerator {
1717    field: Arc<Field>,
1718    entries_field: Arc<Field>,
1719    keys_gen: Box<dyn ArrayGenerator>,
1720    values_gen: Box<dyn ArrayGenerator>,
1721    lengths_gen: Box<dyn ArrayGenerator>,
1722}
1723
1724impl RandomMapGenerator {
1725    fn new(keys_gen: Box<dyn ArrayGenerator>, values_gen: Box<dyn ArrayGenerator>) -> Self {
1726        let entries_fields = Fields::from(vec![
1727            Field::new("keys", keys_gen.data_type().clone(), false),
1728            Field::new("values", values_gen.data_type().clone(), true),
1729        ]);
1730        let entries_field = Arc::new(Field::new(
1731            "entries",
1732            DataType::Struct(entries_fields),
1733            false,
1734        ));
1735        let map_type = DataType::Map(entries_field.clone(), false);
1736        let field = Arc::new(Field::new("", map_type, true));
1737        let lengths_dist = Uniform::new_inclusive(0_i32, 4).unwrap();
1738        let lengths_gen = rand_with_distribution::<Int32Type, Uniform<i32>>(lengths_dist);
1739
1740        Self {
1741            field,
1742            entries_field,
1743            keys_gen,
1744            values_gen,
1745            lengths_gen,
1746        }
1747    }
1748}
1749
1750impl ArrayGenerator for RandomMapGenerator {
1751    fn generate(
1752        &mut self,
1753        length: RowCount,
1754        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1755    ) -> Result<Arc<dyn Array>, ArrowError> {
1756        let lengths = self.lengths_gen.generate(length, rng)?;
1757        let lengths = lengths.as_primitive::<Int32Type>();
1758        let total_entries = lengths.values().iter().sum::<i32>() as u64;
1759        let offsets = OffsetBuffer::from_lengths(lengths.values().iter().map(|v| *v as usize));
1760
1761        let keys = self.keys_gen.generate(RowCount::from(total_entries), rng)?;
1762        let values = self
1763            .values_gen
1764            .generate(RowCount::from(total_entries), rng)?;
1765
1766        let entries = StructArray::new(
1767            Fields::from(vec![
1768                Field::new("keys", keys.data_type().clone(), false),
1769                Field::new("values", values.data_type().clone(), true),
1770            ]),
1771            vec![keys, values],
1772            None,
1773        );
1774
1775        Ok(Arc::new(MapArray::try_new(
1776            self.entries_field.clone(),
1777            offsets,
1778            entries,
1779            None,
1780            false,
1781        )?))
1782    }
1783
1784    fn data_type(&self) -> &DataType {
1785        self.field.data_type()
1786    }
1787
1788    fn element_size_bytes(&self) -> Option<ByteCount> {
1789        None
1790    }
1791}
1792
1793#[derive(Debug)]
1794struct NullArrayGenerator {}
1795
1796impl ArrayGenerator for NullArrayGenerator {
1797    fn generate(
1798        &mut self,
1799        length: RowCount,
1800        _: &mut rand_xoshiro::Xoshiro256PlusPlus,
1801    ) -> Result<Arc<dyn Array>, ArrowError> {
1802        Ok(Arc::new(NullArray::new(length.0 as usize)))
1803    }
1804
1805    fn data_type(&self) -> &DataType {
1806        &DataType::Null
1807    }
1808
1809    fn element_size_bytes(&self) -> Option<ByteCount> {
1810        None
1811    }
1812}
1813
1814/// Generates 2 dimensional vectors along the unit circle, with a configurable number of steps per circle.
1815#[derive(Debug)]
1816struct RadialStepGenerator {
1817    num_steps_per_circle: u32,
1818    data_field: Arc<Field>,
1819    data_type: DataType,
1820    current_step: u32,
1821}
1822
1823impl RadialStepGenerator {
1824    fn new(num_steps_per_circle: u32) -> Self {
1825        let data_field = Arc::new(Field::new("item", DataType::Float32, false));
1826        let data_type = DataType::FixedSizeList(data_field.clone(), 2);
1827        Self {
1828            num_steps_per_circle,
1829            data_field,
1830            data_type,
1831            current_step: 0,
1832        }
1833    }
1834}
1835
1836impl ArrayGenerator for RadialStepGenerator {
1837    fn generate(
1838        &mut self,
1839        length: RowCount,
1840        _rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1841    ) -> Result<Arc<dyn Array>, ArrowError> {
1842        let mut values_builder = Float32Builder::with_capacity(length.0 as usize * 2);
1843        for _ in 0..length.0 {
1844            let angle = (self.current_step as f32) / (self.num_steps_per_circle as f32)
1845                * 2.0
1846                * std::f32::consts::PI;
1847            values_builder.append_value(angle.cos());
1848            values_builder.append_value(angle.sin());
1849            self.current_step = (self.current_step + 1) % self.num_steps_per_circle;
1850        }
1851        let values = values_builder.finish();
1852        let vectors =
1853            FixedSizeListArray::try_new(self.data_field.clone(), 2, Arc::new(values), None)?;
1854        Ok(Arc::new(vectors))
1855    }
1856
1857    fn data_type(&self) -> &DataType {
1858        &self.data_type
1859    }
1860
1861    fn element_size_bytes(&self) -> Option<ByteCount> {
1862        Some(ByteCount::from(8))
1863    }
1864}
1865
1866/// Cycles through a set of centroids, adding noise to each point
1867#[derive(Debug)]
1868struct JitterCentroidsGenerator {
1869    centroids: Float32Array,
1870    dimension: u32,
1871    noise_level: f32,
1872    data_type: DataType,
1873    data_field: Arc<Field>,
1874
1875    offset: usize,
1876}
1877
1878impl JitterCentroidsGenerator {
1879    fn try_new(centroids: Arc<dyn Array>, noise_level: f32) -> Result<Self, ArrowError> {
1880        let DataType::FixedSizeList(values_field, dimension) = centroids.data_type() else {
1881            return Err(ArrowError::InvalidArgumentError(
1882                "Centroids must be a FixedSizeList".to_string(),
1883            ));
1884        };
1885        if values_field.data_type() != &DataType::Float32 {
1886            return Err(ArrowError::InvalidArgumentError(
1887                "Centroids values must be a Float32".to_string(),
1888            ));
1889        }
1890        let data_type = DataType::FixedSizeList(values_field.clone(), *dimension);
1891        Ok(Self {
1892            centroids: centroids
1893                .as_fixed_size_list()
1894                .values()
1895                .as_primitive::<Float32Type>()
1896                .clone(),
1897            dimension: *dimension as u32,
1898            noise_level,
1899            data_type,
1900            data_field: values_field.clone(),
1901            offset: 0,
1902        })
1903    }
1904}
1905
1906impl ArrayGenerator for JitterCentroidsGenerator {
1907    fn generate(
1908        &mut self,
1909        length: RowCount,
1910        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1911    ) -> Result<Arc<dyn Array>, ArrowError> {
1912        let mut values_builder =
1913            Float32Builder::with_capacity(length.0 as usize * self.dimension as usize);
1914        for _ in 0..length.0 {
1915            // Generate random N dimensional point
1916            let mut noise = (0..self.dimension as usize)
1917                .map(|_| rng.random::<f32>())
1918                .collect::<Vec<_>>();
1919            // Scale point to noise_level length
1920            let scale = self.noise_level / noise.iter().map(|v| v * v).sum::<f32>().sqrt();
1921            noise.iter_mut().for_each(|v| *v *= scale);
1922
1923            // Add noise to centroid and store in values
1924            for (i, noise) in noise.into_iter().enumerate() {
1925                let centroid_val = self.centroids.value(self.offset + i);
1926                let jittered_val = centroid_val + noise;
1927                values_builder.append_value(jittered_val);
1928            }
1929            // Advance to next centroid
1930            self.offset = (self.offset + self.dimension as usize) % self.centroids.len();
1931        }
1932        let values = values_builder.finish();
1933        let vectors = FixedSizeListArray::try_new(
1934            self.data_field.clone(),
1935            self.dimension as i32,
1936            Arc::new(values),
1937            None,
1938        )?;
1939        Ok(Arc::new(vectors))
1940    }
1941
1942    fn data_type(&self) -> &DataType {
1943        &self.data_type
1944    }
1945
1946    fn element_size_bytes(&self) -> Option<ByteCount> {
1947        Some(ByteCount::from(self.dimension as u64 * 4))
1948    }
1949}
1950#[derive(Debug)]
1951struct RandomStructGenerator {
1952    fields: Fields,
1953    data_type: DataType,
1954    child_gens: Vec<Box<dyn ArrayGenerator>>,
1955}
1956
1957impl RandomStructGenerator {
1958    fn new(fields: Fields, child_gens: Vec<Box<dyn ArrayGenerator>>) -> Self {
1959        let data_type = DataType::Struct(fields.clone());
1960        Self {
1961            fields,
1962            data_type,
1963            child_gens,
1964        }
1965    }
1966}
1967
1968impl ArrayGenerator for RandomStructGenerator {
1969    fn generate(
1970        &mut self,
1971        length: RowCount,
1972        rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
1973    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
1974        if self.child_gens.is_empty() {
1975            // Have to create empty struct arrays specially to ensure they have the correct
1976            // row count
1977            let struct_arr = StructArray::new_empty_fields(length.0 as usize, None);
1978            return Ok(Arc::new(struct_arr));
1979        }
1980        let child_arrays = self
1981            .child_gens
1982            .iter_mut()
1983            .map(|genn| genn.generate(length, rng))
1984            .collect::<Result<Vec<_>, ArrowError>>()?;
1985        let struct_arr = StructArray::new(self.fields.clone(), child_arrays, None);
1986        Ok(Arc::new(struct_arr))
1987    }
1988
1989    fn data_type(&self) -> &DataType {
1990        &self.data_type
1991    }
1992
1993    fn element_size_bytes(&self) -> Option<ByteCount> {
1994        let mut sum = 0;
1995        for child_gen in &self.child_gens {
1996            sum += child_gen.element_size_bytes()?.0;
1997        }
1998        Some(ByteCount::from(sum))
1999    }
2000}
2001
2002/// A RecordBatchReader that generates batches of the given size from the given array generators
2003pub struct FixedSizeBatchGenerator {
2004    rng: rand_xoshiro::Xoshiro256PlusPlus,
2005    generators: Vec<Box<dyn ArrayGenerator>>,
2006    batch_size: RowCount,
2007    num_batches: BatchCount,
2008    schema: SchemaRef,
2009}
2010
2011impl FixedSizeBatchGenerator {
2012    fn new(
2013        generators: Vec<(Option<String>, Box<dyn ArrayGenerator>)>,
2014        batch_size: RowCount,
2015        num_batches: BatchCount,
2016        seed: Option<Seed>,
2017        default_null_probability: Option<f64>,
2018    ) -> Self {
2019        let mut fields = Vec::with_capacity(generators.len());
2020        for (field_index, field_gen) in generators.iter().enumerate() {
2021            let (name, genn) = field_gen;
2022            let default_name = format!("field_{}", field_index);
2023            let name = name.clone().unwrap_or(default_name);
2024            let mut field = Field::new(name, genn.data_type().clone(), true);
2025            if let Some(metadata) = genn.metadata() {
2026                field = field.with_metadata(metadata);
2027            }
2028            fields.push(field);
2029        }
2030        let mut generators = generators
2031            .into_iter()
2032            .map(|(_, genn)| genn)
2033            .collect::<Vec<_>>();
2034        if let Some(null_probability) = default_null_probability {
2035            generators = generators
2036                .into_iter()
2037                .map(|genn| genn.with_random_nulls(null_probability))
2038                .collect();
2039        }
2040        let schema = Arc::new(Schema::new(fields));
2041        Self {
2042            rng: rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(
2043                seed.map(|s| s.0).unwrap_or(DEFAULT_SEED.0),
2044            ),
2045            generators,
2046            batch_size,
2047            num_batches,
2048            schema,
2049        }
2050    }
2051
2052    fn gen_next(&mut self) -> Result<RecordBatch, ArrowError> {
2053        let mut arrays = Vec::with_capacity(self.generators.len());
2054        for genn in self.generators.iter_mut() {
2055            let arr = genn.generate(self.batch_size, &mut self.rng)?;
2056            arrays.push(arr);
2057        }
2058        self.num_batches.0 -= 1;
2059        Ok(RecordBatch::try_new_with_options(
2060            self.schema.clone(),
2061            arrays,
2062            &RecordBatchOptions::new().with_row_count(Some(self.batch_size.0 as usize)),
2063        )
2064        .unwrap())
2065    }
2066}
2067
2068impl Iterator for FixedSizeBatchGenerator {
2069    type Item = Result<RecordBatch, ArrowError>;
2070
2071    fn next(&mut self) -> Option<Self::Item> {
2072        if self.num_batches.0 == 0 {
2073            return None;
2074        }
2075        Some(self.gen_next())
2076    }
2077}
2078
2079impl RecordBatchReader for FixedSizeBatchGenerator {
2080    fn schema(&self) -> SchemaRef {
2081        self.schema.clone()
2082    }
2083}
2084
2085/// A builder to create a record batch reader with generated data
2086///
2087/// This type is meant to be used in a fluent builder style to define the schema and generators
2088/// for a record batch reader.
2089#[derive(Default)]
2090pub struct BatchGeneratorBuilder {
2091    generators: Vec<(Option<String>, Box<dyn ArrayGenerator>)>,
2092    default_null_probability: Option<f64>,
2093    seed: Option<Seed>,
2094}
2095
2096pub enum RoundingBehavior {
2097    ExactOrErr,
2098    RoundUp,
2099    RoundDown,
2100}
2101
2102impl BatchGeneratorBuilder {
2103    /// Create a new BatchGeneratorBuilder with a default random seed
2104    pub fn new() -> Self {
2105        Default::default()
2106    }
2107
2108    /// Create a new BatchGeneratorBuilder with the given seed
2109    pub fn new_with_seed(seed: Seed) -> Self {
2110        Self {
2111            seed: Some(seed),
2112            ..Default::default()
2113        }
2114    }
2115
2116    /// Adds a new column to the generator
2117    ///
2118    /// See [`crate::generator::array`] for methods to create generators
2119    pub fn col(mut self, name: impl Into<String>, genn: Box<dyn ArrayGenerator>) -> Self {
2120        self.generators.push((Some(name.into()), genn));
2121        self
2122    }
2123
2124    /// Adds a new column to the generator with a generated unique name
2125    ///
2126    /// See [`crate::generator::array`] for methods to create generators
2127    pub fn anon_col(mut self, genn: Box<dyn ArrayGenerator>) -> Self {
2128        self.generators.push((None, genn));
2129        self
2130    }
2131
2132    pub fn into_batch_rows(self, batch_size: RowCount) -> Result<RecordBatch, ArrowError> {
2133        let mut reader = self.into_reader_rows(batch_size, BatchCount::from(1));
2134        reader
2135            .next()
2136            .expect("Asked for 1 batch but reader was empty")
2137    }
2138
2139    pub fn into_batch_bytes(
2140        self,
2141        batch_size: ByteCount,
2142        rounding: RoundingBehavior,
2143    ) -> Result<RecordBatch, ArrowError> {
2144        let mut reader = self.into_reader_bytes(batch_size, BatchCount::from(1), rounding)?;
2145        reader
2146            .next()
2147            .expect("Asked for 1 batch but reader was empty")
2148    }
2149
2150    /// Create a RecordBatchReader that generates batches of the given size (in rows)
2151    pub fn into_reader_rows(
2152        self,
2153        batch_size: RowCount,
2154        num_batches: BatchCount,
2155    ) -> impl RecordBatchReader {
2156        FixedSizeBatchGenerator::new(
2157            self.generators,
2158            batch_size,
2159            num_batches,
2160            self.seed,
2161            self.default_null_probability,
2162        )
2163    }
2164
2165    pub fn into_reader_stream(
2166        self,
2167        batch_size: RowCount,
2168        num_batches: BatchCount,
2169    ) -> (
2170        BoxStream<'static, Result<RecordBatch, ArrowError>>,
2171        Arc<Schema>,
2172    ) {
2173        // TODO: this is pretty lazy and could be optimized
2174        let reader = self.into_reader_rows(batch_size, num_batches);
2175        let schema = reader.schema();
2176        let batches = reader.collect::<Vec<_>>();
2177        (futures::stream::iter(batches).boxed(), schema)
2178    }
2179
2180    /// Create a RecordBatchReader that generates batches of the given size (in bytes)
2181    pub fn into_reader_bytes(
2182        self,
2183        batch_size_bytes: ByteCount,
2184        num_batches: BatchCount,
2185        rounding: RoundingBehavior,
2186    ) -> Result<impl RecordBatchReader, ArrowError> {
2187        let bytes_per_row = self
2188            .generators
2189            .iter()
2190            .map(|genn| genn.1.element_size_bytes().map(|byte_count| byte_count.0).ok_or(
2191                        ArrowError::NotYetImplemented("The function into_reader_bytes currently requires each array generator to have a fixed element size".to_string())
2192                )
2193            )
2194            .sum::<Result<u64, ArrowError>>()?;
2195        let mut num_rows = RowCount::from(batch_size_bytes.0 / bytes_per_row);
2196        if !batch_size_bytes.0.is_multiple_of(bytes_per_row) {
2197            match rounding {
2198                RoundingBehavior::ExactOrErr => {
2199                    return Err(ArrowError::NotYetImplemented(format!(
2200                        "Exact rounding requested but not possible.  Batch size requested {}, row size: {}",
2201                        batch_size_bytes.0, bytes_per_row
2202                    )));
2203                }
2204                RoundingBehavior::RoundUp => {
2205                    num_rows = RowCount::from(num_rows.0 + 1);
2206                }
2207                RoundingBehavior::RoundDown => (),
2208            }
2209        }
2210        Ok(self.into_reader_rows(num_rows, num_batches))
2211    }
2212
2213    /// Set the seed for the generator
2214    pub fn with_seed(mut self, seed: Seed) -> Self {
2215        self.seed = Some(seed);
2216        self
2217    }
2218
2219    /// Adds nulls (with the given probability) to all columns
2220    pub fn with_random_nulls(&mut self, default_null_probability: f64) {
2221        self.default_null_probability = Some(default_null_probability);
2222    }
2223}
2224
2225/// Factory for creating a single random array
2226pub struct ArrayGeneratorBuilder {
2227    generator: Box<dyn ArrayGenerator>,
2228    seed: Option<Seed>,
2229}
2230
2231impl ArrayGeneratorBuilder {
2232    fn new(generator: Box<dyn ArrayGenerator>) -> Self {
2233        Self {
2234            generator,
2235            seed: None,
2236        }
2237    }
2238
2239    /// Use the given seed for the generator
2240    pub fn with_seed(mut self, seed: Seed) -> Self {
2241        self.seed = Some(seed);
2242        self
2243    }
2244
2245    /// Generate a single array with the given length
2246    pub fn into_array_rows(
2247        mut self,
2248        length: RowCount,
2249    ) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
2250        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(
2251            self.seed.map(|s| s.0).unwrap_or(DEFAULT_SEED.0),
2252        );
2253        self.generator.generate(length, &mut rng)
2254    }
2255}
2256
2257const MS_PER_DAY: i64 = 86400000;
2258
2259pub mod array {
2260
2261    use arrow::datatypes::{Int8Type, Int16Type, Int64Type};
2262    use arrow_array::types::{
2263        Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType,
2264        DurationNanosecondType, DurationSecondType, Float16Type, Float32Type, Float64Type,
2265        UInt8Type, UInt16Type, UInt32Type, UInt64Type,
2266    };
2267    use arrow_array::{
2268        ArrowNativeTypeOp, BooleanArray, Date32Array, Date64Array, Time32MillisecondArray,
2269        Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
2270        TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
2271        TimestampSecondArray,
2272    };
2273    use arrow_schema::{IntervalUnit, TimeUnit};
2274    use chrono::Utc;
2275    use rand::prelude::Distribution;
2276
2277    use super::*;
2278
2279    /// Create a generator of vectors by continuously calling the given generator
2280    ///
2281    /// For example, given a step generator and a dimension of 3 this will generate vectors like
2282    /// [0, 1, 2], [3, 4, 5], [6, 7, 8], ...
2283    pub fn cycle_vec(
2284        generator: Box<dyn ArrayGenerator>,
2285        dimension: Dimension,
2286    ) -> Box<dyn ArrayGenerator> {
2287        Box::new(CycleVectorGenerator::new(generator, dimension))
2288    }
2289
2290    /// Create a generator of list vectors by continuously calling the given generator
2291    ///
2292    /// The lists will have lengths uniformly distributed between `min_list_size` (inclusive) and
2293    /// `max_list_size` (exclusive).
2294    pub fn cycle_vec_var(
2295        generator: Box<dyn ArrayGenerator>,
2296        min_list_size: Dimension,
2297        max_list_size: Dimension,
2298    ) -> Box<dyn ArrayGenerator> {
2299        Box::new(CycleListGenerator::new(
2300            generator,
2301            min_list_size,
2302            max_list_size,
2303        ))
2304    }
2305
2306    /// Create a generator of vectors around unit circle
2307    ///
2308    /// Vectors will be equally spaced around the unit circle so that there are num_steps
2309    /// vectors per circle.
2310    pub fn cycle_unit_circle(num_steps: u32) -> Box<dyn ArrayGenerator> {
2311        Box::new(RadialStepGenerator::new(num_steps))
2312    }
2313
2314    /// Create a generator of vectors by cycling through a given set of vectors
2315    ///
2316    /// Each value will be spaced in slightly away from the previous value on a ball of radius jitter
2317    pub fn jitter_centroids(centroids: Arc<dyn Array>, jitter: f32) -> Box<dyn ArrayGenerator> {
2318        Box::new(JitterCentroidsGenerator::try_new(centroids, jitter).unwrap())
2319    }
2320
2321    /// Create a generator from a vector of values
2322    ///
2323    /// If more rows are requested than the length of values then it will restart
2324    /// from the beginning of the vector.
2325    pub fn cycle<DataType>(values: Vec<DataType::Native>) -> Box<dyn ArrayGenerator>
2326    where
2327        DataType::Native: Copy + 'static,
2328        DataType: ArrowPrimitiveType,
2329        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2330    {
2331        let mut values_idx = 0;
2332        Box::new(
2333            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2334                DataType::DATA_TYPE,
2335                move |_| {
2336                    let y = values[values_idx];
2337                    values_idx = (values_idx + 1) % values.len();
2338                    y
2339                },
2340                1,
2341                DataType::DATA_TYPE
2342                    .primitive_width()
2343                    .map(|width| ByteCount::from(width as u64))
2344                    .expect("Primitive types should have a fixed width"),
2345            ),
2346        )
2347    }
2348
2349    /// Create a generator from a vector of booleans
2350    ///
2351    /// If more rows are requested than the length of values then it will restart from
2352    /// the beginning of the vector
2353    pub fn cycle_bool(values: Vec<bool>) -> Box<dyn ArrayGenerator> {
2354        let mut values_idx = 0;
2355        Box::new(FnGen::<bool, BooleanArray, _>::new_unknown_size(
2356            DataType::Boolean,
2357            move |_| {
2358                let val = values[values_idx];
2359                values_idx = (values_idx + 1) % values.len();
2360                val
2361            },
2362            1,
2363        ))
2364    }
2365
2366    /// Create a generator that starts at 0 and increments by 1 for each element
2367    pub fn step<DataType>() -> Box<dyn ArrayGenerator>
2368    where
2369        DataType::Native: Copy + Default + std::ops::AddAssign<DataType::Native> + 'static,
2370        DataType: ArrowPrimitiveType,
2371        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2372    {
2373        let mut x = DataType::Native::default();
2374        Box::new(
2375            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2376                DataType::DATA_TYPE,
2377                move |_| {
2378                    let y = x;
2379                    x += DataType::Native::ONE;
2380                    y
2381                },
2382                1,
2383                DataType::DATA_TYPE
2384                    .primitive_width()
2385                    .map(|width| ByteCount::from(width as u64))
2386                    .expect("Primitive types should have a fixed width"),
2387            ),
2388        )
2389    }
2390
2391    pub fn blob() -> Box<dyn ArrayGenerator> {
2392        let mut blob_meta = HashMap::new();
2393        blob_meta.insert("lance-encoding:blob".to_string(), "true".to_string());
2394        rand_fixedbin(ByteCount::from(4 * 1024 * 1024), true).with_metadata(blob_meta)
2395    }
2396
2397    /// Create a generator that starts at a given value and increments by a given step for each element
2398    pub fn step_custom<DataType>(
2399        start: DataType::Native,
2400        step: DataType::Native,
2401    ) -> Box<dyn ArrayGenerator>
2402    where
2403        DataType::Native: Copy + Default + std::ops::AddAssign<DataType::Native> + 'static,
2404        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2405        DataType: ArrowPrimitiveType,
2406    {
2407        let mut x = start;
2408        Box::new(
2409            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2410                DataType::DATA_TYPE,
2411                move |_| {
2412                    let y = x;
2413                    x += step;
2414                    y
2415                },
2416                1,
2417                DataType::DATA_TYPE
2418                    .primitive_width()
2419                    .map(|width| ByteCount::from(width as u64))
2420                    .expect("Primitive types should have a fixed width"),
2421            ),
2422        )
2423    }
2424
2425    /// Create a generator that fills each element with the given primitive value
2426    pub fn fill<DataType>(value: DataType::Native) -> Box<dyn ArrayGenerator>
2427    where
2428        DataType::Native: Copy + 'static,
2429        DataType: ArrowPrimitiveType,
2430        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2431    {
2432        Box::new(
2433            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2434                DataType::DATA_TYPE,
2435                move |_| value,
2436                1,
2437                DataType::DATA_TYPE
2438                    .primitive_width()
2439                    .map(|width| ByteCount::from(width as u64))
2440                    .expect("Primitive types should have a fixed width"),
2441            ),
2442        )
2443    }
2444
2445    /// Create a generator that fills each element with the given binary value
2446    pub fn fill_varbin(value: Vec<u8>) -> Box<dyn ArrayGenerator> {
2447        Box::new(FixedBinaryGenerator::<BinaryType>::new(value))
2448    }
2449
2450    /// Create a generator that fills each element with the given string value
2451    pub fn fill_utf8(value: String) -> Box<dyn ArrayGenerator> {
2452        Box::new(FixedBinaryGenerator::<Utf8Type>::new(value.into_bytes()))
2453    }
2454
2455    pub fn cycle_utf8_literals(values: &[&'static str]) -> Box<dyn ArrayGenerator> {
2456        Box::new(CycleBinaryGenerator::<Utf8Type>::from_strings(values))
2457    }
2458
2459    /// Create a generator of primitive values that are randomly sampled from the entire range available for the value
2460    pub fn rand<DataType>() -> Box<dyn ArrayGenerator>
2461    where
2462        DataType::Native: Copy + 'static,
2463        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2464        DataType: ArrowPrimitiveType,
2465        rand::distr::StandardUniform: rand::distr::Distribution<DataType::Native>,
2466    {
2467        Box::new(
2468            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2469                DataType::DATA_TYPE,
2470                move |rng| rng.random(),
2471                1,
2472                DataType::DATA_TYPE
2473                    .primitive_width()
2474                    .map(|width| ByteCount::from(width as u64))
2475                    .expect("Primitive types should have a fixed width"),
2476            ),
2477        )
2478    }
2479
2480    /// Create a generator of primitive values that are randomly sampled from the entire range available for the value
2481    pub fn rand_with_distribution<
2482        DataType,
2483        Dist: rand::distr::Distribution<DataType::Native> + Clone + Send + Sync + 'static,
2484    >(
2485        dist: Dist,
2486    ) -> Box<dyn ArrayGenerator>
2487    where
2488        DataType::Native: Copy + 'static,
2489        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2490        DataType: ArrowPrimitiveType,
2491    {
2492        Box::new(
2493            FnGen::<DataType::Native, PrimitiveArray<DataType>, _>::new_known_size(
2494                DataType::DATA_TYPE,
2495                move |rng| rng.sample(dist.clone()),
2496                1,
2497                DataType::DATA_TYPE
2498                    .primitive_width()
2499                    .map(|width| ByteCount::from(width as u64))
2500                    .expect("Primitive types should have a fixed width"),
2501            ),
2502        )
2503    }
2504
2505    /// Create a generator of 1d vectors (of a primitive type) consisting of randomly sampled primitive values
2506    pub fn rand_vec<DataType>(dimension: Dimension) -> Box<dyn ArrayGenerator>
2507    where
2508        DataType::Native: Copy + 'static,
2509        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2510        DataType: ArrowPrimitiveType,
2511        rand::distr::StandardUniform: rand::distr::Distribution<DataType::Native>,
2512    {
2513        let underlying = rand::<DataType>();
2514        cycle_vec(underlying, dimension)
2515    }
2516
2517    /// Create a generator of 1d vectors (of a primitive type) consisting of randomly sampled nullable values
2518    pub fn rand_vec_nullable<DataType>(
2519        dimension: Dimension,
2520        null_probability: f64,
2521    ) -> Box<dyn ArrayGenerator>
2522    where
2523        DataType::Native: Copy + 'static,
2524        PrimitiveArray<DataType>: From<Vec<DataType::Native>> + 'static,
2525        DataType: ArrowPrimitiveType,
2526        rand::distr::StandardUniform: rand::distr::Distribution<DataType::Native>,
2527    {
2528        let underlying = rand::<DataType>().with_random_nulls(null_probability);
2529        cycle_vec(underlying, dimension)
2530    }
2531
2532    /// Create a generator of randomly sampled time32 values covering the entire
2533    /// range of 1 day
2534    pub fn rand_time32(resolution: &TimeUnit) -> Box<dyn ArrayGenerator> {
2535        let start = 0;
2536        let end = match resolution {
2537            TimeUnit::Second => 86_400,
2538            TimeUnit::Millisecond => 86_400_000,
2539            _ => panic!(),
2540        };
2541
2542        let data_type = DataType::Time32(*resolution);
2543        let size = ByteCount::from(data_type.primitive_width().unwrap() as u64);
2544        let dist = Uniform::new(start, end).unwrap();
2545        let sample_fn = move |rng: &mut _| dist.sample(rng);
2546
2547        match resolution {
2548            TimeUnit::Second => Box::new(FnGen::<i32, Time32SecondArray, _>::new_known_size(
2549                data_type, sample_fn, 1, size,
2550            )),
2551            TimeUnit::Millisecond => {
2552                Box::new(FnGen::<i32, Time32MillisecondArray, _>::new_known_size(
2553                    data_type, sample_fn, 1, size,
2554                ))
2555            }
2556            _ => panic!(),
2557        }
2558    }
2559
2560    /// Create a generator of randomly sampled time64 values covering the entire
2561    /// range of 1 day
2562    pub fn rand_time64(resolution: &TimeUnit) -> Box<dyn ArrayGenerator> {
2563        let start = 0_i64;
2564        let end: i64 = match resolution {
2565            TimeUnit::Microsecond => 86_400_000,
2566            TimeUnit::Nanosecond => 86_400_000_000,
2567            _ => panic!(),
2568        };
2569
2570        let data_type = DataType::Time64(*resolution);
2571        let size = ByteCount::from(data_type.primitive_width().unwrap() as u64);
2572        let dist = Uniform::new(start, end).unwrap();
2573        let sample_fn = move |rng: &mut _| dist.sample(rng);
2574
2575        match resolution {
2576            TimeUnit::Microsecond => {
2577                Box::new(FnGen::<i64, Time64MicrosecondArray, _>::new_known_size(
2578                    data_type, sample_fn, 1, size,
2579                ))
2580            }
2581            TimeUnit::Nanosecond => {
2582                Box::new(FnGen::<i64, Time64NanosecondArray, _>::new_known_size(
2583                    data_type, sample_fn, 1, size,
2584                ))
2585            }
2586            _ => panic!(),
2587        }
2588    }
2589
2590    /// Create a generator of random UUIDs, stored as fixed size binary values
2591    ///
2592    /// Note, these are "pseudo UUIDs".  They are 16-byte randomish values but they
2593    /// are not guaranteed to be unique.  We use a simplistic RNG that trades uniqueness
2594    /// for speed.
2595    pub fn rand_pseudo_uuid() -> Box<dyn ArrayGenerator> {
2596        Box::<PseudoUuidGenerator>::default()
2597    }
2598
2599    /// Create a generator of random UUIDs, stored as 32-character strings (hex encoding
2600    /// of the 16-byte binary value)
2601    ///
2602    /// Note, these are "pseudo UUIDs".  They are 16-byte randomish values but they
2603    /// are not guaranteed to be unique.  We use a simplistic RNG that trades uniqueness
2604    /// for speed.
2605    pub fn rand_pseudo_uuid_hex() -> Box<dyn ArrayGenerator> {
2606        Box::<PseudoUuidHexGenerator>::default()
2607    }
2608
2609    pub fn rand_primitive<T: ArrowPrimitiveType + Send + Sync>(
2610        data_type: DataType,
2611    ) -> Box<dyn ArrayGenerator> {
2612        Box::new(RandomBytesGenerator::<T>::new(data_type))
2613    }
2614
2615    pub fn rand_fsb(size: i32) -> Box<dyn ArrayGenerator> {
2616        Box::new(RandomFixedSizeBinaryGenerator::new(size))
2617    }
2618
2619    pub fn rand_interval(unit: IntervalUnit) -> Box<dyn ArrayGenerator> {
2620        Box::new(RandomIntervalGenerator::new(unit))
2621    }
2622
2623    /// Create a generator of randomly sampled date32 values
2624    ///
2625    /// Instead of sampling the entire range, all values will be drawn from the last year as this
2626    /// is a more common use pattern
2627    pub fn rand_date32() -> Box<dyn ArrayGenerator> {
2628        let now = chrono::Utc::now();
2629        let one_year_ago = now - chrono::TimeDelta::try_days(365).expect("TimeDelta try days");
2630        rand_date32_in_range(one_year_ago, now)
2631    }
2632
2633    /// Create a generator of randomly sampled date32 values in the given range
2634    pub fn rand_date32_in_range(
2635        start: chrono::DateTime<Utc>,
2636        end: chrono::DateTime<Utc>,
2637    ) -> Box<dyn ArrayGenerator> {
2638        let data_type = DataType::Date32;
2639        let end_ms = end.timestamp_millis();
2640        let end_days = (end_ms / MS_PER_DAY) as i32;
2641        let start_ms = start.timestamp_millis();
2642        let start_days = (start_ms / MS_PER_DAY) as i32;
2643        let dist = Uniform::new(start_days, end_days).unwrap();
2644
2645        Box::new(FnGen::<i32, Date32Array, _>::new_known_size(
2646            data_type,
2647            move |rng| dist.sample(rng),
2648            1,
2649            DataType::Date32
2650                .primitive_width()
2651                .map(|width| ByteCount::from(width as u64))
2652                .expect("Date32 should have a fixed width"),
2653        ))
2654    }
2655
2656    /// Create a generator of randomly sampled date64 values
2657    ///
2658    /// Instead of sampling the entire range, all values will be drawn from the last year as this
2659    /// is a more common use pattern
2660    pub fn rand_date64() -> Box<dyn ArrayGenerator> {
2661        let now = chrono::Utc::now();
2662        let one_year_ago = now - chrono::TimeDelta::try_days(365).expect("TimeDelta try_days");
2663        rand_date64_in_range(one_year_ago, now)
2664    }
2665
2666    /// Create a generator of randomly sampled timestamp values in the given range
2667    ///
2668    /// Currently just samples the entire range of u64 values and casts to timestamp
2669    pub fn rand_timestamp_in_range(
2670        start: chrono::DateTime<Utc>,
2671        end: chrono::DateTime<Utc>,
2672        data_type: &DataType,
2673    ) -> Box<dyn ArrayGenerator> {
2674        let end_ms = end.timestamp_millis();
2675        let start_ms = start.timestamp_millis();
2676        let (start_ticks, end_ticks) = match data_type {
2677            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
2678                (start_ms * 1000 * 1000, end_ms * 1000 * 1000)
2679            }
2680            DataType::Timestamp(TimeUnit::Microsecond, _) => (start_ms * 1000, end_ms * 1000),
2681            DataType::Timestamp(TimeUnit::Millisecond, _) => (start_ms, end_ms),
2682            DataType::Timestamp(TimeUnit::Second, _) => (start.timestamp(), end.timestamp()),
2683            _ => panic!(),
2684        };
2685        let dist = Uniform::new(start_ticks, end_ticks).unwrap();
2686
2687        let data_type = data_type.clone();
2688        let sample_fn = move |rng: &mut _| dist.sample(rng);
2689        let width = data_type
2690            .primitive_width()
2691            .map(|width| ByteCount::from(width as u64))
2692            .unwrap();
2693
2694        match data_type {
2695            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
2696                Box::new(FnGen::<i64, TimestampNanosecondArray, _>::new_known_size(
2697                    data_type, sample_fn, 1, width,
2698                ))
2699            }
2700            DataType::Timestamp(TimeUnit::Microsecond, _) => {
2701                Box::new(FnGen::<i64, TimestampMicrosecondArray, _>::new_known_size(
2702                    data_type, sample_fn, 1, width,
2703                ))
2704            }
2705            DataType::Timestamp(TimeUnit::Millisecond, _) => {
2706                Box::new(FnGen::<i64, TimestampMillisecondArray, _>::new_known_size(
2707                    data_type, sample_fn, 1, width,
2708                ))
2709            }
2710            DataType::Timestamp(TimeUnit::Second, _) => {
2711                Box::new(FnGen::<i64, TimestampSecondArray, _>::new_known_size(
2712                    data_type, sample_fn, 1, width,
2713                ))
2714            }
2715            _ => panic!(),
2716        }
2717    }
2718
2719    pub fn rand_timestamp(data_type: &DataType) -> Box<dyn ArrayGenerator> {
2720        let now = chrono::Utc::now();
2721        let one_year_ago = now - chrono::Duration::try_days(365).unwrap();
2722        rand_timestamp_in_range(one_year_ago, now, data_type)
2723    }
2724
2725    /// Create a generator of randomly sampled date64 values
2726    ///
2727    /// Instead of sampling the entire range, all values will be drawn from the last year as this
2728    /// is a more common use pattern
2729    pub fn rand_date64_in_range(
2730        start: chrono::DateTime<Utc>,
2731        end: chrono::DateTime<Utc>,
2732    ) -> Box<dyn ArrayGenerator> {
2733        let data_type = DataType::Date64;
2734        let end_ms = end.timestamp_millis();
2735        let end_days = end_ms / MS_PER_DAY;
2736        let start_ms = start.timestamp_millis();
2737        let start_days = start_ms / MS_PER_DAY;
2738        let dist = Uniform::new(start_days, end_days).unwrap();
2739
2740        Box::new(FnGen::<i64, Date64Array, _>::new_known_size(
2741            data_type,
2742            move |rng| (dist.sample(rng)) * MS_PER_DAY,
2743            1,
2744            DataType::Date64
2745                .primitive_width()
2746                .map(|width| ByteCount::from(width as u64))
2747                .expect("Date64 should have a fixed width"),
2748        ))
2749    }
2750
2751    /// Create a generator of random binary values where each value has a fixed number of bytes
2752    pub fn rand_fixedbin(bytes_per_element: ByteCount, is_large: bool) -> Box<dyn ArrayGenerator> {
2753        Box::new(RandomBinaryGenerator::new(
2754            bytes_per_element,
2755            false,
2756            is_large,
2757        ))
2758    }
2759
2760    /// Create a generator of random binary values where each value has a variable number of bytes
2761    ///
2762    /// The number of bytes per element will be randomly sampled from the given (inclusive) range
2763    pub fn rand_varbin(
2764        min_bytes_per_element: ByteCount,
2765        max_bytes_per_element: ByteCount,
2766    ) -> Box<dyn ArrayGenerator> {
2767        Box::new(VariableRandomBinaryGenerator::new(
2768            min_bytes_per_element,
2769            max_bytes_per_element,
2770        ))
2771    }
2772
2773    /// Create a generator of random strings
2774    ///
2775    /// All strings will consist entirely of printable ASCII characters
2776    pub fn rand_utf8(bytes_per_element: ByteCount, is_large: bool) -> Box<dyn ArrayGenerator> {
2777        Box::new(RandomBinaryGenerator::new(
2778            bytes_per_element,
2779            true,
2780            is_large,
2781        ))
2782    }
2783
2784    /// Creates a generator of strings with a prefix and a counter
2785    ///
2786    /// For example, if the prefix is "user_" the strings will be "user_0", "user_1", ...
2787    pub fn utf8_prefix_plus_counter(
2788        prefix: impl Into<String>,
2789        is_large: bool,
2790    ) -> Box<dyn ArrayGenerator> {
2791        Box::new(PrefixPlusCounterGenerator::new(prefix.into(), is_large))
2792    }
2793
2794    pub fn binary_prefix_plus_counter(
2795        prefix: Arc<[u8]>,
2796        is_large: bool,
2797    ) -> Box<dyn ArrayGenerator> {
2798        Box::new(BinaryPrefixPlusCounterGenerator::new(prefix, is_large))
2799    }
2800
2801    /// Create a random generator of boolean values
2802    pub fn rand_boolean() -> Box<dyn ArrayGenerator> {
2803        Box::<RandomBooleanGenerator>::default()
2804    }
2805
2806    /// Create a generator of random sentences
2807    ///
2808    /// Generates strings containing between min_words and max_words random English words joined by spaces
2809    pub fn random_sentence(
2810        min_words: usize,
2811        max_words: usize,
2812        is_large: bool,
2813    ) -> Box<dyn ArrayGenerator> {
2814        Box::new(RandomSentenceGenerator::new(min_words, max_words, is_large))
2815    }
2816
2817    /// Create a generator of random words (one word per row)
2818    ///
2819    /// Generates strings containing a single random English word per row
2820    pub fn random_word(is_large: bool) -> Box<dyn ArrayGenerator> {
2821        Box::new(RandomWordGenerator::new(is_large))
2822    }
2823
2824    pub fn rand_list(item_type: &DataType, is_large: bool) -> Box<dyn ArrayGenerator> {
2825        let child_gen = rand_type(item_type);
2826        Box::new(RandomListGenerator::new(child_gen, is_large))
2827    }
2828
2829    pub fn rand_list_any(
2830        item_gen: Box<dyn ArrayGenerator>,
2831        is_large: bool,
2832    ) -> Box<dyn ArrayGenerator> {
2833        Box::new(RandomListGenerator::new(item_gen, is_large))
2834    }
2835
2836    /// Generates random map arrays where each map has 0-4 entries.
2837    pub fn rand_map(key_type: &DataType, value_type: &DataType) -> Box<dyn ArrayGenerator> {
2838        let keys_gen = rand_type(key_type);
2839        let values_gen = rand_type(value_type);
2840        Box::new(RandomMapGenerator::new(keys_gen, values_gen))
2841    }
2842
2843    pub fn rand_struct(fields: Fields) -> Box<dyn ArrayGenerator> {
2844        let child_gens = fields
2845            .iter()
2846            .map(|f| rand_type(f.data_type()))
2847            .collect::<Vec<_>>();
2848        Box::new(RandomStructGenerator::new(fields, child_gens))
2849    }
2850
2851    pub fn null_type() -> Box<dyn ArrayGenerator> {
2852        Box::new(NullArrayGenerator {})
2853    }
2854
2855    /// Create a generator of random values
2856    pub fn rand_type(data_type: &DataType) -> Box<dyn ArrayGenerator> {
2857        match data_type {
2858            DataType::Boolean => rand_boolean(),
2859            DataType::Int8 => rand::<Int8Type>(),
2860            DataType::Int16 => rand::<Int16Type>(),
2861            DataType::Int32 => rand::<Int32Type>(),
2862            DataType::Int64 => rand::<Int64Type>(),
2863            DataType::UInt8 => rand::<UInt8Type>(),
2864            DataType::UInt16 => rand::<UInt16Type>(),
2865            DataType::UInt32 => rand::<UInt32Type>(),
2866            DataType::UInt64 => rand::<UInt64Type>(),
2867            DataType::Float16 => rand_primitive::<Float16Type>(data_type.clone()),
2868            DataType::Float32 => rand::<Float32Type>(),
2869            DataType::Float64 => rand::<Float64Type>(),
2870            DataType::Decimal128(_, _) => rand_primitive::<Decimal128Type>(data_type.clone()),
2871            DataType::Decimal256(_, _) => rand_primitive::<Decimal256Type>(data_type.clone()),
2872            DataType::Utf8 => rand_utf8(ByteCount::from(12), false),
2873            DataType::LargeUtf8 => rand_utf8(ByteCount::from(12), true),
2874            DataType::Binary => rand_fixedbin(ByteCount::from(12), false),
2875            DataType::LargeBinary => rand_fixedbin(ByteCount::from(12), true),
2876            DataType::Dictionary(key_type, value_type) => {
2877                dict_type(rand_type(value_type), key_type)
2878            }
2879            DataType::FixedSizeList(child, dimension) => cycle_vec(
2880                rand_type(child.data_type()),
2881                Dimension::from(*dimension as u32),
2882            ),
2883            DataType::FixedSizeBinary(size) => rand_fsb(*size),
2884            DataType::List(child) => rand_list(child.data_type(), false),
2885            DataType::LargeList(child) => rand_list(child.data_type(), true),
2886            DataType::Map(entries_field, _) => {
2887                let DataType::Struct(fields) = entries_field.data_type() else {
2888                    panic!("Map entries field must be a struct");
2889                };
2890                let key_type = fields[0].data_type();
2891                let value_type = fields[1].data_type();
2892                rand_map(key_type, value_type)
2893            }
2894            DataType::Duration(unit) => match unit {
2895                TimeUnit::Second => rand::<DurationSecondType>(),
2896                TimeUnit::Millisecond => rand::<DurationMillisecondType>(),
2897                TimeUnit::Microsecond => rand::<DurationMicrosecondType>(),
2898                TimeUnit::Nanosecond => rand::<DurationNanosecondType>(),
2899            },
2900            DataType::Interval(unit) => rand_interval(*unit),
2901            DataType::Date32 => rand_date32(),
2902            DataType::Date64 => rand_date64(),
2903            DataType::Time32(resolution) => rand_time32(resolution),
2904            DataType::Time64(resolution) => rand_time64(resolution),
2905            DataType::Timestamp(_, _) => rand_timestamp(data_type),
2906            DataType::Struct(fields) => rand_struct(fields.clone()),
2907            DataType::Null => null_type(),
2908            _ => unimplemented!("random generation of {}", data_type),
2909        }
2910    }
2911
2912    /// Encodes arrays generated by the underlying generator as dictionaries with the given key type
2913    ///
2914    /// Note that this may not be very realistic if the underlying generator is something like a random
2915    /// generator since most of the underlying values will be unique and the common case for dictionary
2916    /// encoding is when there is a small set of possible values.
2917    pub fn dict<K: ArrowDictionaryKeyType + Send + Sync>(
2918        generator: Box<dyn ArrayGenerator>,
2919    ) -> Box<dyn ArrayGenerator> {
2920        Box::new(DictionaryGenerator::<K>::new(generator))
2921    }
2922
2923    /// Encodes arrays generated by the underlying generator as dictionaries with the given key type
2924    pub fn dict_type(
2925        generator: Box<dyn ArrayGenerator>,
2926        key_type: &DataType,
2927    ) -> Box<dyn ArrayGenerator> {
2928        match key_type {
2929            DataType::Int8 => dict::<Int8Type>(generator),
2930            DataType::Int16 => dict::<Int16Type>(generator),
2931            DataType::Int32 => dict::<Int32Type>(generator),
2932            DataType::Int64 => dict::<Int64Type>(generator),
2933            DataType::UInt8 => dict::<UInt8Type>(generator),
2934            DataType::UInt16 => dict::<UInt16Type>(generator),
2935            DataType::UInt32 => dict::<UInt32Type>(generator),
2936            DataType::UInt64 => dict::<UInt64Type>(generator),
2937            _ => unimplemented!(),
2938        }
2939    }
2940
2941    /// Wraps a generator to produce low-cardinality data.
2942    ///
2943    /// Generates `cardinality` unique values on first call, then randomly
2944    /// selects from them for all subsequent rows.
2945    pub fn low_cardinality(
2946        generator: Box<dyn ArrayGenerator>,
2947        cardinality: usize,
2948    ) -> Box<dyn ArrayGenerator> {
2949        Box::new(LowCardinalityGenerator::new(generator, cardinality))
2950    }
2951}
2952
2953/// Create a BatchGeneratorBuilder to start generating batch data
2954pub fn gen_batch() -> BatchGeneratorBuilder {
2955    BatchGeneratorBuilder::default()
2956}
2957
2958/// Create an ArrayGeneratorBuilder to start generating array data
2959pub fn gen_array(genn: Box<dyn ArrayGenerator>) -> ArrayGeneratorBuilder {
2960    ArrayGeneratorBuilder::new(genn)
2961}
2962
2963/// Metadata key to specify content type for string generation.
2964/// Set to "sentence" to use the sentence generator with Zipf distribution.
2965pub const CONTENT_TYPE_KEY: &str = "lance-datagen:content-type";
2966
2967/// Metadata key to specify cardinality for low-cardinality data generation.
2968/// Set to a numeric string (e.g., "100") to limit unique values.
2969pub const CARDINALITY_KEY: &str = "lance-datagen:cardinality";
2970
2971/// Create a generator for a field, checking metadata for content type hints.
2972///
2973/// Supported metadata keys:
2974/// - `lance-datagen:content-type`: Set to "sentence" for Utf8/LargeUtf8 fields
2975///   to use the sentence generator with Zipf distribution.
2976/// - `lance-datagen:cardinality`: Set to a number to limit unique values.
2977///   The generator will produce only that many unique values and randomly
2978///   select from them.
2979pub fn rand_field(field: &Field) -> Box<dyn ArrayGenerator> {
2980    let mut generator = if let Some(content_type) = field.metadata().get(CONTENT_TYPE_KEY) {
2981        match (content_type.as_str(), field.data_type()) {
2982            ("sentence", DataType::Utf8) => array::random_sentence(1, 10, false),
2983            ("sentence", DataType::LargeUtf8) => array::random_sentence(1, 10, true),
2984            _ => array::rand_type(field.data_type()),
2985        }
2986    } else {
2987        array::rand_type(field.data_type())
2988    };
2989
2990    if let Some(cardinality_str) = field.metadata().get(CARDINALITY_KEY)
2991        && let Ok(cardinality) = cardinality_str.parse::<usize>()
2992        && cardinality > 0
2993    {
2994        generator = array::low_cardinality(generator, cardinality);
2995    }
2996
2997    generator
2998}
2999
3000/// Create a BatchGeneratorBuilder with the given schema
3001///
3002/// You can add more columns or convert this into a reader immediately.
3003///
3004/// Supported field metadata:
3005/// - `lance-datagen:content-type` = `"sentence"`: Use sentence generator with
3006///   Zipf distribution for more realistic text (Utf8/LargeUtf8 only).
3007/// - `lance-datagen:cardinality` = `"<number>"`: Limit to N unique values.
3008pub fn rand(schema: &Schema) -> BatchGeneratorBuilder {
3009    let mut builder = BatchGeneratorBuilder::default();
3010    for field in schema.fields() {
3011        builder = builder.col(field.name(), rand_field(field));
3012    }
3013    builder
3014}
3015
3016#[cfg(test)]
3017mod tests {
3018
3019    use arrow::datatypes::{Float32Type, Int8Type, Int16Type, UInt32Type};
3020    use arrow_array::{BooleanArray, Float32Array, Int8Array, Int16Array, Int32Array, UInt32Array};
3021
3022    use super::*;
3023
3024    #[test]
3025    fn test_step() {
3026        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3027        let mut genn = array::step::<Int32Type>();
3028        assert_eq!(
3029            *genn.generate(RowCount::from(5), &mut rng).unwrap(),
3030            Int32Array::from_iter([0, 1, 2, 3, 4])
3031        );
3032        assert_eq!(
3033            *genn.generate(RowCount::from(5), &mut rng).unwrap(),
3034            Int32Array::from_iter([5, 6, 7, 8, 9])
3035        );
3036
3037        let mut genn = array::step::<Int8Type>();
3038        assert_eq!(
3039            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3040            Int8Array::from_iter([0, 1, 2])
3041        );
3042
3043        let mut genn = array::step::<Float32Type>();
3044        assert_eq!(
3045            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3046            Float32Array::from_iter([0.0, 1.0, 2.0])
3047        );
3048
3049        let mut genn = array::step_custom::<Int16Type>(4, 8);
3050        assert_eq!(
3051            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3052            Int16Array::from_iter([4, 12, 20])
3053        );
3054        assert_eq!(
3055            *genn.generate(RowCount::from(2), &mut rng).unwrap(),
3056            Int16Array::from_iter([28, 36])
3057        );
3058    }
3059
3060    #[test]
3061    fn test_cycle() {
3062        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3063        let mut genn = array::cycle::<Int32Type>(vec![1, 2, 3]);
3064        assert_eq!(
3065            *genn.generate(RowCount::from(5), &mut rng).unwrap(),
3066            Int32Array::from_iter([1, 2, 3, 1, 2])
3067        );
3068
3069        let mut genn = array::cycle_utf8_literals(&["abc", "def", "xyz"]);
3070        assert_eq!(
3071            *genn.generate(RowCount::from(5), &mut rng).unwrap(),
3072            StringArray::from_iter_values(["abc", "def", "xyz", "abc", "def"])
3073        );
3074        assert_eq!(
3075            *genn.generate(RowCount::from(1), &mut rng).unwrap(),
3076            StringArray::from_iter_values(["xyz"])
3077        );
3078
3079        let mut genn = array::cycle_bool(vec![false, false, true]);
3080        assert_eq!(
3081            *genn.generate(RowCount::from(5), &mut rng).unwrap(),
3082            BooleanArray::from_iter(vec![false, false, true, false, false].into_iter().map(Some))
3083        );
3084        assert_eq!(
3085            *genn.generate(RowCount::from(1), &mut rng).unwrap(),
3086            BooleanArray::from_iter(vec![Some(true)])
3087        )
3088    }
3089
3090    #[test]
3091    fn test_fill() {
3092        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3093        let mut genn = array::fill::<Int32Type>(42);
3094        assert_eq!(
3095            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3096            Int32Array::from_iter([42, 42, 42])
3097        );
3098        assert_eq!(
3099            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3100            Int32Array::from_iter([42, 42, 42])
3101        );
3102
3103        let mut genn = array::fill_varbin(vec![0, 1, 2]);
3104        assert_eq!(
3105            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3106            arrow_array::BinaryArray::from_iter_values([
3107                "\x00\x01\x02",
3108                "\x00\x01\x02",
3109                "\x00\x01\x02"
3110            ])
3111        );
3112
3113        let mut genn = array::fill_utf8("xyz".to_string());
3114        assert_eq!(
3115            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3116            arrow_array::StringArray::from_iter_values(["xyz", "xyz", "xyz"])
3117        );
3118    }
3119
3120    #[test]
3121    fn test_utf8_prefix_plus_counter() {
3122        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3123        let mut genn = array::utf8_prefix_plus_counter("user_", false);
3124        assert_eq!(
3125            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3126            arrow_array::StringArray::from_iter_values(["user_0", "user_1", "user_2"])
3127        );
3128
3129        let mut genn = array::utf8_prefix_plus_counter("user_", true);
3130        assert_eq!(
3131            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3132            arrow_array::LargeStringArray::from_iter_values(["user_0", "user_1", "user_2"])
3133        );
3134    }
3135
3136    #[test]
3137    fn test_rng() {
3138        // Note: these tests are heavily dependent on the default seed.
3139        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3140        let mut genn = array::rand::<Int32Type>();
3141        assert_eq!(
3142            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3143            Int32Array::from_iter([-797553329, 1369325940, -69174021])
3144        );
3145
3146        let mut genn = array::rand_fixedbin(ByteCount::from(3), false);
3147        assert_eq!(
3148            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3149            arrow_array::BinaryArray::from_iter_values([
3150                [184, 53, 216],
3151                [12, 96, 159],
3152                [125, 179, 56]
3153            ])
3154        );
3155
3156        let mut genn = array::rand_utf8(ByteCount::from(3), false);
3157        assert_eq!(
3158            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3159            arrow_array::StringArray::from_iter_values([">@p", "n `", "NWa"])
3160        );
3161
3162        let mut genn = array::random_sentence(1, 5, false);
3163        let words = genn.generate(RowCount::from(10), &mut rng).unwrap();
3164        assert_eq!(words.data_type(), &DataType::Utf8);
3165        let words_array = words.as_any().downcast_ref::<StringArray>().unwrap();
3166        // Verify each string contains 1-5 words
3167        for i in 0..10 {
3168            let sentence = words_array.value(i);
3169            let word_count = sentence.split_whitespace().count();
3170            assert!((1..=5).contains(&word_count));
3171        }
3172
3173        let mut genn = array::rand_date32();
3174        let days_32 = genn.generate(RowCount::from(3), &mut rng).unwrap();
3175        assert_eq!(days_32.data_type(), &DataType::Date32);
3176
3177        let mut genn = array::rand_date64();
3178        let days_64 = genn.generate(RowCount::from(3), &mut rng).unwrap();
3179        assert_eq!(days_64.data_type(), &DataType::Date64);
3180
3181        let mut genn = array::rand_boolean();
3182        let bools = genn.generate(RowCount::from(1024), &mut rng).unwrap();
3183        assert_eq!(bools.data_type(), &DataType::Boolean);
3184        let bools = bools.as_any().downcast_ref::<BooleanArray>().unwrap();
3185        // Sanity check to ensure we're getting at least some rng
3186        assert!(bools.false_count() > 100);
3187        assert!(bools.true_count() > 100);
3188
3189        let mut genn = array::rand_varbin(ByteCount::from(2), ByteCount::from(4));
3190        assert_eq!(
3191            *genn.generate(RowCount::from(3), &mut rng).unwrap(),
3192            arrow_array::BinaryArray::from_iter_values([
3193                vec![174, 178],
3194                vec![64, 122, 207, 248],
3195                vec![124, 3, 58]
3196            ])
3197        );
3198    }
3199
3200    #[test]
3201    fn test_rng_list() {
3202        // Note: these tests are heavily dependent on the default seed.
3203        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3204        let mut genn = array::rand_list(&DataType::Int32, false);
3205        let arr = genn.generate(RowCount::from(100), &mut rng).unwrap();
3206        // Make sure we can generate empty lists (note, test is dependent on seed)
3207        let arr = arr.as_list::<i32>();
3208        assert!(arr.iter().any(|l| l.unwrap().is_empty()));
3209        // Shouldn't generate any giant lists (don't kill performance in normal datagen)
3210        assert!(arr.iter().any(|l| l.unwrap().len() < 11));
3211    }
3212
3213    #[test]
3214    fn test_rng_distribution() {
3215        // Sanity test to make sure we our RNG is giving us well distributed values
3216        // We generates some 4-byte integers, histogram them into 8 buckets, and make
3217        // sure each bucket has a good # of values
3218        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3219        let mut genn = array::rand::<UInt32Type>();
3220        for _ in 0..10 {
3221            let arr = genn.generate(RowCount::from(10000), &mut rng).unwrap();
3222            let int_arr = arr.as_any().downcast_ref::<UInt32Array>().unwrap();
3223            let mut buckets = vec![0_u32; 256];
3224            for val in int_arr.values() {
3225                buckets[(*val >> 24) as usize] += 1;
3226            }
3227            for bucket in buckets {
3228                // Perfectly even distribution would have 10000 / 256 values (~40) per bucket
3229                // We test for 15 which should be "good enough" and statistically unlikely to fail
3230                assert!(bucket > 15);
3231            }
3232        }
3233    }
3234
3235    #[test]
3236    fn test_nulls() {
3237        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3238        let mut genn = array::rand::<Int32Type>().with_random_nulls(0.3);
3239
3240        let arr = genn.generate(RowCount::from(1000), &mut rng).unwrap();
3241
3242        // This assert depends on the default seed
3243        assert_eq!(arr.null_count(), 297);
3244
3245        for len in 0..100 {
3246            let arr = genn.generate(RowCount::from(len), &mut rng).unwrap();
3247            // Make sure the null count we came up with matches the actual # of unset bits
3248            assert_eq!(
3249                arr.null_count(),
3250                arr.nulls()
3251                    .map(|nulls| (len as usize)
3252                        - nulls.buffer().count_set_bits_offset(0, len as usize))
3253                    .unwrap_or(0)
3254            );
3255        }
3256
3257        let mut genn = array::rand::<Int32Type>().with_random_nulls(0.0);
3258        let arr = genn.generate(RowCount::from(10), &mut rng).unwrap();
3259
3260        assert_eq!(arr.null_count(), 0);
3261
3262        let mut genn = array::rand::<Int32Type>().with_random_nulls(1.0);
3263        let arr = genn.generate(RowCount::from(10), &mut rng).unwrap();
3264
3265        assert_eq!(arr.null_count(), 10);
3266        assert!((0..10).all(|idx| arr.is_null(idx)));
3267
3268        let mut genn = array::rand::<Int32Type>().with_nulls(&[false, false, true]);
3269        let arr = genn.generate(RowCount::from(7), &mut rng).unwrap();
3270        assert!((0..2).all(|idx| arr.is_valid(idx)));
3271        assert!(arr.is_null(2));
3272        assert!((3..5).all(|idx| arr.is_valid(idx)));
3273        assert!(arr.is_null(5));
3274        assert!(arr.is_valid(6));
3275    }
3276
3277    #[test]
3278    fn test_unit_circle() {
3279        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3280        let mut genn = array::cycle_unit_circle(4);
3281        let arr = genn.generate(RowCount::from(6), &mut rng).unwrap();
3282
3283        let arr_values = arr
3284            .as_fixed_size_list()
3285            .values()
3286            .as_primitive::<Float32Type>()
3287            .values()
3288            .to_vec();
3289        assert_eq!(arr_values.len(), 12);
3290        let expected_values = [1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, -1.0, 1.0, 0.0, 0.0, 1.0];
3291        for (actual, expected) in arr_values.iter().zip(expected_values.iter()) {
3292            assert!((actual - expected).abs() < 0.0001);
3293        }
3294    }
3295
3296    #[test]
3297    fn test_jitter_centroids() {
3298        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(DEFAULT_SEED.0);
3299        let mut centroids_gen = array::cycle_unit_circle(4);
3300        let centroids = centroids_gen.generate(RowCount::from(4), &mut rng).unwrap();
3301
3302        let centroid_values = centroids
3303            .as_fixed_size_list()
3304            .values()
3305            .as_primitive::<Float32Type>()
3306            .values()
3307            .to_vec();
3308
3309        let mut jitter_jen = array::jitter_centroids(centroids, 0.001);
3310        let jittered = jitter_jen.generate(RowCount::from(100), &mut rng).unwrap();
3311
3312        let values = jittered
3313            .as_fixed_size_list()
3314            .values()
3315            .as_primitive::<Float32Type>()
3316            .values()
3317            .to_vec();
3318
3319        for i in 0..100 {
3320            let centroid = i % 4;
3321            let centroid_x = centroid_values[centroid * 2];
3322            let centroid_y = centroid_values[centroid * 2 + 1];
3323            let value_x = values[i * 2];
3324            let value_y = values[i * 2 + 1];
3325
3326            let l2_dist = ((value_x - centroid_x).powi(2) + (value_y - centroid_y).powi(2)).sqrt();
3327            assert!(l2_dist < 0.001001);
3328            assert!(l2_dist > 0.000999);
3329        }
3330    }
3331
3332    #[test]
3333    fn test_rand_schema() {
3334        let schema = Schema::new(vec![
3335            Field::new("a", DataType::Int32, true),
3336            Field::new("b", DataType::Utf8, true),
3337            Field::new("c", DataType::Float32, true),
3338            Field::new("d", DataType::Int32, true),
3339            Field::new("e", DataType::Int32, true),
3340        ]);
3341        let rbr = rand(&schema)
3342            .into_reader_bytes(
3343                ByteCount::from(1024 * 1024),
3344                BatchCount::from(8),
3345                RoundingBehavior::ExactOrErr,
3346            )
3347            .unwrap();
3348        assert_eq!(*rbr.schema(), schema);
3349
3350        let batches = rbr.map(|val| val.unwrap()).collect::<Vec<_>>();
3351        assert_eq!(batches.len(), 8);
3352
3353        for batch in batches {
3354            assert_eq!(batch.num_rows(), 1024 * 1024 / 32);
3355            assert_eq!(batch.num_columns(), 5);
3356        }
3357    }
3358}