fhe_math/rns/
mod.rs

1#![warn(missing_docs, unused_imports)]
2
3//! Residue-Number System operations.
4
5use crate::{zq::Modulus, Error, Result};
6use itertools::{izip, Itertools};
7use ndarray::ArrayView1;
8use num_bigint::BigUint;
9use num_bigint_dig::{BigInt as BigIntDig, BigUint as BigUintDig, ExtendedGcd, ModInverse};
10use num_traits::{cast::ToPrimitive, One, Zero};
11use std::{cmp::Ordering, fmt::Debug};
12
13mod scaler;
14
15pub use scaler::{RnsScaler, ScalingFactor};
16
17/// Context for a Residue Number System.
18#[derive(Default, Clone, PartialEq, Eq)]
19pub struct RnsContext {
20    moduli_u64: Vec<u64>,
21    moduli: Vec<Modulus>,
22    q_tilde: Vec<u64>,
23    q_tilde_shoup: Vec<u64>,
24    q_star: Vec<BigUint>,
25    garner: Vec<BigUint>,
26    product: BigUint,
27}
28
29impl Debug for RnsContext {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("RnsContext")
32            .field("moduli_u64", &self.moduli_u64)
33            // .field("moduli", &self.moduli)
34            // .field("q_tilde", &self.q_tilde)
35            // .field("q_tilde_shoup", &self.q_tilde_shoup)
36            // .field("q_star", &self.q_star)
37            // .field("garner", &self.garner)
38            .field("product", &self.product)
39            .finish()
40    }
41}
42
43impl RnsContext {
44    /// Create a RNS context from a list of moduli.
45    ///
46    /// Returns an error if the list is empty, or if the moduli are no coprime.
47    pub fn new(moduli_u64: &[u64]) -> Result<Self> {
48        if moduli_u64.is_empty() {
49            Err(Error::Default("The list of moduli is empty".to_string()))
50        } else {
51            let mut product = BigUint::one();
52            let mut product_dig = BigUintDig::one();
53
54            for i in 0..moduli_u64.len() {
55                // Return an error if the moduli are not coprime.
56                for j in 0..moduli_u64.len() {
57                    if i != j {
58                        let (d, _, _) = BigUintDig::from(moduli_u64[i])
59                            .extended_gcd(&BigUintDig::from(moduli_u64[j]));
60                        if d.cmp(&BigIntDig::from(1)) != Ordering::Equal {
61                            return Err(Error::Default("The moduli are not coprime".to_string()));
62                        }
63                    }
64                }
65
66                product *= &BigUint::from(moduli_u64[i]);
67                product_dig *= &BigUintDig::from(moduli_u64[i]);
68            }
69
70            #[allow(clippy::type_complexity)]
71            let (moduli, q_tilde, q_tilde_shoup, q_star, garner): (
72                Vec<Modulus>,
73                Vec<u64>,
74                Vec<u64>,
75                Vec<BigUint>,
76                Vec<BigUint>,
77            ) = moduli_u64
78                .iter()
79                .map(|modulus| {
80                    let m = Modulus::new(*modulus)?;
81                    let q_star_i = &product / modulus;
82                    let q_tilde_i = (&product_dig / modulus)
83                        .mod_inverse(&BigUintDig::from(*modulus))
84                        .unwrap()
85                        .to_u64()
86                        .unwrap();
87                    let garner_i = &q_star_i * q_tilde_i;
88                    let q_tilde_shoup_i = m.shoup(q_tilde_i);
89                    Ok((m, q_tilde_i, q_tilde_shoup_i, q_star_i, garner_i))
90                })
91                .collect::<Result<Vec<_>>>()?
92                .into_iter()
93                .multiunzip();
94
95            Ok(Self {
96                moduli_u64: moduli_u64.to_owned(),
97                moduli,
98                q_tilde,
99                q_tilde_shoup,
100                q_star,
101                garner,
102                product,
103            })
104        }
105    }
106
107    /// Returns the product of the moduli used when creating the RNS context.
108    pub const fn modulus(&self) -> &BigUint {
109        &self.product
110    }
111
112    /// Project a BigUint into its rests.
113    pub fn project(&self, a: &BigUint) -> Vec<u64> {
114        self.moduli_u64
115            .iter()
116            .map(|modulus| (a % modulus).to_u64().unwrap())
117            .collect()
118    }
119
120    /// Lift rests into a BigUint.
121    ///
122    /// Aborts if the number of rests is different than the number of moduli in
123    /// debug mode.
124    pub fn lift(&self, rests: ArrayView1<u64>) -> BigUint {
125        let mut result = BigUint::zero();
126        izip!(rests.iter(), self.garner.iter())
127            .for_each(|(r_i, garner_i)| result += garner_i * *r_i);
128        result % &self.product
129    }
130
131    /// Getter for the i-th garner coefficient.
132    pub fn get_garner(&self, i: usize) -> Option<&BigUint> {
133        self.garner.get(i)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139
140    use std::error::Error;
141
142    use super::RnsContext;
143    use ndarray::ArrayView1;
144    use num_bigint::BigUint;
145    use rand::RngCore;
146
147    #[test]
148    fn constructor() {
149        assert!(RnsContext::new(&[2]).is_ok());
150        assert!(RnsContext::new(&[2, 3]).is_ok());
151        assert!(RnsContext::new(&[4, 15, 1153]).is_ok());
152
153        let e = RnsContext::new(&[]);
154        assert!(e.is_err());
155        assert_eq!(e.unwrap_err().to_string(), "The list of moduli is empty");
156        let e = RnsContext::new(&[2, 4]);
157        assert!(e.is_err());
158        assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
159        let e = RnsContext::new(&[2, 3, 5, 30]);
160        assert!(e.is_err());
161        assert_eq!(e.unwrap_err().to_string(), "The moduli are not coprime");
162    }
163
164    #[test]
165    fn garner() -> Result<(), Box<dyn Error>> {
166        let rns = RnsContext::new(&[4, 15, 1153])?;
167
168        for i in 0..3 {
169            let gi = rns.get_garner(i);
170            assert!(gi.is_some());
171            assert_eq!(gi.unwrap(), &rns.garner[i]);
172        }
173        assert!(rns.get_garner(3).is_none());
174
175        Ok(())
176    }
177
178    #[test]
179    fn modulus() -> Result<(), Box<dyn Error>> {
180        let mut rns = RnsContext::new(&[2])?;
181        debug_assert_eq!(rns.modulus(), &BigUint::from(2u64));
182
183        rns = RnsContext::new(&[2, 5])?;
184        debug_assert_eq!(rns.modulus(), &BigUint::from(2u64 * 5));
185
186        rns = RnsContext::new(&[4, 15, 1153])?;
187        debug_assert_eq!(rns.modulus(), &BigUint::from(4u64 * 15 * 1153));
188
189        Ok(())
190    }
191
192    #[test]
193    fn project_lift() -> Result<(), Box<dyn Error>> {
194        let ntests = 100;
195        let rns = RnsContext::new(&[4, 15, 1153])?;
196        let product = 4u64 * 15 * 1153;
197
198        let mut rests = rns.project(&BigUint::from(0u64));
199        assert_eq!(&rests, &[0u64, 0, 0]);
200        assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(0u64));
201
202        rests = rns.project(&BigUint::from(4u64));
203        assert_eq!(&rests, &[0u64, 4, 4]);
204        assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(4u64));
205
206        rests = rns.project(&BigUint::from(15u64));
207        assert_eq!(&rests, &[3u64, 0, 15]);
208        assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(15u64));
209
210        rests = rns.project(&BigUint::from(1153u64));
211        assert_eq!(&rests, &[1u64, 13, 0]);
212        assert_eq!(rns.lift(ArrayView1::from(&rests)), BigUint::from(1153u64));
213
214        rests = rns.project(&BigUint::from(product - 1));
215        assert_eq!(&rests, &[3u64, 14, 1152]);
216        assert_eq!(
217            rns.lift(ArrayView1::from(&rests)),
218            BigUint::from(product - 1)
219        );
220
221        let mut rng = rand::rng();
222
223        for _ in 0..ntests {
224            let b = BigUint::from(rng.next_u64() % product);
225            rests = rns.project(&b);
226            assert_eq!(rns.lift(ArrayView1::from(&rests)), b);
227        }
228
229        Ok(())
230    }
231}