use rayon::prelude::*;
use crate::engines::soa::phys_obj::{AttrsError, PhysObj};
use crate::models::particles::attrs::{ATTR_A, ATTR_R, ATTR_V, ParticleSelection};
use crate::models::particles::state::{
ParticleMasks, ParticleStateError, gather_masks, validate_vector_attr_f64,
};
#[derive(Debug, Clone, PartialEq)]
pub enum IntegratorError {
InvalidDt {
dt: f64,
},
Attrs(AttrsError),
InvalidAttrShape {
label: &'static str,
expected_dim: usize,
got_dim: usize,
},
InconsistentParticleCount {
label: &'static str,
expected: usize,
got: usize,
},
}
impl From<AttrsError> for IntegratorError {
fn from(value: AttrsError) -> Self {
Self::Attrs(value)
}
}
impl From<ParticleStateError> for IntegratorError {
fn from(value: ParticleStateError) -> Self {
match value {
ParticleStateError::Attrs(err) => Self::Attrs(err),
ParticleStateError::InvalidAttrShape {
label,
expected_dim,
got_dim,
} => Self::InvalidAttrShape {
label,
expected_dim,
got_dim,
},
ParticleStateError::InconsistentParticleCount {
label,
expected,
got,
} => Self::InconsistentParticleCount {
label,
expected,
got,
},
}
}
}
pub trait Integrator {
fn apply(&mut self, objects: &mut PhysObj, dt: f64) -> Result<(), IntegratorError>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ExplicitEuler;
#[derive(Debug, Clone, Copy, Default)]
pub struct SemiImplicitEuler;
#[derive(Debug, Clone)]
struct IntegratorContext {
dim: usize,
masks: ParticleMasks,
}
fn validate_dt(dt: f64) -> Result<(), IntegratorError> {
if !dt.is_finite() || dt <= 0.0 {
return Err(IntegratorError::InvalidDt { dt });
}
Ok(())
}
fn validate_core_shapes(objects: &PhysObj) -> Result<IntegratorContext, IntegratorError> {
let (dim, n) = {
let v = objects.core.get::<f64>(ATTR_V)?;
(v.dim(), v.num_vectors())
};
validate_vector_attr_f64(objects, ATTR_A, dim, n)?;
validate_vector_attr_f64(objects, ATTR_R, dim, n)?;
Ok(IntegratorContext {
dim,
masks: gather_masks(objects, n, ParticleSelection::AliveOnly)?,
})
}
fn should_skip(ctx: &IntegratorContext, i: usize) -> bool {
ctx.masks.should_skip(i)
}
fn apply_explicit_euler(objects: &mut PhysObj, dt: f64) -> Result<(), IntegratorError> {
validate_dt(dt)?;
let ctx = validate_core_shapes(objects)?;
{
let (r, v) = objects.core.get_two_mut::<f64>(ATTR_R, ATTR_V)?;
let v_data = &v.as_tensor().data;
r.as_tensor_mut()
.data
.par_chunks_mut(ctx.dim)
.enumerate()
.for_each(|(i, r_row)| {
if should_skip(&ctx, i) {
return;
}
let v_row = &v_data[i * ctx.dim..(i + 1) * ctx.dim];
for k in 0..ctx.dim {
r_row[k] += v_row[k] * dt;
}
});
}
{
let (v, a) = objects.core.get_two_mut::<f64>(ATTR_V, ATTR_A)?;
let a_data = &a.as_tensor().data;
v.as_tensor_mut()
.data
.par_chunks_mut(ctx.dim)
.enumerate()
.for_each(|(i, v_row)| {
if should_skip(&ctx, i) {
return;
}
let a_row = &a_data[i * ctx.dim..(i + 1) * ctx.dim];
for k in 0..ctx.dim {
v_row[k] += a_row[k] * dt;
}
});
}
Ok(())
}
fn apply_semi_implicit_euler(objects: &mut PhysObj, dt: f64) -> Result<(), IntegratorError> {
validate_dt(dt)?;
let ctx = validate_core_shapes(objects)?;
let (a, v, r) = objects.core.get_three_mut::<f64>(ATTR_A, ATTR_V, ATTR_R)?;
let a_data = &a.as_tensor().data;
v.as_tensor_mut()
.data
.par_chunks_mut(ctx.dim)
.enumerate()
.for_each(|(i, v_row)| {
if should_skip(&ctx, i) {
return;
}
let a_row = &a_data[i * ctx.dim..(i + 1) * ctx.dim];
for k in 0..ctx.dim {
v_row[k] += a_row[k] * dt;
}
});
{
let v_data = &v.as_tensor().data;
r.as_tensor_mut()
.data
.par_chunks_mut(ctx.dim)
.enumerate()
.for_each(|(i, r_row)| {
if should_skip(&ctx, i) {
return;
}
let v_row = &v_data[i * ctx.dim..(i + 1) * ctx.dim];
for k in 0..ctx.dim {
r_row[k] += v_row[k] * dt;
}
});
}
Ok(())
}
impl Integrator for ExplicitEuler {
fn apply(&mut self, objects: &mut PhysObj, dt: f64) -> Result<(), IntegratorError> {
apply_explicit_euler(objects, dt)
}
}
impl Integrator for SemiImplicitEuler {
fn apply(&mut self, objects: &mut PhysObj, dt: f64) -> Result<(), IntegratorError> {
apply_semi_implicit_euler(objects, dt)
}
}