#![doc = include_str!("../../docs/rayon-tokio-strategy.md")]
use anyhow::Result;
use rayon::ThreadPoolBuilder;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
pub mod macros;
pub mod metrics;
pub mod pool;
pub mod thread_local;
#[cfg(feature = "compute-validation")]
pub mod validation;
pub use metrics::ComputeMetrics;
pub use pool::{ComputeHandle, ComputePool, ComputePoolExt};
#[derive(Debug, Clone)]
pub struct ComputeConfig {
pub num_threads: Option<usize>,
pub stack_size: Option<usize>,
pub thread_prefix: String,
pub pin_threads: bool,
}
impl Default for ComputeConfig {
fn default() -> Self {
Self {
num_threads: None, stack_size: Some(2 * 1024 * 1024), thread_prefix: "compute".to_string(),
pin_threads: false,
}
}
}
impl ComputeConfig {
pub fn validate(&self) -> Result<()> {
if let Some(num_threads) = self.num_threads
&& num_threads == 0
{
return Err(anyhow::anyhow!(
"Number of compute threads cannot be 0. Use None to disable compute pool entirely."
));
}
if let Some(stack_size) = self.stack_size
&& stack_size < 128 * 1024
{
return Err(anyhow::anyhow!(
"Stack size too small: {}KB. Minimum recommended: 128KB",
stack_size / 1024
));
}
Ok(())
}
pub(crate) fn build_pool(&self) -> Result<rayon::ThreadPool> {
self.validate()?;
let mut builder = ThreadPoolBuilder::new();
let num_threads = self.num_threads.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|n| {
let total_cores = n.get();
(total_cores / 2).clamp(2, 16)
})
.unwrap_or(2) });
builder = builder.num_threads(num_threads);
if let Some(stack_size) = self.stack_size {
builder = builder.stack_size(stack_size);
}
let prefix = self.thread_prefix.clone();
let thread_counter = Arc::new(AtomicU64::new(0));
builder = builder.thread_name(move |_| {
let id = thread_counter.fetch_add(1, Ordering::SeqCst);
format!("{}-{}", prefix, id)
});
builder
.build()
.map_err(|e| anyhow::anyhow!("Failed to create Rayon thread pool: {}", e))
}
}
pub trait ScopeExecutor {
fn execute_in_scope<F, R>(&self, f: F) -> R
where
F: FnOnce(&rayon::Scope) -> R + Send,
R: Send;
}
pub mod patterns {
use super::*;
pub async fn parallel_join<F1, F2, R1, R2>(
pool: &ComputePool,
f1: F1,
f2: F2,
) -> Result<(R1, R2)>
where
F1: FnOnce() -> R1 + Send + 'static,
F2: FnOnce() -> R2 + Send + 'static,
R1: Send + 'static,
R2: Send + 'static,
{
pool.execute(move || rayon::join(f1, f2)).await
}
pub async fn parallel_map<F, T, R>(pool: &ComputePool, items: Vec<T>, f: F) -> Result<Vec<R>>
where
F: Fn(T) -> R + Sync + Send + 'static,
T: Send + 'static,
R: Send + 'static,
{
use rayon::prelude::*;
pool.execute(move || items.into_par_iter().map(f).collect())
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_config_default() {
let config = ComputeConfig::default();
assert_eq!(config.thread_prefix, "compute");
assert_eq!(config.stack_size, Some(2 * 1024 * 1024));
assert!(!config.pin_threads);
}
#[test]
fn test_build_pool() {
let config = ComputeConfig {
num_threads: Some(2),
..Default::default()
};
let pool = config.build_pool().unwrap();
assert_eq!(pool.current_num_threads(), 2);
}
}