Skip to main content

diskann_tools/utils/
random_data_generator.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::io::{BufWriter, Write};
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use diskann_providers::{
10    storage::StorageWriteProvider,
11    utils::{math_util, write_metadata},
12};
13use diskann_vector::Half;
14
15use crate::utils::{CMDResult, CMDToolError, DataType};
16
17type WriteVectorMethodType<T> = Box<dyn Fn(&mut BufWriter<T>, &Vec<f32>) -> CMDResult<bool>>;
18
19/**
20Generate random points around a sphere with the specified radius and write them to a file
21
22When data_type is int8 or uint8 radius must be <= 127.0
23 */
24#[allow(clippy::panic)]
25pub fn write_random_data<StorageProvider: StorageWriteProvider>(
26    storage_provider: &StorageProvider,
27    output_file: &str,
28    data_type: DataType,
29    number_of_dimensions: usize,
30    number_of_vectors: u64,
31    radius: f32,
32) -> CMDResult<()> {
33    if (data_type == DataType::Int8 || data_type == DataType::Uint8)
34        && radius > 127.0
35        && radius <= 0.0
36    {
37        return Err(CMDToolError {
38            details:
39            "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
40                .to_string(),
41        });
42    }
43
44    let file = storage_provider.create_for_write(output_file)?;
45    let writer = BufWriter::new(file);
46
47    write_random_data_writer(
48        writer,
49        data_type,
50        number_of_dimensions,
51        number_of_vectors,
52        radius,
53    )
54}
55
56/**
57Generate random points around a sphere with the specified radius and write them to a file
58
59When data_type is int8 or uint8 radius must be <= 127.0
60*/
61#[allow(clippy::panic)]
62pub fn write_random_data_writer<T: Sized + Write>(
63    mut writer: BufWriter<T>,
64    data_type: DataType,
65    number_of_dimensions: usize,
66    number_of_vectors: u64,
67    radius: f32,
68) -> CMDResult<()> {
69    if (data_type == DataType::Int8 || data_type == DataType::Uint8)
70        && radius > 127.0
71        && radius <= 0.0
72    {
73        return Err(CMDToolError {
74            details:
75                "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
76                    .to_string(),
77        });
78    }
79
80    write_metadata(&mut writer, number_of_vectors, number_of_dimensions)?;
81
82    let block_size = 131072;
83    let nblks = u64::div_ceil(number_of_vectors, block_size);
84    println!("# blks: {}", nblks);
85
86    for i in 0..nblks {
87        let cblk_size = std::cmp::min(number_of_vectors - i * block_size, block_size);
88
89        // Each data has special code to write it out.  These methods convert the random data
90        // from the input vector into the specific datatype and writes it out to the data file.
91        let write_method: WriteVectorMethodType<T> = match data_type {
92            DataType::Float => Box::new(
93                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
94                    let mut found_nonzero = false;
95                    for value in vector {
96                        writer.write_f32::<LittleEndian>(*value)?;
97                        found_nonzero = found_nonzero || ((*value != 0f32) && value.is_finite());
98                    }
99                    Ok(found_nonzero)
100                },
101            ),
102            DataType::Uint8 => Box::new(
103                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
104                    let mut found_nonempty = false;
105                    // Since u8 is unsigned, add 128 to ensure non-negative before
106                    // rounding and casting
107                    for value in vector.iter().map(|&item| (item + 128.0).round() as u8) {
108                        writer.write_u8(value)?;
109
110                        // Since we add 128 to the random number to prevent negative values,
111                        // 'empty' is a vector where all indices hold 128u8.
112                        found_nonempty = found_nonempty || (value != 128u8);
113                    }
114                    Ok(found_nonempty)
115                },
116            ),
117            DataType::Int8 => Box::new(
118                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
119                    let mut found_nonzero = false;
120                    for value in vector.iter().map(|&item| item.round() as i8) {
121                        writer.write_i8(value)?;
122                        found_nonzero = found_nonzero || (value != 0i8);
123                    }
124                    Ok(found_nonzero)
125                },
126            ),
127            DataType::Fp16 => Box::new(
128                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
129                    let mut found_nonzero = false;
130                    for value in vector.iter().map(|&item| Half::from_f32(item)) {
131                        let mut buf = [0; 2];
132                        buf.clone_from_slice(value.to_le_bytes().as_slice());
133                        writer.write_all(&buf)?;
134                        found_nonzero =
135                            found_nonzero || (value != Half::from_f32(0.0) && value.is_finite());
136                    }
137                    Ok(found_nonzero)
138                },
139            ),
140        };
141
142        // Propagate errors if there are any
143        write_random_vector_block(
144            write_method,
145            &mut writer,
146            number_of_dimensions,
147            cblk_size,
148            radius,
149        )?;
150    }
151
152    // writer flushes the inner file object as part of it's flush.  File object moved
153    // to writer scope so we cannot manually call flush on it here.
154    writer.flush()?;
155
156    Ok(())
157}
158
159/**
160Writes random vectors to the specified writer.  Function generates random floats.  It is the
161responsibility of the "write_method" method argument to convert the random floats into other
162datatypes.
163
164NOTE: This generates random points on a sphere that has the specified radius
165*/
166fn write_random_vector_block<
167    F: Sized + Write,
168    T: FnMut(&mut BufWriter<F>, &Vec<f32>) -> CMDResult<bool>,
169>(
170    mut write_method: T,
171    writer: &mut BufWriter<F>,
172    number_of_dimensions: usize,
173    number_of_points: u64,
174    radius: f32,
175) -> CMDResult<()> {
176    let mut found_nonzero = false;
177
178    let vectors = math_util::generate_vectors_with_norm(
179        number_of_points as usize,
180        number_of_dimensions,
181        radius,
182        &mut diskann_providers::utils::create_rnd_from_seed(0),
183    )?;
184    for vector in vectors {
185        // Check for non-zero after casting to final numeric types.  Do not short-circuit
186        // evaluate to ensure we always write the data.
187        found_nonzero |= write_method(writer, &vector)?;
188    }
189
190    if found_nonzero {
191        Ok(())
192    } else {
193        Err(CMDToolError {
194            details: format!(
195                "Generated all-zero vectors with radius {}. Try increasing radius",
196                radius
197            ),
198        })
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use diskann_providers::storage::VirtualStorageProvider;
205    use rstest::rstest;
206
207    use super::*;
208    use crate::utils::size_constants::{TEST_DATASET_SIZE_SMALL, TEST_NUM_DIMENSIONS_RECOMMENDED};
209
210    #[rstest]
211    fn random_data_write_success(
212        #[values(DataType::Float, DataType::Uint8, DataType::Int8)] data_type: DataType,
213        #[values(100.0, 127.0)] norm: f32,
214    ) {
215        let random_data_path = "/mydatafile.bin";
216        let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
217
218        let storage_provider = VirtualStorageProvider::new_overlay(".");
219        let result = write_random_data(
220            &storage_provider,
221            random_data_path,
222            data_type,
223            num_dimensions,
224            10000,
225            norm,
226        );
227
228        assert!(result.is_ok(), "write_random_data should succeed");
229        assert!(
230            storage_provider.exists(random_data_path),
231            "Random data file should exist"
232        );
233    }
234
235    /// Very low values of "radius" cause the random data to all be zero.
236    /// Ensure that an appropriate error is returned when invalid radius is used.
237    #[rstest]
238    #[case(DataType::Float, 0.0)]
239    #[case(DataType::Int8, 0.0)]
240    #[case(DataType::Int8, 0.1)]
241    #[case(DataType::Int8, 1.0)]
242    #[case(DataType::Uint8, 0.0)]
243    #[case(DataType::Uint8, 0.1)]
244    #[case(DataType::Uint8, 1.0)]
245    fn random_data_write_too_low_norm(#[case] data_type: DataType, #[case] radius: f32) {
246        let random_data_path = "/mydatafile.bin";
247        let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
248
249        let expected = Err(CMDToolError {
250            details: format!(
251                "Generated all-zero vectors with radius {}. Try increasing radius",
252                radius
253            ),
254        });
255
256        let storage_provider = VirtualStorageProvider::new_overlay(".");
257        let result = write_random_data(
258            &storage_provider,
259            random_data_path,
260            data_type,
261            num_dimensions,
262            TEST_DATASET_SIZE_SMALL,
263            radius,
264        );
265
266        assert_eq!(expected, result);
267    }
268}