use rayon::prelude::*;
#[derive(Debug, Clone, PartialEq)]
pub enum BoundaryError {
InvalidBounds {
axis: usize,
min: f64,
max: f64,
},
InvalidVectorDimension {
label: &'static str,
expected: usize,
got: usize,
},
InvalidFlatVectorListLength {
label: &'static str,
dim: usize,
len: usize,
},
InconsistentFlatVectorListLength {
expected: usize,
got: usize,
},
}
pub trait ContinuousBoundary: Sync {
fn dim(&self) -> usize;
fn apply_position(&self, r: &mut [f64]) -> Result<(), BoundaryError>;
fn apply_position_velocity(&self, r: &mut [f64], v: &mut [f64]) -> Result<(), BoundaryError> {
validate_vector_len("velocity", self.dim(), v.len())?;
let mut flip_mask = vec![0; self.dim()];
self.apply_position_with_velocity_flip_mask(r, &mut flip_mask)?;
for (velocity, &flip) in v.iter_mut().zip(flip_mask.iter()) {
if flip == 1 {
*velocity = -*velocity;
}
}
Ok(())
}
fn apply_positions(&self, positions: &mut [f64]) -> Result<(), BoundaryError> {
validate_flat_vector_list("positions", self.dim(), positions.len())?;
positions
.par_chunks_mut(self.dim())
.try_for_each(|r| self.apply_position(r))
}
fn apply_positions_velocities(
&self,
positions: &mut [f64],
velocities: &mut [f64],
) -> Result<(), BoundaryError> {
validate_flat_vector_list("positions", self.dim(), positions.len())?;
validate_flat_vector_list("velocities", self.dim(), velocities.len())?;
if positions.len() != velocities.len() {
return Err(BoundaryError::InconsistentFlatVectorListLength {
expected: positions.len(),
got: velocities.len(),
});
}
positions
.par_chunks_mut(self.dim())
.zip(velocities.par_chunks_mut(self.dim()))
.try_for_each(|(r, v)| self.apply_position_velocity(r, v))
}
fn apply_position_with_velocity_flip_mask(
&self,
r: &mut [f64],
flip_mask: &mut [u8],
) -> Result<(), BoundaryError> {
validate_vector_len("velocity_flip_mask", self.dim(), flip_mask.len())?;
flip_mask.fill(0);
self.apply_position(r)
}
}
fn validate_bounds(min: &[f64], max: &[f64]) -> Result<(), BoundaryError> {
if min.len() != max.len() {
return Err(BoundaryError::InvalidVectorDimension {
label: "bounds",
expected: min.len(),
got: max.len(),
});
}
for d in 0..min.len() {
if !min[d].is_finite() || !max[d].is_finite() || max[d] <= min[d] {
return Err(BoundaryError::InvalidBounds {
axis: d,
min: min[d],
max: max[d],
});
}
}
Ok(())
}
#[inline]
fn validate_vector_len(
label: &'static str,
expected: usize,
got: usize,
) -> Result<(), BoundaryError> {
if expected != got {
return Err(BoundaryError::InvalidVectorDimension {
label,
expected,
got,
});
}
Ok(())
}
#[inline]
fn validate_flat_vector_list(
label: &'static str,
dim: usize,
len: usize,
) -> Result<(), BoundaryError> {
if len % dim != 0 {
return Err(BoundaryError::InvalidFlatVectorListLength { label, dim, len });
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct PeriodicBox {
min: Vec<f64>,
max: Vec<f64>,
}
impl PeriodicBox {
pub fn new(min: &[f64], max: &[f64]) -> Result<Self, BoundaryError> {
validate_bounds(min, max)?;
Ok(Self {
min: min.to_vec(),
max: max.to_vec(),
})
}
#[inline]
pub fn dim(&self) -> usize {
self.min.len()
}
#[inline]
pub fn min(&self) -> &[f64] {
&self.min
}
#[inline]
pub fn max(&self) -> &[f64] {
&self.max
}
}
impl ContinuousBoundary for PeriodicBox {
#[inline]
fn dim(&self) -> usize {
self.dim()
}
fn apply_position(&self, r: &mut [f64]) -> Result<(), BoundaryError> {
validate_vector_len("position", self.dim(), r.len())?;
for (d, x) in r.iter_mut().enumerate() {
if !x.is_finite() {
continue;
}
let lo = self.min[d];
let hi = self.max[d];
let w = hi - lo;
*x = lo + (*x - lo).rem_euclid(w);
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ClampBox {
min: Vec<f64>,
max: Vec<f64>,
}
impl ClampBox {
pub fn new(min: &[f64], max: &[f64]) -> Result<Self, BoundaryError> {
validate_bounds(min, max)?;
Ok(Self {
min: min.to_vec(),
max: max.to_vec(),
})
}
#[inline]
pub fn dim(&self) -> usize {
self.min.len()
}
#[inline]
pub fn min(&self) -> &[f64] {
&self.min
}
#[inline]
pub fn max(&self) -> &[f64] {
&self.max
}
}
impl ContinuousBoundary for ClampBox {
#[inline]
fn dim(&self) -> usize {
self.dim()
}
fn apply_position(&self, r: &mut [f64]) -> Result<(), BoundaryError> {
validate_vector_len("position", self.dim(), r.len())?;
for (d, x) in r.iter_mut().enumerate() {
*x = x.clamp(self.min[d], self.max[d]);
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ReflectBox {
min: Vec<f64>,
max: Vec<f64>,
}
impl ReflectBox {
pub fn new(min: &[f64], max: &[f64]) -> Result<Self, BoundaryError> {
validate_bounds(min, max)?;
Ok(Self {
min: min.to_vec(),
max: max.to_vec(),
})
}
#[inline]
pub fn dim(&self) -> usize {
self.min.len()
}
#[inline]
pub fn min(&self) -> &[f64] {
&self.min
}
#[inline]
pub fn max(&self) -> &[f64] {
&self.max
}
}
impl ContinuousBoundary for ReflectBox {
#[inline]
fn dim(&self) -> usize {
self.dim()
}
fn apply_position(&self, r: &mut [f64]) -> Result<(), BoundaryError> {
let mut flip_mask = vec![0; self.dim()];
self.apply_position_with_velocity_flip_mask(r, &mut flip_mask)
}
fn apply_position_with_velocity_flip_mask(
&self,
r: &mut [f64],
flip_mask: &mut [u8],
) -> Result<(), BoundaryError> {
validate_vector_len("position", self.dim(), r.len())?;
validate_vector_len("velocity_flip_mask", self.dim(), flip_mask.len())?;
flip_mask.fill(0);
for d in 0..self.dim() {
let x = r[d];
if !x.is_finite() {
continue;
}
let lo = self.min[d];
let hi = self.max[d];
if !(x < lo || x > hi) {
continue;
}
let w = hi - lo;
let y = (x - lo).rem_euclid(2.0 * w);
r[d] = if y <= w { lo + y } else { hi - (y - w) };
let flips = if x < lo {
((lo - x) / w).ceil() as i64
} else {
((x - hi) / w).ceil() as i64
};
if flips & 1 == 1 {
flip_mask[d] = 1;
}
}
Ok(())
}
}