use super::device::CpuDevice;
use super::runtime::CpuRuntime;
use crate::runtime::{DefaultAllocator, RuntimeClient};
use std::alloc::{Layout as AllocLayout, alloc_zeroed, dealloc};
#[cfg(feature = "rayon")]
use std::sync::Arc;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct ParallelismConfig {
pub max_threads: Option<usize>,
pub chunk_size: Option<usize>,
}
impl ParallelismConfig {
#[must_use]
pub const fn new(max_threads: Option<usize>, chunk_size: Option<usize>) -> Self {
Self {
max_threads,
chunk_size,
}
}
}
#[derive(Clone)]
pub struct CpuClient {
pub(crate) device: CpuDevice,
allocator: CpuAllocator,
parallelism: ParallelismConfig,
#[cfg(feature = "rayon")]
thread_pool: Option<Arc<rayon::ThreadPool>>,
}
impl std::fmt::Debug for CpuClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CpuClient")
.field("device", &self.device)
.field("parallelism", &self.parallelism)
.finish()
}
}
impl CpuClient {
pub fn new(device: CpuDevice) -> Self {
let allocator = create_cpu_allocator(device.clone());
Self {
device,
allocator,
parallelism: ParallelismConfig::default(),
#[cfg(feature = "rayon")]
thread_pool: None,
}
}
#[must_use]
pub fn with_parallelism(&self, config: ParallelismConfig) -> Self {
Self {
device: self.device.clone(),
allocator: self.allocator.clone(),
parallelism: config,
#[cfg(feature = "rayon")]
thread_pool: build_thread_pool(config.max_threads),
}
}
#[must_use]
pub const fn parallelism(&self) -> ParallelismConfig {
self.parallelism
}
#[cfg(feature = "rayon")]
pub(crate) fn install_parallelism<F, T>(&self, f: F) -> T
where
F: FnOnce() -> T + Send,
T: Send,
{
if let Some(pool) = &self.thread_pool {
pool.install(f)
} else {
f()
}
}
#[cfg(not(feature = "rayon"))]
pub(crate) fn install_parallelism<F, T>(&self, f: F) -> T
where
F: FnOnce() -> T,
{
f()
}
#[inline]
pub(crate) fn chunk_size_hint(&self) -> usize {
self.parallelism.chunk_size.unwrap_or(1).max(1)
}
#[cfg(feature = "rayon")]
#[inline]
pub(crate) fn rayon_min_len(&self) -> usize {
self.chunk_size_hint()
}
}
impl RuntimeClient<CpuRuntime> for CpuClient {
fn device(&self) -> &CpuDevice {
&self.device
}
fn synchronize(&self) {
}
fn allocator(&self) -> &CpuAllocator {
&self.allocator
}
}
pub type CpuAllocator = DefaultAllocator<CpuDevice>;
#[cfg(feature = "rayon")]
fn build_thread_pool(max_threads: Option<usize>) -> Option<Arc<rayon::ThreadPool>> {
let threads = max_threads?;
if threads == 0 {
return None;
}
rayon::ThreadPoolBuilder::new()
.num_threads(threads)
.build()
.ok()
.map(Arc::new)
}
fn create_cpu_allocator(device: CpuDevice) -> CpuAllocator {
DefaultAllocator::new(
device,
|size, _dev| {
if size == 0 {
return Ok(0);
}
let align = 64; let layout =
AllocLayout::from_size_align(size, align).expect("Invalid allocation layout");
let ptr = unsafe { alloc_zeroed(layout) };
if ptr.is_null() {
return Err(crate::error::Error::OutOfMemory { size });
}
Ok(ptr as u64)
},
|ptr, size, _dev| {
if ptr == 0 || size == 0 {
return;
}
let align = 64;
let layout =
AllocLayout::from_size_align(size, align).expect("Invalid allocation layout");
unsafe {
dealloc(ptr as *mut u8, layout);
}
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::Device;
#[test]
fn test_with_parallelism_preserves_device_and_updates_config() {
let client = CpuClient::new(CpuDevice::new());
let configured = client.with_parallelism(ParallelismConfig::new(Some(2), Some(512)));
assert_eq!(configured.device.id(), client.device.id());
assert_eq!(configured.parallelism().max_threads, Some(2));
assert_eq!(configured.parallelism().chunk_size, Some(512));
}
#[cfg(feature = "rayon")]
#[test]
fn test_rayon_min_len_defaults_and_normalizes_zero() {
let client = CpuClient::new(CpuDevice::new());
assert_eq!(client.rayon_min_len(), 1);
let configured = client.with_parallelism(ParallelismConfig::new(Some(2), Some(64)));
assert_eq!(configured.rayon_min_len(), 64);
let zero_chunk = client.with_parallelism(ParallelismConfig::new(Some(2), Some(0)));
assert_eq!(zero_chunk.rayon_min_len(), 1);
}
}