1use std::sync::Arc;
22
23use num_bigint::{BigInt, BigUint, Sign};
24use num_traits::Zero;
25
26use crate::primes::{first_n_primes, gcd, mod_inverse};
27use crate::RAYON_CHANNEL_THRESHOLD;
28
29#[derive(Clone, Debug)]
33pub struct Channels(pub Arc<Vec<u64>>);
34
35impl Channels {
36 pub fn new(moduli: Vec<u64>) -> Self {
41 debug_assert!(
42 Self::pairwise_coprime(&moduli),
43 "RNS channels must be pairwise coprime"
44 );
45 Channels(Arc::new(moduli))
46 }
47
48 pub fn standard(n: usize) -> Self {
50 Channels(Arc::new(first_n_primes(n)))
51 }
52
53 fn pairwise_coprime(moduli: &[u64]) -> bool {
54 for i in 0..moduli.len() {
55 for j in (i + 1)..moduli.len() {
56 if gcd(moduli[i], moduli[j]) != 1 {
57 return false;
58 }
59 }
60 }
61 true
62 }
63
64 #[inline]
66 pub fn len(&self) -> usize {
67 self.0.len()
68 }
69
70 #[inline]
72 pub fn is_empty(&self) -> bool {
73 self.0.is_empty()
74 }
75
76 #[inline]
78 pub fn modulus(&self, c: usize) -> u64 {
79 self.0[c]
80 }
81
82 #[inline]
84 pub fn moduli(&self) -> &[u64] {
85 &self.0
86 }
87
88 pub fn capacity(&self) -> BigUint {
90 self.0.iter().map(|&m| BigUint::from(m)).product()
91 }
92
93 pub fn signed_capacity(&self) -> BigInt {
95 BigInt::from(self.capacity() / BigUint::from(2u8))
96 }
97}
98
99impl PartialEq for Channels {
100 fn eq(&self, other: &Self) -> bool {
101 Arc::ptr_eq(&self.0, &other.0) || self.0 == other.0
102 }
103}
104impl Eq for Channels {}
105
106#[derive(Clone, Debug)]
108pub struct RnsInt {
109 pub residues: Vec<u64>,
111 pub channels: Channels,
112 pub negative: bool,
114}
115
116impl RnsInt {
117 pub fn from_bigint(n: &BigInt, channels: Channels) -> Self {
119 let negative = n.sign() == Sign::Minus;
120 let residues = channels
121 .moduli()
122 .iter()
123 .map(|&m| {
124 let mm = BigInt::from(m);
125 let r = ((n % &mm) + &mm) % &mm;
127 r.to_biguint().unwrap().try_into().unwrap()
128 })
129 .collect();
130 RnsInt {
131 residues,
132 channels,
133 negative,
134 }
135 }
136
137 pub fn from_i64(n: i64, channels: Channels) -> Self {
139 Self::from_bigint(&BigInt::from(n), channels)
140 }
141
142 pub fn zero(channels: Channels) -> Self {
144 RnsInt {
145 residues: vec![0; channels.len()],
146 channels,
147 negative: false,
148 }
149 }
150
151 pub fn from_residues(residues: Vec<u64>, channels: Channels) -> Self {
154 Self::finish(residues, channels)
155 }
156
157 pub fn to_bigint(&self) -> BigInt {
159 let u = garner_crt(&self.residues, self.channels.moduli());
160 let m = self.channels.capacity();
161 if &u * 2u8 > m {
163 BigInt::from_biguint(Sign::Plus, u) - BigInt::from_biguint(Sign::Plus, m)
164 } else {
165 BigInt::from_biguint(Sign::Plus, u)
166 }
167 }
168
169 pub fn is_zero(&self) -> bool {
171 self.residues.iter().all(|&r| r == 0)
172 }
173
174 pub fn add(&self, other: &Self) -> Self {
176 let out = channel_map(&self.residues, &other.residues, self.channels.moduli(), |a, b, m| {
177 (a + b) % m
178 });
179 Self::finish(out, self.channels.clone())
180 }
181
182 pub fn sub(&self, other: &Self) -> Self {
184 let out = channel_map(&self.residues, &other.residues, self.channels.moduli(), |a, b, m| {
185 (a + m - b) % m
186 });
187 Self::finish(out, self.channels.clone())
188 }
189
190 pub fn mul(&self, other: &Self) -> Self {
192 let out = channel_map(&self.residues, &other.residues, self.channels.moduli(), |a, b, m| {
193 gpu_mul_channel(a, b, m)
194 });
195 Self::finish(out, self.channels.clone())
196 }
197
198 pub fn neg(&self) -> Self {
200 Self::zero(self.channels.clone()).sub(self)
201 }
202
203 fn finish(residues: Vec<u64>, channels: Channels) -> Self {
205 let mut v = RnsInt {
206 residues,
207 channels,
208 negative: false,
209 };
210 v.negative = v.to_bigint().sign() == Sign::Minus;
211 v
212 }
213}
214
215fn channel_map(
217 a: &[u64],
218 b: &[u64],
219 moduli: &[u64],
220 f: impl Fn(u64, u64, u64) -> u64 + Sync + Send,
221) -> Vec<u64> {
222 use rayon::prelude::*;
223 if a.len() >= RAYON_CHANNEL_THRESHOLD {
224 a.par_iter()
225 .zip(b.par_iter())
226 .zip(moduli.par_iter())
227 .map(|((&av, &bv), &m)| f(av, bv, m))
228 .collect()
229 } else {
230 a.iter()
231 .zip(b.iter())
232 .zip(moduli.iter())
233 .map(|((&av, &bv), &m)| f(av, bv, m))
234 .collect()
235 }
236}
237
238pub fn garner_crt(residues: &[u64], moduli: &[u64]) -> BigUint {
242 let k = residues.len();
243 assert_eq!(k, moduli.len(), "residue/moduli length mismatch");
244 if k == 0 {
245 return BigUint::zero();
246 }
247
248 let mut c: Vec<u64> = residues.to_vec();
250 for i in 0..k {
251 for j in 0..i {
252 let mi = moduli[i];
253 let inv = mod_inverse(moduli[j] % mi, mi)
255 .expect("channels must be pairwise coprime for CRT");
256 let diff = (c[i] + mi - (c[j] % mi)) % mi;
257 c[i] = ((diff as u128 * inv as u128) % mi as u128) as u64;
258 }
259 }
260
261 let mut result = BigUint::from(c[k - 1]);
263 for i in (0..k - 1).rev() {
264 result = result * BigUint::from(moduli[i]) + BigUint::from(c[i]);
265 }
266 result
267}
268
269#[inline]
271pub fn gpu_add_channel(a: u64, b: u64, m: u64) -> u64 {
272 (a + b) % m
273}
274
275#[inline]
277pub fn gpu_mul_channel(a: u64, b: u64, m: u64) -> u64 {
278 ((a as u128 * b as u128) % m as u128) as u64
279}
280
281pub const MAX_SAFE_MODULUS: u64 = 65535;
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 fn ch() -> Channels {
289 Channels::standard(16)
290 }
291
292 #[test]
293 fn roundtrip_positive() {
294 let a = RnsInt::from_i64(123_456_789, ch());
295 assert_eq!(a.to_bigint(), BigInt::from(123_456_789));
296 }
297
298 #[test]
299 fn roundtrip_negative() {
300 let a = RnsInt::from_i64(-42, ch());
301 assert!(a.negative);
302 assert_eq!(a.to_bigint(), BigInt::from(-42));
303 }
304
305 #[test]
306 fn add_sub_mul() {
307 let a = RnsInt::from_i64(1000, ch());
308 let b = RnsInt::from_i64(337, ch());
309 assert_eq!(a.add(&b).to_bigint(), BigInt::from(1337));
310 assert_eq!(a.sub(&b).to_bigint(), BigInt::from(663));
311 assert_eq!(b.sub(&a).to_bigint(), BigInt::from(-663));
312 assert_eq!(a.mul(&b).to_bigint(), BigInt::from(337_000));
313 }
314
315 #[test]
316 fn garner_classic() {
317 assert_eq!(garner_crt(&[2, 3, 2], &[3, 5, 7]), BigUint::from(23u8));
319 assert_eq!(garner_crt(&[0, 1, 0], &[2, 3, 5]), BigUint::from(10u8));
321 }
322
323 #[test]
324 fn is_zero_works() {
325 assert!(RnsInt::zero(ch()).is_zero());
326 assert!(!RnsInt::from_i64(1, ch()).is_zero());
327 }
328}