dynamo_runtime/compute/
mod.rs1#![doc = include_str!("../../docs/rayon-tokio-strategy.md")]
16
17use anyhow::Result;
18use rayon::ThreadPoolBuilder;
19use std::sync::Arc;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::time::Instant;
22
23pub mod macros;
24pub mod metrics;
25pub mod pool;
26pub mod thread_local;
27#[cfg(feature = "compute-validation")]
28pub mod validation;
29
30pub use metrics::ComputeMetrics;
31pub use pool::{ComputeHandle, ComputePool, ComputePoolExt};
32
33#[derive(Debug, Clone)]
35pub struct ComputeConfig {
36 pub num_threads: Option<usize>,
38
39 pub stack_size: Option<usize>,
41
42 pub thread_prefix: String,
44
45 pub pin_threads: bool,
47}
48
49impl Default for ComputeConfig {
50 fn default() -> Self {
51 Self {
52 num_threads: None, stack_size: Some(2 * 1024 * 1024), thread_prefix: "compute".to_string(),
55 pin_threads: false,
56 }
57 }
58}
59
60impl ComputeConfig {
61 pub fn validate(&self) -> Result<()> {
63 if let Some(num_threads) = self.num_threads
64 && num_threads == 0
65 {
66 return Err(anyhow::anyhow!(
67 "Number of compute threads cannot be 0. Use None to disable compute pool entirely."
68 ));
69 }
70
71 if let Some(stack_size) = self.stack_size
72 && stack_size < 128 * 1024
73 {
74 return Err(anyhow::anyhow!(
75 "Stack size too small: {}KB. Minimum recommended: 128KB",
76 stack_size / 1024
77 ));
78 }
79
80 Ok(())
81 }
82
83 pub(crate) fn build_pool(&self) -> Result<rayon::ThreadPool> {
85 self.validate()?;
87
88 let mut builder = ThreadPoolBuilder::new();
89
90 let num_threads = self.num_threads.unwrap_or_else(|| {
92 std::thread::available_parallelism()
93 .map(|n| {
94 let total_cores = n.get();
95 (total_cores / 2).clamp(2, 16)
98 })
99 .unwrap_or(2) });
101 builder = builder.num_threads(num_threads);
102
103 if let Some(stack_size) = self.stack_size {
105 builder = builder.stack_size(stack_size);
106 }
107
108 let prefix = self.thread_prefix.clone();
110 let thread_counter = Arc::new(AtomicU64::new(0));
111 builder = builder.thread_name(move |_| {
112 let id = thread_counter.fetch_add(1, Ordering::SeqCst);
113 format!("{}-{}", prefix, id)
114 });
115
116 builder
124 .build()
125 .map_err(|e| anyhow::anyhow!("Failed to create Rayon thread pool: {}", e))
126 }
127}
128
129pub trait ScopeExecutor {
131 fn execute_in_scope<F, R>(&self, f: F) -> R
133 where
134 F: FnOnce(&rayon::Scope) -> R + Send,
135 R: Send;
136}
137
138pub mod patterns {
140 use super::*;
141
142 pub async fn parallel_join<F1, F2, R1, R2>(
144 pool: &ComputePool,
145 f1: F1,
146 f2: F2,
147 ) -> Result<(R1, R2)>
148 where
149 F1: FnOnce() -> R1 + Send + 'static,
150 F2: FnOnce() -> R2 + Send + 'static,
151 R1: Send + 'static,
152 R2: Send + 'static,
153 {
154 pool.execute(move || rayon::join(f1, f2)).await
155 }
156
157 pub async fn parallel_map<F, T, R>(pool: &ComputePool, items: Vec<T>, f: F) -> Result<Vec<R>>
159 where
160 F: Fn(T) -> R + Sync + Send + 'static,
161 T: Send + 'static,
162 R: Send + 'static,
163 {
164 use rayon::prelude::*;
165 pool.execute(move || items.into_par_iter().map(f).collect())
166 .await
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_compute_config_default() {
176 let config = ComputeConfig::default();
177 assert_eq!(config.thread_prefix, "compute");
178 assert_eq!(config.stack_size, Some(2 * 1024 * 1024));
179 assert!(!config.pin_threads);
180 }
181
182 #[test]
183 fn test_build_pool() {
184 let config = ComputeConfig {
185 num_threads: Some(2),
186 ..Default::default()
187 };
188
189 let pool = config.build_pool().unwrap();
190 assert_eq!(pool.current_num_threads(), 2);
191 }
192}