he_ring/bgv/
modswitch.rs

1use core::f64;
2use std::cell::RefCell;
3use std::cmp::min;
4use std::ops::Range;
5
6use feanor_math::homomorphism::Homomorphism;
7use feanor_math::primitive_int::*;
8use feanor_math::ring::*;
9use feanor_math::rings::zn::zn_64::ZnEl;
10use feanor_math::algorithms::matmul::ComputeInnerProduct;
11
12use crate::bgv::noise_estimator::BGVNoiseEstimator;
13use crate::circuit::evaluator::DefaultCircuitEvaluator;
14use crate::circuit::*;
15use crate::cyclotomic::CyclotomicGaloisGroupEl;
16use crate::gadget_product::digits::*;
17use crate::ZZi64;
18
19use super::noise_estimator::AlwaysZeroNoiseEstimator;
20use super::*;
21
22///
23/// A [`Ciphertext`] which additionally stores w.r.t. which ciphertext modulus it is defined,
24/// and which noise level (as measured by some [`BGVModswitchStrategy`]) it is estimated to have.
25///
26pub struct ModulusAwareCiphertext<Params: BGVCiphertextParams, Strategy: ?Sized + BGVModswitchStrategy<Params>> {
27    /// The stored raw ciphertext
28    pub data: Ciphertext<Params>,
29    /// The indices of those RNS components w.r.t. a "master RNS base" (specified by the context)
30    /// that are not used for this ciphertext; in other words, the ciphertext modulus of this ciphertext
31    /// is the product of all RNS factors of the master RNS base that are not mentioned in this list
32    pub dropped_rns_factor_indices: Box<RNSFactorIndexList>,
33    /// Additional information required by the modulus-switching strategy
34    pub info: Strategy::CiphertextInfo
35}
36
37///
38/// Trait for different modulus-switching strategies in BGV, currently WIP.
39///
40/// Basically, a [`BGVModswitchStrategy`] should be able to determine when (and
41/// how) to modulus-switch during the evaluation of an arithmetic circuit.
42/// The most powerful way to do this is by delegating the evaluation of the
43/// circuit completely to the [`BGVModswitchStrategy`], which is our current
44/// approach.
45///
46pub trait BGVModswitchStrategy<Params: BGVCiphertextParams> {
47
48    ///
49    /// Additional information that is associated to a ciphertext and is used
50    /// to determine when and how to modulus-switch. This will most likely be
51    /// some form of estimate of the noise in the ciphertext.
52    /// 
53    type CiphertextInfo;
54
55    ///
56    /// Evaluates the given circuit homomorphically on the given encrypted inputs.
57    /// This includes performing modulus-switches at suitable times.
58    ///
59    /// The parameters are as follows:
60    ///  - `circuit` is the circuit to evaluate, with constants in a ring that supports 
61    ///    plaintext-ciphertext operations, as specified by [`AsBGVPlaintext`]
62    ///  - `ring` is the ring that contains the constants of `circuit`
63    ///  - `P` is the plaintext ring w.r.t. which the inputs are encrypted; `evaluate_circuit()`
64    ///    does not support mixing different plaintext moduli
65    ///  - `C_master` is the ciphertext ring with the largest relevant RNS base, i.e. its RNS
66    ///    base should contain all RNS factors that are referenced by any ciphertext, and may
67    ///    have additional unused RNS factors
68    ///  - `inputs` contains all inputs to the circuit, i.e. must be of the same length as the
69    ///    circuit has input wires. Each entry should be of the form `(drop_rns_factors, info, ctxt)`
70    ///    where `ctxt` is the ciphertext w.r.t. the RNS base that contains all RNS factors of `C_master`
71    ///    except those mentioned in `drop_rns_fctors`, and `info` should store the additional information
72    ///    associated to the ciphertext that is required to determine modulus-switching times.
73    ///  - `rk` should be the relinearization key w.r.t. `C_master`, can be `None` if the circuit
74    ///    contains no multiplication gates.
75    ///  - `gks` should contain all Galois keys used by the circuit (may also contain unused ones);
76    ///    if the circuit has no Galois gates, this may be an empty slice
77    ///
78    /// Note that the [`BGVModswitchStrategy::CiphertextInfo`]s currently cannot be created using
79    /// functions of the trait, but only via functions on the concrete implementation of
80    /// [`BGVModswitchStrategy`].
81    ///
82    fn evaluate_circuit<R>(
83        &self,
84        circuit: &PlaintextCircuit<R::Type>,
85        ring: R,
86        P: &PlaintextRing<Params>,
87        C_master: &CiphertextRing<Params>,
88        inputs: &[ModulusAwareCiphertext<Params, Self>],
89        rk: Option<&RelinKey<Params>>,
90        gks: &[(CyclotomicGaloisGroupEl, KeySwitchKey<Params>)],
91        key_switches: &mut usize,
92        debug_sk: Option<&SecretKey<Params>>
93    ) -> Vec<ModulusAwareCiphertext<Params, Self>>
94        where R: RingStore,
95            R::Type: AsBGVPlaintext<Params>;
96
97    fn info_for_fresh_encryption(&self, P: &PlaintextRing<Params>, C: &CiphertextRing<Params>, sk_hwt: Option<usize>) -> Self::CiphertextInfo;
98
99    fn clone_info(&self, info: &Self::CiphertextInfo) -> Self::CiphertextInfo;
100
101    fn print_info(&self, P: &PlaintextRing<Params>, C_master: &CiphertextRing<Params>, ct: &ModulusAwareCiphertext<Params, Self>);
102
103    fn clone_ct(&self, P: &PlaintextRing<Params>, C_master: &CiphertextRing<Params>, ct: &ModulusAwareCiphertext<Params, Self>) -> ModulusAwareCiphertext<Params, Self> {
104        let C = Params::mod_switch_down_C(C_master, &ct.dropped_rns_factor_indices);
105        ModulusAwareCiphertext {
106            data: Params::clone_ct(P, &C, &ct.data),
107            info: self.clone_info(&ct.info),
108            dropped_rns_factor_indices: ct.dropped_rns_factor_indices.clone()
109        }
110    }
111}
112
113///
114/// Trait for rings whose elements can be used as plaintexts in
115/// plaintext-ciphertext operations in BGV.
116/// 
117/// In particular, this includes
118///  - integers
119///  - plaintext ring elements
120///  - ciphertext ring elements - usually these are plaintext ring
121///    elements that have already been lifted to the ciphertext ring
122///    to avoid the cost of this conversion later
123/// 
124pub trait AsBGVPlaintext<Params: BGVCiphertextParams>: RingBase {
125
126    fn hom_add_to(
127        &self, 
128        P: &PlaintextRing<Params>, 
129        C: &CiphertextRing<Params>, 
130        dropped_factors: &RNSFactorIndexList, 
131        m: &Self::Element, 
132        ct: Ciphertext<Params>
133    ) -> Ciphertext<Params>;
134
135    fn hom_add_to_noise<N: BGVNoiseEstimator<Params>>(
136        &self, 
137        estimator: &N, 
138        P: &PlaintextRing<Params>, 
139        C: &CiphertextRing<Params>, 
140        dropped_factors: &RNSFactorIndexList, 
141        m: &Self::Element, 
142        ct_info: &N::CriticalQuantityLevel, 
143        implicit_scale: &ZnEl
144    ) -> N::CriticalQuantityLevel;
145
146    fn hom_mul_to(
147        &self, 
148        P: &PlaintextRing<Params>, 
149        C: &CiphertextRing<Params>, 
150        dropped_factors: &RNSFactorIndexList, 
151        m: &Self::Element, 
152        ct: Ciphertext<Params>
153    ) -> Ciphertext<Params>;
154
155    fn hom_mul_to_noise<N: BGVNoiseEstimator<Params>>(
156        &self, 
157        estimator: &N, 
158        P: &PlaintextRing<Params>, 
159        C: &CiphertextRing<Params>, 
160        dropped_factors: &RNSFactorIndexList, 
161        m: &Self::Element, 
162        ct_info: &N::CriticalQuantityLevel, 
163        implicit_scale: &ZnEl
164    ) -> N::CriticalQuantityLevel;
165
166    fn hom_inner_product<I>(
167        &self, 
168        P: &PlaintextRing<Params>, 
169        C: &CiphertextRing<Params>, 
170        dropped_factors: &RNSFactorIndexList, 
171        data: I
172    ) -> Ciphertext<Params>
173        where I: Iterator<Item = (Self::Element, Ciphertext<Params>)>
174    {
175        let mut first_implicit_scale = None;
176        data.fold(Params::transparent_zero(P, C), |current, (lhs, rhs)| {
177            if first_implicit_scale.is_none() {
178                first_implicit_scale = Some(rhs.implicit_scale);
179            } else {
180                assert!(P.base_ring().eq_el(&first_implicit_scale.unwrap(), &rhs.implicit_scale));
181            }
182            Params::hom_add(P, C, current, self.hom_mul_to(P, C, dropped_factors, &lhs, rhs))
183        })
184    }
185
186    fn hom_inner_product_noise<'a, 'b, N: BGVNoiseEstimator<Params>, I>(
187        &self, 
188        estimator: &N, 
189        P: &PlaintextRing<Params>, 
190        C: &CiphertextRing<Params>, 
191        dropped_factors: &RNSFactorIndexList, 
192        data: I
193    ) -> N::CriticalQuantityLevel
194        where I: Iterator<Item = (&'a Self::Element, &'b N::CriticalQuantityLevel)>,
195        Self: 'a,
196        N::CriticalQuantityLevel: 'b
197    {
198        data.fold(estimator.transparent_zero(), |current, (lhs, rhs)| {
199            estimator.hom_add(P, C, &current, P.base_ring().one(), &self.hom_mul_to_noise(estimator, P, C, dropped_factors, lhs, &rhs, &P.base_ring().one()), P.base_ring().one())
200        })
201    }
202
203    fn apply_galois_action_plain(
204        &self,
205        P: &PlaintextRing<Params>, 
206        x: &Self::Element,
207        gs: &[CyclotomicGaloisGroupEl]
208    ) -> Vec<Self::Element>;
209}
210
211///
212/// Chooses `drop_prime_count` indices from `0..rns_base_len`. These indices are chosen in a way
213/// that minimizes the noise growth of key-switching (with the given parameters) after we drop the
214/// corresponding RNS factors.
215/// 
216/// This function will never return indices corresponding to special moduli.
217/// 
218/// Note that this function assumes that all RNS factors have approximately the same size. If this
219/// is not the case, their individual size should be considered when choosing which factors to drop.
220///  
221/// # The standard use case 
222/// 
223/// This hopefully becomes clearer once we consider the main use case:
224/// When we do modulus-switching (e.g. during BGV), we remove RNS factors from the ciphertext modulus.
225/// For the ciphertexts itself, it is (almost) irrelevant which of these RNS factors are removed, but it makes
226/// a huge difference when mod-switching key-switching keys (e.g. relinearization keys). This is because
227/// the used gadget vector relies is based on a decomposition of RNS factors into groups, and removing a single
228/// RNS factor from every group will give a very different behavior from removing a single, whole group and
229/// leaving the other groups unchanged.
230/// 
231/// This function will choose the RNS factors to drop with the goal of minimizing noise growth. In particular,
232/// as long as the RNS factor groups (the digits) are larger than the special modulus, this function will remove
233/// RNS factors from each group in a balanced manner.
234/// 
235/// This is probably the desired behavior in most cases, but other behaviors might as well be reasonable in 
236/// certain scenarios. 
237/// 
238/// # Example
239/// ```
240/// # use feanor_math::seq::*;
241/// # use he_ring::gadget_product::*;
242/// # use he_ring::bgv::KeySwitchKeyParams;
243/// # use he_ring::bgv::modswitch::recommended_rns_factors_to_drop;
244/// # use he_ring::gadget_product::digits::*;
245/// let digits = RNSGadgetVectorDigitIndices::from([0..3, 3..5].clone_els());
246/// let params = KeySwitchKeyParams {
247///     digits_without_special: digits,
248///     special_modulus_factor_count: 0
249/// };
250/// // remove the first two indices from 0..3, and the first index from 3..5 - the resulting ranges both have length 1
251/// assert_eq!(&[0usize, 1, 3][..] as &[usize], &*recommended_rns_factors_to_drop(params, 3) as &[usize]);
252/// ```
253/// 
254pub fn recommended_rns_factors_to_drop(key_params: KeySwitchKeyParams, drop_prime_count: usize) -> Box<RNSFactorIndexList> {
255    assert!(drop_prime_count < key_params.digits_without_special.rns_base_len());
256
257    let mut drop_from_digit = (0..key_params.digits_without_special.len()).map(|_| 0).collect::<Vec<_>>();
258
259    let effective_len = |range: Range<usize>| range.end - range.start;
260    for _ in 0..drop_prime_count {
261        let largest_digit_idx = (0..key_params.digits_without_special.len()).max_by_key(|i| effective_len(key_params.digits_without_special.at(*i)) - drop_from_digit[*i]).unwrap();
262        drop_from_digit[largest_digit_idx] += 1;
263    }
264
265    let result = RNSFactorIndexList::from((0..key_params.digits_without_special.len()).flat_map(|i| key_params.digits_without_special.at(i).start..(key_params.digits_without_special.at(i).start + drop_from_digit[i])).collect(), key_params.digits_without_special.rns_base_len());
266    return result;
267}
268
269
270///
271/// Default modulus-switch strategy for BGV, which performs a certain number of modulus-switches
272/// before each multiplication.
273///
274/// The general strategy is as follows:
275///  - only mod-switch before multiplications
276///  - never introduce new RNS factors, only remove current ones
277///  - use the provided [`BGVNoiseEstimator`] to determine when and by how much
278///    we should reduce the ciphertext modulus
279///
280/// These points lead to a relatively simple and generally well-performing modulus switching strategy.
281/// However, there may be situations where deviating from 1. could lead to a lower number of mod-switches
282/// (and thus better performance), and deviating from 2. could be used for a finer-tuned mod-switching,
283/// and thus less noise growth.
284///
285pub struct DefaultModswitchStrategy<Params: BGVCiphertextParams, N: BGVNoiseEstimator<Params>, const LOG: bool> {
286    params: PhantomData<Params>,
287    noise_estimator: N
288}
289
290impl<Params: BGVCiphertextParams> DefaultModswitchStrategy<Params, AlwaysZeroNoiseEstimator, false> {
291
292    ///
293    /// Create a [`DefaultModswitchStrategy`] that never performs modulus switching,
294    /// except when necessary because operands are defined modulo different RNS bases.
295    ///
296    /// Using this is not recommended, except for linear circuits, or circuits with
297    /// very low multiplicative depth.
298    ///
299    pub fn never_modswitch() -> Self {
300        Self {
301            params: PhantomData,
302            noise_estimator: AlwaysZeroNoiseEstimator
303        }
304    }
305}
306
307///
308/// Used internally when evaluating a circuit, since we want to store plaintexts
309/// as plaintexts as long as possible - or rather until we know w.r.t. which RNS
310/// base we should convert them into a ciphertext ring element
311/// 
312enum PlainOrCiphertext<'a, Params: BGVCiphertextParams, Strategy: BGVModswitchStrategy<Params>, R: ?Sized + RingBase> {
313    Plaintext(Coefficient<R>),
314    PlaintextRef(&'a Coefficient<R>),
315    CiphertextRef(&'a ModulusAwareCiphertext<Params, Strategy>),
316    Ciphertext(ModulusAwareCiphertext<Params, Strategy>)
317}
318
319impl<'a, Params: BGVCiphertextParams, Strategy: BGVModswitchStrategy<Params>, R: ?Sized + RingBase> PlainOrCiphertext<'a, Params, Strategy, R> {
320
321    fn as_ciphertext_ref<'b>(&'b self) -> Result<&'b ModulusAwareCiphertext<Params, Strategy>, &'b Coefficient<R>> {
322        match self {
323            PlainOrCiphertext::Plaintext(x) => Err(x),
324            PlainOrCiphertext::PlaintextRef(x) => Err(x),
325            PlainOrCiphertext::Ciphertext(x) => Ok(x),
326            PlainOrCiphertext::CiphertextRef(x) => Ok(x)
327        }
328    }
329
330    fn as_ciphertext<S: RingStore<Type = R>>(self, P: &PlaintextRing<Params>, C_master: &CiphertextRing<Params>, ring: S, strategy: &Strategy) -> Result<(CiphertextRing<Params>, ModulusAwareCiphertext<Params, Strategy>), Coefficient<R>> {
331        match self {
332            PlainOrCiphertext::Plaintext(x) => Err(x),
333            PlainOrCiphertext::PlaintextRef(x) => Err(x.clone(ring)),
334            PlainOrCiphertext::CiphertextRef(x) => {
335                let Cx = Params::mod_switch_down_C(C_master, &x.dropped_rns_factor_indices);
336                let x = ModulusAwareCiphertext {
337                    data: Params::clone_ct(P, &Cx, &x.data),
338                    dropped_rns_factor_indices: x.dropped_rns_factor_indices.clone(),
339                    info: strategy.clone_info(&x.info)
340                };
341                Ok((Cx, x))
342            },
343            PlainOrCiphertext::Ciphertext(x) => {
344                let Cx = Params::mod_switch_down_C(C_master, &x.dropped_rns_factor_indices);
345                Ok((Cx, x))
346            }
347        }
348    }
349}
350
351impl<Params: BGVCiphertextParams> AsBGVPlaintext<Params> for StaticRingBase<i64> {
352
353    fn hom_add_to(
354        &self, 
355        P: &PlaintextRing<Params>, 
356        C: &CiphertextRing<Params>, 
357        _dropped_factors: &RNSFactorIndexList, 
358        m: &Self::Element, 
359        ct: Ciphertext<Params>
360    ) -> Ciphertext<Params> {
361        Params::hom_add_plain_encoded(P, C, &C.inclusion().map(C.base_ring().coerce(&ZZi64, *m)), ct)
362    }
363
364    fn hom_add_to_noise<N: BGVNoiseEstimator<Params>>(
365        &self, 
366        estimator: &N, 
367        P: &PlaintextRing<Params>, 
368        C: &CiphertextRing<Params>, 
369        _dropped_factors: &RNSFactorIndexList, 
370        m: &Self::Element, 
371        ct_info: &N::CriticalQuantityLevel, 
372        implicit_scale: &ZnEl
373    ) -> N::CriticalQuantityLevel {
374        estimator.hom_add_plain_encoded(P, C, &C.inclusion().map(C.base_ring().coerce(&ZZi64, *m)), ct_info, *implicit_scale)
375    }
376
377    fn hom_mul_to(
378        &self, 
379        P: &PlaintextRing<Params>, 
380        C: &CiphertextRing<Params>, 
381        _dropped_factors: &RNSFactorIndexList, 
382        m: &Self::Element, 
383        ct: Ciphertext<Params>
384    ) -> Ciphertext<Params> {
385        Params::hom_mul_plain_i64(P, C, *m, ct)
386    }
387
388    fn hom_mul_to_noise<N: BGVNoiseEstimator<Params>>(
389        &self, 
390        estimator: &N, 
391        P: &PlaintextRing<Params>, 
392        C: &CiphertextRing<Params>, 
393        _dropped_factors: &RNSFactorIndexList, 
394        m: &Self::Element, 
395        ct_info: &N::CriticalQuantityLevel, 
396        implicit_scale: &ZnEl
397    ) -> N::CriticalQuantityLevel {
398        estimator.hom_mul_plain_i64(P, C, *m, ct_info, *implicit_scale)
399    }
400
401    fn apply_galois_action_plain(
402        &self,
403        _P: &PlaintextRing<Params>, 
404        x: &Self::Element,
405        gs: &[CyclotomicGaloisGroupEl]
406    ) -> Vec<Self::Element> {
407        gs.iter().map(|_| self.clone_el(x)).collect()
408    }
409}
410
411impl<Params: BGVCiphertextParams> AsBGVPlaintext<Params> for StaticRingBase<i32> {
412
413    fn hom_add_to(
414        &self, 
415        P: &PlaintextRing<Params>, 
416        C: &CiphertextRing<Params>, 
417        _dropped_factors: &RNSFactorIndexList, 
418        m: &Self::Element, 
419        ct: Ciphertext<Params>
420    ) -> Ciphertext<Params> {
421        Params::hom_add_plain_encoded(P, C, &C.inclusion().map(C.base_ring().coerce(&StaticRing::<i32>::RING, *m)), ct)
422    }
423
424    fn hom_add_to_noise<N: BGVNoiseEstimator<Params>>(
425        &self, 
426        estimator: &N, 
427        P: &PlaintextRing<Params>, 
428        C: &CiphertextRing<Params>, 
429        _dropped_factors: &RNSFactorIndexList, 
430        m: &Self::Element, 
431        ct_info: &N::CriticalQuantityLevel, 
432        implicit_scale: &ZnEl
433    ) -> N::CriticalQuantityLevel {
434        estimator.hom_add_plain_encoded(P, C, &C.inclusion().map(C.base_ring().coerce(&StaticRing::<i32>::RING, *m)), ct_info, *implicit_scale)
435    }
436
437    fn hom_mul_to(
438        &self, 
439        P: &PlaintextRing<Params>, 
440        C: &CiphertextRing<Params>, 
441        _dropped_factors: &RNSFactorIndexList, 
442        m: &Self::Element, 
443        ct: Ciphertext<Params>
444    ) -> Ciphertext<Params> {
445        Params::hom_mul_plain_i64(P, C, *m as i64, ct)
446    }
447
448    fn hom_mul_to_noise<N: BGVNoiseEstimator<Params>>(
449        &self, 
450        estimator: &N, 
451        P: &PlaintextRing<Params>, 
452        C: &CiphertextRing<Params>, 
453        _dropped_factors: &RNSFactorIndexList, 
454        m: &Self::Element, 
455        ct_info: &N::CriticalQuantityLevel, 
456        implicit_scale: &ZnEl
457    ) -> N::CriticalQuantityLevel {
458        estimator.hom_mul_plain_i64(P, C, *m as i64, ct_info, *implicit_scale)
459    }
460
461    fn apply_galois_action_plain(
462        &self,
463        _P: &PlaintextRing<Params>, 
464        x: &Self::Element,
465        gs: &[CyclotomicGaloisGroupEl]
466    ) -> Vec<Self::Element> {
467        gs.iter().map(|_| self.clone_el(x)).collect()
468    }
469}
470
471impl<Params: BGVCiphertextParams> AsBGVPlaintext<Params> for NumberRingQuotientBase<NumberRing<Params>, Zn> {
472
473    fn hom_add_to(
474        &self, 
475        P: &PlaintextRing<Params>, 
476        C: &CiphertextRing<Params>, 
477        _dropped_factors: &RNSFactorIndexList, 
478        m: &Self::Element, 
479        ct: Ciphertext<Params>
480    ) -> Ciphertext<Params> {
481        Params::hom_add_plain(P, C, m, ct)
482    }
483
484    fn hom_add_to_noise<N: BGVNoiseEstimator<Params>>(
485        &self, 
486        estimator: &N, 
487        P: &PlaintextRing<Params>, 
488        C: &CiphertextRing<Params>, 
489        _dropped_factors: &RNSFactorIndexList, 
490        m: &Self::Element, 
491        ct_info: &N::CriticalQuantityLevel, 
492        implicit_scale: &ZnEl
493    ) -> N::CriticalQuantityLevel {
494        estimator.hom_add_plain(P, C, m, ct_info, *implicit_scale)
495    }
496
497    fn hom_mul_to(
498        &self, 
499        P: &PlaintextRing<Params>, 
500        C: &CiphertextRing<Params>, 
501        _dropped_factors: &RNSFactorIndexList, 
502        m: &Self::Element, 
503        ct: Ciphertext<Params>
504    ) -> Ciphertext<Params> {
505        Params::hom_mul_plain(P, C, m, ct)
506    }
507
508    fn hom_mul_to_noise<N: BGVNoiseEstimator<Params>>(
509        &self, 
510        estimator: &N, 
511        P: &PlaintextRing<Params>, 
512        C: &CiphertextRing<Params>, 
513        _dropped_factors: &RNSFactorIndexList, 
514        m: &Self::Element, 
515        ct_info: &N::CriticalQuantityLevel, 
516        implicit_scale: &ZnEl
517    ) -> N::CriticalQuantityLevel {
518        estimator.hom_mul_plain(P, C, m, ct_info, *implicit_scale)
519    }
520
521    fn apply_galois_action_plain(
522        &self,
523        _P: &PlaintextRing<Params>, 
524        x: &Self::Element,
525        gs: &[CyclotomicGaloisGroupEl]
526    ) -> Vec<Self::Element> {
527        self.apply_galois_action_many(x, gs)
528    }
529}
530
531impl<Params: BGVCiphertextParams, A: Allocator + Clone> AsBGVPlaintext<Params> for ManagedDoubleRNSRingBase<NumberRing<Params>, A>
532    where CiphertextRing<Params>: RingStore<Type = ManagedDoubleRNSRingBase<NumberRing<Params>, A>>
533{
534    fn hom_add_to(
535        &self, 
536        P: &PlaintextRing<Params>, 
537        C: &CiphertextRing<Params>, 
538        dropped_factors: &RNSFactorIndexList, 
539        m: &Self::Element, 
540        ct: Ciphertext<Params>
541    ) -> Ciphertext<Params> {
542        Params::hom_add_plain_encoded(P, C, &C.get_ring().drop_rns_factor_element(self, dropped_factors, self.clone_el(m)), ct)
543    }
544
545    fn hom_add_to_noise<N: BGVNoiseEstimator<Params>>(
546        &self, 
547        estimator: &N, 
548        P: &PlaintextRing<Params>, 
549        C: &CiphertextRing<Params>, 
550        dropped_factors: &RNSFactorIndexList, 
551        m: &Self::Element, 
552        ct_info: &N::CriticalQuantityLevel, 
553        implicit_scale: &ZnEl
554    ) -> N::CriticalQuantityLevel {
555        estimator.hom_add_plain_encoded(P, C, &C.get_ring().drop_rns_factor_element(self, dropped_factors, self.clone_el(m)), ct_info, *implicit_scale)
556    }
557
558    fn hom_mul_to(
559        &self, 
560        P: &PlaintextRing<Params>, 
561        C: &CiphertextRing<Params>, 
562        dropped_factors: &RNSFactorIndexList, 
563        m: &Self::Element, 
564        ct: Ciphertext<Params>
565    ) -> Ciphertext<Params> {
566        Params::hom_mul_plain_encoded(P, C, &C.get_ring().drop_rns_factor_element(self, dropped_factors, self.clone_el(m)), ct)
567    }
568
569    fn hom_mul_to_noise<N: BGVNoiseEstimator<Params>>(
570        &self, 
571        estimator: &N, 
572        P: &PlaintextRing<Params>, 
573        C: &CiphertextRing<Params>, 
574        dropped_factors: &RNSFactorIndexList, 
575        m: &Self::Element, 
576        ct_info: &N::CriticalQuantityLevel, 
577        implicit_scale: &ZnEl
578    ) -> N::CriticalQuantityLevel {
579        estimator.hom_mul_plain_encoded(P, C, &C.get_ring().drop_rns_factor_element(self, dropped_factors, self.clone_el(m)), ct_info, *implicit_scale)
580    }
581
582    #[instrument(skip_all)]
583    fn hom_inner_product<I>(
584        &self, 
585        P: &PlaintextRing<Params>, 
586        C: &CiphertextRing<Params>, 
587        dropped_factors: &RNSFactorIndexList, 
588        data: I
589    ) -> Ciphertext<Params>
590        where I: Iterator<Item = (Self::Element, Ciphertext<Params>)>
591    {
592        let mut lhs = Vec::new();
593        let mut rhs_c0 = Vec::new();
594        let mut rhs_c1 = Vec::new();
595        let mut first_implicit_scale = None;
596        for (l, r) in data {
597            if first_implicit_scale.is_none() {
598                first_implicit_scale = Some(r.implicit_scale);
599            } else {
600                assert!(P.base_ring().eq_el(&first_implicit_scale.unwrap(), &r.implicit_scale));
601            }
602            lhs.push(l);
603            rhs_c0.push(r.c0);
604            rhs_c1.push(r.c1);
605        }
606        return Ciphertext {
607            implicit_scale: first_implicit_scale.unwrap_or(P.base_ring().one()),
608            c0: <_ as ComputeInnerProduct>::inner_product(C.get_ring(), lhs.iter().zip(rhs_c0.into_iter()).map(|(lhs, rhs)| (C.get_ring().drop_rns_factor_element(self, dropped_factors, self.clone_el(lhs)), rhs))),
609            c1: <_ as ComputeInnerProduct>::inner_product(C.get_ring(), lhs.into_iter().zip(rhs_c1.into_iter()).map(|(lhs, rhs)| (C.get_ring().drop_rns_factor_element(self, dropped_factors, lhs), rhs))),
610        };
611    }
612
613    fn apply_galois_action_plain(
614        &self,
615        _P: &PlaintextRing<Params>, 
616        x: &Self::Element,
617        gs: &[CyclotomicGaloisGroupEl]
618    ) -> Vec<Self::Element> {
619        self.apply_galois_action_many(x, gs)
620    }
621}
622
623impl<Params: BGVCiphertextParams, N: BGVNoiseEstimator<Params>, const LOG: bool> DefaultModswitchStrategy<Params, N, LOG> {
624
625    pub fn new(noise_estimator: N) -> Self {
626        Self {
627            params: PhantomData,
628            noise_estimator: noise_estimator
629        }
630    }
631
632    pub fn from_noise_level(&self, noise_level: N::CriticalQuantityLevel) -> <Self as BGVModswitchStrategy<Params>>::CiphertextInfo {
633        noise_level
634    }
635
636    ///
637    /// Mod-switches the given ciphertext from its current ciphertext ring
638    /// to `Ctarget`, and adjusts the noise information.
639    /// 
640    fn mod_switch_down(
641        &self, 
642        P: &PlaintextRing<Params>, 
643        C_target: &CiphertextRing<Params>, 
644        C_master: &CiphertextRing<Params>, 
645        dropped_factors_target: &RNSFactorIndexList, 
646        x: ModulusAwareCiphertext<Params, Self>,
647        context: &str,
648        debug_sk: Option<&SecretKey<Params>>
649    ) -> ModulusAwareCiphertext<Params, Self> {
650        let Cx = Params::mod_switch_down_C(C_master, &x.dropped_rns_factor_indices);
651        let drop_x = dropped_factors_target.pushforward(&x.dropped_rns_factor_indices);
652        let x_noise_budget = if let Some(sk) = debug_sk {
653            let sk_x = Params::mod_switch_down_sk(&Cx, C_master, &x.dropped_rns_factor_indices, sk);
654            Some(Params::noise_budget(P, &Cx, &x.data, &sk_x))
655        } else { None };
656        let result = ModulusAwareCiphertext {
657            data: Params::mod_switch_down_ct(P, &C_target, &Cx, &drop_x, x.data),
658            info: self.noise_estimator.mod_switch_down(&P, &C_target, &Cx, &drop_x, &x.info),
659            dropped_rns_factor_indices: dropped_factors_target.to_owned()
660        };
661        if LOG && drop_x.len() > 0 {
662            println!("{}: Dropping RNS factors {} of operand, estimated noise budget {}/{} -> {}/{}",
663                context,
664                drop_x,
665                -self.noise_estimator.estimate_log2_relative_noise_level(P, &Cx, &x.info).round(),
666                ZZbig.abs_log2_ceil(Cx.base_ring().modulus()).unwrap(),
667                -self.noise_estimator.estimate_log2_relative_noise_level(P, C_target, &result.info).round(),
668                ZZbig.abs_log2_ceil(C_target.base_ring().modulus()).unwrap(),
669            );
670            if let Some(sk) = debug_sk {
671                let sk_target = Params::mod_switch_down_sk(C_target, C_master, &dropped_factors_target, sk);
672                println!("  actual noise budget: {} -> {}", x_noise_budget.unwrap(), Params::noise_budget(P, C_target, &result.data, &sk_target));
673            }
674        }
675        return result;
676    }
677
678    ///
679    /// Mod-switches the given ciphertext from its current ciphertext ring
680    /// to `Ctarget`, and adjusts the noise information.
681    /// 
682    fn mod_switch_down_ref(
683        &self, 
684        P: &PlaintextRing<Params>, 
685        C_target: &CiphertextRing<Params>, 
686        C_master: &CiphertextRing<Params>, 
687        dropped_factors_target: &RNSFactorIndexList, 
688        x: &ModulusAwareCiphertext<Params, Self>,
689        context: &str,
690        debug_sk: Option<&SecretKey<Params>>
691    ) -> ModulusAwareCiphertext<Params, Self> {
692        let Cx = Params::mod_switch_down_C(C_master, &x.dropped_rns_factor_indices);
693        let drop_x = dropped_factors_target.pushforward(&x.dropped_rns_factor_indices);
694        let result = ModulusAwareCiphertext {
695            data: Params::mod_switch_down_ct(P, &C_target, &Cx, &drop_x, Params::clone_ct(P, &Cx, &x.data)),
696            info: self.noise_estimator.mod_switch_down(&P, &C_target, &Cx, &drop_x, &x.info),
697            dropped_rns_factor_indices: dropped_factors_target.to_owned()
698        };
699        if LOG && drop_x.len() > 0 {
700            println!("{}: Dropping RNS factors {} of operand, estimated noise budget {}/{} -> {}/{}",
701                context,
702                drop_x,
703                -self.noise_estimator.estimate_log2_relative_noise_level(P, &Cx, &x.info).round(),
704                ZZbig.abs_log2_ceil(Cx.base_ring().modulus()).unwrap(),
705                -self.noise_estimator.estimate_log2_relative_noise_level(P, C_target, &result.info).round(),
706                ZZbig.abs_log2_ceil(C_target.base_ring().modulus()).unwrap(),
707            );
708            if let Some(sk) = debug_sk {
709                let sk_target = Params::mod_switch_down_sk(C_target, C_master, &dropped_factors_target, sk);
710                let sk_x = Params::mod_switch_down_sk(&Cx, C_master, &x.dropped_rns_factor_indices, sk);
711                println!("  actual noise budget: {} -> {}", Params::noise_budget(P, &Cx, &x.data, &sk_x), Params::noise_budget(P, C_target, &result.data, &sk_target));
712            }
713        }
714        return result;
715    }
716
717    ///
718    /// Computes the RNS base we should switch to before multiplication to
719    /// minimize the result noise. The result is returned as the list of RNS
720    /// factors of `C_master` that we want to drop. This list corresponds to
721    /// the RNS factors to drop from the ciphertexts, i.e. they always include
722    /// the RNS factors for the special modulus. Of course, these must not be
723    /// dropped from the relinearization key.
724    /// 
725    #[instrument(skip_all)]
726    fn compute_optimal_mul_modswitch(
727        &self,
728        P: &PlaintextRing<Params>,
729        C_master: &CiphertextRing<Params>,
730        noise_x: &N::CriticalQuantityLevel,
731        dropped_factors_x: &RNSFactorIndexList,
732        noise_y: &N::CriticalQuantityLevel,
733        dropped_factors_y: &RNSFactorIndexList,
734        rk: KeySwitchKeyParams
735    ) -> Box<RNSFactorIndexList> {
736        let special_modulus_factors = RNSFactorIndexList::from(((C_master.base_ring().len() - rk.special_modulus_factor_count)..C_master.base_ring().len()).collect(), C_master.base_ring().len());
737        let Cx = Params::mod_switch_down_C(C_master, dropped_factors_x);
738        let Cy = Params::mod_switch_down_C(C_master, dropped_factors_y);
739
740        // first, we drop all the RNS factors that are required to make the product well-defined;
741        // these are exactly the RNS factors that are missing in either input, and the ones corresponding
742        // to the special modulus
743        let base_drop_without_special = dropped_factors_x.union(&dropped_factors_y).subtract(&special_modulus_factors);
744        let rk_digits_after_base_drop = rk.digits_without_special.remove_indices(&base_drop_without_special);
745
746        // now try every number of additional RNS factors to drop
747        let compute_result_noise = |num_to_drop: usize| {
748            let second_drop_without_special = recommended_rns_factors_to_drop(KeySwitchKeyParams { digits_without_special: rk_digits_after_base_drop.clone(), special_modulus_factor_count: rk.special_modulus_factor_count }, num_to_drop);
749            let total_drop_without_special = second_drop_without_special.pullback(&base_drop_without_special);
750            let total_drop = total_drop_without_special.union(&special_modulus_factors);
751            let C_target = Params::mod_switch_down_C(C_master, &total_drop);
752            let C_special = Params::mod_switch_down_C(C_master, &total_drop_without_special);
753            let rk_digits_after_total_drop = rk.digits_without_special.remove_indices(&total_drop_without_special);
754
755            let expected_noise = self.noise_estimator.estimate_log2_relative_noise_level(
756                P,
757                &C_target,
758                &self.noise_estimator.hom_mul(
759                    P,
760                    &C_target,
761                    &C_special,
762                    &self.noise_estimator.mod_switch_down(&P, &C_target, &Cx, &total_drop.pushforward(dropped_factors_x), noise_x),
763                    &self.noise_estimator.mod_switch_down(&P, &C_target, &Cy, &total_drop.pushforward(dropped_factors_y), noise_y),
764                    &KeySwitchKeyParams {
765                        digits_without_special: rk_digits_after_total_drop,
766                        special_modulus_factor_count: rk.special_modulus_factor_count
767                    }
768                )
769            );
770            return (total_drop, expected_noise);
771        };
772        return (0..(C_master.base_ring().len() - base_drop_without_special.len() - special_modulus_factors.len())).map(compute_result_noise).min_by(|(_, l), (_, r)| f64::total_cmp(l, r)).unwrap().0;
773    }
774
775    ///
776    /// Computes the value `x + sum_i cs[i] * y[i]`, by mod-switching all involved
777    /// ciphertexts to the RNS base of all shared RNS factors. In particular, if the
778    /// input ciphertexts are all defined w.r.t. the same RNS base, no modulus-switching
779    /// is performed at all.
780    /// 
781    /// This function is quite complicated, as there are many things to consider:
782    ///  - We have to handle both constants and ciphertexts
783    ///  - Special coefficients (e.g. `0, 1, -1`) should be handled without a full
784    ///    plaintext-ciphertext multiplication
785    ///  - We decide not to perform intermediate modulus-switches, but only modulus-switch
786    ///    at the very beginning. Note however that it might be possible to group
787    ///    summands depending on their RNS base, and reduce the number of modulus-switches
788    ///  - We have to decide on the `implicit_scale` of the result, its choice may
789    ///    affect noise growth 
790    ///  - using inner product functionality of the underlying ring can give us better
791    ///    performance than many isolated additions/multiplications
792    /// 
793    #[instrument(skip_all)]
794    fn add_inner_prod<'a, R>(
795        &self,
796        P: &PlaintextRing<Params>,
797        C_master: &CiphertextRing<Params>,
798        x: PlainOrCiphertext<'a, Params, Self, R::Type>,
799        cs: &[Coefficient<R::Type>],
800        ys: &[PlainOrCiphertext<'a, Params, Self, R::Type>],
801        ring: R,
802        debug_sk: Option<&SecretKey<Params>>
803    ) -> PlainOrCiphertext<'a, Params, Self, R::Type>
804        where R: RingStore + Copy,
805            R::Type: AsBGVPlaintext<Params>
806    {
807        assert_eq!(cs.len(), ys.len());
808
809        // first, we separate the inner product into three parts:
810        //  - the constant part, which does not contain any ciphertexts and is immediately computed
811        //  - the integer part, which is of the form `sum_i c[i] * ct[i]` with `c[i]` being integers
812        //  - the main part, which is of the form `sum_i c[i] * ct[i]` with `c[i]` being elements of `R`
813        let mut constant = Coefficient::Zero;
814        let mut int_products: Vec<(i32, &ModulusAwareCiphertext<Params, Self>)> = Vec::new();
815        let mut main_products:  Vec<(&El<R>, &ModulusAwareCiphertext<Params, Self>)> = Vec::new();
816
817        // while separating the different summands, we also keep track of which will be the result modulus
818        let mut total_drop = RNSFactorIndexList::empty();
819        let mut min_dropped_len = usize::MAX;
820        let mut update_total_drop = |ct: &ModulusAwareCiphertext<Params, Self>| {
821            total_drop = total_drop.union(&ct.dropped_rns_factor_indices);
822            min_dropped_len = min(min_dropped_len, ct.dropped_rns_factor_indices.len());
823        };
824
825        for (c, y) in cs.iter().zip(ys.iter()) {
826            match y.as_ciphertext_ref() {
827                Err(y) => constant = constant.add(c.clone(ring).mul(y.clone(ring), ring), ring),
828                Ok(y) => if !c.is_zero() {
829                    update_total_drop(y);
830                    match c {
831                        Coefficient::Zero => unreachable!(),
832                        Coefficient::One => int_products.push((1, y)),
833                        Coefficient::NegOne => int_products.push((-1, y)),
834                        Coefficient::Integer(c) => int_products.push((*c, y)),
835                        Coefficient::Other(c) => main_products.push((c, y)),
836                    }
837                }
838            }
839        }
840        match x.as_ciphertext_ref() {
841            Ok(x) => {
842                update_total_drop(x);
843            },
844            Err(x) => if int_products.len() == 0 && main_products.len() == 0 {
845                // if `x` is a constant and we have no products involving ciphertexts, everything is just a constant
846                return PlainOrCiphertext::Plaintext(x.clone(ring).add(constant, ring));
847            }
848        }
849        assert!(min_dropped_len <= total_drop.len());
850
851        let Ctarget = Params::mod_switch_down_C(C_master, &total_drop);
852
853        // now perform modulus-switches when necessary
854        let mut int_products: Vec<(i32, ModulusAwareCiphertext<Params, Self>)> = int_products.into_iter().map(|(c, y)| (
855            c,
856            self.mod_switch_down_ref(P, &Ctarget, C_master, &total_drop, y, "HomInnerProduct", debug_sk)
857        )).collect();
858
859        let mut main_products: Vec<(El<R>, ModulusAwareCiphertext<Params, Self>)> = main_products.into_iter().map(|(c, y)| (
860            ring.clone_el(c),
861            self.mod_switch_down_ref(P, &Ctarget, C_master, &total_drop, y, "HomInnerProduct", debug_sk)
862        )).collect();
863
864        // finally, we do another noise optimization technique: the implicit scale of the output is
865        // chosen as total scale (implicit scale + coefficient) of the highest-noise ciphertext; this way
866        // we avoid multiplying its size up further
867        let Zt = P.base_ring();
868        let output_implicit_scale: ZnEl = int_products.iter().filter_map(|(c, ct)| Zt.invert(&Zt.int_hom().map(*c)).map(|c| (c, ct)))
869            .map(|(c, ct)| (self.noise_estimator.estimate_log2_relative_noise_level(P, &Ctarget, &ct.info), Zt.mul(ct.data.implicit_scale, c))
870        ).max_by(|(l, _), (r, _)| f64::total_cmp(l, r)).map(|(_, scale)| scale).unwrap_or(P.base_ring().one());
871
872        for (c, ct) in &mut int_products {
873            *c = Zt.smallest_lift(Zt.mul(Zt.int_hom().map(*c), Zt.checked_div(&output_implicit_scale, &ct.data.implicit_scale).unwrap())) as i32;
874            ct.data.implicit_scale = output_implicit_scale;
875        }
876        for (c, ct) in &mut main_products {
877            let factor = Zt.smallest_lift(Zt.checked_div(&output_implicit_scale, &ct.data.implicit_scale).unwrap()) as i32;
878            if factor != 1 {
879                ring.int_hom().mul_assign_map(c, factor);
880            }
881            ct.data.implicit_scale = output_implicit_scale;
882        }
883
884        let int_product_noise = StaticRing::<i32>::RING.get_ring().hom_inner_product_noise(&self.noise_estimator, P, &Ctarget, &total_drop, int_products.iter().map(|(lhs, rhs)| (lhs, &rhs.info)));
885        let int_product_part = StaticRing::<i32>::RING.get_ring().hom_inner_product(P, &Ctarget, &total_drop, int_products.into_iter().map(|(lhs, rhs)| (lhs, rhs.data)));
886
887        let main_product_noise = ring.get_ring().hom_inner_product_noise(&self.noise_estimator, P, &Ctarget, &total_drop, main_products.iter().map(|(lhs, rhs)| (lhs, &rhs.info)));
888        let main_product_part = ring.get_ring().hom_inner_product(P, &Ctarget, &total_drop, main_products.into_iter().map(|(lhs, rhs)| (lhs, rhs.data)));
889
890        return PlainOrCiphertext::Ciphertext(match x.as_ciphertext(P, C_master, ring, self) {
891            Ok((_, x)) => {
892                let x_modswitch = self.mod_switch_down(P, &Ctarget, C_master, &total_drop, x, "HomAdd", debug_sk);
893                ModulusAwareCiphertext {
894                    info: self.noise_estimator.hom_add(P, &Ctarget, &x_modswitch.info, x_modswitch.data.implicit_scale, 
895                        &self.noise_estimator.hom_add(P, &Ctarget, &int_product_noise, P.base_ring().one(), &main_product_noise, P.base_ring().one()),
896                        P.base_ring().one()
897                    ),
898                    data: ring.get_ring().hom_add_to(P, &Ctarget, &total_drop,
899                        &constant.to_ring_el(&ring),
900                        Params::hom_add(P, &Ctarget, x_modswitch.data, Params::hom_add(P, &Ctarget, int_product_part, main_product_part))
901                    ),
902                    dropped_rns_factor_indices: total_drop
903                }
904            },
905            Err(x) => {
906                constant = constant.add(x, ring);
907                // ignore the last plaintext addition for noise analysis, its gonna be fine
908                let res_info = self.noise_estimator.hom_add(P, &Ctarget, &int_product_noise, P.base_ring().one(), &main_product_noise, P.base_ring().one());
909                let product_data = Params::hom_add(P, &Ctarget, int_product_part, main_product_part);
910                let res_data = match constant {
911                    Coefficient::Zero => product_data,
912                    Coefficient::One => Params::hom_add_plain_encoded(P, &Ctarget, &Ctarget.one(), product_data),
913                    Coefficient::NegOne => Params::hom_add_plain_encoded(P, &Ctarget, &Ctarget.neg_one(), product_data),
914                    Coefficient::Integer(c) => Params::hom_add_plain_encoded(P, &Ctarget, &Ctarget.int_hom().map(c), product_data),
915                    Coefficient::Other(m) => ring.get_ring().hom_add_to(P, &Ctarget, &total_drop, &m, product_data),
916                };
917                ModulusAwareCiphertext {
918                    data: res_data,
919                    info: res_info,
920                    dropped_rns_factor_indices: total_drop
921                }
922            }
923        });
924    }
925
926    #[instrument(skip_all)]
927    fn mul<'a, R>(
928        &self,
929        P: &PlaintextRing<Params>,
930        C_master: &CiphertextRing<Params>,
931        x: PlainOrCiphertext<'a, Params, Self, R::Type>,
932        y: PlainOrCiphertext<'a, Params, Self, R::Type>,
933        ring: R,
934        rk: Option<&RelinKey<Params>>,
935        key_switches: &RefCell<&mut usize>,
936        debug_sk: Option<&SecretKey<Params>>
937    ) -> PlainOrCiphertext<'a, Params, Self, R::Type>
938        where R: RingStore + Copy,
939            R::Type: AsBGVPlaintext<Params>
940    {
941        match (x.as_ciphertext(P, C_master, ring, self), y.as_ciphertext(P, C_master, ring, self)) {
942            (Err(x), Err(y)) => PlainOrCiphertext::Plaintext(x.mul(y, ring)),
943            // possibly swap `x` and `y` here so that we can handle both asymmetric cases in one statement
944            (Ok((Cx, x)), Err(y)) | (Err(y), Ok((Cx, x))) => PlainOrCiphertext::Ciphertext({
945                let total_drop = x.dropped_rns_factor_indices.clone();
946                let Ctarget = &Cx;
947                
948                let (res_info, res_data) = match y {
949                    Coefficient::Zero => unreachable!(),
950                    Coefficient::One => (x.info, x.data),
951                    Coefficient::NegOne => (x.info, Params::hom_mul_plain_i64(P, &Ctarget, -1, x.data)),
952                    Coefficient::Integer(c) => (
953                        StaticRing::<i64>::RING.get_ring().hom_mul_to_noise(&self.noise_estimator, P, &Ctarget, &total_drop, &(c as i64), &x.info, &x.data.implicit_scale),
954                        StaticRing::<i64>::RING.get_ring().hom_mul_to(P, &Ctarget, &total_drop, &(c as i64), Params::clone_ct(P, &Cx, &x.data)),
955                    ),
956                    Coefficient::Other(m) => (
957                        ring.get_ring().hom_mul_to_noise(&self.noise_estimator, P, &Ctarget, &total_drop, &m, &x.info, &x.data.implicit_scale),
958                        ring.get_ring().hom_mul_to(P, &Ctarget, &total_drop, &m, Params::clone_ct(P, &Cx, &x.data)),
959                    ),
960                };
961
962                ModulusAwareCiphertext {
963                    data: res_data,
964                    info: res_info,
965                    dropped_rns_factor_indices: total_drop
966                }
967            }),
968            // the ciphertext-ciphertext multiplication case
969            (Ok((_, x)), Ok((_, y))) => PlainOrCiphertext::Ciphertext({
970                **key_switches.borrow_mut() += 1;
971                let rk = rk.unwrap();
972
973                let total_drop = self.compute_optimal_mul_modswitch(P, C_master, &x.info, &x.dropped_rns_factor_indices, &y.info, &y.dropped_rns_factor_indices, rk.params());
974                let special_modulus_factors = RNSFactorIndexList::from(((C_master.base_ring().len() - rk.special_modulus_factor_count)..C_master.base_ring().len()).collect(), C_master.base_ring().len());
975                let total_drop_without_special = total_drop.subtract(&special_modulus_factors);
976                let C_special = Params::mod_switch_down_C(&C_master, &total_drop_without_special);
977                let C_target = Params::mod_switch_down_C(C_master, &total_drop);
978                let rk_modswitch = Params::mod_switch_down_rk(&C_special, C_master, &total_drop_without_special, &rk);
979                debug_assert!(total_drop.len() >= x.dropped_rns_factor_indices.len());
980                debug_assert!(total_drop.len() >= y.dropped_rns_factor_indices.len());
981
982                let x_modswitched = self.mod_switch_down(P, &C_target, C_master, &total_drop, x, "HomMul", debug_sk);
983                let y_modswitched = self.mod_switch_down(P, &C_target, C_master, &total_drop, y, "HomMul", debug_sk);
984
985                let res_data = Params::hom_mul(P, &C_target, &C_special, x_modswitched.data, y_modswitched.data, &rk_modswitch);
986                let res_info = self.noise_estimator.hom_mul(P, &C_target, &C_special, &x_modswitched.info, &y_modswitched.info, &rk_modswitch.params());
987
988                if LOG {
989                    println!("HomMul: Result has estimated noise budget {}/{}",
990                        -self.noise_estimator.estimate_log2_relative_noise_level(P, &C_target, &res_info).round(),
991                        ZZbig.abs_log2_ceil(C_target.base_ring().modulus()).unwrap()
992                    );
993                    if let Some(sk) = debug_sk {
994                        let sk_target = Params::mod_switch_down_sk(&C_target, C_master, &total_drop, sk);
995                        println!("  actual noise budget: {}", Params::noise_budget(P, &C_target, &res_data, &sk_target));
996                    }
997                }
998                ModulusAwareCiphertext {
999                    dropped_rns_factor_indices: total_drop,
1000                    info: res_info,
1001                    data: res_data
1002                }
1003            })
1004        }
1005    }
1006
1007    #[instrument(skip_all)]
1008    fn square<'a, R>(
1009        &self,
1010        P: &PlaintextRing<Params>,
1011        C_master: &CiphertextRing<Params>,
1012        x: PlainOrCiphertext<'a, Params, Self, R::Type>,
1013        ring: R,
1014        rk: Option<&RelinKey<Params>>,
1015        key_switches: &RefCell<&mut usize>,
1016        debug_sk: Option<&SecretKey<Params>>
1017    ) -> PlainOrCiphertext<'a, Params, Self, R::Type>
1018        where R: RingStore + Copy,
1019            R::Type: AsBGVPlaintext<Params>
1020    {
1021        match x.as_ciphertext(P, C_master, ring, self) {
1022            Err(x) => PlainOrCiphertext::Plaintext(x.clone(ring).mul(x, ring)),
1023            Ok((_, x)) => PlainOrCiphertext::Ciphertext({
1024                **key_switches.borrow_mut() += 1;
1025                let rk = rk.unwrap();
1026
1027                let total_drop = self.compute_optimal_mul_modswitch(P, C_master, &x.info, &x.dropped_rns_factor_indices, &x.info, &x.dropped_rns_factor_indices, rk.params());
1028                let special_modulus_factors = RNSFactorIndexList::from(((C_master.base_ring().len() - rk.special_modulus_factor_count)..C_master.base_ring().len()).collect(), C_master.base_ring().len());
1029                let total_drop_without_special = total_drop.subtract(&special_modulus_factors);
1030                let C_special = Params::mod_switch_down_C(&C_master, &total_drop_without_special);
1031                let C_target = Params::mod_switch_down_C(C_master, &total_drop);
1032                let rk_modswitch = Params::mod_switch_down_rk(&C_special, C_master, &total_drop_without_special, &rk);
1033                debug_assert!(total_drop.len() >= x.dropped_rns_factor_indices.len());
1034
1035                let x_modswitched = self.mod_switch_down(P, &C_target, C_master, &total_drop, x, "HomSquare", debug_sk);
1036
1037                let res_info = self.noise_estimator.hom_mul(P, &C_target, &C_special, &x_modswitched.info, &x_modswitched.info, &rk_modswitch.params());
1038                let res_data = Params::hom_square(P, &C_target, &C_special, x_modswitched.data, &rk_modswitch);
1039
1040                if LOG {
1041                    println!("HomSquare: Result has estimated noise budget {}/{}",
1042                        -self.noise_estimator.estimate_log2_relative_noise_level(P, &C_target, &res_info).round(),
1043                        ZZbig.abs_log2_ceil(C_target.base_ring().modulus()).unwrap()
1044                    );
1045                    if let Some(sk) = debug_sk {
1046                        let sk_target = Params::mod_switch_down_sk(&C_target, C_master, &total_drop, sk);
1047                        println!("  actual noise budget: {}", Params::noise_budget(P, &C_target, &res_data, &sk_target));
1048                    }
1049                }
1050                // self.log_modulus_switch("HomMul", P, C_master, &Cx, &Cx, &Ctarget, &x.dropped_rns_factor_indices, &x.dropped_rns_factor_indices, &drop_x, &drop_x, &total_drop, &x_data_copy, &x_data_copy, &res_data, &x.info, &x.info, &res_info, debug_sk);
1051                ModulusAwareCiphertext {
1052                    dropped_rns_factor_indices: total_drop,
1053                    info: res_info,
1054                    data: res_data
1055                }
1056            })
1057        }
1058    }
1059
1060    #[instrument(skip_all)]
1061    fn gal_many<'a, R>(
1062        &self,
1063        P: &PlaintextRing<Params>,
1064        C_master: &CiphertextRing<Params>,
1065        x: PlainOrCiphertext<'a, Params, Self, R::Type>,
1066        ring: R,
1067        gs: &[CyclotomicGaloisGroupEl],
1068        gks: &[(CyclotomicGaloisGroupEl, KeySwitchKey<Params>)],
1069        key_switches: &RefCell<&mut usize>,
1070        debug_sk: Option<&SecretKey<Params>>
1071    ) -> Vec<PlainOrCiphertext<'a, Params, Self, R::Type>>
1072        where R: RingStore + Copy,
1073            R::Type: AsBGVPlaintext<Params>
1074    {
1075        match x.as_ciphertext(P, C_master, ring, self) {
1076            Ok((Cx, x)) => {
1077                **key_switches.borrow_mut() += gs.len();
1078                
1079                // to compute Galois automorphisms, we require key-switching; hence, the ciphertext must
1080                // not contain RNS factors belonging to the special modulus - if it does, drop those RNS
1081                // factors here
1082                let special_modulus_factors = RNSFactorIndexList::from(((C_master.base_ring().len() - gks[0].1.special_modulus_factor_count)..C_master.base_ring().len()).collect(), C_master.base_ring().len());
1083                let (total_drop, C_target, x) = if !x.dropped_rns_factor_indices.contains_all(&special_modulus_factors) {
1084                    let total_drop = x.dropped_rns_factor_indices.union(&special_modulus_factors);
1085                    let C_target = Params::mod_switch_down_C(&Cx, &special_modulus_factors);
1086                    let x = self.mod_switch_down(P, &C_target, C_master, &total_drop, x, "HomGalois", debug_sk);
1087                    (total_drop, C_target, x)
1088                } else {
1089                    (x.dropped_rns_factor_indices.clone(), Cx, x)
1090                };
1091                let total_drop_without_special = total_drop.subtract(&special_modulus_factors);
1092                let C_special = Params::mod_switch_down_C(&C_master, &total_drop_without_special);
1093
1094                let gks_mod_switched = (0..gs.len()).map(|i| {
1095                    if let Some((_, gk)) = gks.iter().filter(|(provided_g, _)| C_master.galois_group().eq_el(gs[i], *provided_g)).next() {
1096                        Params::mod_switch_down_gk(&C_special, C_master, &total_drop_without_special, gk)
1097                    } else {
1098                        panic!("Galois key for {} not found", C_master.galois_group().representative(gs[i]))
1099                    }
1100                }).collect::<Vec<_>>();
1101        
1102                let result = if gs.len() == 1 {
1103                    vec![Params::hom_galois(P, &C_target, &C_special, x.data, gs[0], gks_mod_switched.at(0))]
1104                } else {
1105                    Params::hom_galois_many(P, &C_target, &C_special, x.data, gs, gks_mod_switched.as_fn())
1106                };
1107                result.into_iter().zip(gs.into_iter()).zip(gks_mod_switched.iter()).map(|((res, g), gk)| PlainOrCiphertext::Ciphertext(ModulusAwareCiphertext {
1108                    dropped_rns_factor_indices: total_drop.clone(),
1109                    info: self.noise_estimator.hom_galois(&P, &C_target, &C_special, &x.info, *g, &gk.params()),
1110                    data: res
1111                })).collect()
1112            },
1113            Err(Coefficient::Other(x)) => ring.get_ring().apply_galois_action_plain(P, &x, gs).into_iter().map(|x| PlainOrCiphertext::Plaintext(Coefficient::Other(x))).collect(),
1114            // integers are preserved under all galois automorphisms
1115            Err(x) => gs.iter().map(|_| PlainOrCiphertext::Plaintext(x.clone(ring))).collect()
1116        }
1117    }
1118}
1119
1120impl<Params: BGVCiphertextParams, N: BGVNoiseEstimator<Params>, const LOG: bool> BGVModswitchStrategy<Params> for DefaultModswitchStrategy<Params, N, LOG> {
1121
1122    type CiphertextInfo = N::CriticalQuantityLevel;
1123
1124    #[instrument(skip_all)]
1125    fn evaluate_circuit<R>(
1126        &self,
1127        circuit: &PlaintextCircuit<R::Type>,
1128        ring: R,
1129        P: &PlaintextRing<Params>,
1130        C_master: &CiphertextRing<Params>,
1131        inputs: &[ModulusAwareCiphertext<Params, Self>],
1132        rk: Option<&RelinKey<Params>>,
1133        gks: &[(CyclotomicGaloisGroupEl, KeySwitchKey<Params>)],
1134        key_switches: &mut usize,
1135        mut debug_sk: Option<&SecretKey<Params>>
1136    ) -> Vec<ModulusAwareCiphertext<Params, Self>>
1137        where R: RingStore,
1138            R::Type: AsBGVPlaintext<Params>
1139    {
1140        if !LOG {
1141            debug_sk = None;
1142        }
1143        let key_switches_refcell = std::cell::RefCell::new(key_switches);
1144
1145        let result = circuit.evaluate_generic(
1146            &inputs.iter().map(PlainOrCiphertext::CiphertextRef).collect::<Vec<_>>(),
1147            DefaultCircuitEvaluator::new(
1148                |x, y| self.mul(P, C_master, x, y, &ring, rk, &key_switches_refcell, debug_sk),
1149                |m| PlainOrCiphertext::PlaintextRef(m),
1150                |_, _, _| unreachable!(),
1151            ).with_square(
1152                |x| self.square(P, C_master, x, &ring, rk, &key_switches_refcell, debug_sk),
1153            ).with_gal(
1154                |x, gs| self.gal_many(P, C_master, x, &ring, gs, gks, &key_switches_refcell, debug_sk)
1155            ).with_inner_product(
1156                |x, cs, ys| self.add_inner_prod(P, C_master, x, cs, ys, &ring, debug_sk)
1157            )
1158        );
1159        return result.into_iter().map(|res| match res {
1160            PlainOrCiphertext::Ciphertext(x) => x,
1161            PlainOrCiphertext::CiphertextRef(x) => {
1162                let Cx = Params::mod_switch_down_C(C_master, &x.dropped_rns_factor_indices);
1163                ModulusAwareCiphertext {
1164                    data: Params::clone_ct(&P, &Cx, &x.data),
1165                    dropped_rns_factor_indices: x.dropped_rns_factor_indices.clone(),
1166                    info: self.clone_info(&x.info)
1167                }
1168            },
1169            PlainOrCiphertext::Plaintext(x) => {
1170                let x = x.to_ring_el(&ring);
1171                let res_info = ring.get_ring().hom_add_to_noise(&self.noise_estimator, P, C_master, &RNSFactorIndexList::empty(), &x, &self.noise_estimator.transparent_zero(), &P.base_ring().one());
1172                let res_data = ring.get_ring().hom_add_to(P, C_master, &RNSFactorIndexList::empty(), &x, Params::transparent_zero(P, C_master));
1173                ModulusAwareCiphertext {
1174                    data: res_data,
1175                    dropped_rns_factor_indices: RNSFactorIndexList::empty(),
1176                    info: res_info
1177                }
1178            },
1179            PlainOrCiphertext::PlaintextRef(x) => {
1180                let x = x.clone(&ring).to_ring_el(&ring);
1181                let res_info = ring.get_ring().hom_add_to_noise(&self.noise_estimator, P, C_master, &RNSFactorIndexList::empty(), &x, &self.noise_estimator.transparent_zero(), &P.base_ring().one());
1182                let res_data = ring.get_ring().hom_add_to(P, C_master, &RNSFactorIndexList::empty(), &x, Params::transparent_zero(P, C_master));
1183                ModulusAwareCiphertext {
1184                    data: res_data,
1185                    dropped_rns_factor_indices: RNSFactorIndexList::empty(),
1186                    info: res_info
1187                }
1188            }
1189        }).collect();
1190    }
1191
1192    fn info_for_fresh_encryption(&self, P: &PlaintextRing<Params>, C: &CiphertextRing<Params>, hwt: Option<usize>) -> <Self as BGVModswitchStrategy<Params>>::CiphertextInfo {
1193        self.from_noise_level(self.noise_estimator.enc_sym_zero(P, C, hwt))
1194    }
1195
1196    fn clone_info(&self, info: &Self::CiphertextInfo) -> Self::CiphertextInfo {
1197        self.noise_estimator.clone_critical_quantity_level(info)
1198    }
1199
1200    fn print_info(&self, P: &PlaintextRing<Params>, C_master: &CiphertextRing<Params>, ct: &ModulusAwareCiphertext<Params, Self>) {
1201        let Clocal = Params::mod_switch_down_C(C_master, &ct.dropped_rns_factor_indices);
1202        println!("estimated noise: {}", self.noise_estimator.estimate_log2_relative_noise_level(P, &Clocal, &ct.info));
1203    }
1204}
1205
1206#[cfg(test)]
1207use crate::bgv::noise_estimator::NaiveBGVNoiseEstimator;
1208
1209#[test]
1210fn test_default_modswitch_strategy_plain() {
1211    let mut rng = thread_rng();
1212
1213    let params = Pow2BGV {
1214        log2_q_min: 500,
1215        log2_q_max: 520,
1216        log2_N: 7,
1217        ciphertext_allocator: DefaultCiphertextAllocator::default(),
1218        negacyclic_ntt: PhantomData::<DefaultNegacyclicNTT>
1219    };
1220    let t = 257;
1221    
1222    let P = params.create_plaintext_ring(t);
1223    let C = params.create_initial_ciphertext_ring();
1224
1225    let sk = Pow2BGV::gen_sk(&C, &mut rng, None);
1226    let rk = Pow2BGV::gen_rk(&P, &C, &mut rng, &sk, &KeySwitchKeyParams::default(3, 0, C.base_ring().len()));
1227
1228    let input = P.int_hom().map(2);
1229    let ctxt = Pow2BGV::enc_sym(&P, &C, &mut rng, &input, &sk);
1230
1231    let modswitch_strategy: DefaultModswitchStrategy<Pow2BGV, _, true> = DefaultModswitchStrategy::new(NaiveBGVNoiseEstimator);
1232    let pow8_circuit = PlaintextCircuit::mul(ZZ)
1233        .compose(PlaintextCircuit::mul(ZZ).output_twice(ZZ), ZZ)
1234        .compose(PlaintextCircuit::mul(ZZ).output_twice(ZZ), ZZ)
1235        .compose(PlaintextCircuit::identity(1, ZZ).output_twice(ZZ), ZZ);
1236
1237    let res = modswitch_strategy.evaluate_circuit(
1238        &pow8_circuit,
1239        ZZi64,
1240        &P,
1241        &C,
1242        &[ModulusAwareCiphertext {
1243            dropped_rns_factor_indices: RNSFactorIndexList::empty(),
1244            info: modswitch_strategy.info_for_fresh_encryption(&P, &C, None),
1245            data: ctxt
1246        }],
1247        Some(&rk),
1248        &[],
1249        &mut 0,
1250        Some(&sk)
1251    ).into_iter().next().unwrap();
1252
1253    let res_C = Pow2BGV::mod_switch_down_C(&C, &res.dropped_rns_factor_indices);
1254    let res_sk = Pow2BGV::mod_switch_down_sk(&res_C, &C, &res.dropped_rns_factor_indices, &sk);
1255
1256    let res_noise = Pow2BGV::noise_budget(&P, &res_C, &res.data, &res_sk);
1257    println!("Actual output noise budget is {}", res_noise);
1258    assert_el_eq!(&P, &P.neg_one(), Pow2BGV::dec(&P, &res_C, res.data, &res_sk));
1259}
1260
1261#[test]
1262fn test_default_modswitch_strategy_hybrid_keyswitch() {
1263    let mut rng = thread_rng();
1264
1265    let params = Pow2BGV {
1266        log2_q_min: 500,
1267        log2_q_max: 520,
1268        log2_N: 7,
1269        ciphertext_allocator: DefaultCiphertextAllocator::default(),
1270        negacyclic_ntt: PhantomData::<DefaultNegacyclicNTT>
1271    };
1272    let t = 257;
1273
1274    let P = params.create_plaintext_ring(t);
1275    let C = params.create_initial_ciphertext_ring();
1276    let key_switch_params = KeySwitchKeyParams::default(3, 2, C.base_ring().len());
1277
1278    let sk = Pow2BGV::gen_sk(&C, &mut rng, None);
1279    let rk = Pow2BGV::gen_rk(&P, &C, &mut rng, &sk, &key_switch_params);
1280
1281    let input = P.int_hom().map(2);
1282    let ctxt = Pow2BGV::enc_sym(&P, &C, &mut rng, &input, &sk);
1283
1284    let modswitch_strategy: DefaultModswitchStrategy<Pow2BGV, _, true> = DefaultModswitchStrategy::new(NaiveBGVNoiseEstimator);
1285    let pow8_circuit = PlaintextCircuit::mul(ZZ)
1286        .compose(PlaintextCircuit::mul(ZZ).output_twice(ZZ), ZZ)
1287        .compose(PlaintextCircuit::mul(ZZ).output_twice(ZZ), ZZ)
1288        .compose(PlaintextCircuit::identity(1, ZZ).output_twice(ZZ), ZZ);
1289
1290    let res = modswitch_strategy.evaluate_circuit(
1291        &pow8_circuit,
1292        ZZi64,
1293        &P,
1294        &C,
1295        &[ModulusAwareCiphertext {
1296            dropped_rns_factor_indices: RNSFactorIndexList::empty(),
1297            info: modswitch_strategy.info_for_fresh_encryption(&P, &C, None),
1298            data: ctxt
1299        }],
1300        Some(&rk),
1301        &[],
1302        &mut 0,
1303        Some(&sk)
1304    ).into_iter().next().unwrap();
1305
1306    let res_C = Pow2BGV::mod_switch_down_C(&C, &res.dropped_rns_factor_indices);
1307    let res_sk = Pow2BGV::mod_switch_down_sk(&res_C, &C, &res.dropped_rns_factor_indices, &sk);
1308
1309    let res_noise = Pow2BGV::noise_budget(&P, &res_C, &res.data, &res_sk);
1310    println!("Actual output noise budget is {}", res_noise);
1311    assert_el_eq!(&P, &P.neg_one(), Pow2BGV::dec(&P, &res_C, res.data, &res_sk));
1312}
1313
1314#[test]
1315fn test_never_modswitch_strategy() {
1316    let mut rng = thread_rng();
1317
1318    let params = Pow2BGV {
1319        log2_q_min: 500,
1320        log2_q_max: 520,
1321        log2_N: 7,
1322        ciphertext_allocator: DefaultCiphertextAllocator::default(),
1323        negacyclic_ntt: PhantomData::<DefaultNegacyclicNTT>
1324    };
1325    let t = 257;
1326    
1327    let P = params.create_plaintext_ring(t);
1328    let C = params.create_initial_ciphertext_ring();
1329
1330    let sk = Pow2BGV::gen_sk(&C, &mut rng, None);
1331    let rk = Pow2BGV::gen_rk(&P, &C, &mut rng, &sk, &KeySwitchKeyParams::default(3, 0, C.base_ring().len()));
1332
1333    let input = P.int_hom().map(2);
1334    let ctxt = Pow2BGV::enc_sym(&P, &C, &mut rng, &input, &sk);
1335
1336    {
1337        let modswitch_strategy = DefaultModswitchStrategy::never_modswitch();
1338        let pow4_circuit = PlaintextCircuit::mul(ZZ)
1339            .compose(PlaintextCircuit::square(ZZ).output_twice(ZZ), ZZ);
1340
1341        let res = modswitch_strategy.evaluate_circuit(
1342            &pow4_circuit,
1343            ZZi64,
1344            &P,
1345            &C,
1346            &[ModulusAwareCiphertext {
1347                dropped_rns_factor_indices: RNSFactorIndexList::empty(),
1348                info: modswitch_strategy.info_for_fresh_encryption(&P, &C, None),
1349                data: Pow2BGV::clone_ct(&P, &C, &ctxt)
1350            }],
1351            Some(&rk),
1352            &[],
1353            &mut 0,
1354            None
1355        ).into_iter().next().unwrap();
1356
1357        let res_C = Pow2BGV::mod_switch_down_C(&C, &res.dropped_rns_factor_indices);
1358        let res_sk = Pow2BGV::mod_switch_down_sk(&res_C, &C, &res.dropped_rns_factor_indices, &sk);
1359
1360        let res_noise = Pow2BGV::noise_budget(&P, &res_C, &res.data, &res_sk);
1361        println!("Actual output noise budget is {}", res_noise);
1362        assert_el_eq!(&P, &P.int_hom().map(16), Pow2BGV::dec(&P, &res_C, res.data, &res_sk));
1363    }
1364    {
1365        let modswitch_strategy = DefaultModswitchStrategy::never_modswitch();
1366        let pow8_circuit = PlaintextCircuit::mul(ZZ)
1367            .compose(PlaintextCircuit::mul(ZZ).output_twice(ZZ), ZZ)
1368            .compose(PlaintextCircuit::mul(ZZ).output_twice(ZZ), ZZ)
1369            .compose(PlaintextCircuit::identity(1, ZZ).output_twice(ZZ), ZZ);
1370
1371        let res = modswitch_strategy.evaluate_circuit(
1372            &pow8_circuit,
1373            ZZi64,
1374            &P,
1375            &C,
1376            &[ModulusAwareCiphertext {
1377                dropped_rns_factor_indices: RNSFactorIndexList::empty(),
1378                info: modswitch_strategy.info_for_fresh_encryption(&P, &C, None),
1379                data: Pow2BGV::clone_ct(&P, &C, &ctxt)
1380            }],
1381            Some(&rk),
1382            &[],
1383            &mut 0,
1384            None
1385        ).into_iter().next().unwrap();
1386
1387        let res_C = Pow2BGV::mod_switch_down_C(&C, &res.dropped_rns_factor_indices);
1388        let res_sk = Pow2BGV::mod_switch_down_sk(&res_C, &C, &res.dropped_rns_factor_indices, &sk);
1389
1390        let res_noise = Pow2BGV::noise_budget(&P, &res_C, &res.data, &res_sk);
1391        assert_eq!(0, res_noise);
1392    }
1393}