1use std::io::{BufWriter, Write};
7
8use byteorder::{LittleEndian, WriteBytesExt};
9use diskann_providers::{
10 storage::StorageWriteProvider,
11 utils::{math_util, write_metadata},
12};
13use diskann_vector::Half;
14
15use crate::utils::{CMDResult, CMDToolError, DataType};
16
17type WriteVectorMethodType<T> = Box<dyn Fn(&mut BufWriter<T>, &Vec<f32>) -> CMDResult<bool>>;
18
19#[allow(clippy::panic)]
25pub fn write_random_data<StorageProvider: StorageWriteProvider>(
26 storage_provider: &StorageProvider,
27 output_file: &str,
28 data_type: DataType,
29 number_of_dimensions: usize,
30 number_of_vectors: u64,
31 radius: f32,
32) -> CMDResult<()> {
33 if (data_type == DataType::Int8 || data_type == DataType::Uint8)
34 && radius > 127.0
35 && radius <= 0.0
36 {
37 return Err(CMDToolError {
38 details:
39 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
40 .to_string(),
41 });
42 }
43
44 let file = storage_provider.create_for_write(output_file)?;
45 let writer = BufWriter::new(file);
46
47 write_random_data_writer(
48 writer,
49 data_type,
50 number_of_dimensions,
51 number_of_vectors,
52 radius,
53 )
54}
55
56#[allow(clippy::panic)]
62pub fn write_random_data_writer<T: Sized + Write>(
63 mut writer: BufWriter<T>,
64 data_type: DataType,
65 number_of_dimensions: usize,
66 number_of_vectors: u64,
67 radius: f32,
68) -> CMDResult<()> {
69 if (data_type == DataType::Int8 || data_type == DataType::Uint8)
70 && radius > 127.0
71 && radius <= 0.0
72 {
73 return Err(CMDToolError {
74 details:
75 "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
76 .to_string(),
77 });
78 }
79
80 write_metadata(&mut writer, number_of_vectors, number_of_dimensions)?;
81
82 let block_size = 131072;
83 let nblks = u64::div_ceil(number_of_vectors, block_size);
84 println!("# blks: {}", nblks);
85
86 for i in 0..nblks {
87 let cblk_size = std::cmp::min(number_of_vectors - i * block_size, block_size);
88
89 let write_method: WriteVectorMethodType<T> = match data_type {
92 DataType::Float => Box::new(
93 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
94 let mut found_nonzero = false;
95 for value in vector {
96 writer.write_f32::<LittleEndian>(*value)?;
97 found_nonzero = found_nonzero || ((*value != 0f32) && value.is_finite());
98 }
99 Ok(found_nonzero)
100 },
101 ),
102 DataType::Uint8 => Box::new(
103 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
104 let mut found_nonempty = false;
105 for value in vector.iter().map(|&item| (item + 128.0).round() as u8) {
108 writer.write_u8(value)?;
109
110 found_nonempty = found_nonempty || (value != 128u8);
113 }
114 Ok(found_nonempty)
115 },
116 ),
117 DataType::Int8 => Box::new(
118 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
119 let mut found_nonzero = false;
120 for value in vector.iter().map(|&item| item.round() as i8) {
121 writer.write_i8(value)?;
122 found_nonzero = found_nonzero || (value != 0i8);
123 }
124 Ok(found_nonzero)
125 },
126 ),
127 DataType::Fp16 => Box::new(
128 |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
129 let mut found_nonzero = false;
130 for value in vector.iter().map(|&item| Half::from_f32(item)) {
131 let mut buf = [0; 2];
132 buf.clone_from_slice(value.to_le_bytes().as_slice());
133 writer.write_all(&buf)?;
134 found_nonzero =
135 found_nonzero || (value != Half::from_f32(0.0) && value.is_finite());
136 }
137 Ok(found_nonzero)
138 },
139 ),
140 };
141
142 write_random_vector_block(
144 write_method,
145 &mut writer,
146 number_of_dimensions,
147 cblk_size,
148 radius,
149 )?;
150 }
151
152 writer.flush()?;
155
156 Ok(())
157}
158
159fn write_random_vector_block<
167 F: Sized + Write,
168 T: FnMut(&mut BufWriter<F>, &Vec<f32>) -> CMDResult<bool>,
169>(
170 mut write_method: T,
171 writer: &mut BufWriter<F>,
172 number_of_dimensions: usize,
173 number_of_points: u64,
174 radius: f32,
175) -> CMDResult<()> {
176 let mut found_nonzero = false;
177
178 let vectors = math_util::generate_vectors_with_norm(
179 number_of_points as usize,
180 number_of_dimensions,
181 radius,
182 &mut diskann_providers::utils::create_rnd_from_seed(0),
183 )?;
184 for vector in vectors {
185 found_nonzero |= write_method(writer, &vector)?;
188 }
189
190 if found_nonzero {
191 Ok(())
192 } else {
193 Err(CMDToolError {
194 details: format!(
195 "Generated all-zero vectors with radius {}. Try increasing radius",
196 radius
197 ),
198 })
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use diskann_providers::storage::VirtualStorageProvider;
205 use rstest::rstest;
206
207 use super::*;
208 use crate::utils::size_constants::{TEST_DATASET_SIZE_SMALL, TEST_NUM_DIMENSIONS_RECOMMENDED};
209
210 #[rstest]
211 fn random_data_write_success(
212 #[values(DataType::Float, DataType::Uint8, DataType::Int8)] data_type: DataType,
213 #[values(100.0, 127.0)] norm: f32,
214 ) {
215 let random_data_path = "/mydatafile.bin";
216 let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
217
218 let storage_provider = VirtualStorageProvider::new_overlay(".");
219 let result = write_random_data(
220 &storage_provider,
221 random_data_path,
222 data_type,
223 num_dimensions,
224 10000,
225 norm,
226 );
227
228 assert!(result.is_ok(), "write_random_data should succeed");
229 assert!(
230 storage_provider.exists(random_data_path),
231 "Random data file should exist"
232 );
233 }
234
235 #[rstest]
238 #[case(DataType::Float, 0.0)]
239 #[case(DataType::Int8, 0.0)]
240 #[case(DataType::Int8, 0.1)]
241 #[case(DataType::Int8, 1.0)]
242 #[case(DataType::Uint8, 0.0)]
243 #[case(DataType::Uint8, 0.1)]
244 #[case(DataType::Uint8, 1.0)]
245 fn random_data_write_too_low_norm(#[case] data_type: DataType, #[case] radius: f32) {
246 let random_data_path = "/mydatafile.bin";
247 let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
248
249 let expected = Err(CMDToolError {
250 details: format!(
251 "Generated all-zero vectors with radius {}. Try increasing radius",
252 radius
253 ),
254 });
255
256 let storage_provider = VirtualStorageProvider::new_overlay(".");
257 let result = write_random_data(
258 &storage_provider,
259 random_data_path,
260 data_type,
261 num_dimensions,
262 TEST_DATASET_SIZE_SMALL,
263 radius,
264 );
265
266 assert_eq!(expected, result);
267 }
268}