Skip to main content

diskann_disk/utils/
partition.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult};
6use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
7use diskann_providers::{
8    forward_threadpool,
9    utils::{
10        compute_closest_centers, gen_random_slice, k_meanspp_selecting_pivots, run_lloyds,
11        AsThreadPool, RayonThreadPool, READ_WRITE_BLOCK_SIZE,
12    },
13};
14use rand::Rng;
15use tracing::info;
16
17use crate::{
18    disk_index_build_parameter::BYTES_IN_GB,
19    storage::{CachedReader, CachedWriter, DiskIndexWriter},
20};
21
22/// Block size for reading/processing large files and matrices in blocks
23const BLOCK_SIZE_LARGE_FILE: u32 = 10_000;
24
25#[allow(clippy::too_many_arguments)]
26pub fn partition_with_ram_budget<T, StorageProvider, Pool, F>(
27    dataset_file: &str,
28    dim: usize,
29    sampling_rate: f64,
30    ram_budget_in_bytes: f64,
31    k_base: usize,
32    merged_index_prefix: &str,
33    storage_provider: &StorageProvider,
34    rng: &mut impl Rng,
35    pool: Pool,
36    ram_estimator: F,
37) -> ANNResult<usize>
38where
39    T: VectorRepr,
40    StorageProvider: StorageReadProvider + StorageWriteProvider,
41    Pool: AsThreadPool,
42    F: Fn(u64, u64) -> f64,
43{
44    forward_threadpool!(pool = pool);
45    // Find partition size and get pivot data
46    let (num_parts, pivot_data, train_dim) = find_partition_size::<T, StorageProvider, F>(
47        dataset_file,
48        sampling_rate,
49        ram_budget_in_bytes,
50        k_base,
51        storage_provider,
52        rng,
53        pool,
54        &ram_estimator,
55    )?;
56
57    info!("Saving shard data into clusters, with only ids");
58
59    shard_data_into_clusters_only_ids::<T, StorageProvider>(
60        dataset_file,
61        &pivot_data,
62        num_parts,
63        dim,
64        train_dim,
65        k_base,
66        merged_index_prefix,
67        storage_provider,
68        pool,
69    )?;
70
71    Ok(num_parts)
72}
73
74#[allow(clippy::too_many_arguments)]
75fn find_partition_size<T, StorageProvider, F>(
76    dataset_file: &str,
77    sampling_rate: f64,
78    ram_budget_in_bytes: f64,
79    k_base: usize,
80    storage_provider: &StorageProvider,
81    rng: &mut impl Rng,
82    pool: &RayonThreadPool,
83    ram_estimator: &F,
84) -> ANNResult<(usize, Vec<f32>, usize)>
85where
86    T: VectorRepr,
87    StorageProvider: StorageReadProvider + StorageWriteProvider,
88    F: Fn(u64, u64) -> f64,
89{
90    const MAX_K_MEANS_REPS: usize = 10;
91
92    let (train_data_float, num_train, train_dim) =
93        gen_random_slice::<T, StorageProvider>(dataset_file, sampling_rate, storage_provider, rng)?;
94    info!("Loaded {} points for train, dim: {}", num_train, train_dim);
95
96    let (test_data_float, num_test, test_dim) =
97        gen_random_slice::<T, StorageProvider>(dataset_file, sampling_rate, storage_provider, rng)?;
98    info!("Loaded {} points for test, dim: {}", num_test, test_dim);
99
100    // Calculate total points accounting for sampling rate
101    let total_points = (num_train as f64 / sampling_rate) as u64;
102    // Get initial partition count estimate
103    let initial_num_parts = estimate_initial_partition_count::<F>(
104        total_points,
105        train_dim as u64,
106        k_base,
107        ram_budget_in_bytes,
108        ram_estimator,
109    );
110
111    let mut num_parts = initial_num_parts;
112    let mut fit_in_ram = false;
113    let mut pivot_data = Vec::new();
114    // Iteratively find the right number of parts, kmeans_partitioning on training data
115    while !fit_in_ram {
116        fit_in_ram = true;
117
118        let mut max_ram_usage_in_bytes = 0.0;
119
120        pivot_data = vec![0.0; num_parts * train_dim];
121
122        // Process Global k-means for kmeans_partitioning Step
123        info!("Processing global k-means (kmeans_partitioning Step)");
124        k_meanspp_selecting_pivots(
125            &train_data_float,
126            num_train,
127            train_dim,
128            &mut pivot_data,
129            num_parts,
130            rng,
131            &mut (false),
132            pool,
133        )?;
134
135        run_lloyds(
136            &train_data_float,
137            num_train,
138            train_dim,
139            &mut pivot_data,
140            num_parts,
141            MAX_K_MEANS_REPS,
142            &mut (false),
143            pool,
144        )?;
145
146        // now pivots are ready. need to stream base points and assign them to closest clusters.
147
148        let mut cluster_sizes = Vec::new();
149        estimate_cluster_sizes(
150            &test_data_float,
151            num_test,
152            &pivot_data,
153            num_parts,
154            test_dim,
155            k_base,
156            &mut cluster_sizes,
157            pool,
158        )?;
159
160        let mut partition_stats = Vec::with_capacity(num_parts);
161        for p in &cluster_sizes {
162            // to account for the fact that p is the size of the shard over the testing sample.
163            let p = (*p as f64 / sampling_rate) as u64;
164            let cur_shard_ram_estimate_in_bytes = ram_estimator(p, train_dim as u64);
165            partition_stats.push((p, cur_shard_ram_estimate_in_bytes));
166
167            if cur_shard_ram_estimate_in_bytes > max_ram_usage_in_bytes {
168                max_ram_usage_in_bytes = cur_shard_ram_estimate_in_bytes;
169            }
170        }
171
172        info!(
173            "Partition RAM estimates (GB): {}",
174            partition_stats
175                .iter()
176                .map(|(size, ram)| format!("#{}: {:.2}", size, ram / BYTES_IN_GB))
177                .collect::<Vec<_>>()
178                .join(", ")
179        );
180
181        info!(
182            "With {} parts, max estimated RAM usage: {:.2} GB, budget given is {:.2} GB",
183            num_parts,
184            max_ram_usage_in_bytes / BYTES_IN_GB,
185            ram_budget_in_bytes / BYTES_IN_GB
186        );
187        if max_ram_usage_in_bytes > ram_budget_in_bytes {
188            fit_in_ram = false;
189            num_parts += 2;
190        } else {
191            info!(
192                "Found optimal partition count: [parts={}, initial={}, max_ram={:.2}GB, budget={:.2}GB]",
193                num_parts,
194                initial_num_parts,
195                max_ram_usage_in_bytes / BYTES_IN_GB,
196                ram_budget_in_bytes / BYTES_IN_GB
197            );
198        }
199    }
200
201    Ok((num_parts, pivot_data, train_dim))
202}
203
204/// Initial estimation of partition count based on dataset characteristics and RAM budget
205fn estimate_initial_partition_count<F>(
206    total_points: u64,
207    dimension: u64,
208    k_base: usize,
209    ram_budget_in_bytes: f64,
210    ram_estimator: &F,
211) -> usize
212where
213    F: Fn(u64, u64) -> f64,
214{
215    // Calculate total RAM needed without partitioning
216    let total_ram_estimate = ram_estimator(total_points * k_base as u64, dimension);
217
218    let mut partition_count = (total_ram_estimate / ram_budget_in_bytes).ceil() as usize;
219
220    // Ensure minimum of 3 partitions and odd number for balanced splitting
221    partition_count = std::cmp::max(3, partition_count);
222    if partition_count.is_multiple_of(2) {
223        partition_count += 1;
224    }
225
226    info!(
227        "Estimated initial partition count: {} (total points: {}, dimension: {}, k_base: {}, total_ram_estimate: {:.2} GB, ram_budget: {:.2} GB)",
228        partition_count,
229        total_points,
230        dimension,
231        k_base,
232        total_ram_estimate / BYTES_IN_GB,
233        ram_budget_in_bytes / BYTES_IN_GB
234    );
235
236    partition_count
237}
238
239#[allow(clippy::too_many_arguments)]
240fn shard_data_into_clusters_only_ids<T, StorageProvider>(
241    dataset_file: &str,
242    pivot_data: &[f32],
243    num_parts: usize,
244    dim: usize,
245    full_dim: usize,
246    k_base: usize,
247    merged_index_prefix: &str,
248    storage_provider: &StorageProvider,
249    pool: &RayonThreadPool,
250) -> ANNResult<()>
251where
252    T: VectorRepr,
253    StorageProvider: StorageReadProvider + StorageWriteProvider,
254{
255    let mut dataset_reader = CachedReader::<StorageProvider>::new(
256        dataset_file,
257        READ_WRITE_BLOCK_SIZE,
258        storage_provider,
259    )?;
260    let num_points = dataset_reader.read_u32()?;
261    let base_dim = dataset_reader.read_u32()?;
262    if base_dim != dim as u32 {
263        return Err(ANNError::log_index_error(
264            "dimensions dont match for train set and base set",
265        ));
266    }
267
268    let mut shard_counts = vec![0; num_parts];
269    let shard_idmaps_names = (0..num_parts)
270        .map(|shard| {
271            DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, shard)
272        })
273        .collect::<Vec<String>>();
274
275    // 8KB cache for small ID map files - matches default BufWriter size
276    const WRITE_ID_CACHE_SIZE: u64 = 8 * 1024;
277    let mut shard_idmap_cached_writers = Vec::new();
278    for name in &shard_idmaps_names {
279        let writer = storage_provider.create_for_write(name)?;
280        let cached_writer =
281            CachedWriter::<StorageProvider>::new(name, WRITE_ID_CACHE_SIZE, writer)?;
282        shard_idmap_cached_writers.push(cached_writer);
283    }
284
285    let dummy_size: u32 = 0;
286    let const_one: u32 = 1;
287    for writer in shard_idmap_cached_writers.iter_mut() {
288        writer.write(&dummy_size.to_le_bytes())?;
289        writer.write(&const_one.to_le_bytes())?;
290    }
291
292    let block_size = if num_points <= BLOCK_SIZE_LARGE_FILE {
293        num_points
294    } else {
295        BLOCK_SIZE_LARGE_FILE
296    };
297
298    let num_blocks = num_points.div_ceil(block_size);
299
300    let mut block_closest_centers = vec![0u32; block_size as usize * k_base];
301    let mut block_data_t: Vec<u8> = vec![0; block_size as usize * dim * std::mem::size_of::<T>()];
302    let mut block_data_float: Vec<f32> = vec![0.0; full_dim * block_size as usize];
303
304    for block in 0..num_blocks {
305        let start_id = (block * block_size) as usize;
306        let end_id = std::cmp::min((block + 1) * block_size, num_points) as usize;
307        let cur_blk_size = end_id - start_id;
308
309        dataset_reader.read(&mut block_data_t[..cur_blk_size * dim * std::mem::size_of::<T>()])?;
310
311        // convert data from type T to f32
312        let cur_vector_t: &[T] =
313            bytemuck::cast_slice(&block_data_t[..cur_blk_size * dim * std::mem::size_of::<T>()]);
314
315        for (v, dst) in cur_vector_t
316            .chunks_exact(dim)
317            .zip(block_data_float.chunks_exact_mut(full_dim))
318        {
319            T::as_f32_into(v, dst).into_ann_result()?;
320        }
321
322        compute_closest_centers(
323            &block_data_float[..full_dim * cur_blk_size],
324            cur_blk_size,
325            full_dim,
326            pivot_data,
327            num_parts,
328            k_base,
329            &mut block_closest_centers,
330            None,
331            None,
332            pool,
333        )?;
334
335        for p in 0..cur_blk_size {
336            for p1 in 0..k_base {
337                let shard_id = block_closest_centers[p * k_base + p1] as usize;
338                let original_point_map_id = (start_id + p) as u32;
339                shard_idmap_cached_writers[shard_id].write(&original_point_map_id.to_le_bytes())?;
340                shard_counts[shard_id] += 1;
341            }
342        }
343    }
344
345    let mut total_count = 0;
346
347    for i in 0..num_parts {
348        let cur_shard_count = shard_counts[i] as u32;
349        info!(" shard_{} with npts : {} ", i, cur_shard_count);
350        total_count += cur_shard_count;
351        shard_idmap_cached_writers[i].reset()?;
352        shard_idmap_cached_writers[i].write(&cur_shard_count.to_le_bytes())?;
353        shard_idmap_cached_writers[i].flush()?;
354    }
355
356    info!(
357        "Partitioned {} with replication factor {} to get {} points across {} shards",
358        num_points, k_base, total_count, num_parts
359    );
360
361    Ok(())
362}
363
364#[allow(clippy::too_many_arguments)]
365fn estimate_cluster_sizes(
366    data_float: &[f32],
367    num_pts: usize,
368    pivot_data: &[f32],
369    num_centers: usize,
370    dim: usize,
371    k_base: usize,
372    cluster_sizes: &mut Vec<u32>,
373    pool: &RayonThreadPool,
374) -> ANNResult<()> {
375    cluster_sizes.clear();
376    let mut shard_counts = vec![0; num_centers];
377
378    let block_size = if num_pts <= BLOCK_SIZE_LARGE_FILE as usize {
379        num_pts
380    } else {
381        BLOCK_SIZE_LARGE_FILE as usize
382    };
383
384    let mut block_closest_centers = vec![0; block_size * k_base];
385
386    let num_blocks = num_pts.div_ceil(block_size);
387
388    for block in 0..num_blocks {
389        let start_id = block * block_size;
390        let end_id = std::cmp::min((block + 1) * block_size, num_pts);
391        let cur_blk_size = end_id - start_id;
392
393        let block_data_float = &data_float[start_id * dim..(start_id + cur_blk_size) * dim];
394
395        compute_closest_centers(
396            block_data_float,
397            cur_blk_size,
398            dim,
399            pivot_data,
400            num_centers,
401            k_base,
402            &mut block_closest_centers,
403            None,
404            None,
405            pool,
406        )?;
407
408        for p in 0..cur_blk_size {
409            for p1 in 0..k_base {
410                let shard_id = block_closest_centers[p * k_base + p1] as usize;
411                shard_counts[shard_id] += 1;
412            }
413        }
414    }
415
416    (0..num_centers).for_each(|i| {
417        let cur_shard_count = shard_counts[i] as u32;
418        cluster_sizes.push(cur_shard_count);
419    });
420    info!("Estimated cluster sizes: {:?}", cluster_sizes);
421    Ok(())
422}
423
424#[cfg(test)]
425mod partition_test {
426    use std::io::Read;
427
428    use diskann_providers::storage::VirtualStorageProvider;
429    use diskann_providers::utils::create_thread_pool_for_test;
430    use diskann_utils::test_data_root;
431    use vfs::{MemoryFS, OverlayFS};
432
433    use super::*;
434
435    #[test]
436    fn test_estimate_cluster_sizes() {
437        let data_float = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
438        let num_pts = 3;
439        let pivot_data = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
440        let num_centers = 3;
441        let dim = 2;
442        let k_base = 2;
443        let mut cluster_sizes = vec![];
444        let pool = create_thread_pool_for_test();
445
446        estimate_cluster_sizes(
447            &data_float,
448            num_pts,
449            pivot_data,
450            num_centers,
451            dim,
452            k_base,
453            &mut cluster_sizes,
454            &pool,
455        )
456        .unwrap();
457
458        assert_eq!(cluster_sizes.len(), num_centers);
459        assert_eq!(cluster_sizes, &[2, 3, 1]);
460    }
461
462    #[test]
463    fn test_shard_data_into_clusters_only_ids() {
464        // create a temporary file for the dataset
465        let dataset_path = "/dataset_file";
466        // write some dummy data to the dataset file
467        let mut data_float = Vec::new();
468        let num_points: u32 = 100;
469        let dim: usize = 10;
470
471        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
472        {
473            let writer = storage_provider.create_for_write(dataset_path).unwrap();
474            let mut dataset_writer = CachedWriter::<VirtualStorageProvider<MemoryFS>>::new(
475                dataset_path,
476                READ_WRITE_BLOCK_SIZE,
477                writer,
478            )
479            .unwrap();
480            dataset_writer.write(&num_points.to_le_bytes()).unwrap();
481            dataset_writer.write(&dim.to_le_bytes()).unwrap();
482            for i in 0..num_points {
483                for j in 0..dim {
484                    let val = (i * dim as u32 + j as u32) as f32;
485                    data_float.push(val);
486                    dataset_writer.write(&val.to_le_bytes()).unwrap();
487                }
488            }
489        }
490
491        // create some dummy pivot data
492        let k_base: usize = 2;
493        let num_parts = 3;
494
495        // generate pivot data
496        let pivot_data: [f32; 30] = [
497            820.0, 821.0, 822.0, 823.0, 824.0, 825.0, 826.0, 827.0, 828.0, 829.0, 155.0, 156.0,
498            157.0, 158.0, 159.0, 160.0, 161.0, 162.0, 163.0, 164.0, 480.0, 481.0, 482.0, 483.0,
499            484.0, 485.0, 486.0, 487.0, 488.0, 489.0,
500        ];
501
502        // create a temporary prefix for the merged index prefix
503        let merged_index_prefix = "/merged_index";
504        let pool = create_thread_pool_for_test();
505        // call the function being tested
506        shard_data_into_clusters_only_ids::<f32, VirtualStorageProvider<OverlayFS>>(
507            dataset_path,
508            &pivot_data,
509            num_parts,
510            dim,
511            dim,
512            k_base,
513            merged_index_prefix,
514            &storage_provider,
515            &pool,
516        )
517        .unwrap();
518
519        // check that the output is as expected
520        let expected_prefix = "/partition/id_maps/merged_index_expected";
521        for shard in 0..num_parts {
522            let path1 =
523                DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, shard);
524            let path2 =
525                DiskIndexWriter::get_merged_index_subshard_id_map_file(expected_prefix, shard);
526            let file1 =
527                load_file_to_vec::<VirtualStorageProvider<OverlayFS>>(&path1, &storage_provider);
528            let file2 =
529                load_file_to_vec::<VirtualStorageProvider<OverlayFS>>(&path2, &storage_provider);
530
531            assert_eq!(file1.len(), file2.len());
532            assert_eq!(file1[..], file2[..]);
533
534            // clean up the temporary files and directory
535            storage_provider.delete(&path1).unwrap();
536        }
537
538        storage_provider.delete(dataset_path).unwrap();
539    }
540
541    fn load_file_to_vec<StorageProvider>(
542        file_path: &str,
543        storage_provider: &StorageProvider,
544    ) -> Vec<u8>
545    where
546        StorageProvider: StorageReadProvider,
547    {
548        let mut file = storage_provider.open_reader(file_path).unwrap();
549        let mut buffer = vec![];
550        file.read_to_end(&mut buffer).unwrap();
551        buffer
552    }
553
554    #[test]
555    fn test_partition_with_ram_budget() -> ANNResult<()> {
556        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
557        let dataset_file = "/sift/siftsmall_learn.bin";
558        let mut file = storage_provider.open_reader(dataset_file).unwrap();
559        let mut data = vec![];
560        file.read_to_end(&mut data).unwrap();
561
562        let sampling_rate = 1.0;
563        let ram_budget_in_bytes = 15_000_000.0;
564        let max_degree = 64;
565        let k_base = 2;
566        let merged_index_prefix = "/test_merged_index_prefix";
567        let pool = create_thread_pool_for_test();
568
569        let num_parts = partition_with_ram_budget::<f32, _, _, _>(
570            dataset_file,
571            128, //sift is 128 dimensions
572            sampling_rate,
573            ram_budget_in_bytes,
574            k_base,
575            merged_index_prefix,
576            &storage_provider,
577            &mut diskann_providers::utils::create_rnd_in_tests(),
578            &pool,
579            |num_points, dim| {
580                // Simple RAM estimation for test - capture datasize and graph_degree from context
581                use diskann_providers::model::GRAPH_SLACK_FACTOR;
582
583                let datasize = std::mem::size_of::<f32>() as u64;
584                let graph_degree = max_degree as u64;
585                let dataset_size = (num_points * dim.next_multiple_of(8u64) * datasize) as f64;
586                let graph_size = (num_points * graph_degree * 4) as f64 * GRAPH_SLACK_FACTOR;
587                1.1 * (dataset_size + graph_size)
588            },
589        )?;
590
591        assert!(num_parts >= 3);
592
593        for i in 0..num_parts {
594            let idmap_filename =
595                DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, i);
596            storage_provider.delete(&idmap_filename)?;
597        }
598
599        Ok(())
600    }
601}