#![expect(
clippy::cast_possible_truncation,
reason = "the necessary conversions are necessary and have been checked"
)]
#![expect(
clippy::cast_sign_loss,
reason = "the necessary conversions are necessary and have been checked"
)]
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use std::{array, cmp::Eq, fmt, hash::Hash, marker::PhantomData};
use hoomd_utility::valid::PositiveReal;
use hoomd_vector::Cartesian;
use super::{PointUpdate, PointsNearBall, WithSearchRadius, vec_cell};
#[serde_as]
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub(crate) struct CellIndex<const D: usize>(#[serde_as(as = "[_; D]")] pub [i64; D]);
#[serde_as]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct HashCell<K, const D: usize>
where
K: Eq + Hash,
{
cell_width: PositiveReal,
particle_indices: FxHashMap<CellIndex<D>, Vec<K>>,
cell_index: FxHashMap<K, CellIndex<D>>,
#[serde_as(as = "Vec<Vec<[_; D]>>")]
stencils: Vec<Vec<[i64; D]>>,
}
pub struct HashCellBuilder<K, const D: usize> {
nominal_search_radius: PositiveReal,
maximum_search_radius: f64,
phantom_key: PhantomData<K>,
}
impl<K, const D: usize> HashCellBuilder<K, D>
where
K: Copy + Eq + Hash,
{
#[inline]
#[must_use]
pub fn nominal_search_radius(mut self, nominal_search_radius: PositiveReal) -> Self {
self.nominal_search_radius = nominal_search_radius;
self
}
#[inline]
#[must_use]
pub fn maximum_search_radius(mut self, maximum_search_radius: f64) -> Self {
self.maximum_search_radius = maximum_search_radius;
self
}
#[inline]
#[must_use]
pub fn build(self) -> HashCell<K, D> {
let maximum_stencil_radius =
(self.maximum_search_radius / self.nominal_search_radius.get()).ceil() as u32;
HashCell {
cell_width: self.nominal_search_radius,
particle_indices: FxHashMap::default(),
cell_index: FxHashMap::default(),
stencils: vec_cell::generate_all_stencils(maximum_stencil_radius.max(1)),
}
}
}
impl<K, const D: usize> Default for HashCell<K, D>
where
K: Copy + Eq + Hash,
{
#[inline]
fn default() -> Self {
Self::builder().build()
}
}
impl<K, const D: usize> WithSearchRadius for HashCell<K, D>
where
K: Copy + Eq + Hash,
{
#[inline]
fn with_search_radius(radius: PositiveReal) -> Self {
Self::builder().nominal_search_radius(radius).build()
}
}
impl<K, const D: usize> HashCell<K, D>
where
K: Copy + Eq + Hash,
{
#[inline]
fn cell_index_from_position(&self, position: &Cartesian<D>) -> [i64; D] {
std::array::from_fn(|j| (position.coordinates[j] / self.cell_width.get()).floor() as i64)
}
#[inline]
pub fn shrink_to_fit(&mut self) {
self.particle_indices.retain(|_, v| !v.is_empty());
self.particle_indices.shrink_to_fit();
self.cell_index.shrink_to_fit();
}
#[expect(
clippy::missing_panics_doc,
reason = "hard-coded constant will never panic"
)]
#[inline]
#[must_use]
pub fn builder() -> HashCellBuilder<K, D> {
HashCellBuilder {
nominal_search_radius: 1.0
.try_into()
.expect("hard-coded constant is a positive real"),
maximum_search_radius: 1.0,
phantom_key: PhantomData,
}
}
}
impl<K, const D: usize> PointUpdate<Cartesian<D>, K> for HashCell<K, D>
where
K: Copy + Eq + Hash,
{
#[inline]
fn insert(&mut self, key: K, position: Cartesian<D>) {
let cell_idx = CellIndex(self.cell_index_from_position(&position));
let old_cell_index = self.cell_index.insert(key, cell_idx);
if old_cell_index != Some(cell_idx) {
self.particle_indices.entry(cell_idx).or_default().push(key);
if let Some(old_cell_index) = old_cell_index {
self.particle_indices
.entry(old_cell_index)
.and_modify(|particle_indices| {
if let Some(pos) = particle_indices.iter().position(|x| *x == key) {
particle_indices.swap_remove(pos);
}
});
}
}
}
#[inline]
fn remove(&mut self, key: &K) {
let cell_idx = self.cell_index.remove(key);
if let Some(cell_idx) = cell_idx {
self.particle_indices
.entry(cell_idx)
.and_modify(|particle_indices| {
if let Some(idx) = particle_indices.iter().position(|x| x == key) {
particle_indices.swap_remove(idx);
}
});
}
}
#[inline]
fn len(&self) -> usize {
self.cell_index.len()
}
#[inline]
fn is_empty(&self) -> bool {
self.cell_index.is_empty()
}
#[inline]
fn contains_key(&self, key: &K) -> bool {
self.cell_index.contains_key(key)
}
#[inline]
fn clear(&mut self) {
self.cell_index.clear();
self.particle_indices.clear();
}
}
struct PointsIterator<'a, K, const D: usize>
where
K: Eq + Hash,
{
keys: Option<&'a Vec<K>>,
cell_list: &'a HashCell<K, D>,
index_in_current_cell: usize,
current_stencil: usize,
stencil: &'a [[i64; D]],
center: [i64; D],
}
impl<K, const D: usize> Iterator for PointsIterator<'_, K, D>
where
K: Copy + Eq + Hash,
{
type Item = K;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(keys) = self.keys
&& self.index_in_current_cell < keys.len()
{
let last_index = self.index_in_current_cell;
self.index_in_current_cell += 1;
return Some(keys[last_index]);
}
self.index_in_current_cell = 0;
self.current_stencil += 1;
if self.current_stencil >= self.stencil.len() {
return None;
}
let cell_index =
array::from_fn(|i| self.center[i] + self.stencil[self.current_stencil][i]);
self.keys = self.cell_list.particle_indices.get(&CellIndex(cell_index));
}
}
}
impl<const D: usize, K> PointsNearBall<Cartesian<D>, K> for HashCell<K, D>
where
K: Copy + Eq + Hash,
{
#[inline]
fn points_near_ball(&self, position: &Cartesian<D>, radius: f64) -> impl Iterator<Item = K> {
let stencil_index = (radius / self.cell_width.get()).ceil() as usize - 1;
assert!(
stencil_index < self.stencils.len(),
"search radius must be less than or equal to the maximum search radius"
);
let center = self.cell_index_from_position(position);
let stencil = &self.stencils[stencil_index];
PointsIterator {
keys: self.particle_indices.get(&CellIndex(center)),
cell_list: self,
index_in_current_cell: 0,
current_stencil: 0,
stencil,
center,
}
}
}
impl<K, const D: usize> fmt::Display for HashCell<K, D>
where
K: Eq + Hash,
{
#[allow(
clippy::missing_inline_in_public_items,
reason = "no need to inline display"
)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let largest_cell_size = self
.particle_indices
.values()
.map(Vec::len)
.fold(0, usize::max);
writeln!(f, "HashCell<K, {D}>:")?;
writeln!(f, "- {} total cells.", self.particle_indices.len(),)?;
writeln!(f, "- {} points.", self.cell_index.len())?;
writeln!(
f,
"- Nominal, maximum search radii: {}, {}",
self.cell_width,
self.cell_width.get() * self.stencils.len() as f64
)?;
write!(f, "- Largest cell length: {largest_cell_size}")
}
}
#[expect(
clippy::used_underscore_binding,
reason = "Used for const parameterization."
)]
#[cfg(test)]
mod tests {
use assert2::{assert, check};
use rand::{
RngExt, SeedableRng,
distr::{Distribution, Uniform},
rngs::StdRng,
};
use rstest::*;
use super::*;
use hoomd_vector::{Metric, distribution::Ball};
#[test]
fn test_cell_index() {
let cell_list = HashCell::<usize, 3>::builder()
.nominal_search_radius(
2.0.try_into()
.expect("hard-coded constant is a positive real"),
)
.build();
check!(cell_list.cell_index_from_position(&[0.0, 0.0, 0.0].into()) == [0, 0, 0]);
check!(cell_list.cell_index_from_position(&[2.0, 0.0, 0.0].into()) == [1, 0, 0]);
check!(cell_list.cell_index_from_position(&[0.0, 2.0, 0.0].into()) == [0, 1, 0]);
check!(cell_list.cell_index_from_position(&[0.0, 0.0, 2.0].into()) == [0, 0, 1]);
check!(cell_list.cell_index_from_position(&[-41.5, 18.5, -0.125].into()) == [-21, 9, -1]);
}
#[test]
fn test_insert_one() {
let mut cell_list = HashCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
assert2::assert!(let Some(keys) = keys);
check!(keys.len() == 1);
check!(keys.contains(&0));
}
#[test]
fn test_insert_many() {
let mut cell_list = HashCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.995, 0.897]));
cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
check!(cell_list.cell_index.get(&1) == Some(&CellIndex([0, 0])));
check!(cell_list.cell_index.get(&2) == Some(&CellIndex([-1, 3])));
let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
assert2::assert!(let Some(keys) = keys);
check!(keys.len() == 2);
check!(keys.contains(&0));
check!(keys.contains(&1));
let keys = cell_list.particle_indices.get(&CellIndex([-1, 3]));
assert2::assert!(let Some(keys) = keys);
check!(keys.len() == 1);
check!(keys.contains(&2));
}
#[test]
fn test_insert_again_same() {
let mut cell_list = HashCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(0, Cartesian::from([0.25, 0.5]));
cell_list.insert(0, Cartesian::from([0.5, 0.75]));
check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
assert2::assert!(let Some(keys) = keys);
check!(keys.len() == 1);
check!(keys.contains(&0));
}
#[test]
fn test_insert_again_different() {
let mut cell_list = HashCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.25, 0.5]));
cell_list.insert(1, Cartesian::from([-0.5, -0.75]));
check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
check!(cell_list.cell_index.get(&1) == Some(&CellIndex([-1, -1])));
let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
assert2::assert!(let Some(keys) = keys);
check!(keys.len() == 1);
check!(keys.contains(&0));
let keys = cell_list.particle_indices.get(&CellIndex([-1, -1]));
assert2::assert!(let Some(keys) = keys);
check!(keys.len() == 1);
check!(keys.contains(&1));
}
#[test]
fn test_remove() {
let mut cell_list = HashCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.995, 0.897]));
cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
cell_list.remove(&1);
cell_list.remove(&2);
check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
check!(cell_list.cell_index.get(&1) == None);
check!(cell_list.cell_index.get(&2) == None);
let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
assert2::assert!(let Some(keys) = keys);
check!(keys.len() == 1);
check!(keys.contains(&0));
let keys = cell_list.particle_indices.get(&CellIndex([-1, 3]));
assert2::assert!(let Some(keys) = keys);
assert!(keys.len() == 0);
}
#[test]
fn test_clear() {
let mut cell_list = HashCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.995, 0.897]));
cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
cell_list.clear();
check!(cell_list.cell_index.len() == 0);
check!(cell_list.particle_indices.len() == 0);
}
#[test]
fn test_shrink_to_fit() {
let mut cell_list = HashCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.995, 0.897]));
cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
cell_list.remove(&1);
cell_list.remove(&2);
cell_list.shrink_to_fit();
check!(cell_list.particle_indices.len() == 1);
let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
assert2::assert!(let Some(keys) = keys);
check!(keys.len() == 1);
check!(keys.contains(&0));
}
#[test]
fn consistency() {
const N_STEPS: usize = 65_536;
let mut rng = StdRng::seed_from_u64(0);
let mut reference = FxHashMap::default();
let cell_width = 0.5;
let mut cell_list = HashCell::builder()
.nominal_search_radius(
cell_width
.try_into()
.expect("hard-coded value should be positive"),
)
.build();
let position_distribution = Ball {
radius: 20.0.try_into().expect("hardcoded value should be positive"),
};
let key_distribution =
Uniform::new(0, N_STEPS / 4).expect("hardcoded distribution should be valid");
for _ in 0..N_STEPS {
if rng.random_bool(0.7) {
let position: Cartesian<3> = position_distribution.sample(&mut rng);
let key = key_distribution.sample(&mut rng);
cell_list.insert(key, position);
reference.insert(key, cell_list.cell_index_from_position(&position));
} else {
let key = key_distribution.sample(&mut rng);
cell_list.remove(&key);
reference.remove(&key);
}
}
assert!(cell_list.cell_index.len() == reference.len());
for (reference_key, reference_value) in reference.drain() {
let value = cell_list.cell_index.get(&reference_key);
assert!(value == Some(&CellIndex(reference_value)));
let keys = cell_list.particle_indices.get(&CellIndex(reference_value));
assert2::assert!(let Some(keys) = keys);
check!(keys.contains(&reference_key));
}
let total = cell_list.particle_indices.values().map(Vec::len).sum();
check!(cell_list.cell_index.len() == total);
check!(total > 2000);
}
#[test]
fn test_outside() {
let mut cell_list = HashCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.995, 0.897]));
cell_list.insert(2, Cartesian::from([8.125, 0.0]));
let potential_neighbors: Vec<_> = cell_list
.points_near_ball(&[9.125, 0.0].into(), 1.0)
.collect();
assert!(potential_neighbors.len() == 1);
check!(potential_neighbors[0] == 2);
}
#[rstest]
#[case::d_2(PhantomData::<HashCell<usize, 2>>)]
#[case::d_3(PhantomData::<HashCell<usize, 3>>)]
fn test_points_near_ball<const D: usize>(
#[case] _d: PhantomData<HashCell<usize, D>>,
#[values(1.0, 0.5, 0.25)] nominal_search_radius: f64,
) {
let mut rng = StdRng::seed_from_u64(0);
let mut reference = Vec::new();
let cell_width = 1.0;
let mut cell_list = HashCell::builder()
.nominal_search_radius(
nominal_search_radius
.try_into()
.expect("hardcoded value should be positive"),
)
.maximum_search_radius(1.0)
.build();
let position_distribution = Ball {
radius: 12.0.try_into().expect("hardcoded value should be positive"),
};
let n = 2048;
for key in 0..n {
let position: Cartesian<D> = position_distribution.sample(&mut rng);
cell_list.insert(key, position);
reference.push(position);
}
let mut n_neighbors = 0;
for p_i in &reference {
let potential_neighbors: Vec<_> = cell_list.points_near_ball(p_i, cell_width).collect();
for (j, p_j) in reference.iter().enumerate() {
if p_i.distance(p_j) <= cell_width {
check!(potential_neighbors.contains(&j));
n_neighbors += 1;
}
}
}
check!(n_neighbors >= n * 2);
}
}