use rayon::prelude::*;
use crate::engines::soa::phys_obj::{AttrsError, PhysObj};
use crate::models::particles::attrs::{ATTR_R, ATTR_V, ParticleSelection};
use crate::models::particles::state::{
ParticleMasks, ParticleStateError, gather_masks, validate_vector_attr_f64,
};
use crate::space::continuous::boundary::{
BoundaryError as ContinuousBoundaryError, ContinuousBoundary,
};
#[derive(Debug, Clone, PartialEq)]
pub enum ParticleBoundaryError {
Boundary(ContinuousBoundaryError),
Attrs(AttrsError),
InvalidAttrShape {
label: &'static str,
expected_dim: usize,
got_dim: usize,
},
InconsistentParticleCount {
label: &'static str,
expected: usize,
got: usize,
},
}
impl From<AttrsError> for ParticleBoundaryError {
#[inline]
fn from(value: AttrsError) -> Self {
Self::Attrs(value)
}
}
impl From<ContinuousBoundaryError> for ParticleBoundaryError {
#[inline]
fn from(value: ContinuousBoundaryError) -> Self {
Self::Boundary(value)
}
}
impl From<ParticleStateError> for ParticleBoundaryError {
fn from(value: ParticleStateError) -> Self {
match value {
ParticleStateError::Attrs(err) => Self::Attrs(err),
ParticleStateError::InvalidAttrShape {
label,
expected_dim,
got_dim,
} => Self::InvalidAttrShape {
label,
expected_dim,
got_dim,
},
ParticleStateError::InconsistentParticleCount {
label,
expected,
got,
} => Self::InconsistentParticleCount {
label,
expected,
got,
},
}
}
}
pub trait ParticleBoundary: ContinuousBoundary {
fn apply_to_particles(&self, objects: &mut PhysObj) -> Result<(), ParticleBoundaryError>;
}
impl<T> ParticleBoundary for T
where
T: ContinuousBoundary,
{
fn apply_to_particles(&self, objects: &mut PhysObj) -> Result<(), ParticleBoundaryError> {
let (dim, n, masks) = shape_alive_rigid(objects)?;
if self.dim() != dim {
return Err(ContinuousBoundaryError::InvalidVectorDimension {
label: "bounds",
expected: dim,
got: self.dim(),
}
.into());
}
let mut flip_mask: Vec<u8> = vec![0; n * dim];
{
let r = objects.core.get_mut::<f64>(ATTR_R)?;
r.as_tensor_mut()
.data
.par_chunks_mut(dim)
.zip(flip_mask.par_chunks_mut(dim))
.enumerate()
.try_for_each(|(i, (r_row, mask_row))| {
if masks.should_skip(i) {
return Ok(());
}
self.apply_position_with_velocity_flip_mask(r_row, mask_row)
})?;
}
{
let v = objects.core.get_mut::<f64>(ATTR_V)?;
v.as_tensor_mut()
.data
.par_chunks_mut(dim)
.zip(flip_mask.par_chunks(dim))
.enumerate()
.for_each(|(i, (v_row, mask_row))| {
if masks.should_skip(i) {
return;
}
for d in 0..dim {
if mask_row[d] == 1 {
v_row[d] = -v_row[d];
}
}
});
}
Ok(())
}
}
#[inline]
fn shape_alive_rigid(
objects: &PhysObj,
) -> Result<(usize, usize, ParticleMasks), ParticleBoundaryError> {
let (dim, n) = {
let r = objects.core.get::<f64>(ATTR_R)?;
(r.dim(), r.num_vectors())
};
validate_vector_attr_f64(objects, ATTR_V, dim, n)?;
let masks = gather_masks(objects, n, ParticleSelection::AliveOnly)?;
Ok((dim, n, masks))
}