#![expect(
clippy::cast_possible_truncation,
reason = "the necessary conversions are necessary and have been checked"
)]
#![expect(
clippy::cast_sign_loss,
reason = "the necessary conversions are necessary and have been checked"
)]
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use std::{array, cmp::Eq, fmt, hash::Hash, iter, marker::PhantomData, mem};
use log::trace;
use rustc_hash::FxHashMap;
use hoomd_utility::valid::PositiveReal;
use hoomd_vector::Cartesian;
use super::{PointUpdate, PointsNearBall, WithSearchRadius};
use crate::{IndexFromPosition, hash_cell::CellIndex};
#[serde_as]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VecCell<K, const D: usize>
where
K: Eq + Hash,
{
cell_width: PositiveReal,
keys_map: Vec<Vec<K>>,
cell_index: FxHashMap<K, CellIndex<D>>,
half_extent: u32,
#[serde_as(as = "Vec<Vec<[_; D]>>")]
stencils: Vec<Vec<[i64; D]>>,
}
struct CellIndexIterator<const D: usize> {
cell_index: [i64; D],
half_extent: u32,
}
impl<const D: usize> CellIndexIterator<D> {
fn cube(half_extent: u32) -> Self {
let mut cell_index = [-i64::from(half_extent); D];
cell_index[D - 1] -= 1;
Self {
cell_index,
half_extent,
}
}
#[inline]
fn increment_cell_index(&mut self) -> Option<[i64; D]> {
self.cell_index[D - 1] += 1;
for i in (0..D).rev() {
if self.cell_index[i] > self.half_extent.into() {
if i == 0 {
return None;
}
self.cell_index[i] = -(i64::from(self.half_extent));
self.cell_index[i - 1] += 1;
}
}
Some(self.cell_index)
}
}
impl<const D: usize> Iterator for CellIndexIterator<D> {
type Item = [i64; D];
fn next(&mut self) -> Option<Self::Item> {
self.increment_cell_index()
}
}
fn generate_stencil<const D: usize>(radius: u32) -> Vec<[i64; D]> {
assert!(radius >= 1, "cell list must have a minimum radius of 1");
let mut result = CellIndexIterator::cube(radius).collect::<Vec<_>>();
result.sort_by(|a, b| {
let r_a = a.iter().map(|x| x.pow(2));
let r_b = b.iter().map(|x| x.pow(2));
r_a.cmp(r_b)
});
result
}
pub(crate) fn generate_all_stencils<const D: usize>(max_radius: u32) -> Vec<Vec<[i64; D]>> {
assert!(max_radius >= 1, "cell list must have a minimum radius of 1");
let mut result = Vec::new();
for radius in 0..max_radius {
result.push(generate_stencil(radius + 1));
}
result
}
pub struct VecCellBuilder<K, const D: usize> {
nominal_search_radius: PositiveReal,
maximum_search_radius: f64,
phantom_key: PhantomData<K>,
}
impl<K, const D: usize> VecCellBuilder<K, D>
where
K: Copy + Eq + Hash,
{
#[inline]
#[must_use]
pub fn nominal_search_radius(mut self, nominal_search_radius: PositiveReal) -> Self {
self.nominal_search_radius = nominal_search_radius;
self
}
#[inline]
#[must_use]
pub fn maximum_search_radius(mut self, maximum_search_radius: f64) -> Self {
self.maximum_search_radius = maximum_search_radius;
self
}
#[inline]
#[must_use]
pub fn build(self) -> VecCell<K, D> {
let maximum_stencil_radius =
(self.maximum_search_radius / self.nominal_search_radius.get()).ceil() as u32;
let half_extent: u32 = 1;
VecCell {
cell_width: self.nominal_search_radius,
keys_map: iter::repeat_n(Vec::new(), (half_extent * 2 + 1).pow(D as u32) as usize)
.collect(),
cell_index: FxHashMap::default(),
half_extent,
stencils: generate_all_stencils(maximum_stencil_radius.max(1)),
}
}
}
impl<K, const D: usize> Default for VecCell<K, D>
where
K: Copy + Eq + Hash,
{
#[inline]
fn default() -> Self {
Self::builder().build()
}
}
impl<K, const D: usize> WithSearchRadius for VecCell<K, D>
where
K: Copy + Eq + Hash,
{
#[inline]
fn with_search_radius(radius: PositiveReal) -> Self {
Self::builder()
.nominal_search_radius(radius)
.maximum_search_radius(radius.get())
.build()
}
}
impl<K, const D: usize> VecCell<K, D>
where
K: Eq + Hash,
{
#[inline]
fn cell_index_from_position(&self, position: &Cartesian<D>) -> [i64; D] {
std::array::from_fn(|j| (position.coordinates[j] / self.cell_width.get()).floor() as i64)
}
#[inline]
fn map_index_from_cell(half_extent: u32, cell_index: &[i64; D]) -> Option<usize> {
assert!(D > 1);
let mut vec_index: usize = 0;
let mut width = 1;
for i in (0..D).rev() {
let needed_extent = cell_index[i].unsigned_abs();
if needed_extent > u64::from(half_extent) {
return None;
}
let v: usize = (cell_index[i] + i64::from(half_extent))
.try_into()
.expect("cell index should be in bounds");
vec_index += v * width;
width *= (half_extent * 2 + 1) as usize;
}
Some(vec_index)
}
#[cfg(test)]
#[inline]
fn get_keys(&self, cell_index: &[i64; D]) -> &[K] {
let index = Self::map_index_from_cell(self.half_extent, cell_index)
.expect("cell_index should be in bounds");
&self.keys_map[index]
}
}
impl<K, const D: usize> VecCell<K, D>
where
K: Copy + Eq + Hash,
{
#[expect(
clippy::missing_panics_doc,
reason = "hard-coded constant will never panic"
)]
#[inline]
#[must_use]
pub fn builder() -> VecCellBuilder<K, D> {
VecCellBuilder {
nominal_search_radius: 1.0
.try_into()
.expect("hard-coded constant is a positive real"),
maximum_search_radius: 1.0,
phantom_key: PhantomData,
}
}
#[inline]
pub fn shrink_to_fit(&mut self) {
for keys in &mut self.keys_map {
keys.shrink_to_fit();
}
self.keys_map.shrink_to_fit();
self.cell_index.shrink_to_fit();
}
fn expand_to(&mut self, target: u32) {
if self.half_extent >= target {
return;
}
let mut new_half_extent = self.half_extent.min(1) * 2;
while new_half_extent < target {
new_half_extent *= 2;
}
trace!("Expanding to {}^{} cells", new_half_extent * 2 + 1, D);
let mut new_keys_map: Vec<Vec<K>> =
iter::repeat_n(Vec::new(), (new_half_extent * 2 + 1).pow(D as u32) as usize).collect();
let old_half_extent = self.half_extent;
let old_keys_map = &mut self.keys_map;
for old_cell_index in CellIndexIterator::cube(old_half_extent) {
let old_vec_index = Self::map_index_from_cell(old_half_extent, &old_cell_index)
.expect("cell_index should be consistent with keys_map");
let new_vec_index = Self::map_index_from_cell(new_half_extent, &old_cell_index)
.expect("old_cell_index should be inside the new keys_map");
new_keys_map[new_vec_index] = mem::take(&mut old_keys_map[old_vec_index]);
}
self.half_extent = new_half_extent;
self.keys_map = new_keys_map;
}
}
impl<K, const D: usize> PointUpdate<Cartesian<D>, K> for VecCell<K, D>
where
K: Copy + Eq + Hash,
{
#[inline]
fn insert(&mut self, key: K, position: Cartesian<D>) {
let cell_index = self.cell_index_from_position(&position);
let old_cell_index = self.cell_index.insert(key, CellIndex(cell_index));
let map_index =
Self::map_index_from_cell(self.half_extent, &cell_index).unwrap_or_else(|| {
let max_half_extent = cell_index
.iter()
.map(|x| x.unsigned_abs())
.reduce(u64::max)
.expect("D should be greater than 1");
self.expand_to(
max_half_extent
.try_into()
.expect("max extent cannot exceed u32::MAX"),
);
Self::map_index_from_cell(self.half_extent, &cell_index)
.expect("cell_index should be in the expanded VecCell")
});
if old_cell_index != Some(CellIndex(cell_index)) {
self.keys_map[map_index].push(key);
if let Some(old_cell_index) = old_cell_index {
let old_map_index = Self::map_index_from_cell(self.half_extent, &old_cell_index.0)
.expect("cell_index and keys_map should agree");
let old_keys = &mut self.keys_map[old_map_index];
if let Some(pos) = old_keys.iter().position(|x| *x == key) {
old_keys.swap_remove(pos);
}
}
}
}
#[inline]
fn remove(&mut self, key: &K) {
let cell_index = self.cell_index.remove(key);
if let Some(cell_index) = cell_index {
let map_index = Self::map_index_from_cell(self.half_extent, &cell_index.0);
if let Some(map_index) = map_index {
let keys = &mut self.keys_map[map_index];
if let Some(idx) = keys.iter().position(|x| x == key) {
keys.swap_remove(idx);
}
}
}
}
#[inline]
fn len(&self) -> usize {
self.cell_index.len()
}
#[inline]
fn is_empty(&self) -> bool {
self.cell_index.is_empty()
}
#[inline]
fn contains_key(&self, key: &K) -> bool {
self.cell_index.contains_key(key)
}
#[inline]
fn clear(&mut self) {
self.cell_index.clear();
for keys in &mut self.keys_map {
keys.clear();
}
}
}
struct PointsIterator<'a, K, const D: usize>
where
K: Eq + Hash,
{
keys: Option<&'a Vec<K>>,
cell_list: &'a VecCell<K, D>,
index_in_current_cell: usize,
current_stencil: usize,
stencil: &'a [[i64; D]],
center: [i64; D],
}
impl<K, const D: usize> Iterator for PointsIterator<'_, K, D>
where
K: Copy + Eq + Hash,
{
type Item = K;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(keys) = self.keys
&& self.index_in_current_cell < keys.len()
{
let last_index = self.index_in_current_cell;
self.index_in_current_cell += 1;
return Some(keys[last_index]);
}
self.index_in_current_cell = 0;
self.current_stencil += 1;
if self.current_stencil >= self.stencil.len() {
return None;
}
let cell_index =
array::from_fn(|i| self.center[i] + self.stencil[self.current_stencil][i]);
let map_index =
VecCell::<K, D>::map_index_from_cell(self.cell_list.half_extent, &cell_index);
self.keys = map_index.map(|index| &self.cell_list.keys_map[index]);
}
}
}
impl<const D: usize, K> PointsNearBall<Cartesian<D>, K> for VecCell<K, D>
where
K: Copy + Eq + Hash,
{
#[inline]
fn points_near_ball(&self, position: &Cartesian<D>, radius: f64) -> impl Iterator<Item = K> {
let stencil_index = (radius / self.cell_width.get()).ceil() as usize - 1;
assert!(
stencil_index < self.stencils.len(),
"search radius must be less than or equal to the maximum search radius"
);
let center = self.cell_index_from_position(position);
let stencil = &self.stencils[stencil_index];
let map_index = Self::map_index_from_cell(
self.half_extent,
&array::from_fn(|i| center[i] + stencil[0][i]),
);
PointsIterator {
keys: map_index.map(|index| &self.keys_map[index]),
cell_list: self,
index_in_current_cell: 0,
current_stencil: 0,
stencil,
center,
}
}
}
impl<K, const D: usize> fmt::Display for VecCell<K, D>
where
K: Eq + Hash,
{
#[allow(
clippy::missing_inline_in_public_items,
reason = "no need to inline display"
)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let largest_cell_size = self.keys_map.iter().map(Vec::len).fold(0, usize::max);
writeln!(f, "VecCell<K, {D}>:")?;
writeln!(
f,
"- {} total cells with {} cells on each side.",
self.keys_map.len(),
self.half_extent * 2 + 1
)?;
writeln!(f, "- {} points.", self.cell_index.len())?;
writeln!(
f,
"- Nominal, maximum search radii: {}, {}",
self.cell_width,
self.cell_width.get() * self.stencils.len() as f64
)?;
write!(f, "- Largest cell length: {largest_cell_size}")
}
}
impl<K, const D: usize> IndexFromPosition<Cartesian<D>> for VecCell<K, D>
where
K: Eq + Hash,
{
type Location = [i64; D];
#[inline]
fn location_from_position(&self, position: &Cartesian<D>) -> Self::Location {
self.cell_index_from_position(position)
}
}
#[expect(
clippy::used_underscore_binding,
reason = "Used for const parameterization."
)]
#[cfg(test)]
mod tests {
use assert2::{assert, check};
use rand::{
RngExt, SeedableRng,
distr::{Distribution, Uniform},
rngs::StdRng,
};
use rstest::*;
use super::*;
use hoomd_vector::{Metric, distribution::Ball};
#[test]
fn test_increment_cell_index() {
let mut cells = CellIndexIterator::cube(1);
check!(cells.next() == Some([-1, -1]));
check!(cells.next() == Some([-1, 0]));
check!(cells.next() == Some([-1, 1]));
check!(cells.next() == Some([0, -1]));
check!(cells.next() == Some([0, 0]));
check!(cells.next() == Some([0, 1]));
check!(cells.next() == Some([1, -1]));
check!(cells.next() == Some([1, 0]));
check!(cells.next() == Some([1, 1]));
check!(cells.next() == None);
check!(cells.next() == None);
check!(cells.next() == None);
check!(cells.next() == None);
let mut c = CellIndexIterator {
cell_index: [1, 2, 2],
half_extent: 2,
};
check!(c.increment_cell_index() == Some([2, -2, -2]));
let mut c = CellIndexIterator {
cell_index: [0, 1, 2],
half_extent: 2,
};
check!(c.increment_cell_index() == Some([0, 2, -2]));
let mut c = CellIndexIterator {
cell_index: [0, 0, -2],
half_extent: 2,
};
check!(c.increment_cell_index() == Some([0, 0, -1]));
let mut c = CellIndexIterator {
cell_index: [2, 2, 2],
half_extent: 2,
};
check!(c.increment_cell_index() == None);
}
#[test]
fn test_cell_index() {
let cell_list = VecCell::<usize, 3>::builder()
.nominal_search_radius(
2.0.try_into()
.expect("hard-coded constant should be positive"),
)
.build();
check!(cell_list.cell_index_from_position(&[0.0, 0.0, 0.0].into()) == [0, 0, 0]);
check!(cell_list.cell_index_from_position(&[2.0, 0.0, 0.0].into()) == [1, 0, 0]);
check!(cell_list.cell_index_from_position(&[0.0, 2.0, 0.0].into()) == [0, 1, 0]);
check!(cell_list.cell_index_from_position(&[0.0, 0.0, 2.0].into()) == [0, 0, 1]);
check!(cell_list.cell_index_from_position(&[-41.5, 18.5, -0.125].into()) == [-21, 9, -1]);
}
#[test]
fn test_insert_one() {
let mut cell_list = VecCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
assert!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
let keys = cell_list.get_keys(&[0, 0]);
assert!(keys.len() == 1);
assert!(keys.contains(&0));
}
#[test]
fn test_insert_many() {
let mut cell_list = VecCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.995, 0.897]));
cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
check!(cell_list.cell_index.get(&1) == Some(&CellIndex([0, 0])));
check!(cell_list.cell_index.get(&2) == Some(&CellIndex([-1, 3])));
let keys = cell_list.get_keys(&[0, 0]);
assert!(keys.len() == 2);
check!(keys.contains(&0));
check!(keys.contains(&1));
let keys = cell_list.get_keys(&[-1, 3]);
assert!(keys.len() == 1);
check!(keys.contains(&2));
}
#[test]
fn test_insert_again_same() {
let mut cell_list = VecCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(0, Cartesian::from([0.25, 0.5]));
cell_list.insert(0, Cartesian::from([0.5, 0.75]));
check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
let keys = cell_list.get_keys(&[0, 0]);
assert!(keys.len() == 1);
check!(keys.contains(&0));
}
#[test]
fn test_insert_again_different() {
let mut cell_list = VecCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.25, 0.5]));
cell_list.insert(1, Cartesian::from([-0.5, -0.75]));
check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
check!(cell_list.cell_index.get(&1) == Some(&CellIndex([-1, -1])));
let keys = cell_list.get_keys(&[0, 0]);
assert!(keys.len() == 1);
check!(keys.contains(&0));
let keys = cell_list.get_keys(&[-1, -1]);
assert!(keys.len() == 1);
check!(keys.contains(&1));
}
#[test]
fn test_remove() {
let mut cell_list = VecCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.995, 0.897]));
cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
cell_list.remove(&1);
cell_list.remove(&2);
check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
check!(cell_list.cell_index.get(&1) == None);
check!(cell_list.cell_index.get(&2) == None);
let keys = cell_list.get_keys(&[0, 0]);
assert!(keys.len() == 1);
check!(keys.contains(&0));
let keys = cell_list.get_keys(&[-1, 3]);
check!(keys.is_empty());
}
#[test]
fn test_clear() {
let mut cell_list = VecCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.995, 0.897]));
cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
cell_list.clear();
check!(cell_list.cell_index.is_empty());
for keys in cell_list.keys_map {
check!(keys.is_empty());
}
}
#[test]
fn test_shrink_to_fit() {
let mut cell_list = VecCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.995, 0.897]));
cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
cell_list.insert(3, Cartesian::from([10.995, -12.897]));
cell_list.insert(4, Cartesian::from([15.125, 19.25]));
cell_list.remove(&1);
cell_list.remove(&2);
cell_list.remove(&3);
cell_list.remove(&4);
cell_list.shrink_to_fit();
let filled_cell_index =
VecCell::<usize, 2>::map_index_from_cell(cell_list.half_extent, &[0, 0])
.expect("hard-coded cell is valid");
for (i, keys) in cell_list.keys_map.iter().enumerate() {
if i == filled_cell_index {
check!(keys.capacity() == 1);
check!(keys.len() == 1);
} else {
check!(keys.capacity() == 0);
check!(keys.len() == 0);
}
}
let keys = cell_list.get_keys(&[0, 0]);
assert!(keys.len() == 1);
check!(keys.contains(&0));
}
#[rstest]
fn consistency() {
const N_STEPS: usize = 65_536;
let mut rng = StdRng::seed_from_u64(0);
let mut reference = FxHashMap::default();
let cell_width = 0.5;
let mut cell_list = VecCell::builder()
.nominal_search_radius(
cell_width
.try_into()
.expect("hard-coded cell with should be positive"),
)
.build();
let position_distribution = Ball {
radius: 20.0.try_into().expect("hardcoded value should be positive"),
};
let key_distribution =
Uniform::new(0, N_STEPS / 4).expect("hardcoded distribution should be valid");
for _ in 0..N_STEPS {
if rng.random_bool(0.7) {
let position: Cartesian<3> = position_distribution.sample(&mut rng);
let key = key_distribution.sample(&mut rng);
cell_list.insert(key, position);
reference.insert(key, cell_list.cell_index_from_position(&position));
} else {
let key = key_distribution.sample(&mut rng);
cell_list.remove(&key);
reference.remove(&key);
}
}
assert!(cell_list.cell_index.len() == reference.len());
for (reference_key, reference_value) in reference.drain() {
let value = cell_list.cell_index.get(&reference_key);
check!(value == Some(&CellIndex(reference_value)));
let keys = cell_list.get_keys(&reference_value);
check!(keys.contains(&reference_key));
}
let total = cell_list.keys_map.iter().map(Vec::len).sum();
check!(cell_list.cell_index.len() == total);
check!(total > 2000);
}
#[test]
fn test_outside() {
let mut cell_list = VecCell::default();
cell_list.insert(0, Cartesian::from([0.125, 0.25]));
cell_list.insert(1, Cartesian::from([0.995, 0.897]));
cell_list.insert(2, Cartesian::from([8.125, 0.0]));
check!(cell_list.half_extent == 8);
let potential_neighbors: Vec<_> = cell_list
.points_near_ball(&[9.125, 0.0].into(), 1.0)
.collect();
assert!(potential_neighbors.len() == 1);
check!(potential_neighbors[0] == 2);
}
#[rstest]
#[case::d_2(PhantomData::<VecCell<usize, 2>>)]
#[case::d_3(PhantomData::<VecCell<usize, 3>>)]
fn test_points_near_ball<const D: usize>(
#[case] _d: PhantomData<VecCell<usize, D>>,
#[values(1.0, 0.5, 0.25)] nominal_search_radius: f64,
) {
let mut rng = StdRng::seed_from_u64(0);
let mut reference = Vec::new();
let mut cell_list = VecCell::builder()
.nominal_search_radius(
nominal_search_radius
.try_into()
.expect("hardcoded value should be positive"),
)
.maximum_search_radius(1.0)
.build();
let position_distribution = Ball {
radius: 12.0.try_into().expect("hardcoded value should be positive"),
};
let n = 2048;
for key in 0..n {
let position: Cartesian<D> = position_distribution.sample(&mut rng);
cell_list.insert(key, position);
reference.push(position);
}
let mut n_neighbors = 0;
for p_i in &reference {
let potential_neighbors: Vec<_> = cell_list.points_near_ball(p_i, 1.0).collect();
for (j, p_j) in reference.iter().enumerate() {
if p_i.distance(p_j) <= 1.0 {
check!(potential_neighbors.contains(&j));
n_neighbors += 1;
}
}
}
check!(n_neighbors >= n * 2);
}
}