1use std::io::{BufWriter, Write};
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use diskann_providers::storage::StorageWriteProvider;
10use diskann_utils::{io::Metadata, sampling::random::WithApproximateNorm};
11use diskann_vector::Half;
12
13use crate::utils::{CMDResult, CMDToolError, DataType};
14
15type WriteVectorMethodType<T> = Box<dyn Fn(&mut BufWriter<T>, &Vec<f32>) -> CMDResult<bool>>;
16
17#[allow(clippy::panic)]
23pub fn write_random_data<StorageProvider: StorageWriteProvider>(
24 storage_provider: &StorageProvider,
25 output_file: &str,
26 data_type: DataType,
27 number_of_dimensions: usize,
28 number_of_vectors: u64,
29 radius: f32,
30) -> CMDResult<()> {
31 if (data_type == DataType::Int8 || data_type == DataType::Uint8)
32 && (radius > 127.0 || radius <= 0.0)
33 {
34 return Err(CMDToolError {
35 details:
36 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
37 .to_string(),
38 });
39 }
40
41 let file = storage_provider.create_for_write(output_file)?;
42 let writer = BufWriter::new(file);
43
44 write_random_data_writer(
45 writer,
46 data_type,
47 number_of_dimensions,
48 number_of_vectors,
49 radius,
50 )
51}
52
53#[allow(clippy::panic)]
59fn write_random_data_writer<T: Sized + Write>(
60 mut writer: BufWriter<T>,
61 data_type: DataType,
62 number_of_dimensions: usize,
63 number_of_vectors: u64,
64 radius: f32,
65) -> CMDResult<()> {
66 if (data_type == DataType::Int8 || data_type == DataType::Uint8)
67 && (radius > 127.0 || radius <= 0.0)
68 {
69 return Err(CMDToolError {
70 details:
71 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
72 .to_string(),
73 });
74 }
75
76 Metadata::new(number_of_vectors, number_of_dimensions)?.write(&mut writer)?;
77
78 let block_size = 131072;
79 let nblks = u64::div_ceil(number_of_vectors, block_size);
80 println!("# blks: {}", nblks);
81
82 for i in 0..nblks {
83 let cblk_size = std::cmp::min(number_of_vectors - i * block_size, block_size);
84
85 let write_method: WriteVectorMethodType<T> = match data_type {
88 DataType::Float => Box::new(
89 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
90 let mut found_nonzero = false;
91 for value in vector {
92 writer.write_f32::<LittleEndian>(*value)?;
93 found_nonzero = found_nonzero || ((*value != 0f32) && value.is_finite());
94 }
95 Ok(found_nonzero)
96 },
97 ),
98 DataType::Uint8 => Box::new(
99 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
100 let mut found_nonempty = false;
101 for value in vector.iter().map(|&item| (item + 128.0).round() as u8) {
104 writer.write_u8(value)?;
105
106 found_nonempty = found_nonempty || (value != 128u8);
109 }
110 Ok(found_nonempty)
111 },
112 ),
113 DataType::Int8 => Box::new(
114 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
115 let mut found_nonzero = false;
116 for value in vector.iter().map(|&item| item.round() as i8) {
117 writer.write_i8(value)?;
118 found_nonzero = found_nonzero || (value != 0i8);
119 }
120 Ok(found_nonzero)
121 },
122 ),
123 DataType::Fp16 => Box::new(
124 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
125 let mut found_nonzero = false;
126 for value in vector.iter().map(|&item| Half::from_f32(item)) {
127 let mut buf = [0; 2];
128 buf.clone_from_slice(value.to_le_bytes().as_slice());
129 writer.write_all(&buf)?;
130 found_nonzero =
131 found_nonzero || (value != Half::from_f32(0.0) && value.is_finite());
132 }
133 Ok(found_nonzero)
134 },
135 ),
136 };
137
138 write_random_vector_block(
140 write_method,
141 &mut writer,
142 number_of_dimensions,
143 cblk_size,
144 radius,
145 )?;
146 }
147
148 writer.flush()?;
151
152 Ok(())
153}
154
155fn write_random_vector_block<
163 F: Sized + Write,
164 T: FnMut(&mut BufWriter<F>, &Vec<f32>) -> CMDResult<bool>,
165>(
166 mut write_method: T,
167 writer: &mut BufWriter<F>,
168 number_of_dimensions: usize,
169 number_of_points: u64,
170 radius: f32,
171) -> CMDResult<()> {
172 let mut found_nonzero = false;
173
174 let mut rng = diskann_providers::utils::create_rnd_from_seed(0);
175 for _ in 0..number_of_points {
176 let vector = f32::with_approximate_norm(number_of_dimensions, radius, &mut rng);
177 found_nonzero |= write_method(writer, &vector)?;
180 }
181
182 if found_nonzero {
183 Ok(())
184 } else {
185 Err(CMDToolError {
186 details: format!(
187 "Generated all-zero vectors with radius {}. Try increasing radius",
188 radius
189 ),
190 })
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use diskann_providers::storage::VirtualStorageProvider;
197 use rstest::rstest;
198
199 use super::*;
200 use crate::utils::size_constants::{TEST_DATASET_SIZE_SMALL, TEST_NUM_DIMENSIONS_RECOMMENDED};
201
202 #[rstest]
203 fn random_data_write_success(
204 #[values(DataType::Float, DataType::Uint8, DataType::Int8)] data_type: DataType,
205 #[values(100.0, 127.0)] norm: f32,
206 ) {
207 let random_data_path = "/mydatafile.bin";
208 let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
209
210 let storage_provider = VirtualStorageProvider::new_overlay(".");
211 let result = write_random_data(
212 &storage_provider,
213 random_data_path,
214 data_type,
215 num_dimensions,
216 10000,
217 norm,
218 );
219
220 assert!(result.is_ok(), "write_random_data should succeed");
221 assert!(
222 storage_provider.exists(random_data_path),
223 "Random data file should exist"
224 );
225 }
226
227 #[rstest]
230 #[case(DataType::Float, 0.0)]
231 #[case(DataType::Int8, 0.0)]
232 #[case(DataType::Int8, 0.1)]
233 #[case(DataType::Int8, 1.0)]
234 #[case(DataType::Uint8, 0.0)]
235 #[case(DataType::Uint8, 0.1)]
236 #[case(DataType::Uint8, 1.0)]
237 fn random_data_write_too_low_norm(#[case] data_type: DataType, #[case] radius: f32) {
238 let random_data_path = "/mydatafile.bin";
239 let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
240
241 let expected = if (data_type == DataType::Int8 || data_type == DataType::Uint8)
242 && radius <= 0.0
243 {
244 Err(CMDToolError {
245 details:
246 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
247 .to_string(),
248 })
249 } else {
250 Err(CMDToolError {
251 details: format!(
252 "Generated all-zero vectors with radius {}. Try increasing radius",
253 radius
254 ),
255 })
256 };
257
258 let storage_provider = VirtualStorageProvider::new_overlay(".");
259 let result = write_random_data(
260 &storage_provider,
261 random_data_path,
262 data_type,
263 num_dimensions,
264 TEST_DATASET_SIZE_SMALL,
265 radius,
266 );
267
268 assert_eq!(expected, result);
269 }
270
271 #[test]
272 fn test_fp16_data_type() {
273 let random_data_path = "/fp16_data.bin";
274 let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
275
276 let storage_provider = VirtualStorageProvider::new_overlay(".");
277 let result = write_random_data(
278 &storage_provider,
279 random_data_path,
280 DataType::Fp16,
281 num_dimensions,
282 100,
283 50.0,
284 );
285
286 assert!(result.is_ok(), "write_random_data with Fp16 should succeed");
287 assert!(storage_provider.exists(random_data_path));
288 }
289
290 #[test]
291 fn test_invalid_radius_for_int8() {
292 let random_data_path = "/invalid_int8.bin";
293 let storage_provider = VirtualStorageProvider::new_overlay(".");
294
295 let expected = Err(CMDToolError {
296 details:
297 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
298 .to_string(),
299 });
300 let result = write_random_data(
301 &storage_provider,
302 random_data_path,
303 DataType::Int8,
304 10,
305 100,
306 128.0,
307 );
308
309 assert_eq!(expected, result);
310 }
311
312 #[test]
313 fn test_invalid_radius_for_uint8() {
314 let random_data_path = "/invalid_uint8.bin";
315 let storage_provider = VirtualStorageProvider::new_overlay(".");
316
317 let expected = Err(CMDToolError {
318 details:
319 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
320 .to_string(),
321 });
322 let result = write_random_data(
323 &storage_provider,
324 random_data_path,
325 DataType::Uint8,
326 10,
327 100,
328 150.0,
329 );
330
331 assert_eq!(expected, result);
332 }
333
334 #[test]
335 fn test_small_dataset() {
336 let random_data_path = "/small_data.bin";
337 let storage_provider = VirtualStorageProvider::new_overlay(".");
338
339 let result = write_random_data(
341 &storage_provider,
342 random_data_path,
343 DataType::Float,
344 5,
345 10,
346 100.0,
347 );
348
349 assert!(result.is_ok());
350 assert!(storage_provider.exists(random_data_path));
351 }
352
353 #[test]
354 fn test_large_block_size() {
355 let random_data_path = "/large_blocks.bin";
356 let storage_provider = VirtualStorageProvider::new_overlay(".");
357
358 let result = write_random_data(
360 &storage_provider,
361 random_data_path,
362 DataType::Float,
363 10,
364 200000, 100.0,
366 );
367
368 assert!(result.is_ok());
369 assert!(storage_provider.exists(random_data_path));
370 }
371}