1#![warn(missing_docs, unused_imports)]
2
3use crate::{zq::Modulus, Error, Result};
6use itertools::{izip, Itertools};
7use ndarray::ArrayView1;
8use num_bigint::BigUint;
9use num_bigint_dig::{BigInt as BigIntDig, BigUint as BigUintDig, ExtendedGcd, ModInverse};
10use num_traits::{cast::ToPrimitive, One, Zero};
11use std::{cmp::Ordering, fmt::Debug};
12
13mod scaler;
14
15pub use scaler::{RnsScaler, ScalingFactor};
16
17#[derive(Default, Clone, PartialEq, Eq)]
19pub struct RnsContext {
20 moduli_u64: Vec<u64>,
21 moduli: Vec<Modulus>,
22 q_tilde: Vec<u64>,
23 q_tilde_shoup: Vec<u64>,
24 q_star: Vec<BigUint>,
25 garner: Vec<BigUint>,
26 product: BigUint,
27}
28
29impl Debug for RnsContext {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("RnsContext")
32 .field("moduli_u64", &self.moduli_u64)
33 .field("product", &self.product)
39 .finish()
40 }
41}
42
43impl RnsContext {
44 pub fn new(moduli_u64: &[u64]) -> Result<Self> {
48 if moduli_u64.is_empty() {
49 Err(Error::Default("The list of moduli is empty".to_string()))
50 } else {
51 let mut product = BigUint::one();
52 let mut product_dig = BigUintDig::one();
53
54 for i in 0..moduli_u64.len() {
55 for j in 0..moduli_u64.len() {
57 if i != j {
58 let (d, _, _) = BigUintDig::from(moduli_u64[i])
59 .extended_gcd(&BigUintDig::from(moduli_u64[j]));
60 if d.cmp(&BigIntDig::from(1)) != Ordering::Equal {
61 return Err(Error::Default("The moduli are not coprime".to_string()));
62 }
63 }
64 }
65
66 product *= &BigUint::from(moduli_u64[i]);
67 product_dig *= &BigUintDig::from(moduli_u64[i]);
68 }
69
70 #[allow(clippy::type_complexity)]
71 let (moduli, q_tilde, q_tilde_shoup, q_star, garner): (
72 Vec<Modulus>,
73 Vec<u64>,
74 Vec<u64>,
75 Vec<BigUint>,
76 Vec<BigUint>,
77 ) = moduli_u64
78 .iter()
79 .map(|modulus| {
80 let m = Modulus::new(*modulus)?;
81 let q_star_i = &product / modulus;
82 let q_tilde_i = (&product_dig / modulus)
83 .mod_inverse(&BigUintDig::from(*modulus))
84 .unwrap()
85 .to_u64()
86 .unwrap();
87 let garner_i = &q_star_i * q_tilde_i;
88 let q_tilde_shoup_i = m.shoup(q_tilde_i);
89 Ok((m, q_tilde_i, q_tilde_shoup_i, q_star_i, garner_i))
90 })
91 .collect::<Result<Vec<_>>>()?
92 .into_iter()
93 .multiunzip();
94
95 Ok(Self {
96 moduli_u64: moduli_u64.to_owned(),
97 moduli,
98 q_tilde,
99 q_tilde_shoup,
100 q_star,
101 garner,
102 product,
103 })
104 }
105 }
106
107 pub const fn modulus(&self) -> &BigUint {
109 &self.product
110 }
111
112 pub fn project(&self, a: &BigUint) -> Vec<u64> {
114 self.moduli_u64
115 .iter()
116 .map(|modulus| (a % modulus).to_u64().unwrap())
117 .collect()
118 }
119
120 pub fn lift(&self, rests: ArrayView1<u64>) -> BigUint {
125 let mut result = BigUint::zero();
126 izip!(rests.iter(), self.garner.iter())
127 .for_each(|(r_i, garner_i)| result += garner_i * *r_i);
128 result % &self.product
129 }
130
131 pub fn get_garner(&self, i: usize) -> Option<&BigUint> {
133 self.garner.get(i)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139
140 use std::error::Error;
141
142 use super::RnsContext;
143 use ndarray::ArrayView1;
144 use num_bigint::BigUint;
145 use rand::RngCore;
146
147 #[test]
148 fn constructor() {
149 assert!(RnsContext::new(&[2]).is_ok());
150 assert!(RnsContext::new(&[2, 3]).is_ok());
151 assert!(RnsContext::new(&[4, 15, 1153]).is_ok());
152
153 let e = RnsContext::new(&[]);
154 assert!(e.is_err());
155 assert_eq!(e.unwrap_err().to_string(), "The list of moduli is empty");
156 let e = RnsContext::new(&[2, 4]);
157 assert!(e.is_err());
158 assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
159 let e = RnsContext::new(&[2, 3, 5, 30]);
160 assert!(e.is_err());
161 assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
162 }
163
164 #[test]
165 fn garner() -> Result<(), Box<dyn Error>> {
166 let rns = RnsContext::new(&[4, 15, 1153])?;
167
168 for i in 0..3 {
169 let gi = rns.get_garner(i);
170 assert!(gi.is_some());
171 assert_eq!(gi.unwrap(), &rns.garner[i]);
172 }
173 assert!(rns.get_garner(3).is_none());
174
175 Ok(())
176 }
177
178 #[test]
179 fn modulus() -> Result<(), Box<dyn Error>> {
180 let mut rns = RnsContext::new(&[2])?;
181 debug_assert_eq!(rns.modulus(), &BigUint::from(2u64));
182
183 rns = RnsContext::new(&[2, 5])?;
184 debug_assert_eq!(rns.modulus(), &BigUint::from(2u64 * 5));
185
186 rns = RnsContext::new(&[4, 15, 1153])?;
187 debug_assert_eq!(rns.modulus(), &BigUint::from(4u64 * 15 * 1153));
188
189 Ok(())
190 }
191
192 #[test]
193 fn project_lift() -> Result<(), Box<dyn Error>> {
194 let ntests = 100;
195 let rns = RnsContext::new(&[4, 15, 1153])?;
196 let product = 4u64 * 15 * 1153;
197
198 let mut rests = rns.project(&BigUint::from(0u64));
199 assert_eq!(&rests, &[0u64, 0, 0]);
200 assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(0u64));
201
202 rests = rns.project(&BigUint::from(4u64));
203 assert_eq!(&rests, &[0u64, 4, 4]);
204 assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(4u64));
205
206 rests = rns.project(&BigUint::from(15u64));
207 assert_eq!(&rests, &[3u64, 0, 15]);
208 assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(15u64));
209
210 rests = rns.project(&BigUint::from(1153u64));
211 assert_eq!(&rests, &[1u64, 13, 0]);
212 assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(1153u64));
213
214 rests = rns.project(&BigUint::from(product - 1));
215 assert_eq!(&rests, &[3u64, 14, 1152]);
216 assert_eq!(
217 rns.lift(ArrayView1::from(&rests)),
218 BigUint::from(product - 1)
219 );
220
221 let mut rng = rand::rng();
222
223 for _ in 0..ntests {
224 let b = BigUint::from(rng.next_u64() % product);
225 rests = rns.project(&b);
226 assert_eq!(rns.lift(ArrayView1::from(&rests)), b);
227 }
228
229 Ok(())
230 }
231}