Skip to main content

cjc_repro/
lib.rs

1//! Deterministic computation primitives for CJC.
2//!
3//! This crate provides the foundational building blocks that guarantee
4//! **bit-identical** results across runs, platforms, and thread counts:
5//!
6//! - [`Rng`] -- a SplitMix64 PRNG with explicit seed threading.  Same seed
7//!   produces the identical sequence on every platform.
8//! - [`KahanAccumulatorF64`] / [`KahanAccumulatorF32`] -- incremental
9//!   compensated-summation accumulators (re-exported from the [`kahan`] module).
10//! - [`kahan_sum_f64`] / [`kahan_sum_f32`] -- one-shot compensated summation
11//!   over slices.
12//! - [`pairwise_sum_f64`] -- recursive pairwise summation that falls back to
13//!   Kahan summation for leaves of 32 elements or fewer.
14//! - [`ReproConfig`] -- a lightweight toggle that carries the reproducibility
15//!   seed through the compiler pipeline.
16//!
17//! # Determinism contract
18//!
19//! All primitives in this crate are **serial and deterministic**.  When the
20//! same inputs are provided in the same order, the output is bit-for-bit
21//! identical regardless of the host platform, compiler version, or OS.
22//!
23//! No `HashMap`, no FMA, no non-deterministic SIMD reductions.
24
25pub mod kahan;
26pub use kahan::{
27    KahanAccumulatorF32, KahanAccumulatorF64, KahanAccumulatorF64x4, KahanAccumulatorF64x8,
28};
29
30/// Deterministic pseudo-random number generator using the SplitMix64 algorithm.
31///
32/// Guarantees identical sequences for the same seed across all platforms.
33/// SplitMix64 has a period of 2^64 and passes BigCrush statistical tests.
34///
35/// # Determinism
36///
37/// Two [`Rng`] instances created with the same seed will always produce the
38/// exact same sequence of values, regardless of the host OS or architecture.
39/// This is the backbone of CJC's reproducible computation model.
40///
41/// # Examples
42///
43/// ```
44/// use cjc_repro::Rng;
45///
46/// let mut rng = Rng::seeded(42);
47/// let a = rng.next_f64(); // deterministic value in [0, 1)
48/// let b = rng.next_u64(); // deterministic u64
49/// ```
50#[derive(Debug, Clone)]
51pub struct Rng {
52    state: u64,
53}
54
55impl Rng {
56    /// Creates a new [`Rng`] initialized with the given seed.
57    ///
58    /// # Arguments
59    ///
60    /// * `seed` -- The initial state.  Seed `0` is valid and produces a
61    ///   well-defined sequence.
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use cjc_repro::Rng;
67    /// let mut rng = Rng::seeded(0);
68    /// assert_eq!(rng.next_u64(), Rng::seeded(0).next_u64());
69    /// ```
70    pub fn seeded(seed: u64) -> Self {
71        Self { state: seed }
72    }
73
74    /// Generates the next `u64` using the SplitMix64 mixing function.
75    ///
76    /// Advances the internal state by one step and returns a uniformly
77    /// distributed 64-bit value.
78    ///
79    /// # Returns
80    ///
81    /// A deterministic `u64` drawn from the full `0..=u64::MAX` range.
82    pub fn next_u64(&mut self) -> u64 {
83        self.state = self.state.wrapping_add(0x9e3779b97f4a7c15);
84        let mut z = self.state;
85        z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
86        z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
87        z ^ (z >> 31)
88    }
89
90    /// Generates a uniformly distributed `f64` in the half-open interval `[0, 1)`.
91    ///
92    /// Uses the upper 53 bits of [`next_u64`](Self::next_u64) to fill the
93    /// 53-bit mantissa of an IEEE-754 double, then divides by 2^53.
94    ///
95    /// # Returns
96    ///
97    /// A deterministic `f64` satisfying `0.0 <= value < 1.0`.
98    pub fn next_f64(&mut self) -> f64 {
99        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
100    }
101
102    /// Generates a uniformly distributed `f32` in the half-open interval `[0, 1)`.
103    ///
104    /// Uses the upper 24 bits of [`next_u64`](Self::next_u64) to fill the
105    /// 24-bit mantissa of an IEEE-754 single, then divides by 2^24.
106    ///
107    /// # Returns
108    ///
109    /// A deterministic `f32` satisfying `0.0 <= value < 1.0`.
110    pub fn next_f32(&mut self) -> f32 {
111        (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
112    }
113
114    /// Generates a sample from the standard normal distribution (mean 0, variance 1)
115    /// using the Box-Muller transform.
116    ///
117    /// Consumes **two** uniform samples from [`next_f64`](Self::next_f64) per call.
118    /// The transform is: `sqrt(-2 ln(u1)) * cos(2 pi u2)`.
119    ///
120    /// # Returns
121    ///
122    /// A deterministic `f64` drawn from N(0, 1).
123    pub fn next_normal_f64(&mut self) -> f64 {
124        let u1 = self.next_f64();
125        let u2 = self.next_f64();
126        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
127    }
128
129    /// Generates a standard-normal `f32` sample.
130    ///
131    /// Delegates to [`next_normal_f64`](Self::next_normal_f64) and narrows the
132    /// result to `f32`.  Consumes two uniform draws from the underlying state.
133    ///
134    /// # Returns
135    ///
136    /// A deterministic `f32` drawn from N(0, 1).
137    pub fn next_normal_f32(&mut self) -> f32 {
138        self.next_normal_f64() as f32
139    }
140
141    /// Forks the RNG into an independent sub-stream.
142    ///
143    /// The returned [`Rng`] is seeded with the next `u64` drawn from `self`,
144    /// so both the parent and the child advance deterministically.  This is
145    /// the standard mechanism for giving each CJC closure or parallel lane
146    /// its own reproducible random stream.
147    ///
148    /// # Returns
149    ///
150    /// A new [`Rng`] whose state is derived from the current generator.
151    ///
152    /// # Examples
153    ///
154    /// ```
155    /// use cjc_repro::Rng;
156    /// let mut parent = Rng::seeded(7);
157    /// let mut child = parent.fork();
158    /// // parent and child now produce independent but deterministic sequences
159    /// let _ = child.next_f64();
160    /// ```
161    pub fn fork(&mut self) -> Rng {
162        Rng {
163            state: self.next_u64(),
164        }
165    }
166}
167
168/// Computes the sum of a slice of `f64` values using Kahan compensated summation.
169///
170/// Achieves an error bound of O(epsilon) for *n* summands, compared to O(*n* * epsilon)
171/// for naive left-to-right addition.  Uses only two scalar registers (sum and
172/// compensation) with no heap allocation.
173///
174/// # Arguments
175///
176/// * `values` -- The slice of `f64` values to sum.
177///
178/// # Returns
179///
180/// The compensated sum as `f64`.
181///
182/// # Determinism
183///
184/// The result is deterministic for a given input slice.  Different orderings of
185/// the same values may yield different (but equally stable) results.
186///
187/// # Examples
188///
189/// ```
190/// use cjc_repro::kahan_sum_f64;
191/// let vals: Vec<f64> = (0..10_000).map(|_| 0.0001).collect();
192/// let sum = kahan_sum_f64(&vals);
193/// assert!((sum - 1.0).abs() < 1e-10);
194/// ```
195pub fn kahan_sum_f64(values: &[f64]) -> f64 {
196    let mut sum = 0.0f64;
197    let mut compensation = 0.0f64;
198    for &val in values {
199        let y = val - compensation;
200        let t = sum + y;
201        compensation = (t - sum) - y;
202        sum = t;
203    }
204    sum
205}
206
207/// Computes the sum of a slice of `f32` values using Kahan compensated summation.
208///
209/// This is the single-precision counterpart to [`kahan_sum_f64`].  The error
210/// bound is O(epsilon) relative to `f32` machine epsilon, with no heap
211/// allocation.
212///
213/// # Arguments
214///
215/// * `values` -- The slice of `f32` values to sum.
216///
217/// # Returns
218///
219/// The compensated sum as `f32`.
220///
221/// # Determinism
222///
223/// Deterministic for a given input slice ordering.
224pub fn kahan_sum_f32(values: &[f32]) -> f32 {
225    let mut sum = 0.0f32;
226    let mut compensation = 0.0f32;
227    for &val in values {
228        let y = val - compensation;
229        let t = sum + y;
230        compensation = (t - sum) - y;
231        sum = t;
232    }
233    sum
234}
235
236/// Computes the sum of a slice of `f64` values using recursive pairwise summation.
237///
238/// Recursively splits the slice in half and sums each half independently.
239/// Leaves of 32 elements or fewer are reduced with [`kahan_sum_f64`].  This
240/// yields an error bound of O(epsilon * log2(*n*)) with good cache locality.
241///
242/// # Arguments
243///
244/// * `values` -- The slice of `f64` values to sum.
245///
246/// # Returns
247///
248/// The pairwise-compensated sum as `f64`.
249///
250/// # Determinism
251///
252/// Deterministic for a given input slice.  The recursive split point is always
253/// `len / 2`, so the tree structure is fully determined by the length.
254///
255/// # Examples
256///
257/// ```
258/// use cjc_repro::pairwise_sum_f64;
259/// let vals: Vec<f64> = (0..10_000).map(|_| 0.0001).collect();
260/// let sum = pairwise_sum_f64(&vals);
261/// assert!((sum - 1.0).abs() < 1e-10);
262/// ```
263pub fn pairwise_sum_f64(values: &[f64]) -> f64 {
264    if values.len() <= 32 {
265        return kahan_sum_f64(values);
266    }
267    let mid = values.len() / 2;
268    pairwise_sum_f64(&values[..mid]) + pairwise_sum_f64(&values[mid..])
269}
270
271/// Configuration that controls whether deterministic reproducibility is active.
272///
273/// When `enabled` is `true`, the runtime seeds all [`Rng`] instances from
274/// [`seed`](ReproConfig::seed) and enforces deterministic reduction ordering.
275/// When `enabled` is `false`, the seed field is ignored and the runtime may
276/// use a non-deterministic source.
277///
278/// # Examples
279///
280/// ```
281/// use cjc_repro::ReproConfig;
282///
283/// let cfg = ReproConfig::enabled(42);
284/// assert!(cfg.enabled);
285/// assert_eq!(cfg.seed, 42);
286///
287/// let off = ReproConfig::disabled();
288/// assert!(!off.enabled);
289/// ```
290#[derive(Debug, Clone)]
291pub struct ReproConfig {
292    /// Whether reproducibility mode is active.
293    pub enabled: bool,
294    /// The global seed used to initialize all [`Rng`] instances when
295    /// reproducibility is enabled.
296    pub seed: u64,
297}
298
299impl ReproConfig {
300    /// Creates a [`ReproConfig`] with reproducibility **disabled**.
301    ///
302    /// The seed is set to `0` but will not be used by the runtime.
303    pub fn disabled() -> Self {
304        Self {
305            enabled: false,
306            seed: 0,
307        }
308    }
309
310    /// Creates a [`ReproConfig`] with reproducibility **enabled** using the
311    /// given `seed`.
312    ///
313    /// # Arguments
314    ///
315    /// * `seed` -- The global seed that will be threaded through the runtime.
316    pub fn enabled(seed: u64) -> Self {
317        Self {
318            enabled: true,
319            seed,
320        }
321    }
322}
323
324impl Default for ReproConfig {
325    /// Returns [`ReproConfig::disabled()`].
326    fn default() -> Self {
327        Self::disabled()
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn test_rng_deterministic() {
337        let mut rng1 = Rng::seeded(42);
338        let mut rng2 = Rng::seeded(42);
339
340        for _ in 0..100 {
341            assert_eq!(rng1.next_u64(), rng2.next_u64());
342        }
343    }
344
345    #[test]
346    fn test_rng_f64_range() {
347        let mut rng = Rng::seeded(123);
348        for _ in 0..1000 {
349            let v = rng.next_f64();
350            assert!((0.0..1.0).contains(&v));
351        }
352    }
353
354    #[test]
355    fn test_rng_fork_deterministic() {
356        let mut rng1 = Rng::seeded(42);
357        let mut rng2 = Rng::seeded(42);
358
359        let mut fork1 = rng1.fork();
360        let mut fork2 = rng2.fork();
361
362        for _ in 0..50 {
363            assert_eq!(fork1.next_u64(), fork2.next_u64());
364        }
365    }
366
367    #[test]
368    fn test_kahan_sum() {
369        // Sum of many small values where naive sum would lose precision
370        let values: Vec<f64> = (0..10000).map(|_| 0.0001).collect();
371        let result = kahan_sum_f64(&values);
372        assert!((result - 1.0).abs() < 1e-10);
373    }
374
375    #[test]
376    fn test_kahan_sum_f32() {
377        let values: Vec<f32> = (0..10000).map(|_| 0.0001f32).collect();
378        let result = kahan_sum_f32(&values);
379        assert!((result - 1.0).abs() < 1e-4);
380    }
381
382    #[test]
383    fn test_pairwise_sum() {
384        let values: Vec<f64> = (0..10000).map(|_| 0.0001).collect();
385        let result = pairwise_sum_f64(&values);
386        assert!((result - 1.0).abs() < 1e-10);
387    }
388}