use rayon::prelude::*;
pub fn par_apply_column<F>(column: &mut [f32], f: F)
where
F: Fn(usize, &mut f32) + Send + Sync,
{
column.par_iter_mut().enumerate().for_each(|(i, v)| f(i, v));
}
pub fn par_apply_chunks<F>(column: &mut [f32], chunk_size: usize, f: F)
where
F: Fn(&mut [f32]) + Send + Sync,
{
assert!(chunk_size > 0, "chunk_size must be positive");
column.par_chunks_mut(chunk_size).for_each(f);
}
pub fn par_apply_rows<const N: usize, F>(columns: &mut [&mut [f32]; N], f: F)
where
F: Fn(usize, &mut [&mut f32; N]) + Send + Sync,
{
let n = columns[0].len();
for col in columns.iter() {
assert_eq!(col.len(), n, "all columns must have equal length");
}
#[repr(transparent)]
struct Row<const N: usize>([*mut f32; N]);
unsafe impl<const N: usize> Send for Row<N> {}
unsafe impl<const N: usize> Sync for Row<N> {}
let rows: Vec<Row<N>> = (0..n)
.map(|i| {
let mut row: [*mut f32; N] = [std::ptr::null_mut(); N];
for (c, col) in columns.iter_mut().enumerate() {
row[c] = &mut col[i] as *mut f32;
}
Row(row)
})
.collect();
rows.par_iter().enumerate().for_each(|(i, row)| {
let mut refs: [&mut f32; N] = std::array::from_fn(|c| {
let ptr = row.0[c];
unsafe { &mut *ptr }
});
f(i, &mut refs);
});
}
pub fn par_apply_chunks_multi<F>(columns: &mut [&mut [f32]], chunk_size: usize, f: F)
where
F: Fn(usize, &mut [&mut [f32]]) + Send + Sync,
{
assert!(chunk_size > 0, "chunk_size must be positive");
if columns.is_empty() {
return;
}
let n = columns[0].len();
for col in columns.iter() {
assert_eq!(col.len(), n, "all columns must have equal length");
}
if n == 0 {
return;
}
let n_chunks = n.div_ceil(chunk_size);
#[repr(transparent)]
struct ColPtr(*mut f32);
unsafe impl Send for ColPtr {}
unsafe impl Sync for ColPtr {}
let ptrs: Vec<ColPtr> = columns.iter_mut().map(|c| ColPtr(c.as_mut_ptr())).collect();
(0..n_chunks).into_par_iter().for_each(|c| {
let start = c * chunk_size;
let end = (start + chunk_size).min(n);
let len = end - start;
let mut slices: Vec<&mut [f32]> = ptrs
.iter()
.map(|p| unsafe { std::slice::from_raw_parts_mut(p.0.add(start), len) })
.collect();
f(start, &mut slices);
});
}
pub fn par_chunked_schedule<F>(schedule: &[rustsim_core::types::AgentId], chunk_size: usize, f: F)
where
F: Fn(usize, &[rustsim_core::types::AgentId]) + Send + Sync,
{
assert!(chunk_size > 0, "chunk_size must be positive");
schedule
.par_chunks(chunk_size)
.enumerate()
.for_each(|(c, chunk)| f(c * chunk_size, chunk));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn par_apply_column_writes_all_indices() {
let mut col = vec![0.0_f32; 1000];
par_apply_column(&mut col, |i, v| *v = i as f32);
for (i, v) in col.iter().enumerate() {
assert_eq!(*v, i as f32);
}
}
#[test]
fn par_apply_chunks_processes_tail() {
let mut col = vec![1.0_f32; 1003];
par_apply_chunks(&mut col, 128, |chunk| {
for v in chunk {
*v = 2.0;
}
});
assert!(col.iter().all(|v| *v == 2.0));
}
#[test]
fn par_apply_rows_two_columns() {
let mut a = vec![1.0_f32; 500];
let mut b = vec![2.0_f32; 500];
{
let a_slice: &mut [f32] = &mut a;
let b_slice: &mut [f32] = &mut b;
let mut cols: [&mut [f32]; 2] = [a_slice, b_slice];
par_apply_rows::<2, _>(&mut cols, |_i, row| {
*row[0] += 10.0;
*row[1] *= 3.0;
});
}
assert!(a.iter().all(|v| *v == 11.0));
assert!(b.iter().all(|v| *v == 6.0));
}
#[test]
fn par_apply_chunks_multi_processes_aligned_chunks() {
let mut a = vec![1.0_f32; 1003];
let mut b = vec![2.0_f32; 1003];
{
let a_slice: &mut [f32] = &mut a;
let b_slice: &mut [f32] = &mut b;
let mut cols: [&mut [f32]; 2] = [a_slice, b_slice];
par_apply_chunks_multi(&mut cols, 128, |start, slices| {
let (head, tail) = slices.split_at_mut(1);
let ca = &mut head[0];
let cb = &mut tail[0];
for (offset, (va, vb)) in ca.iter_mut().zip(cb.iter_mut()).enumerate() {
*va = (start + offset) as f32;
*vb = *va * 2.0;
}
});
}
for (i, v) in a.iter().enumerate() {
assert_eq!(*v, i as f32);
}
for (i, v) in b.iter().enumerate() {
assert_eq!(*v, (i as f32) * 2.0);
}
}
#[test]
fn par_chunked_schedule_covers_every_id_once() {
use std::sync::atomic::{AtomicUsize, Ordering};
let schedule: Vec<rustsim_core::types::AgentId> = (0..1000).collect();
let counters: Vec<AtomicUsize> = (0..1000).map(|_| AtomicUsize::new(0)).collect();
par_chunked_schedule(&schedule, 64, |start, chunk| {
for (offset, id) in chunk.iter().enumerate() {
assert_eq!(schedule[start + offset], *id);
counters[*id as usize].fetch_add(1, Ordering::Relaxed);
}
});
for c in &counters {
assert_eq!(c.load(Ordering::Relaxed), 1);
}
}
}