diskann-tools 0.51.0

DiskANN is a fast approximate nearest neighbor search library for high dimensional data
Documentation
/*
 * Copyright (c) Microsoft Corporation.
 * Licensed under the MIT license.
 */

use std::io::{BufWriter, Write};

use byteorder::{LittleEndian, WriteBytesExt};
use diskann_providers::storage::StorageWriteProvider;
use diskann_utils::{io::Metadata, sampling::random::WithApproximateNorm};
use diskann_vector::Half;

use crate::utils::{CMDResult, CMDToolError, DataType};

type WriteVectorMethodType<T> = Box<dyn Fn(&mut BufWriter<T>, &Vec<f32>) -> CMDResult<bool>>;

/**
Generate random points around a sphere with the specified radius and write them to a file

When data_type is int8 or uint8 radius must be <= 127.0
 */
#[allow(clippy::panic)]
pub fn write_random_data<StorageProvider: StorageWriteProvider>(
    storage_provider: &StorageProvider,
    output_file: &str,
    data_type: DataType,
    number_of_dimensions: usize,
    number_of_vectors: u64,
    radius: f32,
) -> CMDResult<()> {
    if (data_type == DataType::Int8 || data_type == DataType::Uint8)
        && (radius > 127.0 || radius <= 0.0)
    {
        return Err(CMDToolError {
            details:
            "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
                .to_string(),
        });
    }

    let file = storage_provider.create_for_write(output_file)?;
    let writer = BufWriter::new(file);

    write_random_data_writer(
        writer,
        data_type,
        number_of_dimensions,
        number_of_vectors,
        radius,
    )
}

/**
Generate random points around a sphere with the specified radius and write them to a file

When data_type is int8 or uint8 radius must be <= 127.0
*/
#[allow(clippy::panic)]
fn write_random_data_writer<T: Sized + Write>(
    mut writer: BufWriter<T>,
    data_type: DataType,
    number_of_dimensions: usize,
    number_of_vectors: u64,
    radius: f32,
) -> CMDResult<()> {
    if (data_type == DataType::Int8 || data_type == DataType::Uint8)
        && (radius > 127.0 || radius <= 0.0)
    {
        return Err(CMDToolError {
            details:
                "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
                    .to_string(),
        });
    }

    Metadata::new(number_of_vectors, number_of_dimensions)?.write(&mut writer)?;

    let block_size = 131072;
    let nblks = u64::div_ceil(number_of_vectors, block_size);
    println!("# blks: {}", nblks);

    for i in 0..nblks {
        let cblk_size = std::cmp::min(number_of_vectors - i * block_size, block_size);

        // Each data has special code to write it out.  These methods convert the random data
        // from the input vector into the specific datatype and writes it out to the data file.
        let write_method: WriteVectorMethodType<T> = match data_type {
            DataType::Float => Box::new(
                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
                    let mut found_nonzero = false;
                    for value in vector {
                        writer.write_f32::<LittleEndian>(*value)?;
                        found_nonzero = found_nonzero || ((*value != 0f32) && value.is_finite());
                    }
                    Ok(found_nonzero)
                },
            ),
            DataType::Uint8 => Box::new(
                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
                    let mut found_nonempty = false;
                    // Since u8 is unsigned, add 128 to ensure non-negative before
                    // rounding and casting
                    for value in vector.iter().map(|&item| (item + 128.0).round() as u8) {
                        writer.write_u8(value)?;

                        // Since we add 128 to the random number to prevent negative values,
                        // 'empty' is a vector where all indices hold 128u8.
                        found_nonempty = found_nonempty || (value != 128u8);
                    }
                    Ok(found_nonempty)
                },
            ),
            DataType::Int8 => Box::new(
                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
                    let mut found_nonzero = false;
                    for value in vector.iter().map(|&item| item.round() as i8) {
                        writer.write_i8(value)?;
                        found_nonzero = found_nonzero || (value != 0i8);
                    }
                    Ok(found_nonzero)
                },
            ),
            DataType::Fp16 => Box::new(
                |writer: &mut BufWriter<T>, vector: &Vec<f32>| -> CMDResult<bool> {
                    let mut found_nonzero = false;
                    for value in vector.iter().map(|&item| Half::from_f32(item)) {
                        let mut buf = [0; 2];
                        buf.clone_from_slice(value.to_le_bytes().as_slice());
                        writer.write_all(&buf)?;
                        found_nonzero =
                            found_nonzero || (value != Half::from_f32(0.0) && value.is_finite());
                    }
                    Ok(found_nonzero)
                },
            ),
        };

        // Propagate errors if there are any
        write_random_vector_block(
            write_method,
            &mut writer,
            number_of_dimensions,
            cblk_size,
            radius,
        )?;
    }

    // writer flushes the inner file object as part of it's flush.  File object moved
    // to writer scope so we cannot manually call flush on it here.
    writer.flush()?;

    Ok(())
}

/**
Writes random vectors to the specified writer.  Function generates random floats.  It is the
responsibility of the "write_method" method argument to convert the random floats into other
datatypes.

NOTE: This generates random points on a sphere that has the specified radius
*/
fn write_random_vector_block<
    F: Sized + Write,
    T: FnMut(&mut BufWriter<F>, &Vec<f32>) -> CMDResult<bool>,
