fhe_math/rns/
scaler.rs

1#![warn(missing_docs, unused_imports)]
2
3//! RNS scaler inspired from Remark 3.2 of <https://eprint.iacr.org/2021/204.pdf>.
4
5use super::RnsContext;
6use ethnum::{u256, U256};
7use itertools::{izip, Itertools};
8use ndarray::{ArrayView1, ArrayViewMut1};
9use num_bigint::BigUint;
10use num_traits::{One, ToPrimitive, Zero};
11use std::{cmp::min, sync::Arc};
12
13/// Scaling factor when performing a RNS scaling.
14#[derive(Default, Debug, Clone, PartialEq, Eq)]
15pub struct ScalingFactor {
16    numerator: BigUint,
17    denominator: BigUint,
18    pub(crate) is_one: bool,
19}
20
21impl ScalingFactor {
22    /// Create a new scaling factor. Aborts if the denominator is 0.
23    pub fn new(numerator: &BigUint, denominator: &BigUint) -> Self {
24        assert_ne!(denominator, &BigUint::zero());
25        Self {
26            numerator: numerator.clone(),
27            denominator: denominator.clone(),
28            is_one: numerator == denominator,
29        }
30    }
31
32    /// Returns the identity element of `Self`.
33    pub fn one() -> Self {
34        Self {
35            numerator: BigUint::one(),
36            denominator: BigUint::one(),
37            is_one: true,
38        }
39    }
40}
41
42/// Scaler for a RNS context.
43/// This is a helper struct to perform RNS scaling.
44#[derive(Default, Debug, Clone, PartialEq, Eq)]
45pub struct RnsScaler {
46    from: Arc<RnsContext>,
47    to: Arc<RnsContext>,
48    scaling_factor: ScalingFactor,
49
50    gamma: Box<[u64]>,
51    gamma_shoup: Box<[u64]>,
52    theta_gamma_lo: u64,
53    theta_gamma_hi: u64,
54    theta_gamma_sign: bool,
55
56    omega: Box<[Box<[u64]>]>,
57    omega_shoup: Box<[Box<[u64]>]>,
58    theta_omega_lo: Box<[u64]>,
59    theta_omega_hi: Box<[u64]>,
60    theta_omega_sign: Box<[bool]>,
61
62    theta_garner_lo: Box<[u64]>,
63    theta_garner_hi: Box<[u64]>,
64    theta_garner_shift: usize,
65}
66
67impl RnsScaler {
68    /// Create a RNS scaler by numerator / denominator.
69    ///
70    /// Aborts if denominator is equal to 0.
71    pub fn new(
72        from: &Arc<RnsContext>,
73        to: &Arc<RnsContext>,
74        scaling_factor: ScalingFactor,
75    ) -> Self {
76        // Let's define gamma = round(numerator * from.product / denominator)
77        let (gamma, theta_gamma_lo, theta_gamma_hi, theta_gamma_sign) =
78            Self::extract_projection_and_theta(
79                to,
80                &from.product,
81                &scaling_factor.numerator,
82                &scaling_factor.denominator,
83                false,
84            );
85        let gamma_shoup = izip!(&gamma, &to.moduli)
86            .map(|(wi, q)| q.shoup(*wi))
87            .collect_vec();
88
89        // Let's define omega_i = round(from.garner_i * numerator / denominator)
90        let mut omega = vec![vec![0u64; from.moduli.len()].into_boxed_slice(); to.moduli.len()];
91        let mut omega_shoup =
92            vec![vec![0u64; from.moduli.len()].into_boxed_slice(); to.moduli.len()];
93        let (omegas_i, theta_omega_lo, theta_omega_hi, theta_omega_sign): (
94            Vec<Vec<u64>>,
95            Vec<u64>,
96            Vec<u64>,
97            Vec<bool>,
98        ) = from
99            .garner
100            .iter()
101            .map(|garner_i| {
102                Self::extract_projection_and_theta(
103                    to,
104                    garner_i,
105                    &scaling_factor.numerator,
106                    &scaling_factor.denominator,
107                    true,
108                )
109            })
110            .multiunzip();
111
112        for (i, omega_i) in omegas_i.iter().enumerate() {
113            for j in 0..to.moduli.len() {
114                let qj = &to.moduli[j];
115                omega[j][i] = qj.reduce(omega_i[j]);
116                omega_shoup[j][i] = qj.shoup(omega[j][i]);
117            }
118        }
119
120        // Determine the shift so that the sum of the scaled theta_garner fit on an U192
121        // (shift + 1) + log(q * n) <= 192
122        let theta_garner_shift = min(
123            from.moduli_u64
124                .iter()
125                .map(|qi| {
126                    192 - 1
127                        - ((*qi as u128) * (from.moduli_u64.len() as u128))
128                            .next_power_of_two()
129                            .ilog2()
130                })
131                .min()
132                .unwrap(),
133            127,
134        );
135        // Finally, define theta_garner_i = from.garner_i / product, also scaled by
136        // 2^127.
137        let (theta_garner_lo, theta_garner_hi): (Vec<u64>, Vec<u64>) = from
138            .garner
139            .iter()
140            .map(|garner_i| {
141                let mut theta: BigUint =
142                    ((garner_i << theta_garner_shift) + (&from.product >> 1)) / &from.product;
143                let theta_hi: BigUint = &theta >> 64;
144                theta -= &theta_hi << 64;
145                (theta.to_u64().unwrap(), theta_hi.to_u64().unwrap())
146            })
147            .unzip();
148
149        Self {
150            from: from.clone(),
151            to: to.clone(),
152            scaling_factor,
153            gamma: gamma.into_boxed_slice(),
154            gamma_shoup: gamma_shoup.into_boxed_slice(),
155            theta_gamma_lo,
156            theta_gamma_hi,
157            theta_gamma_sign,
158            omega: omega.into_boxed_slice(),
159            omega_shoup: omega_shoup.into_boxed_slice(),
160            theta_omega_lo: theta_omega_lo.into_boxed_slice(),
161            theta_omega_hi: theta_omega_hi.into_boxed_slice(),
162            theta_omega_sign: theta_omega_sign.into_boxed_slice(),
163            theta_garner_lo: theta_garner_lo.into_boxed_slice(),
164            theta_garner_hi: theta_garner_hi.into_boxed_slice(),
165            theta_garner_shift: theta_garner_shift as usize,
166        }
167    }
168
169    // Let's define gamma = round(numerator * input / denominator)
170    // and theta_gamma such that theta_gamma = numerator * input / denominator -
171    // gamma. This function projects gamma in the RNS context, and scales
172    // theta_gamma by 2**127 and rounds. It outputs the projection of gamma in the
173    // RNS context, and theta_lo, theta_hi, theta_sign such that theta_gamma =
174    // (-1)**theta_sign * (theta_lo + 2^64 * theta_hi).
175    fn extract_projection_and_theta(
176        ctx: &RnsContext,
177        input: &BigUint,
178        numerator: &BigUint,
179        denominator: &BigUint,
180        round_up: bool,
181    ) -> (Vec<u64>, u64, u64, bool) {
182        let gamma = (numerator * input + (denominator >> 1)) / denominator;
183        let projected = ctx.project(&gamma);
184
185        let mut theta = (numerator * input) % denominator;
186        let mut theta_sign = false;
187        if denominator > &BigUint::one() {
188            // If denominator is odd, flip theta if theta > (denominator >> 1)
189            if denominator & BigUint::one() == BigUint::one() {
190                if theta > (denominator >> 1) {
191                    theta_sign = true;
192                    theta = denominator - theta;
193                }
194            } else {
195                // denominator is even, flip if theta >= (denominator >> 1)
196                if theta >= (denominator >> 1) {
197                    theta_sign = true;
198                    theta = denominator - theta;
199                }
200            }
201        }
202        // theta = ((theta << 127) + (denominator >> 1)) / denominator;
203        // We can now split theta into two u64 words.
204        if round_up {
205            if theta_sign {
206                theta = (theta << 127) / denominator;
207            } else {
208                theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
209            }
210        } else if theta_sign {
211            theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
212        } else {
213            theta = (theta << 127) / denominator;
214        }
215        let theta_hi_biguint: BigUint = &theta >> 64;
216        theta -= &theta_hi_biguint << 64;
217        let theta_lo = theta.to_u64().unwrap();
218        let theta_hi = theta_hi_biguint.to_u64().unwrap();
219
220        (projected, theta_lo, theta_hi, theta_sign)
221    }
222
223    /// Output the RNS representation of the rests scaled by numerator *
224    /// denominator, and either rounded or floored.
225    ///
226    /// Aborts if the number of rests is different than the number of moduli in
227    /// debug mode, or if the size is not in [1, ..., rests.len()].
228    pub fn scale_new(&self, rests: ArrayView1<u64>, size: usize) -> Vec<u64> {
229        let mut out = vec![0; size];
230        self.scale(rests, (&mut out).into(), 0);
231        out
232    }
233
234    /// Compute the RNS representation of the rests scaled by numerator *
235    /// denominator, and either rounded or floored, and store the result in
236    /// `out`.
237    ///
238    /// Aborts if the number of rests is different than the number of moduli in
239    /// debug mode, or if the size of out is not in [1, ..., rests.len()].
240    pub fn scale(
241        &self,
242        rests: ArrayView1<u64>,
243        mut out: ArrayViewMut1<u64>,
244        starting_index: usize,
245    ) {
246        debug_assert_eq!(rests.len(), self.from.moduli_u64.len());
247        debug_assert!(!out.is_empty());
248        debug_assert!(starting_index + out.len() <= self.to.moduli_u64.len());
249
250        // First, let's compute the inner product of the rests with theta_omega.
251        let mut sum_theta_garner = u256::ZERO;
252        for (thetag_lo, thetag_hi, ri) in izip!(
253            self.theta_garner_lo.iter(),
254            self.theta_garner_hi.iter(),
255            rests
256        ) {
257            sum_theta_garner = sum_theta_garner.wrapping_add(
258                U256::from(*ri) * U256::from((*thetag_lo as u128) | ((*thetag_hi as u128) << 64)),
259            );
260        }
261        // Let's compute v = round(sum_theta_garner / 2^theta_garner_shift)
262        sum_theta_garner >>= self.theta_garner_shift - 1;
263        let v = sum_theta_garner.as_u128().div_ceil(2);
264
265        // If the scaling factor is not 1, compute the inner product with the
266        // theta_omega
267        let mut w_sign = false;
268        let mut w = 0u128;
269        if !self.scaling_factor.is_one {
270            let mut sum_theta_omega = u256::ZERO;
271            for (thetao_lo, thetao_hi, thetao_sign, ri) in izip!(
272                self.theta_omega_lo.iter(),
273                self.theta_omega_hi.iter(),
274                self.theta_omega_sign.iter(),
275                rests
276            ) {
277                let product = U256::from(*ri)
278                    * U256::from((*thetao_lo as u128) | ((*thetao_hi as u128) << 64));
279                if *thetao_sign {
280                    sum_theta_omega = sum_theta_omega.wrapping_sub(product);
281                } else {
282                    sum_theta_omega = sum_theta_omega.wrapping_add(product);
283                }
284            }
285
286            // Let's subtract v * theta_gamma to sum_theta_omega.
287            let v_theta_gamma = U256::from(v)
288                * U256::from((self.theta_gamma_lo as u128) | ((self.theta_gamma_hi as u128) << 64));
289            if self.theta_gamma_sign {
290                sum_theta_omega = sum_theta_omega.wrapping_add(v_theta_gamma);
291            } else {
292                sum_theta_omega = sum_theta_omega.wrapping_sub(v_theta_gamma);
293            }
294
295            // Let's compute w = round(sum_theta_omega / 2^(192)).
296            w_sign = (sum_theta_omega >> (63 + 128)) > u256::ZERO;
297
298            if w_sign {
299                w = ((!sum_theta_omega) >> 126isize).as_u128() + 1;
300                w /= 2;
301            } else {
302                w = (sum_theta_omega >> 126isize).as_u128();
303                w = w.div_ceil(2)
304            }
305        }
306
307        unsafe {
308            for i in 0..out.len() {
309                debug_assert!(starting_index + i < self.to.moduli.len());
310                debug_assert!(starting_index + i < self.omega.len());
311                debug_assert!(starting_index + i < self.omega_shoup.len());
312                debug_assert!(starting_index + i < self.gamma.len());
313                debug_assert!(starting_index + i < self.gamma_shoup.len());
314                let out_i = out.get_mut(i).unwrap();
315                let qi = self.to.moduli.get_unchecked(starting_index + i);
316                let omega_i = self.omega.get_unchecked(starting_index + i);
317                let omega_shoup_i = self.omega_shoup.get_unchecked(starting_index + i);
318                let gamma_i = self.gamma.get_unchecked(starting_index + i);
319                let gamma_shoup_i = self.gamma_shoup.get_unchecked(starting_index + i);
320
321                let mut yi = (**qi * 2
322                    - qi.lazy_mul_shoup(qi.reduce_u128(v), *gamma_i, *gamma_shoup_i))
323                    as u128;
324
325                if !self.scaling_factor.is_one {
326                    let wi = qi.lazy_reduce_u128(w);
327                    yi += if w_sign { **qi * 2 - wi } else { wi } as u128;
328                }
329
330                debug_assert!(rests.len() <= omega_i.len());
331                debug_assert!(rests.len() <= omega_shoup_i.len());
332                for j in 0..rests.len() {
333                    yi += qi.lazy_mul_shoup(
334                        *rests.get(j).unwrap(),
335                        *omega_i.get_unchecked(j),
336                        *omega_shoup_i.get_unchecked(j),
337                    ) as u128;
338                }
339
340                *out_i = qi.reduce_u128(yi)
341            }
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use std::{error::Error, panic::catch_unwind, sync::Arc};
349
350    use super::RnsScaler;
351    use crate::rns::{scaler::ScalingFactor, RnsContext};
352    use ndarray::ArrayView1;
353    use num_bigint::BigUint;
354    use num_traits::{ToPrimitive, Zero};
355    use rand::{rng, RngCore};
356
357    #[test]
358    fn constructor() -> Result<(), Box<dyn Error>> {
359        let q = Arc::new(RnsContext::new(&[4, 4611686018326724609, 1153])?);
360
361        let scaler = RnsScaler::new(&q, &q, ScalingFactor::one());
362        assert_eq!(scaler.from, q);
363
364        assert!(
365            catch_unwind(|| ScalingFactor::new(&BigUint::from(1u64), &BigUint::zero())).is_err()
366        );
367        Ok(())
368    }
369
370    #[test]
371    fn scale_same_context() -> Result<(), Box<dyn Error>> {
372        let ntests = 1000;
373        let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
374        let mut rng = rng();
375
376        for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
377            for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
378                let n = BigUint::from(*numerator);
379                let d = BigUint::from(*denominator);
380                let scaler = RnsScaler::new(&q, &q, ScalingFactor::new(&n, &d));
381
382                for _ in 0..ntests {
383                    let x = vec![
384                        rng.next_u64() % q.moduli_u64[0],
385                        rng.next_u64() % q.moduli_u64[1],
386                        rng.next_u64() % q.moduli_u64[2],
387                    ];
388                    let mut x_lift = q.lift(ArrayView1::from(&x));
389                    let x_sign = x_lift >= (q.modulus() >> 1);
390                    if x_sign {
391                        x_lift = q.modulus() - x_lift;
392                    }
393
394                    let z = scaler.scale_new((&x).into(), x.len());
395                    let x_scaled_round = if x_sign {
396                        if d.to_u64().unwrap() % 2 == 0 {
397                            q.modulus()
398                                - (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % q.modulus()
399                        } else {
400                            q.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % q.modulus()
401                        }
402                    } else {
403                        &(&x_lift * &n + (&d >> 1)) / &d
404                    };
405                    assert_eq!(z, q.project(&x_scaled_round));
406                }
407            }
408        }
409        Ok(())
410    }
411
412    #[test]
413    fn scale_different_contexts() -> Result<(), Box<dyn Error>> {
414        let ntests = 100;
415        let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
416        let r = Arc::new(RnsContext::new(&[
417            4u64,
418            4611686018326724609,
419            1153,
420            4611686018309947393,
421            4611686018282684417,
422            4611686018257518593,
423            4611686018232352769,
424            4611686018171535361,
425            4611686018106523649,
426            4611686018058289153,
427        ])?);
428        let mut rng = rng();
429
430        for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
431            for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
432                let n = BigUint::from(*numerator);
433                let d = BigUint::from(*denominator);
434                let scaler = RnsScaler::new(&q, &r, ScalingFactor::new(&n, &d));
435                for _ in 0..ntests {
436                    let x = vec![
437                        rng.next_u64() % q.moduli_u64[0],
438                        rng.next_u64() % q.moduli_u64[1],
439                        rng.next_u64() % q.moduli_u64[2],
440                    ];
441
442                    let mut x_lift = q.lift(ArrayView1::from(&x));
443                    let x_sign = x_lift >= (q.modulus() >> 1);
444                    if x_sign {
445                        x_lift = q.modulus() - x_lift;
446                    }
447
448                    let y = scaler.scale_new((&x).into(), r.moduli.len());
449                    let x_scaled_round = if x_sign {
450                        if d.to_u64().unwrap() % 2 == 0 {
451                            r.modulus()
452                                - (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % r.modulus()
453                        } else {
454                            r.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % r.modulus()
455                        }
456                    } else {
457                        &(&x_lift * &n + (&d >> 1)) / &d
458                    };
459                    assert_eq!(y, r.project(&x_scaled_round));
460                }
461            }
462        }
463        Ok(())
464    }
465}