1use std::ops::Range;
6
7use diskann::{ANNError, ANNResult};
8use rayon::prelude::{IntoParallelIterator, ParallelIterator};
9
10#[inline]
12pub fn execute_with_rayon<F>(range: Range<usize>, num_threads: usize, f: F) -> ANNResult<()>
13where
14 F: Fn(usize) -> ANNResult<()> + Sync + Send + Copy,
15{
16 if num_threads == 1 {
17 for i in range {
18 f(i)?;
19 }
20 Ok(())
21 } else {
22 let pool = create_thread_pool(num_threads)?;
23 range.into_par_iter().try_for_each_in_pool(&pool, f)
24 }
25}
26
27pub fn create_thread_pool(num_threads: usize) -> ANNResult<RayonThreadPool> {
30 let pool = rayon::ThreadPoolBuilder::new()
31 .num_threads(num_threads)
32 .build()
33 .map_err(|err| ANNError::log_thread_pool_error(err.to_string()))?;
34 Ok(RayonThreadPool(pool))
35}
36
37#[allow(clippy::unwrap_used)]
41pub fn create_thread_pool_for_test() -> RayonThreadPool {
42 use std::env;
43
44 let num_threads = env::var("DISKANN_TEST_POOL_THREADS")
45 .ok()
46 .and_then(|val| val.parse().ok())
47 .unwrap_or(3);
48
49 create_thread_pool(num_threads).unwrap()
50}
51#[allow(clippy::unwrap_used)]
56pub fn create_thread_pool_for_bench() -> RayonThreadPool {
57 let pool = rayon::ThreadPoolBuilder::new()
58 .build()
59 .map_err(|err| ANNError::log_thread_pool_error(err.to_string()))
60 .unwrap();
61 RayonThreadPool(pool)
62}
63
64pub struct RayonThreadPool(rayon::ThreadPool);
65
66impl RayonThreadPool {
67 pub fn install<OP, R>(&self, op: OP) -> R
68 where
69 OP: FnOnce() -> R + Send,
70 R: Send,
71 {
72 self.0.install(op)
73 }
74}
75
76mod sealed {
77 pub trait Sealed {}
78}
79
80pub trait AsThreadPool: sealed::Sealed + Send + Sync {
87 type Returns: std::ops::Deref<Target = RayonThreadPool>;
88 fn as_threadpool(&self) -> ANNResult<Self::Returns>;
89}
90
91impl sealed::Sealed for usize {}
92impl sealed::Sealed for &RayonThreadPool {}
93
94impl AsThreadPool for usize {
95 type Returns = diskann_utils::reborrow::Place<RayonThreadPool>;
96 fn as_threadpool(&self) -> ANNResult<Self::Returns> {
97 create_thread_pool(*self).map(diskann_utils::reborrow::Place)
98 }
99}
100
101impl<'a> AsThreadPool for &'a RayonThreadPool {
102 type Returns = &'a RayonThreadPool;
103 fn as_threadpool(&self) -> ANNResult<Self::Returns> {
104 Ok(self)
105 }
106}
107
108#[macro_export]
111macro_rules! forward_threadpool {
112 ($out:ident = $in:ident) => {
113 $crate::forward_threadpool!($out = $in: _);
114 };
115 ($out:ident = $in:ident: $type:ty) => {
116 let $out = &*<$type as $crate::utils::AsThreadPool>::as_threadpool(&$in)?;
117 };
118}
119
120#[allow(clippy::disallowed_methods)]
124pub trait ParallelIteratorInPool: ParallelIterator + Sized {
125 fn for_each_in_pool<OP>(self, pool: &RayonThreadPool, op: OP)
126 where
127 OP: Fn(Self::Item) + Sync + Send,
128 {
129 pool.install(|| self.for_each(op));
130 }
131
132 fn for_each_with_in_pool<OP, T>(self, pool: &RayonThreadPool, init: T, op: OP)
133 where
134 OP: Fn(&mut T, Self::Item) + Sync + Send,
135 T: Send + Clone,
136 {
137 pool.install(|| self.for_each_with(init, op))
138 }
139
140 fn for_each_init_in_pool<OP, INIT, T>(self, pool: &RayonThreadPool, init: INIT, op: OP)
141 where
142 OP: Fn(&mut T, Self::Item) + Sync + Send,
143 INIT: Fn() -> T + Sync + Send,
144 {
145 pool.install(|| self.for_each_init(init, op))
146 }
147
148 fn try_for_each_in_pool<OP, E>(self, pool: &RayonThreadPool, op: OP) -> Result<(), E>
149 where
150 OP: Fn(Self::Item) -> Result<(), E> + Sync + Send,
151 E: Send,
152 {
153 pool.install(|| self.try_for_each(op))
154 }
155
156 fn try_for_each_with_in_pool<OP, T, E>(
157 self,
158 pool: &RayonThreadPool,
159 init: T,
160 op: OP,
161 ) -> Result<(), E>
162 where
163 OP: Fn(&mut T, Self::Item) -> Result<(), E> + Sync + Send,
164 E: Send,
165 T: Send + Clone,
166 {
167 pool.install(|| self.try_for_each_with(init, op))
168 }
169
170 fn try_for_each_init_in_pool<OP, INIT, T, E>(
171 self,
172 pool: &RayonThreadPool,
173 init: INIT,
174 op: OP,
175 ) -> Result<(), E>
176 where
177 OP: Fn(&mut T, Self::Item) -> Result<(), E> + Sync + Send,
178 INIT: Fn() -> T + Sync + Send,
179 E: Send,
180 {
181 pool.install(|| self.try_for_each_init(init, op))
182 }
183
184 fn count_in_pool(self, pool: &RayonThreadPool) -> usize {
185 pool.install(|| self.count())
186 }
187
188 fn collect_in_pool<C>(self, pool: &RayonThreadPool) -> C
189 where
190 C: rayon::iter::FromParallelIterator<Self::Item> + Send,
191 {
192 pool.install(|| self.collect())
193 }
194
195 fn sum_in_pool<S>(self, pool: &RayonThreadPool) -> S
196 where
197 S: Send + std::iter::Sum<Self::Item> + std::iter::Sum<S>,
198 {
199 pool.install(|| self.sum())
200 }
201}
202
203impl<T> ParallelIteratorInPool for T where T: ParallelIterator {}
205
206#[cfg(test)]
207mod tests {
208 use std::sync::{Mutex, mpsc::channel};
209
210 use super::*;
211
212 fn get_num_cpus() -> usize {
213 std::thread::available_parallelism()
214 .map(|n| n.get())
215 .unwrap()
216 }
217
218 #[test]
219 fn test_create_thread_pool_for_test_default() {
220 unsafe { std::env::remove_var("DISKANN_TEST_POOL_THREADS") };
225 let pool = create_thread_pool_for_test();
226 assert_eq!(pool.0.current_num_threads(), 3);
228 }
229
230 #[test]
231 fn test_create_thread_pool_for_test_from_env() {
232 unsafe { std::env::set_var("DISKANN_TEST_POOL_THREADS", "5") };
237 let pool = create_thread_pool_for_test();
238 assert_eq!(pool.0.current_num_threads(), 5);
240
241 unsafe { std::env::remove_var("DISKANN_TEST_POOL_THREADS") };
246 }
247
248 #[test]
249 fn test_create_thread_pool_for_test_invalid_env() {
250 unsafe { std::env::set_var("DISKANN_TEST_POOL_THREADS", "invalid") };
255 let pool = create_thread_pool_for_test();
256 assert_eq!(pool.0.current_num_threads(), 3);
258
259 unsafe { std::env::remove_var("DISKANN_TEST_POOL_THREADS") };
264 }
265
266 #[test]
267 fn test_create_thread_pool_for_bench() {
268 let pool = create_thread_pool_for_bench();
269 assert_eq!(pool.0.current_num_threads(), get_num_cpus());
270 }
271
272 fn assert_run_in_rayon_thread() {
273 println!(
274 "Thread name: {:?}, Thread id: {:?}, Rayon thread index: {:?}, Rayon num_threads: {:?}",
275 std::thread::current().name(),
276 std::thread::current().id(),
277 rayon::current_thread_index(),
278 rayon::current_num_threads()
279 );
280 assert!(rayon::current_thread_index().is_some());
281 }
282
283 #[test]
284 fn test_for_each_in_pool() {
285 let pool = create_thread_pool(4).unwrap();
286
287 let res = Mutex::new(Vec::new());
288 (0..5).into_par_iter().for_each_in_pool(&pool, |x| {
289 let mut res = res.lock().unwrap();
290 res.push(x);
291 assert_run_in_rayon_thread();
292 });
293
294 let mut res = res.lock().unwrap();
295 res.sort();
296
297 assert_eq!(&res[..], &[0, 1, 2, 3, 4]);
298 }
299 #[test]
300 fn test_for_each_with_in_pool() {
301 let pool = create_thread_pool(4).unwrap();
302 let (sender, receiver) = channel();
303
304 (0..5)
305 .into_par_iter()
306 .for_each_with_in_pool(&pool, sender, |s, x| s.send(x).unwrap());
307
308 let mut res: Vec<_> = receiver.iter().collect();
309
310 res.sort();
311
312 assert_eq!(&res[..], &[0, 1, 2, 3, 4]);
313 }
314
315 #[test]
316 fn test_for_each_init_in_pool() {
317 let pool = create_thread_pool(4).unwrap();
318 let iter = (0..100).into_par_iter();
319 iter.for_each_init_in_pool(
320 &pool,
321 || 0,
322 |s, i| {
323 assert_run_in_rayon_thread();
324 *s += i;
325 },
326 );
327 }
328
329 #[test]
330 fn test_map_in_pool() {
331 let pool = create_thread_pool(4).unwrap();
332 let iter = (0..100).into_par_iter();
333 let mapped_iter = iter.map(|i| {
334 assert_run_in_rayon_thread();
335 i as f32
336 });
337 let list = mapped_iter.collect_in_pool::<Vec<f32>>(&pool);
338 assert!(list.len() == 100);
339 }
340
341 #[test]
342 fn test_try_for_each_in_pool() {
343 let pool = create_thread_pool(4).unwrap();
344 let iter = (0..100).into_par_iter();
345 let result = iter.try_for_each_in_pool(&pool, |i| {
346 assert_run_in_rayon_thread();
347 if i < 50 { Ok(()) } else { Err("Error") }
348 });
349 assert!(result.is_err());
350 }
351
352 #[test]
353 fn test_try_for_each_init_in_pool() {
354 let pool = create_thread_pool(4).unwrap();
355 let iter = (0..100).into_par_iter();
356 let result = iter.try_for_each_init_in_pool(
357 &pool,
358 || 0,
359 |_, i| {
360 assert_run_in_rayon_thread();
361 if i < 50 { Ok(()) } else { Err("Error") }
362 },
363 );
364 assert!(result.is_err());
365 }
366
367 #[test]
368 fn test_try_for_each_with_in_pool() {
369 let pool = create_thread_pool(4).unwrap();
370 let iter = (0..100).into_par_iter();
371 let result = iter.try_for_each_with_in_pool(&pool, 0, |acc, i| {
372 assert_run_in_rayon_thread();
373 if i < 50 {
374 *acc += i;
375 Ok(())
376 } else {
377 Err("Error")
378 }
379 });
380 assert!(result.is_err());
381 }
382
383 #[test]
384 fn test_count_in_pool() {
385 let pool = create_thread_pool(4).unwrap();
386 let iter = (0..100).into_par_iter();
387 let count = iter.count_in_pool(&pool);
388 assert_eq!(count, 100);
389 }
390
391 #[test]
392 fn test_collect_in_pool() {
393 let pool = create_thread_pool(4).unwrap();
394 let iter = (0..100).into_par_iter();
395 let vec = iter.collect_in_pool::<Vec<_>>(&pool);
396 assert_eq!(vec.len(), 100);
397 }
398
399 #[test]
400 fn test_sum_in_pool() {
401 let pool = create_thread_pool(4).unwrap();
402 let iter = (0..100).into_par_iter();
403 let sum: i32 = iter.sum_in_pool(&pool);
404 assert_eq!(sum, (0..100).sum::<i32>());
405 }
406}
407
408#[cfg(test)]
409mod as_threadpool_tests {
410 use super::*;
411
412 fn some_parallel_op<P: AsThreadPool>(pool: P) -> ANNResult<f32> {
413 forward_threadpool!(pool = pool);
414
415 let ret = (0..100).into_par_iter().map(|i| i as f32).sum_in_pool(pool);
416 Ok(ret)
417 }
418
419 fn another_parallel_op<P: AsThreadPool>(pool: P) -> ANNResult<f32> {
420 forward_threadpool!(pool = pool);
421 let ret = (0..100).into_par_iter().map(|i| i as f32).sum_in_pool(pool);
422 Ok(ret)
423 }
424
425 fn execute_single_parallel_op<P: AsThreadPool>(pool: P) -> ANNResult<f32> {
426 some_parallel_op(pool)
428 }
429
430 fn execute_two_parallel_ops<P: AsThreadPool>(pool: P) -> ANNResult<f32> {
431 forward_threadpool!(pool = pool);
433
434 let ret1 = some_parallel_op(pool)?;
435 let ret2 = another_parallel_op(pool)?;
436 Ok(ret1 + ret2)
437 }
438
439 fn execute_combined_parallel_ops<P: AsThreadPool>(pool: P) -> ANNResult<f32> {
440 forward_threadpool!(pool = pool);
442
443 let ret1: f32 = (0..100).into_par_iter().map(|i| i as f32).sum_in_pool(pool);
444 let ret2 = some_parallel_op(pool)?;
445 Ok(ret1 + ret2)
446 }
447
448 #[test]
449 fn test_execute_single_parallel_op_with_usize() {
450 let num_threads = 4;
451 let result = execute_single_parallel_op(num_threads);
452 assert!(result.is_ok());
453 assert!(result.unwrap() > 0.0);
454 }
455
456 #[test]
457 fn test_execute_single_parallel_op_with_existing_pool() {
458 let pool = create_thread_pool(4).unwrap();
459 let result = execute_single_parallel_op(&pool);
460 assert!(result.is_ok());
461 assert!(result.unwrap() > 0.0);
462 }
463
464 #[test]
465 fn test_execute_two_parallel_ops_with_usize() {
466 let num_threads = 4;
467 let result = execute_two_parallel_ops(num_threads);
468 assert!(result.is_ok());
469 assert!(result.unwrap() > 0.0);
470 }
471
472 #[test]
473 fn test_execute_two_parallel_ops_with_existing_pool() {
474 let pool = create_thread_pool(4).unwrap();
475 let result = execute_two_parallel_ops(&pool);
476 assert!(result.is_ok());
477 assert!(result.unwrap() > 0.0);
478 }
479
480 #[test]
481 fn test_execute_combined_parallel_ops_with_usize() {
482 let num_threads = 4;
483 let result = execute_combined_parallel_ops(num_threads);
484 assert!(result.is_ok());
485 assert!(result.unwrap() > 0.0);
486 }
487
488 #[test]
489 fn test_execute_combined_parallel_ops_with_existing_pool() {
490 let pool = create_thread_pool(4).unwrap();
491 let result = execute_combined_parallel_ops(&pool);
492 assert!(result.is_ok());
493 assert!(result.unwrap() > 0.0);
494 }
495}