#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub min_fft_size: usize,
pub min_batch_chunk: usize,
pub min_rows_per_thread: usize,
pub enabled: bool,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
min_fft_size: 4096,
min_batch_chunk: 4,
min_rows_per_thread: 4,
enabled: true,
}
}
}
impl ParallelConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_min_fft_size(mut self, size: usize) -> Self {
self.min_fft_size = size;
self
}
#[must_use]
pub fn with_min_batch_chunk(mut self, chunk: usize) -> Self {
self.min_batch_chunk = chunk;
self
}
#[must_use]
pub fn with_min_rows_per_thread(mut self, rows: usize) -> Self {
self.min_rows_per_thread = rows;
self
}
#[must_use]
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
#[must_use]
pub fn serial() -> Self {
Self {
enabled: false,
..Self::default()
}
}
#[inline]
#[must_use]
pub fn should_parallelize_fft(&self, n: usize) -> bool {
self.enabled && n >= self.min_fft_size
}
#[inline]
#[must_use]
pub fn batch_chunk_size(&self, batch_size: usize, num_threads: usize) -> usize {
if !self.enabled || num_threads <= 1 {
return batch_size;
}
let ideal = (batch_size + num_threads - 1) / num_threads;
let chunk = ideal.max(self.min_batch_chunk);
chunk.min(batch_size)
}
#[inline]
#[must_use]
pub fn rows_per_thread(&self, total_rows: usize, num_threads: usize) -> usize {
if !self.enabled || num_threads <= 1 {
return total_rows;
}
let ideal = (total_rows + num_threads - 1) / num_threads;
let rows = ideal.max(self.min_rows_per_thread);
rows.min(total_rows)
}
#[inline]
#[must_use]
pub fn should_parallelize_batch(&self, batch_size: usize, num_threads: usize) -> bool {
if !self.enabled || num_threads <= 1 {
return false;
}
batch_size >= num_threads * self.min_batch_chunk
}
#[inline]
#[must_use]
pub fn should_parallelize_rows(&self, total_rows: usize, num_threads: usize) -> bool {
if !self.enabled || num_threads <= 1 {
return false;
}
total_rows >= num_threads * self.min_rows_per_thread
}
}
#[cfg(feature = "std")]
use std::sync::OnceLock;
#[cfg(feature = "std")]
static GLOBAL_CONFIG: OnceLock<ParallelConfig> = OnceLock::new();
#[cfg(feature = "std")]
#[must_use]
pub fn global_parallel_config() -> &'static ParallelConfig {
GLOBAL_CONFIG.get_or_init(ParallelConfig::default)
}
#[cfg(feature = "std")]
pub fn set_global_parallel_config(config: ParallelConfig) -> Result<(), ParallelConfig> {
GLOBAL_CONFIG.set(config)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_defaults() {
let cfg = ParallelConfig::default();
assert!(cfg.enabled);
assert_eq!(cfg.min_fft_size, 4096);
assert_eq!(cfg.min_batch_chunk, 4);
assert_eq!(cfg.min_rows_per_thread, 4);
}
#[test]
fn test_should_parallelize_fft() {
let cfg = ParallelConfig::new().with_min_fft_size(4096);
assert!(!cfg.should_parallelize_fft(1024));
assert!(!cfg.should_parallelize_fft(4095));
assert!(cfg.should_parallelize_fft(4096));
assert!(cfg.should_parallelize_fft(65536));
}
#[test]
fn test_disabled_never_parallelizes() {
let cfg = ParallelConfig::serial();
assert!(!cfg.should_parallelize_fft(1_000_000));
assert!(!cfg.should_parallelize_batch(1000, 8));
assert!(!cfg.should_parallelize_rows(1000, 8));
assert_eq!(cfg.batch_chunk_size(1000, 8), 1000);
assert_eq!(cfg.rows_per_thread(1000, 8), 1000);
}
#[test]
fn test_batch_chunk_size() {
let cfg = ParallelConfig::new().with_min_batch_chunk(4);
assert_eq!(cfg.batch_chunk_size(100, 8), 13);
assert_eq!(cfg.batch_chunk_size(8, 8), 4);
assert_eq!(cfg.batch_chunk_size(2, 8), 2);
assert_eq!(cfg.batch_chunk_size(100, 1), 100);
}
#[test]
fn test_rows_per_thread() {
let cfg = ParallelConfig::new().with_min_rows_per_thread(4);
assert_eq!(cfg.rows_per_thread(64, 8), 8);
assert_eq!(cfg.rows_per_thread(8, 8), 4);
}
#[test]
fn test_should_parallelize_batch() {
let cfg = ParallelConfig::new().with_min_batch_chunk(4);
assert!(cfg.should_parallelize_batch(32, 8));
assert!(!cfg.should_parallelize_batch(31, 8));
}
#[test]
fn test_should_parallelize_rows() {
let cfg = ParallelConfig::new().with_min_rows_per_thread(4);
assert!(cfg.should_parallelize_rows(32, 8));
assert!(!cfg.should_parallelize_rows(31, 8));
}
#[test]
fn test_builder_chain() {
let cfg = ParallelConfig::new()
.with_min_fft_size(8192)
.with_min_batch_chunk(8)
.with_min_rows_per_thread(16)
.with_enabled(true);
assert_eq!(cfg.min_fft_size, 8192);
assert_eq!(cfg.min_batch_chunk, 8);
assert_eq!(cfg.min_rows_per_thread, 16);
assert!(cfg.enabled);
}
#[cfg(feature = "std")]
#[test]
fn test_global_config() {
let cfg = global_parallel_config();
assert!(cfg.enabled);
}
}