1
2use feanor_math::algorithms::miller_rabin::is_prime;
3use feanor_math::divisibility::*;
4use feanor_math::primitive_int::{StaticRing, StaticRingBase};
5use feanor_math::ring::*;
6use feanor_math::rings::poly::dense_poly::DensePolyRing;
7use feanor_math::rings::zn::zn_64::Zn;
8use feanor_math::homomorphism::*;
9use polys::{digit_retain_poly, poly_to_circuit, precomputed_p_2};
10use tracing::instrument;
11
12use crate::circuit::PlaintextCircuit;
13
14pub mod polys;
15
16pub struct DigitExtract<R: ?Sized + RingBase = StaticRingBase<i64>> {
37 extraction_circuits: Vec<(Vec<usize>, PlaintextCircuit<R>)>,
38 identity_circuit: PlaintextCircuit<R>,
40 add_circuit: PlaintextCircuit<R>,
42 sub_circuit: PlaintextCircuit<R>,
44 v: usize,
45 e: usize,
46 p: i64
47}
48
49impl DigitExtract {
50
51 #[instrument(skip_all)]
57 pub fn new_precomputed_p_is_2(p: i64, e: usize, r: usize) -> Self {
58 assert_eq!(2, p);
59 assert!(is_prime(&StaticRing::<i64>::RING, &p, 10));
60 return Self::new_with(
61 p,
62 e,
63 r,
64 StaticRing::<i64>::RING,
65 [1, 2, 4, 8, 16, 23].into_iter().map(|e| (
66 [1, 2, 4, 8, 16, 23].into_iter().take_while(|i| *i <= e).collect(),
67 precomputed_p_2(e)
68 )).collect::<Vec<_>>()
69 );
70 }
71
72 #[instrument(skip_all)]
80 pub fn new_default(p: i64, e: usize, r: usize) -> Self {
81 assert!(is_prime(&StaticRing::<i64>::RING, &p, 10));
82 assert!(e > r);
83 let v = e - r;
84
85 let digit_extraction_circuits = (1..=v).rev().map(|i| {
86 let required_digits = (2..=(v - i + 1)).chain([r + v - i + 1].into_iter()).collect::<Vec<_>>();
87 let poly_ring = DensePolyRing::new(Zn::new(StaticRing::<i64>::RING.pow(p, *required_digits.last().unwrap()) as u64), "X");
88 let circuit = poly_to_circuit(&poly_ring, &required_digits.iter().map(|j| digit_retain_poly(&poly_ring, *j)).collect::<Vec<_>>());
89 return (required_digits, circuit);
90 }).collect::<Vec<_>>();
91 assert!(digit_extraction_circuits.is_sorted_by_key(|(digits, _)| *digits.last().unwrap()));
92
93 return Self::new_with(p, e, r, StaticRing::<i64>::RING, digit_extraction_circuits);
94 }
95}
96
97impl<R: ?Sized + RingBase> DigitExtract<R> {
98
99 pub fn new_with<S: Copy + RingStore<Type = R>>(p: i64, e: usize, r: usize, ring: S, extraction_circuits: Vec<(Vec<usize>, PlaintextCircuit<R>)>) -> Self {
110 assert!(is_prime(&StaticRing::<i64>::RING, &p, 10));
111 assert!(e > r);
112 for (digits, circuit) in &extraction_circuits {
113 assert!(digits.is_sorted());
114 assert_eq!(digits.len(), circuit.output_count());
115 assert_eq!(1, circuit.input_count());
116 }
117 assert!(extraction_circuits.iter().any(|(digits, _)| *digits.last().unwrap() >= e));
118 Self {
119 extraction_circuits: extraction_circuits,
120 add_circuit: PlaintextCircuit::add(ring),
121 sub_circuit: PlaintextCircuit::sub(ring),
122 identity_circuit: PlaintextCircuit::identity(1, ring),
123 v: e - r,
124 p: p,
125 e: e
126 }
127 }
128
129 pub fn r(&self) -> usize {
130 self.e - self.v
131 }
132
133 pub fn e(&self) -> usize {
134 self.e
135 }
136
137 pub fn v(&self) -> usize {
138 self.v
139 }
140
141 pub fn p(&self) -> i64 {
142 self.p
143 }
144
145 pub fn evaluate_generic<T, EvalCircuit, ChangeSpace>(&self,
169 input: T,
170 mut eval_circuit: EvalCircuit,
171 mut change_space: ChangeSpace
172 ) -> (T, T)
173 where EvalCircuit: FnMut(usize, &[T], &PlaintextCircuit<R>) -> Vec<T>,
174 ChangeSpace: FnMut(usize, usize, T) -> T
175 {
176 let e = self.e;
177 let r = self.e - self.v;
178
179 enum OneOrTwoValues<T> {
180 One(T), Two([T; 2])
181 }
182
183 impl<T> OneOrTwoValues<T> {
184
185 fn with_first_el<'a>(&'a mut self, first: T) -> &'a mut [T; 2] {
186 take_mut::take(self, |value| match value {
187 OneOrTwoValues::One(second) => OneOrTwoValues::Two([first, second]),
188 OneOrTwoValues::Two([_, second]) => OneOrTwoValues::Two([first, second])
189 });
190 return match self {
191 OneOrTwoValues::One(_) => unreachable!(),
192 OneOrTwoValues::Two(data) => data
193 };
194 }
195
196 fn get_second<'a>(&'a self) -> &'a T {
197 match self {
198 OneOrTwoValues::One(second) => second,
199 OneOrTwoValues::Two([_, second]) => second
200 }
201 }
202 }
203
204 let clone_value = |modulus_exp: usize, value: &T, eval_circuit: &mut EvalCircuit| eval_circuit(modulus_exp, std::slice::from_ref(value), &self.identity_circuit).into_iter().next().unwrap();
205 let sub_values = |modulus_exp: usize, params: &[T; 2], eval_circuit: &mut EvalCircuit| eval_circuit(modulus_exp, params, &self.sub_circuit).into_iter().next().unwrap();
206 let add_values = |modulus_exp: usize, params: &[T; 2], eval_circuit: &mut EvalCircuit| eval_circuit(modulus_exp, params, &self.add_circuit).into_iter().next().unwrap();
207
208 let mut mod_result: Option<T> = None;
209 let mut partial_floor_divs = (0..self.v).map(|_| Some(clone_value(e, &input, &mut eval_circuit))).collect::<Vec<_>>();
210 let mut floor_div_result = input;
211 for i in 0..self.v {
212 let remaining_digits = e - i;
213 debug_assert!(self.extraction_circuits.is_sorted_by_key(|(digits, _)| *digits.last().unwrap()));
214 let (use_circuit_digits, use_circuit) = self.extraction_circuits.iter().filter(|(digits, _)| *digits.last().unwrap() >= remaining_digits).next().unwrap();
215 debug_assert!(use_circuit_digits.is_sorted());
216
217 let current = change_space(e, remaining_digits, partial_floor_divs[i].take().unwrap());
218 let digit_extracted = eval_circuit(remaining_digits, std::slice::from_ref(¤t), use_circuit);
219 let mut digit_extracted = digit_extracted.into_iter().map(|value| OneOrTwoValues::One(change_space(remaining_digits, e, value))).collect::<Vec<_>>();
220
221 let last_digit_extracted = digit_extracted.last_mut().unwrap();
222 take_mut::take(&mut floor_div_result, |current| sub_values(e, last_digit_extracted.with_first_el(current), &mut eval_circuit));
223 if let Some(mod_result) = &mut mod_result {
224 take_mut::take(mod_result, |current| add_values(e, last_digit_extracted.with_first_el(current), &mut eval_circuit));
225 } else {
226 mod_result = Some(clone_value(e, last_digit_extracted.get_second(), &mut eval_circuit));
227 }
228 for j in (i + 1)..self.v {
229 let digit_extracted_index = use_circuit_digits.iter().enumerate().filter(|(_, cleared_digits)| **cleared_digits > j - i).next().unwrap().0;
230 take_mut::take(partial_floor_divs[j].as_mut().unwrap(), |current| sub_values(e, digit_extracted[digit_extracted_index].with_first_el(current), &mut eval_circuit));
231 }
232 }
233
234 return (change_space(e, r, floor_div_result), mod_result.unwrap());
235 }
236
237 pub fn evaluate<H, S>(&self, input: S::Element, hom: H) -> (S::Element, S::Element)
252 where H: Homomorphism<R, S>,
253 S: ?Sized + RingBase + DivisibilityRing
254 {
255 let p = hom.codomain().int_hom().map(self.p as i32);
256 self.evaluate_generic(
257 input,
258 |_, params, circuit| circuit.evaluate_no_galois(params, &hom),
259 |from, to, x| if from < to {
260 hom.codomain().mul(x, hom.codomain().pow(hom.codomain().clone_el(&p), to - from))
261 } else {
262 hom.codomain().checked_div(&x, &hom.codomain().pow(hom.codomain().clone_el(&p), from - to)).unwrap()
263 }
264 )
265 }
266}
267
268#[cfg(test)]
269use feanor_math::rings::zn::ZnRingStore;
270#[cfg(test)]
271use feanor_math::assert_el_eq;
272#[cfg(test)]
273use feanor_math::divisibility::DivisibilityRingStore;
274#[cfg(test)]
275use feanor_math::rings::extension::FreeAlgebraStore;
276#[cfg(test)]
277use feanor_math::seq::VectorFn;
278#[cfg(test)]
279use rand::SeedableRng;
280#[cfg(test)]
281use rand::rngs::StdRng;
282#[cfg(test)]
283use crate::bfv::*;
284#[cfg(test)]
285use crate::DefaultNegacyclicNTT;
286#[cfg(test)]
287use std::alloc::Global;
288#[cfg(test)]
289use std::marker::PhantomData;
290
291#[test]
292fn test_digit_extract() {
293 let digitextract = DigitExtract::new_default(3, 5, 2);
294 let ring = Zn::new(StaticRing::<i64>::RING.pow(3, 5) as u64);
295 let hom = ring.can_hom(&StaticRing::<i64>::RING).unwrap();
296
297 for x in 0..*ring.modulus() {
298 let (quo, rem) = digitextract.evaluate_generic(
299 (5, hom.map(x)),
300 |exp, params, circuit| {
301 assert!(params.iter().all(|(p_exp, _)| *p_exp == exp));
302 circuit.evaluate_no_galois(¶ms.iter().map(|(_, x)| *x).collect::<Vec<_>>(), &hom).into_iter().map(|x| (exp, x)).collect()
303 },
304 |from, to, (exp, x)| {
305 assert_eq!(from, exp);
306 if from < to {
307 (to, ring.mul(x, ring.pow(hom.map(3), to - from)))
308 } else {
309 (to, ring.checked_div(&x, &ring.pow(hom.map(3), from - to)).unwrap())
310 }
311 }
312 );
313 assert_eq!(5, rem.0);
314 assert_el_eq!(&ring, hom.map(x % 27), rem.1);
315 assert_eq!(2, quo.0);
316 assert_eq!(x / 27, ring.smallest_positive_lift(quo.1) % 9);
317 }
318}
319
320#[test]
321fn test_digit_extract_homomorphic() {
322 let mut rng = StdRng::from_seed([1; 32]);
323
324 let params = Pow2BFV {
325 log2_q_min: 500,
326 log2_q_max: 520,
327 log2_N: 6,
328 ciphertext_allocator: Global,
329 negacyclic_ntt: PhantomData::<DefaultNegacyclicNTT>
330 };
331 let digits = 3;
332
333 let P1 = params.create_plaintext_ring(17 * 17);
334 let P2 = params.create_plaintext_ring(17 * 17 * 17);
335 let (C, Cmul) = params.create_ciphertext_rings();
336
337 let sk = Pow2BFV::gen_sk(&C, &mut rng, None);
338 let rk = Pow2BFV::gen_rk(&C, &mut rng, &sk, digits);
339
340 let m = P2.int_hom().map(17 * 17 + 2 * 17 + 5);
341 let ct = Pow2BFV::enc_sym(&P2, &C, &mut rng, &m, &sk);
342
343 let digitextract = DigitExtract::new_default(17, 2, 1);
344
345 let (ct_high, ct_low) = digitextract.evaluate_bfv::<Pow2BFV>(&P1, std::slice::from_ref(&P2), &C, &Cmul, ct, &rk);
346
347 let m_high = Pow2BFV::dec(&P1, &C, Pow2BFV::clone_ct(&C, &ct_high), &sk);
348 assert!(P1.wrt_canonical_basis(&m_high).iter().skip(1).all(|x| P1.base_ring().is_zero(&x)));
349 let m_high = P1.base_ring().smallest_positive_lift(P1.wrt_canonical_basis(&m_high).at(0));
350 assert_eq!(2, m_high % 17);
351
352 let m_low = Pow2BFV::dec(&P2, &C, Pow2BFV::clone_ct(&C, &ct_low), &sk);
353 assert!(P1.wrt_canonical_basis(&m_low).iter().skip(1).all(|x| P2.base_ring().is_zero(&x)));
354 let m_low = P1.base_ring().smallest_positive_lift(P1.wrt_canonical_basis(&m_low).at(0));
355 assert_eq!(5, m_low % (17 * 17));
356}
357
358#[test]
359fn test_digit_extract_evaluate() {
360 let ring = Zn::new(16);
361 let digit_extract = DigitExtract::new_default(2, 4, 2);
362 for x in 0..16 {
363 let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
364 assert_eq!(x / 4, ring.smallest_positive_lift(actual_high) as i32 % 4);
365 assert_eq!(x % 4, ring.smallest_positive_lift(actual_low) as i32);
366 }
367
368 let ring = Zn::new(81);
369 let digit_extract = DigitExtract::new_default(3, 4, 2);
370 for x in 0..81 {
371 let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
372 assert_eq!(x / 9, ring.smallest_positive_lift(actual_high) as i32 % 9);
373 assert_eq!(x % 9, ring.smallest_positive_lift(actual_low) as i32);
374 }
375
376 let ring = Zn::new(125);
377 let digit_extract = DigitExtract::new_default(5, 3, 2);
378 for x in 0..125 {
379 let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
380 assert_eq!(x / 5, ring.smallest_positive_lift(actual_high) as i32 % 25);
381 assert_eq!(x % 5, ring.smallest_positive_lift(actual_low) as i32);
382 }
383}
384
385#[test]
386fn test_digit_extract_evaluate_ignore_higher() {
387 let ring = Zn::new(64);
388 let digit_extract = DigitExtract::new_default(2, 4, 2);
389 for x in 0..64 {
390 let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
391 assert_eq!((x / 4) % 4, ring.smallest_positive_lift(actual_high) as i32 % 4);
392 assert_eq!(x % 4, ring.smallest_positive_lift(actual_low) as i32 % 16);
393 }
394
395 let ring = Zn::new(243);
396 let digit_extract = DigitExtract::new_default(3, 4, 2);
397 for x in 0..243 {
398 let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
399 assert_eq!((x / 9) % 9, ring.smallest_positive_lift(actual_high) as i32 % 9);
400 assert_eq!(x % 9, ring.smallest_positive_lift(actual_low) as i32 % 81);
401 }
402
403 let ring = Zn::new(625);
404 let digit_extract = DigitExtract::new_default(5, 3, 2);
405 for x in 0..625 {
406 let (actual_high, actual_low) = digit_extract.evaluate(ring.int_hom().map(x), ring.can_hom(&StaticRing::<i64>::RING).unwrap());
407 assert_eq!((x / 5) % 25, ring.smallest_positive_lift(actual_high) as i32 % 25);
408 assert_eq!(x % 5, ring.smallest_positive_lift(actual_low) as i32 % 125);
409 }
410}