1use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult};
6use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
7use diskann_providers::{
8 forward_threadpool,
9 utils::{
10 compute_closest_centers, gen_random_slice, k_meanspp_selecting_pivots, run_lloyds,
11 AsThreadPool, RayonThreadPool, READ_WRITE_BLOCK_SIZE,
12 },
13};
14use rand::Rng;
15use tracing::info;
16
17use crate::{
18 disk_index_build_parameter::BYTES_IN_GB,
19 storage::{CachedReader, CachedWriter, DiskIndexWriter},
20};
21
22const BLOCK_SIZE_LARGE_FILE: u32 = 10_000;
24
25#[allow(clippy::too_many_arguments)]
26pub fn partition_with_ram_budget<T, StorageProvider, Pool, F>(
27 dataset_file: &str,
28 dim: usize,
29 sampling_rate: f64,
30 ram_budget_in_bytes: f64,
31 k_base: usize,
32 merged_index_prefix: &str,
33 storage_provider: &StorageProvider,
34 rng: &mut impl Rng,
35 pool: Pool,
36 ram_estimator: F,
37) -> ANNResult<usize>
38where
39 T: VectorRepr,
40 StorageProvider: StorageReadProvider + StorageWriteProvider,
41 Pool: AsThreadPool,
42 F: Fn(u64, u64) -> f64,
43{
44 forward_threadpool!(pool = pool);
45 let (num_parts, pivot_data, train_dim) = find_partition_size::<T, StorageProvider, F>(
47 dataset_file,
48 sampling_rate,
49 ram_budget_in_bytes,
50 k_base,
51 storage_provider,
52 rng,
53 pool,
54 &ram_estimator,
55 )?;
56
57 info!("Saving shard data into clusters, with only ids");
58
59 shard_data_into_clusters_only_ids::<T, StorageProvider>(
60 dataset_file,
61 &pivot_data,
62 num_parts,
63 dim,
64 train_dim,
65 k_base,
66 merged_index_prefix,
67 storage_provider,
68 pool,
69 )?;
70
71 Ok(num_parts)
72}
73
74#[allow(clippy::too_many_arguments)]
75fn find_partition_size<T, StorageProvider, F>(
76 dataset_file: &str,
77 sampling_rate: f64,
78 ram_budget_in_bytes: f64,
79 k_base: usize,
80 storage_provider: &StorageProvider,
81 rng: &mut impl Rng,
82 pool: &RayonThreadPool,
83 ram_estimator: &F,
84) -> ANNResult<(usize, Vec<f32>, usize)>
85where
86 T: VectorRepr,
87 StorageProvider: StorageReadProvider + StorageWriteProvider,
88 F: Fn(u64, u64) -> f64,
89{
90 const MAX_K_MEANS_REPS: usize = 10;
91
92 let (train_data_float, num_train, train_dim) =
93 gen_random_slice::<T, StorageProvider>(dataset_file, sampling_rate, storage_provider, rng)?;
94 info!("Loaded {} points for train, dim: {}", num_train, train_dim);
95
96 let (test_data_float, num_test, test_dim) =
97 gen_random_slice::<T, StorageProvider>(dataset_file, sampling_rate, storage_provider, rng)?;
98 info!("Loaded {} points for test, dim: {}", num_test, test_dim);
99
100 let total_points = (num_train as f64 / sampling_rate) as u64;
102 let initial_num_parts = estimate_initial_partition_count::<F>(
104 total_points,
105 train_dim as u64,
106 k_base,
107 ram_budget_in_bytes,
108 ram_estimator,
109 );
110
111 let mut num_parts = initial_num_parts;
112 let mut fit_in_ram = false;
113 let mut pivot_data = Vec::new();
114 while !fit_in_ram {
116 fit_in_ram = true;
117
118 let mut max_ram_usage_in_bytes = 0.0;
119
120 pivot_data = vec![0.0; num_parts * train_dim];
121
122 info!("Processing global k-means (kmeans_partitioning Step)");
124 k_meanspp_selecting_pivots(
125 &train_data_float,
126 num_train,
127 train_dim,
128 &mut pivot_data,
129 num_parts,
130 rng,
131 &mut (false),
132 pool,
133 )?;
134
135 run_lloyds(
136 &train_data_float,
137 num_train,
138 train_dim,
139 &mut pivot_data,
140 num_parts,
141 MAX_K_MEANS_REPS,
142 &mut (false),
143 pool,
144 )?;
145
146 let mut cluster_sizes = Vec::new();
149 estimate_cluster_sizes(
150 &test_data_float,
151 num_test,
152 &pivot_data,
153 num_parts,
154 test_dim,
155 k_base,
156 &mut cluster_sizes,
157 pool,
158 )?;
159
160 let mut partition_stats = Vec::with_capacity(num_parts);
161 for p in &cluster_sizes {
162 let p = (*p as f64 / sampling_rate) as u64;
164 let cur_shard_ram_estimate_in_bytes = ram_estimator(p, train_dim as u64);
165 partition_stats.push((p, cur_shard_ram_estimate_in_bytes));
166
167 if cur_shard_ram_estimate_in_bytes > max_ram_usage_in_bytes {
168 max_ram_usage_in_bytes = cur_shard_ram_estimate_in_bytes;
169 }
170 }
171
172 info!(
173 "Partition RAM estimates (GB): {}",
174 partition_stats
175 .iter()
176 .map(|(size, ram)| format!("#{}: {:.2}", size, ram / BYTES_IN_GB))
177 .collect::<Vec<_>>()
178 .join(", ")
179 );
180
181 info!(
182 "With {} parts, max estimated RAM usage: {:.2} GB, budget given is {:.2} GB",
183 num_parts,
184 max_ram_usage_in_bytes / BYTES_IN_GB,
185 ram_budget_in_bytes / BYTES_IN_GB
186 );
187 if max_ram_usage_in_bytes > ram_budget_in_bytes {
188 fit_in_ram = false;
189 num_parts += 2;
190 } else {
191 info!(
192 "Found optimal partition count: [parts={}, initial={}, max_ram={:.2}GB, budget={:.2}GB]",
193 num_parts,
194 initial_num_parts,
195 max_ram_usage_in_bytes / BYTES_IN_GB,
196 ram_budget_in_bytes / BYTES_IN_GB
197 );
198 }
199 }
200
201 Ok((num_parts, pivot_data, train_dim))
202}
203
204fn estimate_initial_partition_count<F>(
206 total_points: u64,
207 dimension: u64,
208 k_base: usize,
209 ram_budget_in_bytes: f64,
210 ram_estimator: &F,
211) -> usize
212where
213 F: Fn(u64, u64) -> f64,
214{
215 let total_ram_estimate = ram_estimator(total_points * k_base as u64, dimension);
217
218 let mut partition_count = (total_ram_estimate / ram_budget_in_bytes).ceil() as usize;
219
220 partition_count = std::cmp::max(3, partition_count);
222 if partition_count.is_multiple_of(2) {
223 partition_count += 1;
224 }
225
226 info!(
227 "Estimated initial partition count: {} (total points: {}, dimension: {}, k_base: {}, total_ram_estimate: {:.2} GB, ram_budget: {:.2} GB)",
228 partition_count,
229 total_points,
230 dimension,
231 k_base,
232 total_ram_estimate / BYTES_IN_GB,
233 ram_budget_in_bytes / BYTES_IN_GB
234 );
235
236 partition_count
237}
238
239#[allow(clippy::too_many_arguments)]
240fn shard_data_into_clusters_only_ids<T, StorageProvider>(
241 dataset_file: &str,
242 pivot_data: &[f32],
243 num_parts: usize,
244 dim: usize,
245 full_dim: usize,
246 k_base: usize,
247 merged_index_prefix: &str,
248 storage_provider: &StorageProvider,
249 pool: &RayonThreadPool,
250) -> ANNResult<()>
251where
252 T: VectorRepr,
253 StorageProvider: StorageReadProvider + StorageWriteProvider,
254{
255 let mut dataset_reader = CachedReader::<StorageProvider>::new(
256 dataset_file,
257 READ_WRITE_BLOCK_SIZE,
258 storage_provider,
259 )?;
260 let num_points = dataset_reader.read_u32()?;
261 let base_dim = dataset_reader.read_u32()?;
262 if base_dim != dim as u32 {
263 return Err(ANNError::log_index_error(
264 "dimensions dont match for train set and base set",
265 ));
266 }
267
268 let mut shard_counts = vec![0; num_parts];
269 let shard_idmaps_names = (0..num_parts)
270 .map(|shard| {
271 DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, shard)
272 })
273 .collect::<Vec<String>>();
274
275 const WRITE_ID_CACHE_SIZE: u64 = 8 * 1024;
277 let mut shard_idmap_cached_writers = Vec::new();
278 for name in &shard_idmaps_names {
279 let writer = storage_provider.create_for_write(name)?;
280 let cached_writer =
281 CachedWriter::<StorageProvider>::new(name, WRITE_ID_CACHE_SIZE, writer)?;
282 shard_idmap_cached_writers.push(cached_writer);
283 }
284
285 let dummy_size: u32 = 0;
286 let const_one: u32 = 1;
287 for writer in shard_idmap_cached_writers.iter_mut() {
288 writer.write(&dummy_size.to_le_bytes())?;
289 writer.write(&const_one.to_le_bytes())?;
290 }
291
292 let block_size = if num_points <= BLOCK_SIZE_LARGE_FILE {
293 num_points
294 } else {
295 BLOCK_SIZE_LARGE_FILE
296 };
297
298 let num_blocks = num_points.div_ceil(block_size);
299
300 let mut block_closest_centers = vec![0u32; block_size as usize * k_base];
301 let mut block_data_t: Vec<u8> = vec![0; block_size as usize * dim * std::mem::size_of::<T>()];
302 let mut block_data_float: Vec<f32> = vec![0.0; full_dim * block_size as usize];
303
304 for block in 0..num_blocks {
305 let start_id = (block * block_size) as usize;
306 let end_id = std::cmp::min((block + 1) * block_size, num_points) as usize;
307 let cur_blk_size = end_id - start_id;
308
309 dataset_reader.read(&mut block_data_t[..cur_blk_size * dim * std::mem::size_of::<T>()])?;
310
311 let cur_vector_t: &[T] =
313 bytemuck::cast_slice(&block_data_t[..cur_blk_size * dim * std::mem::size_of::<T>()]);
314
315 for (v, dst) in cur_vector_t
316 .chunks_exact(dim)
317 .zip(block_data_float.chunks_exact_mut(full_dim))
318 {
319 T::as_f32_into(v, dst).into_ann_result()?;
320 }
321
322 compute_closest_centers(
323 &block_data_float[..full_dim * cur_blk_size],
324 cur_blk_size,
325 full_dim,
326 pivot_data,
327 num_parts,
328 k_base,
329 &mut block_closest_centers,
330 None,
331 None,
332 pool,
333 )?;
334
335 for p in 0..cur_blk_size {
336 for p1 in 0..k_base {
337 let shard_id = block_closest_centers[p * k_base + p1] as usize;
338 let original_point_map_id = (start_id + p) as u32;
339 shard_idmap_cached_writers[shard_id].write(&original_point_map_id.to_le_bytes())?;
340 shard_counts[shard_id] += 1;
341 }
342 }
343 }
344
345 let mut total_count = 0;
346
347 for i in 0..num_parts {
348 let cur_shard_count = shard_counts[i] as u32;
349 info!(" shard_{} with npts : {} ", i, cur_shard_count);
350 total_count += cur_shard_count;
351 shard_idmap_cached_writers[i].reset()?;
352 shard_idmap_cached_writers[i].write(&cur_shard_count.to_le_bytes())?;
353 shard_idmap_cached_writers[i].flush()?;
354 }
355
356 info!(
357 "Partitioned {} with replication factor {} to get {} points across {} shards",
358 num_points, k_base, total_count, num_parts
359 );
360
361 Ok(())
362}
363
364#[allow(clippy::too_many_arguments)]
365fn estimate_cluster_sizes(
366 data_float: &[f32],
367 num_pts: usize,
368 pivot_data: &[f32],
369 num_centers: usize,
370 dim: usize,
371 k_base: usize,
372 cluster_sizes: &mut Vec<u32>,
373 pool: &RayonThreadPool,
374) -> ANNResult<()> {
375 cluster_sizes.clear();
376 let mut shard_counts = vec![0; num_centers];
377
378 let block_size = if num_pts <= BLOCK_SIZE_LARGE_FILE as usize {
379 num_pts
380 } else {
381 BLOCK_SIZE_LARGE_FILE as usize
382 };
383
384 let mut block_closest_centers = vec![0; block_size * k_base];
385
386 let num_blocks = num_pts.div_ceil(block_size);
387
388 for block in 0..num_blocks {
389 let start_id = block * block_size;
390 let end_id = std::cmp::min((block + 1) * block_size, num_pts);
391 let cur_blk_size = end_id - start_id;
392
393 let block_data_float = &data_float[start_id * dim..(start_id + cur_blk_size) * dim];
394
395 compute_closest_centers(
396 block_data_float,
397 cur_blk_size,
398 dim,
399 pivot_data,
400 num_centers,
401 k_base,
402 &mut block_closest_centers,
403 None,
404 None,
405 pool,
406 )?;
407
408 for p in 0..cur_blk_size {
409 for p1 in 0..k_base {
410 let shard_id = block_closest_centers[p * k_base + p1] as usize;
411 shard_counts[shard_id] += 1;
412 }
413 }
414 }
415
416 (0..num_centers).for_each(|i| {
417 let cur_shard_count = shard_counts[i] as u32;
418 cluster_sizes.push(cur_shard_count);
419 });
420 info!("Estimated cluster sizes: {:?}", cluster_sizes);
421 Ok(())
422}
423
424#[cfg(test)]
425mod partition_test {
426 use std::io::Read;
427
428 use diskann_providers::storage::VirtualStorageProvider;
429 use diskann_providers::utils::create_thread_pool_for_test;
430 use diskann_utils::test_data_root;
431 use vfs::{MemoryFS, OverlayFS};
432
433 use super::*;
434
435 #[test]
436 fn test_estimate_cluster_sizes() {
437 let data_float = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
438 let num_pts = 3;
439 let pivot_data = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
440 let num_centers = 3;
441 let dim = 2;
442 let k_base = 2;
443 let mut cluster_sizes = vec![];
444 let pool = create_thread_pool_for_test();
445
446 estimate_cluster_sizes(
447 &data_float,
448 num_pts,
449 pivot_data,
450 num_centers,
451 dim,
452 k_base,
453 &mut cluster_sizes,
454 &pool,
455 )
456 .unwrap();
457
458 assert_eq!(cluster_sizes.len(), num_centers);
459 assert_eq!(cluster_sizes, &[2, 3, 1]);
460 }
461
462 #[test]
463 fn test_shard_data_into_clusters_only_ids() {
464 let dataset_path = "/dataset_file";
466 let mut data_float = Vec::new();
468 let num_points: u32 = 100;
469 let dim: usize = 10;
470
471 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
472 {
473 let writer = storage_provider.create_for_write(dataset_path).unwrap();
474 let mut dataset_writer = CachedWriter::<VirtualStorageProvider<MemoryFS>>::new(
475 dataset_path,
476 READ_WRITE_BLOCK_SIZE,
477 writer,
478 )
479 .unwrap();
480 dataset_writer.write(&num_points.to_le_bytes()).unwrap();
481 dataset_writer.write(&dim.to_le_bytes()).unwrap();
482 for i in 0..num_points {
483 for j in 0..dim {
484 let val = (i * dim as u32 + j as u32) as f32;
485 data_float.push(val);
486 dataset_writer.write(&val.to_le_bytes()).unwrap();
487 }
488 }
489 }
490
491 let k_base: usize = 2;
493 let num_parts = 3;
494
495 let pivot_data: [f32; 30] = [
497 820.0, 821.0, 822.0, 823.0, 824.0, 825.0, 826.0, 827.0, 828.0, 829.0, 155.0, 156.0,
498 157.0, 158.0, 159.0, 160.0, 161.0, 162.0, 163.0, 164.0, 480.0, 481.0, 482.0, 483.0,
499 484.0, 485.0, 486.0, 487.0, 488.0, 489.0,
500 ];
501
502 let merged_index_prefix = "/merged_index";
504 let pool = create_thread_pool_for_test();
505 shard_data_into_clusters_only_ids::<f32, VirtualStorageProvider<OverlayFS>>(
507 dataset_path,
508 &pivot_data,
509 num_parts,
510 dim,
511 dim,
512 k_base,
513 merged_index_prefix,
514 &storage_provider,
515 &pool,
516 )
517 .unwrap();
518
519 let expected_prefix = "/partition/id_maps/merged_index_expected";
521 for shard in 0..num_parts {
522 let path1 =
523 DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, shard);
524 let path2 =
525 DiskIndexWriter::get_merged_index_subshard_id_map_file(expected_prefix, shard);
526 let file1 =
527 load_file_to_vec::<VirtualStorageProvider<OverlayFS>>(&path1, &storage_provider);
528 let file2 =
529 load_file_to_vec::<VirtualStorageProvider<OverlayFS>>(&path2, &storage_provider);
530
531 assert_eq!(file1.len(), file2.len());
532 assert_eq!(file1[..], file2[..]);
533
534 storage_provider.delete(&path1).unwrap();
536 }
537
538 storage_provider.delete(dataset_path).unwrap();
539 }
540
541 fn load_file_to_vec<StorageProvider>(
542 file_path: &str,
543 storage_provider: &StorageProvider,
544 ) -> Vec<u8>
545 where
546 StorageProvider: StorageReadProvider,
547 {
548 let mut file = storage_provider.open_reader(file_path).unwrap();
549 let mut buffer = vec![];
550 file.read_to_end(&mut buffer).unwrap();
551 buffer
552 }
553
554 #[test]
555 fn test_partition_with_ram_budget() -> ANNResult<()> {
556 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
557 let dataset_file = "/sift/siftsmall_learn.bin";
558 let mut file = storage_provider.open_reader(dataset_file).unwrap();
559 let mut data = vec![];
560 file.read_to_end(&mut data).unwrap();
561
562 let sampling_rate = 1.0;
563 let ram_budget_in_bytes = 15_000_000.0;
564 let max_degree = 64;
565 let k_base = 2;
566 let merged_index_prefix = "/test_merged_index_prefix";
567 let pool = create_thread_pool_for_test();
568
569 let num_parts = partition_with_ram_budget::<f32, _, _, _>(
570 dataset_file,
571 128, sampling_rate,
573 ram_budget_in_bytes,
574 k_base,
575 merged_index_prefix,
576 &storage_provider,
577 &mut diskann_providers::utils::create_rnd_in_tests(),
578 &pool,
579 |num_points, dim| {
580 use diskann_providers::model::GRAPH_SLACK_FACTOR;
582
583 let datasize = std::mem::size_of::<f32>() as u64;
584 let graph_degree = max_degree as u64;
585 let dataset_size = (num_points * dim.next_multiple_of(8u64) * datasize) as f64;
586 let graph_size = (num_points * graph_degree * 4) as f64 * GRAPH_SLACK_FACTOR;
587 1.1 * (dataset_size + graph_size)
588 },
589 )?;
590
591 assert!(num_parts >= 3);
592
593 for i in 0..num_parts {
594 let idmap_filename =
595 DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, i);
596 storage_provider.delete(&idmap_filename)?;
597 }
598
599 Ok(())
600 }
601}