Skip to main content

rlx_ir/
rng.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Counter-based and ONNX Runtime–compatible RNG for in-graph random ops.
17//!
18//! # Behavioral contract
19//!
20//! [`Op::RngNormal`] / [`Op::RngUniform`] take an optional shape-template input
21//! (ONNX `Random*Like`) or no inputs when the output shape is fixed at import
22//! time (ONNX `Random*` with a `shape` attribute). The output tensor shape is
23//! always the node's assigned shape; the template input is not copied into the
24//! output.
25//!
26//! | Backend | Semantics |
27//! |---------|-----------|
28//! | [`RngBackend::Philox`] | Deterministic Philox4×32-10 stream keyed by [`RngOptions::seed`] + per-node `key`. Default for RLX-native runs. |
29//! | [`RngBackend::Ort`] | Matches ONNX Runtime CPU `Random*` (`minstd_rand0` + polar normal / uniform). Use for import parity tests. Per-op ONNX `seed` (f32) overrides the mixed engine seed when set. |
30//! | [`RngBackend::Zero`] | Writes zeros — useful when comparing against a stochastic reference without re-seeding ORT. |
31//!
32//! Policy is set at compile time via [`CompileOptions::rng`] and can be overridden
33//! per session through [`rlx_runtime::CompiledGraph::set_rng`] without
34//! recompiling. Each execute re-seeds from the current policy (ORT session state
35//! is not advanced across runs today).
36
37/// Which RNG implementation to use for [`Op::RngNormal`] / [`Op::RngUniform`].
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
40pub enum RngBackend {
41    /// Philox4×32-10 sequential stream (RLX native default).
42    #[default]
43    Philox,
44    /// ONNX Runtime CPU `Random*Like` (`minstd_rand0` + `std::normal_distribution`).
45    Ort,
46    /// Fill with zero (deterministic parity vs stochastic reference runs).
47    Zero,
48}
49
50/// Compile-time / execute-time RNG policy for graphs containing random ops.
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
53pub struct RngOptions {
54    /// Global seed mixed into per-node keys (maps to ORT session seed).
55    pub seed: u64,
56    pub backend: RngBackend,
57}
58
59impl Default for RngOptions {
60    fn default() -> Self {
61        Self {
62            seed: 42,
63            backend: RngBackend::Philox,
64        }
65    }
66}
67
68impl RngOptions {
69    pub const fn new(seed: u64, backend: RngBackend) -> Self {
70        Self { seed, backend }
71    }
72
73    pub fn philox(seed: u64) -> Self {
74        Self {
75            seed,
76            backend: RngBackend::Philox,
77        }
78    }
79
80    pub fn ort(seed: u64) -> Self {
81        Self {
82            seed,
83            backend: RngBackend::Ort,
84        }
85    }
86
87    pub fn zero() -> Self {
88        Self {
89            seed: 0,
90            backend: RngBackend::Zero,
91        }
92    }
93}
94
95/// Mix a global compile seed with a per-node key (ONNX node name hash).
96pub fn combine_seed(global: u64, key: u64) -> u64 {
97    global.wrapping_add(key.wrapping_mul(0x9E37_79B9_7F4A_7C15))
98}
99
100/// ORT CPU engine seed: explicit ONNX `seed` attr cast to u32, else global+key.
101pub fn ort_engine_seed(global: u64, key: u64, op_seed: Option<f32>) -> u32 {
102    if let Some(s) = op_seed {
103        s as u32
104    } else {
105        global.wrapping_add(key) as u32
106    }
107}
108
109/// Fill `out` with `mean + scale * N(0,1)` samples.
110pub fn fill_normal_like(
111    out: &mut [f32],
112    mean: f32,
113    scale: f32,
114    opts: RngOptions,
115    key: u64,
116    op_seed: Option<f32>,
117) {
118    match opts.backend {
119        RngBackend::Zero => out.fill(0.0),
120        RngBackend::Philox => {
121            let mut rng = Philox4x32::new(combine_seed(opts.seed, key));
122            for v in out.iter_mut() {
123                *v = mean + scale * rng.normal();
124            }
125        }
126        RngBackend::Ort => {
127            let mut eng = MinstdRand0::new(ort_engine_seed(opts.seed, key, op_seed));
128            let mut dist = StdNormalDist::new(mean, scale);
129            for v in out.iter_mut() {
130                *v = dist.sample(&mut eng);
131            }
132        }
133    }
134}
135
136/// Fill `out` with uniform samples in `[low, high)`.
137pub fn fill_uniform_like(
138    out: &mut [f32],
139    low: f32,
140    high: f32,
141    opts: RngOptions,
142    key: u64,
143    op_seed: Option<f32>,
144) {
145    match opts.backend {
146        RngBackend::Zero => out.fill(0.0),
147        RngBackend::Philox => {
148            let mut rng = Philox4x32::new(combine_seed(opts.seed, key));
149            for v in out.iter_mut() {
150                *v = rng.uniform(low, high);
151            }
152        }
153        RngBackend::Ort => {
154            let mut eng = MinstdRand0::new(ort_engine_seed(opts.seed, key, op_seed));
155            for v in out.iter_mut() {
156                *v = low + (high - low) * eng.unit_f32();
157            }
158        }
159    }
160}
161
162/// Philox4×32 counter-based RNG. Produces 4 u32s per round of the
163/// core hash — we expose an iterator that yields one f32 per call.
164#[derive(Debug, Clone, Copy)]
165pub struct Philox4x32 {
166    seed: [u32; 2],
167    counter: [u32; 4],
168    /// Cached output buffer + cursor into it.
169    buffer: [u32; 4],
170    cursor: u8,
171}
172
173impl Philox4x32 {
174    pub const fn new(seed: u64) -> Self {
175        let lo = (seed & 0xFFFF_FFFF) as u32;
176        let hi = (seed >> 32) as u32;
177        Self {
178            seed: [lo, hi],
179            counter: [0, 0, 0, 0],
180            buffer: [0; 4],
181            cursor: 4, // empty — next next_u32 fills the buffer
182        }
183    }
184
185    fn round(state: &mut [u32; 4], key: [u32; 2]) {
186        const M0: u64 = 0xD256_1A75;
187        const M1: u64 = 0xCD9E_8D57;
188        let p0 = (state[0] as u64) * M0;
189        let p1 = (state[2] as u64) * M1;
190        let hi0 = (p0 >> 32) as u32;
191        let lo0 = p0 as u32;
192        let hi1 = (p1 >> 32) as u32;
193        let lo1 = p1 as u32;
194        state[0] = hi1 ^ state[1] ^ key[0];
195        state[1] = lo1;
196        state[2] = hi0 ^ state[3] ^ key[1];
197        state[3] = lo0;
198    }
199
200    fn fill_buffer(&mut self) {
201        let mut state = self.counter;
202        let mut key = self.seed;
203        for _ in 0..10 {
204            Self::round(&mut state, key);
205            // Bump the key on every round (Philox key schedule).
206            key[0] = key[0].wrapping_add(0x9E37_79B9);
207            key[1] = key[1].wrapping_add(0xBB67_AE85);
208        }
209        self.buffer = state;
210        self.cursor = 0;
211
212        // Increment the 128-bit counter.
213        let (c0, of0) = self.counter[0].overflowing_add(1);
214        self.counter[0] = c0;
215        if of0 {
216            let (c1, of1) = self.counter[1].overflowing_add(1);
217            self.counter[1] = c1;
218            if of1 {
219                let (c2, of2) = self.counter[2].overflowing_add(1);
220                self.counter[2] = c2;
221                if of2 {
222                    self.counter[3] = self.counter[3].wrapping_add(1);
223                }
224            }
225        }
226    }
227
228    pub fn next_u32(&mut self) -> u32 {
229        if self.cursor >= 4 {
230            self.fill_buffer();
231        }
232        let v = self.buffer[self.cursor as usize];
233        self.cursor += 1;
234        v
235    }
236
237    /// Uniform `[0, 1)` f32 — the top 24 bits of a u32 give exactly
238    /// f32 mantissa precision.
239    pub fn next_f32(&mut self) -> f32 {
240        let bits = self.next_u32() >> 8;
241        bits as f32 / (1u32 << 24) as f32
242    }
243
244    /// Uniform `[lo, hi)` f32.
245    pub fn uniform(&mut self, lo: f32, hi: f32) -> f32 {
246        lo + self.next_f32() * (hi - lo)
247    }
248
249    /// Standard-normal `f32` via Box-Muller. Returns one sample;
250    /// the second is discarded (we don't cache to keep the type
251    /// `Copy`-able).
252    pub fn normal(&mut self) -> f32 {
253        let u1 = self.next_f32().max(f32::MIN_POSITIVE);
254        let u2 = self.next_f32();
255        let r = (-2.0 * u1.ln()).sqrt();
256        let theta = 2.0 * std::f32::consts::PI * u2;
257        r * theta.cos()
258    }
259
260    /// Fill `out` with uniform `[0, 1)` samples. Convenience for
261    /// weight init.
262    pub fn fill_uniform(&mut self, out: &mut [f32]) {
263        for v in out {
264            *v = self.next_f32();
265        }
266    }
267
268    /// Fill `out` with N(0, 1) samples.
269    pub fn fill_normal(&mut self, out: &mut [f32]) {
270        for v in out {
271            *v = self.normal();
272        }
273    }
274}
275
276/// C++11 `std::default_random_engine` on libstdc++/libc++ (`minstd_rand0`).
277#[derive(Debug, Clone, Copy)]
278struct MinstdRand0 {
279    state: u32,
280}
281
282impl MinstdRand0 {
283    const A: u32 = 48_271;
284    const M: u32 = 2_147_483_647;
285
286    fn new(seed: u32) -> Self {
287        Self {
288            state: seed % Self::M,
289        }
290    }
291
292    fn next_u32(&mut self) -> u32 {
293        self.state = ((self.state as u64 * Self::A as u64) % Self::M as u64) as u32;
294        self.state
295    }
296
297    /// Uniform in `[0, 1)` matching ORT's `RealType(g()) / (g.max() - g.min())`.
298    fn unit_f32(&mut self) -> f32 {
299        self.next_u32() as f32 / (Self::M - 1) as f32
300    }
301}
302
303/// C++ `std::normal_distribution<float>` (polar method, caches spare sample).
304#[derive(Debug, Clone, Copy)]
305struct StdNormalDist {
306    mean: f32,
307    scale: f32,
308    spare: f32,
309    has_spare: bool,
310}
311
312impl StdNormalDist {
313    fn new(mean: f32, scale: f32) -> Self {
314        Self {
315            mean,
316            scale,
317            spare: 0.0,
318            has_spare: false,
319        }
320    }
321
322    fn sample(&mut self, eng: &mut MinstdRand0) -> f32 {
323        if self.has_spare {
324            self.has_spare = false;
325            return self.spare;
326        }
327        loop {
328            let u1 = 2.0 * eng.unit_f32() - 1.0;
329            let u2 = 2.0 * eng.unit_f32() - 1.0;
330            let s = u1 * u1 + u2 * u2;
331            if s >= 1.0 || s == 0.0 {
332                continue;
333            }
334            let factor = (-2.0 * s.ln() / s).sqrt();
335            self.spare = u2 * factor * self.scale + self.mean;
336            self.has_spare = true;
337            return u1 * factor * self.scale + self.mean;
338        }
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn same_seed_same_sequence() {
348        let mut a = Philox4x32::new(0x1234_5678);
349        let mut b = Philox4x32::new(0x1234_5678);
350        for _ in 0..256 {
351            assert_eq!(a.next_u32(), b.next_u32());
352        }
353    }
354
355    #[test]
356    fn different_seed_different_sequence() {
357        let mut a = Philox4x32::new(1);
358        let mut b = Philox4x32::new(2);
359        let mut diffs = 0usize;
360        for _ in 0..16 {
361            if a.next_u32() != b.next_u32() {
362                diffs += 1;
363            }
364        }
365        assert!(
366            diffs >= 14,
367            "two distinct seeds should disagree on >=14/16 samples"
368        );
369    }
370
371    #[test]
372    fn next_f32_in_unit_interval() {
373        let mut r = Philox4x32::new(42);
374        for _ in 0..1000 {
375            let v = r.next_f32();
376            assert!((0.0..1.0).contains(&v), "{v} not in [0, 1)");
377        }
378    }
379
380    #[test]
381    fn fill_uniform_is_deterministic() {
382        let mut r1 = Philox4x32::new(7);
383        let mut r2 = Philox4x32::new(7);
384        let mut a = vec![0f32; 64];
385        let mut b = vec![0f32; 64];
386        r1.fill_uniform(&mut a);
387        r2.fill_uniform(&mut b);
388        assert_eq!(a, b);
389    }
390
391    #[test]
392    fn normal_mean_is_near_zero() {
393        let mut r = Philox4x32::new(123);
394        let n = 10_000;
395        let mut sum = 0f32;
396        for _ in 0..n {
397            sum += r.normal();
398        }
399        let mean = sum / n as f32;
400        assert!(mean.abs() < 0.1, "mean {mean} too far from 0");
401    }
402
403    #[test]
404    fn zero_backend_fills_zeros() {
405        let mut out = vec![1.0; 8];
406        fill_normal_like(&mut out, 0.0, 1.0, RngOptions::zero(), 0xABC, None);
407        assert!(out.iter().all(|&v| v == 0.0));
408    }
409
410    #[test]
411    fn philox_backend_is_deterministic() {
412        let opts = RngOptions::philox(99);
413        let mut a = vec![0f32; 32];
414        let mut b = vec![0f32; 32];
415        fill_normal_like(&mut a, 0.0, 0.5, opts, 123, None);
416        fill_normal_like(&mut b, 0.0, 0.5, opts, 123, None);
417        assert_eq!(a, b);
418    }
419
420    #[test]
421    fn ort_backend_is_deterministic() {
422        let opts = RngOptions::ort(7);
423        let mut a = vec![0f32; 64];
424        let mut b = vec![0f32; 64];
425        fill_normal_like(&mut a, 0.1, 2.0, opts, 555, None);
426        fill_normal_like(&mut b, 0.1, 2.0, opts, 555, None);
427        assert_eq!(a, b);
428    }
429
430    #[test]
431    fn backends_disagree() {
432        let mut philox = vec![0f32; 16];
433        let mut ort = vec![0f32; 16];
434        fill_normal_like(&mut philox, 0.0, 1.0, RngOptions::philox(42), 1, None);
435        fill_normal_like(&mut ort, 0.0, 1.0, RngOptions::ort(42), 1, None);
436        assert_ne!(philox, ort);
437    }
438}