use std::io::{BufWriter, Write};
use byteorder::{LittleEndian, WriteBytesExt};
use diskann_providers::storage::StorageWriteProvider;
use diskann_utils::{io::Metadata, sampling::random::WithApproximateNorm};
use diskann_vector::Half;
use crate::utils::{CMDResult, CMDToolError, DataType};
type WriteVectorMethodType<T> = Box<dyn Fn(&mut BufWriter<T>, &Vec<f32>) -> CMDResult<bool>>;
#[allow(clippy::panic)]
pub fn write_random_data<StorageProvider: StorageWriteProvider>(
storage_provider: &StorageProvider,
output_file: &str,
data_type: DataType,
number_of_dimensions: usize,
number_of_vectors: u64,
radius: f32,
) -> CMDResult<()> {
if (data_type == DataType::Int8 || data_type == DataType::Uint8)
&& (radius > 127.0 || radius <= 0.0)
{
return Err(CMDToolError {
details:
"Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
.to_string(),
});
}
let file = storage_provider.create_for_write(output_file)?;
let writer = BufWriter::new(file);
write_random_data_writer(
writer,
data_type,
number_of_dimensions,
number_of_vectors,
radius,
)
}
#[allow(clippy::panic)]
fn write_random_data_writer<T: Sized + Write>(
mut writer: BufWriter<T>,
data_type: DataType,
number_of_dimensions: usize,
number_of_vectors: u64,
radius: f32,
) -> CMDResult<()> {
if (data_type == DataType::Int8 || data_type == DataType::Uint8)
&& (radius > 127.0 || radius <= 0.0)
{
return Err(CMDToolError {
details:
"Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
.to_string(),
});
}
Metadata::new(number_of_vectors, number_of_dimensions)?.write(&mut writer)?;
let block_size = 131072;
let nblks = u64::div_ceil(number_of_vectors, block_size);
println!("# blks: {}", nblks);
for i in 0..nblks {
let cblk_size = std::cmp::min(number_of_vectors - i * block_size, block_size);
let write_method: WriteVectorMethodType<T> = match data_type {
DataType::Float => Box::new(
|writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
let mut found_nonzero = false;
for value in vector {
writer.write_f32::<LittleEndian>(*value)?;
found_nonzero = found_nonzero || ((*value != 0f32) && value.is_finite());
}
Ok(found_nonzero)
},
),
DataType::Uint8 => Box::new(
|writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
let mut found_nonempty = false;
for value in vector.iter().map(|&item| (item + 128.0).round() as u8) {
writer.write_u8(value)?;
found_nonempty = found_nonempty || (value != 128u8);
}
Ok(found_nonempty)
},
),
DataType::Int8 => Box::new(
|writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
let mut found_nonzero = false;
for value in vector.iter().map(|&item| item.round() as i8) {
writer.write_i8(value)?;
found_nonzero = found_nonzero || (value != 0i8);
}
Ok(found_nonzero)
},
),
DataType::Fp16 => Box::new(
|writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
let mut found_nonzero = false;
for value in vector.iter().map(|&item| Half::from_f32(item)) {
let mut buf = [0; 2];
buf.clone_from_slice(value.to_le_bytes().as_slice());
writer.write_all(&buf)?;
found_nonzero =
found_nonzero || (value != Half::from_f32(0.0) && value.is_finite());
}
Ok(found_nonzero)
},
),
};
write_random_vector_block(
write_method,
&mut writer,
number_of_dimensions,
cblk_size,
radius,
)?;
}
writer.flush()?;
Ok(())
}
fn write_random_vector_block<
F: Sized + Write,
T: FnMut(&mut BufWriter<F>, &Vec<f32>) -> CMDResult<bool>,
>(
mut write_method: T,
writer: &mut BufWriter<F>,
number_of_dimensions: usize,
number_of_points: u64,
radius: f32,
) -> CMDResult<()> {
let mut found_nonzero = false;
let mut rng = diskann_providers::utils::create_rnd_from_seed(0);
for _ in 0..number_of_points {
let vector = f32::with_approximate_norm(number_of_dimensions, radius, &mut rng);
found_nonzero |= write_method(writer, &vector)?;
}
if found_nonzero {
Ok(())
} else {
Err(CMDToolError {
details: format!(
"Generated all-zero vectors with radius {}. Try increasing radius",
radius
),
})
}
}
#[cfg(test)]
mod tests {
use diskann_providers::storage::VirtualStorageProvider;
use rstest::rstest;
use super::*;
use crate::utils::size_constants::{TEST_DATASET_SIZE_SMALL, TEST_NUM_DIMENSIONS_RECOMMENDED};
#[rstest]
fn random_data_write_success(
#[values(DataType::Float, DataType::Uint8, DataType::Int8)] data_type: DataType,
#[values(100.0, 127.0)] norm: f32,
) {
let random_data_path = "/mydatafile.bin";
let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
let storage_provider = VirtualStorageProvider::new_overlay(".");
let result = write_random_data(
&storage_provider,
random_data_path,
data_type,
num_dimensions,
10000,
norm,
);
assert!(result.is_ok(), "write_random_data should succeed");
assert!(
storage_provider.exists(random_data_path),
"Random data file should exist"
);
}
#[rstest]
#[case(DataType::Float, 0.0)]
#[case(DataType::Int8, 0.0)]
#[case(DataType::Int8, 0.1)]
#[case(DataType::Int8, 1.0)]
#[case(DataType::Uint8, 0.0)]
#[case(DataType::Uint8, 0.1)]
#[case(DataType::Uint8, 1.0)]
fn random_data_write_too_low_norm(#[case] data_type: DataType, #[case] radius: f32) {
let random_data_path = "/mydatafile.bin";
let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
let expected = if (data_type == DataType::Int8 || data_type == DataType::Uint8)
&& radius <= 0.0
{
Err(CMDToolError {
details:
"Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
.to_string(),
})
} else {
Err(CMDToolError {
details: format!(
"Generated all-zero vectors with radius {}. Try increasing radius",
radius
),
})
};
let storage_provider = VirtualStorageProvider::new_overlay(".");
let result = write_random_data(
&storage_provider,
random_data_path,
data_type,
num_dimensions,
TEST_DATASET_SIZE_SMALL,
radius,
);
assert_eq!(expected, result);
}
#[test]
fn test_fp16_data_type() {
let random_data_path = "/fp16_data.bin";
let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;
let storage_provider = VirtualStorageProvider::new_overlay(".");
let result = write_random_data(
&storage_provider,
random_data_path,
DataType::Fp16,
num_dimensions,
100,
50.0,
);
assert!(result.is_ok(), "write_random_data with Fp16 should succeed");
assert!(storage_provider.exists(random_data_path));
}
#[test]
fn test_invalid_radius_for_int8() {
let random_data_path = "/invalid_int8.bin";
let storage_provider = VirtualStorageProvider::new_overlay(".");
let expected = Err(CMDToolError {
details:
"Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
.to_string(),
});
let result = write_random_data(
&storage_provider,
random_data_path,
DataType::Int8,
10,
100,
128.0,
);
assert_eq!(expected, result);
}
#[test]
fn test_invalid_radius_for_uint8() {
let random_data_path = "/invalid_uint8.bin";
let storage_provider = VirtualStorageProvider::new_overlay(".");
let expected = Err(CMDToolError {
details:
"Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
.to_string(),
});
let result = write_random_data(
&storage_provider,
random_data_path,
DataType::Uint8,
10,
100,
150.0,
);
assert_eq!(expected, result);
}
#[test]
fn test_small_dataset() {
let random_data_path = "/small_data.bin";
let storage_provider = VirtualStorageProvider::new_overlay(".");
let result = write_random_data(
&storage_provider,
random_data_path,
DataType::Float,
5,
10,
100.0,
);
assert!(result.is_ok());
assert!(storage_provider.exists(random_data_path));
}
#[test]
fn test_large_block_size() {
let random_data_path = "/large_blocks.bin";
let storage_provider = VirtualStorageProvider::new_overlay(".");
let result = write_random_data(
&storage_provider,
random_data_path,
DataType::Float,
10,
200000, 100.0,
);
assert!(result.is_ok());
assert!(storage_provider.exists(random_data_path));
}
}