use std::sync::Arc;
use parking_lot::{Condvar, Mutex, RwLock};
use rayon::ThreadPool;
use super::config::WorkerPoolConfig;
pub struct AdaptiveWorkerPool {
pub(crate) config: Arc<RwLock<WorkerPoolConfig>>,
rayon_pool: ThreadPool,
pub(crate) semaphore: Arc<Semaphore>,
#[cfg(feature = "memory")]
pub(crate) memory_guard: parking_lot::Mutex<Option<Arc<crate::memory::MemoryGuard>>>,
#[cfg(feature = "scaling")]
pub(crate) scaling_pressure: parking_lot::Mutex<Option<Arc<crate::scaling::ScalingPressure>>>,
}
pub(crate) struct Semaphore {
state: Mutex<SemState>,
available: Condvar,
max_permits: usize,
}
struct SemState {
target: usize,
leased: usize,
}
impl Semaphore {
fn new(initial_target: usize, max_permits: usize) -> Self {
let max_permits = max_permits.max(1);
Self {
state: Mutex::new(SemState {
target: initial_target.clamp(1, max_permits),
leased: 0,
}),
available: Condvar::new(),
max_permits,
}
}
fn acquire(&self) -> SemaphoreGuard<'_> {
let mut st = self.state.lock();
while st.leased >= st.target {
self.available.wait(&mut st);
}
st.leased += 1;
SemaphoreGuard { semaphore: self }
}
pub(crate) fn set_target(&self, target: usize) {
let clamped = target.clamp(1, self.max_permits);
let mut st = self.state.lock();
let grew = clamped > st.target;
st.target = clamped;
drop(st);
if grew {
self.available.notify_all();
}
}
pub(crate) fn target(&self) -> usize {
self.state.lock().target
}
pub(crate) fn leased(&self) -> usize {
self.state.lock().leased
}
pub(crate) fn available(&self) -> usize {
let st = self.state.lock();
st.target.saturating_sub(st.leased)
}
}
struct SemaphoreGuard<'a> {
semaphore: &'a Semaphore,
}
impl Drop for SemaphoreGuard<'_> {
fn drop(&mut self) {
let mut st = self.semaphore.state.lock();
st.leased = st.leased.saturating_sub(1);
drop(st);
self.semaphore.available.notify_one();
}
}
#[derive(Debug, Clone, Default)]
pub struct FanOutPolicy {
pub per_item_timeout: Option<std::time::Duration>,
pub cancel: Option<tokio_util::sync::CancellationToken>,
}
#[derive(Debug)]
pub enum FanOutResult<R, E> {
Ok(R),
Err(E),
TimedOut,
Panicked,
Cancelled,
}
impl AdaptiveWorkerPool {
pub fn try_new(config: WorkerPoolConfig) -> Result<Self, crate::config::ConfigError> {
let mut resolved = config;
resolved.resolve_max_threads();
resolved.validate()?;
Ok(Self::build(resolved))
}
#[must_use]
pub fn new(config: WorkerPoolConfig) -> Self {
Self::try_new(config).expect("invalid WorkerPoolConfig (use try_new to handle the error)")
}
fn build(resolved: WorkerPoolConfig) -> Self {
let max_threads = resolved.max_threads;
let min_threads = resolved.min_threads;
let rayon_pool = rayon::ThreadPoolBuilder::new()
.num_threads(max_threads)
.thread_name(|i| format!("worker-{i}"))
.build()
.expect("Failed to create rayon thread pool");
let semaphore = Arc::new(Semaphore::new(min_threads, max_threads));
Self {
config: Arc::new(RwLock::new(resolved)),
rayon_pool,
semaphore,
#[cfg(feature = "memory")]
memory_guard: parking_lot::Mutex::new(None),
#[cfg(feature = "scaling")]
scaling_pressure: parking_lot::Mutex::new(None),
}
}
pub fn from_cascade(key: &str) -> Result<Self, crate::config::ConfigError> {
let config = WorkerPoolConfig::from_cascade(key)?;
Self::try_new(config)
}
pub fn process_batch<T, R, E, F>(&self, items: &[T], f: F) -> Vec<Result<R, E>>
where
T: Sync,
R: Send,
E: Send,
F: Fn(&T) -> Result<R, E> + Sync,
{
let sem = &self.semaphore;
self.rayon_pool.install(|| {
use rayon::prelude::*;
items
.par_iter()
.map(|item| {
let _permit = sem.acquire();
f(item)
})
.collect()
})
}
pub async fn fan_out_async<T, R, E, F, Fut>(
&self,
items: &[T],
f: F,
) -> Vec<Option<Result<R, E>>>
where
T: Sync + Send,
R: Send + 'static,
E: Send + 'static,
F: Fn(&T) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<R, E>> + Send + 'static,
{
let concurrency = self.config.read().async_concurrency;
let mut results: Vec<Option<Result<R, E>>> = (0..items.len()).map(|_| None).collect();
for chunk_start in (0..items.len()).step_by(concurrency) {
let chunk_end = (chunk_start + concurrency).min(items.len());
let mut handles = Vec::with_capacity(chunk_end - chunk_start);
for (idx, item) in items
.iter()
.enumerate()
.skip(chunk_start)
.take(chunk_end - chunk_start)
{
let fut = f(item);
handles.push((idx, tokio::spawn(fut)));
}
for (idx, handle) in handles {
match handle.await {
Ok(result) => results[idx] = Some(result),
Err(join_err) => {
tracing::error!(error = %join_err, idx, "fan_out_async task panicked");
}
}
}
}
results
}
pub async fn fan_out_async_with_policy<T, R, E, F, Fut>(
&self,
items: &[T],
policy: &FanOutPolicy,
f: F,
) -> Vec<FanOutResult<R, E>>
where
T: Sync + Send,
R: Send + 'static,
E: Send + 'static,
F: Fn(&T) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<R, E>> + Send + 'static,
{
let concurrency = self.config.read().async_concurrency.max(1);
let mut results: Vec<FanOutResult<R, E>> =
(0..items.len()).map(|_| FanOutResult::Cancelled).collect();
#[cfg(feature = "metrics")]
let started = std::time::Instant::now();
let mut set: tokio::task::JoinSet<(usize, FanOutResult<R, E>)> =
tokio::task::JoinSet::new();
let mut id_to_idx: std::collections::HashMap<tokio::task::Id, usize> =
std::collections::HashMap::new();
let mut next = 0;
let cancelled = || policy.cancel.as_ref().is_some_and(|c| c.is_cancelled());
loop {
while set.len() < concurrency && next < items.len() && !cancelled() {
let fut = f(&items[next]);
let timeout = policy.per_item_timeout;
let idx = next;
let handle = set.spawn(async move {
let outcome = match timeout {
Some(d) => match tokio::time::timeout(d, fut).await {
Ok(Ok(r)) => FanOutResult::Ok(r),
Ok(Err(e)) => FanOutResult::Err(e),
Err(_) => FanOutResult::TimedOut,
},
None => match fut.await {
Ok(r) => FanOutResult::Ok(r),
Err(e) => FanOutResult::Err(e),
},
};
(idx, outcome)
});
id_to_idx.insert(handle.id(), idx);
next += 1;
}
#[cfg(feature = "metrics")]
::metrics::gauge!("dfe_fanout_inflight").set(set.len() as f64);
let Some(joined) = set.join_next().await else {
break; };
match joined {
Ok((idx, outcome)) => {
#[cfg(feature = "metrics")]
if matches!(outcome, FanOutResult::TimedOut) {
::metrics::counter!("dfe_fanout_timeout_total").increment(1);
}
results[idx] = outcome;
}
Err(join_err) => {
if let Some(&idx) = id_to_idx.get(&join_err.id()) {
results[idx] = FanOutResult::Panicked;
}
#[cfg(feature = "metrics")]
::metrics::counter!("dfe_fanout_panic_total").increment(1);
tracing::error!(error = %join_err, "fan_out_async_with_policy task panicked");
}
}
if cancelled() && set.is_empty() {
break;
}
}
#[cfg(feature = "metrics")]
{
::metrics::gauge!("dfe_fanout_inflight").set(0.0);
::metrics::histogram!("dfe_fanout_batch_duration_seconds")
.record(started.elapsed().as_secs_f64());
}
results
}
pub fn map_owned<T, R, F>(&self, items: Vec<T>, f: F) -> Vec<R>
where
T: Send,
R: Send,
F: Fn(T) -> R + Sync,
{
let sem = &self.semaphore;
self.rayon_pool.install(|| {
use rayon::prelude::*;
items
.into_par_iter()
.map(|item| {
let _permit = sem.acquire();
f(item)
})
.collect()
})
}
pub fn install<R: Send>(&self, f: impl FnOnce() -> R + Send) -> R {
self.rayon_pool.install(f)
}
pub fn register_metrics(&self, manager: &crate::metrics::MetricsManager) {
let config = self.config.read();
super::metrics::register(manager, &config);
}
pub fn start_scaling_loop(self: &Arc<Self>, cancel: tokio_util::sync::CancellationToken) {
let controller = super::scaler::ScalingController::new(self.clone());
tokio::spawn(controller.run(cancel));
}
#[cfg(feature = "memory")]
pub fn set_memory_guard(&self, guard: Arc<crate::memory::MemoryGuard>) {
*self.memory_guard.lock() = Some(guard);
}
#[cfg(feature = "scaling")]
pub fn set_scaling_pressure(&self, pressure: Arc<crate::scaling::ScalingPressure>) {
*self.scaling_pressure.lock() = Some(pressure);
}
#[must_use]
pub fn active_threads(&self) -> usize {
self.semaphore.leased()
}
#[must_use]
pub fn target_threads(&self) -> usize {
self.semaphore.target()
}
#[must_use]
pub fn available_threads(&self) -> usize {
self.semaphore.available()
}
#[must_use]
pub fn max_threads(&self) -> usize {
self.config.read().max_threads
}
}
#[cfg(test)]
mod semaphore_tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::Semaphore;
#[test]
fn idle_reports_zero_leased() {
let s = Semaphore::new(2, 8);
assert_eq!(s.leased(), 0);
assert_eq!(s.target(), 2);
assert_eq!(s.available(), 2);
}
#[test]
fn lease_and_drop_track_leased() {
let s = Semaphore::new(4, 8);
{
let _g1 = s.acquire();
let _g2 = s.acquire();
assert_eq!(s.leased(), 2);
assert_eq!(s.available(), 2);
}
assert_eq!(s.leased(), 0, "drops release leases");
assert_eq!(s.available(), 4);
}
#[test]
fn downscale_does_not_overshoot_on_drop() {
let s = Semaphore::new(8, 8);
let guards: Vec<_> = (0..8).map(|_| s.acquire()).collect();
assert_eq!(s.leased(), 8);
s.set_target(2);
assert_eq!(s.target(), 2);
assert_eq!(
s.available(),
0,
"leased (8) exceeds target (2): no headroom"
);
drop(guards);
assert_eq!(s.leased(), 0);
assert_eq!(
s.available(),
2,
"available equals target after drain, not max_permits"
);
}
#[test]
fn set_target_clamps_to_one_and_max() {
let s = Semaphore::new(4, 8);
s.set_target(0);
assert_eq!(s.target(), 1, "target floored at 1 to avoid deadlock");
s.set_target(100);
assert_eq!(s.target(), 8, "target capped at max_permits");
}
#[test]
fn contention_never_exceeds_target() {
let s = Arc::new(Semaphore::new(2, 2));
let max_seen = Arc::new(AtomicUsize::new(0));
let handles: Vec<_> = (0..8)
.map(|_| {
let s = Arc::clone(&s);
let max_seen = Arc::clone(&max_seen);
std::thread::spawn(move || {
for _ in 0..50 {
let _g = s.acquire();
max_seen.fetch_max(s.leased(), Ordering::Relaxed);
std::thread::yield_now();
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert!(
max_seen.load(Ordering::Relaxed) <= 2,
"leased never exceeded target=2"
);
assert_eq!(s.leased(), 0);
}
#[test]
fn grow_target_wakes_parked_acquirer() {
let s = Arc::new(Semaphore::new(1, 4));
let held = s.acquire();
assert_eq!(s.leased(), 1);
let s2 = Arc::clone(&s);
let handle = std::thread::spawn(move || {
let _g = s2.acquire();
s2.leased()
});
std::thread::sleep(std::time::Duration::from_millis(50));
s.set_target(2);
let observed = handle.join().unwrap();
assert!(observed >= 1, "parked acquirer proceeded after target grew");
drop(held);
}
}