cjc_repro/lib.rs
1//! Deterministic computation primitives for CJC.
2//!
3//! This crate provides the foundational building blocks that guarantee
4//! **bit-identical** results across runs, platforms, and thread counts:
5//!
6//! - [`Rng`] -- a SplitMix64 PRNG with explicit seed threading. Same seed
7//! produces the identical sequence on every platform.
8//! - [`KahanAccumulatorF64`] / [`KahanAccumulatorF32`] -- incremental
9//! compensated-summation accumulators (re-exported from the [`kahan`] module).
10//! - [`kahan_sum_f64`] / [`kahan_sum_f32`] -- one-shot compensated summation
11//! over slices.
12//! - [`pairwise_sum_f64`] -- recursive pairwise summation that falls back to
13//! Kahan summation for leaves of 32 elements or fewer.
14//! - [`ReproConfig`] -- a lightweight toggle that carries the reproducibility
15//! seed through the compiler pipeline.
16//!
17//! # Determinism contract
18//!
19//! All primitives in this crate are **serial and deterministic**. When the
20//! same inputs are provided in the same order, the output is bit-for-bit
21//! identical regardless of the host platform, compiler version, or OS.
22//!
23//! No `HashMap`, no FMA, no non-deterministic SIMD reductions.
24
25pub mod kahan;
26pub use kahan::{
27 KahanAccumulatorF32, KahanAccumulatorF64, KahanAccumulatorF64x4, KahanAccumulatorF64x8,
28};
29
30/// Deterministic pseudo-random number generator using the SplitMix64 algorithm.
31///
32/// Guarantees identical sequences for the same seed across all platforms.
33/// SplitMix64 has a period of 2^64 and passes BigCrush statistical tests.
34///
35/// # Determinism
36///
37/// Two [`Rng`] instances created with the same seed will always produce the
38/// exact same sequence of values, regardless of the host OS or architecture.
39/// This is the backbone of CJC's reproducible computation model.
40///
41/// # Examples
42///
43/// ```
44/// use cjc_repro::Rng;
45///
46/// let mut rng = Rng::seeded(42);
47/// let a = rng.next_f64(); // deterministic value in [0, 1)
48/// let b = rng.next_u64(); // deterministic u64
49/// ```
50#[derive(Debug, Clone)]
51pub struct Rng {
52 state: u64,
53}
54
55impl Rng {
56 /// Creates a new [`Rng`] initialized with the given seed.
57 ///
58 /// # Arguments
59 ///
60 /// * `seed` -- The initial state. Seed `0` is valid and produces a
61 /// well-defined sequence.
62 ///
63 /// # Examples
64 ///
65 /// ```
66 /// use cjc_repro::Rng;
67 /// let mut rng = Rng::seeded(0);
68 /// assert_eq!(rng.next_u64(), Rng::seeded(0).next_u64());
69 /// ```
70 pub fn seeded(seed: u64) -> Self {
71 Self { state: seed }
72 }
73
74 /// Generates the next `u64` using the SplitMix64 mixing function.
75 ///
76 /// Advances the internal state by one step and returns a uniformly
77 /// distributed 64-bit value.
78 ///
79 /// # Returns
80 ///
81 /// A deterministic `u64` drawn from the full `0..=u64::MAX` range.
82 pub fn next_u64(&mut self) -> u64 {
83 self.state = self.state.wrapping_add(0x9e3779b97f4a7c15);
84 let mut z = self.state;
85 z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
86 z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
87 z ^ (z >> 31)
88 }
89
90 /// Generates a uniformly distributed `f64` in the half-open interval `[0, 1)`.
91 ///
92 /// Uses the upper 53 bits of [`next_u64`](Self::next_u64) to fill the
93 /// 53-bit mantissa of an IEEE-754 double, then divides by 2^53.
94 ///
95 /// # Returns
96 ///
97 /// A deterministic `f64` satisfying `0.0 <= value < 1.0`.
98 pub fn next_f64(&mut self) -> f64 {
99 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
100 }
101
102 /// Generates a uniformly distributed `f32` in the half-open interval `[0, 1)`.
103 ///
104 /// Uses the upper 24 bits of [`next_u64`](Self::next_u64) to fill the
105 /// 24-bit mantissa of an IEEE-754 single, then divides by 2^24.
106 ///
107 /// # Returns
108 ///
109 /// A deterministic `f32` satisfying `0.0 <= value < 1.0`.
110 pub fn next_f32(&mut self) -> f32 {
111 (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
112 }
113
114 /// Generates a sample from the standard normal distribution (mean 0, variance 1)
115 /// using the Box-Muller transform.
116 ///
117 /// Consumes **two** uniform samples from [`next_f64`](Self::next_f64) per call.
118 /// The transform is: `sqrt(-2 ln(u1)) * cos(2 pi u2)`.
119 ///
120 /// # Returns
121 ///
122 /// A deterministic `f64` drawn from N(0, 1).
123 pub fn next_normal_f64(&mut self) -> f64 {
124 let u1 = self.next_f64();
125 let u2 = self.next_f64();
126 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
127 }
128
129 /// Generates a standard-normal `f32` sample.
130 ///
131 /// Delegates to [`next_normal_f64`](Self::next_normal_f64) and narrows the
132 /// result to `f32`. Consumes two uniform draws from the underlying state.
133 ///
134 /// # Returns
135 ///
136 /// A deterministic `f32` drawn from N(0, 1).
137 pub fn next_normal_f32(&mut self) -> f32 {
138 self.next_normal_f64() as f32
139 }
140
141 /// Forks the RNG into an independent sub-stream.
142 ///
143 /// The returned [`Rng`] is seeded with the next `u64` drawn from `self`,
144 /// so both the parent and the child advance deterministically. This is
145 /// the standard mechanism for giving each CJC closure or parallel lane
146 /// its own reproducible random stream.
147 ///
148 /// # Returns
149 ///
150 /// A new [`Rng`] whose state is derived from the current generator.
151 ///
152 /// # Examples
153 ///
154 /// ```
155 /// use cjc_repro::Rng;
156 /// let mut parent = Rng::seeded(7);
157 /// let mut child = parent.fork();
158 /// // parent and child now produce independent but deterministic sequences
159 /// let _ = child.next_f64();
160 /// ```
161 pub fn fork(&mut self) -> Rng {
162 Rng {
163 state: self.next_u64(),
164 }
165 }
166}
167
168/// Computes the sum of a slice of `f64` values using Kahan compensated summation.
169///
170/// Achieves an error bound of O(epsilon) for *n* summands, compared to O(*n* * epsilon)
171/// for naive left-to-right addition. Uses only two scalar registers (sum and
172/// compensation) with no heap allocation.
173///
174/// # Arguments
175///
176/// * `values` -- The slice of `f64` values to sum.
177///
178/// # Returns
179///
180/// The compensated sum as `f64`.
181///
182/// # Determinism
183///
184/// The result is deterministic for a given input slice. Different orderings of
185/// the same values may yield different (but equally stable) results.
186///
187/// # Examples
188///
189/// ```
190/// use cjc_repro::kahan_sum_f64;
191/// let vals: Vec<f64> = (0..10_000).map(|_| 0.0001).collect();
192/// let sum = kahan_sum_f64(&vals);
193/// assert!((sum - 1.0).abs() < 1e-10);
194/// ```
195pub fn kahan_sum_f64(values: &[f64]) -> f64 {
196 let mut sum = 0.0f64;
197 let mut compensation = 0.0f64;
198 for &val in values {
199 let y = val - compensation;
200 let t = sum + y;
201 compensation = (t - sum) - y;
202 sum = t;
203 }
204 sum
205}
206
207/// Computes the sum of a slice of `f32` values using Kahan compensated summation.
208///
209/// This is the single-precision counterpart to [`kahan_sum_f64`]. The error
210/// bound is O(epsilon) relative to `f32` machine epsilon, with no heap
211/// allocation.
212///
213/// # Arguments
214///
215/// * `values` -- The slice of `f32` values to sum.
216///
217/// # Returns
218///
219/// The compensated sum as `f32`.
220///
221/// # Determinism
222///
223/// Deterministic for a given input slice ordering.
224pub fn kahan_sum_f32(values: &[f32]) -> f32 {
225 let mut sum = 0.0f32;
226 let mut compensation = 0.0f32;
227 for &val in values {
228 let y = val - compensation;
229 let t = sum + y;
230 compensation = (t - sum) - y;
231 sum = t;
232 }
233 sum
234}
235
236/// Computes the sum of a slice of `f64` values using recursive pairwise summation.
237///
238/// Recursively splits the slice in half and sums each half independently.
239/// Leaves of 32 elements or fewer are reduced with [`kahan_sum_f64`]. This
240/// yields an error bound of O(epsilon * log2(*n*)) with good cache locality.
241///
242/// # Arguments
243///
244/// * `values` -- The slice of `f64` values to sum.
245///
246/// # Returns
247///
248/// The pairwise-compensated sum as `f64`.
249///
250/// # Determinism
251///
252/// Deterministic for a given input slice. The recursive split point is always
253/// `len / 2`, so the tree structure is fully determined by the length.
254///
255/// # Examples
256///
257/// ```
258/// use cjc_repro::pairwise_sum_f64;
259/// let vals: Vec<f64> = (0..10_000).map(|_| 0.0001).collect();
260/// let sum = pairwise_sum_f64(&vals);
261/// assert!((sum - 1.0).abs() < 1e-10);
262/// ```
263pub fn pairwise_sum_f64(values: &[f64]) -> f64 {
264 if values.len() <= 32 {
265 return kahan_sum_f64(values);
266 }
267 let mid = values.len() / 2;
268 pairwise_sum_f64(&values[..mid]) + pairwise_sum_f64(&values[mid..])
269}
270
271/// Configuration that controls whether deterministic reproducibility is active.
272///
273/// When `enabled` is `true`, the runtime seeds all [`Rng`] instances from
274/// [`seed`](ReproConfig::seed) and enforces deterministic reduction ordering.
275/// When `enabled` is `false`, the seed field is ignored and the runtime may
276/// use a non-deterministic source.
277///
278/// # Examples
279///
280/// ```
281/// use cjc_repro::ReproConfig;
282///
283/// let cfg = ReproConfig::enabled(42);
284/// assert!(cfg.enabled);
285/// assert_eq!(cfg.seed, 42);
286///
287/// let off = ReproConfig::disabled();
288/// assert!(!off.enabled);
289/// ```
290#[derive(Debug, Clone)]
291pub struct ReproConfig {
292 /// Whether reproducibility mode is active.
293 pub enabled: bool,
294 /// The global seed used to initialize all [`Rng`] instances when
295 /// reproducibility is enabled.
296 pub seed: u64,
297}
298
299impl ReproConfig {
300 /// Creates a [`ReproConfig`] with reproducibility **disabled**.
301 ///
302 /// The seed is set to `0` but will not be used by the runtime.
303 pub fn disabled() -> Self {
304 Self {
305 enabled: false,
306 seed: 0,
307 }
308 }
309
310 /// Creates a [`ReproConfig`] with reproducibility **enabled** using the
311 /// given `seed`.
312 ///
313 /// # Arguments
314 ///
315 /// * `seed` -- The global seed that will be threaded through the runtime.
316 pub fn enabled(seed: u64) -> Self {
317 Self {
318 enabled: true,
319 seed,
320 }
321 }
322}
323
324impl Default for ReproConfig {
325 /// Returns [`ReproConfig::disabled()`].
326 fn default() -> Self {
327 Self::disabled()
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn test_rng_deterministic() {
337 let mut rng1 = Rng::seeded(42);
338 let mut rng2 = Rng::seeded(42);
339
340 for _ in 0..100 {
341 assert_eq!(rng1.next_u64(), rng2.next_u64());
342 }
343 }
344
345 #[test]
346 fn test_rng_f64_range() {
347 let mut rng = Rng::seeded(123);
348 for _ in 0..1000 {
349 let v = rng.next_f64();
350 assert!((0.0..1.0).contains(&v));
351 }
352 }
353
354 #[test]
355 fn test_rng_fork_deterministic() {
356 let mut rng1 = Rng::seeded(42);
357 let mut rng2 = Rng::seeded(42);
358
359 let mut fork1 = rng1.fork();
360 let mut fork2 = rng2.fork();
361
362 for _ in 0..50 {
363 assert_eq!(fork1.next_u64(), fork2.next_u64());
364 }
365 }
366
367 #[test]
368 fn test_kahan_sum() {
369 // Sum of many small values where naive sum would lose precision
370 let values: Vec<f64> = (0..10000).map(|_| 0.0001).collect();
371 let result = kahan_sum_f64(&values);
372 assert!((result - 1.0).abs() < 1e-10);
373 }
374
375 #[test]
376 fn test_kahan_sum_f32() {
377 let values: Vec<f32> = (0..10000).map(|_| 0.0001f32).collect();
378 let result = kahan_sum_f32(&values);
379 assert!((result - 1.0).abs() < 1e-4);
380 }
381
382 #[test]
383 fn test_pairwise_sum() {
384 let values: Vec<f64> = (0..10000).map(|_| 0.0001).collect();
385 let result = pairwise_sum_f64(&values);
386 assert!((result - 1.0).abs() < 1e-10);
387 }
388}