Skip to main content

rlx_cpu/
pool.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Rayon-backed parallel for: `par_for(total, grain, |off, cnt| …)`.
17//!
18//! Replaces the old per-worker Condvar pool with Rayon's work-stealing
19//! scheduler. Same `(offset, count)` chunk API so all existing call
20//! sites (BLAS tiling, SDPA, LayerNorm, …) pick up Rayon without
21//! changes.
22
23use rayon::prelude::*;
24use std::sync::Once;
25
26static POOL_INIT: Once = Once::new();
27
28fn ensure_pool() {
29    POOL_INIT.call_once(|| {
30        let cfg = crate::config::RuntimeConfig::global();
31        let n = cfg.pool_workers.max(1);
32        let _ = rayon::ThreadPoolBuilder::new()
33            .num_threads(n)
34            .thread_name(|i| format!("rlx-rayon-{i}"))
35            .build_global();
36    });
37}
38
39/// Total Rayon worker count (configured from [`RuntimeConfig::pool_workers`]).
40pub fn num_threads() -> usize {
41    ensure_pool();
42    rayon::current_num_threads()
43}
44
45/// Parallel for: split `total` items across threads. `f(off, cnt)` is
46/// called once per chunk with disjoint regions.
47///
48/// SAFETY: caller must ensure `f` accesses disjoint memory regions for
49/// different `(offset, count)` pairs.
50#[inline]
51pub fn par_for<F: Fn(usize, usize) + Sync>(total: usize, min_per_thread: usize, f: &F) {
52    if total == 0 {
53        return;
54    }
55    ensure_pool();
56    let grain = min_per_thread.max(1);
57    let n_threads = (total / grain).max(1).min(num_threads());
58    if n_threads <= 1 {
59        f(0, total);
60        return;
61    }
62    let chunk = total.div_ceil(n_threads);
63    (0..n_threads).into_par_iter().for_each(|t| {
64        let off = t * chunk;
65        if off < total {
66            f(off, (off + chunk).min(total) - off);
67        }
68    });
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use std::sync::atomic::{AtomicU64, Ordering};
75
76    #[test]
77    fn par_for_sums_correctly() {
78        let data = vec![1.0f32; 10_000];
79        let total = AtomicU64::new(0);
80
81        par_for(data.len(), 100, &|off, cnt| {
82            let partial: f32 = data[off..off + cnt].iter().sum();
83            total.fetch_add(partial.to_bits() as u64, Ordering::Relaxed);
84        });
85
86        assert!(total.load(Ordering::Relaxed) > 0);
87    }
88
89    #[test]
90    fn par_for_small_is_sequential() {
91        let sum = std::sync::atomic::AtomicUsize::new(0);
92        par_for(10, 100, &|off, cnt| {
93            sum.fetch_add(cnt, Ordering::Relaxed);
94            assert_eq!(off + cnt, 10);
95        });
96        assert_eq!(sum.load(Ordering::Relaxed), 10);
97    }
98
99    #[test]
100    fn par_for_exact_sum_many_dispatches() {
101        for &n in &[256usize, 1024, 4097] {
102            let sum = std::sync::atomic::AtomicUsize::new(0);
103            par_for(n, 256, &|off, cnt| {
104                sum.fetch_add(cnt, Ordering::Relaxed);
105                assert!(off + cnt <= n);
106            });
107            assert_eq!(sum.load(Ordering::Relaxed), n);
108        }
109    }
110
111    #[test]
112    fn par_for_concurrent_callers_isolated() {
113        std::thread::scope(|s| {
114            for t in 0..4 {
115                s.spawn(move || {
116                    let n = 4096 + t * 17;
117                    let sum = std::sync::atomic::AtomicUsize::new(0);
118                    par_for(n, 128, &|off, cnt| {
119                        sum.fetch_add(cnt, Ordering::Relaxed);
120                        assert!(off + cnt <= n);
121                    });
122                    assert_eq!(sum.load(Ordering::Relaxed), n);
123                });
124            }
125        });
126    }
127}