he_ring/digitextract/
mod.rs

1
2use feanor_math::algorithms::miller_rabin::is_prime;
3use feanor_math::divisibility::*;
4use feanor_math::primitive_int::{StaticRing, StaticRingBase};
5use feanor_math::ring::*;
6use feanor_math::rings::poly::dense_poly::DensePolyRing;
7use feanor_math::rings::zn::zn_64::Zn;
8use feanor_math::homomorphism::*;
9use polys::{digit_retain_poly, poly_to_circuit, precomputed_p_2};
10use tracing::instrument;
11
12use crate::circuit::PlaintextCircuit;
13
14pub mod polys;
15
16///
17/// The digit extraction operation, as required during BFV and
18/// BGV bootstrapping.
19/// 
20/// Concretely, this encapsulates an efficient implementation of the
21/// per-slot digit extraction function
22/// ```text
23///   Z/p^eZ -> Z/p^rZ x Z/p^eZ,  x -> (x - (x mod p^v) / p^v, x mod p^v)
24/// ```
25/// for `v = e - r`. Here `x mod p^v` refers to the smallest positive element
26/// of `Z/p^eZ` that is congruent to `x` modulo `p^v`.
27/// 
28/// This function can also be applied to values in a ring `Z/p^e'Z` for
29/// `e' > e`, i.e. it will then have the signature
30/// ```text
31///   Z/p^e'Z -> Z/p^(e' - e + r)Z x Z/p^e'Z
32/// ```
33/// In this case, the results are only specified modulo `p^r` resp. `p^e`, i.e.
34/// may be perturbed by an arbitrary value `p^r a` resp. `p^e a'`.
35/// 
36pub struct DigitExtract<R: ?Sized + RingBase = StaticRingBase<i64>> {
37    extraction_circuits: Vec<(Vec<usize>, PlaintextCircuit<R>)>,
38    /// the one-input, one-output identity circuit
39    identity_circuit: PlaintextCircuit<R>,
40    /// the two-input, one-output addition circuit
41    add_circuit: PlaintextCircuit<R>,
42    /// the two-input, one-output subtraction circuit
43    sub_circuit: PlaintextCircuit<R>,
44    v: usize,
45    e: usize,
46    p: i64
47}
48
49impl DigitExtract {
50
51    ///
52    /// Creates a [`DigitExtract`] for a scalar ring `Z/2^eZ`.
53    /// 
54    /// Uses the precomputed table of best digit extraction circuits for `e <= 23`.
55    /// 
56    #[instrument(skip_all)]
57    pub fn new_precomputed_p_is_2(p: i64, e: usize, r: usize) -> Self {
58        assert_eq!(2, p);
59        assert!(is_prime(&StaticRing::<i64>::RING, &p, 10));
60        return Self::new_with(
61            p, 
62            e, 
63            r, 
64            StaticRing::<i64>::RING, 
65            [1, 2, 4, 8, 16, 23].into_iter().map(|e| (
66                [1, 2, 4, 8, 16, 23].into_iter().take_while(|i| *i <= e).collect(),
67                precomputed_p_2(e)
68            )).collect::<Vec<_>>()
69        );
70    }
71    
72    ///
73    /// Creates a [`DigitExtract`] for a scalar ring `Z/p^eZ`.
74    /// 
75    /// Uses the Chen-Han digit retain polynomials <https://ia.cr/2018/067> together with
76    /// a heuristic method to compile them into an arithmetic circuit, based on the
77    /// Paterson-Stockmeyer method.
78    /// 
79    #[instrument(skip_all)]
80    pub fn new_default(p: i64, e: usize, r: usize) -> Self {
81        assert!(is_prime(&StaticRing::<i64>::RING, &p, 10));
82        assert!(e > r);
83        let v = e - r;
84        
85        let digit_extraction_circuits = (1..=v).rev().map(|i| {
86            let required_digits = (2..=(v - i + 1)).chain([r + v - i + 1].into_iter()).collect::<Vec<_>>();
87            let poly_ring = DensePolyRing::new(Zn::new(StaticRing::<i64>::RING.pow(p, *required_digits.last().unwrap()) as u64), "X");
88            let circuit = poly_to_circuit(&poly_ring, &required_digits.iter().map(|j| digit_retain_poly(&poly_ring, *j)).collect::<Vec<_>>());
89            return (required_digits, circuit);
90        }).collect::<Vec<_>>();
91        assert!(digit_extraction_circuits.is_sorted_by_key(|(digits, _)| *digits.last().unwrap()));
92        
93        return Self::new_with(p, e, r, StaticRing::<i64>::RING, digit_extraction_circuits);
94    }
95}
96
97impl<R: ?Sized + RingBase> DigitExtract<R> {
98
99    ///
100    /// Creates a new [`DigitExtract`] from the given circuits.
101    /// 
102    /// This functions expects the list of circuits to contain tuples `(digits, C)`,
103    /// where the circuit `C` takes a single input and computes `digits.len()` outputs, 
104    /// such that the `i`-th output is congruent to `lift(input mod p)` modulo 
105    /// `p^digits[i]`.
106    /// 
107    /// If you want to use the default choice of circuits, consider using [`DigitExtract::new_default()`].
108    /// 
109    pub fn new_with<S: Copy + RingStore<Type = R>>(p: i64, e: usize, r: usize, ring: S, extraction_circuits: Vec<(Vec<usize>, PlaintextCircuit<R>)>) -> Self {
110        assert!(is_prime(&StaticRing::<i64>::RING, &p, 10));
111        assert!(e > r);
112        for (digits, circuit) in &extraction_circuits {
113            assert!(digits.is_sorted());
114            assert_eq!(digits.len(), circuit.output_count());
115            assert_eq!(1, circuit.input_count());
116        }
117        assert!(extraction_circuits.iter().any(|(digits, _)| *digits.last().unwrap() >= e));
118        Self {
119            extraction_circuits: extraction_circuits,
120            add_circuit: PlaintextCircuit::add(ring),
121            sub_circuit: PlaintextCircuit::sub(ring),
122            identity_circuit: PlaintextCircuit::identity(1, ring),
123            v: e - r,
124            p: p,
125            e: e
126        }
127    }
128
129    pub fn r(&self) -> usize {
130        self.e - self.v
131    }
132
133    pub fn e(&self) -> usize {
134        self.e
135    }
136
137    pub fn v(&self) -> usize {
138        self.v
139    }
140
141    pub fn p(&self) -> i64 {
142        self.p
143    }
144    
145    ///
146    /// Evaluates the digit extraction function over any representation of elements of `Z/p^iZ`, which
147    /// supports the evaluation of [`PlaintextCircuit`]s. Since digit extraction requires computations
148    /// in all the rings `Z/p^(r - 1)Z, ...., Z/p^eZ`, we also require a `change_space` function, with
149    /// the following properties:
150    /// ```text
151    ///   change_space(e, e', .): Z/p^eZ -> Z/p^e' Z
152    ///   change_space(e, e', x mod p^e) = x p^(e' - e) mod p^e'      if e' > e
153    ///   change_space(e, e', x mod p^e) = x / p^(e - e') mod p^e'    if e' < e and p^(e - e') | x
154    /// ```
155    /// If the passed functions behave as specified, `change_space(e, e', x)` will never be called for
156    /// `e' < e` and an `x` which is not divisible by `p^(e - e')`.
157    /// 
158    /// Furthermore, the `eval_circuit` is given the exponent of the current ring we work in as the first
159    /// parameter. The result of [`DigitExtract::evaluate_generic()`] is then the tuple `(quo, rem)` with
160    /// `quo` in `Z/p^rZ` and `rem` in `Z/p^eZ` such that `x = p^(e - r) * quo + rem` and `rem < p^(e - r)`.
161    /// 
162    /// If [`DigitExtract`] is used on elements of `Z/p^e'Z` with `e' > e` (as mentioned at the end of
163    /// the doc of [`DigitExtract`]), the moduli passed to `eval_circuit()` and `change_space()` remain
164    /// nevertheless unchanged - after all, `evaluate_generic()` does not know that we are in a larger
165    /// ring. If necessary, you have to manually offset all exponents passed to `eval_circuit` and 
166    /// `change_space` by `e' - e`.
167    /// 
168    pub fn evaluate_generic<T, EvalCircuit, ChangeSpace>(&self, 
169        input: T,
170        mut eval_circuit: EvalCircuit,
171        mut change_space: ChangeSpace
172    ) -> (T, T) 
173        where EvalCircuit: FnMut(/* exponent of p */ usize, &[T], &PlaintextCircuit<R>) -> Vec<T>,
174            ChangeSpace: FnMut(/* input exponent of p */ usize, /* output exponent of p */ usize, T) -> T
175    {
176        let e = self.e;
177        let r = self.e - self.v;
178
179        enum OneOrTwoValues<T> {
180            One(T), Two([T; 2])
181        }
182
183        impl<T> OneOrTwoValues<T> {
184
185            fn with_first_el<'a>(&'a mut self, first: T) -> &'a mut [T; 2] {
186                take_mut::take(self, |value| match value {
187                    OneOrTwoValues::One(second) => OneOrTwoValues::Two([first, second]),
188                    OneOrTwoValues::Two([_, second]) => OneOrTwoValues::Two([first, second])
189                });
190                return match self {
191                    OneOrTwoValues::One(_) => unreachable!(),
192                    OneOrTwoValues::Two(data) => data
193                };
194            }
195
196            fn get_second<'a>(&'a self) -> &'a T {
197                match self {
198                    OneOrTwoValues::One(second) => second,
199                    OneOrTwoValues::Two([_, second]) => second
200                }
201            }
202        }
203
204        let clone_value = |modulus_exp: usize, value: &T, eval_circuit: &mut EvalCircuit| eval_circuit(modulus_exp, std::slice::from_ref(value), &self.identity_circuit).into_iter().next().unwrap();
205        let sub_values = |modulus_exp: usize, params: &[T; 2], eval_circuit: &mut EvalCircuit| eval_circuit(modulus_exp, params, &self.sub_circuit).into_iter().next().unwrap();
206        let add_values = |modulus_exp: usize, params: &[T; 2], eval_circuit: &mut EvalCircuit| eval_circuit(modulus_exp, params, &self.add_circuit).into_iter().next().unwrap();
207
208        let mut mod_result: Option<T> = None;
209        let mut partial_floor_divs = (0..self.v).map(|_| Some(clone_value(e, &input, &mut eval_circuit))).collect::<Vec<_>>();
210        let mut floor_div_result = input;
211        for i in 0..self.v {
212            let remaining_digits = e - i;
213            debug_assert!(self.extraction_circuits.is_sorted_by_key(|(digits, _)| *digits.last().unwrap()));
214            let (use_circuit_digits, use_circuit) = self.extraction_circuits.iter().filter(|(digits, _)| *digits.last().unwrap() >= remaining_digits).next().unwrap();
215            debug_assert!(use_circuit_digits.is_sorted());
216
217            let current = change_space(e, remaining_digits, partial_floor_divs[i].take().unwrap());
218            let digit_extracted = eval_circuit(remaining_digits, std::slice::from_ref(&current), use_circuit);
219            let mut digit_extracted = digit_extracted.into_iter().map(|value| OneOrTwoValues::One(change_space(remaining_digits, e, value))).collect::<Vec<_>>();
220            
221            let last_digit_extracted = digit_extracted.last_mut().unwrap();
222            take_mut::take(&mut floor_div_result, |current| sub_values(e, last_digit_extracted.with_first_el(current), &mut eval_circuit));
223            if let Some(mod_result) = &mut mod_result {
224                take_mut::take(mod_result, |current| add_values(e, last_digit_extracted.with_first_el(current), &mut eval_circuit));
225            } else {
226                mod_result = Some(clone_value(e, last_digit_extracted.get_second(), &mut eval_circuit));
227            }
228            for j in (i + 1)..self.v {
229                let digit_extracted_index = use_circuit_digits.iter().enumerate().filter(|(_, cleared_digits)| **cleared_digits > j - i).next().unwrap().0;
230                take_mut::take(partial_floor_divs[j].as_mut().unwrap(), |current| sub_values(e, digit_extracted[digit_extracted_index].with_first_el(current), &mut eval_circuit));
231            }
232        }
233
234        return (change_space(e, r, floor_div_result), mod_result.unwrap());
235    }
236
237    ///
238    /// Computes `(quo, rem)` with `input = quo * p^(e - r) + rem` and `rem < p^(e - r)`.
239    /// Note that both `quo` and `rem` are returned as elements of `Z/p^eZ`, which means that
240    /// `quo` is defined only up to a multiple of `p^r`.
241    /// 
242    /// This function is designed to test digit extraction, since `quo` and `rem` will be computed
243    /// exactly in the same way as in a homomorphic setting. Note also that performing euclidean
244    /// division can be done much easier with [`feanor_math::pid::EuclideanRing::euclidean_div_rem()`]
245    /// when you have access to the ring elements.
246    /// 
247    /// This function does not perform any checks on the underlying ring, in particular, you can
248    /// call it on an input in `Z/p^e'Z` with `e' > e` or an input in `Z`. Of course, in any case,
249    /// the output will only be correct modulo `p^r` resp. `p^e`.
250    /// 
251    pub fn evaluate<H, S>(&self, input: S::Element, hom: H) -> (S::Element, S::Element)
252        where H: Homomorphism<R, S>,
253            S: ?Sized + RingBase + DivisibilityRing
254    {
255        let p = hom.codomain().int_hom().map(self.p as i32);
256        self.evaluate_generic(
257            input,
258            |_, params, circuit| circuit.evaluate_no_galois(params, &hom),
259            |from, to, x| if from < to {
260                hom.codomain().mul(x, hom.codomain().pow(hom.codomain().clone_el(&p), to - from))
261            } else {
262                hom.codomain().checked_div(&x, &hom.codomain().pow(hom.codomain().clone_el(&p), from - to)).unwrap()
263            }
264        )
265    }
266}
267
268#[cfg(test)]
269use feanor_math::rings::zn::ZnRingStore;
270#[cfg(test)]
271use feanor_math::assert_el_eq;
272#[cfg(test)]
273use feanor_math::divisibility::DivisibilityRingStore;
274#[cfg(test)]
275use feanor_math::rings::extension::FreeAlgebraStore;
276#[cfg(test)]
277use feanor_math::seq::VectorFn;
278#[cfg(test)]
279use rand::SeedableRng;
280#[cfg(test)]
281use rand::rngs::StdRng;
282#[cfg(test)]
283use crate::bfv::*;
284#[cfg(test)]
285use crate::DefaultNegacyclicNTT;
286#[cfg(test)]
287use std::alloc::Global;
288#[cfg(test)]
289use std::marker::PhantomData;
290
291#[test]
292fn test_digit_extract() {
293    let digitextract = DigitExtract::new_default(3, 5, 2);
294    let ring = Zn::new(StaticRing::<i64>::RING.pow(3, 5) as u64);
295    let hom = ring.can_hom(&StaticRing::<i64>::RING).unwrap();
296
297    for x in 0..*ring.modulus() {
298        let (quo, rem) = digitextract.evaluate_generic(
299            (5, hom.map(x)),
300            |exp, params, circuit| {
301                assert!(params.iter().all(|(p_exp, _)| *p_exp == exp));
302                circuit.evaluate_no_galois(&params.iter().map(|(_, x)| *x).collect::<Vec<_>>(), &hom).into_iter().map(|x| (exp, x)).collect()
303            },
304            |from, to, (exp, x)| {
305                assert_eq!(from, exp);
306                if from < to {
307                    (to, ring.mul(x, ring.pow(hom.map(3), to - from)))
308                } else {
309                    (to, ring.checked_div(&x, &ring.pow(hom.map(3), from - to)).unwrap())
310                }
311            }
312        );
313        assert_eq!(5, rem.0);
314        assert_el_eq!(&ring, hom.map(x % 27), rem.1);
315        assert_eq!(2, quo.0);
316        assert_eq!(x / 27, ring.smallest_positive_lift(quo.1) % 9);
317    }
318}
319
320#[test]
321fn test_digit_extract_homomorphic() {
322    let mut rng = StdRng::from_seed([1; 32]);
323    
324    let params = Pow2BFV {
325        log2_q_min: 500,
326        log2_q_max: 520,
327        log2_N: 6,
328        ciphertext_allocator: Global,
329        negacyclic_ntt: PhantomData::<DefaultNegacyclicNTT>
330    };
331    let digits = 3;
332    
333    let P1 = params.create_plaintext_ring(17 * 17);
334    let P2 = params.create_plaintext_ring(17 * 17 * 17);
335    let (C, Cmul) = params.create_ciphertext_rings();
336
337    let sk = Pow2BFV::gen_sk(&C, &mut rng, None);
338    let rk = Pow2BFV::gen_rk(&C, &mut rng, &sk, digits);
339
340    let m = P2.int_hom().map(17 * 17 + 2 * 17 + 5);
341    let ct = Pow2BFV::enc_sym(&P2, &C, &mut rng, &m, &sk);
342
343    let digitextract = DigitExtract::new_default(17, 2, 1);
344
345    let (ct_high, ct_low) = digitextract.evaluate_bfv::<Pow2BFV>(&P1, std::slice::from_ref(&P2), &C, &Cmul, ct, &rk);
346
347    let m_high = Pow2BFV::dec(&P1, &C, Pow2BFV::clone_ct(&C, &ct_high), &sk);
348    assert!(P1.wrt_canonical_basis(&m_high).iter().skip(1).all(|x| P1.base_ring().is_zero(&x)));
349    let m_high = P1.base_ring().smallest_positive_lift(P1.wrt_canonical_basis(&m_high).at(0));
350    assert_eq!(2, m_high % 17);
351
352    let m_low = Pow2BFV::dec(&P2, &C, Pow2BFV::clone_ct(&C, &ct_low), &sk);
353    assert!(P1.wrt_canonical_basis(&m_low).iter().skip(1).all(|x| P2.base_ring().is_zero(&x)));
354    let m_low = P1.base_ring().smallest_positive_lift(P1.wrt_canonical_basis(&m_low).at(0));
355    assert_eq!(5, m_low % (17 * 17));
356}
357
358#[test]
359fn test_digit_extract_evaluate() {
360    let ring = Zn::new(16);
361    let digit_extract = DigitExtract::new_default(2, 4, 2);
362    for x in 0..16 {
363        let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
364        assert_eq!(x / 4, ring.smallest_positive_lift(actual_high) as i32 % 4);
365        assert_eq!(x % 4, ring.smallest_positive_lift(actual_low) as i32);
366    }
367
368    let ring = Zn::new(81);
369    let digit_extract = DigitExtract::new_default(3, 4, 2);
370    for x in 0..81 {
371        let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
372        assert_eq!(x / 9, ring.smallest_positive_lift(actual_high) as i32 % 9);
373        assert_eq!(x % 9, ring.smallest_positive_lift(actual_low) as i32);
374    }
375
376    let ring = Zn::new(125);
377    let digit_extract = DigitExtract::new_default(5, 3, 2);
378    for x in 0..125 {
379        let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
380        assert_eq!(x / 5, ring.smallest_positive_lift(actual_high) as i32 % 25);
381        assert_eq!(x % 5, ring.smallest_positive_lift(actual_low) as i32);
382    }
383}
384
385#[test]
386fn test_digit_extract_evaluate_ignore_higher() {
387    let ring = Zn::new(64);
388    let digit_extract = DigitExtract::new_default(2, 4, 2);
389    for x in 0..64 {
390        let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
391        assert_eq!((x / 4) % 4, ring.smallest_positive_lift(actual_high) as i32 % 4);
392        assert_eq!(x % 4, ring.smallest_positive_lift(actual_low) as i32 % 16);
393    }
394
395    let ring = Zn::new(243);
396    let digit_extract = DigitExtract::new_default(3, 4, 2);
397    for x in 0..243 {
398        let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
399        assert_eq!((x / 9) % 9, ring.smallest_positive_lift(actual_high) as i32 % 9);
400        assert_eq!(x % 9, ring.smallest_positive_lift(actual_low) as i32 % 81);
401    }
402
403    let ring = Zn::new(625);
404    let digit_extract = DigitExtract::new_default(5, 3, 2);
405    for x in 0..625 {
406        let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
407        assert_eq!((x / 5) % 25, ring.smallest_positive_lift(actual_high) as i32 % 25);
408        assert_eq!(x % 5, ring.smallest_positive_lift(actual_low) as i32 % 125);
409    }
410}