Skip to main content

oxiphysics_gpu/cell_list/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5#![allow(clippy::needless_range_loop)]
6use rayon::prelude::*;
7
8use super::types::GpuCellList;
9
10/// Expand a 10-bit integer by inserting 2 zero bits after each bit.
11///
12/// `x` must be in `[0, 1023]`.
13pub(super) fn expand_bits(mut x: u32) -> u32 {
14    x &= 0x000003FF;
15    x = (x | (x << 16)) & 0x030000FF;
16    x = (x | (x << 8)) & 0x0300F00F;
17    x = (x | (x << 4)) & 0x030C30C3;
18    x = (x | (x << 2)) & 0x09249249;
19    x
20}
21/// Compute a 30-bit Morton code from 3D integer coordinates.
22///
23/// Each coordinate must be in `[0, 1023]`.
24pub fn morton_encode(x: u32, y: u32, z: u32) -> u32 {
25    expand_bits(x) | (expand_bits(y) << 1) | (expand_bits(z) << 2)
26}
27/// Decode a 30-bit Morton code back to 3D integer coordinates.
28pub fn morton_decode(code: u32) -> (u32, u32, u32) {
29    (
30        compact_bits(code),
31        compact_bits(code >> 1),
32        compact_bits(code >> 2),
33    )
34}
35/// Compact every third bit (inverse of expand_bits).
36pub(super) fn compact_bits(mut x: u32) -> u32 {
37    x &= 0x09249249;
38    x = (x | (x >> 2)) & 0x030C30C3;
39    x = (x | (x >> 4)) & 0x0300F00F;
40    x = (x | (x >> 8)) & 0x030000FF;
41    x = (x | (x >> 16)) & 0x000003FF;
42    x
43}
44/// Compute Morton codes for a set of positions and sort indices by Morton code.
45///
46/// `positions` - particle positions
47/// `box_min` - minimum corner of the bounding box
48/// `box_max` - maximum corner of the bounding box
49///
50/// Returns `(sorted_indices, morton_codes)` where `sorted_indices[i]` is the
51/// original index of the particle that should be in position `i` after sorting.
52pub fn morton_sort(
53    positions: &[[f64; 3]],
54    box_min: [f64; 3],
55    box_max: [f64; 3],
56) -> (Vec<usize>, Vec<u32>) {
57    let range = [
58        (box_max[0] - box_min[0]).max(1e-10),
59        (box_max[1] - box_min[1]).max(1e-10),
60        (box_max[2] - box_min[2]).max(1e-10),
61    ];
62    let mut codes: Vec<(u32, usize)> = positions
63        .par_iter()
64        .enumerate()
65        .map(|(i, p)| {
66            let x = (((p[0] - box_min[0]) / range[0] * 1023.0) as u32).min(1023);
67            let y = (((p[1] - box_min[1]) / range[1] * 1023.0) as u32).min(1023);
68            let z = (((p[2] - box_min[2]) / range[2] * 1023.0) as u32).min(1023);
69            (morton_encode(x, y, z), i)
70        })
71        .collect();
72    codes.sort_by_key(|&(code, _)| code);
73    let sorted_indices: Vec<usize> = codes.iter().map(|&(_, idx)| idx).collect();
74    let morton_codes: Vec<u32> = codes.iter().map(|&(code, _)| code).collect();
75    (sorted_indices, morton_codes)
76}
77/// Compute an exclusive prefix sum (scan) over `counts`.
78pub fn parallel_prefix_sum(counts: &[usize]) -> Vec<usize> {
79    let mut out = Vec::with_capacity(counts.len());
80    let mut acc = 0usize;
81    for &c in counts {
82        out.push(acc);
83        acc += c;
84    }
85    out
86}
87/// Compute the bounding box of a set of positions.
88///
89/// Returns `(min, max)` corner coordinates.
90pub fn compute_bounding_box(positions: &[[f64; 3]]) -> ([f64; 3], [f64; 3]) {
91    if positions.is_empty() {
92        return ([0.0; 3], [0.0; 3]);
93    }
94    let mut min = positions[0];
95    let mut max = positions[0];
96    for p in positions {
97        for d in 0..3 {
98            if p[d] < min[d] {
99                min[d] = p[d];
100            }
101            if p[d] > max[d] {
102                max[d] = p[d];
103            }
104        }
105    }
106    (min, max)
107}
108/// Reorder an array according to a permutation.
109///
110/// `perm[i]` is the original index of the element that should appear at position `i`.
111pub fn reorder_by_permutation<T: Clone>(data: &[T], perm: &[usize]) -> Vec<T> {
112    perm.iter().map(|&i| data[i].clone()).collect()
113}
114/// Mock radix sort: stable sort of `keys` returning `(sorted_keys, sorted_indices)`.
115///
116/// `sorted_indices[i]` is the original index of the element now at position `i`.
117/// This is a CPU reference implementation matching the expected GPU radix sort output.
118pub fn radix_sort_mock(keys: &[u32]) -> (Vec<u32>, Vec<usize>) {
119    if keys.is_empty() {
120        return (vec![], vec![]);
121    }
122    let mut indexed: Vec<(u32, usize)> = keys.iter().copied().zip(0..).collect();
123    indexed.sort_by_key(|&(k, _)| k);
124    let sorted_keys: Vec<u32> = indexed.iter().map(|&(k, _)| k).collect();
125    let sorted_indices: Vec<usize> = indexed.iter().map(|&(_, i)| i).collect();
126    (sorted_keys, sorted_indices)
127}
128/// Exclusive prefix sum (scan) of `counts`.
129///
130/// For input `[a, b, c, d]` the output is `[0, a, a+b, a+b+c]`.
131/// This mirrors what a parallel GPU prefix-sum kernel would produce.
132pub fn gpu_prefix_sum(counts: &[usize]) -> Vec<usize> {
133    let mut out = Vec::with_capacity(counts.len());
134    let mut running = 0usize;
135    for &c in counts {
136        out.push(running);
137        running += c;
138    }
139    out
140}
141/// Count how many particles fall into each cell of a regular grid.
142///
143/// The grid has `n_cells = [nx, ny, nz]` cells, each of side length `cell_size`.
144/// The origin is `[0, 0, 0]`.  Particles outside the grid are clamped to the
145/// nearest boundary cell.
146///
147/// Returns a flat Vec of length `nx * ny * nz`.
148pub fn parallel_count_particles(
149    positions: &[[f64; 3]],
150    n_cells: [usize; 3],
151    cell_size: f64,
152) -> Vec<usize> {
153    let [nx, ny, nz] = n_cells;
154    let total = nx * ny * nz;
155    let mut counts = vec![0usize; total];
156    for p in positions {
157        let ix = ((p[0] / cell_size) as isize).clamp(0, nx as isize - 1) as usize;
158        let iy = ((p[1] / cell_size) as isize).clamp(0, ny as isize - 1) as usize;
159        let iz = ((p[2] / cell_size) as isize).clamp(0, nz as isize - 1) as usize;
160        counts[ix + nx * (iy + ny * iz)] += 1;
161    }
162    counts
163}
164/// Distribute `n_cells` cells evenly across `n_gpus` GPUs.
165///
166/// Returns a Vec of non-overlapping contiguous ranges that together cover
167/// `0..n_cells`.  Remainder cells are spread among the first GPUs.
168pub fn distribute_cells_to_gpus(n_cells: usize, n_gpus: usize) -> Vec<std::ops::Range<usize>> {
169    if n_gpus == 0 || n_cells == 0 {
170        return vec![];
171    }
172    let base = n_cells / n_gpus;
173    let remainder = n_cells % n_gpus;
174    let mut ranges = Vec::with_capacity(n_gpus);
175    let mut start = 0;
176    for gpu in 0..n_gpus {
177        let extra = if gpu < remainder { 1 } else { 0 };
178        let end = start + base + extra;
179        ranges.push(start..end);
180        start = end;
181    }
182    ranges
183}
184/// Return all particle pairs `(i, j)` with `i < j` and `dist(i, j) < cutoff`.
185///
186/// Uses the [`GpuCellList`] for candidate acceleration, filtering by exact
187/// distance afterward.  The returned pairs are in undefined order.
188pub fn gpu_neighbor_search_kernel(
189    cl: &GpuCellList,
190    positions: &[[f64; 3]],
191    cutoff: f64,
192) -> Vec<(usize, usize)> {
193    let mut pairs = Vec::new();
194    cl.for_each_pair(positions, cutoff, |i, j, _d2| {
195        let (a, b) = if i < j { (i, j) } else { (j, i) };
196        pairs.push((a, b));
197    });
198    pairs.sort_unstable();
199    pairs.dedup();
200    pairs
201}
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::cell_list::CellList;
206
207    use crate::cell_list::GpuCellList;
208
209    use crate::cell_list::SpatialHash;
210
211    #[test]
212    fn test_prefix_sum_empty() {
213        assert_eq!(parallel_prefix_sum(&[]), Vec::<usize>::new());
214    }
215    #[test]
216    fn test_prefix_sum_basic() {
217        let counts = [1usize, 2, 3, 4];
218        let result = parallel_prefix_sum(&counts);
219        assert_eq!(result, vec![0, 1, 3, 6]);
220    }
221    #[test]
222    fn test_cell_index_clamp() {
223        let list = GpuCellList::new([4, 4, 4], 1.0, [4.0, 4.0, 4.0]);
224        let idx = list.cell_index([4.5, 4.5, 4.5]);
225        assert_eq!(idx, 3 * 4 * 4 + 3 * 4 + 3);
226    }
227    #[test]
228    fn test_total_cells() {
229        let list = GpuCellList::new([3, 4, 5], 1.0, [3.0, 4.0, 5.0]);
230        assert_eq!(list.total_cells(), 60);
231    }
232    #[test]
233    fn test_build_parallel_counts() {
234        let positions: Vec<[f64; 3]> = vec![
235            [0.5, 0.5, 0.5],
236            [1.5, 0.5, 0.5],
237            [0.5, 1.5, 0.5],
238            [1.5, 1.5, 0.5],
239            [0.5, 0.5, 1.5],
240            [1.5, 0.5, 1.5],
241            [0.5, 1.5, 1.5],
242            [1.5, 1.5, 1.5],
243        ];
244        let cl = GpuCellList::build_parallel(&positions);
245        assert_eq!(cl.sorted_indices.len(), 8);
246        for c in 0..cl.total_cells() {
247            assert_eq!(cl.cell_counts[c], 1);
248        }
249    }
250    #[test]
251    fn test_neighbors_in_radius() {
252        let positions: Vec<[f64; 3]> = vec![[0.5, 0.5, 0.5], [0.6, 0.5, 0.5], [5.0, 5.0, 5.0]];
253        let cl = GpuCellList::build_parallel(&positions);
254        let mut neighbours = cl.neighbors_in_radius(&positions, [0.5, 0.5, 0.5], 0.5);
255        neighbours.sort_unstable();
256        assert!(neighbours.contains(&0));
257        assert!(neighbours.contains(&1));
258        assert!(!neighbours.contains(&2));
259    }
260    #[test]
261    fn cell_list_find_neighbors_all_pairs() {
262        let positions: Vec<[f64; 3]> = vec![
263            [1.0, 1.0, 1.0],
264            [1.2, 1.0, 1.0],
265            [1.0, 1.3, 1.0],
266            [9.0, 9.0, 9.0],
267        ];
268        let cl = CellList::build(&positions);
269        let radius = 0.5;
270        let mut neighbours = cl.find_neighbors([1.0, 1.0, 1.0], radius);
271        neighbours.sort_unstable();
272        assert!(neighbours.contains(&0), "should find self: {neighbours:?}");
273        assert!(
274            neighbours.contains(&1),
275            "should find particle 1: {neighbours:?}"
276        );
277        assert!(
278            neighbours.contains(&2),
279            "should find particle 2: {neighbours:?}"
280        );
281        assert!(
282            !neighbours.contains(&3),
283            "particle 3 is far: {neighbours:?}"
284        );
285    }
286    #[test]
287    fn cell_list_new_compiles() {
288        let cl = CellList::new([10.0, 10.0, 10.0], 2.0);
289        assert_eq!(cl.inner.total_cells(), 125);
290    }
291    /// Morton encode/decode should be inverses.
292    #[test]
293    fn test_morton_roundtrip() {
294        let test_cases = [
295            (0, 0, 0),
296            (1, 0, 0),
297            (0, 1, 0),
298            (0, 0, 1),
299            (7, 3, 5),
300            (1023, 1023, 1023),
301            (512, 256, 128),
302        ];
303        for (x, y, z) in test_cases {
304            let code = morton_encode(x, y, z);
305            let (dx, dy, dz) = morton_decode(code);
306            assert_eq!(dx, x, "x mismatch for ({x},{y},{z})");
307            assert_eq!(dy, y, "y mismatch for ({x},{y},{z})");
308            assert_eq!(dz, z, "z mismatch for ({x},{y},{z})");
309        }
310    }
311    /// Morton codes should preserve locality: nearby points have nearby codes.
312    #[test]
313    fn test_morton_locality() {
314        let c1 = morton_encode(1, 1, 1);
315        let c2 = morton_encode(2, 1, 1);
316        let c_far = morton_encode(100, 100, 100);
317        let d_near = c1.abs_diff(c2);
318        let d_far = c1.abs_diff(c_far);
319        assert!(
320            d_near < d_far,
321            "near distance {d_near} should be less than far {d_far}"
322        );
323    }
324    /// Morton sort should produce a valid permutation.
325    #[test]
326    fn test_morton_sort_permutation() {
327        let positions = vec![[5.0, 5.0, 5.0], [1.0, 1.0, 1.0], [3.0, 3.0, 3.0]];
328        let (indices, codes) = morton_sort(&positions, [0.0; 3], [10.0, 10.0, 10.0]);
329        assert_eq!(indices.len(), 3);
330        assert_eq!(codes.len(), 3);
331        for i in 0..codes.len() - 1 {
332            assert!(codes[i] <= codes[i + 1], "codes not sorted at {i}");
333        }
334        let mut sorted = indices.clone();
335        sorted.sort();
336        assert_eq!(sorted, vec![0, 1, 2]);
337    }
338    /// Spatial hash should find nearby particles.
339    #[test]
340    fn test_spatial_hash_query() {
341        let positions = vec![[0.5, 0.5, 0.5], [0.6, 0.5, 0.5], [5.0, 5.0, 5.0]];
342        let mut hash = SpatialHash::new(64, 1.0);
343        hash.build(&positions);
344        assert_eq!(hash.len(), 3);
345        let mut neighbours = hash.query_radius(&positions, [0.5, 0.5, 0.5], 0.5);
346        neighbours.sort_unstable();
347        neighbours.dedup();
348        assert!(neighbours.contains(&0));
349        assert!(neighbours.contains(&1));
350        assert!(!neighbours.contains(&2));
351    }
352    /// Empty spatial hash.
353    #[test]
354    fn test_spatial_hash_empty() {
355        let hash = SpatialHash::new(64, 1.0);
356        assert!(hash.is_empty());
357        assert_eq!(hash.len(), 0);
358    }
359    /// Spatial hash clear.
360    #[test]
361    fn test_spatial_hash_clear() {
362        let mut hash = SpatialHash::new(64, 1.0);
363        hash.insert(0, [0.5, 0.5, 0.5]);
364        assert!(!hash.is_empty());
365        hash.clear();
366        assert!(hash.is_empty());
367    }
368    /// Bounding box computation.
369    #[test]
370    fn test_bounding_box() {
371        let positions = vec![[1.0, 2.0, 3.0], [4.0, 0.0, 1.0], [2.0, 5.0, 2.0]];
372        let (min, max) = compute_bounding_box(&positions);
373        assert_eq!(min, [1.0, 0.0, 1.0]);
374        assert_eq!(max, [4.0, 5.0, 3.0]);
375    }
376    /// Empty bounding box.
377    #[test]
378    fn test_bounding_box_empty() {
379        let (min, max) = compute_bounding_box(&[]);
380        assert_eq!(min, [0.0; 3]);
381        assert_eq!(max, [0.0; 3]);
382    }
383    /// Reorder by permutation.
384    #[test]
385    fn test_reorder() {
386        let data = vec![10, 20, 30, 40];
387        let perm = vec![3, 1, 0, 2];
388        let reordered = reorder_by_permutation(&data, &perm);
389        assert_eq!(reordered, vec![40, 20, 10, 30]);
390    }
391    /// Max cell occupancy.
392    #[test]
393    fn test_max_cell_occupancy() {
394        let positions: Vec<[f64; 3]> = vec![[0.5, 0.5, 0.5], [0.6, 0.5, 0.5], [5.0, 5.0, 5.0]];
395        let cl = GpuCellList::build_parallel(&positions);
396        let max = cl.max_cell_occupancy();
397        assert!(max >= 2, "max occupancy should be at least 2, got {max}");
398    }
399    /// Non-empty cells count.
400    #[test]
401    fn test_nonempty_cells() {
402        let positions: Vec<[f64; 3]> = vec![[0.5, 0.5, 0.5], [5.0, 5.0, 5.0]];
403        let cl = GpuCellList::build_parallel(&positions);
404        let ne = cl.num_nonempty_cells();
405        assert_eq!(ne, 2, "should have 2 non-empty cells, got {ne}");
406    }
407    /// for_each_pair should find all pairs within cutoff.
408    #[test]
409    fn test_for_each_pair() {
410        let positions: Vec<[f64; 3]> = vec![[0.5, 0.5, 0.5], [0.6, 0.5, 0.5], [5.0, 5.0, 5.0]];
411        let cl = GpuCellList::build_parallel(&positions);
412        let mut pairs = Vec::new();
413        cl.for_each_pair(&positions, 0.5, |i, j, _d2| {
414            pairs.push((i.min(j), i.max(j)));
415        });
416        pairs.sort();
417        pairs.dedup();
418        assert!(
419            pairs.contains(&(0, 1)),
420            "should find pair (0,1), got {pairs:?}"
421        );
422        assert!(
423            !pairs.iter().any(|&(a, b)| a == 2 || b == 2),
424            "should not find pairs with particle 2"
425        );
426    }
427    #[test]
428    fn test_radix_sort_sorted_output() {
429        let keys = vec![5u32, 1, 9, 3, 7, 2];
430        let (sorted_keys, sorted_indices) = radix_sort_mock(&keys);
431        for i in 0..sorted_keys.len() - 1 {
432            assert!(
433                sorted_keys[i] <= sorted_keys[i + 1],
434                "radix sort not sorted at {i}"
435            );
436        }
437        for &idx in &sorted_indices {
438            assert!(idx < keys.len(), "invalid index {idx}");
439        }
440    }
441    #[test]
442    fn test_radix_sort_permutation_correct() {
443        let keys = vec![30u32, 10, 20];
444        let (sorted_keys, sorted_indices) = radix_sort_mock(&keys);
445        assert_eq!(sorted_keys[0], 10);
446        assert_eq!(sorted_keys[1], 20);
447        assert_eq!(sorted_keys[2], 30);
448        assert_eq!(sorted_indices[0], 1);
449        assert_eq!(sorted_indices[1], 2);
450        assert_eq!(sorted_indices[2], 0);
451    }
452    #[test]
453    fn test_radix_sort_empty() {
454        let keys: Vec<u32> = vec![];
455        let (sk, si) = radix_sort_mock(&keys);
456        assert!(sk.is_empty());
457        assert!(si.is_empty());
458    }
459    #[test]
460    fn test_radix_sort_all_equal() {
461        let keys = vec![7u32; 10];
462        let (sorted_keys, sorted_indices) = radix_sort_mock(&keys);
463        assert_eq!(sorted_keys.len(), 10);
464        assert!(sorted_keys.iter().all(|&k| k == 7));
465        assert_eq!(sorted_indices.len(), 10);
466    }
467    #[test]
468    fn test_gpu_prefix_sum_basic() {
469        let counts = vec![0usize, 1, 3, 0, 2, 5];
470        let result = gpu_prefix_sum(&counts);
471        assert_eq!(result, vec![0, 0, 1, 4, 4, 6]);
472    }
473    #[test]
474    fn test_gpu_prefix_sum_all_zeros() {
475        let counts = vec![0usize; 5];
476        let result = gpu_prefix_sum(&counts);
477        assert_eq!(result, vec![0, 0, 0, 0, 0]);
478    }
479    #[test]
480    fn test_gpu_prefix_sum_single() {
481        let result = gpu_prefix_sum(&[7usize]);
482        assert_eq!(result, vec![0]);
483    }
484    #[test]
485    fn test_parallel_cell_counting() {
486        let positions = vec![[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [2.5, 0.5, 0.5]];
487        let n_cells = [4usize, 4, 4];
488        let counts = parallel_count_particles(&positions, n_cells, 1.0);
489        let total: usize = counts.iter().sum();
490        assert_eq!(total, 3, "total count should equal number of particles");
491    }
492    #[test]
493    fn test_parallel_cell_counting_all_in_one_cell() {
494        let positions: Vec<[f64; 3]> = vec![[0.1, 0.1, 0.1], [0.2, 0.1, 0.1], [0.1, 0.2, 0.1]];
495        let counts = parallel_count_particles(&positions, [4, 4, 4], 1.0);
496        let max = counts.iter().cloned().max().unwrap_or(0);
497        assert!(max >= 3, "all particles in one cell: max_count={max}");
498    }
499    #[test]
500    fn test_multi_gpu_distribution_two_gpus() {
501        let n_cells = 100;
502        let n_gpus = 2;
503        let ranges = distribute_cells_to_gpus(n_cells, n_gpus);
504        assert_eq!(ranges.len(), n_gpus);
505        assert_eq!(ranges[0].start, 0);
506        assert_eq!(ranges[n_gpus - 1].end, n_cells);
507        for i in 0..n_gpus - 1 {
508            assert_eq!(ranges[i].end, ranges[i + 1].start, "gap at gpu {i}");
509        }
510    }
511    #[test]
512    fn test_multi_gpu_distribution_odd_cells() {
513        let ranges = distribute_cells_to_gpus(7, 3);
514        assert_eq!(ranges.len(), 3);
515        let total: usize = ranges.iter().map(|r| r.end - r.start).sum();
516        assert_eq!(total, 7);
517    }
518    #[test]
519    fn test_multi_gpu_distribution_single_gpu() {
520        let ranges = distribute_cells_to_gpus(50, 1);
521        assert_eq!(ranges.len(), 1);
522        assert_eq!(ranges[0].start, 0);
523        assert_eq!(ranges[0].end, 50);
524    }
525    #[test]
526    fn test_neighbor_search_kernel_finds_close_pair() {
527        let positions: Vec<[f64; 3]> = vec![[1.0, 1.0, 1.0], [1.1, 1.0, 1.0], [5.0, 5.0, 5.0]];
528        let cl = GpuCellList::build_parallel(&positions);
529        let pairs = gpu_neighbor_search_kernel(&cl, &positions, 0.5);
530        assert!(
531            pairs.contains(&(0, 1)) || pairs.contains(&(1, 0)),
532            "should find pair (0,1), got {pairs:?}"
533        );
534        assert!(
535            !pairs.iter().any(|&(a, b)| a == 2 || b == 2),
536            "particle 2 should not appear in pairs"
537        );
538    }
539    #[test]
540    fn test_neighbor_search_kernel_no_pairs() {
541        let positions: Vec<[f64; 3]> = vec![[0.0, 0.0, 0.0], [10.0, 0.0, 0.0], [20.0, 0.0, 0.0]];
542        let cl = GpuCellList::build_parallel(&positions);
543        let pairs = gpu_neighbor_search_kernel(&cl, &positions, 0.5);
544        assert!(pairs.is_empty(), "well-separated particles → no pairs");
545    }
546    #[test]
547    fn test_spatial_hash_rebuild() {
548        let mut hash = SpatialHash::new(128, 1.0);
549        let positions1 = vec![[0.5, 0.5, 0.5], [1.5, 0.5, 0.5]];
550        hash.build(&positions1);
551        assert_eq!(hash.len(), 2);
552        let positions2 = vec![[0.1, 0.1, 0.1]];
553        hash.build(&positions2);
554        assert_eq!(hash.len(), 1, "rebuild should replace old data");
555    }
556    #[test]
557    fn test_spatial_hash_large_number_of_particles() {
558        let positions: Vec<[f64; 3]> = (0..200).map(|i| [i as f64 * 0.1, 0.0, 0.0]).collect();
559        let mut hash = SpatialHash::new(256, 1.0);
560        hash.build(&positions);
561        assert_eq!(hash.len(), 200);
562    }
563}
564/// Sort particle indices by Morton code using Rayon for parallel code generation.
565///
566/// This variant uses Rayon for parallel code generation, then a sequential
567/// sort for the final ordering (radix sort on a GPU would be fully parallel).
568///
569/// Returns `(sorted_indices, sorted_morton_codes)`.
570pub fn parallel_morton_sort(
571    positions: &[[f64; 3]],
572    box_min: [f64; 3],
573    box_max: [f64; 3],
574) -> (Vec<usize>, Vec<u32>) {
575    let range = [
576        (box_max[0] - box_min[0]).max(1e-10),
577        (box_max[1] - box_min[1]).max(1e-10),
578        (box_max[2] - box_min[2]).max(1e-10),
579    ];
580    let mut code_index_pairs: Vec<(u32, usize)> = positions
581        .par_iter()
582        .enumerate()
583        .map(|(i, p)| {
584            let xi = (((p[0] - box_min[0]) / range[0]) * 1023.0) as u32;
585            let yi = (((p[1] - box_min[1]) / range[1]) * 1023.0) as u32;
586            let zi = (((p[2] - box_min[2]) / range[2]) * 1023.0) as u32;
587            let x = xi.min(1023);
588            let y = yi.min(1023);
589            let z = zi.min(1023);
590            (morton_encode(x, y, z), i)
591        })
592        .collect();
593    code_index_pairs.sort_by_key(|&(code, _)| code);
594    let sorted_indices: Vec<usize> = code_index_pairs.iter().map(|&(_, i)| i).collect();
595    let sorted_codes: Vec<u32> = code_index_pairs.iter().map(|&(c, _)| c).collect();
596    (sorted_indices, sorted_codes)
597}
598/// Compute the 30-bit Morton code for a position relative to a bounding box.
599///
600/// Normalises `pos` into `[0, 1023]^3` integer coordinates and encodes them.
601pub fn position_to_morton(pos: [f64; 3], box_min: [f64; 3], box_max: [f64; 3]) -> u32 {
602    let range = [
603        (box_max[0] - box_min[0]).max(1e-10),
604        (box_max[1] - box_min[1]).max(1e-10),
605        (box_max[2] - box_min[2]).max(1e-10),
606    ];
607    let x = (((pos[0] - box_min[0]) / range[0]) * 1023.0) as u32;
608    let y = (((pos[1] - box_min[1]) / range[1]) * 1023.0) as u32;
609    let z = (((pos[2] - box_min[2]) / range[2]) * 1023.0) as u32;
610    morton_encode(x.min(1023), y.min(1023), z.min(1023))
611}
612/// Insert particles into an existing cell list without a full rebuild.
613///
614/// This is a **sequential** incremental insert that updates `cell_counts` and
615/// `sorted_indices` in place.  The `cell_starts` offsets remain valid only if
616/// the new particles land in cells that already have capacity; otherwise a
617/// full rebuild should be used.
618///
619/// Returns the number of particles successfully inserted.
620pub fn insert_particles(cl: &mut GpuCellList, new_positions: &[[f64; 3]]) -> usize {
621    let old_n = cl.sorted_indices.len();
622    let mut inserted = 0usize;
623    for (i, &pos) in new_positions.iter().enumerate() {
624        let cell = cl.cell_index(pos);
625        cl.sorted_indices.push(old_n + i);
626        cl.cell_counts[cell] += 1;
627        inserted += 1;
628    }
629    let new_starts = parallel_prefix_sum(
630        &cl.cell_counts
631            .iter()
632            .map(|&c| c as usize)
633            .collect::<Vec<_>>(),
634    );
635    cl.cell_starts = new_starts.iter().map(|&s| s as i32).collect();
636    inserted
637}
638/// Query all particles within `radius` of `query_pos` from a set of positions,
639/// using a pre-built `GpuCellList`.
640///
641/// This is an alias for `GpuCellList::neighbors_in_radius` exposed at module level
642/// for ergonomic access.
643pub fn query_neighbors(
644    cl: &GpuCellList,
645    positions: &[[f64; 3]],
646    query_pos: [f64; 3],
647    radius: f64,
648) -> Vec<usize> {
649    cl.neighbors_in_radius(positions, query_pos, radius)
650}
651#[cfg(test)]
652mod extended_cell_tests {
653    use crate::cell_list::CellList;
654    use crate::cell_list::GhostCellManager;
655    use crate::cell_list::GpuCellList;
656    use crate::cell_list::GridResizer;
657    use crate::cell_list::OccupancyStats;
658
659    use crate::cell_list::insert_particles;
660    use crate::cell_list::parallel_morton_sort;
661    use crate::cell_list::position_to_morton;
662    use crate::cell_list::query_neighbors;
663    #[test]
664    fn test_occupancy_stats_uniform() {
665        let positions: Vec<[f64; 3]> = vec![
666            [0.5, 0.5, 0.5],
667            [1.5, 0.5, 0.5],
668            [0.5, 1.5, 0.5],
669            [1.5, 1.5, 0.5],
670            [0.5, 0.5, 1.5],
671            [1.5, 0.5, 1.5],
672            [0.5, 1.5, 1.5],
673            [1.5, 1.5, 1.5],
674        ];
675        let cl = GpuCellList::build_parallel(&positions);
676        let stats = OccupancyStats::compute(&cl);
677        assert_eq!(stats.total_particles, 8);
678        assert_eq!(stats.max_occupancy, 1);
679        assert!(stats.is_perfectly_spread());
680    }
681    #[test]
682    fn test_occupancy_stats_clustered() {
683        let positions: Vec<[f64; 3]> = vec![
684            [0.1, 0.1, 0.1],
685            [0.2, 0.1, 0.1],
686            [0.1, 0.2, 0.1],
687            [10.0, 10.0, 10.0],
688        ];
689        let cl = GpuCellList::build_parallel(&positions);
690        let stats = OccupancyStats::compute(&cl);
691        assert_eq!(stats.total_particles, 4);
692        assert!(
693            stats.max_occupancy >= 2,
694            "clustered particles should share a cell"
695        );
696        assert_eq!(stats.nonempty_cells, 2);
697    }
698    #[test]
699    fn test_occupancy_stats_load_imbalance_uniform() {
700        let positions: Vec<[f64; 3]> = vec![[0.5, 0.5, 0.5], [1.5, 0.5, 0.5]];
701        let cl = GpuCellList::build_parallel(&positions);
702        let stats = OccupancyStats::compute(&cl);
703        assert!(
704            (stats.load_imbalance - 1.0).abs() < 1e-10,
705            "load_imbalance = {}",
706            stats.load_imbalance
707        );
708    }
709    #[test]
710    fn test_occupancy_stats_completely_unbalanced() {
711        let positions: Vec<[f64; 3]> = vec![[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.1, 0.1]];
712        let cl = GpuCellList::build_parallel(&positions);
713        let stats = OccupancyStats::compute(&cl);
714        assert!(stats.is_completely_unbalanced() || stats.max_occupancy >= 2);
715    }
716    #[test]
717    fn test_grid_resizer_initial_build() {
718        let mut resizer = GridResizer::new(1.0, 0.5);
719        let positions = vec![[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]];
720        resizer.update(&positions);
721        assert!(resizer.get().is_some(), "cell list should be built");
722    }
723    #[test]
724    fn test_grid_resizer_no_resize_needed() {
725        let mut resizer = GridResizer::new(1.0, 1.0);
726        let positions = vec![[2.0, 2.0, 2.0]];
727        resizer.update(&positions);
728        let needs = resizer.needs_resize(&positions);
729        assert!(!needs, "same positions should not need resize");
730    }
731    #[test]
732    fn test_grid_resizer_escaping_particle() {
733        let mut resizer = GridResizer::new(1.0, 0.5);
734        let positions = vec![[1.0, 1.0, 1.0]];
735        resizer.rebuild(&positions);
736        let new_positions = vec![[100.0, 100.0, 100.0]];
737        assert!(
738            resizer.needs_resize(&new_positions),
739            "escaped particle should trigger resize"
740        );
741    }
742    #[test]
743    fn test_grid_resizer_empty_positions() {
744        let mut resizer = GridResizer::new(1.0, 0.5);
745        resizer.rebuild(&[]);
746        assert!(
747            resizer.get().is_some(),
748            "empty rebuild should produce valid list"
749        );
750    }
751    #[test]
752    fn test_ghost_manager_no_ghosts_interior() {
753        let mut mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.0);
754        let positions = vec![[5.0, 5.0, 5.0], [6.0, 6.0, 6.0]];
755        mgr.build_ghosts(&positions);
756        assert_eq!(mgr.num_ghosts(), 0, "interior particles need no ghosts");
757    }
758    #[test]
759    fn test_ghost_manager_near_one_face() {
760        let mut mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.5);
761        let positions = vec![[0.5, 5.0, 5.0]];
762        mgr.build_ghosts(&positions);
763        assert_eq!(mgr.num_ghosts(), 1, "should create 1 ghost on +x side");
764        assert!(
765            (mgr.ghost_positions[0][0] - 10.5).abs() < 1e-10,
766            "ghost x = {}",
767            mgr.ghost_positions[0][0]
768        );
769    }
770    #[test]
771    fn test_ghost_manager_near_two_faces() {
772        let mut mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.5);
773        let positions = vec![[0.5, 0.5, 5.0]];
774        mgr.build_ghosts(&positions);
775        assert_eq!(mgr.num_ghosts(), 2, "particle near two faces → 2 ghosts");
776    }
777    #[test]
778    fn test_ghost_manager_near_corner() {
779        let mut mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.5);
780        let positions = vec![[0.5, 0.5, 0.5]];
781        mgr.build_ghosts(&positions);
782        assert_eq!(mgr.num_ghosts(), 3, "corner particle → 3 primary ghosts");
783    }
784    #[test]
785    fn test_ghost_manager_map_to_real() {
786        let mut mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.0);
787        let positions = vec![[0.5, 5.0, 5.0], [9.5, 5.0, 5.0]];
788        mgr.build_ghosts(&positions);
789        for &ri in &mgr.ghost_to_real {
790            assert!(ri < positions.len(), "real index {ri} out of range");
791        }
792    }
793    #[test]
794    fn test_minimum_image_convention() {
795        let mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.0);
796        let d = mgr.minimum_image([9.0, 0.0, 0.0]);
797        assert!((d[0] - (-1.0)).abs() < 1e-10, "min image x = {}", d[0]);
798        assert!(d[1].abs() < 1e-12);
799        assert!(d[2].abs() < 1e-12);
800    }
801    #[test]
802    fn test_wrap_position_basic() {
803        let mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.0);
804        let p = mgr.wrap_position([11.5, -0.5, 10.0]);
805        assert!((p[0] - 1.5).abs() < 1e-10, "wrapped x = {}", p[0]);
806        assert!((p[1] - 9.5).abs() < 1e-10, "wrapped y = {}", p[1]);
807        assert!(p[2].abs() < 1e-10, "wrapped z = {}", p[2]);
808    }
809    #[test]
810    fn test_wrap_all_in_place() {
811        let mgr = GhostCellManager::new([5.0, 5.0, 5.0], 0.5);
812        let mut positions = vec![[6.0, 7.0, 0.0], [-1.0, 2.5, 11.0]];
813        mgr.wrap_all(&mut positions);
814        for p in &positions {
815            for k in 0..3 {
816                assert!(
817                    p[k] >= 0.0 && p[k] < 5.0,
818                    "wrapped coord out of range: {}",
819                    p[k]
820                );
821            }
822        }
823    }
824    #[test]
825    fn test_parallel_morton_sort_sorted_codes() {
826        let positions = vec![
827            [3.0, 3.0, 3.0],
828            [1.0, 1.0, 1.0],
829            [7.0, 7.0, 7.0],
830            [5.0, 5.0, 5.0],
831        ];
832        let (_idx, codes) = parallel_morton_sort(&positions, [0.0; 3], [10.0; 3]);
833        for i in 0..codes.len() - 1 {
834            assert!(
835                codes[i] <= codes[i + 1],
836                "parallel morton sort codes not sorted at {i}"
837            );
838        }
839    }
840    #[test]
841    fn test_parallel_morton_sort_valid_permutation() {
842        let positions: Vec<[f64; 3]> = (0..10).map(|i| [i as f64, 0.0, 0.0]).collect();
843        let (idx, codes) = parallel_morton_sort(&positions, [0.0; 3], [10.0, 1.0, 1.0]);
844        assert_eq!(idx.len(), 10);
845        assert_eq!(codes.len(), 10);
846        let mut sorted_idx = idx.clone();
847        sorted_idx.sort_unstable();
848        assert_eq!(sorted_idx, (0..10).collect::<Vec<_>>());
849    }
850    #[test]
851    fn test_position_to_morton_corner() {
852        let code = position_to_morton([0.0, 0.0, 0.0], [0.0; 3], [1.0; 3]);
853        assert_eq!(code, 0, "corner should give Morton code 0");
854    }
855    #[test]
856    fn test_position_to_morton_different_positions() {
857        let p1 = position_to_morton([1.0, 0.0, 0.0], [0.0; 3], [10.0; 3]);
858        let p2 = position_to_morton([0.0, 1.0, 0.0], [0.0; 3], [10.0; 3]);
859        let p3 = position_to_morton([5.0, 5.0, 5.0], [0.0; 3], [10.0; 3]);
860        assert_ne!(p1, p3);
861        assert_ne!(p2, p3);
862    }
863    #[test]
864    fn test_insert_particles_increases_count() {
865        let mut cl = GpuCellList::build_parallel(&[[1.0, 1.0, 1.0]]);
866        let original_len = cl.sorted_indices.len();
867        let new_particles = vec![[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]];
868        let inserted = insert_particles(&mut cl, &new_particles);
869        assert_eq!(inserted, 2);
870        assert_eq!(cl.sorted_indices.len(), original_len + 2);
871    }
872    #[test]
873    fn test_insert_particles_empty_grid() {
874        let mut cl = GpuCellList::new([4, 4, 4], 1.0, [4.0, 4.0, 4.0]);
875        let positions = vec![[0.5, 0.5, 0.5], [1.5, 0.5, 0.5]];
876        let inserted = insert_particles(&mut cl, &positions);
877        assert_eq!(inserted, 2);
878    }
879    #[test]
880    fn test_query_neighbors_finds_close_particle() {
881        let positions = vec![[1.0, 1.0, 1.0], [1.1, 1.0, 1.0], [9.0, 9.0, 9.0]];
882        let cl = GpuCellList::build_parallel(&positions);
883        let mut neighbours = query_neighbors(&cl, &positions, [1.0, 1.0, 1.0], 0.5);
884        neighbours.sort_unstable();
885        assert!(neighbours.contains(&0), "should find self");
886        assert!(neighbours.contains(&1), "should find nearby particle");
887        assert!(!neighbours.contains(&2), "should not find far particle");
888    }
889    #[test]
890    fn test_query_neighbors_empty_result() {
891        let positions = vec![[0.0, 0.0, 0.0], [100.0, 100.0, 100.0]];
892        let cl = GpuCellList::build_parallel(&positions);
893        let neighbours = query_neighbors(&cl, &positions, [50.0, 50.0, 50.0], 0.1);
894        assert!(neighbours.is_empty(), "no particles near middle of box");
895    }
896    #[test]
897    fn test_verlet_list_close_pair_found() {
898        let positions = vec![[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], [10.0, 10.0, 10.0]];
899        let cl = CellList::build(&positions);
900        let pairs = cl.build_neighbor_list_verlet(1.0, 0.2);
901        let has_01 = pairs.contains(&(0, 1));
902        assert!(has_01, "pair (0,1) must be in Verlet list");
903    }
904    #[test]
905    fn test_verlet_list_far_pair_excluded() {
906        let positions = vec![[0.0, 0.0, 0.0], [20.0, 20.0, 20.0]];
907        let cl = CellList::build(&positions);
908        let pairs = cl.build_neighbor_list_verlet(1.0, 0.2);
909        assert!(pairs.is_empty(), "far pair must not appear in Verlet list");
910    }
911    #[test]
912    fn test_verlet_list_no_self_pairs() {
913        let positions = vec![[1.0, 1.0, 1.0], [1.1, 1.0, 1.0], [1.2, 1.0, 1.0]];
914        let cl = CellList::build(&positions);
915        let pairs = cl.build_neighbor_list_verlet(1.0, 0.5);
916        for &(i, j) in &pairs {
917            assert_ne!(i, j, "self-pair found");
918        }
919    }
920    #[test]
921    fn test_verlet_list_pairs_ordered() {
922        let positions: Vec<[f64; 3]> = (0..5).map(|i| [i as f64 * 0.3, 0.0, 0.0]).collect();
923        let cl = CellList::build(&positions);
924        let pairs = cl.build_neighbor_list_verlet(1.0, 0.1);
925        for &(i, j) in &pairs {
926            assert!(i < j, "Verlet pair must have i < j");
927        }
928    }
929    #[test]
930    fn test_update_incremental_no_move() {
931        let positions = vec![[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]];
932        let mut cl = CellList::build(&positions);
933        let relocated = cl.update_incremental(&positions, &positions, 0.1);
934        assert_eq!(relocated, 0, "no particle moved");
935    }
936    #[test]
937    fn test_update_incremental_large_move_counted() {
938        let old = vec![[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]];
939        let new_pos = vec![[1.0, 1.0, 1.0], [6.0, 6.0, 6.0]];
940        let mut cl = CellList::build(&old);
941        let relocated = cl.update_incremental(&new_pos, &old, 0.5);
942        assert!(relocated >= 1, "at least one particle relocated");
943    }
944    #[test]
945    fn test_update_incremental_threshold_respected() {
946        let old = vec![[0.0, 0.0, 0.0], [5.0, 5.0, 5.0]];
947        let new_pos = vec![[0.05, 0.0, 0.0], [8.0, 8.0, 8.0]];
948        let mut cl = CellList::build(&old);
949        let relocated = cl.update_incremental(&new_pos, &old, 1.0);
950        assert_eq!(relocated, 1);
951    }
952    #[test]
953    fn test_pair_density_single_bin() {
954        let positions = vec![[0.0, 0.0, 0.0], [0.5, 0.0, 0.0]];
955        let cl = CellList::build(&positions);
956        let hist = cl.compute_pair_density(2.0, 1.0);
957        assert!(hist[0] >= 1, "pair must appear in bin 0");
958    }
959    #[test]
960    fn test_pair_density_no_pairs_beyond_max_r() {
961        let positions = vec![[0.0, 0.0, 0.0], [5.0, 0.0, 0.0]];
962        let cl = CellList::build(&positions);
963        let hist = cl.compute_pair_density(2.0, 0.5);
964        let total: usize = hist.iter().sum();
965        assert_eq!(total, 0, "pair beyond max_r should not be counted");
966    }
967    #[test]
968    fn test_pair_density_histogram_length() {
969        let positions = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]];
970        let cl = CellList::build(&positions);
971        let hist = cl.compute_pair_density(5.0, 1.0);
972        assert_eq!(hist.len(), 5, "histogram length = ceil(max_r/dr)");
973    }
974}