Skip to main content

diskann_tools/utils/
random_data_generator.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::io::{BufWriter, Write};
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use diskann_providers::storage::StorageWriteProvider;
10use diskann_utils::{io::Metadata, sampling::random::WithApproximateNorm};
11use diskann_vector::Half;
12
13use crate::utils::{CMDResult, CMDToolError, DataType};
14
15type WriteVectorMethodType<T> = Box<dyn Fn(&mut BufWriter<T>, &Vec<f32>) -> CMDResult<bool>>;
16
17/**
18Generate random points around a sphere with the specified radius and write them to a file
19
20When data_type is int8 or uint8 radius must be <= 127.0
21 */
22#[allow(clippy::panic)]
23pub fn write_random_data<StorageProvider: StorageWriteProvider>(
24    storage_provider: &StorageProvider,
25    output_file: &str,
26    data_type: DataType,
27    number_of_dimensions: usize,
28    number_of_vectors: u64,
29    radius: f32,
30) -> CMDResult<()> {
31    if (data_type == DataType::Int8 || data_type == DataType::Uint8)
32        && (radius > 127.0 || radius <= 0.0)
33    {
34        return Err(CMDToolError {
35            details:
36            "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
37                .to_string(),
38        });
39    }
40
41    let file = storage_provider.create_for_write(output_file)?;
42    let writer = BufWriter::new(file);
43
44    write_random_data_writer(
45        writer,
46        data_type,
47        number_of_dimensions,
48        number_of_vectors,
49        radius,
50    )
51}
52
53/**
54Generate random points around a sphere with the specified radius and write them to a file
55
56When data_type is int8 or uint8 radius must be <= 127.0
57*/
58#[allow(clippy::panic)]
59fn write_random_data_writer<T: Sized + Write>(
60    mut writer: BufWriter<T>,
61    data_type: DataType,
62    number_of_dimensions: usize,
63    number_of_vectors: u64,
64    radius: f32,
65) -> CMDResult<()> {
66    if (data_type == DataType::Int8 || data_type == DataType::Uint8)
67        && (radius > 127.0 || radius <= 0.0)
68    {
69        return Err(CMDToolError {
70            details:
71                "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
72                    .to_string(),
73        });
74    }
75
76    Metadata::new(number_of_vectors, number_of_dimensions)?.write(&mut writer)?;
77
78    let block_size = 131072;
79    let nblks = u64::div_ceil(number_of_vectors, block_size);
80    println!("# blks: {}", nblks);
81
82    for i in 0..nblks {
83        let cblk_size = std::cmp::min(number_of_vectors - i * block_size, block_size);
84
85        // Each data has special code to write it out.  These methods convert the random data
86        // from the input vector into the specific datatype and writes it out to the data file.
87        let write_method: WriteVectorMethodType<T> = match data_type {
88            DataType::Float => Box::new(
89                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
90                    let mut found_nonzero = false;
91                    for value in vector {
92                        writer.write_f32::<LittleEndian>(*value)?;
93                        found_nonzero = found_nonzero || ((*value != 0f32) && value.is_finite());
94                    }
95                    Ok(found_nonzero)
96                },
97            ),
98            DataType::Uint8 => Box::new(
99                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
100                    let mut found_nonempty = false;
101                    // Since u8 is unsigned, add 128 to ensure non-negative before
102                    // rounding and casting
103                    for value in vector.iter().map(|&item| (item + 128.0).round() as u8) {
104                        writer.write_u8(value)?;
105
106                        // Since we add 128 to the random number to prevent negative values,
107                        // 'empty' is a vector where all indices hold 128u8.
108                        found_nonempty = found_nonempty || (value != 128u8);
109                    }
110                    Ok(found_nonempty)
111                },
112            ),
113            DataType::Int8 => Box::new(
114                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
115                    let mut found_nonzero = false;
116                    for value in vector.iter().map(|&item| item.round() as i8) {
117                        writer.write_i8(value)?;
118                        found_nonzero = found_nonzero || (value != 0i8);
119                    }
120                    Ok(found_nonzero)
121                },
122            ),
123            DataType::Fp16 => Box::new(
124                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
125                    let mut found_nonzero = false;
126                    for value in vector.iter().map(|&item| Half::from_f32(item)) {
127                        let mut buf = [0; 2];
128                        buf.clone_from_slice(value.to_le_bytes().as_slice());
129                        writer.write_all(&buf)?;
130                        found_nonzero =
131                            found_nonzero || (value != Half::from_f32(0.0) && value.is_finite());
132                    }
133                    Ok(found_nonzero)
134                },
135            ),
136        };
137
138        // Propagate errors if there are any
139        write_random_vector_block(
140            write_method,
141            &mut writer,
142            number_of_dimensions,
143            cblk_size,
144            radius,
145        )?;
146    }
147
148    // writer flushes the inner file object as part of it's flush.  File object moved
149    // to writer scope so we cannot manually call flush on it here.
150    writer.flush()?;
151
152    Ok(())
153}
154
155/**
156Writes random vectors to the specified writer.  Function generates random floats.  It is the
157responsibility of the "write_method" method argument to convert the random floats into other
158datatypes.
159
160NOTE: This generates random points on a sphere that has the specified radius
161*/
162fn write_random_vector_block<
163    F: Sized + Write,
164    T: FnMut(&mut BufWriter<F>, &Vec<f32>) -> CMDResult<bool>,
165>(
166    mut write_method: T,
167    writer: &mut BufWriter<F>,
168    number_of_dimensions: usize,
169    number_of_points: u64,
170    radius: f32,
171) -> CMDResult<()> {
172    let mut found_nonzero = false;
173
174    let mut rng = diskann_providers::utils::create_rnd_from_seed(0);
175    for _ in 0..number_of_points {
176        let vector = f32::with_approximate_norm(number_of_dimensions, radius, &mut rng);
177        // Check for non-zero after casting to final numeric types.  Do not short-circuit
178        // evaluate to ensure we always write the data.
179        found_nonzero |= write_method(writer, &vector)?;
180    }
181
182    if found_nonzero {
183        Ok(())
184    } else {
185        Err(CMDToolError {
186            details: format!(
187                "Generated all-zero vectors with radius {}. Try increasing radius",
188                radius
189            ),
190        })
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use diskann_providers::storage::VirtualStorageProvider;
197    use rstest::rstest;
198
199    use super::*;
200    use crate::utils::size_constants::{TEST_DATASET_SIZE_SMALL, TEST_NUM_DIMENSIONS_RECOMMENDED};
201
202    #[rstest]
203    fn random_data_write_success(
204        #[values(DataType::Float, DataType::Uint8, DataType::Int8)] data_type: DataType,
205        #[values(100.0, 127.0)] norm: f32,
206    ) {
207        let random_data_path = "/mydatafile.bin";
208        let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
209
210        let storage_provider = VirtualStorageProvider::new_overlay(".");
211        let result = write_random_data(
212            &storage_provider,
213            random_data_path,
214            data_type,
215            num_dimensions,
216            10000,
217            norm,
218        );
219
220        assert!(result.is_ok(), "write_random_data should succeed");
221        assert!(
222            storage_provider.exists(random_data_path),
223            "Random data file should exist"
224        );
225    }
226
227    /// Very low values of "radius" cause the random data to all be zero.
228    /// Ensure that an appropriate error is returned when invalid radius is used.
229    #[rstest]
230    #[case(DataType::Float, 0.0)]
231    #[case(DataType::Int8, 0.0)]
232    #[case(DataType::Int8, 0.1)]
233    #[case(DataType::Int8, 1.0)]
234    #[case(DataType::Uint8, 0.0)]
235    #[case(DataType::Uint8, 0.1)]
236    #[case(DataType::Uint8, 1.0)]
237    fn random_data_write_too_low_norm(#[case] data_type: DataType, #[case] radius: f32) {
238        let random_data_path = "/mydatafile.bin";
239        let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
240
241        let expected = if (data_type == DataType::Int8 || data_type == DataType::Uint8)
242            && radius <= 0.0
243        {
244            Err(CMDToolError {
245                details:
246                    "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
247                        .to_string(),
248            })
249        } else {
250            Err(CMDToolError {
251                details: format!(
252                    "Generated all-zero vectors with radius {}. Try increasing radius",
253                    radius
254                ),
255            })
256        };
257
258        let storage_provider = VirtualStorageProvider::new_overlay(".");
259        let result = write_random_data(
260            &storage_provider,
261            random_data_path,
262            data_type,
263            num_dimensions,
264            TEST_DATASET_SIZE_SMALL,
265            radius,
266        );
267
268        assert_eq!(expected, result);
269    }
270
271    #[test]
272    fn test_fp16_data_type() {
273        let random_data_path = "/fp16_data.bin";
274        let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
275
276        let storage_provider = VirtualStorageProvider::new_overlay(".");
277        let result = write_random_data(
278            &storage_provider,
279            random_data_path,
280            DataType::Fp16,
281            num_dimensions,
282            100,
283            50.0,
284        );
285
286        assert!(result.is_ok(), "write_random_data with Fp16 should succeed");
287        assert!(storage_provider.exists(random_data_path));
288    }
289
290    #[test]
291    fn test_invalid_radius_for_int8() {
292        let random_data_path = "/invalid_int8.bin";
293        let storage_provider = VirtualStorageProvider::new_overlay(".");
294
295        let expected = Err(CMDToolError {
296            details:
297                "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
298                    .to_string(),
299        });
300        let result = write_random_data(
301            &storage_provider,
302            random_data_path,
303            DataType::Int8,
304            10,
305            100,
306            128.0,
307        );
308
309        assert_eq!(expected, result);
310    }
311
312    #[test]
313    fn test_invalid_radius_for_uint8() {
314        let random_data_path = "/invalid_uint8.bin";
315        let storage_provider = VirtualStorageProvider::new_overlay(".");
316
317        let expected = Err(CMDToolError {
318            details:
319                "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
320                    .to_string(),
321        });
322        let result = write_random_data(
323            &storage_provider,
324            random_data_path,
325            DataType::Uint8,
326            10,
327            100,
328            150.0,
329        );
330
331        assert_eq!(expected, result);
332    }
333
334    #[test]
335    fn test_small_dataset() {
336        let random_data_path = "/small_data.bin";
337        let storage_provider = VirtualStorageProvider::new_overlay(".");
338
339        // Test with very small dataset
340        let result = write_random_data(
341            &storage_provider,
342            random_data_path,
343            DataType::Float,
344            5,
345            10,
346            100.0,
347        );
348
349        assert!(result.is_ok());
350        assert!(storage_provider.exists(random_data_path));
351    }
352
353    #[test]
354    fn test_large_block_size() {
355        let random_data_path = "/large_blocks.bin";
356        let storage_provider = VirtualStorageProvider::new_overlay(".");
357
358        // Test with more than one block
359        let result = write_random_data(
360            &storage_provider,
361            random_data_path,
362            DataType::Float,
363            10,
364            200000, // More than block_size (131072)
365            100.0,
366        );
367
368        assert!(result.is_ok());
369        assert!(storage_provider.exists(random_data_path));
370    }
371}