Skip to main content

diskann_tools/utils/
build_pq.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann::ANNResult;
7use diskann_disk::data_model::GraphDataType;
8use diskann_providers::storage::StorageReadProvider;
9use diskann_providers::{
10    model::{
11        GeneratePivotArguments, MAX_PQ_TRAINING_SET_SIZE, NUM_KMEANS_REPS_PQ, NUM_PQ_CENTROIDS,
12    },
13    storage::{
14        get_disk_index_compressed_pq_file, get_disk_index_pq_pivot_file, FileStorageProvider,
15        PQStorage,
16    },
17    utils::{create_thread_pool, load_metadata_from_file, Timer},
18};
19use diskann_vector::distance::Metric;
20use tracing::info;
21
22pub struct BuildPQParameters<'a> {
23    pub metric: Metric,
24    pub data_path: &'a str,
25    pub index_path_prefix: &'a str,
26    pub num_threads: usize,
27    pub p_val: f64,
28    pub pq_bytes: f64,
29}
30
31pub fn build_pq<Data: GraphDataType>(
32    storage_provider: &impl StorageReadProvider,
33    parameters: BuildPQParameters,
34) -> ANNResult<()> {
35    let num_pq_chunks = parameters.pq_bytes as usize;
36
37    let data_path = parameters.data_path;
38    let disk_pq_pivot_path = get_disk_index_pq_pivot_file(parameters.index_path_prefix);
39    let disk_pq_compressed_data_path =
40        get_disk_index_compressed_pq_file(parameters.index_path_prefix);
41
42    let mut pq_storage = PQStorage::new(
43        &disk_pq_pivot_path,
44        &disk_pq_compressed_data_path,
45        Some(data_path),
46    );
47
48    let metadata = load_metadata_from_file(storage_provider, parameters.data_path)?;
49    info!(
50        "Compressing dim-{} data into {} chunks(bytes) for PQ",
51        metadata.ndims(),
52        num_pq_chunks
53    );
54
55    let p_val = MAX_PQ_TRAINING_SET_SIZE / (metadata.npoints() as f64);
56
57    let timer = Timer::new();
58    let storage_provider = FileStorageProvider;
59    let random_provider = diskann_providers::utils::create_rnd_provider_from_seed(42);
60
61    let (mut train_data_vector, num_train, train_dim) = pq_storage
62        .get_random_train_data_slice::<Data::VectorDataType, _>(
63            p_val,
64            &storage_provider,
65            &mut random_provider.create_rnd(),
66        )?;
67
68    let pool = create_thread_pool(parameters.num_threads)?;
69
70    diskann_providers::model::pq::generate_pq_pivots(
71        GeneratePivotArguments::new(
72            num_train,
73            train_dim,
74            NUM_PQ_CENTROIDS,
75            num_pq_chunks,
76            NUM_KMEANS_REPS_PQ,
77            false,
78        )?,
79        &mut train_data_vector,
80        &pq_storage,
81        &storage_provider,
82        random_provider,
83        pool.as_ref(),
84    )?;
85
86    diskann_providers::model::pq::generate_pq_data_from_pivots::<f32, _>(
87        NUM_PQ_CENTROIDS,
88        num_pq_chunks,
89        &mut pq_storage,
90        &storage_provider,
91        0,
92        pool.as_ref(),
93    )?;
94
95    info!(
96         "PQ build completed in {:.3} seconds, {:.3}B cycles, {:.3}% CPU time, peak memory {:.3} GBs for {} chunks, using {} threads",
97         timer.elapsed_seconds(),
98         timer.elapsed_gcycles(),
99         timer.get_average_cpu_time_in_percents(),
100         timer.get_peak_memory_usage(),
101         num_pq_chunks,
102         parameters.num_threads
103     );
104
105    Ok(())
106}