dynamo_runtime/compute/pool.rs
1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Compute pool implementation with tokio-rayon integration
5//!
6//! The `ComputePool` allows multiple async tasks to concurrently submit different
7//! types of parallel work to a shared Rayon thread pool. This enables efficient
8//! CPU utilization without manual thread management.
9//!
10//! # Concurrent Usage Example
11//!
12//! ```ignore
13//! use std::sync::Arc;
14//! use dynamo_runtime::compute::ComputePool;
15//! use rayon::prelude::*;
16//!
17//! async fn concurrent_processing(pool: Arc<ComputePool>) {
18//! // Task 1: Using scope for dynamic task generation
19//! let task1 = tokio::spawn({
20//! let pool = pool.clone();
21//! async move {
22//! pool.execute_scoped(|scope| {
23//! // Dynamically spawn tasks based on runtime conditions
24//! for i in 0..100 {
25//! scope.spawn(move |_| {
26//! // CPU-intensive work
27//! let mut sum = 0u64;
28//! for j in 0..1000 {
29//! sum += (i * j) as u64;
30//! }
31//! sum
32//! });
33//! }
34//! }).await
35//! }
36//! });
37//!
38//! // Task 2: Using parallel iterators for batch processing
39//! let task2 = tokio::spawn({
40//! let pool = pool.clone();
41//! async move {
42//! let data: Vec<u32> = (0..10000).collect();
43//! pool.install(|| {
44//! data.par_chunks(100)
45//! .map(|chunk| chunk.iter().sum::<u32>())
46//! .collect::<Vec<_>>()
47//! }).await
48//! }
49//! });
50//!
51//! // Both tasks run concurrently, sharing the same thread pool
52//! let (result1, result2) = tokio::join!(task1, task2);
53//! }
54//! ```
55//!
56//! # Thread Pool Sharing
57//!
58//! The Rayon thread pool uses work-stealing to efficiently distribute work from
59//! multiple concurrent sources:
60//!
61//! - Tasks from `scope.spawn()` are pushed to thread-local deques
62//! - Parallel iterators distribute work across all threads
63//! - Idle threads steal work from busy threads
64//! - No coordination needed between different parallelization patterns
65
66use super::{ComputeConfig, ComputeMetrics};
67use anyhow::Result;
68use async_trait::async_trait;
69use std::future::Future;
70use std::pin::Pin;
71use std::sync::Arc;
72use std::task::{Context, Poll};
73
74/// A compute pool that manages CPU-intensive operations
75#[derive(Clone)]
76pub struct ComputePool {
77 /// The underlying Rayon thread pool
78 pool: Arc<rayon::ThreadPool>,
79
80 /// Metrics for monitoring compute operations
81 metrics: Arc<ComputeMetrics>,
82
83 /// Configuration used to create this pool
84 config: ComputeConfig,
85}
86
87impl std::fmt::Debug for ComputePool {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("ComputePool")
90 .field("num_threads", &self.pool.current_num_threads())
91 .field("metrics", &self.metrics)
92 .field("config", &self.config)
93 .finish()
94 }
95}
96
97impl ComputePool {
98 /// Create a new compute pool with the given configuration
99 pub fn new(config: ComputeConfig) -> Result<Self> {
100 let pool = config.build_pool()?;
101 let metrics = Arc::new(ComputeMetrics::new());
102
103 Ok(Self {
104 pool: Arc::new(pool),
105 metrics,
106 config,
107 })
108 }
109
110 /// Create a compute pool with default configuration
111 pub fn with_defaults() -> Result<Self> {
112 Self::new(ComputeConfig::default())
113 }
114
115 /// Execute a synchronous computation on the thread pool
116 ///
117 /// This method is designed to be called from within `spawn_blocking` or other
118 /// synchronous contexts. It has minimal overhead as it directly uses Rayon
119 /// without the async bridge.
120 ///
121 /// # Example
122 /// ```ignore
123 /// # use dynamo_runtime::compute::ComputePool;
124 /// # let pool = ComputePool::new(Default::default()).unwrap();
125 /// tokio::task::spawn_blocking(move || {
126 /// pool.execute_sync(|| {
127 /// // CPU-intensive work
128 /// expensive_computation()
129 /// })
130 /// });
131 /// ```
132 pub fn execute_sync<F, R>(&self, f: F) -> R
133 where
134 F: FnOnce() -> R + Send,
135 R: Send,
136 {
137 self.pool.install(f)
138 }
139
140 /// Execute a compute task in the Rayon pool
141 ///
142 /// This bridges from async context to the Rayon thread pool,
143 /// allowing CPU-intensive work to run without blocking Tokio workers.
144 ///
145 /// Note: This method has ~25μs overhead for small tasks due to the async
146 /// channel communication. For very small computations (<100μs), consider
147 /// running directly on Tokio or using `spawn_blocking` with `execute_sync`.
148 pub async fn execute<F, R>(&self, f: F) -> Result<R>
149 where
150 F: FnOnce() -> R + Send + 'static,
151 R: Send + 'static,
152 {
153 self.metrics.record_task_start();
154 let start = std::time::Instant::now();
155
156 // Use tokio-rayon to bridge to the compute pool
157 let pool = self.pool.clone();
158 let result = tokio_rayon::spawn(move || pool.install(f)).await;
159
160 self.metrics.record_task_completion(start.elapsed());
161 Ok(result)
162 }
163
164 /// Execute a function with a Rayon scope
165 ///
166 /// This allows spawning multiple parallel tasks within the scope,
167 /// with the guarantee that all tasks complete before returning.
168 pub async fn execute_scoped<F, R>(&self, f: F) -> Result<R>
169 where
170 F: FnOnce(&rayon::Scope) -> R + Send + 'static,
171 R: Send + 'static,
172 {
173 self.metrics.record_task_start();
174 let start = std::time::Instant::now();
175
176 let pool = self.pool.clone();
177 let result = tokio_rayon::spawn(move || {
178 pool.install(|| {
179 let mut result = None;
180 rayon::scope(|s| {
181 result = Some(f(s));
182 });
183 result.unwrap()
184 })
185 })
186 .await;
187
188 self.metrics.record_task_completion(start.elapsed());
189 Ok(result)
190 }
191
192 /// Execute a function with a FIFO scope
193 ///
194 /// Similar to execute_scoped, but tasks are prioritized in FIFO order
195 /// rather than the default LIFO order.
196 pub async fn execute_scoped_fifo<F, R>(&self, f: F) -> Result<R>
197 where
198 F: FnOnce(&rayon::ScopeFifo) -> R + Send + 'static,
199 R: Send + 'static,
200 {
201 self.metrics.record_task_start();
202 let start = std::time::Instant::now();
203
204 let pool = self.pool.clone();
205 let result = tokio_rayon::spawn(move || {
206 pool.install(|| {
207 let mut result = None;
208 rayon::scope_fifo(|s| {
209 result = Some(f(s));
210 });
211 result.unwrap()
212 })
213 })
214 .await;
215
216 self.metrics.record_task_completion(start.elapsed());
217 Ok(result)
218 }
219
220 /// Join two computations in parallel
221 pub async fn join<F1, F2, R1, R2>(&self, f1: F1, f2: F2) -> Result<(R1, R2)>
222 where
223 F1: FnOnce() -> R1 + Send + 'static,
224 F2: FnOnce() -> R2 + Send + 'static,
225 R1: Send + 'static,
226 R2: Send + 'static,
227 {
228 self.execute(move || rayon::join(f1, f2)).await
229 }
230
231 /// Get metrics for this compute pool
232 pub fn metrics(&self) -> &ComputeMetrics {
233 &self.metrics
234 }
235
236 /// Get the number of threads in the pool
237 pub fn num_threads(&self) -> usize {
238 self.pool.current_num_threads()
239 }
240
241 /// Install this pool as the Rayon pool for the given closure
242 ///
243 /// This method is essential for using Rayon's parallel iterators (like `par_iter`,
244 /// `par_chunks`, etc.) with this specific thread pool. Any parallel iterator
245 /// operations within the closure will execute on this pool's threads.
246 ///
247 /// # Example
248 ///
249 /// ```ignore
250 /// use rayon::prelude::*;
251 ///
252 /// // Process data using parallel iterators
253 /// let result = pool.install(|| {
254 /// data.par_chunks(100)
255 /// .map(|chunk| process_chunk(chunk))
256 /// .collect::<Vec<_>>()
257 /// }).await?;
258 /// ```
259 ///
260 /// # Concurrent Usage
261 ///
262 /// Multiple async tasks can call `install()` concurrently on the same pool.
263 /// The Rayon work-stealing scheduler will efficiently distribute work from
264 /// all concurrent operations:
265 ///
266 /// ```ignore
267 /// // These can run concurrently without interference
268 /// let task1 = pool.install(|| data1.par_iter().map(f1).collect());
269 /// let task2 = pool.install(|| data2.par_chunks(50).map(f2).collect());
270 /// ```
271 pub async fn install<F, R>(&self, f: F) -> Result<R>
272 where
273 F: FnOnce() -> R + Send + 'static,
274 R: Send + 'static,
275 {
276 let pool = self.pool.clone();
277 self.metrics.record_task_start();
278 let start = std::time::Instant::now();
279
280 let result = tokio_rayon::spawn(move || pool.install(f)).await;
281
282 self.metrics.record_task_completion(start.elapsed());
283 Ok(result)
284 }
285}
286
287/// A handle to a compute task that's currently running
288pub struct ComputeHandle<T> {
289 inner: Pin<Box<dyn Future<Output = T> + Send>>,
290}
291
292impl<T> ComputeHandle<T> {
293 /// Create a new compute handle from a future
294 pub(crate) fn new<F>(future: F) -> Self
295 where
296 F: Future<Output = T> + Send + 'static,
297 {
298 Self {
299 inner: Box::pin(future),
300 }
301 }
302}
303
304impl<T> Future for ComputeHandle<T> {
305 type Output = T;
306
307 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
308 self.inner.as_mut().poll(cx)
309 }
310}
311
312/// Extension trait for ComputePool with additional patterns
313#[async_trait]
314pub trait ComputePoolExt {
315 /// Process items in parallel batches
316 async fn parallel_batch<T, F, R>(
317 &self,
318 items: Vec<T>,
319 batch_size: usize,
320 f: F,
321 ) -> Result<Vec<R>>
322 where
323 T: Send + Sync + 'static,
324 F: Fn(&[T]) -> Vec<R> + Send + Sync + 'static,
325 R: Send + 'static;
326
327 /// Map over items in parallel using Rayon's par_iter
328 async fn parallel_map<T, F, R>(&self, items: Vec<T>, f: F) -> Result<Vec<R>>
329 where
330 T: Send + Sync + 'static,
331 F: Fn(T) -> R + Send + Sync + 'static,
332 R: Send + 'static;
333}
334
335#[async_trait]
336impl ComputePoolExt for ComputePool {
337 async fn parallel_batch<T, F, R>(
338 &self,
339 items: Vec<T>,
340 batch_size: usize,
341 f: F,
342 ) -> Result<Vec<R>>
343 where
344 T: Send + Sync + 'static,
345 F: Fn(&[T]) -> Vec<R> + Send + Sync + 'static,
346 R: Send + 'static,
347 {
348 use rayon::prelude::*;
349
350 self.install(move || items.par_chunks(batch_size).flat_map(f).collect())
351 .await
352 }
353
354 async fn parallel_map<T, F, R>(&self, items: Vec<T>, f: F) -> Result<Vec<R>>
355 where
356 T: Send + Sync + 'static,
357 F: Fn(T) -> R + Send + Sync + 'static,
358 R: Send + 'static,
359 {
360 use rayon::prelude::*;
361
362 self.install(move || items.into_par_iter().map(f).collect())
363 .await
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use parking_lot::Mutex;
371
372 #[tokio::test]
373 async fn test_compute_pool_execute() {
374 let pool = ComputePool::with_defaults().unwrap();
375
376 let result = pool
377 .execute(|| {
378 // Simulate CPU-intensive work
379 let mut sum = 0u64;
380 for i in 0..1000 {
381 sum += i;
382 }
383 sum
384 })
385 .await
386 .unwrap();
387
388 assert_eq!(result, 499500);
389 }
390
391 #[tokio::test]
392 async fn test_compute_pool_join() {
393 let pool = ComputePool::with_defaults().unwrap();
394
395 let (a, b) = pool.join(|| 2 + 2, || 3 * 3).await.unwrap();
396
397 assert_eq!(a, 4);
398 assert_eq!(b, 9);
399 }
400
401 #[tokio::test]
402 async fn test_compute_pool_execute_sync() {
403 let pool = Arc::new(ComputePool::with_defaults().unwrap());
404
405 // Test using execute_sync from spawn_blocking
406 let pool_clone = pool.clone();
407 let result = tokio::task::spawn_blocking(move || {
408 pool_clone.execute_sync(|| {
409 let mut sum = 0u64;
410 for i in 0..1000 {
411 sum += i;
412 }
413 sum
414 })
415 })
416 .await
417 .unwrap();
418
419 assert_eq!(result, 499500);
420 }
421
422 #[tokio::test]
423 async fn test_compute_pool_scoped() {
424 use std::sync::mpsc;
425
426 let pool = ComputePool::with_defaults().unwrap();
427
428 let mut result = pool
429 .execute_scoped(|scope| {
430 let (tx, rx) = mpsc::channel();
431
432 for i in 0..4 {
433 let tx = tx.clone();
434 scope.spawn(move |_| {
435 tx.send((i, i * 2)).unwrap();
436 });
437 }
438
439 drop(tx); // Close sender so receiver can finish
440
441 let mut results = vec![0; 4];
442 for (i, val) in rx {
443 results[i] = val;
444 }
445 results
446 })
447 .await
448 .unwrap();
449
450 // Results may be in any order due to parallel execution
451 result.sort();
452 assert_eq!(result, vec![0, 2, 4, 6]);
453 }
454}