use core::sync::atomic::{AtomicUsize, Ordering};
pub struct DecodePool {
active: AtomicUsize,
parallel_threshold: usize,
}
impl DecodePool {
#[must_use]
pub fn new() -> Self {
Self {
active: AtomicUsize::new(0),
parallel_threshold: 4,
}
}
#[must_use]
pub fn parallel_threshold(mut self, threshold: usize) -> Self {
self.parallel_threshold = threshold;
self
}
pub fn active_count(&self) -> usize {
self.active.load(Ordering::Relaxed)
}
pub(super) fn acquire(&self) -> PoolGuard<'_> {
let prev = self.active.fetch_add(1, Ordering::Relaxed);
let num_threads = if prev < self.parallel_threshold {
0 } else {
1 };
PoolGuard {
pool: self,
num_threads,
}
}
}
impl Default for DecodePool {
fn default() -> Self {
Self::new()
}
}
impl core::fmt::Debug for DecodePool {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("DecodePool")
.field("active", &self.active.load(Ordering::Relaxed))
.field("parallel_threshold", &self.parallel_threshold)
.finish()
}
}
pub(super) struct PoolGuard<'a> {
pool: &'a DecodePool,
pub(super) num_threads: usize,
}
impl Drop for PoolGuard<'_> {
fn drop(&mut self) {
self.pool.active.fetch_sub(1, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_threshold() {
let pool = DecodePool::new();
assert_eq!(pool.parallel_threshold, 4);
assert_eq!(pool.active_count(), 0);
}
#[test]
fn acquire_increments_and_drop_decrements() {
let pool = DecodePool::new();
assert_eq!(pool.active_count(), 0);
let g1 = pool.acquire();
assert_eq!(pool.active_count(), 1);
assert_eq!(g1.num_threads, 0);
let g2 = pool.acquire();
assert_eq!(pool.active_count(), 2);
assert_eq!(g2.num_threads, 0);
drop(g1);
assert_eq!(pool.active_count(), 1);
drop(g2);
assert_eq!(pool.active_count(), 0);
}
#[test]
fn threshold_triggers_sequential() {
let pool = DecodePool::new().parallel_threshold(2);
let g1 = pool.acquire();
assert_eq!(g1.num_threads, 0);
let g2 = pool.acquire();
assert_eq!(g2.num_threads, 0);
let g3 = pool.acquire();
assert_eq!(g3.num_threads, 1);
let g4 = pool.acquire();
assert_eq!(g4.num_threads, 1);
drop(g3);
drop(g4);
drop(g1);
let g5 = pool.acquire();
assert_eq!(g5.num_threads, 0);
drop(g2);
drop(g5);
assert_eq!(pool.active_count(), 0);
}
#[test]
fn custom_threshold() {
let pool = DecodePool::new().parallel_threshold(1);
let g1 = pool.acquire();
assert_eq!(g1.num_threads, 0);
let g2 = pool.acquire();
assert_eq!(g2.num_threads, 1);
drop(g1);
drop(g2);
}
}