use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
use crate::parallel::scheduler::{SchedulerConfigBuilder, WorkStealingScheduler};
use rayon::iter::ParallelIterator;
use rayon::prelude::*;
use std::cell::RefCell;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
thread_local! {
static NESTING_LEVEL: RefCell<usize> = const { RefCell::new(0) };
static PARENT_CONTEXT: RefCell<Option<Arc<NestedContext>>> = const { RefCell::new(None) };
}
static GLOBAL_RESOURCE_MANAGER: std::sync::OnceLock<Arc<ResourceManager>> =
std::sync::OnceLock::new();
#[allow(dead_code)]
fn get_resource_manager() -> Arc<ResourceManager> {
GLOBAL_RESOURCE_MANAGER
.get_or_init(|| Arc::new(ResourceManager::new()))
.clone()
}
#[derive(Debug, Clone)]
pub struct ResourceLimits {
pub max_total_threads: usize,
pub max_nesting_depth: usize,
pub threads_per_level: Vec<usize>,
pub max_memory_bytes: usize,
pub max_cpu_usage: f64,
pub enable_thread_pooling: bool,
pub enable_cross_level_stealing: bool,
}
impl Default for ResourceLimits {
fn default() -> Self {
let num_cpus = num_cpus::get();
Self {
max_total_threads: num_cpus * 2,
max_nesting_depth: 3,
threads_per_level: vec![num_cpus, num_cpus / 2, 1],
max_memory_bytes: (4u64 * 1024 * 1024 * 1024) as usize, max_cpu_usage: 0.9,
enable_thread_pooling: true,
enable_cross_level_stealing: false,
}
}
}
pub struct NestedContext {
level: usize,
parent: Option<Arc<NestedContext>>,
limits: ResourceLimits,
active_threads: AtomicUsize,
scheduler: Option<Arc<Mutex<WorkStealingScheduler>>>,
}
impl NestedContext {
pub fn new(limits: ResourceLimits) -> Self {
Self {
level: 0,
parent: None,
limits,
active_threads: AtomicUsize::new(0),
scheduler: None,
}
}
pub fn create_child(&self) -> CoreResult<Arc<NestedContext>> {
if self.level >= self.limits.max_nesting_depth {
return Err(CoreError::ConfigError(
ErrorContext::new(format!(
"Maximum nesting depth {} exceeded",
self.limits.max_nesting_depth
))
.with_location(ErrorLocation::new(file!(), line!())),
));
}
let child = NestedContext {
level: self.level + 1,
parent: Some(Arc::new(self.clone())),
limits: self.limits.clone(),
active_threads: AtomicUsize::new(0),
scheduler: None,
};
Ok(Arc::new(child))
}
pub fn max_threads_at_level(&self) -> usize {
if self.level < self.limits.threads_per_level.len() {
self.limits.threads_per_level[self.level]
} else {
1 }
}
pub fn try_acquire_threads(&self, requested: usize) -> usize {
let max_at_level = self.max_threads_at_level();
let resource_manager = get_resource_manager();
let available_global = resource_manager.try_acquire_threads(requested);
let current = self.active_threads.load(Ordering::Relaxed);
let available_at_level = max_at_level.saturating_sub(current);
let granted = requested.min(available_global).min(available_at_level);
if granted > 0 {
self.active_threads.fetch_add(granted, Ordering::Relaxed);
} else {
resource_manager.release_threads(available_global);
}
granted
}
pub fn release_threads(&self, count: usize) {
self.active_threads.fetch_sub(count, Ordering::Relaxed);
get_resource_manager().release_threads(count);
}
pub fn get_scheduler(&self) -> CoreResult<Arc<Mutex<WorkStealingScheduler>>> {
if let Some(ref scheduler) = self.scheduler {
return Ok(scheduler.clone());
}
let config = SchedulerConfigBuilder::new()
.workers(self.max_threads_at_level())
.adaptive(true)
.enable_stealing_heuristics(true)
.enable_priorities(true)
.build();
let scheduler = WorkStealingScheduler::new(config);
Ok(Arc::new(Mutex::new(scheduler)))
}
}
impl Clone for NestedContext {
fn clone(&self) -> Self {
Self {
level: self.level,
parent: self.parent.clone(),
limits: self.limits.clone(),
active_threads: AtomicUsize::new(self.active_threads.load(Ordering::Relaxed)),
scheduler: self.scheduler.clone(),
}
}
}
pub struct ResourceManager {
total_threads: AtomicUsize,
memory_used: AtomicUsize,
cpu_usage: RwLock<f64>,
active_contexts: RwLock<Vec<usize>>,
}
impl Default for ResourceManager {
fn default() -> Self {
Self::new()
}
}
impl ResourceManager {
pub fn new() -> Self {
let max_levels = 10;
Self {
total_threads: AtomicUsize::new(0),
memory_used: AtomicUsize::new(0),
cpu_usage: RwLock::new(0.0),
active_contexts: RwLock::new(vec![0; max_levels]),
}
}
pub fn try_acquire_threads(&self, requested: usize) -> usize {
let mut acquired = 0;
for _ in 0..requested {
let current = self.total_threads.load(Ordering::Relaxed);
let max_threads = num_cpus::get() * 2;
if current < max_threads {
if self
.total_threads
.compare_exchange(current, current + 1, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
acquired += 1;
} else {
continue;
}
} else {
break;
}
}
acquired
}
pub fn release_threads(&self, count: usize) {
self.total_threads.fetch_sub(count, Ordering::Release);
}
pub fn update_memory_usage(&self, bytes: isize) {
if bytes > 0 {
self.memory_used
.fetch_add(bytes as usize, Ordering::Relaxed);
} else {
self.memory_used
.fetch_sub((-bytes) as usize, Ordering::Relaxed);
}
}
pub fn get_usage_stats(&self) -> ResourceUsageStats {
ResourceUsageStats {
total_threads: self.total_threads.load(Ordering::Relaxed),
memory_bytes: self.memory_used.load(Ordering::Relaxed),
cpu_usage: *self.cpu_usage.read().expect("Operation failed"),
active_contexts_per_level: self
.active_contexts
.read()
.expect("Operation failed")
.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct ResourceUsageStats {
pub total_threads: usize,
pub memory_bytes: usize,
pub cpu_usage: f64,
pub active_contexts_per_level: Vec<usize>,
}
pub struct NestedScope<'a> {
context: Arc<NestedContext>,
acquired_threads: usize,
phantom: std::marker::PhantomData<&'a ()>,
}
impl NestedScope<'_> {
pub fn execute<F, R>(&self, f: F) -> CoreResult<R>
where
F: FnOnce() -> R + Send,
R: Send,
{
PARENT_CONTEXT.with(|ctx| {
*ctx.borrow_mut() = Some(self.context.clone());
});
NESTING_LEVEL.with(|level| {
*level.borrow_mut() = self.context.level;
});
let result = f();
PARENT_CONTEXT.with(|ctx| {
*ctx.borrow_mut() = None;
});
Ok(result)
}
pub fn par_iter<I, F, R>(&self, items: I, f: F) -> CoreResult<Vec<R>>
where
I: IntoParallelIterator,
I::Item: Send,
F: Fn(I::Item) -> R + Send + Sync,
R: Send,
{
let results: Vec<R> = items.into_par_iter().map(f).collect();
Ok(results)
}
}
impl Drop for NestedScope<'_> {
fn drop(&mut self) {
if self.acquired_threads > 0 {
self.context.release_threads(self.acquired_threads);
}
}
}
#[allow(dead_code)]
pub fn nested_scope<F, R>(f: F) -> CoreResult<R>
where
F: FnOnce(&NestedScope) -> CoreResult<R>,
{
nested_scope_with_limits(ResourceLimits::default(), f)
}
#[allow(dead_code)]
pub fn nested_scope_with_limits<F, R>(limits: ResourceLimits, f: F) -> CoreResult<R>
where
F: FnOnce(&NestedScope) -> CoreResult<R>,
{
let context = match PARENT_CONTEXT
.with(|ctx| ctx.borrow().as_ref().map(|parent| parent.create_child()))
{
Some(child_result) => child_result?,
None => {
Arc::new(NestedContext::new(limits.clone()))
}
};
let requested_threads = context.max_threads_at_level();
let acquired_threads = context.try_acquire_threads(requested_threads);
let scope = NestedScope {
context: context.clone(),
acquired_threads,
phantom: std::marker::PhantomData,
};
let old_level = NESTING_LEVEL.with(|level| {
let old = *level.borrow();
*level.borrow_mut() = context.level;
old
});
let old_context = PARENT_CONTEXT.with(|ctx| ctx.borrow_mut().replace(context));
let result = f(&scope);
NESTING_LEVEL.with(|level| {
*level.borrow_mut() = old_level;
});
PARENT_CONTEXT.with(|ctx| {
*ctx.borrow_mut() = old_context;
});
result
}
#[allow(dead_code)]
pub fn current_nesting_level() -> usize {
NESTING_LEVEL.with(|level| *level.borrow())
}
#[allow(dead_code)]
pub fn is_nested_parallelism_allowed() -> bool {
PARENT_CONTEXT.with(|ctx| {
if let Some(ref context) = *ctx.borrow() {
context.level < context.limits.max_nesting_depth
} else {
true }
})
}
#[allow(dead_code)]
pub fn adaptive_par_for_each<T, F>(data: Vec<T>, f: F) -> CoreResult<()>
where
T: Send,
F: Fn(T) + Send + Sync,
{
if is_nested_parallelism_allowed() {
data.into_par_iter().for_each(f);
} else {
data.into_iter().for_each(f);
}
Ok(())
}
#[allow(dead_code)]
pub fn adaptive_par_map<T, F, R>(data: Vec<T>, f: F) -> CoreResult<Vec<R>>
where
T: Send,
F: Fn(T) -> R + Send + Sync,
R: Send,
{
if is_nested_parallelism_allowed() {
Ok(data.into_par_iter().map(f).collect())
} else {
Ok(data.into_iter().map(f).collect())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NestedPolicy {
Allow,
Sequential,
Delegate,
Deny,
}
#[derive(Debug, Clone)]
pub struct NestedConfig {
pub policy: NestedPolicy,
pub limits: ResourceLimits,
pub track_usage: bool,
pub adaptive_scheduling: bool,
}
impl Default for NestedConfig {
fn default() -> Self {
Self {
policy: NestedPolicy::Allow,
limits: ResourceLimits::default(),
track_usage: true,
adaptive_scheduling: true,
}
}
}
#[allow(dead_code)]
pub fn with_nested_policy<F, R>(config: NestedConfig, f: F) -> CoreResult<R>
where
F: FnOnce() -> CoreResult<R>,
{
match config.policy {
NestedPolicy::Allow => nested_scope_with_limits(config.limits, |_scope| f()),
NestedPolicy::Sequential => {
NESTING_LEVEL.with(|level| {
*level.borrow_mut() = usize::MAX;
});
let result = f();
NESTING_LEVEL.with(|level| {
*level.borrow_mut() = 0;
});
result
}
NestedPolicy::Delegate => {
f()
}
NestedPolicy::Deny => {
let is_nested = PARENT_CONTEXT.with(|ctx| ctx.borrow().is_some());
if is_nested {
Err(CoreError::ConfigError(
ErrorContext::new("Nested parallelism not allowed".to_string())
.with_location(ErrorLocation::new(file!(), line!())),
))
} else {
f()
}
}
}
}
#[allow(dead_code)]
fn get_parent_scheduler() -> Option<Arc<Mutex<WorkStealingScheduler>>> {
PARENT_CONTEXT.with(|ctx| {
ctx.borrow()
.as_ref()
.and_then(|context| context.scheduler.clone())
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_nested_execution() {
let result = nested_scope(|scope| {
let data: Vec<i32> = (0..100).collect();
scope.par_iter(data, |x| x * 2)
})
.expect("Operation failed");
assert_eq!(result.len(), 100);
assert_eq!(result[0], 0);
assert_eq!(result[50], 100);
}
#[test]
fn test_nesting_levels() {
nested_scope(|outer_scope| {
assert_eq!(current_nesting_level(), 0);
outer_scope.execute(|| {
nested_scope(|inner_scope| {
assert_eq!(current_nesting_level(), 1);
inner_scope.execute(|| {
nested_scope(|_deepest_scope| {
assert_eq!(current_nesting_level(), 2);
Ok(())
})
.expect("Operation failed")
})
})
.expect("Operation failed")
})
})
.expect("Operation failed");
}
#[test]
fn test_resource_limits() {
let limits = ResourceLimits {
max_total_threads: 4,
max_nesting_depth: 2,
threads_per_level: vec![2, 1],
..Default::default()
};
let result = nested_scope_with_limits(limits, |scope| {
let context = &scope.context;
assert!(context.max_threads_at_level() <= 2);
Ok(42)
});
assert_eq!(result.expect("Operation failed"), 42);
}
#[test]
fn test_sequential_policy() {
let config = NestedConfig {
policy: NestedPolicy::Sequential,
..Default::default()
};
let result = with_nested_policy(config, || {
let data: Vec<i32> = (0..10).collect();
let sum: i32 = data.into_par_iter().sum();
Ok(sum)
});
assert_eq!(result.expect("Operation failed"), 45);
}
#[test]
fn test_deny_policy() {
let config = NestedConfig {
policy: NestedPolicy::Deny,
..Default::default()
};
let result = with_nested_policy(config.clone(), || Ok(1));
assert!(result.is_ok());
let result = nested_scope(|_scope| {
with_nested_policy(config, || Ok(2))
});
assert!(result.is_err());
}
}