1use std::io::{BufWriter, Write};
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use diskann_providers::{storage::StorageWriteProvider, utils::math_util};
10use diskann_utils::io::Metadata;
11use diskann_vector::Half;
12
13use crate::utils::{CMDResult, CMDToolError, DataType};
14
15type WriteVectorMethodType<T> = Box<dyn Fn(&mut BufWriter<T>, &Vec<f32>) -> CMDResult<bool>>;
16
17#[allow(clippy::panic)]
23pub fn write_random_data<StorageProvider: StorageWriteProvider>(
24 storage_provider: &StorageProvider,
25 output_file: &str,
26 data_type: DataType,
27 number_of_dimensions: usize,
28 number_of_vectors: u64,
29 radius: f32,
30) -> CMDResult<()> {
31 if (data_type == DataType::Int8 || data_type == DataType::Uint8)
32 && (radius > 127.0 || radius <= 0.0)
33 {
34 return Err(CMDToolError {
35 details:
36 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
37 .to_string(),
38 });
39 }
40
41 let file = storage_provider.create_for_write(output_file)?;
42 let writer = BufWriter::new(file);
43
44 write_random_data_writer(
45 writer,
46 data_type,
47 number_of_dimensions,
48 number_of_vectors,
49 radius,
50 )
51}
52
53#[allow(clippy::panic)]
59fn write_random_data_writer<T: Sized + Write>(
60 mut writer: BufWriter<T>,
61 data_type: DataType,
62 number_of_dimensions: usize,
63 number_of_vectors: u64,
64 radius: f32,
65) -> CMDResult<()> {
66 if (data_type == DataType::Int8 || data_type == DataType::Uint8)
67 && (radius > 127.0 || radius <= 0.0)
68 {
69 return Err(CMDToolError {
70 details:
71 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
72 .to_string(),
73 });
74 }
75
76 Metadata::new(number_of_vectors, number_of_dimensions)?.write(&mut writer)?;
77
78 let block_size = 131072;
79 let nblks = u64::div_ceil(number_of_vectors, block_size);
80 println!("# blks: {}", nblks);
81
82 for i in 0..nblks {
83 let cblk_size = std::cmp::min(number_of_vectors - i * block_size, block_size);
84
85 let write_method: WriteVectorMethodType<T> = match data_type {
88 DataType::Float => Box::new(
89 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
90 let mut found_nonzero = false;
91 for value in vector {
92 writer.write_f32::<LittleEndian>(*value)?;
93 found_nonzero = found_nonzero || ((*value != 0f32) && value.is_finite());
94 }
95 Ok(found_nonzero)
96 },
97 ),
98 DataType::Uint8 => Box::new(
99 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
100 let mut found_nonempty = false;
101 for value in vector.iter().map(|&item| (item + 128.0).round() as u8) {
104 writer.write_u8(value)?;
105
106 found_nonempty = found_nonempty || (value != 128u8);
109 }
110 Ok(found_nonempty)
111 },
112 ),
113 DataType::Int8 => Box::new(
114 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
115 let mut found_nonzero = false;
116 for value in vector.iter().map(|&item| item.round() as i8) {
117 writer.write_i8(value)?;
118 found_nonzero = found_nonzero || (value != 0i8);
119 }
120 Ok(found_nonzero)
121 },
122 ),
123 DataType::Fp16 => Box::new(
124 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
125 let mut found_nonzero = false;
126 for value in vector.iter().map(|&item| Half::from_f32(item)) {
127 let mut buf = [0; 2];
128 buf.clone_from_slice(value.to_le_bytes().as_slice());
129 writer.write_all(&buf)?;
130 found_nonzero =
131 found_nonzero || (value != Half::from_f32(0.0) && value.is_finite());
132 }
133 Ok(found_nonzero)
134 },
135 ),
136 };
137
138 write_random_vector_block(
140 write_method,
141 &mut writer,
142 number_of_dimensions,
143 cblk_size,
144 radius,
145 )?;
146 }
147
148 writer.flush()?;
151
152 Ok(())
153}
154
155fn write_random_vector_block<
163 F: Sized + Write,
164 T: FnMut(&mut BufWriter<F>, &Vec<f32>) -> CMDResult<bool>,
165>(
166 mut write_method: T,
167 writer: &mut BufWriter<F>,
168 number_of_dimensions: usize,
169 number_of_points: u64,
170 radius: f32,
171) -> CMDResult<()> {
172 let mut found_nonzero = false;
173
174 let vectors = math_util::generate_vectors_with_norm(
175 number_of_points as usize,
176 number_of_dimensions,
177 radius,
178 &mut diskann_providers::utils::create_rnd_from_seed(0),
179 )?;
180 for vector in vectors {
181 found_nonzero |= write_method(writer, &vector)?;
184 }
185
186 if found_nonzero {
187 Ok(())
188 } else {
189 Err(CMDToolError {
190 details: format!(
191 "Generated all-zero vectors with radius {}. Try increasing radius",
192 radius
193 ),
194 })
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use diskann_providers::storage::VirtualStorageProvider;
201 use rstest::rstest;
202
203 use super::*;
204 use crate::utils::size_constants::{TEST_DATASET_SIZE_SMALL, TEST_NUM_DIMENSIONS_RECOMMENDED};
205
206 #[rstest]
207 fn random_data_write_success(
208 #[values(DataType::Float, DataType::Uint8, DataType::Int8)] data_type: DataType,
209 #[values(100.0, 127.0)] norm: f32,
210 ) {
211 let random_data_path = "/mydatafile.bin";
212 let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
213
214 let storage_provider = VirtualStorageProvider::new_overlay(".");
215 let result = write_random_data(
216 &storage_provider,
217 random_data_path,
218 data_type,
219 num_dimensions,
220 10000,
221 norm,
222 );
223
224 assert!(result.is_ok(), "write_random_data should succeed");
225 assert!(
226 storage_provider.exists(random_data_path),
227 "Random data file should exist"
228 );
229 }
230
231 #[rstest]
234 #[case(DataType::Float, 0.0)]
235 #[case(DataType::Int8, 0.0)]
236 #[case(DataType::Int8, 0.1)]
237 #[case(DataType::Int8, 1.0)]
238 #[case(DataType::Uint8, 0.0)]
239 #[case(DataType::Uint8, 0.1)]
240 #[case(DataType::Uint8, 1.0)]
241 fn random_data_write_too_low_norm(#[case] data_type: DataType, #[case] radius: f32) {
242 let random_data_path = "/mydatafile.bin";
243 let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
244
245 let expected = if (data_type == DataType::Int8 || data_type == DataType::Uint8)
246 && radius <= 0.0
247 {
248 Err(CMDToolError {
249 details:
250 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
251 .to_string(),
252 })
253 } else {
254 Err(CMDToolError {
255 details: format!(
256 "Generated all-zero vectors with radius {}. Try increasing radius",
257 radius
258 ),
259 })
260 };
261
262 let storage_provider = VirtualStorageProvider::new_overlay(".");
263 let result = write_random_data(
264 &storage_provider,
265 random_data_path,
266 data_type,
267 num_dimensions,
268 TEST_DATASET_SIZE_SMALL,
269 radius,
270 );
271
272 assert_eq!(expected, result);
273 }
274
275 #[test]
276 fn test_fp16_data_type() {
277 let random_data_path = "/fp16_data.bin";
278 let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
279
280 let storage_provider = VirtualStorageProvider::new_overlay(".");
281 let result = write_random_data(
282 &storage_provider,
283 random_data_path,
284 DataType::Fp16,
285 num_dimensions,
286 100,
287 50.0,
288 );
289
290 assert!(result.is_ok(), "write_random_data with Fp16 should succeed");
291 assert!(storage_provider.exists(random_data_path));
292 }
293
294 #[test]
295 fn test_invalid_radius_for_int8() {
296 let random_data_path = "/invalid_int8.bin";
297 let storage_provider = VirtualStorageProvider::new_overlay(".");
298
299 let expected = Err(CMDToolError {
300 details:
301 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
302 .to_string(),
303 });
304 let result = write_random_data(
305 &storage_provider,
306 random_data_path,
307 DataType::Int8,
308 10,
309 100,
310 128.0,
311 );
312
313 assert_eq!(expected, result);
314 }
315
316 #[test]
317 fn test_invalid_radius_for_uint8() {
318 let random_data_path = "/invalid_uint8.bin";
319 let storage_provider = VirtualStorageProvider::new_overlay(".");
320
321 let expected = Err(CMDToolError {
322 details:
323 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
324 .to_string(),
325 });
326 let result = write_random_data(
327 &storage_provider,
328 random_data_path,
329 DataType::Uint8,
330 10,
331 100,
332 150.0,
333 );
334
335 assert_eq!(expected, result);
336 }
337
338 #[test]
339 fn test_small_dataset() {
340 let random_data_path = "/small_data.bin";
341 let storage_provider = VirtualStorageProvider::new_overlay(".");
342
343 let result = write_random_data(
345 &storage_provider,
346 random_data_path,
347 DataType::Float,
348 5,
349 10,
350 100.0,
351 );
352
353 assert!(result.is_ok());
354 assert!(storage_provider.exists(random_data_path));
355 }
356
357 #[test]
358 fn test_large_block_size() {
359 let random_data_path = "/large_blocks.bin";
360 let storage_provider = VirtualStorageProvider::new_overlay(".");
361
362 let result = write_random_data(
364 &storage_provider,
365 random_data_path,
366 DataType::Float,
367 10,
368 200000, 100.0,
370 );
371
372 assert!(result.is_ok());
373 assert!(storage_provider.exists(random_data_path));
374 }
375}