Skip to main content

dynamo_runtime/compute/
pool.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Compute pool implementation with tokio-rayon integration
5//!
6//! The `ComputePool` allows multiple async tasks to concurrently submit different
7//! types of parallel work to a shared Rayon thread pool. This enables efficient
8//! CPU utilization without manual thread management.
9//!
10//! # Concurrent Usage Example
11//!
12//! ```ignore
13//! use std::sync::Arc;
14//! use dynamo_runtime::compute::ComputePool;
15//! use rayon::prelude::*;
16//!
17//! async fn concurrent_processing(pool: Arc<ComputePool>) {
18//!     // Task 1: Using scope for dynamic task generation
19//!     let task1 = tokio::spawn({
20//!         let pool = pool.clone();
21//!         async move {
22//!             pool.execute_scoped(|scope| {
23//!                 // Dynamically spawn tasks based on runtime conditions
24//!                 for i in 0..100 {
25//!                     scope.spawn(move |_| {
26//!                         // CPU-intensive work
27//!                         let mut sum = 0u64;
28//!                         for j in 0..1000 {
29//!                             sum += (i * j) as u64;
30//!                         }
31//!                         sum
32//!                     });
33//!                 }
34//!             }).await
35//!         }
36//!     });
37//!
38//!     // Task 2: Using parallel iterators for batch processing
39//!     let task2 = tokio::spawn({
40//!         let pool = pool.clone();
41//!         async move {
42//!             let data: Vec<u32> = (0..10000).collect();
43//!             pool.install(|| {
44//!                 data.par_chunks(100)
45//!                     .map(|chunk| chunk.iter().sum::<u32>())
46//!                     .collect::<Vec<_>>()
47//!             }).await
48//!         }
49//!     });
50//!
51//!     // Both tasks run concurrently, sharing the same thread pool
52//!     let (result1, result2) = tokio::join!(task1, task2);
53//! }
54//! ```
55//!
56//! # Thread Pool Sharing
57//!
58//! The Rayon thread pool uses work-stealing to efficiently distribute work from
59//! multiple concurrent sources:
60//!
61//! - Tasks from `scope.spawn()` are pushed to thread-local deques
62//! - Parallel iterators distribute work across all threads
63//! - Idle threads steal work from busy threads
64//! - No coordination needed between different parallelization patterns
65
66use super::{ComputeConfig, ComputeMetrics};
67use anyhow::Result;
68use async_trait::async_trait;
69use std::future::Future;
70use std::pin::Pin;
71use std::sync::Arc;
72use std::task::{Context, Poll};
73
74/// A compute pool that manages CPU-intensive operations
75#[derive(Clone)]
76pub struct ComputePool {
77    /// The underlying Rayon thread pool
78    pool: Arc<rayon::ThreadPool>,
79
80    /// Metrics for monitoring compute operations
81    metrics: Arc<ComputeMetrics>,
82
83    /// Configuration used to create this pool
84    config: ComputeConfig,
85}
86
87impl std::fmt::Debug for ComputePool {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("ComputePool")
90            .field("num_threads", &self.pool.current_num_threads())
91            .field("metrics", &self.metrics)
92            .field("config", &self.config)
93            .finish()
94    }
95}
96
97impl ComputePool {
98    /// Create a new compute pool with the given configuration
99    pub fn new(config: ComputeConfig) -> Result<Self> {
100        let pool = config.build_pool()?;
101        let metrics = Arc::new(ComputeMetrics::new());
102
103        Ok(Self {
104            pool: Arc::new(pool),
105            metrics,
106            config,
107        })
108    }
109
110    /// Create a compute pool with default configuration
111    pub fn with_defaults() -> Result<Self> {
112        Self::new(ComputeConfig::default())
113    }
114
115    /// Execute a synchronous computation on the thread pool
116    ///
117    /// This method is designed to be called from within `spawn_blocking` or other
118    /// synchronous contexts. It has minimal overhead as it directly uses Rayon
119    /// without the async bridge.
120    ///
121    /// # Example
122    /// ```ignore
123    /// # use dynamo_runtime::compute::ComputePool;
124    /// # let pool = ComputePool::new(Default::default()).unwrap();
125    /// tokio::task::spawn_blocking(move || {
126    ///     pool.execute_sync(|| {
127    ///         // CPU-intensive work
128    ///         expensive_computation()
129    ///     })
130    /// });
131    /// ```
132    pub fn execute_sync<F, R>(&self, f: F) -> R
133    where
134        F: FnOnce() -> R + Send,
135        R: Send,
136    {
137        self.pool.install(f)
138    }
139
140    /// Execute a compute task in the Rayon pool
141    ///
142    /// This bridges from async context to the Rayon thread pool,
143    /// allowing CPU-intensive work to run without blocking Tokio workers.
144    ///
145    /// Note: This method has ~25μs overhead for small tasks due to the async
146    /// channel communication. For very small computations (<100μs), consider
147    /// running directly on Tokio or using `spawn_blocking` with `execute_sync`.
148    pub async fn execute<F, R>(&self, f: F) -> Result<R>
149    where
150        F: FnOnce() -> R + Send + 'static,
151        R: Send + 'static,
152    {
153        self.metrics.record_task_start();
154        let start = std::time::Instant::now();
155
156        // Use tokio-rayon to bridge to the compute pool
157        let pool = self.pool.clone();
158        let result = tokio_rayon::spawn(move || pool.install(f)).await;
159
160        self.metrics.record_task_completion(start.elapsed());
161        Ok(result)
162    }
163
164    /// Execute a function with a Rayon scope
165    ///
166    /// This allows spawning multiple parallel tasks within the scope,
167    /// with the guarantee that all tasks complete before returning.
168    pub async fn execute_scoped<F, R>(&self, f: F) -> Result<R>
169    where
170        F: FnOnce(&rayon::Scope) -> R + Send + 'static,
171        R: Send + 'static,
172    {
173        self.metrics.record_task_start();
174        let start = std::time::Instant::now();
175
176        let pool = self.pool.clone();
177        let result = tokio_rayon::spawn(move || {
178            pool.install(|| {
179                let mut result = None;
180                rayon::scope(|s| {
181                    result = Some(f(s));
182                });
183                result.unwrap()
184            })
185        })
186        .await;
187
188        self.metrics.record_task_completion(start.elapsed());
189        Ok(result)
190    }
191
192    /// Execute a function with a FIFO scope
193    ///
194    /// Similar to execute_scoped, but tasks are prioritized in FIFO order
195    /// rather than the default LIFO order.
196    pub async fn execute_scoped_fifo<F, R>(&self, f: F) -> Result<R>
197    where
198        F: FnOnce(&rayon::ScopeFifo) -> R + Send + 'static,
199        R: Send + 'static,
200    {
201        self.metrics.record_task_start();
202        let start = std::time::Instant::now();
203
204        let pool = self.pool.clone();
205        let result = tokio_rayon::spawn(move || {
206            pool.install(|| {
207                let mut result = None;
208                rayon::scope_fifo(|s| {
209                    result = Some(f(s));
210                });
211                result.unwrap()
212            })
213        })
214        .await;
215
216        self.metrics.record_task_completion(start.elapsed());
217        Ok(result)
218    }
219
220    /// Join two computations in parallel
221    pub async fn join<F1, F2, R1, R2>(&self, f1: F1, f2: F2) -> Result<(R1, R2)>
222    where
223        F1: FnOnce() -> R1 + Send + 'static,
224        F2: FnOnce() -> R2 + Send + 'static,
225        R1: Send + 'static,
226        R2: Send + 'static,
227    {
228        self.execute(move || rayon::join(f1, f2)).await
229    }
230
231    /// Get metrics for this compute pool
232    pub fn metrics(&self) -> &ComputeMetrics {
233        &self.metrics
234    }
235
236    /// Get the number of threads in the pool
237    pub fn num_threads(&self) -> usize {
238        self.pool.current_num_threads()
239    }
240
241    /// Install this pool as the Rayon pool for the given closure
242    ///
243    /// This method is essential for using Rayon's parallel iterators (like `par_iter`,
244    /// `par_chunks`, etc.) with this specific thread pool. Any parallel iterator
245    /// operations within the closure will execute on this pool's threads.
246    ///
247    /// # Example
248    ///
249    /// ```ignore
250    /// use rayon::prelude::*;
251    ///
252    /// // Process data using parallel iterators
253    /// let result = pool.install(|| {
254    ///     data.par_chunks(100)
255    ///         .map(|chunk| process_chunk(chunk))
256    ///         .collect::<Vec<_>>()
257    /// }).await?;
258    /// ```
259    ///
260    /// # Concurrent Usage
261    ///
262    /// Multiple async tasks can call `install()` concurrently on the same pool.
263    /// The Rayon work-stealing scheduler will efficiently distribute work from
264    /// all concurrent operations:
265    ///
266    /// ```ignore
267    /// // These can run concurrently without interference
268    /// let task1 = pool.install(|| data1.par_iter().map(f1).collect());
269    /// let task2 = pool.install(|| data2.par_chunks(50).map(f2).collect());
270    /// ```
271    pub async fn install<F, R>(&self, f: F) -> Result<R>
272    where
273        F: FnOnce() -> R + Send + 'static,
274        R: Send + 'static,
275    {
276        let pool = self.pool.clone();
277        self.metrics.record_task_start();
278        let start = std::time::Instant::now();
279
280        let result = tokio_rayon::spawn(move || pool.install(f)).await;
281
282        self.metrics.record_task_completion(start.elapsed());
283        Ok(result)
284    }
285}
286
287/// A handle to a compute task that's currently running
288pub struct ComputeHandle<T> {
289    inner: Pin<Box<dyn Future<Output = T> + Send>>,
290}
291
292impl<T> ComputeHandle<T> {
293    /// Create a new compute handle from a future
294    pub(crate) fn new<F>(future: F) -> Self
295    where
296        F: Future<Output = T> + Send + 'static,
297    {
298        Self {
299            inner: Box::pin(future),
300        }
301    }
302}
303
304impl<T> Future for ComputeHandle<T> {
305    type Output = T;
306
307    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
308        self.inner.as_mut().poll(cx)
309    }
310}
311
312/// Extension trait for ComputePool with additional patterns
313#[async_trait]
314pub trait ComputePoolExt {
315    /// Process items in parallel batches
316    async fn parallel_batch<T, F, R>(
317        &self,
318        items: Vec<T>,
319        batch_size: usize,
320        f: F,
321    ) -> Result<Vec<R>>
322    where
323        T: Send + Sync + 'static,
324        F: Fn(&[T]) -> Vec<R> + Send + Sync + 'static,
325        R: Send + 'static;
326
327    /// Map over items in parallel using Rayon's par_iter
328    async fn parallel_map<T, F, R>(&self, items: Vec<T>, f: F) -> Result<Vec<R>>
329    where
330        T: Send + Sync + 'static,
331        F: Fn(T) -> R + Send + Sync + 'static,
332        R: Send + 'static;
333}
334
335#[async_trait]
336impl ComputePoolExt for ComputePool {
337    async fn parallel_batch<T, F, R>(
338        &self,
339        items: Vec<T>,
340        batch_size: usize,
341        f: F,
342    ) -> Result<Vec<R>>
343    where
344        T: Send + Sync + 'static,
345        F: Fn(&[T]) -> Vec<R> + Send + Sync + 'static,
346        R: Send + 'static,
347    {
348        use rayon::prelude::*;
349
350        self.install(move || items.par_chunks(batch_size).flat_map(f).collect())
351            .await
352    }
353
354    async fn parallel_map<T, F, R>(&self, items: Vec<T>, f: F) -> Result<Vec<R>>
355    where
356        T: Send + Sync + 'static,
357        F: Fn(T) -> R + Send + Sync + 'static,
358        R: Send + 'static,
359    {
360        use rayon::prelude::*;
361
362        self.install(move || items.into_par_iter().map(f).collect())
363            .await
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use parking_lot::Mutex;
371
372    #[tokio::test]
373    async fn test_compute_pool_execute() {
374        let pool = ComputePool::with_defaults().unwrap();
375
376        let result = pool
377            .execute(|| {
378                // Simulate CPU-intensive work
379                let mut sum = 0u64;
380                for i in 0..1000 {
381                    sum += i;
382                }
383                sum
384            })
385            .await
386            .unwrap();
387
388        assert_eq!(result, 499500);
389    }
390
391    #[tokio::test]
392    async fn test_compute_pool_join() {
393        let pool = ComputePool::with_defaults().unwrap();
394
395        let (a, b) = pool.join(|| 2 + 2, || 3 * 3).await.unwrap();
396
397        assert_eq!(a, 4);
398        assert_eq!(b, 9);
399    }
400
401    #[tokio::test]
402    async fn test_compute_pool_execute_sync() {
403        let pool = Arc::new(ComputePool::with_defaults().unwrap());
404
405        // Test using execute_sync from spawn_blocking
406        let pool_clone = pool.clone();
407        let result = tokio::task::spawn_blocking(move || {
408            pool_clone.execute_sync(|| {
409                let mut sum = 0u64;
410                for i in 0..1000 {
411                    sum += i;
412                }
413                sum
414            })
415        })
416        .await
417        .unwrap();
418
419        assert_eq!(result, 499500);
420    }
421
422    #[tokio::test]
423    async fn test_compute_pool_scoped() {
424        use std::sync::mpsc;
425
426        let pool = ComputePool::with_defaults().unwrap();
427
428        let mut result = pool
429            .execute_scoped(|scope| {
430                let (tx, rx) = mpsc::channel();
431
432                for i in 0..4 {
433                    let tx = tx.clone();
434                    scope.spawn(move |_| {
435                        tx.send((i, i * 2)).unwrap();
436                    });
437                }
438
439                drop(tx); // Close sender so receiver can finish
440
441                let mut results = vec![0; 4];
442                for (i, val) in rx {
443                    results[i] = val;
444                }
445                results
446            })
447            .await
448            .unwrap();
449
450        // Results may be in any order due to parallel execution
451        result.sort();
452        assert_eq!(result, vec![0, 2, 4, 6]);
453    }
454}