Skip to main content

dynamo_runtime/compute/
mod.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Compute module for CPU-intensive operations using Rayon
5//!
6//! This module provides a dedicated compute thread pool for CPU-bound work,
7//! integrating Rayon's fork-join parallelism with Tokio's async runtime.
8//!
9//! Key features:
10//! - Dedicated Rayon thread pool for compute operations
11//! - Seamless async-to-sync bridging via tokio-rayon
12//! - Scope-based parallelism for complex computational graphs
13//! - Metrics and monitoring for compute operations
14//!
15#![doc = include_str!("../../docs/rayon-tokio-strategy.md")]
16
17use anyhow::Result;
18use rayon::ThreadPoolBuilder;
19use std::sync::Arc;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::time::Instant;
22
23pub mod macros;
24pub mod metrics;
25pub mod pool;
26pub mod thread_local;
27#[cfg(feature = "compute-validation")]
28pub mod validation;
29
30pub use metrics::ComputeMetrics;
31pub use pool::{ComputeHandle, ComputePool, ComputePoolExt};
32
33/// Configuration for the compute thread pool
34#[derive(Debug, Clone)]
35pub struct ComputeConfig {
36    /// Number of threads in the Rayon pool (defaults to num_cpus / 2)
37    pub num_threads: Option<usize>,
38
39    /// Stack size for compute threads (defaults to 2MB)
40    pub stack_size: Option<usize>,
41
42    /// Thread name prefix (defaults to "compute")
43    pub thread_prefix: String,
44
45    /// Whether to pin threads to CPU cores
46    pub pin_threads: bool,
47}
48
49impl Default for ComputeConfig {
50    fn default() -> Self {
51        Self {
52            num_threads: None,                 // Will use num_cpus / 2
53            stack_size: Some(2 * 1024 * 1024), // 2MB
54            thread_prefix: "compute".to_string(),
55            pin_threads: false,
56        }
57    }
58}
59
60impl ComputeConfig {
61    /// Validate the configuration
62    pub fn validate(&self) -> Result<()> {
63        if let Some(num_threads) = self.num_threads
64            && num_threads == 0
65        {
66            return Err(anyhow::anyhow!(
67                "Number of compute threads cannot be 0. Use None to disable compute pool entirely."
68            ));
69        }
70
71        if let Some(stack_size) = self.stack_size
72            && stack_size < 128 * 1024
73        {
74            return Err(anyhow::anyhow!(
75                "Stack size too small: {}KB. Minimum recommended: 128KB",
76                stack_size / 1024
77            ));
78        }
79
80        Ok(())
81    }
82
83    /// Create a ThreadPoolBuilder from this configuration
84    pub(crate) fn build_pool(&self) -> Result<rayon::ThreadPool> {
85        // Validate configuration first
86        self.validate()?;
87
88        let mut builder = ThreadPoolBuilder::new();
89
90        // Set number of threads with better logic for minimum parallelism
91        let num_threads = self.num_threads.unwrap_or_else(|| {
92            std::thread::available_parallelism()
93                .map(|n| {
94                    let total_cores = n.get();
95                    // Use half the cores, but ensure we have at least 2 threads
96                    // for meaningful parallelism, and cap at 16 for efficiency
97                    (total_cores / 2).clamp(2, 16)
98                })
99                .unwrap_or(2) // Fallback to 2 threads if detection fails
100        });
101        builder = builder.num_threads(num_threads);
102
103        // Set stack size if specified
104        if let Some(stack_size) = self.stack_size {
105            builder = builder.stack_size(stack_size);
106        }
107
108        // Set thread name prefix
109        let prefix = self.thread_prefix.clone();
110        let thread_counter = Arc::new(AtomicU64::new(0));
111        builder = builder.thread_name(move |_| {
112            let id = thread_counter.fetch_add(1, Ordering::SeqCst);
113            format!("{}-{}", prefix, id)
114        });
115
116        // TODO: Add CPU pinning if requested
117        // if self.pin_threads {
118        //     builder = builder.start_handler(|idx| {
119        //         // Pin thread to CPU core
120        //     });
121        // }
122
123        builder
124            .build()
125            .map_err(|e| anyhow::anyhow!("Failed to create Rayon thread pool: {}", e))
126    }
127}
128
129/// Helper trait for scope-based operations
130pub trait ScopeExecutor {
131    /// Execute a function within a Rayon scope
132    fn execute_in_scope<F, R>(&self, f: F) -> R
133    where
134        F: FnOnce(&rayon::Scope) -> R + Send,
135        R: Send;
136}
137
138/// Helper functions for common parallel patterns
139pub mod patterns {
140    use super::*;
141
142    /// Execute two functions in parallel and return both results
143    pub async fn parallel_join<F1, F2, R1, R2>(
144        pool: &ComputePool,
145        f1: F1,
146        f2: F2,
147    ) -> Result<(R1, R2)>
148    where
149        F1: FnOnce() -> R1 + Send + 'static,
150        F2: FnOnce() -> R2 + Send + 'static,
151        R1: Send + 'static,
152        R2: Send + 'static,
153    {
154        pool.execute(move || rayon::join(f1, f2)).await
155    }
156
157    /// Execute multiple functions in parallel using scope
158    pub async fn parallel_map<F, T, R>(pool: &ComputePool, items: Vec<T>, f: F) -> Result<Vec<R>>
159    where
160        F: Fn(T) -> R + Sync + Send + 'static,
161        T: Send + 'static,
162        R: Send + 'static,
163    {
164        use rayon::prelude::*;
165        pool.execute(move || items.into_par_iter().map(f).collect())
166            .await
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_compute_config_default() {
176        let config = ComputeConfig::default();
177        assert_eq!(config.thread_prefix, "compute");
178        assert_eq!(config.stack_size, Some(2 * 1024 * 1024));
179        assert!(!config.pin_threads);
180    }
181
182    #[test]
183    fn test_build_pool() {
184        let config = ComputeConfig {
185            num_threads: Some(2),
186            ..Default::default()
187        };
188
189        let pool = config.build_pool().unwrap();
190        assert_eq!(pool.current_num_threads(), 2);
191    }
192}