>(
    mut write_method: T,
    writer: &mut BufWriter<F>,
    number_of_dimensions: usize,
    number_of_points: u64,
    radius: f32,
) -> CMDResult<()> {
    let mut found_nonzero = false;

    let mut rng = diskann_providers::utils::create_rnd_from_seed(0);
    for _ in 0..number_of_points {
        let vector = f32::with_approximate_norm(number_of_dimensions, radius, &mut rng);
        // Check for non-zero after casting to final numeric types.  Do not short-circuit
        // evaluate to ensure we always write the data.
        found_nonzero |= write_method(writer, &vector)?;
    }

    if found_nonzero {
        Ok(())
    } else {
        Err(CMDToolError {
            details: format!(
                "Generated all-zero vectors with radius {}. Try increasing radius",
                radius
            ),
        })
    }
}

#[cfg(test)]
mod tests {
    use diskann_providers::storage::VirtualStorageProvider;
    use rstest::rstest;

    use super::*;
    use crate::utils::size_constants::{TEST_DATASET_SIZE_SMALL, TEST_NUM_DIMENSIONS_RECOMMENDED};

    #[rstest]
    fn random_data_write_success(
        #[values(DataType::Float, DataType::Uint8, DataType::Int8)] data_type: DataType,
        #[values(100.0, 127.0)] norm: f32,
    ) {
        let random_data_path = "/mydatafile.bin";
        let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;

        let storage_provider = VirtualStorageProvider::new_overlay(".");
        let result = write_random_data(
            &storage_provider,
            random_data_path,
            data_type,
            num_dimensions,
            10000,
            norm,
        );

        assert!(result.is_ok(), "write_random_data should succeed");
        assert!(
            storage_provider.exists(random_data_path),
            "Random data file should exist"
        );
    }

    /// Very low values of "radius" cause the random data to all be zero.
    /// Ensure that an appropriate error is returned when invalid radius is used.
    #[rstest]
    #[case(DataType::Float, 0.0)]
    #[case(DataType::Int8, 0.0)]
    #[case(DataType::Int8, 0.1)]
    #[case(DataType::Int8, 1.0)]
    #[case(DataType::Uint8, 0.0)]
    #[case(DataType::Uint8, 0.1)]
    #[case(DataType::Uint8, 1.0)]
    fn random_data_write_too_low_norm(#[case] data_type: DataType, #[case] radius: f32) {
        let random_data_path = "/mydatafile.bin";
        let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;

        let expected = if (data_type == DataType::Int8 || data_type == DataType::Uint8)
            && radius <= 0.0
        {
            Err(CMDToolError {
                details:
                    "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
                        .to_string(),
            })
        } else {
            Err(CMDToolError {
                details: format!(
                    "Generated all-zero vectors with radius {}. Try increasing radius",
                    radius
                ),
            })
        };

        let storage_provider = VirtualStorageProvider::new_overlay(".");
        let result = write_random_data(
            &storage_provider,
            random_data_path,
            data_type,
            num_dimensions,
            TEST_DATASET_SIZE_SMALL,
            radius,
        );

        assert_eq!(expected, result);
    }

    #[test]
    fn test_fp16_data_type() {
        let random_data_path = "/fp16_data.bin";
        let num_dimensions = TEST_NUM_DIMENSIONS_RECOMMENDED;

        let storage_provider = VirtualStorageProvider::new_overlay(".");
        let result = write_random_data(
            &storage_provider,
            random_data_path,
            DataType::Fp16,
            num_dimensions,
            100,
            50.0,
        );

        assert!(result.is_ok(), "write_random_data with Fp16 should succeed");
        assert!(storage_provider.exists(random_data_path));
    }

    #[test]
    fn test_invalid_radius_for_int8() {
        let random_data_path = "/invalid_int8.bin";
        let storage_provider = VirtualStorageProvider::new_overlay(".");

        let expected = Err(CMDToolError {
            details:
                "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
                    .to_string(),
        });
        let result = write_random_data(
            &storage_provider,
            random_data_path,
            DataType::Int8,
            10,
            100,
            128.0,
        );

        assert_eq!(expected, result);
    }

    #[test]
    fn test_invalid_radius_for_uint8() {
        let random_data_path = "/invalid_uint8.bin";
        let storage_provider = VirtualStorageProvider::new_overlay(".");

        let expected = Err(CMDToolError {
            details:
                "Error: for int8/uint8 datatypes, radius (L2 norm) cannot be greater than 127 and less than or equal to 0"
                    .to_string(),
        });
        let result = write_random_data(
            &storage_provider,
            random_data_path,
            DataType::Uint8,
            10,
            100,
            150.0,
        );

        assert_eq!(expected, result);
    }

    #[test]
    fn test_small_dataset() {
        let random_data_path = "/small_data.bin";
        let storage_provider = VirtualStorageProvider::new_overlay(".");

        // Test with very small dataset
        let result = write_random_data(
            &storage_provider,
            random_data_path,
            DataType::Float,
            5,
            10,
            100.0,
        );

        assert!(result.is_ok());
        assert!(storage_provider.exists(random_data_path));
    }

    #[test]
    fn test_large_block_size() {
        let random_data_path = "/large_blocks.bin";
        let storage_provider = VirtualStorageProvider::new_overlay(".");

        // Test with more than one block
        let result = write_random_data(
            &storage_provider,
            random_data_path,
            DataType::Float,
            10,
            200000, // More than block_size (131072)
            100.0,
        );

        assert!(result.is_ok());
        assert!(storage_provider.exists(random_data_path));
    }
}