1#![warn(missing_docs, unused_imports)]
2
3use super::{Context, Poly, Representation};
6use crate::{
7 rns::{RnsScaler, ScalingFactor},
8 Error, Result,
9};
10use itertools::izip;
11use ndarray::{s, Array2, Axis};
12use std::borrow::Cow;
13use std::sync::Arc;
14
15#[derive(Default, Debug, Clone, PartialEq, Eq)]
17pub struct Scaler {
18 from: Arc<Context>,
19 to: Arc<Context>,
20 number_common_moduli: usize,
21 scaler: RnsScaler,
22}
23
24impl Scaler {
25 pub fn new(from: &Arc<Context>, to: &Arc<Context>, factor: ScalingFactor) -> Result<Self> {
27 if from.degree != to.degree {
28 return Err(Error::Default("Incompatible degrees".to_string()));
29 }
30
31 let number_common_moduli = if factor.is_one {
32 from.q
33 .iter()
34 .zip(to.q.iter())
35 .take_while(|(qi, pi)| qi == pi)
36 .count()
37 } else {
38 0
39 };
40 let scaler = RnsScaler::new(&from.rns, &to.rns, factor);
41
42 Ok(Self {
43 from: from.clone(),
44 to: to.clone(),
45 number_common_moduli,
46 scaler,
47 })
48 }
49
50 pub(crate) fn scale(&self, p: &Poly) -> Result<Poly> {
52 if p.ctx.as_ref() != self.from.as_ref() {
53 Err(Error::Default(
54 "The input polynomial does not have the correct context".to_string(),
55 ))
56 } else {
57 let mut representation = p.representation;
58 if representation == Representation::NttShoup {
59 representation = Representation::Ntt;
60 }
61
62 let mut new_coefficients = Array2::<u64>::zeros((self.to.q.len(), self.to.degree));
63
64 if self.number_common_moduli > 0 {
65 new_coefficients
66 .slice_mut(s![..self.number_common_moduli, ..])
67 .assign(&p.coefficients.slice(s![..self.number_common_moduli, ..]));
68 }
69
70 if self.number_common_moduli < self.to.q.len() {
71 let needs_transform = p.representation != Representation::PowerBasis;
72 let p_coefficients_powerbasis: Cow<'_, Array2<u64>> = if needs_transform {
73 let mut owned = p.coefficients.clone();
74 if p.allow_variable_time_computations {
76 izip!(owned.outer_iter_mut(), p.ctx.ops.iter())
77 .for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) });
78 } else {
79 izip!(owned.outer_iter_mut(), p.ctx.ops.iter())
80 .for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap()));
81 }
82 Cow::Owned(owned)
83 } else {
84 Cow::Borrowed(&p.coefficients)
85 };
86
87 izip!(
89 new_coefficients
90 .slice_mut(s![self.number_common_moduli.., ..])
91 .axis_iter_mut(Axis(1)),
92 p_coefficients_powerbasis.axis_iter(Axis(1))
93 )
94 .for_each(|(new_column, column)| {
95 self.scaler
96 .scale(column, new_column, self.number_common_moduli)
97 });
98
99 if needs_transform {
101 if p.allow_variable_time_computations {
102 izip!(
103 new_coefficients
104 .slice_mut(s![self.number_common_moduli.., ..])
105 .outer_iter_mut(),
106 &self.to.ops[self.number_common_moduli..]
107 )
108 .for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) });
109 } else {
110 izip!(
111 new_coefficients
112 .slice_mut(s![self.number_common_moduli.., ..])
113 .outer_iter_mut(),
114 &self.to.ops[self.number_common_moduli..]
115 )
116 .for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap()));
117 }
118 }
119 }
120
121 Ok(Poly {
122 ctx: self.to.clone(),
123 representation,
124 allow_variable_time_computations: p.allow_variable_time_computations,
125 coefficients: new_coefficients,
126 coefficients_shoup: None,
127 has_lazy_coefficients: false,
128 })
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::{Scaler, ScalingFactor};
136 use crate::rq::{Context, Poly, Representation};
137 use itertools::Itertools;
138 use num_bigint::BigUint;
139 use num_traits::{One, Zero};
140 use rand::rng;
141 use std::error::Error;
142
143 static Q: &[u64; 3] = &[
145 4611686018282684417,
146 4611686018326724609,
147 4611686018309947393,
148 ];
149
150 static P: &[u64; 3] = &[
151 4611686018282684417,
152 4611686018309947393,
153 4611686018257518593,
154 ];
155
156 #[test]
157 fn scaler() -> Result<(), Box<dyn Error>> {
158 let mut rng = rng();
159 let ntests = 100;
160 let from = Context::new_arc(Q, 16)?;
161 let to = Context::new_arc(P, 16)?;
162
163 for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
164 for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
165 let n = BigUint::from(*numerator);
166 let d = BigUint::from(*denominator);
167
168 let scaler = Scaler::new(&from, &to, ScalingFactor::new(&n, &d))?;
169
170 for _ in 0..ntests {
171 let mut poly = Poly::random(&from, Representation::PowerBasis, &mut rng);
172 let poly_biguint = Vec::<BigUint>::from(&poly);
173
174 let scaled_poly = scaler.scale(&poly)?;
175 let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
176
177 let expected = poly_biguint
178 .iter()
179 .map(|i| {
180 if i >= &(from.modulus() >> 1usize) {
181 if &d & BigUint::one() == BigUint::zero() {
182 to.modulus()
183 - (&(&(from.modulus() - i) * &n + ((&d >> 1usize) - 1u64))
184 / &d)
185 % to.modulus()
186 } else {
187 to.modulus()
188 - (&(&(from.modulus() - i) * &n + (&d >> 1)) / &d)
189 % to.modulus()
190 }
191 } else {
192 ((i * &n + (&d >> 1)) / &d) % to.modulus()
193 }
194 })
195 .collect_vec();
196 assert_eq!(expected, scaled_biguint);
197
198 poly.change_representation(Representation::Ntt);
199 let mut scaled_poly = scaler.scale(&poly)?;
200 scaled_poly.change_representation(Representation::PowerBasis);
201 let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
202 assert_eq!(expected, scaled_biguint);
203 }
204 }
205 }
206
207 Ok(())
208 }
209}