Skip to main content

oxiphysics_gpu/parallel/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5#![allow(clippy::needless_range_loop)]
6use rayon::prelude::*;
7
8use super::types::{LoadBalancePlan, LoadBalanceStrategy, WorkStealQueue};
9
10#[inline]
11pub(super) fn dist3(a: [f64; 3], b: [f64; 3]) -> f64 {
12    let dx = a[0] - b[0];
13    let dy = a[1] - b[1];
14    let dz = a[2] - b[2];
15    (dx * dx + dy * dy + dz * dz).sqrt()
16}
17/// Parallel SPH density computation.
18///
19/// For each particle `i`, computes `rho_i = sum_j m_j * W(|r_i - r_j|, h)`.
20/// The outer loop runs in parallel via Rayon.
21///
22/// # Arguments
23/// * `positions`  - slice of 3-D particle positions.
24/// * `masses`     - per-particle masses (same length as `positions`).
25/// * `h`          - smoothing length.
26/// * `kernel_fn`  - smoothing kernel `W(r, h)` callable from multiple threads.
27pub fn parallel_sph_density(
28    positions: &[[f64; 3]],
29    masses: &[f64],
30    h: f64,
31    kernel_fn: impl Fn(f64, f64) -> f64 + Sync,
32) -> Vec<f64> {
33    positions
34        .par_iter()
35        .map(|&pi| {
36            positions
37                .iter()
38                .zip(masses.iter())
39                .map(|(&pj, &mj)| mj * kernel_fn(dist3(pi, pj), h))
40                .sum()
41        })
42        .collect()
43}
44/// Parallel pairwise force accumulation.
45///
46/// Computes the net force on every particle.  The outer loop (over particle
47/// `i`) runs in parallel; each thread independently sums the contributions
48/// from all `j != i`.
49///
50/// # Arguments
51/// * `positions` - particle positions.
52/// * `n`         - number of particles (must equal `positions.len()`).
53/// * `force_fn`  - `force_fn(i, j, r_ij) -> force_on_i_from_j`.
54pub fn parallel_pairwise_forces(
55    positions: &[[f64; 3]],
56    n: usize,
57    force_fn: impl Fn(usize, usize, [f64; 3]) -> [f64; 3] + Sync,
58) -> Vec<[f64; 3]> {
59    (0..n)
60        .into_par_iter()
61        .map(|i| {
62            let mut f = [0.0f64; 3];
63            for j in 0..n {
64                if i == j {
65                    continue;
66                }
67                let r_ij = [
68                    positions[j][0] - positions[i][0],
69                    positions[j][1] - positions[i][1],
70                    positions[j][2] - positions[i][2],
71                ];
72                let fij = force_fn(i, j, r_ij);
73                f[0] += fij[0];
74                f[1] += fij[1];
75                f[2] += fij[2];
76            }
77            f
78        })
79        .collect()
80}
81/// Parallel Lennard-Jones (12-6) force computation.
82///
83/// For each particle `i`, accumulates contributions from all `j` within
84/// `cutoff`.  The potential is `U = 4*eps*[(sig/r)^12 - (sig/r)^6]`, giving
85/// `F_i = sum_{j!=i} 4*eps [12(sig/r)^12 - 6(sig/r)^6] / r * r_hat_ij`.
86///
87/// Interactions beyond `cutoff` are skipped.
88pub fn parallel_lj_forces(
89    positions: &[[f64; 3]],
90    epsilon: f64,
91    sigma: f64,
92    cutoff: f64,
93) -> Vec<[f64; 3]> {
94    let n = positions.len();
95    (0..n)
96        .into_par_iter()
97        .map(|i| {
98            let mut f = [0.0f64; 3];
99            for j in 0..n {
100                if i == j {
101                    continue;
102                }
103                let dx = positions[j][0] - positions[i][0];
104                let dy = positions[j][1] - positions[i][1];
105                let dz = positions[j][2] - positions[i][2];
106                let r2 = dx * dx + dy * dy + dz * dz;
107                let r = r2.sqrt();
108                if r >= cutoff || r < 1e-12 {
109                    continue;
110                }
111                let sr = sigma / r;
112                let sr6 = sr.powi(6);
113                let sr12 = sr6 * sr6;
114                let mag = 4.0 * epsilon * (12.0 * sr12 - 6.0 * sr6) / r2;
115                f[0] -= mag * dx;
116                f[1] -= mag * dy;
117                f[2] -= mag * dz;
118            }
119            f
120        })
121        .collect()
122}
123/// Parallel velocity-Verlet position and velocity half-update.
124///
125/// Updates positions with `x += v*dt + 0.5*a*dt^2` and velocities with
126/// `v += 0.5*a*dt` (first half of the Verlet velocity update; call again
127/// after recomputing forces for the second half).
128///
129/// The loop runs in parallel via `par_iter_mut`.
130pub fn parallel_verlet_step(
131    positions: &mut Vec<[f64; 3]>,
132    velocities: &mut Vec<[f64; 3]>,
133    forces: &[[f64; 3]],
134    masses: &[f64],
135    dt: f64,
136) {
137    positions
138        .par_iter_mut()
139        .zip(velocities.par_iter_mut())
140        .zip(forces.par_iter())
141        .zip(masses.par_iter())
142        .for_each(|(((pos, vel), force), &mass)| {
143            let inv_m = 1.0 / mass;
144            for k in 0..3 {
145                let a = force[k] * inv_m;
146                pos[k] += vel[k] * dt + 0.5 * a * dt * dt;
147                vel[k] += 0.5 * a * dt;
148            }
149        });
150}
151/// Parallel AABB overlap detection.
152///
153/// Returns all pairs `(i, j)` with `i < j` whose axis-aligned bounding boxes
154/// overlap.  The outer loop runs in parallel; each thread contributes matching
155/// pairs into a local vector that is then flattened.
156pub fn parallel_aabb_pairs(aabbs: &[([f64; 3], [f64; 3])]) -> Vec<(usize, usize)> {
157    let n = aabbs.len();
158    (0..n)
159        .into_par_iter()
160        .flat_map(|i| {
161            let mut local = Vec::new();
162            let (min_i, max_i) = aabbs[i];
163            for j in (i + 1)..n {
164                let (min_j, max_j) = aabbs[j];
165                let overlap = (0..3).all(|k| min_i[k] <= max_j[k] && min_j[k] <= max_i[k]);
166                if overlap {
167                    local.push((i, j));
168                }
169            }
170            local
171        })
172        .collect()
173}
174/// Execute `f(i)` for `i` in `0..n`, splitting into chunks of `chunk_size`.
175///
176/// Currently processes chunks sequentially; prefer [`parallel_sph_density`]
177/// and the other parallel kernels for performance-critical code.
178pub fn parallel_for(n: usize, chunk_size: usize, f: impl Fn(usize)) {
179    let cs = if chunk_size == 0 { 1 } else { chunk_size };
180    for start in (0..n).step_by(cs) {
181        let end = (start + cs).min(n);
182        for i in start..end {
183            f(i);
184        }
185    }
186}
187/// Parallel sum reduction using Rayon.
188#[allow(dead_code)]
189pub fn parallel_reduce_sum(data: &[f64]) -> f64 {
190    data.par_iter().copied().sum()
191}
192/// Parallel max reduction using Rayon.
193#[allow(dead_code)]
194pub fn parallel_reduce_max(data: &[f64]) -> f64 {
195    data.par_iter()
196        .copied()
197        .reduce(|| f64::NEG_INFINITY, f64::max)
198}
199/// Parallel min reduction using Rayon.
200#[allow(dead_code)]
201pub fn parallel_reduce_min(data: &[f64]) -> f64 {
202    data.par_iter().copied().reduce(|| f64::INFINITY, f64::min)
203}
204/// Parallel dot product of two slices.
205#[allow(dead_code)]
206pub fn parallel_dot_product(a: &[f64], b: &[f64]) -> f64 {
207    a.par_iter()
208        .zip(b.par_iter())
209        .map(|(&ai, &bi)| ai * bi)
210        .sum()
211}
212/// Parallel L2 norm (Euclidean norm).
213#[allow(dead_code)]
214pub fn parallel_norm2(data: &[f64]) -> f64 {
215    let sum_sq: f64 = data.par_iter().map(|&x| x * x).sum();
216    sum_sq.sqrt()
217}
218/// Parallel mean.
219#[allow(dead_code)]
220pub fn parallel_mean(data: &[f64]) -> f64 {
221    if data.is_empty() {
222        return 0.0;
223    }
224    let sum: f64 = data.par_iter().copied().sum();
225    sum / data.len() as f64
226}
227/// Parallel variance (population variance).
228#[allow(dead_code)]
229pub fn parallel_variance(data: &[f64]) -> f64 {
230    if data.is_empty() {
231        return 0.0;
232    }
233    let mean = parallel_mean(data);
234    let sum_sq: f64 = data.par_iter().map(|&x| (x - mean) * (x - mean)).sum();
235    sum_sq / data.len() as f64
236}
237/// Two-pass parallel reduction: compute both sum and count in one pass.
238#[allow(dead_code)]
239pub fn parallel_sum_count(data: &[f64]) -> (f64, usize) {
240    data.par_iter()
241        .copied()
242        .fold(|| (0.0f64, 0usize), |(s, c), x| (s + x, c + 1))
243        .reduce(|| (0.0, 0), |(s1, c1), (s2, c2)| (s1 + s2, c1 + c2))
244}
245/// Parallel reduction with a custom binary operator.
246///
247/// `identity` is the identity element for the operator (e.g. 0.0 for add).
248/// `op` must be associative and commutative for correctness.
249#[allow(dead_code)]
250pub fn parallel_reduce_custom(
251    data: &[f64],
252    identity: f64,
253    op: impl Fn(f64, f64) -> f64 + Sync + Send,
254) -> f64 {
255    data.par_iter().copied().reduce(|| identity, op)
256}
257/// Parallel exclusive prefix sum using a two-pass algorithm.
258///
259/// Phase 1: compute partial sums in chunks (parallel).
260/// Phase 2: propagate offsets (sequential).
261/// Phase 3: apply offsets within chunks (parallel).
262#[allow(dead_code)]
263pub fn parallel_exclusive_scan(data: &[f64]) -> Vec<f64> {
264    let n = data.len();
265    if n == 0 {
266        return Vec::new();
267    }
268    let chunk_size = (n / rayon::current_num_threads().max(1)).max(64);
269    let chunks: Vec<&[f64]> = data.chunks(chunk_size).collect();
270    let chunk_sums: Vec<f64> = chunks
271        .par_iter()
272        .map(|chunk| chunk.iter().copied().sum())
273        .collect();
274    let mut offsets = Vec::with_capacity(chunks.len());
275    let mut acc = 0.0;
276    for &cs in &chunk_sums {
277        offsets.push(acc);
278        acc += cs;
279    }
280    let result = vec![0.0; n];
281    chunks.par_iter().enumerate().for_each(|(ci, chunk)| {
282        let base = ci * chunk_size;
283        let offset = offsets[ci];
284        let mut local_acc = offset;
285        let result_ptr = result.as_ptr() as *mut f64;
286        for (k, &v) in chunk.iter().enumerate() {
287            unsafe {
288                *result_ptr.add(base + k) = local_acc;
289            }
290            local_acc += v;
291        }
292    });
293    result
294}
295/// Parallel inclusive prefix sum.
296#[allow(dead_code)]
297pub fn parallel_inclusive_scan(data: &[f64]) -> Vec<f64> {
298    let n = data.len();
299    if n == 0 {
300        return Vec::new();
301    }
302    let chunk_size = (n / rayon::current_num_threads().max(1)).max(64);
303    let chunks: Vec<&[f64]> = data.chunks(chunk_size).collect();
304    let chunk_sums: Vec<f64> = chunks
305        .par_iter()
306        .map(|chunk| chunk.iter().copied().sum())
307        .collect();
308    let mut offsets = Vec::with_capacity(chunks.len());
309    let mut acc = 0.0;
310    for &cs in &chunk_sums {
311        offsets.push(acc);
312        acc += cs;
313    }
314    let result = vec![0.0; n];
315    chunks.par_iter().enumerate().for_each(|(ci, chunk)| {
316        let base = ci * chunk_size;
317        let offset = offsets[ci];
318        let mut local_acc = offset;
319        let result_ptr = result.as_ptr() as *mut f64;
320        for (k, &v) in chunk.iter().enumerate() {
321            local_acc += v;
322            unsafe {
323                *result_ptr.add(base + k) = local_acc;
324            }
325        }
326    });
327    result
328}
329/// Segmented prefix sum: performs exclusive scans within each segment.
330///
331/// `segment_ids` assigns each element to a segment. When the segment ID
332/// changes, the accumulator resets. Segments must be contiguous.
333#[allow(dead_code)]
334pub fn segmented_exclusive_scan(data: &[f64], segment_ids: &[usize]) -> Vec<f64> {
335    let n = data.len();
336    let mut result = vec![0.0; n];
337    if n == 0 {
338        return result;
339    }
340    let mut acc = 0.0;
341    let mut current_seg = segment_ids[0];
342    for i in 0..n {
343        if segment_ids[i] != current_seg {
344            current_seg = segment_ids[i];
345            acc = 0.0;
346        }
347        result[i] = acc;
348        acc += data[i];
349    }
350    result
351}
352/// Parallel sort of f64 values (ascending).
353///
354/// Uses Rayon's parallel sort. NaN values are placed at the end.
355#[allow(dead_code)]
356pub fn parallel_sort_f64(data: &mut [f64]) {
357    data.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
358}
359/// Parallel argsort: returns indices that would sort `data` in ascending order.
360#[allow(dead_code)]
361pub fn parallel_argsort(data: &[f64]) -> Vec<usize> {
362    let mut indices: Vec<usize> = (0..data.len()).collect();
363    indices.par_sort_unstable_by(|&a, &b| {
364        data[a]
365            .partial_cmp(&data[b])
366            .unwrap_or(std::cmp::Ordering::Equal)
367    });
368    indices
369}
370/// Parallel sort by key: sorts `items` based on `key_fn`.
371#[allow(dead_code)]
372pub fn parallel_sort_by_key<T: Send>(items: &mut [T], key_fn: impl Fn(&T) -> f64 + Sync + Send) {
373    items.par_sort_unstable_by(|a, b| {
374        let ka = key_fn(a);
375        let kb = key_fn(b);
376        ka.partial_cmp(&kb).unwrap_or(std::cmp::Ordering::Equal)
377    });
378}
379/// Parallel partition: split data into two groups based on a predicate.
380///
381/// Returns `(true_group, false_group)`.
382#[allow(dead_code)]
383pub fn parallel_partition<T: Send + Sync + Clone>(
384    data: &[T],
385    predicate: impl Fn(&T) -> bool + Sync + Send,
386) -> (Vec<T>, Vec<T>) {
387    let (left, right): (Vec<_>, Vec<_>) =
388        data.par_iter().cloned().partition(|item| predicate(item));
389    (left, right)
390}
391/// Parallel rank: compute the rank of each element (0-based) in sorted order.
392#[allow(dead_code)]
393pub fn parallel_rank(data: &[f64]) -> Vec<usize> {
394    let sorted_indices = parallel_argsort(data);
395    let n = data.len();
396    let mut ranks = vec![0usize; n];
397    for (rank, &idx) in sorted_indices.iter().enumerate() {
398        ranks[idx] = rank;
399    }
400    ranks
401}
402/// Compute a load balance plan for `n` items across `num_workers` workers.
403///
404/// For `Static` strategy, `item_weights` is ignored.
405/// For `Weighted` strategy, items are assigned sequentially to workers to
406/// balance the total weight per worker.
407/// For `Guided` strategy, chunks start large and shrink.
408#[allow(dead_code)]
409pub fn compute_load_balance(
410    n: usize,
411    num_workers: usize,
412    strategy: LoadBalanceStrategy,
413    item_weights: Option<&[f64]>,
414) -> LoadBalancePlan {
415    let nw = num_workers.max(1);
416    match strategy {
417        LoadBalanceStrategy::Static => {
418            let chunk = n.div_ceil(nw);
419            let mut ranges = Vec::with_capacity(nw);
420            let mut weights = Vec::with_capacity(nw);
421            for w in 0..nw {
422                let start = w * chunk;
423                let end = ((w + 1) * chunk).min(n);
424                if start < n {
425                    let weight = if let Some(wts) = item_weights {
426                        wts[start..end].iter().sum()
427                    } else {
428                        (end - start) as f64
429                    };
430                    ranges.push(start..end);
431                    weights.push(weight);
432                }
433            }
434            LoadBalancePlan { ranges, weights }
435        }
436        LoadBalanceStrategy::Weighted => {
437            let wts = match item_weights {
438                Some(w) => w,
439                None => {
440                    return compute_load_balance(n, num_workers, LoadBalanceStrategy::Static, None);
441                }
442            };
443            let total_weight: f64 = wts.iter().sum();
444            let target_per_worker = total_weight / nw as f64;
445            let mut ranges = Vec::with_capacity(nw);
446            let mut weights = Vec::with_capacity(nw);
447            let mut start = 0;
448            let mut current_weight = 0.0;
449            for i in 0..n {
450                current_weight += wts[i];
451                let workers_remaining = nw - ranges.len();
452                let at_last_worker = workers_remaining == 1;
453                let exceeded_target = current_weight >= target_per_worker && !at_last_worker;
454                if exceeded_target {
455                    ranges.push(start..(i + 1));
456                    weights.push(current_weight);
457                    start = i + 1;
458                    current_weight = 0.0;
459                }
460            }
461            if start < n || ranges.is_empty() {
462                ranges.push(start..n);
463                weights.push(current_weight);
464            }
465            LoadBalancePlan { ranges, weights }
466        }
467        LoadBalanceStrategy::Guided => {
468            let mut ranges = Vec::new();
469            let mut weights = Vec::new();
470            let mut pos = 0;
471            let mut remaining = n;
472            while remaining > 0 {
473                let min_chunk = 1usize;
474                let chunk = (remaining / nw).max(min_chunk).min(remaining);
475                let end = pos + chunk;
476                let weight = if let Some(wts) = item_weights {
477                    wts[pos..end].iter().sum()
478                } else {
479                    chunk as f64
480                };
481                ranges.push(pos..end);
482                weights.push(weight);
483                pos = end;
484                remaining -= chunk;
485            }
486            LoadBalancePlan { ranges, weights }
487        }
488    }
489}
490/// Execute a function in parallel with load-balanced ranges.
491///
492/// Each range is processed as one Rayon task. The function receives
493/// `(worker_id, range)`.
494#[allow(dead_code)]
495pub fn execute_balanced(
496    plan: &LoadBalancePlan,
497    f: impl Fn(usize, std::ops::Range<usize>) + Sync + Send,
498) {
499    plan.ranges
500        .par_iter()
501        .enumerate()
502        .for_each(|(worker_id, range)| {
503            f(worker_id, range.clone());
504        });
505}
506/// Parallel map-reduce: map each element, then reduce results.
507///
508/// Combines mapping and reduction in a single parallel pass.
509#[allow(dead_code)]
510pub fn parallel_map_reduce<T: Send + Sync>(
511    data: &[T],
512    map_fn: impl Fn(&T) -> f64 + Sync + Send,
513    identity: f64,
514    reduce_fn: impl Fn(f64, f64) -> f64 + Sync + Send,
515) -> f64 {
516    data.par_iter().map(map_fn).reduce(|| identity, reduce_fn)
517}
518/// Parallel histogram: count elements falling into each bin.
519///
520/// Bins are `[min, min+step), [min+step, min+2*step), ...`.
521/// Returns a vector of length `num_bins`.
522#[allow(dead_code)]
523pub fn parallel_histogram(data: &[f64], min: f64, max: f64, num_bins: usize) -> Vec<usize> {
524    if num_bins == 0 || max <= min {
525        return vec![0; num_bins];
526    }
527    let step = (max - min) / num_bins as f64;
528    data.par_iter()
529        .fold(
530            || vec![0usize; num_bins],
531            |mut hist, &v| {
532                if v >= min && v < max {
533                    let bin = ((v - min) / step) as usize;
534                    let bin = bin.min(num_bins - 1);
535                    hist[bin] += 1;
536                } else if (v - max).abs() < 1e-15 {
537                    hist[num_bins - 1] += 1;
538                }
539                hist
540            },
541        )
542        .reduce(
543            || vec![0usize; num_bins],
544            |mut a, b| {
545                for (ai, bi) in a.iter_mut().zip(b.iter()) {
546                    *ai += bi;
547                }
548                a
549            },
550        )
551}
552/// Stream compaction: retain only elements satisfying a predicate, returning
553/// a compacted output together with a scatter index map.
554///
555/// Returns `(compacted, scatter_map)` where:
556/// * `compacted` contains the elements `data[i]` for which `pred(data[i])` is true.
557/// * `scatter_map[j]` is the original index `i` of `compacted[j]`.
558///
559/// This mirrors a GPU stream-compaction pass (prefix-sum → scatter).
560#[allow(dead_code)]
561pub fn stream_compaction<T: Clone>(data: &[T], pred: impl Fn(&T) -> bool) -> (Vec<T>, Vec<usize>) {
562    let mut compacted = Vec::new();
563    let mut scatter_map = Vec::new();
564    for (i, item) in data.iter().enumerate() {
565        if pred(item) {
566            compacted.push(item.clone());
567            scatter_map.push(i);
568        }
569    }
570    (compacted, scatter_map)
571}
572/// Parallel stream compaction via Rayon.
573///
574/// Each thread builds a local (value, original_index) list and then the
575/// lists are merged in order to preserve a deterministic output.
576#[allow(dead_code)]
577pub fn parallel_stream_compaction<T: Clone + Send + Sync>(
578    data: &[T],
579    pred: impl Fn(&T) -> bool + Sync,
580) -> (Vec<T>, Vec<usize>) {
581    use rayon::iter::IndexedParallelIterator;
582    let pairs: Vec<(T, usize)> = data
583        .par_iter()
584        .enumerate()
585        .filter_map(|(i, item)| {
586            if pred(item) {
587                Some((item.clone(), i))
588            } else {
589                None
590            }
591        })
592        .collect();
593    let compacted: Vec<T> = pairs.iter().map(|(v, _)| v.clone()).collect();
594    let scatter_map: Vec<usize> = pairs.iter().map(|(_, i)| *i).collect();
595    (compacted, scatter_map)
596}
597/// Segmented reduction: sum values within each segment independently.
598///
599/// `segment_ids[i]` must be monotonically non-decreasing.
600/// Returns a vector of partial sums, one per distinct segment id.
601///
602/// Example:
603/// ```text
604/// data          = [1, 2, 3, 4, 5, 6]
605/// segment_ids   = [0, 0, 1, 1, 1, 2]
606/// output        = [3, 12, 6]
607/// ```
608#[allow(dead_code)]
609pub fn segmented_reduce_sum(data: &[f64], segment_ids: &[usize]) -> Vec<f64> {
610    if data.is_empty() {
611        return Vec::new();
612    }
613    let max_seg = *segment_ids.iter().max().unwrap_or(&0);
614    let mut result = vec![0.0f64; max_seg + 1];
615    for (&v, &s) in data.iter().zip(segment_ids.iter()) {
616        result[s] += v;
617    }
618    result
619}
620/// Segmented reduction: maximum value within each segment.
621#[allow(dead_code)]
622pub fn segmented_reduce_max(data: &[f64], segment_ids: &[usize]) -> Vec<f64> {
623    if data.is_empty() {
624        return Vec::new();
625    }
626    let max_seg = *segment_ids.iter().max().unwrap_or(&0);
627    let mut result = vec![f64::NEG_INFINITY; max_seg + 1];
628    for (&v, &s) in data.iter().zip(segment_ids.iter()) {
629        if v > result[s] {
630            result[s] = v;
631        }
632    }
633    result
634}
635/// Segmented reduction: minimum value within each segment.
636#[allow(dead_code)]
637pub fn segmented_reduce_min(data: &[f64], segment_ids: &[usize]) -> Vec<f64> {
638    if data.is_empty() {
639        return Vec::new();
640    }
641    let max_seg = *segment_ids.iter().max().unwrap_or(&0);
642    let mut result = vec![f64::INFINITY; max_seg + 1];
643    for (&v, &s) in data.iter().zip(segment_ids.iter()) {
644        if v < result[s] {
645            result[s] = v;
646        }
647    }
648    result
649}
650/// Stable merge sort for f64 values (CPU reference implementation).
651///
652/// Returns a new sorted vector leaving the input unchanged.
653/// NaN values are placed at the end (treated as greater than any finite value).
654#[allow(dead_code)]
655pub fn merge_sort_f64(data: &[f64]) -> Vec<f64> {
656    let mut buf = data.to_vec();
657    merge_sort_recurse(&mut buf);
658    buf
659}
660pub(super) fn merge_sort_recurse(data: &mut [f64]) {
661    let n = data.len();
662    if n <= 1 {
663        return;
664    }
665    let mid = n / 2;
666    merge_sort_recurse(&mut data[..mid]);
667    merge_sort_recurse(&mut data[mid..]);
668    let left: Vec<f64> = data[..mid].to_vec();
669    let right: Vec<f64> = data[mid..].to_vec();
670    let (mut l, mut r, mut i) = (0, 0, 0);
671    while l < left.len() && r < right.len() {
672        if left[l]
673            .partial_cmp(&right[r])
674            .unwrap_or(std::cmp::Ordering::Greater)
675            != std::cmp::Ordering::Greater
676        {
677            data[i] = left[l];
678            l += 1;
679        } else {
680            data[i] = right[r];
681            r += 1;
682        }
683        i += 1;
684    }
685    while l < left.len() {
686        data[i] = left[l];
687        l += 1;
688        i += 1;
689    }
690    while r < right.len() {
691        data[i] = right[r];
692        r += 1;
693        i += 1;
694    }
695}
696/// Merge sort returning the sorted permutation (argsort, stable).
697///
698/// `result[k]` is the original index of the k-th smallest element.
699#[allow(dead_code)]
700pub fn merge_sort_argsort(data: &[f64]) -> Vec<usize> {
701    let mut indices: Vec<usize> = (0..data.len()).collect();
702    merge_argsort_recurse(data, &mut indices);
703    indices
704}
705pub(super) fn merge_argsort_recurse(data: &[f64], indices: &mut [usize]) {
706    let n = indices.len();
707    if n <= 1 {
708        return;
709    }
710    let mid = n / 2;
711    let (left_idx, right_idx) = indices.split_at_mut(mid);
712    merge_argsort_recurse(data, left_idx);
713    merge_argsort_recurse(data, right_idx);
714    let left: Vec<usize> = left_idx.to_vec();
715    let right: Vec<usize> = right_idx.to_vec();
716    let (mut l, mut r, mut i) = (0, 0, 0);
717    while l < left.len() && r < right.len() {
718        let cmp = data[left[l]]
719            .partial_cmp(&data[right[r]])
720            .unwrap_or(std::cmp::Ordering::Greater);
721        if cmp != std::cmp::Ordering::Greater {
722            indices[i] = left[l];
723            l += 1;
724        } else {
725            indices[i] = right[r];
726            r += 1;
727        }
728        i += 1;
729    }
730    while l < left.len() {
731        indices[i] = left[l];
732        l += 1;
733        i += 1;
734    }
735    while r < right.len() {
736        indices[i] = right[r];
737        r += 1;
738        i += 1;
739    }
740}
741/// Bitonic sort in ascending order.
742///
743/// Works on arrays whose length is a power of two.  Pads the input with
744/// `f64::INFINITY` if needed and trims back afterwards.
745///
746/// This CPU reference mirrors a GPU bitonic sort which operates in
747/// `O(n log² n)` compare-and-swap steps.
748#[allow(dead_code)]
749pub fn bitonic_sort(data: &[f64]) -> Vec<f64> {
750    let n = data.len();
751    if n == 0 {
752        return Vec::new();
753    }
754    let mut p = 1usize;
755    while p < n {
756        p <<= 1;
757    }
758    let mut buf: Vec<f64> = data.to_vec();
759    buf.resize(p, f64::INFINITY);
760    let mut k = 2usize;
761    while k <= p {
762        let mut j = k >> 1;
763        while j >= 1 {
764            for i in 0..p {
765                let l = i ^ j;
766                if l > i {
767                    let ascending = (i & k) == 0;
768                    if (ascending && buf[i] > buf[l]) || (!ascending && buf[i] < buf[l]) {
769                        buf.swap(i, l);
770                    }
771                }
772            }
773            j >>= 1;
774        }
775        k <<= 1;
776    }
777    buf.truncate(n);
778    buf
779}
780/// Bitonic sort that returns the original indices (argsort variant).
781///
782/// Pads with `(f64::INFINITY, usize::MAX)` pairs and trims back.
783#[allow(dead_code)]
784pub fn bitonic_argsort(data: &[f64]) -> Vec<usize> {
785    let n = data.len();
786    if n == 0 {
787        return Vec::new();
788    }
789    let mut p = 1usize;
790    while p < n {
791        p <<= 1;
792    }
793    let mut buf: Vec<(f64, usize)> = data
794        .iter()
795        .copied()
796        .enumerate()
797        .map(|(i, v)| (v, i))
798        .collect();
799    buf.resize(p, (f64::INFINITY, usize::MAX));
800    let mut k = 2usize;
801    while k <= p {
802        let mut j = k >> 1;
803        while j >= 1 {
804            for i in 0..p {
805                let l = i ^ j;
806                if l > i {
807                    let ascending = (i & k) == 0;
808                    let should_swap =
809                        (ascending && buf[i].0 > buf[l].0) || (!ascending && buf[i].0 < buf[l].0);
810                    if should_swap {
811                        buf.swap(i, l);
812                    }
813                }
814            }
815            j >>= 1;
816        }
817        k <<= 1;
818    }
819    buf.truncate(n);
820    buf.iter().map(|(_, idx)| *idx).collect()
821}
822/// Simulate a work-stealing dispatcher across `num_workers` queues.
823///
824/// `tasks` is divided evenly among workers.  Any worker that finishes early
825/// steals from the most loaded remaining worker.
826/// Returns a `Vec`usize` of length `num_workers` recording the tasks each
827/// worker processed.
828#[allow(dead_code)]
829pub fn work_steal_queue<T: Send + Clone>(
830    tasks: Vec<T>,
831    num_workers: usize,
832    _process: impl Fn(&T) + Sync,
833) -> Vec<usize> {
834    let nw = num_workers.max(1);
835    let mut queues: Vec<WorkStealQueue<T>> = (0..nw).map(|_| WorkStealQueue::new()).collect();
836    for (i, task) in tasks.into_iter().enumerate() {
837        queues[i % nw].push(task);
838    }
839    let mut processed = vec![0usize; nw];
840    loop {
841        let mut did_work = false;
842        for w in 0..nw {
843            while let Some(task) = queues[w].pop() {
844                _process(&task);
845                processed[w] += 1;
846                did_work = true;
847            }
848        }
849        let max_len = queues.iter().map(|q| q.len()).max().unwrap_or(0);
850        if max_len == 0 {
851            break;
852        }
853        if did_work {
854            continue;
855        }
856        let victim = queues
857            .iter()
858            .enumerate()
859            .max_by_key(|(_, q)| q.len())
860            .map(|(i, _)| i);
861        let thief = queues
862            .iter()
863            .enumerate()
864            .find(|(_, q)| q.is_empty())
865            .map(|(i, _)| i);
866        if let (Some(v), Some(t)) = (victim, thief) {
867            if v != t {
868                if let Some(task) = queues[v].steal() {
869                    queues[t].push(task);
870                }
871            } else {
872                break;
873            }
874        } else {
875            break;
876        }
877    }
878    processed
879}
880/// Compute a load-balance efficiency metric given per-worker task counts.
881///
882/// Returns a value in `\[0, 1\]`: 1.0 means perfect balance, smaller values
883/// indicate more imbalance.  Defined as `avg_load / max_load`.
884#[allow(dead_code)]
885pub fn compute_load_balance_metric(worker_loads: &[usize]) -> f64 {
886    if worker_loads.is_empty() {
887        return 1.0;
888    }
889    let total: usize = worker_loads.iter().sum();
890    let n = worker_loads.len();
891    let avg = total as f64 / n as f64;
892    let max = *worker_loads.iter().max().unwrap_or(&1) as f64;
893    if max < 1e-15 {
894        return 1.0;
895    }
896    avg / max
897}
898/// Suggest an optimal chunk size for `n` work items across `num_workers`
899/// workers, targeting at least `min_chunks_per_worker` chunks per worker.
900#[allow(dead_code)]
901pub fn suggest_chunk_size(n: usize, num_workers: usize, min_chunks_per_worker: usize) -> usize {
902    let nw = num_workers.max(1);
903    let chunks = (nw * min_chunks_per_worker).max(1);
904    n.div_ceil(chunks).max(1)
905}
906/// Parallel merge sort for `f64` slices using Rayon.
907///
908/// Splits the array recursively.  Below `SERIAL_THRESHOLD` elements the
909/// standard library sort is used.  Above that the two halves are sorted in
910/// parallel and then merged sequentially.
911#[allow(dead_code)]
912pub fn merge_sort_parallel(data: &[f64]) -> Vec<f64> {
913    pub(super) const SERIAL_THRESHOLD: usize = 256;
914    let n = data.len();
915    if n <= 1 {
916        return data.to_vec();
917    }
918    if n <= SERIAL_THRESHOLD {
919        let mut v = data.to_vec();
920        v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
921        return v;
922    }
923    let mid = n / 2;
924    let (left_slice, right_slice) = data.split_at(mid);
925    let (left_sorted, right_sorted) = rayon::join(
926        || merge_sort_parallel(left_slice),
927        || merge_sort_parallel(right_slice),
928    );
929    merge_two_sorted(&left_sorted, &right_sorted)
930}
931/// Merge two sorted `f64` slices into one sorted `Vec`f64`.
932#[allow(dead_code)]
933pub fn merge_two_sorted(a: &[f64], b: &[f64]) -> Vec<f64> {
934    let mut result = Vec::with_capacity(a.len() + b.len());
935    let (mut i, mut j) = (0, 0);
936    while i < a.len() && j < b.len() {
937        if a[i] <= b[j] {
938            result.push(a[i]);
939            i += 1;
940        } else {
941            result.push(b[j]);
942            j += 1;
943        }
944    }
945    result.extend_from_slice(&a[i..]);
946    result.extend_from_slice(&b[j..]);
947    result
948}