#[cfg(feature = "rayon")]
use std::sync::OnceLock;
#[cfg(feature = "rayon")]
use rayon::ThreadPoolBuilder;
use crate::algebra::parallel_cfg::ParallelTune;
use std::collections::BTreeMap;
pub const DEFAULT_PAR_CUTOFF: usize = 4096;
pub fn env_usize(key: &str, default: usize) -> usize {
std::env::var(key)
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(default)
}
fn env_usize_opt(key: &str) -> Option<usize> {
std::env::var(key)
.ok()
.and_then(|s| s.parse::<usize>().ok())
.filter(|v| *v > 0)
}
fn detect_local_mpi_size(global_mpi_size: usize) -> usize {
env_usize_opt("KRYST_MPI_LOCAL_SIZE")
.or_else(|| env_usize_opt("OMPI_COMM_WORLD_LOCAL_SIZE"))
.or_else(|| env_usize_opt("MPI_LOCALNRANKS"))
.or_else(|| env_usize_opt("SLURM_NTASKS_PER_NODE"))
.unwrap_or(global_mpi_size.max(1))
}
#[cfg(feature = "rayon")]
static EFFECTIVE_THREADS: OnceLock<usize> = OnceLock::new();
pub fn init_global_rayon_pool(mpi_size: usize) -> usize {
#[cfg(not(feature = "rayon"))]
{
let _ = mpi_size; return 1;
}
#[cfg(feature = "rayon")]
{
*EFFECTIVE_THREADS.get_or_init(|| {
let local_mpi_size = detect_local_mpi_size(mpi_size.max(1));
let total = std::env::var("KRYST_THREADS")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.or_else(|| {
std::env::var("RAYON_NUM_THREADS")
.ok()
.and_then(|s| s.parse().ok())
})
.or_else(|| env_usize_opt("SLURM_CPUS_PER_TASK"))
.unwrap_or_else(num_cpus::get);
let threads = std::cmp::max(1, total / local_mpi_size);
let _ = ThreadPoolBuilder::new().num_threads(threads).build_global();
threads
})
}
}
pub fn init_global_rayon_pool_with_threads(threads: usize) -> usize {
#[cfg(not(feature = "rayon"))]
{
let _ = threads;
return 1;
}
#[cfg(feature = "rayon")]
{
*EFFECTIVE_THREADS.get_or_init(|| {
let threads = std::cmp::max(1, threads);
let _ = ThreadPoolBuilder::new().num_threads(threads).build_global();
threads
})
}
}
pub fn current_rayon_threads() -> usize {
#[cfg(feature = "rayon")]
{
EFFECTIVE_THREADS
.get()
.copied()
.unwrap_or_else(rayon::current_num_threads)
}
#[cfg(not(feature = "rayon"))]
{
1
}
}
pub struct ThreadPoolGuard {
threads: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum KspExecStage {
OuterSetup,
OuterApply,
InnerSetup,
InnerApply,
}
impl KspExecStage {
pub fn as_key(self) -> &'static str {
match self {
KspExecStage::OuterSetup => "outer_setup",
KspExecStage::OuterApply => "outer_apply",
KspExecStage::InnerSetup => "inner_setup",
KspExecStage::InnerApply => "inner_apply",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ScopedThreadPolicy {
pub setup_threads: Option<usize>,
pub apply_threads: Option<usize>,
pub inner_setup_threads: Option<usize>,
pub inner_apply_threads: Option<usize>,
pub per_pc_threads: BTreeMap<String, usize>,
}
impl ScopedThreadPolicy {
fn clean_limit(v: Option<usize>) -> Option<usize> {
v.filter(|n| *n > 0)
}
pub fn set_outer_threads(&mut self, n: Option<usize>) {
let n = Self::clean_limit(n);
self.setup_threads = n;
self.apply_threads = n;
}
pub fn set_inner_threads(&mut self, n: Option<usize>) {
let n = Self::clean_limit(n);
self.inner_setup_threads = n;
self.inner_apply_threads = n;
}
pub fn set_stage_threads(&mut self, stage: KspExecStage, n: Option<usize>) {
let n = Self::clean_limit(n);
match stage {
KspExecStage::OuterSetup => self.setup_threads = n,
KspExecStage::OuterApply => self.apply_threads = n,
KspExecStage::InnerSetup => self.inner_setup_threads = n,
KspExecStage::InnerApply => self.inner_apply_threads = n,
}
}
pub fn effective_threads(&self, stage: KspExecStage) -> usize {
let fallback = current_rayon_threads();
match stage {
KspExecStage::OuterSetup => self.setup_threads.unwrap_or(fallback),
KspExecStage::OuterApply => self.apply_threads.unwrap_or(fallback),
KspExecStage::InnerSetup => self
.inner_setup_threads
.or(self.setup_threads)
.unwrap_or(fallback),
KspExecStage::InnerApply => self
.inner_apply_threads
.or(self.apply_threads)
.unwrap_or(fallback),
}
.max(1)
}
pub fn diagnostics(&self) -> BTreeMap<String, usize> {
let mut out = BTreeMap::new();
for stage in [
KspExecStage::OuterSetup,
KspExecStage::OuterApply,
KspExecStage::InnerSetup,
KspExecStage::InnerApply,
] {
out.insert(stage.as_key().to_string(), self.effective_threads(stage));
}
out
}
pub fn install_stage<T>(&self, stage: KspExecStage, f: impl FnOnce() -> T + Send) -> T
where
T: Send,
{
let threads = self.effective_threads(stage);
#[cfg(feature = "rayon")]
{
if threads <= 1 {
let _guard = crate::algebra::parallel_cfg::serial_guard(true);
return f();
}
if threads >= current_rayon_threads() {
return f();
}
if let Ok(pool) = rayon::ThreadPoolBuilder::new().num_threads(threads).build() {
return pool.install(f);
}
}
f()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ThreadExecFlavor {
Serial,
Rayon,
}
#[derive(Debug, Clone, Copy)]
pub struct ThreadHeuristicDecision {
pub flavor: ThreadExecFlavor,
pub threads: usize,
pub reason: &'static str,
}
pub fn suggest_thread_policy(
problem_size: usize,
comm_size: usize,
tune: ParallelTune,
) -> ThreadHeuristicDecision {
let base_threads = current_rayon_threads();
if problem_size <= tune.min_len_vec.saturating_div(2) || base_threads <= 1 {
return ThreadHeuristicDecision {
flavor: ThreadExecFlavor::Serial,
threads: 1,
reason: "small local problem or single worker",
};
}
let mpi_share = if comm_size > 1 {
std::cmp::max(1, base_threads.saturating_sub(1))
} else {
base_threads
};
let scaled = std::cmp::max(1, problem_size / std::cmp::max(1, tune.min_len_vec));
let threads = std::cmp::max(1, std::cmp::min(mpi_share, scaled.max(2)));
ThreadHeuristicDecision {
flavor: if threads == 1 {
ThreadExecFlavor::Serial
} else {
ThreadExecFlavor::Rayon
},
threads,
reason: "threshold-driven local throughput tuning",
}
}
impl ThreadPoolGuard {
pub fn new_per_rank(mpi_size: usize) -> Self {
let threads = init_global_rayon_pool(mpi_size);
Self { threads }
}
pub fn threads(&self) -> usize {
self.threads
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algebra::parallel_cfg::ParallelTune;
#[test]
fn thread_heuristic_prefers_serial_for_small_problem() {
let tune = ParallelTune::default();
let d = suggest_thread_policy(tune.min_len_vec / 4, 1, tune);
assert_eq!(d.flavor, ThreadExecFlavor::Serial);
assert_eq!(d.threads, 1);
}
}