use crate::engines::soa::phys_obj::{AttrsError, PhysObj};
use crate::models::particles::attrs::{ATTR_M_INV, ATTR_V, ParticleSelection};
use crate::models::particles::state::{
ParticleMasks, ParticleStateError, gather_alive_flags, gather_inverse_mass,
};
use rayon::prelude::*;
#[derive(Debug, Clone, PartialEq)]
pub enum ObserveError {
Attrs(AttrsError),
InvalidState {
field: &'static str,
value: f64,
},
InvalidAttrShape {
label: &'static str,
expected_dim: usize,
got_dim: usize,
},
InconsistentParticleCount {
label: &'static str,
expected: usize,
got: usize,
},
}
impl From<AttrsError> for ObserveError {
fn from(value: AttrsError) -> Self {
Self::Attrs(value)
}
}
impl From<ParticleStateError> for ObserveError {
fn from(value: ParticleStateError) -> Self {
match value {
ParticleStateError::Attrs(err) => Self::Attrs(err),
ParticleStateError::InvalidAttrShape {
label,
expected_dim,
got_dim,
} => Self::InvalidAttrShape {
label,
expected_dim,
got_dim,
},
ParticleStateError::InconsistentParticleCount {
label,
expected,
got,
} => Self::InconsistentParticleCount {
label,
expected,
got,
},
}
}
}
#[derive(Debug, Clone)]
struct KineticContext {
dim: usize,
n: usize,
v_data: Vec<f64>,
m_inv_values: Vec<f64>,
masks: ParticleMasks,
selection: ParticleSelection,
}
fn gather_kinetic_context(
objects: &PhysObj,
selection: ParticleSelection,
) -> Result<KineticContext, ObserveError> {
let (dim, n, v_data) = {
let v = objects.core.get::<f64>(ATTR_V)?;
(v.dim(), v.num_vectors(), v.as_tensor().data.clone())
};
Ok(KineticContext {
dim,
n,
v_data,
m_inv_values: gather_inverse_mass(objects, n)?,
masks: ParticleMasks {
alive: gather_alive_flags(objects, n, selection)?,
rigid: None,
},
selection,
})
}
fn kinetic_energy_from_context(ctx: &KineticContext) -> Result<f64, ObserveError> {
(0..ctx.n)
.into_par_iter()
.map(|i| -> Result<f64, ObserveError> {
if !ctx.masks.is_included(ctx.selection, i) {
return Ok(0.0);
}
let m_inv_i = ctx.m_inv_values[i];
if !m_inv_i.is_finite() || m_inv_i <= 0.0 {
return Err(ObserveError::InvalidState {
field: ATTR_M_INV,
value: m_inv_i,
});
}
let row = &ctx.v_data[i * ctx.dim..(i + 1) * ctx.dim];
let mut v2 = 0.0;
for &component in row {
if !component.is_finite() {
return Err(ObserveError::InvalidState {
field: ATTR_V,
value: component,
});
}
v2 += component * component;
}
Ok(0.5 * v2 / m_inv_i)
})
.try_reduce(|| 0.0, |a, b| Ok(a + b))
}
pub trait Observer {
type Output;
fn observe(&self, objects: &PhysObj) -> Result<Self::Output, ObserveError>;
}
#[derive(Debug, Clone, Copy)]
pub struct KineticEnergyObserver {
pub selection: ParticleSelection,
}
impl KineticEnergyObserver {
pub fn new(selection: ParticleSelection) -> Self {
Self { selection }
}
}
impl Default for KineticEnergyObserver {
fn default() -> Self {
Self {
selection: ParticleSelection::AliveOnly,
}
}
}
impl Observer for KineticEnergyObserver {
type Output = f64;
fn observe(&self, objects: &PhysObj) -> Result<Self::Output, ObserveError> {
let ctx = gather_kinetic_context(objects, self.selection)?;
kinetic_energy_from_context(&ctx)
}
}
#[derive(Debug, Clone, Copy)]
pub struct TemperatureObserver {
pub selection: ParticleSelection,
}
impl TemperatureObserver {
pub fn new(selection: ParticleSelection) -> Self {
Self { selection }
}
}
impl Default for TemperatureObserver {
fn default() -> Self {
Self {
selection: ParticleSelection::AliveOnly,
}
}
}
impl Observer for TemperatureObserver {
type Output = f64;
fn observe(&self, objects: &PhysObj) -> Result<Self::Output, ObserveError> {
let ctx = gather_kinetic_context(objects, self.selection)?;
let ke = kinetic_energy_from_context(&ctx)?;
let count = ctx.masks.included_count(ctx.selection, ctx.n);
if count == 0 || ctx.dim == 0 {
return Ok(0.0);
}
Ok((2.0 * ke) / ((count * ctx.dim) as f64))
}
}