use std::num::NonZeroUsize;
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub min_chunk_size: usize,
pub max_threads: Option<NonZeroUsize>,
pub work_stealing: bool,
pub load_balancing: LoadBalancing,
}
#[derive(Debug, Clone)]
pub enum LoadBalancing {
Static,
Dynamic,
Adaptive,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
min_chunk_size: 1000,
max_threads: None,
work_stealing: true,
load_balancing: LoadBalancing::Dynamic,
}
}
}
impl ParallelConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_min_chunk_size(mut self, size: usize) -> Self {
self.min_chunk_size = size;
self
}
pub fn with_max_threads(mut self, threads: usize) -> Self {
self.max_threads = NonZeroUsize::new(threads);
self
}
pub fn with_work_stealing(mut self, enabled: bool) -> Self {
self.work_stealing = enabled;
self
}
pub fn with_load_balancing(mut self, strategy: LoadBalancing) -> Self {
self.load_balancing = strategy;
self
}
pub fn optimal_chunk_size(&self, data_size: usize) -> usize {
let thread_count = self
.max_threads
.map(|t| t.get())
.unwrap_or_else(|| num_cpus::get());
let base_chunk_size = (data_size + thread_count - 1) / thread_count;
base_chunk_size.max(self.min_chunk_size)
}
}
#[derive(Debug, Clone)]
pub struct SIMDConfig {
pub enabled: bool,
pub vector_width: usize,
pub alignment: usize,
pub min_simd_size: usize,
}
impl Default for SIMDConfig {
fn default() -> Self {
Self {
enabled: true,
vector_width: 32, alignment: 32,
min_simd_size: 64,
}
}
}
impl SIMDConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn with_vector_width(mut self, width: usize) -> Self {
self.vector_width = width;
self
}
pub fn with_alignment(mut self, alignment: usize) -> Self {
self.alignment = alignment;
self
}
pub fn with_min_simd_size(mut self, size: usize) -> Self {
self.min_simd_size = size;
self
}
pub fn should_use_simd(&self, data_size: usize) -> bool {
self.enabled && data_size >= self.min_simd_size
}
}
#[derive(Debug, Clone)]
pub struct JITConfig {
pub parallel: ParallelConfig,
pub simd: SIMDConfig,
pub jit_enabled: bool,
pub optimization_level: u8,
pub cache_compiled: bool,
}
impl Default for JITConfig {
fn default() -> Self {
Self {
parallel: ParallelConfig::default(),
simd: SIMDConfig::default(),
jit_enabled: true,
optimization_level: 2,
cache_compiled: true,
}
}
}
impl JITConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_parallel(mut self, config: ParallelConfig) -> Self {
self.parallel = config;
self
}
pub fn with_simd(mut self, config: SIMDConfig) -> Self {
self.simd = config;
self
}
pub fn with_jit_enabled(mut self, enabled: bool) -> Self {
self.jit_enabled = enabled;
self
}
pub fn with_optimization_level(mut self, level: u8) -> Self {
self.optimization_level = level.min(3);
self
}
pub fn with_cache_compiled(mut self, enabled: bool) -> Self {
self.cache_compiled = enabled;
self
}
pub fn throughput_optimized() -> Self {
Self::new()
.with_parallel(
ParallelConfig::new()
.with_min_chunk_size(10000)
.with_load_balancing(LoadBalancing::Static),
)
.with_simd(SIMDConfig::new().with_min_simd_size(128))
.with_optimization_level(3)
}
pub fn latency_optimized() -> Self {
Self::new()
.with_parallel(
ParallelConfig::new()
.with_min_chunk_size(100)
.with_load_balancing(LoadBalancing::Dynamic),
)
.with_simd(SIMDConfig::new().with_min_simd_size(32))
.with_optimization_level(2)
}
}
static GLOBAL_JIT_CONFIG: std::sync::OnceLock<JITConfig> = std::sync::OnceLock::new();
pub fn get_global_config() -> &'static JITConfig {
GLOBAL_JIT_CONFIG.get_or_init(|| JITConfig::default())
}
pub fn set_global_config(config: JITConfig) -> Result<(), JITConfig> {
GLOBAL_JIT_CONFIG.set(config)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_config() {
let config = ParallelConfig::new()
.with_min_chunk_size(500)
.with_max_threads(4);
assert_eq!(config.min_chunk_size, 500);
assert_eq!(
config.max_threads.expect("operation should succeed").get(),
4
);
}
#[test]
fn test_simd_config() {
let config = SIMDConfig::new().with_enabled(false).with_min_simd_size(32);
assert!(!config.enabled);
assert_eq!(config.min_simd_size, 32);
assert!(!config.should_use_simd(64));
}
#[test]
fn test_jit_config() {
let config = JITConfig::throughput_optimized();
assert_eq!(config.parallel.min_chunk_size, 10000);
assert_eq!(config.optimization_level, 3);
}
}