1use std::collections::HashSet;
7use std::sync::Arc;
8use std::{iter::repeat_with, ops::Range};
9
10use arrow_array::types::ArrowPrimitiveType;
11use arrow_array::{
12 Float32Array, Int32Array, PrimitiveArray, RecordBatch, RecordBatchIterator, RecordBatchReader,
13};
14use arrow_schema::{DataType, Field, Schema as ArrowSchema};
15use lance_arrow::{fixed_size_list_type, ArrowFloatType, FixedSizeListArrayExt};
16use num_traits::{real::Real, FromPrimitive};
17use rand::distributions::uniform::SampleUniform;
18use rand::{
19 distributions::Uniform, prelude::Distribution, rngs::StdRng, seq::SliceRandom, Rng, SeedableRng,
20};
21
22pub trait ArrayGenerator {
23 fn generate(&mut self, length: usize) -> Arc<dyn arrow_array::Array>;
24 fn data_type(&self) -> &DataType;
25 fn name(&self) -> Option<&str>;
26}
27
28pub struct IncrementingInt32 {
29 name: Option<String>,
30 current: i32,
31 step: i32,
32}
33
34impl Default for IncrementingInt32 {
35 fn default() -> Self {
36 Self {
37 name: None,
38 current: 0,
39 step: 1,
40 }
41 }
42}
43
44impl IncrementingInt32 {
45 pub fn new() -> Self {
46 Default::default()
47 }
48
49 pub fn start(mut self, start: i32) -> Self {
50 self.current = start;
51 self
52 }
53
54 pub fn step(mut self, step: i32) -> Self {
55 self.step = step;
56 self
57 }
58
59 pub fn named(mut self, name: impl Into<String>) -> Self {
60 self.name = Some(name.into());
61 self
62 }
63}
64
65impl ArrayGenerator for IncrementingInt32 {
66 fn generate(&mut self, length: usize) -> Arc<dyn arrow_array::Array> {
67 let mut values = Vec::with_capacity(length);
68 for _ in 0..length {
69 values.push(self.current);
70 self.current += self.step;
71 }
72 Arc::new(Int32Array::from(values))
73 }
74
75 fn name(&self) -> Option<&str> {
76 self.name.as_deref()
77 }
78
79 fn data_type(&self) -> &DataType {
80 &DataType::Int32
81 }
82}
83
84pub struct RandomVector {
85 name: Option<String>,
86 vec_width: i32,
87 data_type: DataType,
88}
89
90impl Default for RandomVector {
91 fn default() -> Self {
92 Self {
93 name: None,
94 vec_width: 4,
95 data_type: fixed_size_list_type(4, DataType::Float32),
96 }
97 }
98}
99
100impl RandomVector {
101 pub fn new() -> Self {
102 Default::default()
103 }
104
105 pub fn vec_width(mut self, vec_width: i32) -> Self {
106 self.vec_width = vec_width;
107 self.data_type = fixed_size_list_type(self.vec_width, DataType::Float32);
108 self
109 }
110
111 pub fn named(mut self, name: String) -> Self {
112 self.name = Some(name);
113 self
114 }
115}
116
117impl ArrayGenerator for RandomVector {
118 fn generate(&mut self, length: usize) -> Arc<dyn arrow_array::Array> {
119 let values = generate_random_array(length * (self.vec_width as usize));
120 Arc::new(
121 <arrow_array::FixedSizeListArray as FixedSizeListArrayExt>::try_new_from_values(
122 values,
123 self.vec_width,
124 )
125 .expect("Create fixed size list"),
126 )
127 }
128
129 fn name(&self) -> Option<&str> {
130 self.name.as_deref()
131 }
132
133 fn data_type(&self) -> &DataType {
134 &self.data_type
135 }
136}
137
138#[derive(Default)]
139pub struct BatchGenerator {
140 generators: Vec<Box<dyn ArrayGenerator>>,
141}
142
143impl BatchGenerator {
144 pub fn new() -> Self {
145 Default::default()
146 }
147
148 pub fn col(mut self, gen: Box<dyn ArrayGenerator>) -> Self {
149 self.generators.push(gen);
150 self
151 }
152
153 fn gen_batch(&mut self, num_rows: u32) -> RecordBatch {
154 let mut fields = Vec::with_capacity(self.generators.len());
155 let mut arrays = Vec::with_capacity(self.generators.len());
156 for (field_index, gen) in self.generators.iter_mut().enumerate() {
157 let arr = gen.generate(num_rows as usize);
158 let default_name = format!("field_{}", field_index);
159 let name = gen.name().unwrap_or(&default_name);
160 fields.push(Field::new(name, arr.data_type().clone(), true));
161 arrays.push(arr);
162 }
163 let schema = Arc::new(ArrowSchema::new(fields));
164 RecordBatch::try_new(schema, arrays).unwrap()
165 }
166
167 pub fn batch(&mut self, num_rows: i32) -> impl RecordBatchReader {
168 let batch = self.gen_batch(num_rows as u32);
169 let schema = batch.schema();
170 RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema)
171 }
172
173 pub fn batches(&mut self, num_batches: u32, rows_per_batch: u32) -> impl RecordBatchReader {
174 let batches = (0..num_batches)
175 .map(|_| self.gen_batch(rows_per_batch))
176 .collect::<Vec<_>>();
177 let schema = batches[0].schema();
178 RecordBatchIterator::new(batches.into_iter().map(Ok), schema)
179 }
180}
181
182pub fn some_indexable_batch() -> impl RecordBatchReader {
190 let x = Box::new(RandomVector::new().named("indexable".to_string()));
191 BatchGenerator::new().col(x).batch(512)
192}
193
194pub fn some_batch() -> impl RecordBatchReader {
201 some_indexable_batch()
202}
203
204pub fn generate_random_array_with_seed<T: ArrowFloatType>(n: usize, seed: [u8; 32]) -> T::ArrayType
206where
207 T::Native: Real + FromPrimitive,
208{
209 let mut rng = StdRng::from_seed(seed);
210
211 T::ArrayType::from(
212 repeat_with(|| T::Native::from_f32(rng.gen::<f32>()).unwrap())
213 .take(n)
214 .collect::<Vec<_>>(),
215 )
216}
217
218pub fn generate_random_array(n: usize) -> Float32Array {
221 let mut rng = rand::thread_rng();
222 Float32Array::from_iter_values(repeat_with(|| rng.gen::<f32>()).take(n))
223}
224
225pub fn generate_random_array_with_range<T: ArrowPrimitiveType>(
228 n: usize,
229 range: Range<T::Native>,
230) -> PrimitiveArray<T>
231where
232 T::Native: SampleUniform,
233{
234 let mut rng = StdRng::from_seed([13; 32]);
235 let distribution = Uniform::new(range.start, range.end);
236 PrimitiveArray::<T>::from_iter_values(repeat_with(|| distribution.sample(&mut rng)).take(n))
237}
238
239pub fn generate_scaled_random_array(n: usize, min: f32, max: f32) -> Float32Array {
242 let mut rng = rand::thread_rng();
243 let distribution = Uniform::new(min, max);
244 Float32Array::from_iter_values(repeat_with(|| distribution.sample(&mut rng)).take(n))
245}
246
247pub fn sample_indices(range: Range<usize>, num_picks: u32) -> Vec<usize> {
248 let mut rng = rand::thread_rng();
249 let dist = Uniform::new(range.start, range.end);
250 let ratio = num_picks as f32 / range.len() as f32;
251 if ratio < 0.1_f32 && num_picks > 1000 {
252 let mut picked = HashSet::<usize>::with_capacity(num_picks as usize);
255 let mut ordered_picked = Vec::with_capacity(num_picks as usize);
256 while picked.len() < num_picks as usize {
257 let val = dist.sample(&mut rng);
258 if picked.insert(val) {
259 ordered_picked.push(val);
260 }
261 }
262 ordered_picked
263 } else {
264 let mut values = Vec::from_iter(range);
267 values.partial_shuffle(&mut rng, num_picks as usize);
268 values.truncate(num_picks as usize);
269 values
270 }
271}
272
273pub fn sample_without_replacement<T: Copy>(choices: &[T], num_picks: u32) -> Vec<T> {
274 let mut rng = rand::thread_rng();
275 let mut shuffled = Vec::from(choices);
276 shuffled.partial_shuffle(&mut rng, num_picks as usize);
277 shuffled.truncate(num_picks as usize);
278 shuffled
279}