Skip to main content

diskann_providers/utils/
rayon_util.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use std::ops::Range;
6
7use diskann::{ANNError, ANNResult};
8use rayon::prelude::{IntoParallelIterator, ParallelIterator};
9
10/// based on thread_num, execute the task in parallel using Rayon or serial
11#[inline]
12pub fn execute_with_rayon<F>(range: Range<usize>, num_threads: usize, f: F) -> ANNResult<()>
13where
14    F: Fn(usize) -> ANNResult<()> + Sync + Send + Copy,
15{
16    if num_threads == 1 {
17        for i in range {
18            f(i)?;
19        }
20        Ok(())
21    } else {
22        let pool = create_thread_pool(num_threads)?;
23        range.into_par_iter().try_for_each_in_pool(&pool, f)
24    }
25}
26
27/// Creates a new thread pool with the specified number of threads.
28/// If `num_threads` is 0, it defaults to the number of logical CPUs.
29pub fn create_thread_pool(num_threads: usize) -> ANNResult<RayonThreadPool> {
30    let pool = rayon::ThreadPoolBuilder::new()
31        .num_threads(num_threads)
32        .build()
33        .map_err(|err| ANNError::log_thread_pool_error(err.to_string()))?;
34    Ok(RayonThreadPool(pool))
35}
36
37/// Creates a thread pool with a configurable number of threads for testing purposes.
38/// The number of threads can be set using the environment variable `DISKANN_TEST_POOL_THREADS`.
39/// If the environment variable is not set or cannot be parsed, it defaults to 3 threads.
40#[allow(clippy::unwrap_used)]
41pub fn create_thread_pool_for_test() -> RayonThreadPool {
42    use std::env;
43
44    let num_threads = env::var("DISKANN_TEST_POOL_THREADS")
45        .ok()
46        .and_then(|val| val.parse().ok())
47        .unwrap_or(3);
48
49    create_thread_pool(num_threads).unwrap()
50}
51/// Creates a thread pool for benchmarking purposes without specifying the number of threads.
52/// The Rayon runtime will automatically determine the optimal number of threads to use.
53/// It uses the `RAYON_NUM_THREADS` environment variable if set,
54/// or defaults to the number of logical CPUs otherwise
55#[allow(clippy::unwrap_used)]
56pub fn create_thread_pool_for_bench() -> RayonThreadPool {
57    let pool = rayon::ThreadPoolBuilder::new()
58        .build()
59        .map_err(|err| ANNError::log_thread_pool_error(err.to_string()))
60        .unwrap();
61    RayonThreadPool(pool)
62}
63
64pub struct RayonThreadPool(rayon::ThreadPool);
65
66impl RayonThreadPool {
67    pub fn install<OP, R>(&self, op: OP) -> R
68    where
69        OP: FnOnce() -> R + Send,
70        R: Send,
71    {
72        self.0.install(op)
73    }
74}
75
76mod sealed {
77    pub trait Sealed {}
78}
79
80/// This allows either an integer to be provided or an explicit `&RayonThreadPool`.
81/// If an integer is provided, we create a new thread-pool with the requested number of
82/// threads.
83///
84/// This trait should be "sealed" to avoid external users being able to implement it.
85/// See [as_threadpool_tests] for examples of how to use this trait.
86pub trait AsThreadPool: sealed::Sealed + Send + Sync {
87    type Returns: std::ops::Deref<Target = RayonThreadPool>;
88    fn as_threadpool(&self) -> ANNResult<Self::Returns>;
89}
90
91impl sealed::Sealed for usize {}
92impl sealed::Sealed for &RayonThreadPool {}
93
94impl AsThreadPool for usize {
95    type Returns = diskann_utils::reborrow::Place<RayonThreadPool>;
96    fn as_threadpool(&self) -> ANNResult<Self::Returns> {
97        create_thread_pool(*self).map(diskann_utils::reborrow::Place)
98    }
99}
100
101impl<'a> AsThreadPool for &'a RayonThreadPool {
102    type Returns = &'a RayonThreadPool;
103    fn as_threadpool(&self) -> ANNResult<Self::Returns> {
104        Ok(self)
105    }
106}
107
108/// The `forward_threadpool` macro simplifies obtaining a thread pool from an input
109/// that implements the `AsThreadPool` trait.
110#[macro_export]
111macro_rules! forward_threadpool {
112    ($out:ident = $in:ident) => {
113        $crate::forward_threadpool!($out = $in: _);
114    };
115    ($out:ident = $in:ident: $type:ty) => {
116        let $out = &*<$type as $crate::utils::AsThreadPool>::as_threadpool(&$in)?;
117    };
118}
119
120// Allow use of disallowed methods within this trait to provide custom
121// implementations of common parallel operations that enforce execution
122// within a specified thread pool.
123#[allow(clippy::disallowed_methods)]
124pub trait ParallelIteratorInPool: ParallelIterator + Sized {
125    fn for_each_in_pool<OP>(self, pool: &RayonThreadPool, op: OP)
126    where
127        OP: Fn(Self::Item) + Sync + Send,
128    {
129        pool.install(|| self.for_each(op));
130    }
131
132    fn for_each_with_in_pool<OP, T>(self, pool: &RayonThreadPool, init: T, op: OP)
133    where
134        OP: Fn(&mut T, Self::Item) + Sync + Send,
135        T: Send + Clone,
136    {
137        pool.install(|| self.for_each_with(init, op))
138    }
139
140    fn for_each_init_in_pool<OP, INIT, T>(self, pool: &RayonThreadPool, init: INIT, op: OP)
141    where
142        OP: Fn(&mut T, Self::Item) + Sync + Send,
143        INIT: Fn() -> T + Sync + Send,
144    {
145        pool.install(|| self.for_each_init(init, op))
146    }
147
148    fn try_for_each_in_pool<OP, E>(self, pool: &RayonThreadPool, op: OP) -> Result<(), E>
149    where
150        OP: Fn(Self::Item) -> Result<(), E> + Sync + Send,
151        E: Send,
152    {
153        pool.install(|| self.try_for_each(op))
154    }
155
156    fn try_for_each_with_in_pool<OP, T, E>(
157        self,
158        pool: &RayonThreadPool,
159        init: T,
160        op: OP,
161    ) -> Result<(), E>
162    where
163        OP: Fn(&mut T, Self::Item) -> Result<(), E> + Sync + Send,
164        E: Send,
165        T: Send + Clone,
166    {
167        pool.install(|| self.try_for_each_with(init, op))
168    }
169
170    fn try_for_each_init_in_pool<OP, INIT, T, E>(
171        self,
172        pool: &RayonThreadPool,
173        init: INIT,
174        op: OP,
175    ) -> Result<(), E>
176    where
177        OP: Fn(&mut T, Self::Item) -> Result<(), E> + Sync + Send,
178        INIT: Fn() -> T + Sync + Send,
179        E: Send,
180    {
181        pool.install(|| self.try_for_each_init(init, op))
182    }
183
184    fn count_in_pool(self, pool: &RayonThreadPool) -> usize {
185        pool.install(|| self.count())
186    }
187
188    fn collect_in_pool<C>(self, pool: &RayonThreadPool) -> C
189    where
190        C: rayon::iter::FromParallelIterator<Self::Item> + Send,
191    {
192        pool.install(|| self.collect())
193    }
194
195    fn sum_in_pool<S>(self, pool: &RayonThreadPool) -> S
196    where
197        S: Send + std::iter::Sum<Self::Item> + std::iter::Sum<S>,
198    {
199        pool.install(|| self.sum())
200    }
201}
202
203// Implement the `ParallelIteratorInPool` trait for any type that implements `ParallelIterator`.
204impl<T> ParallelIteratorInPool for T where T: ParallelIterator {}
205
206#[cfg(test)]
207mod tests {
208    use std::sync::{Mutex, mpsc::channel};
209
210    use super::*;
211
212    fn get_num_cpus() -> usize {
213        std::thread::available_parallelism()
214            .map(|n| n.get())
215            .unwrap()
216    }
217
218    #[test]
219    fn test_create_thread_pool_for_test_default() {
220        // Ensure the environment variable is not set
221        //
222        // SAFETY: These environment variables are only set and removed using `std::env`
223        // functions (probably).
224        unsafe { std::env::remove_var("DISKANN_TEST_POOL_THREADS") };
225        let pool = create_thread_pool_for_test();
226        // Assuming RayonThreadPool has a method to get the number of threads
227        assert_eq!(pool.0.current_num_threads(), 3);
228    }
229
230    #[test]
231    fn test_create_thread_pool_for_test_from_env() {
232        // Set the environment variable to a specific value
233        //
234        // SAFETY: These environment variables are only set and removed using `std::env`
235        // functions (probably).
236        unsafe { std::env::set_var("DISKANN_TEST_POOL_THREADS", "5") };
237        let pool = create_thread_pool_for_test();
238        // Assuming RayonThreadPool has a method to get the number of threads
239        assert_eq!(pool.0.current_num_threads(), 5);
240
241        // Clean up the environment variable
242        //
243        // SAFETY: These environment variables are only set and removed using `std::env`
244        // functions (probably).
245        unsafe { std::env::remove_var("DISKANN_TEST_POOL_THREADS") };
246    }
247
248    #[test]
249    fn test_create_thread_pool_for_test_invalid_env() {
250        // Set the environment variable to an invalid value
251        //
252        // SAFETY: These environment variables are only set and removed using `std::env`
253        // functions (probably).
254        unsafe { std::env::set_var("DISKANN_TEST_POOL_THREADS", "invalid") };
255        let pool = create_thread_pool_for_test();
256        // Assuming RayonThreadPool has a method to get the number of threads
257        assert_eq!(pool.0.current_num_threads(), 3);
258
259        // Clean up the environment variable
260        //
261        // SAFETY: These environment variables are only set and removed using `std::env`
262        // functions (probably).
263        unsafe { std::env::remove_var("DISKANN_TEST_POOL_THREADS") };
264    }
265
266    #[test]
267    fn test_create_thread_pool_for_bench() {
268        let pool = create_thread_pool_for_bench();
269        assert_eq!(pool.0.current_num_threads(), get_num_cpus());
270    }
271
272    fn assert_run_in_rayon_thread() {
273        println!(
274            "Thread name: {:?}, Thread id: {:?}, Rayon thread index: {:?}, Rayon num_threads: {:?}",
275            std::thread::current().name(),
276            std::thread::current().id(),
277            rayon::current_thread_index(),
278            rayon::current_num_threads()
279        );
280        assert!(rayon::current_thread_index().is_some());
281    }
282
283    #[test]
284    fn test_for_each_in_pool() {
285        let pool = create_thread_pool(4).unwrap();
286
287        let res = Mutex::new(Vec::new());
288        (0..5).into_par_iter().for_each_in_pool(&pool, |x| {
289            let mut res = res.lock().unwrap();
290            res.push(x);
291            assert_run_in_rayon_thread();
292        });
293
294        let mut res = res.lock().unwrap();
295        res.sort();
296
297        assert_eq!(&res[..], &[0, 1, 2, 3, 4]);
298    }
299    #[test]
300    fn test_for_each_with_in_pool() {
301        let pool = create_thread_pool(4).unwrap();
302        let (sender, receiver) = channel();
303
304        (0..5)
305            .into_par_iter()
306            .for_each_with_in_pool(&pool, sender, |s, x| s.send(x).unwrap());
307
308        let mut res: Vec<_> = receiver.iter().collect();
309
310        res.sort();
311
312        assert_eq!(&res[..], &[0, 1, 2, 3, 4]);
313    }
314
315    #[test]
316    fn test_for_each_init_in_pool() {
317        let pool = create_thread_pool(4).unwrap();
318        let iter = (0..100).into_par_iter();
319        iter.for_each_init_in_pool(
320            &pool,
321            || 0,
322            |s, i| {
323                assert_run_in_rayon_thread();
324                *s += i;
325            },
326        );
327    }
328
329    #[test]
330    fn test_map_in_pool() {
331        let pool = create_thread_pool(4).unwrap();
332        let iter = (0..100).into_par_iter();
333        let mapped_iter = iter.map(|i| {
334            assert_run_in_rayon_thread();
335            i as f32
336        });
337        let list = mapped_iter.collect_in_pool::<Vec<f32>>(&pool);
338        assert!(list.len() == 100);
339    }
340
341    #[test]
342    fn test_try_for_each_in_pool() {
343        let pool = create_thread_pool(4).unwrap();
344        let iter = (0..100).into_par_iter();
345        let result = iter.try_for_each_in_pool(&pool, |i| {
346            assert_run_in_rayon_thread();
347            if i < 50 { Ok(()) } else { Err("Error") }
348        });
349        assert!(result.is_err());
350    }
351
352    #[test]
353    fn test_try_for_each_init_in_pool() {
354        let pool = create_thread_pool(4).unwrap();
355        let iter = (0..100).into_par_iter();
356        let result = iter.try_for_each_init_in_pool(
357            &pool,
358            || 0,
359            |_, i| {
360                assert_run_in_rayon_thread();
361                if i < 50 { Ok(()) } else { Err("Error") }
362            },
363        );
364        assert!(result.is_err());
365    }
366
367    #[test]
368    fn test_try_for_each_with_in_pool() {
369        let pool = create_thread_pool(4).unwrap();
370        let iter = (0..100).into_par_iter();
371        let result = iter.try_for_each_with_in_pool(&pool, 0, |acc, i| {
372            assert_run_in_rayon_thread();
373            if i < 50 {
374                *acc += i;
375                Ok(())
376            } else {
377                Err("Error")
378            }
379        });
380        assert!(result.is_err());
381    }
382
383    #[test]
384    fn test_count_in_pool() {
385        let pool = create_thread_pool(4).unwrap();
386        let iter = (0..100).into_par_iter();
387        let count = iter.count_in_pool(&pool);
388        assert_eq!(count, 100);
389    }
390
391    #[test]
392    fn test_collect_in_pool() {
393        let pool = create_thread_pool(4).unwrap();
394        let iter = (0..100).into_par_iter();
395        let vec = iter.collect_in_pool::<Vec<_>>(&pool);
396        assert_eq!(vec.len(), 100);
397    }
398
399    #[test]
400    fn test_sum_in_pool() {
401        let pool = create_thread_pool(4).unwrap();
402        let iter = (0..100).into_par_iter();
403        let sum: i32 = iter.sum_in_pool(&pool);
404        assert_eq!(sum, (0..100).sum::<i32>());
405    }
406}
407
408#[cfg(test)]
409mod as_threadpool_tests {
410    use super::*;
411
412    fn some_parallel_op<P: AsThreadPool>(pool: P) -> ANNResult<f32> {
413        forward_threadpool!(pool = pool);
414
415        let ret = (0..100).into_par_iter().map(|i| i as f32).sum_in_pool(pool);
416        Ok(ret)
417    }
418
419    fn another_parallel_op<P: AsThreadPool>(pool: P) -> ANNResult<f32> {
420        forward_threadpool!(pool = pool);
421        let ret = (0..100).into_par_iter().map(|i| i as f32).sum_in_pool(pool);
422        Ok(ret)
423    }
424
425    fn execute_single_parallel_op<P: AsThreadPool>(pool: P) -> ANNResult<f32> {
426        // Directly pass the thread pool to the function.
427        some_parallel_op(pool)
428    }
429
430    fn execute_two_parallel_ops<P: AsThreadPool>(pool: P) -> ANNResult<f32> {
431        // Need a reference to the thread pool to share it with multiple functions.
432        forward_threadpool!(pool = pool);
433
434        let ret1 = some_parallel_op(pool)?;
435        let ret2 = another_parallel_op(pool)?;
436        Ok(ret1 + ret2)
437    }
438
439    fn execute_combined_parallel_ops<P: AsThreadPool>(pool: P) -> ANNResult<f32> {
440        // Need a Threadpool reference to execute the operations.
441        forward_threadpool!(pool = pool);
442
443        let ret1: f32 = (0..100).into_par_iter().map(|i| i as f32).sum_in_pool(pool);
444        let ret2 = some_parallel_op(pool)?;
445        Ok(ret1 + ret2)
446    }
447
448    #[test]
449    fn test_execute_single_parallel_op_with_usize() {
450        let num_threads = 4;
451        let result = execute_single_parallel_op(num_threads);
452        assert!(result.is_ok());
453        assert!(result.unwrap() > 0.0);
454    }
455
456    #[test]
457    fn test_execute_single_parallel_op_with_existing_pool() {
458        let pool = create_thread_pool(4).unwrap();
459        let result = execute_single_parallel_op(&pool);
460        assert!(result.is_ok());
461        assert!(result.unwrap() > 0.0);
462    }
463
464    #[test]
465    fn test_execute_two_parallel_ops_with_usize() {
466        let num_threads = 4;
467        let result = execute_two_parallel_ops(num_threads);
468        assert!(result.is_ok());
469        assert!(result.unwrap() > 0.0);
470    }
471
472    #[test]
473    fn test_execute_two_parallel_ops_with_existing_pool() {
474        let pool = create_thread_pool(4).unwrap();
475        let result = execute_two_parallel_ops(&pool);
476        assert!(result.is_ok());
477        assert!(result.unwrap() > 0.0);
478    }
479
480    #[test]
481    fn test_execute_combined_parallel_ops_with_usize() {
482        let num_threads = 4;
483        let result = execute_combined_parallel_ops(num_threads);
484        assert!(result.is_ok());
485        assert!(result.unwrap() > 0.0);
486    }
487
488    #[test]
489    fn test_execute_combined_parallel_ops_with_existing_pool() {
490        let pool = create_thread_pool(4).unwrap();
491        let result = execute_combined_parallel_ops(&pool);
492        assert!(result.is_ok());
493        assert!(result.unwrap() > 0.0);
494    }
495}