Skip to main content

lance_testing/
datagen.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Data generation utilities for unit tests
5
6use std::collections::HashSet;
7use std::sync::Arc;
8use std::{iter::repeat_with, ops::Range};
9
10use arrow_array::types::ArrowPrimitiveType;
11use arrow_array::{
12    Float32Array, Int8Array, Int32Array, PrimitiveArray, RecordBatch, RecordBatchIterator,
13    RecordBatchReader,
14};
15use arrow_schema::{DataType, Field, Schema as ArrowSchema};
16use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, fixed_size_list_type};
17use num_traits::{FromPrimitive, real::Real};
18use rand::distr::uniform::SampleUniform;
19use rand::{
20    Rng, SeedableRng, distr::Uniform, prelude::Distribution, rngs::StdRng, seq::SliceRandom,
21};
22
23pub trait ArrayGenerator {
24    fn generate(&mut self, length: usize) -> Arc<dyn arrow_array::Array>;
25    fn data_type(&self) -> &DataType;
26    fn name(&self) -> Option<&str>;
27}
28
29pub struct IncrementingInt32 {
30    name: Option<String>,
31    current: i32,
32    step: i32,
33}
34
35impl Default for IncrementingInt32 {
36    fn default() -> Self {
37        Self {
38            name: None,
39            current: 0,
40            step: 1,
41        }
42    }
43}
44
45impl IncrementingInt32 {
46    pub fn new() -> Self {
47        Default::default()
48    }
49
50    pub fn start(mut self, start: i32) -> Self {
51        self.current = start;
52        self
53    }
54
55    pub fn step(mut self, step: i32) -> Self {
56        self.step = step;
57        self
58    }
59
60    pub fn named(mut self, name: impl Into<String>) -> Self {
61        self.name = Some(name.into());
62        self
63    }
64}
65
66impl ArrayGenerator for IncrementingInt32 {
67    fn generate(&mut self, length: usize) -> Arc<dyn arrow_array::Array> {
68        let mut values = Vec::with_capacity(length);
69        for _ in 0..length {
70            values.push(self.current);
71            self.current += self.step;
72        }
73        Arc::new(Int32Array::from(values))
74    }
75
76    fn name(&self) -> Option<&str> {
77        self.name.as_deref()
78    }
79
80    fn data_type(&self) -> &DataType {
81        &DataType::Int32
82    }
83}
84
85pub struct RandomVector {
86    name: Option<String>,
87    vec_width: i32,
88    data_type: DataType,
89}
90
91impl Default for RandomVector {
92    fn default() -> Self {
93        Self {
94            name: None,
95            vec_width: 4,
96            data_type: fixed_size_list_type(4, DataType::Float32),
97        }
98    }
99}
100
101impl RandomVector {
102    pub fn new() -> Self {
103        Default::default()
104    }
105
106    pub fn vec_width(mut self, vec_width: i32) -> Self {
107        self.vec_width = vec_width;
108        self.data_type = fixed_size_list_type(self.vec_width, DataType::Float32);
109        self
110    }
111
112    pub fn named(mut self, name: String) -> Self {
113        self.name = Some(name);
114        self
115    }
116}
117
118impl ArrayGenerator for RandomVector {
119    fn generate(&mut self, length: usize) -> Arc<dyn arrow_array::Array> {
120        let values = generate_random_array(length * (self.vec_width as usize));
121        Arc::new(
122            <arrow_array::FixedSizeListArray as FixedSizeListArrayExt>::try_new_from_values(
123                values,
124                self.vec_width,
125            )
126            .expect("Create fixed size list"),
127        )
128    }
129
130    fn name(&self) -> Option<&str> {
131        self.name.as_deref()
132    }
133
134    fn data_type(&self) -> &DataType {
135        &self.data_type
136    }
137}
138
139#[derive(Default)]
140pub struct BatchGenerator {
141    generators: Vec<Box<dyn ArrayGenerator>>,
142}
143
144impl BatchGenerator {
145    pub fn new() -> Self {
146        Default::default()
147    }
148
149    pub fn col(mut self, genn: Box<dyn ArrayGenerator>) -> Self {
150        self.generators.push(genn);
151        self
152    }
153
154    fn gen_batch(&mut self, num_rows: u32) -> RecordBatch {
155        let mut fields = Vec::with_capacity(self.generators.len());
156        let mut arrays = Vec::with_capacity(self.generators.len());
157        for (field_index, genn) in self.generators.iter_mut().enumerate() {
158            let arr = genn.generate(num_rows as usize);
159            let default_name = format!("field_{}", field_index);
160            let name = genn.name().unwrap_or(&default_name);
161            fields.push(Field::new(name, arr.data_type().clone(), true));
162            arrays.push(arr);
163        }
164        let schema = Arc::new(ArrowSchema::new(fields));
165        RecordBatch::try_new(schema, arrays).unwrap()
166    }
167
168    pub fn batch(&mut self, num_rows: i32) -> impl RecordBatchReader + use<> {
169        let batch = self.gen_batch(num_rows as u32);
170        let schema = batch.schema();
171        RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema)
172    }
173
174    pub fn batches(
175        &mut self,
176        num_batches: u32,
177        rows_per_batch: u32,
178    ) -> impl RecordBatchReader + use<> {
179        let batches = (0..num_batches)
180            .map(|_| self.gen_batch(rows_per_batch))
181            .collect::<Vec<_>>();
182        let schema = batches[0].schema();
183        RecordBatchIterator::new(batches.into_iter().map(Ok), schema)
184    }
185}
186
187/// Returns a batch of data that has a column that can be used to create an ANN index
188///
189/// The indexable column will be named "indexable"
190/// The batch will not be empty
191/// There will only be one batch
192///
193/// There are no other assumptions it is safe to make about the returned reader
194pub fn some_indexable_batch() -> impl RecordBatchReader {
195    let x = Box::new(RandomVector::new().named("indexable".to_string()));
196    BatchGenerator::new().col(x).batch(512)
197}
198
199/// Returns a non-empty batch of data
200///
201/// The batch will not be empty
202/// There will only be one batch
203///
204/// There are no other assumptions it is safe to make about the returned reader
205pub fn some_batch() -> impl RecordBatchReader {
206    some_indexable_batch()
207}
208
209/// Create a random float32 array.
210pub fn generate_random_array_with_seed<T: ArrowFloatType>(n: usize, seed: [u8; 32]) -> T::ArrayType
211where
212    T::Native: Real + FromPrimitive,
213{
214    let mut rng = StdRng::from_seed(seed);
215
216    <T::ArrayType as lance_arrow::FloatArray<T>>::from_iter_values(
217        repeat_with(|| T::Native::from_f32(rng.random::<f32>()).unwrap()).take(n),
218    )
219}
220
221/// Create a random float32 array where each element is uniformly
222/// distributed between [0..1]
223pub fn generate_random_array(n: usize) -> Float32Array {
224    let mut rng = rand::rng();
225    Float32Array::from_iter_values(repeat_with(|| rng.random::<f32>()).take(n))
226}
227
228/// Create a random float32 array where each element is uniformly
229/// distributed between [0..1]
230pub fn generate_random_int8_array(n: usize) -> Int8Array {
231    let mut rng = rand::rng();
232    Int8Array::from_iter_values(repeat_with(|| rng.random::<i8>()).take(n))
233}
234
235/// Create a random primitive array where each element is uniformly distributed a
236/// given range.
237pub fn generate_random_array_with_range<T: ArrowPrimitiveType>(
238    n: usize,
239    range: Range<T::Native>,
240) -> PrimitiveArray<T>
241where
242    T::Native: SampleUniform,
243{
244    let mut rng = StdRng::from_seed([13; 32]);
245    let distribution = Uniform::new(range.start, range.end).unwrap();
246    PrimitiveArray::<T>::from_iter_values(repeat_with(|| distribution.sample(&mut rng)).take(n))
247}
248
249/// Create a random float32 array where each element is uniformly
250/// distributed across the given range
251pub fn generate_scaled_random_array(n: usize, min: f32, max: f32) -> Float32Array {
252    let mut rng = rand::rng();
253    let distribution = Uniform::new(min, max).unwrap();
254    Float32Array::from_iter_values(repeat_with(|| distribution.sample(&mut rng)).take(n))
255}
256
257pub fn sample_indices(range: Range<usize>, num_picks: u32) -> Vec<usize> {
258    let mut rng = rand::rng();
259    let dist = Uniform::new(range.start, range.end).unwrap();
260    let ratio = num_picks as f32 / range.len() as f32;
261    if ratio < 0.1_f32 && num_picks > 1000 {
262        // We want to pick a large number of values from a big range.  Better to
263        // use a set and potential retries
264        let mut picked = HashSet::<usize>::with_capacity(num_picks as usize);
265        let mut ordered_picked = Vec::with_capacity(num_picks as usize);
266        while picked.len() < num_picks as usize {
267            let val = dist.sample(&mut rng);
268            if picked.insert(val) {
269                ordered_picked.push(val);
270            }
271        }
272        ordered_picked
273    } else {
274        // We want to pick most of the range, or a small number of values.  Go ahead
275        // and just materialize the range and shuffle
276        let mut values = Vec::from_iter(range);
277        values.partial_shuffle(&mut rng, num_picks as usize);
278        values.truncate(num_picks as usize);
279        values
280    }
281}
282
283pub fn sample_without_replacement<T: Copy>(choices: &[T], num_picks: u32) -> Vec<T> {
284    let mut rng = rand::rng();
285    let mut shuffled = Vec::from(choices);
286    shuffled.partial_shuffle(&mut rng, num_picks as usize);
287    shuffled.truncate(num_picks as usize);
288    shuffled
289}