1#![allow(clippy::needless_range_loop)]
6use rayon::prelude::*;
7
8use super::types::GpuCellList;
9
10pub(super) fn expand_bits(mut x: u32) -> u32 {
14 x &= 0x000003FF;
15 x = (x | (x << 16)) & 0x030000FF;
16 x = (x | (x << 8)) & 0x0300F00F;
17 x = (x | (x << 4)) & 0x030C30C3;
18 x = (x | (x << 2)) & 0x09249249;
19 x
20}
21pub fn morton_encode(x: u32, y: u32, z: u32) -> u32 {
25 expand_bits(x) | (expand_bits(y) << 1) | (expand_bits(z) << 2)
26}
27pub fn morton_decode(code: u32) -> (u32, u32, u32) {
29 (
30 compact_bits(code),
31 compact_bits(code >> 1),
32 compact_bits(code >> 2),
33 )
34}
35pub(super) fn compact_bits(mut x: u32) -> u32 {
37 x &= 0x09249249;
38 x = (x | (x >> 2)) & 0x030C30C3;
39 x = (x | (x >> 4)) & 0x0300F00F;
40 x = (x | (x >> 8)) & 0x030000FF;
41 x = (x | (x >> 16)) & 0x000003FF;
42 x
43}
44pub fn morton_sort(
53 positions: &[[f64; 3]],
54 box_min: [f64; 3],
55 box_max: [f64; 3],
56) -> (Vec<usize>, Vec<u32>) {
57 let range = [
58 (box_max[0] - box_min[0]).max(1e-10),
59 (box_max[1] - box_min[1]).max(1e-10),
60 (box_max[2] - box_min[2]).max(1e-10),
61 ];
62 let mut codes: Vec<(u32, usize)> = positions
63 .par_iter()
64 .enumerate()
65 .map(|(i, p)| {
66 let x = (((p[0] - box_min[0]) / range[0] * 1023.0) as u32).min(1023);
67 let y = (((p[1] - box_min[1]) / range[1] * 1023.0) as u32).min(1023);
68 let z = (((p[2] - box_min[2]) / range[2] * 1023.0) as u32).min(1023);
69 (morton_encode(x, y, z), i)
70 })
71 .collect();
72 codes.sort_by_key(|&(code, _)| code);
73 let sorted_indices: Vec<usize> = codes.iter().map(|&(_, idx)| idx).collect();
74 let morton_codes: Vec<u32> = codes.iter().map(|&(code, _)| code).collect();
75 (sorted_indices, morton_codes)
76}
77pub fn parallel_prefix_sum(counts: &[usize]) -> Vec<usize> {
79 let mut out = Vec::with_capacity(counts.len());
80 let mut acc = 0usize;
81 for &c in counts {
82 out.push(acc);
83 acc += c;
84 }
85 out
86}
87pub fn compute_bounding_box(positions: &[[f64; 3]]) -> ([f64; 3], [f64; 3]) {
91 if positions.is_empty() {
92 return ([0.0; 3], [0.0; 3]);
93 }
94 let mut min = positions[0];
95 let mut max = positions[0];
96 for p in positions {
97 for d in 0..3 {
98 if p[d] < min[d] {
99 min[d] = p[d];
100 }
101 if p[d] > max[d] {
102 max[d] = p[d];
103 }
104 }
105 }
106 (min, max)
107}
108pub fn reorder_by_permutation<T: Clone>(data: &[T], perm: &[usize]) -> Vec<T> {
112 perm.iter().map(|&i| data[i].clone()).collect()
113}
114pub fn radix_sort_mock(keys: &[u32]) -> (Vec<u32>, Vec<usize>) {
119 if keys.is_empty() {
120 return (vec![], vec![]);
121 }
122 let mut indexed: Vec<(u32, usize)> = keys.iter().copied().zip(0..).collect();
123 indexed.sort_by_key(|&(k, _)| k);
124 let sorted_keys: Vec<u32> = indexed.iter().map(|&(k, _)| k).collect();
125 let sorted_indices: Vec<usize> = indexed.iter().map(|&(_, i)| i).collect();
126 (sorted_keys, sorted_indices)
127}
128pub fn gpu_prefix_sum(counts: &[usize]) -> Vec<usize> {
133 let mut out = Vec::with_capacity(counts.len());
134 let mut running = 0usize;
135 for &c in counts {
136 out.push(running);
137 running += c;
138 }
139 out
140}
141pub fn parallel_count_particles(
149 positions: &[[f64; 3]],
150 n_cells: [usize; 3],
151 cell_size: f64,
152) -> Vec<usize> {
153 let [nx, ny, nz] = n_cells;
154 let total = nx * ny * nz;
155 let mut counts = vec![0usize; total];
156 for p in positions {
157 let ix = ((p[0] / cell_size) as isize).clamp(0, nx as isize - 1) as usize;
158 let iy = ((p[1] / cell_size) as isize).clamp(0, ny as isize - 1) as usize;
159 let iz = ((p[2] / cell_size) as isize).clamp(0, nz as isize - 1) as usize;
160 counts[ix + nx * (iy + ny * iz)] += 1;
161 }
162 counts
163}
164pub fn distribute_cells_to_gpus(n_cells: usize, n_gpus: usize) -> Vec<std::ops::Range<usize>> {
169 if n_gpus == 0 || n_cells == 0 {
170 return vec![];
171 }
172 let base = n_cells / n_gpus;
173 let remainder = n_cells % n_gpus;
174 let mut ranges = Vec::with_capacity(n_gpus);
175 let mut start = 0;
176 for gpu in 0..n_gpus {
177 let extra = if gpu < remainder { 1 } else { 0 };
178 let end = start + base + extra;
179 ranges.push(start..end);
180 start = end;
181 }
182 ranges
183}
184pub fn gpu_neighbor_search_kernel(
189 cl: &GpuCellList,
190 positions: &[[f64; 3]],
191 cutoff: f64,
192) -> Vec<(usize, usize)> {
193 let mut pairs = Vec::new();
194 cl.for_each_pair(positions, cutoff, |i, j, _d2| {
195 let (a, b) = if i < j { (i, j) } else { (j, i) };
196 pairs.push((a, b));
197 });
198 pairs.sort_unstable();
199 pairs.dedup();
200 pairs
201}
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::cell_list::CellList;
206
207 use crate::cell_list::GpuCellList;
208
209 use crate::cell_list::SpatialHash;
210
211 #[test]
212 fn test_prefix_sum_empty() {
213 assert_eq!(parallel_prefix_sum(&[]), Vec::<usize>::new());
214 }
215 #[test]
216 fn test_prefix_sum_basic() {
217 let counts = [1usize, 2, 3, 4];
218 let result = parallel_prefix_sum(&counts);
219 assert_eq!(result, vec![0, 1, 3, 6]);
220 }
221 #[test]
222 fn test_cell_index_clamp() {
223 let list = GpuCellList::new([4, 4, 4], 1.0, [4.0, 4.0, 4.0]);
224 let idx = list.cell_index([4.5, 4.5, 4.5]);
225 assert_eq!(idx, 3 * 4 * 4 + 3 * 4 + 3);
226 }
227 #[test]
228 fn test_total_cells() {
229 let list = GpuCellList::new([3, 4, 5], 1.0, [3.0, 4.0, 5.0]);
230 assert_eq!(list.total_cells(), 60);
231 }
232 #[test]
233 fn test_build_parallel_counts() {
234 let positions: Vec<[f64; 3]> = vec![
235 [0.5, 0.5, 0.5],
236 [1.5, 0.5, 0.5],
237 [0.5, 1.5, 0.5],
238 [1.5, 1.5, 0.5],
239 [0.5, 0.5, 1.5],
240 [1.5, 0.5, 1.5],
241 [0.5, 1.5, 1.5],
242 [1.5, 1.5, 1.5],
243 ];
244 let cl = GpuCellList::build_parallel(&positions);
245 assert_eq!(cl.sorted_indices.len(), 8);
246 for c in 0..cl.total_cells() {
247 assert_eq!(cl.cell_counts[c], 1);
248 }
249 }
250 #[test]
251 fn test_neighbors_in_radius() {
252 let positions: Vec<[f64; 3]> = vec![[0.5, 0.5, 0.5], [0.6, 0.5, 0.5], [5.0, 5.0, 5.0]];
253 let cl = GpuCellList::build_parallel(&positions);
254 let mut neighbours = cl.neighbors_in_radius(&positions, [0.5, 0.5, 0.5], 0.5);
255 neighbours.sort_unstable();
256 assert!(neighbours.contains(&0));
257 assert!(neighbours.contains(&1));
258 assert!(!neighbours.contains(&2));
259 }
260 #[test]
261 fn cell_list_find_neighbors_all_pairs() {
262 let positions: Vec<[f64; 3]> = vec![
263 [1.0, 1.0, 1.0],
264 [1.2, 1.0, 1.0],
265 [1.0, 1.3, 1.0],
266 [9.0, 9.0, 9.0],
267 ];
268 let cl = CellList::build(&positions);
269 let radius = 0.5;
270 let mut neighbours = cl.find_neighbors([1.0, 1.0, 1.0], radius);
271 neighbours.sort_unstable();
272 assert!(neighbours.contains(&0), "should find self: {neighbours:?}");
273 assert!(
274 neighbours.contains(&1),
275 "should find particle 1: {neighbours:?}"
276 );
277 assert!(
278 neighbours.contains(&2),
279 "should find particle 2: {neighbours:?}"
280 );
281 assert!(
282 !neighbours.contains(&3),
283 "particle 3 is far: {neighbours:?}"
284 );
285 }
286 #[test]
287 fn cell_list_new_compiles() {
288 let cl = CellList::new([10.0, 10.0, 10.0], 2.0);
289 assert_eq!(cl.inner.total_cells(), 125);
290 }
291 #[test]
293 fn test_morton_roundtrip() {
294 let test_cases = [
295 (0, 0, 0),
296 (1, 0, 0),
297 (0, 1, 0),
298 (0, 0, 1),
299 (7, 3, 5),
300 (1023, 1023, 1023),
301 (512, 256, 128),
302 ];
303 for (x, y, z) in test_cases {
304 let code = morton_encode(x, y, z);
305 let (dx, dy, dz) = morton_decode(code);
306 assert_eq!(dx, x, "x mismatch for ({x},{y},{z})");
307 assert_eq!(dy, y, "y mismatch for ({x},{y},{z})");
308 assert_eq!(dz, z, "z mismatch for ({x},{y},{z})");
309 }
310 }
311 #[test]
313 fn test_morton_locality() {
314 let c1 = morton_encode(1, 1, 1);
315 let c2 = morton_encode(2, 1, 1);
316 let c_far = morton_encode(100, 100, 100);
317 let d_near = c1.abs_diff(c2);
318 let d_far = c1.abs_diff(c_far);
319 assert!(
320 d_near < d_far,
321 "near distance {d_near} should be less than far {d_far}"
322 );
323 }
324 #[test]
326 fn test_morton_sort_permutation() {
327 let positions = vec![[5.0, 5.0, 5.0], [1.0, 1.0, 1.0], [3.0, 3.0, 3.0]];
328 let (indices, codes) = morton_sort(&positions, [0.0; 3], [10.0, 10.0, 10.0]);
329 assert_eq!(indices.len(), 3);
330 assert_eq!(codes.len(), 3);
331 for i in 0..codes.len() - 1 {
332 assert!(codes[i] <= codes[i + 1], "codes not sorted at {i}");
333 }
334 let mut sorted = indices.clone();
335 sorted.sort();
336 assert_eq!(sorted, vec![0, 1, 2]);
337 }
338 #[test]
340 fn test_spatial_hash_query() {
341 let positions = vec![[0.5, 0.5, 0.5], [0.6, 0.5, 0.5], [5.0, 5.0, 5.0]];
342 let mut hash = SpatialHash::new(64, 1.0);
343 hash.build(&positions);
344 assert_eq!(hash.len(), 3);
345 let mut neighbours = hash.query_radius(&positions, [0.5, 0.5, 0.5], 0.5);
346 neighbours.sort_unstable();
347 neighbours.dedup();
348 assert!(neighbours.contains(&0));
349 assert!(neighbours.contains(&1));
350 assert!(!neighbours.contains(&2));
351 }
352 #[test]
354 fn test_spatial_hash_empty() {
355 let hash = SpatialHash::new(64, 1.0);
356 assert!(hash.is_empty());
357 assert_eq!(hash.len(), 0);
358 }
359 #[test]
361 fn test_spatial_hash_clear() {
362 let mut hash = SpatialHash::new(64, 1.0);
363 hash.insert(0, [0.5, 0.5, 0.5]);
364 assert!(!hash.is_empty());
365 hash.clear();
366 assert!(hash.is_empty());
367 }
368 #[test]
370 fn test_bounding_box() {
371 let positions = vec![[1.0, 2.0, 3.0], [4.0, 0.0, 1.0], [2.0, 5.0, 2.0]];
372 let (min, max) = compute_bounding_box(&positions);
373 assert_eq!(min, [1.0, 0.0, 1.0]);
374 assert_eq!(max, [4.0, 5.0, 3.0]);
375 }
376 #[test]
378 fn test_bounding_box_empty() {
379 let (min, max) = compute_bounding_box(&[]);
380 assert_eq!(min, [0.0; 3]);
381 assert_eq!(max, [0.0; 3]);
382 }
383 #[test]
385 fn test_reorder() {
386 let data = vec![10, 20, 30, 40];
387 let perm = vec![3, 1, 0, 2];
388 let reordered = reorder_by_permutation(&data, &perm);
389 assert_eq!(reordered, vec![40, 20, 10, 30]);
390 }
391 #[test]
393 fn test_max_cell_occupancy() {
394 let positions: Vec<[f64; 3]> = vec![[0.5, 0.5, 0.5], [0.6, 0.5, 0.5], [5.0, 5.0, 5.0]];
395 let cl = GpuCellList::build_parallel(&positions);
396 let max = cl.max_cell_occupancy();
397 assert!(max >= 2, "max occupancy should be at least 2, got {max}");
398 }
399 #[test]
401 fn test_nonempty_cells() {
402 let positions: Vec<[f64; 3]> = vec![[0.5, 0.5, 0.5], [5.0, 5.0, 5.0]];
403 let cl = GpuCellList::build_parallel(&positions);
404 let ne = cl.num_nonempty_cells();
405 assert_eq!(ne, 2, "should have 2 non-empty cells, got {ne}");
406 }
407 #[test]
409 fn test_for_each_pair() {
410 let positions: Vec<[f64; 3]> = vec![[0.5, 0.5, 0.5], [0.6, 0.5, 0.5], [5.0, 5.0, 5.0]];
411 let cl = GpuCellList::build_parallel(&positions);
412 let mut pairs = Vec::new();
413 cl.for_each_pair(&positions, 0.5, |i, j, _d2| {
414 pairs.push((i.min(j), i.max(j)));
415 });
416 pairs.sort();
417 pairs.dedup();
418 assert!(
419 pairs.contains(&(0, 1)),
420 "should find pair (0,1), got {pairs:?}"
421 );
422 assert!(
423 !pairs.iter().any(|&(a, b)| a == 2 || b == 2),
424 "should not find pairs with particle 2"
425 );
426 }
427 #[test]
428 fn test_radix_sort_sorted_output() {
429 let keys = vec![5u32, 1, 9, 3, 7, 2];
430 let (sorted_keys, sorted_indices) = radix_sort_mock(&keys);
431 for i in 0..sorted_keys.len() - 1 {
432 assert!(
433 sorted_keys[i] <= sorted_keys[i + 1],
434 "radix sort not sorted at {i}"
435 );
436 }
437 for &idx in &sorted_indices {
438 assert!(idx < keys.len(), "invalid index {idx}");
439 }
440 }
441 #[test]
442 fn test_radix_sort_permutation_correct() {
443 let keys = vec![30u32, 10, 20];
444 let (sorted_keys, sorted_indices) = radix_sort_mock(&keys);
445 assert_eq!(sorted_keys[0], 10);
446 assert_eq!(sorted_keys[1], 20);
447 assert_eq!(sorted_keys[2], 30);
448 assert_eq!(sorted_indices[0], 1);
449 assert_eq!(sorted_indices[1], 2);
450 assert_eq!(sorted_indices[2], 0);
451 }
452 #[test]
453 fn test_radix_sort_empty() {
454 let keys: Vec<u32> = vec![];
455 let (sk, si) = radix_sort_mock(&keys);
456 assert!(sk.is_empty());
457 assert!(si.is_empty());
458 }
459 #[test]
460 fn test_radix_sort_all_equal() {
461 let keys = vec![7u32; 10];
462 let (sorted_keys, sorted_indices) = radix_sort_mock(&keys);
463 assert_eq!(sorted_keys.len(), 10);
464 assert!(sorted_keys.iter().all(|&k| k == 7));
465 assert_eq!(sorted_indices.len(), 10);
466 }
467 #[test]
468 fn test_gpu_prefix_sum_basic() {
469 let counts = vec![0usize, 1, 3, 0, 2, 5];
470 let result = gpu_prefix_sum(&counts);
471 assert_eq!(result, vec![0, 0, 1, 4, 4, 6]);
472 }
473 #[test]
474 fn test_gpu_prefix_sum_all_zeros() {
475 let counts = vec![0usize; 5];
476 let result = gpu_prefix_sum(&counts);
477 assert_eq!(result, vec![0, 0, 0, 0, 0]);
478 }
479 #[test]
480 fn test_gpu_prefix_sum_single() {
481 let result = gpu_prefix_sum(&[7usize]);
482 assert_eq!(result, vec![0]);
483 }
484 #[test]
485 fn test_parallel_cell_counting() {
486 let positions = vec![[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [2.5, 0.5, 0.5]];
487 let n_cells = [4usize, 4, 4];
488 let counts = parallel_count_particles(&positions, n_cells, 1.0);
489 let total: usize = counts.iter().sum();
490 assert_eq!(total, 3, "total count should equal number of particles");
491 }
492 #[test]
493 fn test_parallel_cell_counting_all_in_one_cell() {
494 let positions: Vec<[f64; 3]> = vec![[0.1, 0.1, 0.1], [0.2, 0.1, 0.1], [0.1, 0.2, 0.1]];
495 let counts = parallel_count_particles(&positions, [4, 4, 4], 1.0);
496 let max = counts.iter().cloned().max().unwrap_or(0);
497 assert!(max >= 3, "all particles in one cell: max_count={max}");
498 }
499 #[test]
500 fn test_multi_gpu_distribution_two_gpus() {
501 let n_cells = 100;
502 let n_gpus = 2;
503 let ranges = distribute_cells_to_gpus(n_cells, n_gpus);
504 assert_eq!(ranges.len(), n_gpus);
505 assert_eq!(ranges[0].start, 0);
506 assert_eq!(ranges[n_gpus - 1].end, n_cells);
507 for i in 0..n_gpus - 1 {
508 assert_eq!(ranges[i].end, ranges[i + 1].start, "gap at gpu {i}");
509 }
510 }
511 #[test]
512 fn test_multi_gpu_distribution_odd_cells() {
513 let ranges = distribute_cells_to_gpus(7, 3);
514 assert_eq!(ranges.len(), 3);
515 let total: usize = ranges.iter().map(|r| r.end - r.start).sum();
516 assert_eq!(total, 7);
517 }
518 #[test]
519 fn test_multi_gpu_distribution_single_gpu() {
520 let ranges = distribute_cells_to_gpus(50, 1);
521 assert_eq!(ranges.len(), 1);
522 assert_eq!(ranges[0].start, 0);
523 assert_eq!(ranges[0].end, 50);
524 }
525 #[test]
526 fn test_neighbor_search_kernel_finds_close_pair() {
527 let positions: Vec<[f64; 3]> = vec![[1.0, 1.0, 1.0], [1.1, 1.0, 1.0], [5.0, 5.0, 5.0]];
528 let cl = GpuCellList::build_parallel(&positions);
529 let pairs = gpu_neighbor_search_kernel(&cl, &positions, 0.5);
530 assert!(
531 pairs.contains(&(0, 1)) || pairs.contains(&(1, 0)),
532 "should find pair (0,1), got {pairs:?}"
533 );
534 assert!(
535 !pairs.iter().any(|&(a, b)| a == 2 || b == 2),
536 "particle 2 should not appear in pairs"
537 );
538 }
539 #[test]
540 fn test_neighbor_search_kernel_no_pairs() {
541 let positions: Vec<[f64; 3]> = vec![[0.0, 0.0, 0.0], [10.0, 0.0, 0.0], [20.0, 0.0, 0.0]];
542 let cl = GpuCellList::build_parallel(&positions);
543 let pairs = gpu_neighbor_search_kernel(&cl, &positions, 0.5);
544 assert!(pairs.is_empty(), "well-separated particles → no pairs");
545 }
546 #[test]
547 fn test_spatial_hash_rebuild() {
548 let mut hash = SpatialHash::new(128, 1.0);
549 let positions1 = vec![[0.5, 0.5, 0.5], [1.5, 0.5, 0.5]];
550 hash.build(&positions1);
551 assert_eq!(hash.len(), 2);
552 let positions2 = vec![[0.1, 0.1, 0.1]];
553 hash.build(&positions2);
554 assert_eq!(hash.len(), 1, "rebuild should replace old data");
555 }
556 #[test]
557 fn test_spatial_hash_large_number_of_particles() {
558 let positions: Vec<[f64; 3]> = (0..200).map(|i| [i as f64 * 0.1, 0.0, 0.0]).collect();
559 let mut hash = SpatialHash::new(256, 1.0);
560 hash.build(&positions);
561 assert_eq!(hash.len(), 200);
562 }
563}
564pub fn parallel_morton_sort(
571 positions: &[[f64; 3]],
572 box_min: [f64; 3],
573 box_max: [f64; 3],
574) -> (Vec<usize>, Vec<u32>) {
575 let range = [
576 (box_max[0] - box_min[0]).max(1e-10),
577 (box_max[1] - box_min[1]).max(1e-10),
578 (box_max[2] - box_min[2]).max(1e-10),
579 ];
580 let mut code_index_pairs: Vec<(u32, usize)> = positions
581 .par_iter()
582 .enumerate()
583 .map(|(i, p)| {
584 let xi = (((p[0] - box_min[0]) / range[0]) * 1023.0) as u32;
585 let yi = (((p[1] - box_min[1]) / range[1]) * 1023.0) as u32;
586 let zi = (((p[2] - box_min[2]) / range[2]) * 1023.0) as u32;
587 let x = xi.min(1023);
588 let y = yi.min(1023);
589 let z = zi.min(1023);
590 (morton_encode(x, y, z), i)
591 })
592 .collect();
593 code_index_pairs.sort_by_key(|&(code, _)| code);
594 let sorted_indices: Vec<usize> = code_index_pairs.iter().map(|&(_, i)| i).collect();
595 let sorted_codes: Vec<u32> = code_index_pairs.iter().map(|&(c, _)| c).collect();
596 (sorted_indices, sorted_codes)
597}
598pub fn position_to_morton(pos: [f64; 3], box_min: [f64; 3], box_max: [f64; 3]) -> u32 {
602 let range = [
603 (box_max[0] - box_min[0]).max(1e-10),
604 (box_max[1] - box_min[1]).max(1e-10),
605 (box_max[2] - box_min[2]).max(1e-10),
606 ];
607 let x = (((pos[0] - box_min[0]) / range[0]) * 1023.0) as u32;
608 let y = (((pos[1] - box_min[1]) / range[1]) * 1023.0) as u32;
609 let z = (((pos[2] - box_min[2]) / range[2]) * 1023.0) as u32;
610 morton_encode(x.min(1023), y.min(1023), z.min(1023))
611}
612pub fn insert_particles(cl: &mut GpuCellList, new_positions: &[[f64; 3]]) -> usize {
621 let old_n = cl.sorted_indices.len();
622 let mut inserted = 0usize;
623 for (i, &pos) in new_positions.iter().enumerate() {
624 let cell = cl.cell_index(pos);
625 cl.sorted_indices.push(old_n + i);
626 cl.cell_counts[cell] += 1;
627 inserted += 1;
628 }
629 let new_starts = parallel_prefix_sum(
630 &cl.cell_counts
631 .iter()
632 .map(|&c| c as usize)
633 .collect::<Vec<_>>(),
634 );
635 cl.cell_starts = new_starts.iter().map(|&s| s as i32).collect();
636 inserted
637}
638pub fn query_neighbors(
644 cl: &GpuCellList,
645 positions: &[[f64; 3]],
646 query_pos: [f64; 3],
647 radius: f64,
648) -> Vec<usize> {
649 cl.neighbors_in_radius(positions, query_pos, radius)
650}
651#[cfg(test)]
652mod extended_cell_tests {
653 use crate::cell_list::CellList;
654 use crate::cell_list::GhostCellManager;
655 use crate::cell_list::GpuCellList;
656 use crate::cell_list::GridResizer;
657 use crate::cell_list::OccupancyStats;
658
659 use crate::cell_list::insert_particles;
660 use crate::cell_list::parallel_morton_sort;
661 use crate::cell_list::position_to_morton;
662 use crate::cell_list::query_neighbors;
663 #[test]
664 fn test_occupancy_stats_uniform() {
665 let positions: Vec<[f64; 3]> = vec![
666 [0.5, 0.5, 0.5],
667 [1.5, 0.5, 0.5],
668 [0.5, 1.5, 0.5],
669 [1.5, 1.5, 0.5],
670 [0.5, 0.5, 1.5],
671 [1.5, 0.5, 1.5],
672 [0.5, 1.5, 1.5],
673 [1.5, 1.5, 1.5],
674 ];
675 let cl = GpuCellList::build_parallel(&positions);
676 let stats = OccupancyStats::compute(&cl);
677 assert_eq!(stats.total_particles, 8);
678 assert_eq!(stats.max_occupancy, 1);
679 assert!(stats.is_perfectly_spread());
680 }
681 #[test]
682 fn test_occupancy_stats_clustered() {
683 let positions: Vec<[f64; 3]> = vec![
684 [0.1, 0.1, 0.1],
685 [0.2, 0.1, 0.1],
686 [0.1, 0.2, 0.1],
687 [10.0, 10.0, 10.0],
688 ];
689 let cl = GpuCellList::build_parallel(&positions);
690 let stats = OccupancyStats::compute(&cl);
691 assert_eq!(stats.total_particles, 4);
692 assert!(
693 stats.max_occupancy >= 2,
694 "clustered particles should share a cell"
695 );
696 assert_eq!(stats.nonempty_cells, 2);
697 }
698 #[test]
699 fn test_occupancy_stats_load_imbalance_uniform() {
700 let positions: Vec<[f64; 3]> = vec![[0.5, 0.5, 0.5], [1.5, 0.5, 0.5]];
701 let cl = GpuCellList::build_parallel(&positions);
702 let stats = OccupancyStats::compute(&cl);
703 assert!(
704 (stats.load_imbalance - 1.0).abs() < 1e-10,
705 "load_imbalance = {}",
706 stats.load_imbalance
707 );
708 }
709 #[test]
710 fn test_occupancy_stats_completely_unbalanced() {
711 let positions: Vec<[f64; 3]> = vec![[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.1, 0.1]];
712 let cl = GpuCellList::build_parallel(&positions);
713 let stats = OccupancyStats::compute(&cl);
714 assert!(stats.is_completely_unbalanced() || stats.max_occupancy >= 2);
715 }
716 #[test]
717 fn test_grid_resizer_initial_build() {
718 let mut resizer = GridResizer::new(1.0, 0.5);
719 let positions = vec![[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]];
720 resizer.update(&positions);
721 assert!(resizer.get().is_some(), "cell list should be built");
722 }
723 #[test]
724 fn test_grid_resizer_no_resize_needed() {
725 let mut resizer = GridResizer::new(1.0, 1.0);
726 let positions = vec![[2.0, 2.0, 2.0]];
727 resizer.update(&positions);
728 let needs = resizer.needs_resize(&positions);
729 assert!(!needs, "same positions should not need resize");
730 }
731 #[test]
732 fn test_grid_resizer_escaping_particle() {
733 let mut resizer = GridResizer::new(1.0, 0.5);
734 let positions = vec![[1.0, 1.0, 1.0]];
735 resizer.rebuild(&positions);
736 let new_positions = vec![[100.0, 100.0, 100.0]];
737 assert!(
738 resizer.needs_resize(&new_positions),
739 "escaped particle should trigger resize"
740 );
741 }
742 #[test]
743 fn test_grid_resizer_empty_positions() {
744 let mut resizer = GridResizer::new(1.0, 0.5);
745 resizer.rebuild(&[]);
746 assert!(
747 resizer.get().is_some(),
748 "empty rebuild should produce valid list"
749 );
750 }
751 #[test]
752 fn test_ghost_manager_no_ghosts_interior() {
753 let mut mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.0);
754 let positions = vec![[5.0, 5.0, 5.0], [6.0, 6.0, 6.0]];
755 mgr.build_ghosts(&positions);
756 assert_eq!(mgr.num_ghosts(), 0, "interior particles need no ghosts");
757 }
758 #[test]
759 fn test_ghost_manager_near_one_face() {
760 let mut mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.5);
761 let positions = vec![[0.5, 5.0, 5.0]];
762 mgr.build_ghosts(&positions);
763 assert_eq!(mgr.num_ghosts(), 1, "should create 1 ghost on +x side");
764 assert!(
765 (mgr.ghost_positions[0][0] - 10.5).abs() < 1e-10,
766 "ghost x = {}",
767 mgr.ghost_positions[0][0]
768 );
769 }
770 #[test]
771 fn test_ghost_manager_near_two_faces() {
772 let mut mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.5);
773 let positions = vec![[0.5, 0.5, 5.0]];
774 mgr.build_ghosts(&positions);
775 assert_eq!(mgr.num_ghosts(), 2, "particle near two faces → 2 ghosts");
776 }
777 #[test]
778 fn test_ghost_manager_near_corner() {
779 let mut mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.5);
780 let positions = vec![[0.5, 0.5, 0.5]];
781 mgr.build_ghosts(&positions);
782 assert_eq!(mgr.num_ghosts(), 3, "corner particle → 3 primary ghosts");
783 }
784 #[test]
785 fn test_ghost_manager_map_to_real() {
786 let mut mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.0);
787 let positions = vec![[0.5, 5.0, 5.0], [9.5, 5.0, 5.0]];
788 mgr.build_ghosts(&positions);
789 for &ri in &mgr.ghost_to_real {
790 assert!(ri < positions.len(), "real index {ri} out of range");
791 }
792 }
793 #[test]
794 fn test_minimum_image_convention() {
795 let mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.0);
796 let d = mgr.minimum_image([9.0, 0.0, 0.0]);
797 assert!((d[0] - (-1.0)).abs() < 1e-10, "min image x = {}", d[0]);
798 assert!(d[1].abs() < 1e-12);
799 assert!(d[2].abs() < 1e-12);
800 }
801 #[test]
802 fn test_wrap_position_basic() {
803 let mgr = GhostCellManager::new([10.0, 10.0, 10.0], 1.0);
804 let p = mgr.wrap_position([11.5, -0.5, 10.0]);
805 assert!((p[0] - 1.5).abs() < 1e-10, "wrapped x = {}", p[0]);
806 assert!((p[1] - 9.5).abs() < 1e-10, "wrapped y = {}", p[1]);
807 assert!(p[2].abs() < 1e-10, "wrapped z = {}", p[2]);
808 }
809 #[test]
810 fn test_wrap_all_in_place() {
811 let mgr = GhostCellManager::new([5.0, 5.0, 5.0], 0.5);
812 let mut positions = vec![[6.0, 7.0, 0.0], [-1.0, 2.5, 11.0]];
813 mgr.wrap_all(&mut positions);
814 for p in &positions {
815 for k in 0..3 {
816 assert!(
817 p[k] >= 0.0 && p[k] < 5.0,
818 "wrapped coord out of range: {}",
819 p[k]
820 );
821 }
822 }
823 }
824 #[test]
825 fn test_parallel_morton_sort_sorted_codes() {
826 let positions = vec![
827 [3.0, 3.0, 3.0],
828 [1.0, 1.0, 1.0],
829 [7.0, 7.0, 7.0],
830 [5.0, 5.0, 5.0],
831 ];
832 let (_idx, codes) = parallel_morton_sort(&positions, [0.0; 3], [10.0; 3]);
833 for i in 0..codes.len() - 1 {
834 assert!(
835 codes[i] <= codes[i + 1],
836 "parallel morton sort codes not sorted at {i}"
837 );
838 }
839 }
840 #[test]
841 fn test_parallel_morton_sort_valid_permutation() {
842 let positions: Vec<[f64; 3]> = (0..10).map(|i| [i as f64, 0.0, 0.0]).collect();
843 let (idx, codes) = parallel_morton_sort(&positions, [0.0; 3], [10.0, 1.0, 1.0]);
844 assert_eq!(idx.len(), 10);
845 assert_eq!(codes.len(), 10);
846 let mut sorted_idx = idx.clone();
847 sorted_idx.sort_unstable();
848 assert_eq!(sorted_idx, (0..10).collect::<Vec<_>>());
849 }
850 #[test]
851 fn test_position_to_morton_corner() {
852 let code = position_to_morton([0.0, 0.0, 0.0], [0.0; 3], [1.0; 3]);
853 assert_eq!(code, 0, "corner should give Morton code 0");
854 }
855 #[test]
856 fn test_position_to_morton_different_positions() {
857 let p1 = position_to_morton([1.0, 0.0, 0.0], [0.0; 3], [10.0; 3]);
858 let p2 = position_to_morton([0.0, 1.0, 0.0], [0.0; 3], [10.0; 3]);
859 let p3 = position_to_morton([5.0, 5.0, 5.0], [0.0; 3], [10.0; 3]);
860 assert_ne!(p1, p3);
861 assert_ne!(p2, p3);
862 }
863 #[test]
864 fn test_insert_particles_increases_count() {
865 let mut cl = GpuCellList::build_parallel(&[[1.0, 1.0, 1.0]]);
866 let original_len = cl.sorted_indices.len();
867 let new_particles = vec![[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]];
868 let inserted = insert_particles(&mut cl, &new_particles);
869 assert_eq!(inserted, 2);
870 assert_eq!(cl.sorted_indices.len(), original_len + 2);
871 }
872 #[test]
873 fn test_insert_particles_empty_grid() {
874 let mut cl = GpuCellList::new([4, 4, 4], 1.0, [4.0, 4.0, 4.0]);
875 let positions = vec![[0.5, 0.5, 0.5], [1.5, 0.5, 0.5]];
876 let inserted = insert_particles(&mut cl, &positions);
877 assert_eq!(inserted, 2);
878 }
879 #[test]
880 fn test_query_neighbors_finds_close_particle() {
881 let positions = vec![[1.0, 1.0, 1.0], [1.1, 1.0, 1.0], [9.0, 9.0, 9.0]];
882 let cl = GpuCellList::build_parallel(&positions);
883 let mut neighbours = query_neighbors(&cl, &positions, [1.0, 1.0, 1.0], 0.5);
884 neighbours.sort_unstable();
885 assert!(neighbours.contains(&0), "should find self");
886 assert!(neighbours.contains(&1), "should find nearby particle");
887 assert!(!neighbours.contains(&2), "should not find far particle");
888 }
889 #[test]
890 fn test_query_neighbors_empty_result() {
891 let positions = vec![[0.0, 0.0, 0.0], [100.0, 100.0, 100.0]];
892 let cl = GpuCellList::build_parallel(&positions);
893 let neighbours = query_neighbors(&cl, &positions, [50.0, 50.0, 50.0], 0.1);
894 assert!(neighbours.is_empty(), "no particles near middle of box");
895 }
896 #[test]
897 fn test_verlet_list_close_pair_found() {
898 let positions = vec![[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], [10.0, 10.0, 10.0]];
899 let cl = CellList::build(&positions);
900 let pairs = cl.build_neighbor_list_verlet(1.0, 0.2);
901 let has_01 = pairs.contains(&(0, 1));
902 assert!(has_01, "pair (0,1) must be in Verlet list");
903 }
904 #[test]
905 fn test_verlet_list_far_pair_excluded() {
906 let positions = vec![[0.0, 0.0, 0.0], [20.0, 20.0, 20.0]];
907 let cl = CellList::build(&positions);
908 let pairs = cl.build_neighbor_list_verlet(1.0, 0.2);
909 assert!(pairs.is_empty(), "far pair must not appear in Verlet list");
910 }
911 #[test]
912 fn test_verlet_list_no_self_pairs() {
913 let positions = vec![[1.0, 1.0, 1.0], [1.1, 1.0, 1.0], [1.2, 1.0, 1.0]];
914 let cl = CellList::build(&positions);
915 let pairs = cl.build_neighbor_list_verlet(1.0, 0.5);
916 for &(i, j) in &pairs {
917 assert_ne!(i, j, "self-pair found");
918 }
919 }
920 #[test]
921 fn test_verlet_list_pairs_ordered() {
922 let positions: Vec<[f64; 3]> = (0..5).map(|i| [i as f64 * 0.3, 0.0, 0.0]).collect();
923 let cl = CellList::build(&positions);
924 let pairs = cl.build_neighbor_list_verlet(1.0, 0.1);
925 for &(i, j) in &pairs {
926 assert!(i < j, "Verlet pair must have i < j");
927 }
928 }
929 #[test]
930 fn test_update_incremental_no_move() {
931 let positions = vec![[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]];
932 let mut cl = CellList::build(&positions);
933 let relocated = cl.update_incremental(&positions, &positions, 0.1);
934 assert_eq!(relocated, 0, "no particle moved");
935 }
936 #[test]
937 fn test_update_incremental_large_move_counted() {
938 let old = vec![[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]];
939 let new_pos = vec![[1.0, 1.0, 1.0], [6.0, 6.0, 6.0]];
940 let mut cl = CellList::build(&old);
941 let relocated = cl.update_incremental(&new_pos, &old, 0.5);
942 assert!(relocated >= 1, "at least one particle relocated");
943 }
944 #[test]
945 fn test_update_incremental_threshold_respected() {
946 let old = vec![[0.0, 0.0, 0.0], [5.0, 5.0, 5.0]];
947 let new_pos = vec![[0.05, 0.0, 0.0], [8.0, 8.0, 8.0]];
948 let mut cl = CellList::build(&old);
949 let relocated = cl.update_incremental(&new_pos, &old, 1.0);
950 assert_eq!(relocated, 1);
951 }
952 #[test]
953 fn test_pair_density_single_bin() {
954 let positions = vec![[0.0, 0.0, 0.0], [0.5, 0.0, 0.0]];
955 let cl = CellList::build(&positions);
956 let hist = cl.compute_pair_density(2.0, 1.0);
957 assert!(hist[0] >= 1, "pair must appear in bin 0");
958 }
959 #[test]
960 fn test_pair_density_no_pairs_beyond_max_r() {
961 let positions = vec![[0.0, 0.0, 0.0], [5.0, 0.0, 0.0]];
962 let cl = CellList::build(&positions);
963 let hist = cl.compute_pair_density(2.0, 0.5);
964 let total: usize = hist.iter().sum();
965 assert_eq!(total, 0, "pair beyond max_r should not be counted");
966 }
967 #[test]
968 fn test_pair_density_histogram_length() {
969 let positions = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]];
970 let cl = CellList::build(&positions);
971 let hist = cl.compute_pair_density(5.0, 1.0);
972 assert_eq!(hist.len(), 5, "histogram length = ceil(max_r/dr)");
973 }
974}