Skip to main content

diskann_tools/utils/
random_data_generator.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::io::{BufWriter, Write};
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use diskann_providers::{storage::StorageWriteProvider, utils::math_util};
10use diskann_utils::io::Metadata;
11use diskann_vector::Half;
12
13use crate::utils::{CMDResult, CMDToolError, DataType};
14
15type WriteVectorMethodType<T> = Box<dyn Fn(&mut BufWriter<T>, &Vec<f32>) -> CMDResult<bool>>;
16
17/**
18Generate random points around a sphere with the specified radius and write them to a file
19
20When data_type is int8 or uint8 radius must be <= 127.0
21 */
22#[allow(clippy::panic)]
23pub fn write_random_data<StorageProvider: StorageWriteProvider>(
24    storage_provider: &StorageProvider,
25    output_file: &str,
26    data_type: DataType,
27    number_of_dimensions: usize,
28    number_of_vectors: u64,
29    radius: f32,
30) -> CMDResult<()> {
31    if (data_type == DataType::Int8 || data_type == DataType::Uint8)
32        && (radius > 127.0 || radius <= 0.0)
33    {
34        return Err(CMDToolError {
35            details:
36            "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
37                .to_string(),
38        });
39    }
40
41    let file = storage_provider.create_for_write(output_file)?;
42    let writer = BufWriter::new(file);
43
44    write_random_data_writer(
45        writer,
46        data_type,
47        number_of_dimensions,
48        number_of_vectors,
49        radius,
50    )
51}
52
53/**
54Generate random points around a sphere with the specified radius and write them to a file
55
56When data_type is int8 or uint8 radius must be <= 127.0
57*/
58#[allow(clippy::panic)]
59fn write_random_data_writer<T: Sized + Write>(
60    mut writer: BufWriter<T>,
61    data_type: DataType,
62    number_of_dimensions: usize,
63    number_of_vectors: u64,
64    radius: f32,
65) -> CMDResult<()> {
66    if (data_type == DataType::Int8 || data_type == DataType::Uint8)
67        && (radius > 127.0 || radius <= 0.0)
68    {
69        return Err(CMDToolError {
70            details:
71                "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
72                    .to_string(),
73        });
74    }
75
76    Metadata::new(number_of_vectors, number_of_dimensions)?.write(&mut writer)?;
77
78    let block_size = 131072;
79    let nblks = u64::div_ceil(number_of_vectors, block_size);
80    println!("# blks: {}", nblks);
81
82    for i in 0..nblks {
83        let cblk_size = std::cmp::min(number_of_vectors - i * block_size, block_size);
84
85        // Each data has special code to write it out.  These methods convert the random data
86        // from the input vector into the specific datatype and writes it out to the data file.
87        let write_method: WriteVectorMethodType<T> = match data_type {
88            DataType::Float => Box::new(
89                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
90                    let mut found_nonzero = false;
91                    for value in vector {
92                        writer.write_f32::<LittleEndian>(*value)?;
93                        found_nonzero = found_nonzero || ((*value != 0f32) && value.is_finite());
94                    }
95                    Ok(found_nonzero)
96                },
97            ),
98            DataType::Uint8 => Box::new(
99                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
100                    let mut found_nonempty = false;
101                    // Since u8 is unsigned, add 128 to ensure non-negative before
102                    // rounding and casting
103                    for value in vector.iter().map(|&item| (item + 128.0).round() as u8) {
104                        writer.write_u8(value)?;
105
106                        // Since we add 128 to the random number to prevent negative values,
107                        // 'empty' is a vector where all indices hold 128u8.
108                        found_nonempty = found_nonempty || (value != 128u8);
109                    }
110                    Ok(found_nonempty)
111                },
112            ),
113            DataType::Int8 => Box::new(
114                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
115                    let mut found_nonzero = false;
116                    for value in vector.iter().map(|&item| item.round() as i8) {
117                        writer.write_i8(value)?;
118                        found_nonzero = found_nonzero || (value != 0i8);
119                    }
120                    Ok(found_nonzero)
121                },
122            ),
123            DataType::Fp16 => Box::new(
124                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
125                    let mut found_nonzero = false;
126                    for value in vector.iter().map(|&item| Half::from_f32(item)) {
127                        let mut buf = [0; 2];
128                        buf.clone_from_slice(value.to_le_bytes().as_slice());
129                        writer.write_all(&buf)?;
130                        found_nonzero =
131                            found_nonzero || (value != Half::from_f32(0.0) && value.is_finite());
132                    }
133                    Ok(found_nonzero)
134                },
135            ),
136        };
137
138        // Propagate errors if there are any
139        write_random_vector_block(
140            write_method,
141            &mut writer,
142            number_of_dimensions,
143            cblk_size,
144            radius,
145        )?;
146    }
147
148    // writer flushes the inner file object as part of it's flush.  File object moved
149    // to writer scope so we cannot manually call flush on it here.
150    writer.flush()?;
151
152    Ok(())
153}
154
155/**
156Writes random vectors to the specified writer.  Function generates random floats.  It is the
157responsibility of the "write_method" method argument to convert the random floats into other
158datatypes.
159
160NOTE: This generates random points on a sphere that has the specified radius
161*/
162fn write_random_vector_block<
163    F: Sized + Write,
164    T: FnMut(&mut BufWriter<F>, &Vec<f32>) -> CMDResult<bool>,
165>(
166    mut write_method: T,
167    writer: &mut BufWriter<F>,
168    number_of_dimensions: usize,
169    number_of_points: u64,
170    radius: f32,
171) -> CMDResult<()> {
172    let mut found_nonzero = false;
173
174    let vectors = math_util::generate_vectors_with_norm(
175        number_of_points as usize,
176        number_of_dimensions,
177        radius,
178        &mut diskann_providers::utils::create_rnd_from_seed(0),
179    )?;
180    for vector in vectors {
181        // Check for non-zero after casting to final numeric types.  Do not short-circuit
182        // evaluate to ensure we always write the data.
183        found_nonzero |= write_method(writer, &vector)?;
184    }
185
186    if found_nonzero {
187        Ok(())
188    } else {
189        Err(CMDToolError {
190            details: format!(
191                "Generated all-zero vectors with radius {}. Try increasing radius",
192                radius
193            ),
194        })
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use diskann_providers::storage::VirtualStorageProvider;
201    use rstest::rstest;
202
203    use super::*;
204    use crate::utils::size_constants::{TEST_DATASET_SIZE_SMALL, TEST_NUM_DIMENSIONS_RECOMMENDED};
205
206    #[rstest]
207    fn random_data_write_success(
208        #[values(DataType::Float, DataType::Uint8, DataType::Int8)] data_type: DataType,
209        #[values(100.0, 127.0)] norm: f32,
210    ) {
211        let random_data_path = "/mydatafile.bin";
212        let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
213
214        let storage_provider = VirtualStorageProvider::new_overlay(".");
215        let result = write_random_data(
216            &storage_provider,
217            random_data_path,
218            data_type,
219            num_dimensions,
220            10000,
221            norm,
222        );
223
224        assert!(result.is_ok(), "write_random_data should succeed");
225        assert!(
226            storage_provider.exists(random_data_path),
227            "Random data file should exist"
228        );
229    }
230
231    /// Very low values of "radius" cause the random data to all be zero.
232    /// Ensure that an appropriate error is returned when invalid radius is used.
233    #[rstest]
234    #[case(DataType::Float, 0.0)]
235    #[case(DataType::Int8, 0.0)]
236    #[case(DataType::Int8, 0.1)]
237    #[case(DataType::Int8, 1.0)]
238    #[case(DataType::Uint8, 0.0)]
239    #[case(DataType::Uint8, 0.1)]
240    #[case(DataType::Uint8, 1.0)]
241    fn random_data_write_too_low_norm(#[case] data_type: DataType, #[case] radius: f32) {
242        let random_data_path = "/mydatafile.bin";
243        let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
244
245        let expected = if (data_type == DataType::Int8 || data_type == DataType::Uint8)
246            && radius <= 0.0
247        {
248            Err(CMDToolError {
249                details:
250                    "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
251                        .to_string(),
252            })
253        } else {
254            Err(CMDToolError {
255                details: format!(
256                    "Generated all-zero vectors with radius {}. Try increasing radius",
257                    radius
258                ),
259            })
260        };
261
262        let storage_provider = VirtualStorageProvider::new_overlay(".");
263        let result = write_random_data(
264            &storage_provider,
265            random_data_path,
266            data_type,
267            num_dimensions,
268            TEST_DATASET_SIZE_SMALL,
269            radius,
270        );
271
272        assert_eq!(expected, result);
273    }
274
275    #[test]
276    fn test_fp16_data_type() {
277        let random_data_path = "/fp16_data.bin";
278        let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
279
280        let storage_provider = VirtualStorageProvider::new_overlay(".");
281        let result = write_random_data(
282            &storage_provider,
283            random_data_path,
284            DataType::Fp16,
285            num_dimensions,
286            100,
287            50.0,
288        );
289
290        assert!(result.is_ok(), "write_random_data with Fp16 should succeed");
291        assert!(storage_provider.exists(random_data_path));
292    }
293
294    #[test]
295    fn test_invalid_radius_for_int8() {
296        let random_data_path = "/invalid_int8.bin";
297        let storage_provider = VirtualStorageProvider::new_overlay(".");
298
299        let expected = Err(CMDToolError {
300            details:
301                "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
302                    .to_string(),
303        });
304        let result = write_random_data(
305            &storage_provider,
306            random_data_path,
307            DataType::Int8,
308            10,
309            100,
310            128.0,
311        );
312
313        assert_eq!(expected, result);
314    }
315
316    #[test]
317    fn test_invalid_radius_for_uint8() {
318        let random_data_path = "/invalid_uint8.bin";
319        let storage_provider = VirtualStorageProvider::new_overlay(".");
320
321        let expected = Err(CMDToolError {
322            details:
323                "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
324                    .to_string(),
325        });
326        let result = write_random_data(
327            &storage_provider,
328            random_data_path,
329            DataType::Uint8,
330            10,
331            100,
332            150.0,
333        );
334
335        assert_eq!(expected, result);
336    }
337
338    #[test]
339    fn test_small_dataset() {
340        let random_data_path = "/small_data.bin";
341        let storage_provider = VirtualStorageProvider::new_overlay(".");
342
343        // Test with very small dataset
344        let result = write_random_data(
345            &storage_provider,
346            random_data_path,
347            DataType::Float,
348            5,
349            10,
350            100.0,
351        );
352
353        assert!(result.is_ok());
354        assert!(storage_provider.exists(random_data_path));
355    }
356
357    #[test]
358    fn test_large_block_size() {
359        let random_data_path = "/large_blocks.bin";
360        let storage_provider = VirtualStorageProvider::new_overlay(".");
361
362        // Test with more than one block
363        let result = write_random_data(
364            &storage_provider,
365            random_data_path,
366            DataType::Float,
367            10,
368            200000, // More than block_size (131072)
369            100.0,
370        );
371
372        assert!(result.is_ok());
373        assert!(storage_provider.exists(random_data_path));
374    }
375}