1use std::collections::HashSet;
7use std::sync::Arc;
8use std::{iter::repeat_with, ops::Range};
9
10use arrow_array::types::ArrowPrimitiveType;
11use arrow_array::{
12 Float32Array, Int32Array, Int8Array, PrimitiveArray, RecordBatch, RecordBatchIterator,
13 RecordBatchReader,
14};
15use arrow_schema::{DataType, Field, Schema as ArrowSchema};
16use lance_arrow::{fixed_size_list_type, ArrowFloatType, FixedSizeListArrayExt};
17use num_traits::{real::Real, FromPrimitive};
18use rand::distr::uniform::SampleUniform;
19use rand::{
20 distr::Uniform, prelude::Distribution, rngs::StdRng, seq::SliceRandom, Rng, SeedableRng,
21};
22
23pub trait ArrayGenerator {
24 fn generate(&mut self, length: usize) -> Arc<dyn arrow_array::Array>;
25 fn data_type(&self) -> &DataType;
26 fn name(&self) -> Option<&str>;
27}
28
29pub struct IncrementingInt32 {
30 name: Option<String>,
31 current: i32,
32 step: i32,
33}
34
35impl Default for IncrementingInt32 {
36 fn default() -> Self {
37 Self {
38 name: None,
39 current: 0,
40 step: 1,
41 }
42 }
43}
44
45impl IncrementingInt32 {
46 pub fn new() -> Self {
47 Default::default()
48 }
49
50 pub fn start(mut self, start: i32) -> Self {
51 self.current = start;
52 self
53 }
54
55 pub fn step(mut self, step: i32) -> Self {
56 self.step = step;
57 self
58 }
59
60 pub fn named(mut self, name: impl Into<String>) -> Self {
61 self.name = Some(name.into());
62 self
63 }
64}
65
66impl ArrayGenerator for IncrementingInt32 {
67 fn generate(&mut self, length: usize) -> Arc<dyn arrow_array::Array> {
68 let mut values = Vec::with_capacity(length);
69 for _ in 0..length {
70 values.push(self.current);
71 self.current += self.step;
72 }
73 Arc::new(Int32Array::from(values))
74 }
75
76 fn name(&self) -> Option<&str> {
77 self.name.as_deref()
78 }
79
80 fn data_type(&self) -> &DataType {
81 &DataType::Int32
82 }
83}
84
85pub struct RandomVector {
86 name: Option<String>,
87 vec_width: i32,
88 data_type: DataType,
89}
90
91impl Default for RandomVector {
92 fn default() -> Self {
93 Self {
94 name: None,
95 vec_width: 4,
96 data_type: fixed_size_list_type(4, DataType::Float32),
97 }
98 }
99}
100
101impl RandomVector {
102 pub fn new() -> Self {
103 Default::default()
104 }
105
106 pub fn vec_width(mut self, vec_width: i32) -> Self {
107 self.vec_width = vec_width;
108 self.data_type = fixed_size_list_type(self.vec_width, DataType::Float32);
109 self
110 }
111
112 pub fn named(mut self, name: String) -> Self {
113 self.name = Some(name);
114 self
115 }
116}
117
118impl ArrayGenerator for RandomVector {
119 fn generate(&mut self, length: usize) -> Arc<dyn arrow_array::Array> {
120 let values = generate_random_array(length * (self.vec_width as usize));
121 Arc::new(
122 <arrow_array::FixedSizeListArray as FixedSizeListArrayExt>::try_new_from_values(
123 values,
124 self.vec_width,
125 )
126 .expect("Create fixed size list"),
127 )
128 }
129
130 fn name(&self) -> Option<&str> {
131 self.name.as_deref()
132 }
133
134 fn data_type(&self) -> &DataType {
135 &self.data_type
136 }
137}
138
139#[derive(Default)]
140pub struct BatchGenerator {
141 generators: Vec<Box<dyn ArrayGenerator>>,
142}
143
144impl BatchGenerator {
145 pub fn new() -> Self {
146 Default::default()
147 }
148
149 pub fn col(mut self, genn: Box<dyn ArrayGenerator>) -> Self {
150 self.generators.push(genn);
151 self
152 }
153
154 fn gen_batch(&mut self, num_rows: u32) -> RecordBatch {
155 let mut fields = Vec::with_capacity(self.generators.len());
156 let mut arrays = Vec::with_capacity(self.generators.len());
157 for (field_index, genn) in self.generators.iter_mut().enumerate() {
158 let arr = genn.generate(num_rows as usize);
159 let default_name = format!("field_{}", field_index);
160 let name = genn.name().unwrap_or(&default_name);
161 fields.push(Field::new(name, arr.data_type().clone(), true));
162 arrays.push(arr);
163 }
164 let schema = Arc::new(ArrowSchema::new(fields));
165 RecordBatch::try_new(schema, arrays).unwrap()
166 }
167
168 pub fn batch(&mut self, num_rows: i32) -> impl RecordBatchReader {
169 let batch = self.gen_batch(num_rows as u32);
170 let schema = batch.schema();
171 RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema)
172 }
173
174 pub fn batches(&mut self, num_batches: u32, rows_per_batch: u32) -> impl RecordBatchReader {
175 let batches = (0..num_batches)
176 .map(|_| self.gen_batch(rows_per_batch))
177 .collect::<Vec<_>>();
178 let schema = batches[0].schema();
179 RecordBatchIterator::new(batches.into_iter().map(Ok), schema)
180 }
181}
182
183pub fn some_indexable_batch() -> impl RecordBatchReader {
191 let x = Box::new(RandomVector::new().named("indexable".to_string()));
192 BatchGenerator::new().col(x).batch(512)
193}
194
195pub fn some_batch() -> impl RecordBatchReader {
202 some_indexable_batch()
203}
204
205pub fn generate_random_array_with_seed<T: ArrowFloatType>(n: usize, seed: [u8; 32]) -> T::ArrayType
207where
208 T::Native: Real + FromPrimitive,
209{
210 let mut rng = StdRng::from_seed(seed);
211
212 T::ArrayType::from(
213 repeat_with(|| T::Native::from_f32(rng.random::<f32>()).unwrap())
214 .take(n)
215 .collect::<Vec<_>>(),
216 )
217}
218
219pub fn generate_random_array(n: usize) -> Float32Array {
222 let mut rng = rand::rng();
223 Float32Array::from_iter_values(repeat_with(|| rng.random::<f32>()).take(n))
224}
225
226pub fn generate_random_int8_array(n: usize) -> Int8Array {
229 let mut rng = rand::rng();
230 Int8Array::from_iter_values(repeat_with(|| rng.random::<i8>()).take(n))
231}
232
233pub fn generate_random_array_with_range<T: ArrowPrimitiveType>(
236 n: usize,
237 range: Range<T::Native>,
238) -> PrimitiveArray<T>
239where
240 T::Native: SampleUniform,
241{
242 let mut rng = StdRng::from_seed([13; 32]);
243 let distribution = Uniform::new(range.start, range.end).unwrap();
244 PrimitiveArray::<T>::from_iter_values(repeat_with(|| distribution.sample(&mut rng)).take(n))
245}
246
247pub fn generate_scaled_random_array(n: usize, min: f32, max: f32) -> Float32Array {
250 let mut rng = rand::rng();
251 let distribution = Uniform::new(min, max).unwrap();
252 Float32Array::from_iter_values(repeat_with(|| distribution.sample(&mut rng)).take(n))
253}
254
255pub fn sample_indices(range: Range<usize>, num_picks: u32) -> Vec<usize> {
256 let mut rng = rand::rng();
257 let dist = Uniform::new(range.start, range.end).unwrap();
258 let ratio = num_picks as f32 / range.len() as f32;
259 if ratio < 0.1_f32 && num_picks > 1000 {
260 let mut picked = HashSet::<usize>::with_capacity(num_picks as usize);
263 let mut ordered_picked = Vec::with_capacity(num_picks as usize);
264 while picked.len() < num_picks as usize {
265 let val = dist.sample(&mut rng);
266 if picked.insert(val) {
267 ordered_picked.push(val);
268 }
269 }
270 ordered_picked
271 } else {
272 let mut values = Vec::from_iter(range);
275 values.partial_shuffle(&mut rng, num_picks as usize);
276 values.truncate(num_picks as usize);
277 values
278 }
279}
280
281pub fn sample_without_replacement<T: Copy>(choices: &[T], num_picks: u32) -> Vec<T> {
282 let mut rng = rand::rng();
283 let mut shuffled = Vec::from(choices);
284 shuffled.partial_shuffle(&mut rng, num_picks as usize);
285 shuffled.truncate(num_picks as usize);
286 shuffled
287}