1#![warn(missing_docs, unused_imports)]
2
3use super::RnsContext;
6use ethnum::{u256, U256};
7use itertools::{izip, Itertools};
8use ndarray::{ArrayView1, ArrayViewMut1};
9use num_bigint::BigUint;
10use num_traits::{One, ToPrimitive, Zero};
11use std::{cmp::min, sync::Arc};
12
13#[derive(Default, Debug, Clone, PartialEq, Eq)]
15pub struct ScalingFactor {
16 numerator: BigUint,
17 denominator: BigUint,
18 pub(crate) is_one: bool,
19}
20
21impl ScalingFactor {
22 pub fn new(numerator: &BigUint, denominator: &BigUint) -> Self {
24 assert_ne!(denominator, &BigUint::zero());
25 Self {
26 numerator: numerator.clone(),
27 denominator: denominator.clone(),
28 is_one: numerator == denominator,
29 }
30 }
31
32 pub fn one() -> Self {
34 Self {
35 numerator: BigUint::one(),
36 denominator: BigUint::one(),
37 is_one: true,
38 }
39 }
40}
41
42#[derive(Default, Debug, Clone, PartialEq, Eq)]
45pub struct RnsScaler {
46 from: Arc<RnsContext>,
47 to: Arc<RnsContext>,
48 scaling_factor: ScalingFactor,
49
50 gamma: Box<[u64]>,
51 gamma_shoup: Box<[u64]>,
52 theta_gamma_lo: u64,
53 theta_gamma_hi: u64,
54 theta_gamma_sign: bool,
55
56 omega: Box<[Box<[u64]>]>,
57 omega_shoup: Box<[Box<[u64]>]>,
58 theta_omega_lo: Box<[u64]>,
59 theta_omega_hi: Box<[u64]>,
60 theta_omega_sign: Box<[bool]>,
61
62 theta_garner_lo: Box<[u64]>,
63 theta_garner_hi: Box<[u64]>,
64 theta_garner_shift: usize,
65}
66
67impl RnsScaler {
68 pub fn new(
72 from: &Arc<RnsContext>,
73 to: &Arc<RnsContext>,
74 scaling_factor: ScalingFactor,
75 ) -> Self {
76 let (gamma, theta_gamma_lo, theta_gamma_hi, theta_gamma_sign) =
78 Self::extract_projection_and_theta(
79 to,
80 &from.product,
81 &scaling_factor.numerator,
82 &scaling_factor.denominator,
83 false,
84 );
85 let gamma_shoup = izip!(&gamma, &to.moduli)
86 .map(|(wi, q)| q.shoup(*wi))
87 .collect_vec();
88
89 let mut omega = vec![vec![0u64; from.moduli.len()].into_boxed_slice(); to.moduli.len()];
91 let mut omega_shoup =
92 vec![vec![0u64; from.moduli.len()].into_boxed_slice(); to.moduli.len()];
93 let (omegas_i, theta_omega_lo, theta_omega_hi, theta_omega_sign): (
94 Vec<Vec<u64>>,
95 Vec<u64>,
96 Vec<u64>,
97 Vec<bool>,
98 ) = from
99 .garner
100 .iter()
101 .map(|garner_i| {
102 Self::extract_projection_and_theta(
103 to,
104 garner_i,
105 &scaling_factor.numerator,
106 &scaling_factor.denominator,
107 true,
108 )
109 })
110 .multiunzip();
111
112 for (i, omega_i) in omegas_i.iter().enumerate() {
113 for j in 0..to.moduli.len() {
114 let qj = &to.moduli[j];
115 omega[j][i] = qj.reduce(omega_i[j]);
116 omega_shoup[j][i] = qj.shoup(omega[j][i]);
117 }
118 }
119
120 let theta_garner_shift = min(
123 from.moduli_u64
124 .iter()
125 .map(|qi| {
126 192 - 1
127 - ((*qi as u128) * (from.moduli_u64.len() as u128))
128 .next_power_of_two()
129 .ilog2()
130 })
131 .min()
132 .unwrap(),
133 127,
134 );
135 let (theta_garner_lo, theta_garner_hi): (Vec<u64>, Vec<u64>) = from
138 .garner
139 .iter()
140 .map(|garner_i| {
141 let mut theta: BigUint =
142 ((garner_i << theta_garner_shift) + (&from.product >> 1)) / &from.product;
143 let theta_hi: BigUint = &theta >> 64;
144 theta -= &theta_hi << 64;
145 (theta.to_u64().unwrap(), theta_hi.to_u64().unwrap())
146 })
147 .unzip();
148
149 Self {
150 from: from.clone(),
151 to: to.clone(),
152 scaling_factor,
153 gamma: gamma.into_boxed_slice(),
154 gamma_shoup: gamma_shoup.into_boxed_slice(),
155 theta_gamma_lo,
156 theta_gamma_hi,
157 theta_gamma_sign,
158 omega: omega.into_boxed_slice(),
159 omega_shoup: omega_shoup.into_boxed_slice(),
160 theta_omega_lo: theta_omega_lo.into_boxed_slice(),
161 theta_omega_hi: theta_omega_hi.into_boxed_slice(),
162 theta_omega_sign: theta_omega_sign.into_boxed_slice(),
163 theta_garner_lo: theta_garner_lo.into_boxed_slice(),
164 theta_garner_hi: theta_garner_hi.into_boxed_slice(),
165 theta_garner_shift: theta_garner_shift as usize,
166 }
167 }
168
169 fn extract_projection_and_theta(
176 ctx: &RnsContext,
177 input: &BigUint,
178 numerator: &BigUint,
179 denominator: &BigUint,
180 round_up: bool,
181 ) -> (Vec<u64>, u64, u64, bool) {
182 let gamma = (numerator * input + (denominator >> 1)) / denominator;
183 let projected = ctx.project(&gamma);
184
185 let mut theta = (numerator * input) % denominator;
186 let mut theta_sign = false;
187 if denominator > &BigUint::one() {
188 if denominator & BigUint::one() == BigUint::one() {
190 if theta > (denominator >> 1) {
191 theta_sign = true;
192 theta = denominator - theta;
193 }
194 } else {
195 if theta >= (denominator >> 1) {
197 theta_sign = true;
198 theta = denominator - theta;
199 }
200 }
201 }
202 if round_up {
205 if theta_sign {
206 theta = (theta << 127) / denominator;
207 } else {
208 theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
209 }
210 } else if theta_sign {
211 theta = ((theta << 127) + denominator - BigUint::one()) / denominator;
212 } else {
213 theta = (theta << 127) / denominator;
214 }
215 let theta_hi_biguint: BigUint = &theta >> 64;
216 theta -= &theta_hi_biguint << 64;
217 let theta_lo = theta.to_u64().unwrap();
218 let theta_hi = theta_hi_biguint.to_u64().unwrap();
219
220 (projected, theta_lo, theta_hi, theta_sign)
221 }
222
223 pub fn scale_new(&self, rests: ArrayView1<u64>, size: usize) -> Vec<u64> {
229 let mut out = vec![0; size];
230 self.scale(rests, (&mut out).into(), 0);
231 out
232 }
233
234 pub fn scale(
241 &self,
242 rests: ArrayView1<u64>,
243 mut out: ArrayViewMut1<u64>,
244 starting_index: usize,
245 ) {
246 debug_assert_eq!(rests.len(), self.from.moduli_u64.len());
247 debug_assert!(!out.is_empty());
248 debug_assert!(starting_index + out.len() <= self.to.moduli_u64.len());
249
250 let mut sum_theta_garner = u256::ZERO;
252 for (thetag_lo, thetag_hi, ri) in izip!(
253 self.theta_garner_lo.iter(),
254 self.theta_garner_hi.iter(),
255 rests
256 ) {
257 sum_theta_garner = sum_theta_garner.wrapping_add(
258 U256::from(*ri) * U256::from((*thetag_lo as u128) | ((*thetag_hi as u128) << 64)),
259 );
260 }
261 sum_theta_garner >>= self.theta_garner_shift - 1;
263 let v = sum_theta_garner.as_u128().div_ceil(2);
264
265 let mut w_sign = false;
268 let mut w = 0u128;
269 if !self.scaling_factor.is_one {
270 let mut sum_theta_omega = u256::ZERO;
271 for (thetao_lo, thetao_hi, thetao_sign, ri) in izip!(
272 self.theta_omega_lo.iter(),
273 self.theta_omega_hi.iter(),
274 self.theta_omega_sign.iter(),
275 rests
276 ) {
277 let product = U256::from(*ri)
278 * U256::from((*thetao_lo as u128) | ((*thetao_hi as u128) << 64));
279 if *thetao_sign {
280 sum_theta_omega = sum_theta_omega.wrapping_sub(product);
281 } else {
282 sum_theta_omega = sum_theta_omega.wrapping_add(product);
283 }
284 }
285
286 let v_theta_gamma = U256::from(v)
288 * U256::from((self.theta_gamma_lo as u128) | ((self.theta_gamma_hi as u128) << 64));
289 if self.theta_gamma_sign {
290 sum_theta_omega = sum_theta_omega.wrapping_add(v_theta_gamma);
291 } else {
292 sum_theta_omega = sum_theta_omega.wrapping_sub(v_theta_gamma);
293 }
294
295 w_sign = (sum_theta_omega >> (63 + 128)) > u256::ZERO;
297
298 if w_sign {
299 w = ((!sum_theta_omega) >> 126isize).as_u128() + 1;
300 w /= 2;
301 } else {
302 w = (sum_theta_omega >> 126isize).as_u128();
303 w = w.div_ceil(2)
304 }
305 }
306
307 unsafe {
308 for i in 0..out.len() {
309 debug_assert!(starting_index + i < self.to.moduli.len());
310 debug_assert!(starting_index + i < self.omega.len());
311 debug_assert!(starting_index + i < self.omega_shoup.len());
312 debug_assert!(starting_index + i < self.gamma.len());
313 debug_assert!(starting_index + i < self.gamma_shoup.len());
314 let out_i = out.get_mut(i).unwrap();
315 let qi = self.to.moduli.get_unchecked(starting_index + i);
316 let omega_i = self.omega.get_unchecked(starting_index + i);
317 let omega_shoup_i = self.omega_shoup.get_unchecked(starting_index + i);
318 let gamma_i = self.gamma.get_unchecked(starting_index + i);
319 let gamma_shoup_i = self.gamma_shoup.get_unchecked(starting_index + i);
320
321 let mut yi = (**qi * 2
322 - qi.lazy_mul_shoup(qi.reduce_u128(v), *gamma_i, *gamma_shoup_i))
323 as u128;
324
325 if !self.scaling_factor.is_one {
326 let wi = qi.lazy_reduce_u128(w);
327 yi += if w_sign { **qi * 2 - wi } else { wi } as u128;
328 }
329
330 debug_assert!(rests.len() <= omega_i.len());
331 debug_assert!(rests.len() <= omega_shoup_i.len());
332 for j in 0..rests.len() {
333 yi += qi.lazy_mul_shoup(
334 *rests.get(j).unwrap(),
335 *omega_i.get_unchecked(j),
336 *omega_shoup_i.get_unchecked(j),
337 ) as u128;
338 }
339
340 *out_i = qi.reduce_u128(yi)
341 }
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use std::{error::Error, panic::catch_unwind, sync::Arc};
349
350 use super::RnsScaler;
351 use crate::rns::{scaler::ScalingFactor, RnsContext};
352 use ndarray::ArrayView1;
353 use num_bigint::BigUint;
354 use num_traits::{ToPrimitive, Zero};
355 use rand::{rng, RngCore};
356
357 #[test]
358 fn constructor() -> Result<(), Box<dyn Error>> {
359 let q = Arc::new(RnsContext::new(&[4, 4611686018326724609, 1153])?);
360
361 let scaler = RnsScaler::new(&q, &q, ScalingFactor::one());
362 assert_eq!(scaler.from, q);
363
364 assert!(
365 catch_unwind(|| ScalingFactor::new(&BigUint::from(1u64), &BigUint::zero())).is_err()
366 );
367 Ok(())
368 }
369
370 #[test]
371 fn scale_same_context() -> Result<(), Box<dyn Error>> {
372 let ntests = 1000;
373 let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
374 let mut rng = rng();
375
376 for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
377 for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
378 let n = BigUint::from(*numerator);
379 let d = BigUint::from(*denominator);
380 let scaler = RnsScaler::new(&q, &q, ScalingFactor::new(&n, &d));
381
382 for _ in 0..ntests {
383 let x = vec![
384 rng.next_u64() % q.moduli_u64[0],
385 rng.next_u64() % q.moduli_u64[1],
386 rng.next_u64() % q.moduli_u64[2],
387 ];
388 let mut x_lift = q.lift(ArrayView1::from(&x));
389 let x_sign = x_lift >= (q.modulus() >> 1);
390 if x_sign {
391 x_lift = q.modulus() - x_lift;
392 }
393
394 let z = scaler.scale_new((&x).into(), x.len());
395 let x_scaled_round = if x_sign {
396 if d.to_u64().unwrap() % 2 == 0 {
397 q.modulus()
398 - (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % q.modulus()
399 } else {
400 q.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % q.modulus()
401 }
402 } else {
403 &(&x_lift * &n + (&d >> 1)) / &d
404 };
405 assert_eq!(z, q.project(&x_scaled_round));
406 }
407 }
408 }
409 Ok(())
410 }
411
412 #[test]
413 fn scale_different_contexts() -> Result<(), Box<dyn Error>> {
414 let ntests = 100;
415 let q = Arc::new(RnsContext::new(&[4u64, 4611686018326724609, 1153])?);
416 let r = Arc::new(RnsContext::new(&[
417 4u64,
418 4611686018326724609,
419 1153,
420 4611686018309947393,
421 4611686018282684417,
422 4611686018257518593,
423 4611686018232352769,
424 4611686018171535361,
425 4611686018106523649,
426 4611686018058289153,
427 ])?);
428 let mut rng = rng();
429
430 for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
431 for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
432 let n = BigUint::from(*numerator);
433 let d = BigUint::from(*denominator);
434 let scaler = RnsScaler::new(&q, &r, ScalingFactor::new(&n, &d));
435 for _ in 0..ntests {
436 let x = vec![
437 rng.next_u64() % q.moduli_u64[0],
438 rng.next_u64() % q.moduli_u64[1],
439 rng.next_u64() % q.moduli_u64[2],
440 ];
441
442 let mut x_lift = q.lift(ArrayView1::from(&x));
443 let x_sign = x_lift >= (q.modulus() >> 1);
444 if x_sign {
445 x_lift = q.modulus() - x_lift;
446 }
447
448 let y = scaler.scale_new((&x).into(), r.moduli.len());
449 let x_scaled_round = if x_sign {
450 if d.to_u64().unwrap() % 2 == 0 {
451 r.modulus()
452 - (&(&x_lift * &n + ((&d >> 1usize) - 1u64)) / &d) % r.modulus()
453 } else {
454 r.modulus() - (&(&x_lift * &n + (&d >> 1)) / &d) % r.modulus()
455 }
456 } else {
457 &(&x_lift * &n + (&d >> 1)) / &d
458 };
459 assert_eq!(y, r.project(&x_scaled_round));
460 }
461 }
462 }
463 Ok(())
464 }
465}