fhe_math/rq/
scaler.rs

1#![warn(missing_docs, unused_imports)]
2
3//! Polynomial scaler.
4
5use super::{Context, Poly, Representation};
6use crate::{
7    rns::{RnsScaler, ScalingFactor},
8    Error, Result,
9};
10use itertools::izip;
11use ndarray::{s, Array2, Axis};
12use std::borrow::Cow;
13use std::sync::Arc;
14
15/// Context extender.
16#[derive(Default, Debug, Clone, PartialEq, Eq)]
17pub struct Scaler {
18    from: Arc<Context>,
19    to: Arc<Context>,
20    number_common_moduli: usize,
21    scaler: RnsScaler,
22}
23
24impl Scaler {
25    /// Create a scaler from a context `from` to a context `to`.
26    pub fn new(from: &Arc<Context>, to: &Arc<Context>, factor: ScalingFactor) -> Result<Self> {
27        if from.degree != to.degree {
28            return Err(Error::Default("Incompatible degrees".to_string()));
29        }
30
31        let number_common_moduli = if factor.is_one {
32            from.q
33                .iter()
34                .zip(to.q.iter())
35                .take_while(|(qi, pi)| qi == pi)
36                .count()
37        } else {
38            0
39        };
40        let scaler = RnsScaler::new(&from.rns, &to.rns, factor);
41
42        Ok(Self {
43            from: from.clone(),
44            to: to.clone(),
45            number_common_moduli,
46            scaler,
47        })
48    }
49
50    /// Scale a polynomial
51    pub(crate) fn scale(&self, p: &Poly) -> Result<Poly> {
52        if p.ctx.as_ref() != self.from.as_ref() {
53            Err(Error::Default(
54                "The input polynomial does not have the correct context".to_string(),
55            ))
56        } else {
57            let mut representation = p.representation;
58            if representation == Representation::NttShoup {
59                representation = Representation::Ntt;
60            }
61
62            let mut new_coefficients = Array2::<u64>::zeros((self.to.q.len(), self.to.degree));
63
64            if self.number_common_moduli > 0 {
65                new_coefficients
66                    .slice_mut(s![..self.number_common_moduli, ..])
67                    .assign(&p.coefficients.slice(s![..self.number_common_moduli, ..]));
68            }
69
70            if self.number_common_moduli < self.to.q.len() {
71                let needs_transform = p.representation != Representation::PowerBasis;
72                let p_coefficients_powerbasis: Cow<'_, Array2<u64>> = if needs_transform {
73                    let mut owned = p.coefficients.clone();
74                    // Backward NTT
75                    if p.allow_variable_time_computations {
76                        izip!(owned.outer_iter_mut(), p.ctx.ops.iter())
77                            .for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) });
78                    } else {
79                        izip!(owned.outer_iter_mut(), p.ctx.ops.iter())
80                            .for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap()));
81                    }
82                    Cow::Owned(owned)
83                } else {
84                    Cow::Borrowed(&p.coefficients)
85                };
86
87                // Conversion
88                izip!(
89                    new_coefficients
90                        .slice_mut(s![self.number_common_moduli.., ..])
91                        .axis_iter_mut(Axis(1)),
92                    p_coefficients_powerbasis.axis_iter(Axis(1))
93                )
94                .for_each(|(new_column, column)| {
95                    self.scaler
96                        .scale(column, new_column, self.number_common_moduli)
97                });
98
99                // Forward NTT on the second half when the source required a transform
100                if needs_transform {
101                    if p.allow_variable_time_computations {
102                        izip!(
103                            new_coefficients
104                                .slice_mut(s![self.number_common_moduli.., ..])
105                                .outer_iter_mut(),
106                            &self.to.ops[self.number_common_moduli..]
107                        )
108                        .for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) });
109                    } else {
110                        izip!(
111                            new_coefficients
112                                .slice_mut(s![self.number_common_moduli.., ..])
113                                .outer_iter_mut(),
114                            &self.to.ops[self.number_common_moduli..]
115                        )
116                        .for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap()));
117                    }
118                }
119            }
120
121            Ok(Poly {
122                ctx: self.to.clone(),
123                representation,
124                allow_variable_time_computations: p.allow_variable_time_computations,
125                coefficients: new_coefficients,
126                coefficients_shoup: None,
127                has_lazy_coefficients: false,
128            })
129        }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::{Scaler, ScalingFactor};
136    use crate::rq::{Context, Poly, Representation};
137    use itertools::Itertools;
138    use num_bigint::BigUint;
139    use num_traits::{One, Zero};
140    use rand::rng;
141    use std::error::Error;
142
143    // Moduli to be used in tests.
144    static Q: &[u64; 3] = &[
145        4611686018282684417,
146        4611686018326724609,
147        4611686018309947393,
148    ];
149
150    static P: &[u64; 3] = &[
151        4611686018282684417,
152        4611686018309947393,
153        4611686018257518593,
154    ];
155
156    #[test]
157    fn scaler() -> Result<(), Box<dyn Error>> {
158        let mut rng = rng();
159        let ntests = 100;
160        let from = Context::new_arc(Q, 16)?;
161        let to = Context::new_arc(P, 16)?;
162
163        for numerator in &[1u64, 2, 3, 100, 1000, 4611686018326724610] {
164            for denominator in &[1u64, 2, 3, 4, 100, 101, 1000, 1001, 4611686018326724610] {
165                let n = BigUint::from(*numerator);
166                let d = BigUint::from(*denominator);
167
168                let scaler = Scaler::new(&from, &to, ScalingFactor::new(&n, &d))?;
169
170                for _ in 0..ntests {
171                    let mut poly = Poly::random(&from, Representation::PowerBasis, &mut rng);
172                    let poly_biguint = Vec::<BigUint>::from(&poly);
173
174                    let scaled_poly = scaler.scale(&poly)?;
175                    let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
176
177                    let expected = poly_biguint
178                        .iter()
179                        .map(|i| {
180                            if i >= &(from.modulus() >> 1usize) {
181                                if &d & BigUint::one() == BigUint::zero() {
182                                    to.modulus()
183                                        - (&(&(from.modulus() - i) * &n + ((&d >> 1usize) - 1u64))
184                                            / &d)
185                                            % to.modulus()
186                                } else {
187                                    to.modulus()
188                                        - (&(&(from.modulus() - i) * &n + (&d >> 1)) / &d)
189                                            % to.modulus()
190                                }
191                            } else {
192                                ((i * &n + (&d >> 1)) / &d) % to.modulus()
193                            }
194                        })
195                        .collect_vec();
196                    assert_eq!(expected, scaled_biguint);
197
198                    poly.change_representation(Representation::Ntt);
199                    let mut scaled_poly = scaler.scale(&poly)?;
200                    scaled_poly.change_representation(Representation::PowerBasis);
201                    let scaled_biguint = Vec::<BigUint>::from(&scaled_poly);
202                    assert_eq!(expected, scaled_biguint);
203                }
204            }
205        }
206
207        Ok(())
208    }
209}