Skip to main content

adele_ring/
rns.rs

1//! Level 0 — ℤ. RNS integers and the core CRT reconstruction.
2//!
3//! # Representation
4//!
5//! A number lives as a tuple of residues, one per prime channel. By the Chinese
6//! Remainder Theorem any integer in `[0, M)` (where `M = ∏ moduli`) is uniquely
7//! identified by its residue tuple. We use the **symmetric** residue system:
8//! residues are stored in `[0, m)`, and on reconstruction a value `U` with
9//! `2U > M` is interpreted as the negative number `U - M`. This lets the signed
10//! range be `(-M/2, M/2)` while keeping per-channel arithmetic a pure modular op
11//! — exactly the property that makes RNS embarrassingly parallel.
12//!
13//! # CPU vs GPU
14//!
15//! Single `RnsInt` operations parallelize over channels with rayon only when
16//! `k >= RAYON_CHANNEL_THRESHOLD`; below that the task overhead dominates and we
17//! stay sequential. For *batches* of `RnsInt`, always go through
18//! [`crate::backend::executor`], which auto-selects CPU rayon or the GPU based on
19//! the batch size. Never call a backend directly from the math layers.
20
21use std::sync::Arc;
22
23use num_bigint::{BigInt, BigUint, Sign};
24use num_traits::Zero;
25
26use crate::primes::{first_n_primes, gcd, mod_inverse};
27use crate::RAYON_CHANNEL_THRESHOLD;
28
29/// A shared, immutable set of pairwise-coprime moduli (the RNS "channels").
30///
31/// Cheap to clone — it is just an `Arc<Vec<u64>>` behind a newtype.
32#[derive(Clone, Debug)]
33pub struct Channels(pub Arc<Vec<u64>>);
34
35impl Channels {
36    /// Build channels from explicit moduli.
37    ///
38    /// In debug builds this asserts the moduli are pairwise coprime (the CRT
39    /// requirement); in release builds the check is skipped for speed.
40    pub fn new(moduli: Vec<u64>) -> Self {
41        debug_assert!(
42            Self::pairwise_coprime(&moduli),
43            "RNS channels must be pairwise coprime"
44        );
45        Channels(Arc::new(moduli))
46    }
47
48    /// The first `n` primes as channels — the standard configuration.
49    pub fn standard(n: usize) -> Self {
50        Channels(Arc::new(first_n_primes(n)))
51    }
52
53    fn pairwise_coprime(moduli: &[u64]) -> bool {
54        for i in 0..moduli.len() {
55            for j in (i + 1)..moduli.len() {
56                if gcd(moduli[i], moduli[j]) != 1 {
57                    return false;
58                }
59            }
60        }
61        true
62    }
63
64    /// Number of channels `k`.
65    #[inline]
66    pub fn len(&self) -> usize {
67        self.0.len()
68    }
69
70    /// Whether there are no channels.
71    #[inline]
72    pub fn is_empty(&self) -> bool {
73        self.0.is_empty()
74    }
75
76    /// The modulus of channel `c`.
77    #[inline]
78    pub fn modulus(&self, c: usize) -> u64 {
79        self.0[c]
80    }
81
82    /// The moduli as a slice.
83    #[inline]
84    pub fn moduli(&self) -> &[u64] {
85        &self.0
86    }
87
88    /// Total dynamic range `M = ∏ moduli`.
89    pub fn capacity(&self) -> BigUint {
90        self.0.iter().map(|&m| BigUint::from(m)).product()
91    }
92
93    /// Signed range bound `⌊M/2⌋`: values in `(-bound, bound]` are representable.
94    pub fn signed_capacity(&self) -> BigInt {
95        BigInt::from(self.capacity() / BigUint::from(2u8))
96    }
97}
98
99impl PartialEq for Channels {
100    fn eq(&self, other: &Self) -> bool {
101        Arc::ptr_eq(&self.0, &other.0) || self.0 == other.0
102    }
103}
104impl Eq for Channels {}
105
106/// An exact integer in RNS form (Level 0 of the tower).
107#[derive(Clone, Debug)]
108pub struct RnsInt {
109    /// `residues[i] = value mod channels[i]`, stored in `[0, m)`.
110    pub residues: Vec<u64>,
111    pub channels: Channels,
112    /// Sign hint: `true` when the represented (symmetric) value is negative.
113    pub negative: bool,
114}
115
116impl RnsInt {
117    /// Construct from an arbitrary `BigInt`.
118    pub fn from_bigint(n: &BigInt, channels: Channels) -> Self {
119        let negative = n.sign() == Sign::Minus;
120        let residues = channels
121            .moduli()
122            .iter()
123            .map(|&m| {
124                let mm = BigInt::from(m);
125                // Euclidean remainder in [0, m).
126                let r = ((n % &mm) + &mm) % &mm;
127                r.to_biguint().unwrap().try_into().unwrap()
128            })
129            .collect();
130        RnsInt {
131            residues,
132            channels,
133            negative,
134        }
135    }
136
137    /// Construct from a machine integer.
138    pub fn from_i64(n: i64, channels: Channels) -> Self {
139        Self::from_bigint(&BigInt::from(n), channels)
140    }
141
142    /// Additive identity.
143    pub fn zero(channels: Channels) -> Self {
144        RnsInt {
145            residues: vec![0; channels.len()],
146            channels,
147            negative: false,
148        }
149    }
150
151    /// Build directly from raw channel residues, recomputing the sign hint.
152    /// The residues must already be reduced into `[0, m)` for each channel.
153    pub fn from_residues(residues: Vec<u64>, channels: Channels) -> Self {
154        Self::finish(residues, channels)
155    }
156
157    /// Reconstruct the exact signed value via Garner CRT + symmetric folding.
158    pub fn to_bigint(&self) -> BigInt {
159        let u = garner_crt(&self.residues, self.channels.moduli());
160        let m = self.channels.capacity();
161        // Symmetric range: if 2u > M, the value is u - M (negative).
162        if &u * 2u8 > m {
163            BigInt::from_biguint(Sign::Plus, u) - BigInt::from_biguint(Sign::Plus, m)
164        } else {
165            BigInt::from_biguint(Sign::Plus, u)
166        }
167    }
168
169    /// `true` iff every residue is zero.
170    pub fn is_zero(&self) -> bool {
171        self.residues.iter().all(|&r| r == 0)
172    }
173
174    /// Channel-wise modular addition.
175    pub fn add(&self, other: &Self) -> Self {
176        let out = channel_map(&self.residues, &other.residues, self.channels.moduli(), |a, b, m| {
177            (a + b) % m
178        });
179        Self::finish(out, self.channels.clone())
180    }
181
182    /// Channel-wise modular subtraction.
183    pub fn sub(&self, other: &Self) -> Self {
184        let out = channel_map(&self.residues, &other.residues, self.channels.moduli(), |a, b, m| {
185            (a + m - b) % m
186        });
187        Self::finish(out, self.channels.clone())
188    }
189
190    /// Channel-wise modular multiplication (uses `u128` intermediate).
191    pub fn mul(&self, other: &Self) -> Self {
192        let out = channel_map(&self.residues, &other.residues, self.channels.moduli(), |a, b, m| {
193            gpu_mul_channel(a, b, m)
194        });
195        Self::finish(out, self.channels.clone())
196    }
197
198    /// Additive inverse.
199    pub fn neg(&self) -> Self {
200        Self::zero(self.channels.clone()).sub(self)
201    }
202
203    /// Build from raw residues and recompute the sign hint.
204    fn finish(residues: Vec<u64>, channels: Channels) -> Self {
205        let mut v = RnsInt {
206            residues,
207            channels,
208            negative: false,
209        };
210        v.negative = v.to_bigint().sign() == Sign::Minus;
211        v
212    }
213}
214
215/// Apply `f(a, b, m)` channel-wise, parallelizing only above the threshold.
216fn channel_map(
217    a: &[u64],
218    b: &[u64],
219    moduli: &[u64],
220    f: impl Fn(u64, u64, u64) -> u64 + Sync + Send,
221) -> Vec<u64> {
222    use rayon::prelude::*;
223    if a.len() >= RAYON_CHANNEL_THRESHOLD {
224        a.par_iter()
225            .zip(b.par_iter())
226            .zip(moduli.par_iter())
227            .map(|((&av, &bv), &m)| f(av, bv, m))
228            .collect()
229    } else {
230        a.iter()
231            .zip(b.iter())
232            .zip(moduli.iter())
233            .map(|((&av, &bv), &m)| f(av, bv, m))
234            .collect()
235    }
236}
237
238/// Garner's algorithm: reconstruct the unsigned integer in `[0, M)` from its
239/// residues. Never materializes the large basis elements `M/m_i`; all
240/// intermediates stay small (each fits within its own modulus).
241pub fn garner_crt(residues: &[u64], moduli: &[u64]) -> BigUint {
242    let k = residues.len();
243    assert_eq!(k, moduli.len(), "residue/moduli length mismatch");
244    if k == 0 {
245        return BigUint::zero();
246    }
247
248    // Step 1 — mixed-radix coefficients via forward substitution.
249    let mut c: Vec<u64> = residues.to_vec();
250    for i in 0..k {
251        for j in 0..i {
252            let mi = moduli[i];
253            // c[i] = (c[i] - c[j]) * inv(m[j], m[i])  (mod m[i])
254            let inv = mod_inverse(moduli[j] % mi, mi)
255                .expect("channels must be pairwise coprime for CRT");
256            let diff = (c[i] + mi - (c[j] % mi)) % mi;
257            c[i] = ((diff as u128 * inv as u128) % mi as u128) as u64;
258        }
259    }
260
261    // Step 2 — Horner reconstruction into a single BigUint.
262    let mut result = BigUint::from(c[k - 1]);
263    for i in (0..k - 1).rev() {
264        result = result * BigUint::from(moduli[i]) + BigUint::from(c[i]);
265    }
266    result
267}
268
269/// Reference implementation of one GPU thread's add: `(a + b) % m`.
270#[inline]
271pub fn gpu_add_channel(a: u64, b: u64, m: u64) -> u64 {
272    (a + b) % m
273}
274
275/// Reference implementation of one GPU thread's multiply: `(a * b) % m`.
276#[inline]
277pub fn gpu_mul_channel(a: u64, b: u64, m: u64) -> u64 {
278    ((a as u128 * b as u128) % m as u128) as u64
279}
280
281/// Largest modulus for which `(a*b)` fits in `u32` on the GPU (`< 2^16`).
282pub const MAX_SAFE_MODULUS: u64 = 65535;
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    fn ch() -> Channels {
289        Channels::standard(16)
290    }
291
292    #[test]
293    fn roundtrip_positive() {
294        let a = RnsInt::from_i64(123_456_789, ch());
295        assert_eq!(a.to_bigint(), BigInt::from(123_456_789));
296    }
297
298    #[test]
299    fn roundtrip_negative() {
300        let a = RnsInt::from_i64(-42, ch());
301        assert!(a.negative);
302        assert_eq!(a.to_bigint(), BigInt::from(-42));
303    }
304
305    #[test]
306    fn add_sub_mul() {
307        let a = RnsInt::from_i64(1000, ch());
308        let b = RnsInt::from_i64(337, ch());
309        assert_eq!(a.add(&b).to_bigint(), BigInt::from(1337));
310        assert_eq!(a.sub(&b).to_bigint(), BigInt::from(663));
311        assert_eq!(b.sub(&a).to_bigint(), BigInt::from(-663));
312        assert_eq!(a.mul(&b).to_bigint(), BigInt::from(337_000));
313    }
314
315    #[test]
316    fn garner_classic() {
317        // x ≡ 2 (3), 3 (5), 2 (7)  =>  x = 23
318        assert_eq!(garner_crt(&[2, 3, 2], &[3, 5, 7]), BigUint::from(23u8));
319        // x ≡ 0 (2), 1 (3), 0 (5)  =>  x = 10
320        assert_eq!(garner_crt(&[0, 1, 0], &[2, 3, 5]), BigUint::from(10u8));
321    }
322
323    #[test]
324    fn is_zero_works() {
325        assert!(RnsInt::zero(ch()).is_zero());
326        assert!(!RnsInt::from_i64(1, ch()).is_zero());
327    }
328}