#[cfg(feature = "neural-network")]
use krnl::scalar::Scalar;
use ndarray::{ArrayViewMut, Dimension, RawArrayViewMut};
#[cfg(feature = "neural-network")]
use ndarray::{Axis, Ix4, Ix5, RemoveAxis};
use std::marker::PhantomData;
#[cfg(target_family = "x86")]
use std::sync::OnceLock;
pub(crate) fn parallel_size() -> usize {
const L1_CACHE_SIZE_DEFAULT: usize = 1 << 15;
let l1_cache_size: usize = {
#[cfg(target_family = "x86")]
{
static L1_CACHE_SIZE: OnceLock<usize> = std::sync::OnceLock::new();
*L1_CACHE_SIZE
.get_or_init(|| cache_size::l1_cache_size().unwrap_or(L1_CACHE_SIZE_DEFAULT))
}
#[cfg(not(target_family = "x86"))]
{
L1_CACHE_SIZE_DEFAULT
}
};
let simd_width = if cfg!(target_feature = "avx") {
256
} else {
32
};
2 * simd_width * l1_cache_size
}
#[cfg(feature = "neural-network")]
pub(crate) fn array_par_outer_iter_mut_for_each<T: Scalar, D: RemoveAxis, F>(
mut array: ArrayViewMut<T, D>,
f: F,
) where
F: Fn(usize, ArrayViewMut<T, D::Smaller>) + Send + Sync,
{
if rayon::current_num_threads() == 1 {
array
.outer_iter_mut()
.enumerate()
.for_each(|(i, array)| f(i, array));
return;
}
let items = array.shape().first().copied().unwrap_or(1);
let sync_array = SyncRawArrayViewMut::try_from(array).unwrap();
rayon::scope(|scope| {
scope.spawn_broadcast(move |_scope, context| {
let _ = &sync_array;
let item_id = context.index();
let threads = context.num_threads();
(item_id..items).step_by(threads).for_each(|item_id| {
let item = sync_array.inner.clone().index_axis_move(Axis(0), item_id);
let item = unsafe { item.deref_into_view_mut() };
f(item_id, item);
});
});
});
}
#[derive(Clone)]
pub(crate) struct SyncRawArrayViewMut<'a, T, D: Dimension> {
#[allow(unused)]
inner: RawArrayViewMut<T, D>,
_m: PhantomData<&'a T>,
}
#[cfg(feature = "neural-network")]
pub(crate) type SyncRawArrayViewMut4<'a, T> = SyncRawArrayViewMut<'a, T, Ix4>;
#[cfg(feature = "neural-network")]
pub(crate) type SyncRawArrayViewMut5<'a, T> = SyncRawArrayViewMut<'a, T, Ix5>;
impl<'a, T, D: Dimension> TryFrom<ArrayViewMut<'a, T, D>> for SyncRawArrayViewMut<'a, T, D> {
type Error = ();
fn try_from(mut array: ArrayViewMut<T, D>) -> Result<Self, ()> {
if array.is_standard_layout() {
Ok(Self {
inner: unsafe {
RawArrayViewMut::from_shape_ptr(array.raw_dim(), array.as_mut_ptr())
},
_m: PhantomData,
})
} else {
Err(())
}
}
}
#[cfg(feature = "neural-network")]
impl<'a, T, D: Dimension> SyncRawArrayViewMut<'a, T, D> {
pub(crate) fn as_mut_ptr(&mut self) -> *mut T {
self.inner.as_mut_ptr()
}
pub(crate) fn dim(&self) -> D::Pattern {
self.inner.dim()
}
}
unsafe impl<T: Send + Sync + 'static, D: Dimension> Send for SyncRawArrayViewMut<'_, T, D> {}
unsafe impl<T: Send + Sync + 'static, D: Dimension> Sync for SyncRawArrayViewMut<'_, T, D> {}
#[cfg(feature = "neural-network")]
pub(crate) fn broadcast(threads: Option<usize>, f: impl Fn(usize, usize) + Send + Sync) {
let threads = threads
.unwrap_or(usize::MAX)
.min(rayon::current_num_threads());
if threads == 1 {
f(0, 1);
} else {
rayon::in_place_scope(|scope| {
scope.spawn_broadcast(|_scope, context| {
let thread_id = context.index();
debug_assert!(threads <= context.num_threads());
if thread_id < threads {
f(thread_id, threads);
}
});
});
}
}