Skip to main content

diskann_tools/utils/
build_pq.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann::ANNResult;
7use diskann_providers::storage::StorageReadProvider;
8use diskann_providers::{
9    model::{
10        graph::traits::GraphDataType, GeneratePivotArguments, MAX_PQ_TRAINING_SET_SIZE,
11        NUM_KMEANS_REPS_PQ, NUM_PQ_CENTROIDS,
12    },
13    storage::{
14        get_disk_index_compressed_pq_file, get_disk_index_pq_pivot_file, FileStorageProvider,
15        PQStorage,
16    },
17    utils::{load_metadata_from_file, Timer},
18};
19use diskann_vector::distance::Metric;
20use tracing::info;
21
22pub struct BuildPQParameters<'a> {
23    pub metric: Metric,
24    pub data_path: &'a str,
25    pub index_path_prefix: &'a str,
26    pub num_threads: usize,
27    pub p_val: f64,
28    pub pq_bytes: f64,
29}
30
31pub fn build_pq<Data: GraphDataType>(
32    storage_provider: &impl StorageReadProvider,
33    parameters: BuildPQParameters,
34) -> ANNResult<()> {
35    let num_pq_chunks = parameters.pq_bytes as usize;
36
37    let data_path = parameters.data_path;
38    let disk_pq_pivot_path = get_disk_index_pq_pivot_file(parameters.index_path_prefix);
39    let disk_pq_compressed_data_path =
40        get_disk_index_compressed_pq_file(parameters.index_path_prefix);
41
42    let mut pq_storage = PQStorage::new(
43        &disk_pq_pivot_path,
44        &disk_pq_compressed_data_path,
45        Some(data_path),
46    );
47
48    let metadata = load_metadata_from_file(storage_provider, parameters.data_path)?;
49    info!(
50        "Compressing dim-{} data into {} chunks(bytes) for PQ",
51        metadata.ndims, num_pq_chunks
52    );
53
54    let p_val = MAX_PQ_TRAINING_SET_SIZE / (metadata.npoints as f64);
55
56    let timer = Timer::new();
57    let storage_provider = FileStorageProvider;
58    let random_provider = diskann_providers::utils::create_rnd_provider_from_seed(42);
59
60    let (mut train_data_vector, num_train, train_dim) = pq_storage
61        .get_random_train_data_slice::<Data::VectorDataType, _>(
62            p_val,
63            &storage_provider,
64            &mut random_provider.create_rnd(),
65        )?;
66
67    diskann_providers::model::pq::generate_pq_pivots(
68        GeneratePivotArguments::new(
69            num_train,
70            train_dim,
71            NUM_PQ_CENTROIDS,
72            num_pq_chunks,
73            NUM_KMEANS_REPS_PQ,
74            false,
75        )?,
76        &mut train_data_vector,
77        &pq_storage,
78        &storage_provider,
79        random_provider,
80        parameters.num_threads,
81    )?;
82
83    diskann_providers::model::pq::generate_pq_data_from_pivots::<f32, _, _>(
84        NUM_PQ_CENTROIDS,
85        num_pq_chunks,
86        &mut pq_storage,
87        &storage_provider,
88        false,
89        0,
90        parameters.num_threads,
91    )?;
92
93    info!(
94         "PQ build completed in {:.3} seconds, {:.3}B cycles, {:.3}% CPU time, peak memory {:.3} GBs for {} chunks, using {} threads",
95         timer.elapsed_seconds(),
96         timer.elapsed_gcycles(),
97         timer.get_average_cpu_time_in_percents(),
98         timer.get_peak_memory_usage(),
99         num_pq_chunks,
100         parameters.num_threads
101     );
102
103    Ok(())
104}