1#![expect(
5 clippy::cast_possible_truncation,
6 reason = "the necessary conversions are necessary and have been checked"
7)]
8#![expect(
9 clippy::cast_sign_loss,
10 reason = "the necessary conversions are necessary and have been checked"
11)]
12
13use rustc_hash::FxHashMap;
16use serde::{Deserialize, Serialize};
17use serde_with::serde_as;
18use std::{array, cmp::Eq, fmt, hash::Hash, marker::PhantomData};
19
20use hoomd_utility::valid::PositiveReal;
21use hoomd_vector::Cartesian;
22
23use super::{PointUpdate, PointsNearBall, WithSearchRadius, vec_cell};
24
25#[serde_as]
30#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
31pub(crate) struct CellIndex<const D: usize>(#[serde_as(as = "[_; D]")] pub [i64; D]);
32
33#[serde_as]
65#[derive(Clone, Debug, Serialize, Deserialize)]
66pub struct HashCell<K, const D: usize>
67where
68 K: Eq + Hash,
69{
70 cell_width: PositiveReal,
72
73 particle_indices: FxHashMap<CellIndex<D>, Vec<K>>,
75
76 cell_index: FxHashMap<K, CellIndex<D>>,
78
79 #[serde_as(as = "Vec<Vec<[_; D]>>")]
81 stencils: Vec<Vec<[i64; D]>>,
82}
83
84pub struct HashCellBuilder<K, const D: usize> {
100 nominal_search_radius: PositiveReal,
102
103 maximum_search_radius: f64,
105
106 phantom_key: PhantomData<K>,
108}
109
110impl<K, const D: usize> HashCellBuilder<K, D>
111where
112 K: Copy + Eq + Hash,
113{
114 #[inline]
131 #[must_use]
132 pub fn nominal_search_radius(mut self, nominal_search_radius: PositiveReal) -> Self {
133 self.nominal_search_radius = nominal_search_radius;
134 self
135 }
136
137 #[inline]
156 #[must_use]
157 pub fn maximum_search_radius(mut self, maximum_search_radius: f64) -> Self {
158 self.maximum_search_radius = maximum_search_radius;
159 self
160 }
161
162 #[inline]
177 #[must_use]
178 pub fn build(self) -> HashCell<K, D> {
179 let maximum_stencil_radius =
180 (self.maximum_search_radius / self.nominal_search_radius.get()).ceil() as u32;
181
182 HashCell {
183 cell_width: self.nominal_search_radius,
184 particle_indices: FxHashMap::default(),
185 cell_index: FxHashMap::default(),
186 stencils: vec_cell::generate_all_stencils(maximum_stencil_radius.max(1)),
187 }
188 }
189}
190
191impl<K, const D: usize> Default for HashCell<K, D>
192where
193 K: Copy + Eq + Hash,
194{
195 #[inline]
207 fn default() -> Self {
208 Self::builder().build()
209 }
210}
211
212impl<K, const D: usize> WithSearchRadius for HashCell<K, D>
213where
214 K: Copy + Eq + Hash,
215{
216 #[inline]
231 fn with_search_radius(radius: PositiveReal) -> Self {
232 Self::builder().nominal_search_radius(radius).build()
233 }
234}
235
236impl<K, const D: usize> HashCell<K, D>
237where
238 K: Copy + Eq + Hash,
239{
240 #[inline]
242 fn cell_index_from_position(&self, position: &Cartesian<D>) -> [i64; D] {
243 std::array::from_fn(|j| (position.coordinates[j] / self.cell_width.get()).floor() as i64)
244 }
245
246 #[inline]
248 pub fn shrink_to_fit(&mut self) {
249 self.particle_indices.retain(|_, v| !v.is_empty());
250 self.particle_indices.shrink_to_fit();
251 self.cell_index.shrink_to_fit();
252 }
253
254 #[expect(
272 clippy::missing_panics_doc,
273 reason = "hard-coded constant will never panic"
274 )]
275 #[inline]
276 #[must_use]
277 pub fn builder() -> HashCellBuilder<K, D> {
278 HashCellBuilder {
279 nominal_search_radius: 1.0
280 .try_into()
281 .expect("hard-coded constant is a positive real"),
282 maximum_search_radius: 1.0,
283 phantom_key: PhantomData,
284 }
285 }
286}
287
288impl<K, const D: usize> PointUpdate<Cartesian<D>, K> for HashCell<K, D>
289where
290 K: Copy + Eq + Hash,
291{
292 #[inline]
303 fn insert(&mut self, key: K, position: Cartesian<D>) {
304 let cell_idx = CellIndex(self.cell_index_from_position(&position));
305 let old_cell_index = self.cell_index.insert(key, cell_idx);
306 if old_cell_index != Some(cell_idx) {
308 self.particle_indices.entry(cell_idx).or_default().push(key);
310
311 if let Some(old_cell_index) = old_cell_index {
312 self.particle_indices
314 .entry(old_cell_index)
315 .and_modify(|particle_indices| {
316 if let Some(pos) = particle_indices.iter().position(|x| *x == key) {
317 particle_indices.swap_remove(pos);
318 }
319 });
320 }
321 }
322 }
323
324 #[inline]
336 fn remove(&mut self, key: &K) {
337 let cell_idx = self.cell_index.remove(key);
338 if let Some(cell_idx) = cell_idx {
339 self.particle_indices
341 .entry(cell_idx)
342 .and_modify(|particle_indices| {
343 if let Some(idx) = particle_indices.iter().position(|x| x == key) {
345 particle_indices.swap_remove(idx);
347 }
348 });
349 }
350 }
351
352 #[inline]
364 fn len(&self) -> usize {
365 self.cell_index.len()
366 }
367
368 #[inline]
381 fn is_empty(&self) -> bool {
382 self.cell_index.is_empty()
383 }
384
385 #[inline]
395 fn contains_key(&self, key: &K) -> bool {
396 self.cell_index.contains_key(key)
397 }
398
399 #[inline]
411 fn clear(&mut self) {
412 self.cell_index.clear();
413 self.particle_indices.clear();
414 }
415}
416
417struct PointsIterator<'a, K, const D: usize>
419where
420 K: Eq + Hash,
421{
422 keys: Option<&'a Vec<K>>,
424
425 cell_list: &'a HashCell<K, D>,
427
428 index_in_current_cell: usize,
430
431 current_stencil: usize,
433
434 stencil: &'a [[i64; D]],
436
437 center: [i64; D],
439}
440
441impl<K, const D: usize> Iterator for PointsIterator<'_, K, D>
442where
443 K: Copy + Eq + Hash,
444{
445 type Item = K;
446
447 #[inline]
448 fn next(&mut self) -> Option<Self::Item> {
449 loop {
450 if let Some(keys) = self.keys
451 && self.index_in_current_cell < keys.len()
452 {
453 let last_index = self.index_in_current_cell;
454 self.index_in_current_cell += 1;
455 return Some(keys[last_index]);
456 }
457
458 self.index_in_current_cell = 0;
459 self.current_stencil += 1;
460
461 if self.current_stencil >= self.stencil.len() {
462 return None;
463 }
464
465 let cell_index =
466 array::from_fn(|i| self.center[i] + self.stencil[self.current_stencil][i]);
467 self.keys = self.cell_list.particle_indices.get(&CellIndex(cell_index));
468 }
469 }
470}
471
472impl<const D: usize, K> PointsNearBall<Cartesian<D>, K> for HashCell<K, D>
473where
474 K: Copy + Eq + Hash,
475{
476 #[inline]
507 fn points_near_ball(&self, position: &Cartesian<D>, radius: f64) -> impl Iterator<Item = K> {
508 let stencil_index = (radius / self.cell_width.get()).ceil() as usize - 1;
509 assert!(
510 stencil_index < self.stencils.len(),
511 "search radius must be less than or equal to the maximum search radius"
512 );
513
514 let center = self.cell_index_from_position(position);
515 let stencil = &self.stencils[stencil_index];
516
517 PointsIterator {
518 keys: self.particle_indices.get(&CellIndex(center)),
519 cell_list: self,
520 index_in_current_cell: 0,
521 current_stencil: 0,
522 stencil,
523 center,
524 }
525 }
526}
527
528impl<K, const D: usize> fmt::Display for HashCell<K, D>
529where
530 K: Eq + Hash,
531{
532 #[allow(
548 clippy::missing_inline_in_public_items,
549 reason = "no need to inline display"
550 )]
551 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552 let largest_cell_size = self
553 .particle_indices
554 .values()
555 .map(Vec::len)
556 .fold(0, usize::max);
557
558 writeln!(f, "HashCell<K, {D}>:")?;
559 writeln!(f, "- {} total cells.", self.particle_indices.len(),)?;
560 writeln!(f, "- {} points.", self.cell_index.len())?;
561 writeln!(
562 f,
563 "- Nominal, maximum search radii: {}, {}",
564 self.cell_width,
565 self.cell_width.get() * self.stencils.len() as f64
566 )?;
567 write!(f, "- Largest cell length: {largest_cell_size}")
568 }
569}
570#[expect(
571 clippy::used_underscore_binding,
572 reason = "Used for const parameterization."
573)]
574#[cfg(test)]
575mod tests {
576 use assert2::{assert, check};
577 use rand::{
578 RngExt, SeedableRng,
579 distr::{Distribution, Uniform},
580 rngs::StdRng,
581 };
582 use rstest::*;
583
584 use super::*;
585 use hoomd_vector::{Metric, distribution::Ball};
586
587 #[test]
588 fn test_cell_index() {
589 let cell_list = HashCell::<usize, 3>::builder()
590 .nominal_search_radius(
591 2.0.try_into()
592 .expect("hard-coded constant is a positive real"),
593 )
594 .build();
595 check!(cell_list.cell_index_from_position(&[0.0, 0.0, 0.0].into()) == [0, 0, 0]);
596 check!(cell_list.cell_index_from_position(&[2.0, 0.0, 0.0].into()) == [1, 0, 0]);
597 check!(cell_list.cell_index_from_position(&[0.0, 2.0, 0.0].into()) == [0, 1, 0]);
598 check!(cell_list.cell_index_from_position(&[0.0, 0.0, 2.0].into()) == [0, 0, 1]);
599 check!(cell_list.cell_index_from_position(&[-41.5, 18.5, -0.125].into()) == [-21, 9, -1]);
600 }
601
602 #[test]
603 fn test_insert_one() {
604 let mut cell_list = HashCell::default();
605
606 cell_list.insert(0, Cartesian::from([0.125, 0.25]));
607
608 check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
609
610 let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
611 assert2::assert!(let Some(keys) = keys);
612 check!(keys.len() == 1);
613 check!(keys.contains(&0));
614 }
615
616 #[test]
617 fn test_insert_many() {
618 let mut cell_list = HashCell::default();
619
620 cell_list.insert(0, Cartesian::from([0.125, 0.25]));
621 cell_list.insert(1, Cartesian::from([0.995, 0.897]));
622 cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
623
624 check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
625 check!(cell_list.cell_index.get(&1) == Some(&CellIndex([0, 0])));
626 check!(cell_list.cell_index.get(&2) == Some(&CellIndex([-1, 3])));
627
628 let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
629 assert2::assert!(let Some(keys) = keys);
630 check!(keys.len() == 2);
631 check!(keys.contains(&0));
632 check!(keys.contains(&1));
633
634 let keys = cell_list.particle_indices.get(&CellIndex([-1, 3]));
635 assert2::assert!(let Some(keys) = keys);
636 check!(keys.len() == 1);
637 check!(keys.contains(&2));
638 }
639
640 #[test]
641 fn test_insert_again_same() {
642 let mut cell_list = HashCell::default();
643
644 cell_list.insert(0, Cartesian::from([0.125, 0.25]));
645 cell_list.insert(0, Cartesian::from([0.25, 0.5]));
646 cell_list.insert(0, Cartesian::from([0.5, 0.75]));
647
648 check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
649
650 let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
651 assert2::assert!(let Some(keys) = keys);
652 check!(keys.len() == 1);
653 check!(keys.contains(&0));
654 }
655
656 #[test]
657 fn test_insert_again_different() {
658 let mut cell_list = HashCell::default();
659
660 cell_list.insert(0, Cartesian::from([0.125, 0.25]));
661 cell_list.insert(1, Cartesian::from([0.25, 0.5]));
662 cell_list.insert(1, Cartesian::from([-0.5, -0.75]));
663
664 check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
665 check!(cell_list.cell_index.get(&1) == Some(&CellIndex([-1, -1])));
666
667 let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
668 assert2::assert!(let Some(keys) = keys);
669 check!(keys.len() == 1);
670 check!(keys.contains(&0));
671
672 let keys = cell_list.particle_indices.get(&CellIndex([-1, -1]));
673 assert2::assert!(let Some(keys) = keys);
674 check!(keys.len() == 1);
675 check!(keys.contains(&1));
676 }
677
678 #[test]
679 fn test_remove() {
680 let mut cell_list = HashCell::default();
681
682 cell_list.insert(0, Cartesian::from([0.125, 0.25]));
683 cell_list.insert(1, Cartesian::from([0.995, 0.897]));
684 cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
685
686 cell_list.remove(&1);
687 cell_list.remove(&2);
688
689 check!(cell_list.cell_index.get(&0) == Some(&CellIndex([0, 0])));
690 check!(cell_list.cell_index.get(&1) == None);
691 check!(cell_list.cell_index.get(&2) == None);
692
693 let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
694 assert2::assert!(let Some(keys) = keys);
695 check!(keys.len() == 1);
696 check!(keys.contains(&0));
697
698 let keys = cell_list.particle_indices.get(&CellIndex([-1, 3]));
699 assert2::assert!(let Some(keys) = keys);
700 assert!(keys.len() == 0);
701 }
702
703 #[test]
704 fn test_clear() {
705 let mut cell_list = HashCell::default();
706
707 cell_list.insert(0, Cartesian::from([0.125, 0.25]));
708 cell_list.insert(1, Cartesian::from([0.995, 0.897]));
709 cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
710
711 cell_list.clear();
712
713 check!(cell_list.cell_index.len() == 0);
714 check!(cell_list.particle_indices.len() == 0);
715 }
716
717 #[test]
718 fn test_shrink_to_fit() {
719 let mut cell_list = HashCell::default();
720
721 cell_list.insert(0, Cartesian::from([0.125, 0.25]));
722 cell_list.insert(1, Cartesian::from([0.995, 0.897]));
723 cell_list.insert(2, Cartesian::from([-0.125, 3.25]));
724
725 cell_list.remove(&1);
726 cell_list.remove(&2);
727
728 cell_list.shrink_to_fit();
729 check!(cell_list.particle_indices.len() == 1);
730
731 let keys = cell_list.particle_indices.get(&CellIndex([0, 0]));
732 assert2::assert!(let Some(keys) = keys);
733 check!(keys.len() == 1);
734 check!(keys.contains(&0));
735 }
736
737 #[test]
738 fn consistency() {
739 const N_STEPS: usize = 65_536;
740 let mut rng = StdRng::seed_from_u64(0);
741 let mut reference = FxHashMap::default();
742
743 let cell_width = 0.5;
744 let mut cell_list = HashCell::builder()
745 .nominal_search_radius(
746 cell_width
747 .try_into()
748 .expect("hard-coded value should be positive"),
749 )
750 .build();
751 let position_distribution = Ball {
752 radius: 20.0.try_into().expect("hardcoded value should be positive"),
753 };
754 let key_distribution =
755 Uniform::new(0, N_STEPS / 4).expect("hardcoded distribution should be valid");
756
757 for _ in 0..N_STEPS {
758 if rng.random_bool(0.7) {
760 let position: Cartesian<3> = position_distribution.sample(&mut rng);
761 let key = key_distribution.sample(&mut rng);
762
763 cell_list.insert(key, position);
764 reference.insert(key, cell_list.cell_index_from_position(&position));
765 } else {
766 let key = key_distribution.sample(&mut rng);
767 cell_list.remove(&key);
768 reference.remove(&key);
769 }
770 }
771
772 assert!(cell_list.cell_index.len() == reference.len());
775 for (reference_key, reference_value) in reference.drain() {
776 let value = cell_list.cell_index.get(&reference_key);
777 assert!(value == Some(&CellIndex(reference_value)));
778
779 let keys = cell_list.particle_indices.get(&CellIndex(reference_value));
780 assert2::assert!(let Some(keys) = keys);
781 check!(keys.contains(&reference_key));
782 }
783
784 let total = cell_list.particle_indices.values().map(Vec::len).sum();
786 check!(cell_list.cell_index.len() == total);
787 check!(total > 2000);
788 }
789
790 #[test]
791 fn test_outside() {
792 let mut cell_list = HashCell::default();
793
794 cell_list.insert(0, Cartesian::from([0.125, 0.25]));
795 cell_list.insert(1, Cartesian::from([0.995, 0.897]));
796 cell_list.insert(2, Cartesian::from([8.125, 0.0]));
797
798 let potential_neighbors: Vec<_> = cell_list
799 .points_near_ball(&[9.125, 0.0].into(), 1.0)
800 .collect();
801 assert!(potential_neighbors.len() == 1);
802 check!(potential_neighbors[0] == 2);
803 }
804
805 #[rstest]
806 #[case::d_2(PhantomData::<HashCell<usize, 2>>)]
807 #[case::d_3(PhantomData::<HashCell<usize, 3>>)]
808 fn test_points_near_ball<const D: usize>(
809 #[case] _d: PhantomData<HashCell<usize, D>>,
810 #[values(1.0, 0.5, 0.25)] nominal_search_radius: f64,
811 ) {
812 let mut rng = StdRng::seed_from_u64(0);
813 let mut reference = Vec::new();
814
815 let cell_width = 1.0;
816 let mut cell_list = HashCell::builder()
817 .nominal_search_radius(
818 nominal_search_radius
819 .try_into()
820 .expect("hardcoded value should be positive"),
821 )
822 .maximum_search_radius(1.0)
823 .build();
824 let position_distribution = Ball {
825 radius: 12.0.try_into().expect("hardcoded value should be positive"),
826 };
827
828 let n = 2048;
829
830 for key in 0..n {
831 let position: Cartesian<D> = position_distribution.sample(&mut rng);
832
833 cell_list.insert(key, position);
834 reference.push(position);
835 }
836
837 let mut n_neighbors = 0;
838 for p_i in &reference {
839 let potential_neighbors: Vec<_> = cell_list.points_near_ball(p_i, cell_width).collect();
840
841 for (j, p_j) in reference.iter().enumerate() {
842 if p_i.distance(p_j) <= cell_width {
843 check!(potential_neighbors.contains(&j));
844 n_neighbors += 1;
845 }
846 }
847 }
848 check!(n_neighbors >= n * 2);
849 }
850}