Skip to main content

oxiphysics_gpu/
compute_pipeline.rs

1#![allow(clippy::ptr_arg)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! CPU-side simulation of a wgpu compute pipeline abstraction.
6//!
7//! Provides data layout and dispatch logic that mirrors a typical GPU compute
8//! pipeline, with all computation running on the CPU.
9
10// ── silence unused-item lints for the enum variants / fields used only in
11//    tests or future GPU back-ends ──────────────────────────────────────────
12#![allow(dead_code)]
13
14// ─────────────────────────────────────────────────────────────────────────────
15// BufferUsage
16// ─────────────────────────────────────────────────────────────────────────────
17
18/// Intended usage of a [`ComputeBuffer`].
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum BufferUsage {
21    /// Vertex data fed to the rasteriser.
22    Vertex,
23    /// Index data for indexed draw calls.
24    Index,
25    /// Small, frequently-updated uniform / constant data.
26    Uniform,
27    /// General-purpose read-write storage.
28    Storage,
29    /// CPU↔GPU transfer staging area.
30    Staging,
31}
32
33// ─────────────────────────────────────────────────────────────────────────────
34// ComputeBuffer
35// ─────────────────────────────────────────────────────────────────────────────
36
37/// A CPU-resident buffer that mirrors a GPU buffer.
38#[derive(Debug, Clone)]
39pub struct ComputeBuffer {
40    /// Raw `f32` elements.
41    pub data: Vec<f32>,
42    /// Intended GPU usage hint.
43    pub usage: BufferUsage,
44    /// Human-readable label (useful for debugging).
45    pub label: String,
46}
47
48impl ComputeBuffer {
49    /// Allocate a zero-initialised buffer of `size` `f32` elements.
50    pub fn new(size: usize, usage: BufferUsage, label: &str) -> Self {
51        Self {
52            data: vec![0.0_f32; size],
53            usage,
54            label: label.to_owned(),
55        }
56    }
57
58    /// Write `values` into the buffer starting at element `offset`.
59    ///
60    /// # Panics
61    /// Panics if `offset + values.len() > self.data.len()`.
62    pub fn write_f32(&mut self, offset: usize, values: &[f32]) {
63        let end = offset + values.len();
64        assert!(
65            end <= self.data.len(),
66            "write_f32: out-of-bounds write (offset={offset}, len={}, capacity={})",
67            values.len(),
68            self.data.len()
69        );
70        self.data[offset..end].copy_from_slice(values);
71    }
72
73    /// Read `count` `f32` elements starting at element `offset`.
74    ///
75    /// # Panics
76    /// Panics if `offset + count > self.data.len()`.
77    pub fn read_f32(&self, offset: usize, count: usize) -> Vec<f32> {
78        let end = offset + count;
79        assert!(
80            end <= self.data.len(),
81            "read_f32: out-of-bounds read (offset={offset}, count={count}, capacity={})",
82            self.data.len()
83        );
84        self.data[offset..end].to_vec()
85    }
86
87    /// Total size of the buffer in bytes (`len * 4`).
88    pub fn byte_size(&self) -> usize {
89        self.data.len() * std::mem::size_of::<f32>()
90    }
91}
92
93// ─────────────────────────────────────────────────────────────────────────────
94// WorkgroupSize
95// ─────────────────────────────────────────────────────────────────────────────
96
97/// Workgroup dimensions for a compute dispatch.
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub struct WorkgroupSize {
100    /// X dimension.
101    pub x: u32,
102    /// Y dimension.
103    pub y: u32,
104    /// Z dimension.
105    pub z: u32,
106}
107
108impl WorkgroupSize {
109    /// Compute the number of workgroups needed to cover `total` items when
110    /// each workgroup handles `workgroup` items (ceiling division).
111    pub fn dispatch_count(total: u32, workgroup: u32) -> u32 {
112        assert!(workgroup > 0, "workgroup size must be > 0");
113        total.div_ceil(workgroup)
114    }
115}
116
117impl Default for WorkgroupSize {
118    fn default() -> Self {
119        Self { x: 64, y: 1, z: 1 }
120    }
121}
122
123// ─────────────────────────────────────────────────────────────────────────────
124// ComputeKernelKind
125// ─────────────────────────────────────────────────────────────────────────────
126
127/// The kind of compute kernel to dispatch.
128#[derive(Debug, Clone, PartialEq, Eq)]
129pub enum ComputeKernelKind {
130    /// Semi-implicit velocity / position integration.
131    VelocityUpdate,
132    /// Pressure Jacobi iteration.
133    PressureJacobi,
134    /// Lennard-Jones particle force evaluation.
135    ParticleForce,
136    /// Neighbour-search (spatial hashing).
137    NeighborSearch,
138    /// User-defined kernel identified by a string tag.
139    Custom(String),
140}
141
142// ─────────────────────────────────────────────────────────────────────────────
143// CpuComputeDispatch
144// ─────────────────────────────────────────────────────────────────────────────
145
146/// Dispatcher that executes compute kernels on the CPU.
147pub struct CpuComputeDispatch {
148    /// Which kernel this dispatcher is configured for.
149    pub kernel: ComputeKernelKind,
150    /// Workgroup size hint (informational on the CPU path).
151    pub workgroup_size: WorkgroupSize,
152}
153
154impl CpuComputeDispatch {
155    /// Create a new dispatcher.
156    pub fn new(kernel: ComputeKernelKind, wg: WorkgroupSize) -> Self {
157        Self {
158            kernel,
159            workgroup_size: wg,
160        }
161    }
162
163    /// Semi-implicit Euler integration for `n` particles.
164    ///
165    /// `pos[i] += vel[i] * dt` then `vel[i] += force[i] / mass[i] * dt`.
166    pub fn dispatch_velocity_update(
167        &self,
168        pos: &mut ComputeBuffer,
169        vel: &mut ComputeBuffer,
170        force: &ComputeBuffer,
171        mass: &ComputeBuffer,
172        dt: f32,
173        n: usize,
174    ) {
175        for i in 0..n {
176            pos.data[i] += vel.data[i] * dt;
177            vel.data[i] += force.data[i] / mass.data[i] * dt;
178        }
179    }
180
181    /// Single Jacobi pressure iteration over an `nx × ny` grid.
182    ///
183    /// Interior points only; boundary cells are left unchanged.
184    ///
185    /// `p[i,j] = (p_old[i+1,j] + p_old[i-1,j] + p_old[i,j+1] + p_old[i,j-1]
186    ///            - dx² * rhs[i,j]) / 4`
187    pub fn dispatch_pressure_jacobi(
188        &self,
189        p: &mut ComputeBuffer,
190        p_old: &ComputeBuffer,
191        rhs: &ComputeBuffer,
192        nx: usize,
193        ny: usize,
194        dx: f32,
195    ) {
196        let dx2 = dx * dx;
197        for j in 1..ny - 1 {
198            for i in 1..nx - 1 {
199                let idx = j * nx + i;
200                p.data[idx] = (p_old.data[idx + 1]
201                    + p_old.data[idx - 1]
202                    + p_old.data[idx + nx]
203                    + p_old.data[idx - nx]
204                    - dx2 * rhs.data[idx])
205                    / 4.0;
206            }
207        }
208    }
209
210    /// O(n²) Lennard-Jones force accumulation.
211    ///
212    /// Positions are stored as interleaved `[x0, y0, x1, y1, …]`.
213    /// Forces are accumulated in-place (`force` is zeroed first).
214    pub fn dispatch_particle_force(
215        &self,
216        pos: &ComputeBuffer,
217        force: &mut ComputeBuffer,
218        eps: f32,
219        sigma: f32,
220        n: usize,
221    ) {
222        // Zero forces first.
223        for v in force.data[..2 * n].iter_mut() {
224            *v = 0.0;
225        }
226
227        for i in 0..n {
228            for j in (i + 1)..n {
229                let dx = pos.data[2 * j] - pos.data[2 * i];
230                let dy = pos.data[2 * j + 1] - pos.data[2 * i + 1];
231                let r2 = dx * dx + dy * dy;
232                if r2 < 1e-12 {
233                    continue;
234                }
235                let sr2 = (sigma * sigma) / r2;
236                let sr6 = sr2 * sr2 * sr2;
237                let sr12 = sr6 * sr6;
238                // F = 24ε/r² * (2(σ/r)^12 - (σ/r)^6)
239                let fmag = 24.0 * eps / r2 * (2.0 * sr12 - sr6);
240                force.data[2 * i] -= fmag * dx;
241                force.data[2 * i + 1] -= fmag * dy;
242                force.data[2 * j] += fmag * dx;
243                force.data[2 * j + 1] += fmag * dy;
244            }
245        }
246    }
247}
248
249// ─────────────────────────────────────────────────────────────────────────────
250// GpuStats
251// ─────────────────────────────────────────────────────────────────────────────
252
253/// Accumulated statistics for compute dispatches.
254#[derive(Debug, Clone, Default)]
255pub struct GpuStats {
256    /// Total number of dispatches recorded.
257    pub dispatch_count: u64,
258    /// Total bytes transferred (reads + writes).
259    pub bytes_transferred: u64,
260    /// Total kernel wall-clock time in milliseconds.
261    pub kernel_time_ms: f64,
262}
263
264impl GpuStats {
265    /// Create zeroed stats.
266    pub fn new() -> Self {
267        Self::default()
268    }
269
270    /// Record a single dispatch event.
271    pub fn record_dispatch(&mut self, bytes: u64, time_ms: f64) {
272        self.dispatch_count += 1;
273        self.bytes_transferred += bytes;
274        self.kernel_time_ms += time_ms;
275    }
276}
277
278// ─────────────────────────────────────────────────────────────────────────────
279// Free-standing solver helpers
280// ─────────────────────────────────────────────────────────────────────────────
281
282/// Single Jacobi sweep over an `nx × ny` pressure grid.
283///
284/// Updates every interior cell of `p_new` using `p_old` and `rhs`.
285pub fn jacobi_step_2d(
286    p_new: &mut Vec<f32>,
287    p_old: &[f32],
288    rhs: &[f32],
289    nx: usize,
290    ny: usize,
291    dx: f32,
292) {
293    let dx2 = dx * dx;
294    for j in 1..ny - 1 {
295        for i in 1..nx - 1 {
296            let idx = j * nx + i;
297            p_new[idx] = (p_old[idx + 1] + p_old[idx - 1] + p_old[idx + nx] + p_old[idx - nx]
298                - dx2 * rhs[idx])
299                / 4.0;
300        }
301    }
302}
303
304/// Iterative Jacobi pressure-Poisson solver.
305///
306/// Runs `n_iter` Jacobi sweeps and returns the final L∞ residual.
307pub fn pressure_poisson_solve(
308    p: &mut Vec<f32>,
309    rhs: &[f32],
310    nx: usize,
311    ny: usize,
312    dx: f32,
313    n_iter: usize,
314) -> f32 {
315    let size = nx * ny;
316    let mut p_old = p.clone();
317
318    for _ in 0..n_iter {
319        jacobi_step_2d(p, &p_old, rhs, nx, ny, dx);
320        p_old.copy_from_slice(&p[..size]);
321    }
322
323    // Compute L∞ residual on interior cells.
324    let dx2 = dx * dx;
325    let mut residual = 0.0_f32;
326    for j in 1..ny - 1 {
327        for i in 1..nx - 1 {
328            let idx = j * nx + i;
329            let lap = (p[idx + 1] + p[idx - 1] + p[idx + nx] + p[idx - nx] - 4.0 * p[idx]) / dx2;
330            let r = (lap - rhs[idx]).abs();
331            if r > residual {
332                residual = r;
333            }
334        }
335    }
336    residual
337}
338
339// ─────────────────────────────────────────────────────────────────────────────
340// PipelineCache – LRU-style pipeline caching
341// ─────────────────────────────────────────────────────────────────────────────
342
343/// A simple LRU-eviction cache for compiled compute pipelines.
344///
345/// Keyed by a string label; evicts the oldest entry when the capacity is reached.
346pub struct PipelineCache {
347    /// Maximum number of entries.
348    capacity: usize,
349    /// Entries in insertion order (oldest first).
350    entries: Vec<(String, CpuComputeDispatch)>,
351}
352
353impl PipelineCache {
354    /// Create a new cache with the given capacity.
355    pub fn new(capacity: usize) -> Self {
356        Self {
357            capacity,
358            entries: Vec::new(),
359        }
360    }
361
362    /// Insert (or replace) a pipeline under `key`.
363    pub fn insert(&mut self, key: &str, pipeline: CpuComputeDispatch) {
364        // Remove existing entry with the same key
365        self.entries.retain(|(k, _)| k != key);
366        // Evict oldest if at capacity
367        while self.entries.len() >= self.capacity {
368            self.entries.remove(0);
369        }
370        self.entries.push((key.to_owned(), pipeline));
371    }
372
373    /// Look up a cached pipeline by key.
374    pub fn get(&self, key: &str) -> Option<&CpuComputeDispatch> {
375        self.entries.iter().find(|(k, _)| k == key).map(|(_, v)| v)
376    }
377
378    /// Number of cached entries.
379    pub fn len(&self) -> usize {
380        self.entries.len()
381    }
382
383    /// Whether the cache is empty.
384    pub fn is_empty(&self) -> bool {
385        self.entries.is_empty()
386    }
387
388    /// Clear all cached pipelines.
389    pub fn clear(&mut self) {
390        self.entries.clear();
391    }
392}
393
394// ─────────────────────────────────────────────────────────────────────────────
395// PipelineStats – dispatch-level statistics
396// ─────────────────────────────────────────────────────────────────────────────
397
398/// Fine-grained statistics for compute pipeline usage.
399#[derive(Debug, Clone, Default)]
400pub struct PipelineStats {
401    /// Total number of dispatches.
402    pub total_dispatches: u64,
403    /// Total number of workgroups launched.
404    pub total_workgroups: u64,
405    /// Total number of invocations (workgroups × workgroup_size).
406    pub total_invocations: u64,
407    /// Number of times a cached pipeline was re-used.
408    pub cache_hits: u64,
409    /// Number of times a pipeline had to be compiled (or was not in cache).
410    pub cache_misses: u64,
411}
412
413impl PipelineStats {
414    /// Record a single dispatch event.
415    pub fn record_dispatch(&mut self, num_workgroups: u64, wg_size: WorkgroupSize) {
416        self.total_dispatches += 1;
417        self.total_workgroups += num_workgroups;
418        self.total_invocations +=
419            num_workgroups * (wg_size.x as u64) * (wg_size.y as u64) * (wg_size.z as u64);
420    }
421
422    /// Cache hit ratio (0.0–1.0). Returns NaN when no lookups occurred.
423    pub fn cache_hit_ratio(&self) -> f64 {
424        let total = self.cache_hits + self.cache_misses;
425        if total == 0 {
426            return f64::NAN;
427        }
428        self.cache_hits as f64 / total as f64
429    }
430}
431
432// ─────────────────────────────────────────────────────────────────────────────
433// MultiPassPipeline – chained compute passes
434// ─────────────────────────────────────────────────────────────────────────────
435
436/// A single compute pass in a multi-pass pipeline.
437#[derive(Debug, Clone)]
438pub struct ComputePass {
439    /// Human-readable label.
440    pub label: String,
441    /// The kernel to dispatch.
442    pub kernel: ComputeKernelKind,
443    /// Workgroup size for this pass.
444    pub workgroup_size: WorkgroupSize,
445    /// Indices of buffers bound to this pass.
446    pub buffer_bindings: Vec<usize>,
447}
448
449/// A sequence of compute passes that execute in order.
450#[derive(Debug)]
451pub struct MultiPassPipeline {
452    /// Human-readable label.
453    pub label: String,
454    /// Ordered list of passes.
455    pub passes: Vec<ComputePass>,
456}
457
458impl MultiPassPipeline {
459    /// Create a new empty multi-pass pipeline.
460    pub fn new(label: &str) -> Self {
461        Self {
462            label: label.to_owned(),
463            passes: Vec::new(),
464        }
465    }
466
467    /// Append a compute pass.
468    pub fn add_pass(&mut self, pass: ComputePass) {
469        self.passes.push(pass);
470    }
471
472    /// Number of passes.
473    pub fn num_passes(&self) -> usize {
474        self.passes.len()
475    }
476}
477
478// ─────────────────────────────────────────────────────────────────────────────
479// Pipeline validation
480// ─────────────────────────────────────────────────────────────────────────────
481
482/// Validate resource bindings for a single compute pass.
483///
484/// Returns a list of error messages (empty if valid).
485pub fn validate_resource_bindings(pass: &ComputePass, buffers: &[ComputeBuffer]) -> Vec<String> {
486    let mut errors = Vec::new();
487    let mut seen = std::collections::HashSet::new();
488    for &idx in &pass.buffer_bindings {
489        if idx >= buffers.len() {
490            errors.push(format!(
491                "Pass '{}': buffer binding {} is out of range (have {} buffers)",
492                pass.label,
493                idx,
494                buffers.len()
495            ));
496        }
497        if !seen.insert(idx) {
498            errors.push(format!(
499                "Pass '{}': Duplicate buffer binding {}",
500                pass.label, idx
501            ));
502        }
503    }
504    errors
505}
506
507/// Validate all passes in a multi-pass pipeline.
508pub fn validate_pipeline(pipeline: &MultiPassPipeline, buffers: &[ComputeBuffer]) -> Vec<String> {
509    let mut errors = Vec::new();
510    for pass in &pipeline.passes {
511        errors.extend(validate_resource_bindings(pass, buffers));
512    }
513    errors
514}
515
516// ─────────────────────────────────────────────────────────────────────────────
517// Additional solver helpers
518// ─────────────────────────────────────────────────────────────────────────────
519
520/// Single SOR (Successive Over-Relaxation) sweep over an `nx × ny` grid.
521///
522/// `omega = 1.0` gives standard Gauss-Seidel; `omega ∈ (1, 2)` gives SOR.
523pub fn sor_step_2d(
524    p: &mut Vec<f32>,
525    p_old: &[f32],
526    rhs: &[f32],
527    nx: usize,
528    ny: usize,
529    dx: f32,
530    omega: f32,
531) {
532    let dx2 = dx * dx;
533    for j in 1..ny - 1 {
534        for i in 1..nx - 1 {
535            let idx = j * nx + i;
536            let gs = (p_old[idx + 1] + p_old[idx - 1] + p_old[idx + nx] + p_old[idx - nx]
537                - dx2 * rhs[idx])
538                / 4.0;
539            p[idx] = (1.0 - omega) * p_old[idx] + omega * gs;
540        }
541    }
542}
543
544/// Red-black Gauss-Seidel sweep (in-place) on an `nx × ny` grid.
545///
546/// Updates "red" cells (i+j even) first, then "black" cells (i+j odd).
547pub fn red_black_gauss_seidel_step(p: &mut Vec<f32>, rhs: &[f32], nx: usize, ny: usize, dx: f32) {
548    let dx2 = dx * dx;
549    // Red sweep (i + j even)
550    for j in 1..ny - 1 {
551        for i in 1..nx - 1 {
552            if (i + j) % 2 == 0 {
553                let idx = j * nx + i;
554                p[idx] =
555                    (p[idx + 1] + p[idx - 1] + p[idx + nx] + p[idx - nx] - dx2 * rhs[idx]) / 4.0;
556            }
557        }
558    }
559    // Black sweep (i + j odd)
560    for j in 1..ny - 1 {
561        for i in 1..nx - 1 {
562            if (i + j) % 2 == 1 {
563                let idx = j * nx + i;
564                p[idx] =
565                    (p[idx + 1] + p[idx - 1] + p[idx + nx] + p[idx - nx] - dx2 * rhs[idx]) / 4.0;
566            }
567        }
568    }
569}
570
571/// Compute the L∞ residual of a 2D Poisson discretisation.
572pub fn compute_linf_residual(p: &[f32], rhs: &[f32], nx: usize, ny: usize, dx: f32) -> f32 {
573    let dx2 = dx * dx;
574    let mut residual = 0.0_f32;
575    for j in 1..ny - 1 {
576        for i in 1..nx - 1 {
577            let idx = j * nx + i;
578            let lap = (p[idx + 1] + p[idx - 1] + p[idx + nx] + p[idx - nx] - 4.0 * p[idx]) / dx2;
579            let r = (lap - rhs[idx]).abs();
580            if r > residual {
581                residual = r;
582            }
583        }
584    }
585    residual
586}
587
588/// O(n²) neighbor search for 2D interleaved positions `[x0, y0, x1, y1, …]`.
589///
590/// Returns a `Vec<Vec`usize`>` where `result[i]` contains the indices of
591/// particles within `cutoff` distance of particle `i`.
592pub fn dispatch_neighbor_search(positions: &[f32], n: usize, cutoff: f32) -> Vec<Vec<usize>> {
593    let cutoff2 = cutoff * cutoff;
594    let mut neighbors = vec![Vec::new(); n];
595    for i in 0..n {
596        for j in (i + 1)..n {
597            let dx = positions[2 * j] - positions[2 * i];
598            let dy = positions[2 * j + 1] - positions[2 * i + 1];
599            let r2 = dx * dx + dy * dy;
600            if r2 < cutoff2 {
601                neighbors[i].push(j);
602                neighbors[j].push(i);
603            }
604        }
605    }
606    neighbors
607}
608
609// ─────────────────────────────────────────────────────────────────────────────
610// Tests
611// ─────────────────────────────────────────────────────────────────────────────
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    // ── BufferUsage ──────────────────────────────────────────────────────────
618
619    #[test]
620    fn buffer_usage_eq() {
621        assert_eq!(BufferUsage::Storage, BufferUsage::Storage);
622        assert_ne!(BufferUsage::Vertex, BufferUsage::Index);
623    }
624
625    #[test]
626    fn buffer_usage_clone() {
627        let u = BufferUsage::Uniform;
628        assert_eq!(u.clone(), BufferUsage::Uniform);
629    }
630
631    // ── ComputeBuffer ────────────────────────────────────────────────────────
632
633    #[test]
634    fn compute_buffer_new_zeroed() {
635        let buf = ComputeBuffer::new(8, BufferUsage::Storage, "test");
636        assert_eq!(buf.data.len(), 8);
637        assert!(buf.data.iter().all(|&v| v == 0.0));
638        assert_eq!(buf.label, "test");
639    }
640
641    #[test]
642    fn compute_buffer_byte_size() {
643        let buf = ComputeBuffer::new(4, BufferUsage::Uniform, "u");
644        assert_eq!(buf.byte_size(), 16);
645    }
646
647    #[test]
648    fn compute_buffer_write_read_roundtrip() {
649        let mut buf = ComputeBuffer::new(8, BufferUsage::Storage, "rw");
650        buf.write_f32(2, &[1.0, 2.0, 3.0]);
651        let out = buf.read_f32(2, 3);
652        assert_eq!(out, vec![1.0, 2.0, 3.0]);
653    }
654
655    #[test]
656    fn compute_buffer_write_at_offset_zero() {
657        let mut buf = ComputeBuffer::new(4, BufferUsage::Storage, "s");
658        buf.write_f32(0, &[9.0, 8.0, 7.0, 6.0]);
659        assert_eq!(buf.data, vec![9.0, 8.0, 7.0, 6.0]);
660    }
661
662    #[test]
663    #[should_panic(expected = "out-of-bounds write")]
664    fn compute_buffer_write_oob_panics() {
665        let mut buf = ComputeBuffer::new(4, BufferUsage::Storage, "oob");
666        buf.write_f32(3, &[1.0, 2.0]); // 3+2 > 4
667    }
668
669    #[test]
670    #[should_panic(expected = "out-of-bounds read")]
671    fn compute_buffer_read_oob_panics() {
672        let buf = ComputeBuffer::new(4, BufferUsage::Storage, "oob");
673        let _ = buf.read_f32(3, 2);
674    }
675
676    // ── WorkgroupSize ────────────────────────────────────────────────────────
677
678    #[test]
679    fn workgroup_dispatch_count_exact() {
680        assert_eq!(WorkgroupSize::dispatch_count(64, 64), 1);
681    }
682
683    #[test]
684    fn workgroup_dispatch_count_ceil() {
685        assert_eq!(WorkgroupSize::dispatch_count(65, 64), 2);
686        assert_eq!(WorkgroupSize::dispatch_count(1, 64), 1);
687    }
688
689    #[test]
690    fn workgroup_dispatch_count_zero_total() {
691        assert_eq!(WorkgroupSize::dispatch_count(0, 64), 0);
692    }
693
694    #[test]
695    fn workgroup_default() {
696        let wg = WorkgroupSize::default();
697        assert_eq!(wg.x, 64);
698        assert_eq!(wg.y, 1);
699        assert_eq!(wg.z, 1);
700    }
701
702    // ── ComputeKernelKind ────────────────────────────────────────────────────
703
704    #[test]
705    fn kernel_kind_custom_eq() {
706        let a = ComputeKernelKind::Custom("foo".into());
707        let b = ComputeKernelKind::Custom("foo".into());
708        assert_eq!(a, b);
709    }
710
711    #[test]
712    fn kernel_kind_variants_neq() {
713        assert_ne!(
714            ComputeKernelKind::VelocityUpdate,
715            ComputeKernelKind::PressureJacobi
716        );
717    }
718
719    // ── CpuComputeDispatch – velocity update ─────────────────────────────────
720
721    #[test]
722    fn velocity_update_basic() {
723        let disp =
724            CpuComputeDispatch::new(ComputeKernelKind::VelocityUpdate, WorkgroupSize::default());
725        let n = 3;
726        let mut pos = ComputeBuffer::new(n, BufferUsage::Storage, "pos");
727        let mut vel = ComputeBuffer::new(n, BufferUsage::Storage, "vel");
728        let mut force = ComputeBuffer::new(n, BufferUsage::Storage, "force");
729        let mut mass = ComputeBuffer::new(n, BufferUsage::Storage, "mass");
730
731        pos.write_f32(0, &[0.0, 1.0, 2.0]);
732        vel.write_f32(0, &[1.0, 0.5, -1.0]);
733        force.write_f32(0, &[0.0, 1.0, 0.0]);
734        mass.write_f32(0, &[1.0, 2.0, 1.0]);
735
736        let dt = 0.1_f32;
737        disp.dispatch_velocity_update(&mut pos, &mut vel, &force, &mass, dt, n);
738
739        // pos[0] = 0.0 + 1.0*0.1 = 0.1,  vel[0] = 1.0 + 0/1*0.1 = 1.0
740        assert!((pos.data[0] - 0.1).abs() < 1e-6);
741        assert!((vel.data[0] - 1.0).abs() < 1e-6);
742        // pos[1] = 1.0 + 0.5*0.1 = 1.05, vel[1] = 0.5 + 1/2*0.1 = 0.55
743        assert!((pos.data[1] - 1.05).abs() < 1e-6);
744        assert!((vel.data[1] - 0.55).abs() < 1e-6);
745    }
746
747    #[test]
748    fn velocity_update_zero_force() {
749        let disp =
750            CpuComputeDispatch::new(ComputeKernelKind::VelocityUpdate, WorkgroupSize::default());
751        let n = 2;
752        let mut pos = ComputeBuffer::new(n, BufferUsage::Storage, "pos");
753        let mut vel = ComputeBuffer::new(n, BufferUsage::Storage, "vel");
754        let force = ComputeBuffer::new(n, BufferUsage::Storage, "force");
755        let mut mass = ComputeBuffer::new(n, BufferUsage::Storage, "mass");
756
757        pos.write_f32(0, &[0.0, 0.0]);
758        vel.write_f32(0, &[2.0, -3.0]);
759        mass.write_f32(0, &[1.0, 1.0]);
760
761        disp.dispatch_velocity_update(&mut pos, &mut vel, &force, &mass, 0.5, n);
762
763        assert!((pos.data[0] - 1.0).abs() < 1e-6);
764        assert!((pos.data[1] - (-1.5)).abs() < 1e-6);
765        // velocities unchanged (zero force)
766        assert!((vel.data[0] - 2.0).abs() < 1e-6);
767        assert!((vel.data[1] - (-3.0)).abs() < 1e-6);
768    }
769
770    // ── CpuComputeDispatch – pressure Jacobi ─────────────────────────────────
771
772    #[test]
773    fn pressure_jacobi_interior_update() {
774        let disp =
775            CpuComputeDispatch::new(ComputeKernelKind::PressureJacobi, WorkgroupSize::default());
776        let nx = 4;
777        let ny = 4;
778        let mut p = ComputeBuffer::new(nx * ny, BufferUsage::Storage, "p");
779        let mut p_old = ComputeBuffer::new(nx * ny, BufferUsage::Storage, "p_old");
780        let rhs = ComputeBuffer::new(nx * ny, BufferUsage::Storage, "rhs");
781
782        // Set p_old to known values; neighbours of (1,1) are all 1.0
783        for v in p_old.data.iter_mut() {
784            *v = 1.0;
785        }
786
787        disp.dispatch_pressure_jacobi(&mut p, &p_old, &rhs, nx, ny, 1.0);
788
789        // p[1*4+1] = (1+1+1+1 - 0)/4 = 1.0
790        let idx = nx + 1;
791        assert!((p.data[idx] - 1.0).abs() < 1e-6);
792    }
793
794    #[test]
795    fn pressure_jacobi_boundary_unchanged() {
796        let disp =
797            CpuComputeDispatch::new(ComputeKernelKind::PressureJacobi, WorkgroupSize::default());
798        let nx = 5;
799        let ny = 5;
800        let mut p = ComputeBuffer::new(nx * ny, BufferUsage::Storage, "p");
801        let p_old = ComputeBuffer::new(nx * ny, BufferUsage::Storage, "p_old");
802        let rhs = ComputeBuffer::new(nx * ny, BufferUsage::Storage, "rhs");
803
804        // Boundary should remain zero.
805        disp.dispatch_pressure_jacobi(&mut p, &p_old, &rhs, nx, ny, 1.0);
806        assert_eq!(p.data[0], 0.0); // corner
807        assert_eq!(p.data[4], 0.0); // top-right corner
808    }
809
810    // ── CpuComputeDispatch – LJ particle force ───────────────────────────────
811
812    #[test]
813    fn particle_force_zero_at_large_sep() {
814        let disp =
815            CpuComputeDispatch::new(ComputeKernelKind::ParticleForce, WorkgroupSize::default());
816        let n = 2;
817        // Two particles very far apart → tiny force.
818        let mut pos = ComputeBuffer::new(2 * n, BufferUsage::Storage, "pos");
819        let mut force = ComputeBuffer::new(2 * n, BufferUsage::Storage, "force");
820        pos.write_f32(0, &[0.0, 0.0, 1000.0, 0.0]);
821
822        disp.dispatch_particle_force(&pos, &mut force, 1.0, 1.0, n);
823        // Force should be negligible at r=1000σ
824        assert!(force.data[0].abs() < 1e-10);
825    }
826
827    #[test]
828    fn particle_force_newton3() {
829        let disp =
830            CpuComputeDispatch::new(ComputeKernelKind::ParticleForce, WorkgroupSize::default());
831        let n = 2;
832        let mut pos = ComputeBuffer::new(2 * n, BufferUsage::Storage, "pos");
833        let mut force = ComputeBuffer::new(2 * n, BufferUsage::Storage, "force");
834        pos.write_f32(0, &[0.0, 0.0, 1.5, 0.0]);
835
836        disp.dispatch_particle_force(&pos, &mut force, 1.0, 1.0, n);
837        // Newton's third law: f0 + f1 == 0
838        assert!((force.data[0] + force.data[2]).abs() < 1e-5);
839        assert!((force.data[1] + force.data[3]).abs() < 1e-5);
840    }
841
842    // ── GpuStats ─────────────────────────────────────────────────────────────
843
844    #[test]
845    fn gpu_stats_initial_zero() {
846        let s = GpuStats::new();
847        assert_eq!(s.dispatch_count, 0);
848        assert_eq!(s.bytes_transferred, 0);
849        assert_eq!(s.kernel_time_ms, 0.0);
850    }
851
852    #[test]
853    fn gpu_stats_accumulate() {
854        let mut s = GpuStats::new();
855        s.record_dispatch(128, 0.5);
856        s.record_dispatch(256, 1.0);
857        assert_eq!(s.dispatch_count, 2);
858        assert_eq!(s.bytes_transferred, 384);
859        assert!((s.kernel_time_ms - 1.5).abs() < 1e-9);
860    }
861
862    // ── jacobi_step_2d ───────────────────────────────────────────────────────
863
864    #[test]
865    fn jacobi_step_2d_uniform_field() {
866        let nx = 4;
867        let ny = 4;
868        let size = nx * ny;
869        let mut p_new = vec![0.0_f32; size];
870        let p_old = vec![1.0_f32; size];
871        let rhs = vec![0.0_f32; size];
872
873        jacobi_step_2d(&mut p_new, &p_old, &rhs, nx, ny, 1.0);
874
875        // Uniform field → interior stays 1.0.
876        for j in 1..ny - 1 {
877            for i in 1..nx - 1 {
878                assert!((p_new[j * nx + i] - 1.0).abs() < 1e-6);
879            }
880        }
881    }
882
883    #[test]
884    fn jacobi_step_2d_rhs_effect() {
885        let nx = 4;
886        let ny = 4;
887        let size = nx * ny;
888        let mut p_new = vec![0.0_f32; size];
889        let p_old = vec![4.0_f32; size];
890        // rhs = 4 at every interior point
891        let rhs = vec![4.0_f32; size];
892
893        jacobi_step_2d(&mut p_new, &p_old, &rhs, nx, ny, 1.0);
894        // p[i,j] = (4+4+4+4 - 1²*4)/4 = (16-4)/4 = 3
895        for j in 1..ny - 1 {
896            for i in 1..nx - 1 {
897                assert!((p_new[j * nx + i] - 3.0).abs() < 1e-6);
898            }
899        }
900    }
901
902    // ── pressure_poisson_solve ───────────────────────────────────────────────
903
904    #[test]
905    fn pressure_poisson_zero_rhs_zero_bc() {
906        // Zero RHS + zero BCs → solution stays zero → zero residual.
907        let nx = 5;
908        let ny = 5;
909        let mut p = vec![0.0_f32; nx * ny];
910        let rhs = vec![0.0_f32; nx * ny];
911        let residual = pressure_poisson_solve(&mut p, &rhs, nx, ny, 0.1, 50);
912        assert!(residual < 1e-6, "residual={residual}");
913    }
914
915    #[test]
916    fn pressure_poisson_residual_decreases() {
917        let nx = 6;
918        let ny = 6;
919        let mut p1 = vec![0.0_f32; nx * ny];
920        let mut p2 = p1.clone();
921        let rhs: Vec<f32> = (0..(nx * ny)).map(|k| (k as f32).sin()).collect();
922        let dx = 0.1;
923
924        let r1 = pressure_poisson_solve(&mut p1, &rhs, nx, ny, dx, 10);
925        let r2 = pressure_poisson_solve(&mut p2, &rhs, nx, ny, dx, 200);
926        assert!(
927            r2 <= r1 + 1e-4,
928            "more iterations should not increase residual (r1={r1}, r2={r2})"
929        );
930    }
931
932    // ── PipelineCache ──────────────────────────────────────────────────────
933
934    #[test]
935    fn pipeline_cache_insert_and_get() {
936        let mut cache = PipelineCache::new(4);
937        let disp =
938            CpuComputeDispatch::new(ComputeKernelKind::VelocityUpdate, WorkgroupSize::default());
939        cache.insert("vel_update", disp);
940        assert!(cache.get("vel_update").is_some());
941        assert!(cache.get("nonexistent").is_none());
942    }
943
944    #[test]
945    fn pipeline_cache_eviction() {
946        let mut cache = PipelineCache::new(2);
947        let d1 =
948            CpuComputeDispatch::new(ComputeKernelKind::VelocityUpdate, WorkgroupSize::default());
949        let d2 =
950            CpuComputeDispatch::new(ComputeKernelKind::PressureJacobi, WorkgroupSize::default());
951        let d3 =
952            CpuComputeDispatch::new(ComputeKernelKind::ParticleForce, WorkgroupSize::default());
953        cache.insert("a", d1);
954        cache.insert("b", d2);
955        cache.insert("c", d3); // should evict "a"
956        assert!(cache.get("a").is_none());
957        assert!(cache.get("b").is_some());
958        assert!(cache.get("c").is_some());
959    }
960
961    #[test]
962    fn pipeline_cache_replace() {
963        let mut cache = PipelineCache::new(4);
964        let d1 =
965            CpuComputeDispatch::new(ComputeKernelKind::VelocityUpdate, WorkgroupSize::default());
966        let d2 = CpuComputeDispatch::new(
967            ComputeKernelKind::ParticleForce,
968            WorkgroupSize { x: 128, y: 1, z: 1 },
969        );
970        cache.insert("key", d1);
971        cache.insert("key", d2);
972        let entry = cache.get("key").unwrap();
973        assert_eq!(entry.kernel, ComputeKernelKind::ParticleForce);
974    }
975
976    // ── PipelineStats ──────────────────────────────────────────────────────
977
978    #[test]
979    fn pipeline_stats_default() {
980        let stats = PipelineStats::default();
981        assert_eq!(stats.total_dispatches, 0);
982        assert_eq!(stats.total_workgroups, 0);
983        assert_eq!(stats.total_invocations, 0);
984        assert_eq!(stats.cache_hits, 0);
985        assert_eq!(stats.cache_misses, 0);
986    }
987
988    #[test]
989    fn pipeline_stats_record() {
990        let mut stats = PipelineStats::default();
991        stats.record_dispatch(4, WorkgroupSize { x: 64, y: 1, z: 1 });
992        assert_eq!(stats.total_dispatches, 1);
993        assert_eq!(stats.total_workgroups, 4);
994        assert_eq!(stats.total_invocations, 4 * 64);
995    }
996
997    #[test]
998    fn pipeline_stats_record_3d_workgroup() {
999        let mut stats = PipelineStats::default();
1000        stats.record_dispatch(2, WorkgroupSize { x: 8, y: 8, z: 4 });
1001        assert_eq!(stats.total_dispatches, 1);
1002        assert_eq!(stats.total_workgroups, 2);
1003        assert_eq!(stats.total_invocations, 2 * 8 * 8 * 4);
1004    }
1005
1006    #[test]
1007    fn pipeline_stats_cache_ratio() {
1008        let mut stats = PipelineStats::default();
1009        assert!(stats.cache_hit_ratio().is_nan() || stats.cache_hit_ratio() == 0.0);
1010        stats.cache_hits = 3;
1011        stats.cache_misses = 1;
1012        assert!((stats.cache_hit_ratio() - 0.75).abs() < 1e-6);
1013    }
1014
1015    // ── MultiPassPipeline ──────────────────────────────────────────────────
1016
1017    #[test]
1018    fn multi_pass_empty() {
1019        let mp = MultiPassPipeline::new("empty");
1020        assert_eq!(mp.passes.len(), 0);
1021        assert_eq!(mp.label, "empty");
1022    }
1023
1024    #[test]
1025    fn multi_pass_execute_add_scale() {
1026        // pass 0: fill buffer with [1, 2, 3, 4]
1027        // pass 1: scale by 2 → [2, 4, 6, 8]
1028        let mut mp = MultiPassPipeline::new("add_scale");
1029        mp.add_pass(ComputePass {
1030            label: "fill".into(),
1031            kernel: ComputeKernelKind::Custom("fill".into()),
1032            workgroup_size: WorkgroupSize::default(),
1033            buffer_bindings: vec![0],
1034        });
1035        mp.add_pass(ComputePass {
1036            label: "scale".into(),
1037            kernel: ComputeKernelKind::Custom("scale".into()),
1038            workgroup_size: WorkgroupSize::default(),
1039            buffer_bindings: vec![0],
1040        });
1041        assert_eq!(mp.passes.len(), 2);
1042        assert_eq!(mp.passes[0].label, "fill");
1043        assert_eq!(mp.passes[1].label, "scale");
1044    }
1045
1046    #[test]
1047    fn multi_pass_dispatch_velocity_chain() {
1048        // Test chaining two velocity update passes
1049        let mut mp = MultiPassPipeline::new("vel_chain");
1050        mp.add_pass(ComputePass {
1051            label: "step1".into(),
1052            kernel: ComputeKernelKind::VelocityUpdate,
1053            workgroup_size: WorkgroupSize::default(),
1054            buffer_bindings: vec![0, 1, 2, 3],
1055        });
1056        mp.add_pass(ComputePass {
1057            label: "step2".into(),
1058            kernel: ComputeKernelKind::VelocityUpdate,
1059            workgroup_size: WorkgroupSize::default(),
1060            buffer_bindings: vec![0, 1, 2, 3],
1061        });
1062
1063        let n = 2;
1064        let mut pos = ComputeBuffer::new(n, BufferUsage::Storage, "pos");
1065        let mut vel = ComputeBuffer::new(n, BufferUsage::Storage, "vel");
1066        let force = ComputeBuffer::new(n, BufferUsage::Storage, "force");
1067        let mut mass = ComputeBuffer::new(n, BufferUsage::Storage, "mass");
1068
1069        pos.write_f32(0, &[0.0, 0.0]);
1070        vel.write_f32(0, &[1.0, 2.0]);
1071        mass.write_f32(0, &[1.0, 1.0]);
1072
1073        let dt = 0.1_f32;
1074
1075        // Execute passes manually (since we simulate CPU-side)
1076        for pass in &mp.passes {
1077            if pass.kernel == ComputeKernelKind::VelocityUpdate {
1078                let disp =
1079                    CpuComputeDispatch::new(ComputeKernelKind::VelocityUpdate, pass.workgroup_size);
1080                disp.dispatch_velocity_update(&mut pos, &mut vel, &force, &mass, dt, n);
1081            }
1082        }
1083        // After 2 steps with zero force: pos = vel*dt*2
1084        assert!((pos.data[0] - 0.2).abs() < 1e-5);
1085        assert!((pos.data[1] - 0.4).abs() < 1e-5);
1086    }
1087
1088    // ── Pipeline validation ────────────────────────────────────────────────
1089
1090    #[test]
1091    fn validate_binding_valid() {
1092        let buffers = vec![
1093            ComputeBuffer::new(16, BufferUsage::Storage, "buf0"),
1094            ComputeBuffer::new(16, BufferUsage::Uniform, "buf1"),
1095        ];
1096        let pass = ComputePass {
1097            label: "test".into(),
1098            kernel: ComputeKernelKind::VelocityUpdate,
1099            workgroup_size: WorkgroupSize::default(),
1100            buffer_bindings: vec![0, 1],
1101        };
1102        let errors = validate_resource_bindings(&pass, &buffers);
1103        assert!(errors.is_empty());
1104    }
1105
1106    #[test]
1107    fn validate_binding_out_of_range() {
1108        let buffers = vec![ComputeBuffer::new(16, BufferUsage::Storage, "buf0")];
1109        let pass = ComputePass {
1110            label: "test".into(),
1111            kernel: ComputeKernelKind::VelocityUpdate,
1112            workgroup_size: WorkgroupSize::default(),
1113            buffer_bindings: vec![0, 5],
1114        };
1115        let errors = validate_resource_bindings(&pass, &buffers);
1116        assert_eq!(errors.len(), 1);
1117        assert!(errors[0].contains("out of range"));
1118    }
1119
1120    #[test]
1121    fn validate_binding_duplicate() {
1122        let buffers = vec![ComputeBuffer::new(16, BufferUsage::Storage, "buf0")];
1123        let pass = ComputePass {
1124            label: "test".into(),
1125            kernel: ComputeKernelKind::VelocityUpdate,
1126            workgroup_size: WorkgroupSize::default(),
1127            buffer_bindings: vec![0, 0],
1128        };
1129        let errors = validate_resource_bindings(&pass, &buffers);
1130        assert_eq!(errors.len(), 1);
1131        assert!(errors[0].contains("Duplicate"));
1132    }
1133
1134    #[test]
1135    fn validate_pipeline_all_passes() {
1136        let buffers = vec![ComputeBuffer::new(16, BufferUsage::Storage, "buf0")];
1137        let mut mp = MultiPassPipeline::new("test");
1138        mp.add_pass(ComputePass {
1139            label: "good".into(),
1140            kernel: ComputeKernelKind::VelocityUpdate,
1141            workgroup_size: WorkgroupSize::default(),
1142            buffer_bindings: vec![0],
1143        });
1144        mp.add_pass(ComputePass {
1145            label: "bad".into(),
1146            kernel: ComputeKernelKind::PressureJacobi,
1147            workgroup_size: WorkgroupSize::default(),
1148            buffer_bindings: vec![0, 3],
1149        });
1150        let errors = validate_pipeline(&mp, &buffers);
1151        assert_eq!(errors.len(), 1); // only the bad pass has errors
1152    }
1153
1154    // ── ComputeBuffer additional tests ─────────────────────────────────────
1155
1156    #[test]
1157    fn compute_buffer_clone() {
1158        let mut buf = ComputeBuffer::new(4, BufferUsage::Storage, "orig");
1159        buf.write_f32(0, &[1.0, 2.0, 3.0, 4.0]);
1160        let cloned = buf.clone();
1161        assert_eq!(buf.data, cloned.data);
1162        assert_eq!(buf.label, cloned.label);
1163    }
1164
1165    #[test]
1166    fn compute_buffer_staging_usage() {
1167        let buf = ComputeBuffer::new(8, BufferUsage::Staging, "staging");
1168        assert_eq!(buf.usage, BufferUsage::Staging);
1169        assert_eq!(buf.byte_size(), 32);
1170    }
1171
1172    // ── Workgroup additional tests ─────────────────────────────────────────
1173
1174    #[test]
1175    fn workgroup_dispatch_count_large() {
1176        assert_eq!(WorkgroupSize::dispatch_count(1024, 256), 4);
1177        assert_eq!(WorkgroupSize::dispatch_count(1025, 256), 5);
1178    }
1179
1180    // ── GpuStats additional tests ──────────────────────────────────────────
1181
1182    #[test]
1183    fn gpu_stats_clone() {
1184        let mut s = GpuStats::new();
1185        s.record_dispatch(100, 1.5);
1186        let s2 = s.clone();
1187        assert_eq!(s.dispatch_count, s2.dispatch_count);
1188        assert_eq!(s.bytes_transferred, s2.bytes_transferred);
1189        assert!((s.kernel_time_ms - s2.kernel_time_ms).abs() < 1e-12);
1190    }
1191
1192    // ── dispatch_particle_force additional tests ──────────────────────────
1193
1194    #[test]
1195    fn particle_force_repulsive_at_close_range() {
1196        let disp =
1197            CpuComputeDispatch::new(ComputeKernelKind::ParticleForce, WorkgroupSize::default());
1198        let n = 2;
1199        let mut pos = ComputeBuffer::new(2 * n, BufferUsage::Storage, "pos");
1200        let mut force = ComputeBuffer::new(2 * n, BufferUsage::Storage, "force");
1201        // Two particles at distance 0.9σ (< σ → repulsive region)
1202        pos.write_f32(0, &[0.0, 0.0, 0.9, 0.0]);
1203        disp.dispatch_particle_force(&pos, &mut force, 1.0, 1.0, n);
1204        // Force on particle 0 should push it away from particle 1 (negative x)
1205        assert!(
1206            force.data[0] < 0.0,
1207            "expected repulsive force, got {}",
1208            force.data[0]
1209        );
1210    }
1211
1212    #[test]
1213    fn particle_force_three_particles() {
1214        let disp =
1215            CpuComputeDispatch::new(ComputeKernelKind::ParticleForce, WorkgroupSize::default());
1216        let n = 3;
1217        let mut pos = ComputeBuffer::new(2 * n, BufferUsage::Storage, "pos");
1218        let mut force = ComputeBuffer::new(2 * n, BufferUsage::Storage, "force");
1219        // Triangle arrangement
1220        pos.write_f32(0, &[0.0, 0.0, 2.0, 0.0, 1.0, 1.732]);
1221        disp.dispatch_particle_force(&pos, &mut force, 1.0, 1.0, n);
1222
1223        // Total momentum conservation: sum of all forces should be zero
1224        let fx_total = force.data[0] + force.data[2] + force.data[4];
1225        let fy_total = force.data[1] + force.data[3] + force.data[5];
1226        assert!(fx_total.abs() < 1e-5, "fx_total={fx_total}");
1227        assert!(fy_total.abs() < 1e-5, "fy_total={fy_total}");
1228    }
1229
1230    // ── pressure_poisson_solve additional tests ───────────────────────────
1231
1232    #[test]
1233    fn pressure_poisson_uniform_rhs() {
1234        let nx = 8;
1235        let ny = 8;
1236        let mut p = vec![0.0_f32; nx * ny];
1237        let rhs = vec![1.0_f32; nx * ny];
1238        let residual = pressure_poisson_solve(&mut p, &rhs, nx, ny, 0.1, 500);
1239        // After many iterations residual should decrease significantly
1240        assert!(residual < 10.0, "residual={residual}");
1241    }
1242
1243    // ── SOR solver ────────────────────────────────────────────────────────
1244
1245    #[test]
1246    fn sor_step_uniform_field() {
1247        let nx = 4;
1248        let ny = 4;
1249        let mut p = vec![0.0_f32; nx * ny];
1250        let rhs = vec![0.0_f32; nx * ny];
1251        let p_ref = vec![1.0_f32; nx * ny];
1252        sor_step_2d(&mut p, &p_ref, &rhs, nx, ny, 1.0, 1.0);
1253        // With omega=1.0 this is the same as Jacobi
1254        for j in 1..ny - 1 {
1255            for i in 1..nx - 1 {
1256                assert!((p[j * nx + i] - 1.0).abs() < 1e-6);
1257            }
1258        }
1259    }
1260
1261    #[test]
1262    fn sor_step_over_relaxation() {
1263        // SOR with omega > 1 should differ from standard Jacobi (omega=1)
1264        let nx = 6;
1265        let ny = 6;
1266        let rhs = vec![0.0_f32; nx * ny];
1267
1268        // Non-uniform reference: set boundary to 1, interior p_old to 0
1269        let mut p_ref = vec![0.0_f32; nx * ny];
1270        for i in 0..nx {
1271            p_ref[i] = 1.0;
1272            p_ref[(ny - 1) * nx + i] = 1.0;
1273        }
1274        for j in 0..ny {
1275            p_ref[j * nx] = 1.0;
1276            p_ref[j * nx + nx - 1] = 1.0;
1277        }
1278
1279        let mut p_jac = vec![0.0_f32; nx * ny];
1280        let mut p_sor = vec![0.0_f32; nx * ny];
1281        sor_step_2d(&mut p_jac, &p_ref, &rhs, nx, ny, 1.0, 1.0);
1282        sor_step_2d(&mut p_sor, &p_ref, &rhs, nx, ny, 1.0, 1.5);
1283
1284        // SOR result should differ from Jacobi for interior nodes next to boundary
1285        let idx = nx + 1; // has 2 boundary neighbors
1286        // Jacobi: (1+0+1+0)/4 = 0.5, SOR: (1-1.5)*0 + 1.5*0.5 = 0.75
1287        assert!(
1288            (p_sor[idx] - p_jac[idx]).abs() > 0.01,
1289            "SOR and Jacobi should differ: SOR={}, Jac={}",
1290            p_sor[idx],
1291            p_jac[idx]
1292        );
1293    }
1294
1295    // ── Red-black Gauss-Seidel ────────────────────────────────────────────
1296
1297    #[test]
1298    fn red_black_gs_uniform() {
1299        let nx = 6;
1300        let ny = 6;
1301        let mut p = vec![1.0_f32; nx * ny];
1302        let rhs = vec![0.0_f32; nx * ny];
1303        red_black_gauss_seidel_step(&mut p, &rhs, nx, ny, 1.0);
1304        // Uniform field is a fixed point with zero rhs
1305        for j in 1..ny - 1 {
1306            for i in 1..nx - 1 {
1307                assert!((p[j * nx + i] - 1.0).abs() < 1e-6);
1308            }
1309        }
1310    }
1311
1312    #[test]
1313    fn red_black_gs_converges() {
1314        let nx = 8;
1315        let ny = 8;
1316        let mut p = vec![0.0_f32; nx * ny];
1317        let rhs = vec![0.0_f32; nx * ny];
1318        // Set boundary to 1
1319        for i in 0..nx {
1320            p[i] = 1.0;
1321            p[(ny - 1) * nx + i] = 1.0;
1322        }
1323        for j in 0..ny {
1324            p[j * nx] = 1.0;
1325            p[j * nx + nx - 1] = 1.0;
1326        }
1327        // Multiple sweeps should converge interior toward 1.0
1328        for _ in 0..200 {
1329            red_black_gauss_seidel_step(&mut p, &rhs, nx, ny, 1.0);
1330        }
1331        let center = p[(ny / 2) * nx + nx / 2];
1332        assert!((center - 1.0).abs() < 0.01, "center={center}");
1333    }
1334
1335    // ── compute_linf_residual ─────────────────────────────────────────────
1336
1337    #[test]
1338    fn linf_residual_zero_for_exact() {
1339        // Uniform field with zero rhs is an exact solution
1340        let nx = 4;
1341        let ny = 4;
1342        let p = vec![1.0_f32; nx * ny];
1343        let rhs = vec![0.0_f32; nx * ny];
1344        let res = compute_linf_residual(&p, &rhs, nx, ny, 1.0);
1345        assert!(res < 1e-6, "res={res}");
1346    }
1347
1348    #[test]
1349    fn linf_residual_nonzero_for_wrong() {
1350        let nx = 4;
1351        let ny = 4;
1352        let mut p = vec![0.0_f32; nx * ny];
1353        p[nx + 1] = 100.0; // big spike
1354        let rhs = vec![0.0_f32; nx * ny];
1355        let res = compute_linf_residual(&p, &rhs, nx, ny, 1.0);
1356        assert!(res > 1.0, "expected large residual, got {res}");
1357    }
1358
1359    // ── Dispatch neighbor search ──────────────────────────────────────────
1360
1361    #[test]
1362    fn dispatch_neighbor_search_basic() {
1363        let n = 4;
1364        let positions = vec![
1365            0.0_f32, 0.0, // particle 0
1366            0.5, 0.0, // particle 1 (close to 0)
1367            5.0, 5.0, // particle 2 (far)
1368            0.3, 0.3, // particle 3 (close to 0 and 1)
1369        ];
1370        let neighbors = dispatch_neighbor_search(&positions, n, 1.0);
1371        // particle 0 should have neighbors 1 and 3
1372        assert!(neighbors[0].contains(&1));
1373        assert!(neighbors[0].contains(&3));
1374        // particle 2 should have no neighbors
1375        assert!(neighbors[2].is_empty());
1376    }
1377
1378    #[test]
1379    fn dispatch_neighbor_search_all_close() {
1380        let n = 3;
1381        let positions = vec![0.0_f32, 0.0, 0.1, 0.0, 0.0, 0.1];
1382        let neighbors = dispatch_neighbor_search(&positions, n, 1.0);
1383        // All particles within cutoff of each other
1384        assert_eq!(neighbors[0].len(), 2);
1385        assert_eq!(neighbors[1].len(), 2);
1386        assert_eq!(neighbors[2].len(), 2);
1387    }
1388
1389    #[test]
1390    fn dispatch_neighbor_search_none() {
1391        let n = 2;
1392        let positions = vec![0.0_f32, 0.0, 100.0, 100.0];
1393        let neighbors = dispatch_neighbor_search(&positions, n, 1.0);
1394        assert!(neighbors[0].is_empty());
1395        assert!(neighbors[1].is_empty());
1396    }
1397}