use std::sync::OnceLock;
static CHUNKING_CONFIG: OnceLock<ChunkingConfig> = OnceLock::new();
#[derive(Debug, Clone)]
pub struct ChunkingConfig {
pub l1_cache_size: usize,
pub l2_cache_size: usize,
pub l3_cache_size: usize,
pub num_cores: usize,
pub cache_line_size: usize,
pub numa_aware: bool,
}
impl Default for ChunkingConfig {
fn default() -> Self {
Self {
l1_cache_size: 32 * 1024, l2_cache_size: 256 * 1024, l3_cache_size: 8 * 1024 * 1024, num_cores: num_cpus::get(),
cache_line_size: 64,
numa_aware: false,
}
}
}
impl ChunkingConfig {
pub fn compute_intensive() -> Self {
let mut config = Self::default();
config.l2_cache_size = config.l2_cache_size.max(512 * 1024);
config
}
pub fn memory_intensive() -> Self {
let mut config = Self::default();
config.l1_cache_size = config.l1_cache_size.min(32 * 1024);
config
}
pub fn cache_friendly() -> Self {
let mut config = Self::default();
config.cache_line_size = 64;
config
}
pub fn optimal_elementwise_chunk(&self, element_size: usize) -> usize {
let target_bytes = (self.l1_cache_size * 4) / 5;
let chunk_size = target_bytes / element_size;
let elements_per_cache_line = self.cache_line_size / element_size;
(chunk_size / elements_per_cache_line) * elements_per_cache_line
}
pub fn optimal_matrix_chunk(&self, element_size: usize) -> (usize, usize) {
let target_bytes = (self.l2_cache_size * 7) / 10;
let total_elements = target_bytes / element_size;
let side = (total_elements as f64).sqrt() as usize;
let elements_per_cache_line = self.cache_line_size / element_size;
let aligned_side = (side / elements_per_cache_line) * elements_per_cache_line;
(
aligned_side.max(elements_per_cache_line),
aligned_side.max(elements_per_cache_line),
)
}
pub fn optimal_reduction_chunk(&self, element_size: usize) -> usize {
let target_bytes = self.l2_cache_size / 2;
let chunk_size = target_bytes / element_size;
chunk_size.max(1024)
}
pub fn optimal_parallel_chunks(&self, total_size: usize, chunk_size: usize) -> usize {
let chunks_from_size = (total_size + chunk_size - 1) / chunk_size;
let ideal_chunks = self.num_cores * 3;
chunks_from_size.min(ideal_chunks).max(self.num_cores)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WorkloadType {
Elementwise,
Matrix,
Reduction,
Convolution,
Custom,
}
pub struct ChunkingUtils;
impl ChunkingUtils {
pub fn global_config() -> &'static ChunkingConfig {
CHUNKING_CONFIG.get_or_init(ChunkingConfig::default)
}
pub fn set_global_config(config: ChunkingConfig) -> Result<(), ChunkingConfig> {
CHUNKING_CONFIG.set(config)
}
pub fn optimal_chunk_size(
workload: WorkloadType,
element_size: usize,
total_elements: usize,
) -> usize {
let config = Self::global_config();
let chunk_size = match workload {
WorkloadType::Elementwise => config.optimal_elementwise_chunk(element_size),
WorkloadType::Reduction => config.optimal_reduction_chunk(element_size),
WorkloadType::Matrix => {
let (rows, _) = config.optimal_matrix_chunk(element_size);
rows * rows }
WorkloadType::Convolution => {
(config.l2_cache_size / element_size) / 4
}
WorkloadType::Custom => {
config.l2_cache_size / element_size
}
};
chunk_size
.max(64) .min(total_elements) }
pub fn chunk_range(
start: usize,
end: usize,
workload: WorkloadType,
element_size: usize,
) -> Vec<(usize, usize)> {
let total_elements = end - start;
if total_elements == 0 {
return Vec::new();
}
let chunk_size = Self::optimal_chunk_size(workload, element_size, total_elements);
let config = Self::global_config();
let num_chunks = config.optimal_parallel_chunks(total_elements, chunk_size);
let actual_chunk_size = (total_elements + num_chunks - 1) / num_chunks;
let mut chunks = Vec::with_capacity(num_chunks);
let mut current = start;
while current < end {
let chunk_end = (current + actual_chunk_size).min(end);
chunks.push((current, chunk_end));
current = chunk_end;
}
chunks
}
pub fn matrix_blocks(
m: usize,
n: usize,
k: usize,
element_size: usize,
) -> (usize, usize, usize) {
let config = Self::global_config();
let (block_m, block_n) = config.optimal_matrix_chunk(element_size);
let block_k = config.optimal_elementwise_chunk(element_size) / block_m.max(1);
(
block_m.min(m).max(1),
block_n.min(n).max(1),
block_k.min(k).max(1),
)
}
pub fn is_cache_aligned(chunk_size: usize, element_size: usize) -> bool {
let config = Self::global_config();
let bytes = chunk_size * element_size;
bytes % config.cache_line_size == 0
}
pub fn align_to_cache_line(chunk_size: usize, element_size: usize) -> usize {
let config = Self::global_config();
let elements_per_line = config.cache_line_size / element_size.max(1);
((chunk_size + elements_per_line - 1) / elements_per_line) * elements_per_line
}
}
pub struct ChunkingStrategy {
workload: WorkloadType,
element_size: usize,
prefer_alignment: bool,
}
impl ChunkingStrategy {
pub fn new(workload: WorkloadType, element_size: usize) -> Self {
Self {
workload,
element_size,
prefer_alignment: true,
}
}
pub fn with_alignment(mut self, prefer: bool) -> Self {
self.prefer_alignment = prefer;
self
}
pub fn chunk_size(&self, total_size: usize) -> usize {
let size = ChunkingUtils::optimal_chunk_size(self.workload, self.element_size, total_size);
if self.prefer_alignment {
ChunkingUtils::align_to_cache_line(size, self.element_size)
} else {
size
}
}
pub fn split_range(&self, start: usize, end: usize) -> Vec<(usize, usize)> {
ChunkingUtils::chunk_range(start, end, self.workload, self.element_size)
}
}
pub mod prelude {
pub use super::{ChunkingConfig, ChunkingStrategy, ChunkingUtils, WorkloadType};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chunking_config_default() {
let config = ChunkingConfig::default();
assert!(config.l1_cache_size > 0);
assert!(config.l2_cache_size > config.l1_cache_size);
assert!(config.l3_cache_size > config.l2_cache_size);
assert!(config.num_cores > 0);
assert_eq!(config.cache_line_size, 64);
}
#[test]
fn test_chunking_config_compute_intensive() {
let config = ChunkingConfig::compute_intensive();
assert!(config.l2_cache_size >= 512 * 1024);
}
#[test]
fn test_chunking_config_memory_intensive() {
let config = ChunkingConfig::memory_intensive();
assert!(config.l1_cache_size <= 32 * 1024);
}
#[test]
fn test_optimal_elementwise_chunk() {
let config = ChunkingConfig::default();
let chunk_size = config.optimal_elementwise_chunk(4); assert!(chunk_size > 0);
assert!(chunk_size * 4 <= config.l1_cache_size);
}
#[test]
fn test_optimal_matrix_chunk() {
let config = ChunkingConfig::default();
let (rows, cols) = config.optimal_matrix_chunk(4); assert!(rows > 0);
assert!(cols > 0);
assert!(rows * cols * 4 <= config.l2_cache_size);
}
#[test]
fn test_optimal_reduction_chunk() {
let config = ChunkingConfig::default();
let chunk_size = config.optimal_reduction_chunk(4); assert!(chunk_size >= 1024);
}
#[test]
fn test_optimal_parallel_chunks() {
let config = ChunkingConfig::default();
let num_chunks = config.optimal_parallel_chunks(10000, 100);
assert!(num_chunks >= config.num_cores);
assert!(num_chunks <= config.num_cores * 4);
}
#[test]
fn test_chunking_utils_optimal_chunk_size() {
let chunk_size = ChunkingUtils::optimal_chunk_size(WorkloadType::Elementwise, 4, 10000);
assert!(chunk_size > 0);
assert!(chunk_size <= 10000);
}
#[test]
fn test_chunking_utils_chunk_range() {
let chunks = ChunkingUtils::chunk_range(0, 1000, WorkloadType::Elementwise, 4);
assert!(!chunks.is_empty());
assert_eq!(chunks.first().expect("collection should be non-empty").0, 0);
assert_eq!(
chunks.last().expect("collection should be non-empty").1,
1000
);
for window in chunks.windows(2) {
assert_eq!(window[0].1, window[1].0);
}
}
#[test]
fn test_chunking_utils_matrix_blocks() {
let (block_m, block_n, block_k) = ChunkingUtils::matrix_blocks(1000, 1000, 1000, 4);
assert!(block_m > 0 && block_m <= 1000);
assert!(block_n > 0 && block_n <= 1000);
assert!(block_k > 0 && block_k <= 1000);
}
#[test]
fn test_is_cache_aligned() {
let config = ChunkingConfig::default();
let aligned_size = config.cache_line_size / 4; assert!(ChunkingUtils::is_cache_aligned(aligned_size, 4));
assert!(!ChunkingUtils::is_cache_aligned(aligned_size + 1, 4));
}
#[test]
fn test_align_to_cache_line() {
let unaligned = 100;
let aligned = ChunkingUtils::align_to_cache_line(unaligned, 4);
assert!(ChunkingUtils::is_cache_aligned(aligned, 4));
assert!(aligned >= unaligned);
}
#[test]
fn test_chunking_strategy() {
let strategy = ChunkingStrategy::new(WorkloadType::Elementwise, 4);
let chunk_size = strategy.chunk_size(10000);
assert!(chunk_size > 0);
assert!(chunk_size <= 10000);
}
#[test]
fn test_chunking_strategy_split_range() {
let strategy = ChunkingStrategy::new(WorkloadType::Elementwise, 4);
let chunks = strategy.split_range(0, 1000);
assert!(!chunks.is_empty());
assert_eq!(chunks.first().expect("collection should be non-empty").0, 0);
assert_eq!(
chunks.last().expect("collection should be non-empty").1,
1000
);
}
#[test]
fn test_chunking_strategy_with_alignment() {
let strategy = ChunkingStrategy::new(WorkloadType::Elementwise, 4).with_alignment(true);
let chunk_size = strategy.chunk_size(10000);
assert!(ChunkingUtils::is_cache_aligned(chunk_size, 4));
}
#[test]
fn test_workload_types() {
for workload in &[
WorkloadType::Elementwise,
WorkloadType::Matrix,
WorkloadType::Reduction,
WorkloadType::Convolution,
WorkloadType::Custom,
] {
let chunk_size = ChunkingUtils::optimal_chunk_size(*workload, 4, 10000);
assert!(chunk_size > 0);
}
}
}