Skip to main content

tfhe/integer/server_key/radix_parallel/
div_mod.rs

1use crate::integer::ciphertext::{IntegerRadixCiphertext, RadixCiphertext, SignedRadixCiphertext};
2use crate::integer::server_key::comparator::ZeroComparisonType;
3use crate::integer::{BooleanBlock, IntegerCiphertext, ServerKey};
4use crate::shortint::MessageModulus;
5use rayon::prelude::*;
6
7impl ServerKey {
8    //======================================================================
9    //                Div Rem
10    //======================================================================
11    pub fn unchecked_div_rem_parallelized<T>(&self, numerator: &T, divisor: &T) -> (T, T)
12    where
13        T: IntegerRadixCiphertext,
14    {
15        if T::IS_SIGNED {
16            let n = SignedRadixCiphertext::from_blocks(numerator.blocks().to_vec());
17            let d = SignedRadixCiphertext::from_blocks(divisor.blocks().to_vec());
18            let (q, r) = self.signed_unchecked_div_rem_parallelized(&n, &d);
19            let q = T::from_blocks(q.into_blocks());
20            let r = T::from_blocks(r.into_blocks());
21            (q, r)
22        } else {
23            let n = RadixCiphertext::from_blocks(numerator.blocks().to_vec());
24            let d = RadixCiphertext::from_blocks(divisor.blocks().to_vec());
25            let (q, r) = self.unsigned_unchecked_div_rem_parallelized(&n, &d);
26            let q = T::from_blocks(q.into_blocks());
27            let r = T::from_blocks(r.into_blocks());
28            (q, r)
29        }
30    }
31
32    pub fn unchecked_div_rem_floor_parallelized(
33        &self,
34        numerator: &SignedRadixCiphertext,
35        divisor: &SignedRadixCiphertext,
36    ) -> (SignedRadixCiphertext, SignedRadixCiphertext) {
37        let (quotient, remainder) = self.unchecked_div_rem_parallelized(numerator, divisor);
38
39        let (remainder_is_not_zero, remainder_and_divisor_signs_disagrees) = rayon::join(
40            || self.unchecked_scalar_ne_parallelized(&remainder, 0),
41            || {
42                let sign_bit_pos = self.key.message_modulus.0.ilog2() - 1;
43                let compare_sign_bits = |x, y| {
44                    let x_sign_bit = (x >> sign_bit_pos) & 1;
45                    let y_sign_bit = (y >> sign_bit_pos) & 1;
46                    u64::from(x_sign_bit != y_sign_bit)
47                };
48                let lut = self.key.generate_lookup_table_bivariate(compare_sign_bits);
49                self.key.unchecked_apply_lookup_table_bivariate(
50                    remainder.blocks().last().unwrap(),
51                    divisor.blocks().last().unwrap(),
52                    &lut,
53                )
54            },
55        );
56
57        let mut condition = remainder_is_not_zero.0;
58        let mut remainder_plus_divisor = remainder.clone();
59        let mut quotient_minus_one = quotient.clone();
60        rayon::scope(|s| {
61            s.spawn(|_| {
62                self.key
63                    .add_assign(&mut condition, &remainder_and_divisor_signs_disagrees);
64            });
65            s.spawn(|_| self.add_assign_parallelized(&mut remainder_plus_divisor, divisor));
66            s.spawn(|_| self.scalar_sub_assign_parallelized(&mut quotient_minus_one, 1));
67        });
68
69        let (quotient, remainder) = rayon::join(
70            || {
71                self.unchecked_programmable_if_then_else_parallelized(
72                    &condition,
73                    &quotient_minus_one,
74                    &quotient,
75                    |x| x == 2,
76                    true,
77                )
78            },
79            || {
80                self.unchecked_programmable_if_then_else_parallelized(
81                    &condition,
82                    &remainder_plus_divisor,
83                    &remainder,
84                    |x| x == 2,
85                    true,
86                )
87            },
88        );
89
90        (quotient, remainder)
91    }
92
93    fn unsigned_div_rem_block_by_block_2_2(
94        &self,
95        numerator: &RadixCiphertext,
96        divisor: &RadixCiphertext,
97    ) -> (RadixCiphertext, RadixCiphertext) {
98        let num_bits_in_block = self.message_modulus().0.ilog2() as usize;
99        assert!(
100            num_bits_in_block == 2 && self.carry_modulus().0 == 4,
101            "This algorithm only works for 2_2 parameters"
102        );
103
104        let num_blocks = numerator.blocks.len();
105
106        let mut remainder = numerator.clone();
107        let mut quotient_blocks = Vec::with_capacity(num_blocks);
108
109        let mut d1 = divisor.clone();
110
111        let (d2, d3) = rayon::join(
112            || {
113                let mut d2 = self.extend_radix_with_trivial_zero_blocks_msb(divisor, 1);
114                self.scalar_left_shift_assign_parallelized(&mut d2, 1);
115                d2
116            },
117            || {
118                self.extend_radix_with_trivial_zero_blocks_msb_assign(&mut d1, 1);
119                let mut d4 = self.blockshift(&d1, 1);
120                self.sub_assign_parallelized(&mut d4, &d1);
121                self.trim_radix_blocks_msb_assign(&mut d1, 1);
122                d4 // 4 * d - d = 3 * d
123            },
124        );
125
126        // This will be used on blocks that contain 2 blocks encoded in
127        // the following way: (block, condition_block) = (block * 2) + condition_block
128        // As the condition_block is always 0 or 1
129        //
130        // The goal is to return 0 if the condition is not 1
131        // (i.e., return block is condition is 1)
132        let zero_out_if_not_1_lut = (
133            self.key.generate_lookup_table(|x| {
134                let block = x / 2;
135                let condition = (x & 1) == 1;
136
137                block * u64::from(condition)
138            }),
139            2u8,
140        );
141
142        // This will be used on blocks that contain 2 blocks encoded in
143        // the following way: (block, condition_block) = (block * 3) + condition_block
144        // As the condition_block is in [0, 1, 2]
145        //
146        // The goal is to return 0 if the condition is not 2
147        // (i.e., return block is condition is 2)
148        let zero_out_if_not_2_lut = (
149            self.key.generate_lookup_table(|x| {
150                let block = x / 3;
151                let condition = (x % 3) == 2;
152
153                block * u64::from(condition)
154            }),
155            3u8,
156        );
157
158        // Luts to generate quotient blocks from a condition block
159        let quotient_block_luts = [
160            // cond is in [0, 1, 2], but only 2 means true
161            // (the divisor fit 1 time)
162            self.key.generate_lookup_table(|cond| u64::from(cond == 2)),
163            // cond is in [0, 1, 2], but only 2 means true
164            // (the divisor fit 2 times)
165            self.key
166                .generate_lookup_table(|cond| u64::from(cond == 2) * 2),
167            // cond is in [0, 1], 1 meaning true
168            // (the divisor fit 3 times)
169            self.key.generate_lookup_table(|cond| cond * 3),
170        ];
171
172        for block_index in (0..num_blocks).rev() {
173            let low1 = RadixCiphertext::from(d1.blocks[..num_blocks - block_index].to_vec());
174            let low2 = RadixCiphertext::from(d2.blocks[..num_blocks - block_index].to_vec());
175            let low3 = RadixCiphertext::from(d3.blocks[..num_blocks - block_index].to_vec());
176            let mut rem = RadixCiphertext::from(remainder.blocks[block_index..].to_vec());
177
178            let (mut sub_results, cmps) = rayon::join(
179                || {
180                    [&low3, &low2, &low1]
181                        .into_par_iter()
182                        .map(|rhs| self.unsigned_overflowing_sub_parallelized(&rem, rhs))
183                        .collect::<Vec<_>>()
184                },
185                || {
186                    [
187                        &d3.blocks[num_blocks - block_index..],
188                        &d2.blocks[num_blocks - block_index..],
189                        &d1.blocks[num_blocks - block_index..],
190                    ]
191                    .into_par_iter()
192                    .map(|blocks| {
193                        let mut b = BooleanBlock::new_unchecked(self.are_all_blocks_zero(blocks));
194                        self.boolean_bitnot_assign(&mut b);
195                        b
196                    })
197                    .collect::<Vec<_>>()
198                },
199            );
200
201            let (mut r1, mut o1) = sub_results.pop().unwrap();
202            let (mut r2, mut o2) = sub_results.pop().unwrap();
203            let (mut r3, mut o3) = sub_results.pop().unwrap();
204
205            [&mut o3, &mut o2, &mut o1]
206                .into_par_iter()
207                .zip(cmps.par_iter())
208                .for_each(|(ox, cmpx)| {
209                    self.boolean_bitor_assign(ox, cmpx);
210                });
211
212            // The cx variables tell whether the corresponding result of the subtraction
213            // should be kept, and what value the quotient block should have
214            //
215            // for c3, c0; the block values are in [0, 1]
216            // for c2, c1; the block values are in [0, 1, 2], 2 meaning true; 0,1 meaning false
217            let c3 = self.boolean_bitnot(&o3).0;
218            let c2 = {
219                let mut c2 = self.boolean_bitnot(&o2).0;
220                self.key.unchecked_add_assign(&mut c2, &o3.0);
221                c2
222            };
223            let c1 = {
224                let mut c1 = self.boolean_bitnot(&o1).0;
225                self.key.unchecked_add_assign(&mut c1, &o2.0);
226                c1
227            };
228            let c0 = o1.0;
229
230            let (_, [q1, q2, q3]) = rayon::join(
231                || {
232                    [&c3, &c2, &c1, &c0]
233                        .into_par_iter()
234                        .zip([&mut r3, &mut r2, &mut r1, &mut rem])
235                        .zip([
236                            &zero_out_if_not_1_lut,
237                            &zero_out_if_not_2_lut,
238                            &zero_out_if_not_2_lut,
239                            &zero_out_if_not_1_lut,
240                        ])
241                        .for_each(|((cx, rx), (lut, factor))| {
242                            // Manual zero_out_if to avoid noise problems
243                            rx.blocks.par_iter_mut().for_each(|block| {
244                                self.key.unchecked_scalar_mul_assign(block, *factor);
245                                self.key.unchecked_add_assign(block, cx);
246                                self.key.apply_lookup_table_assign(block, lut);
247                            });
248                        });
249                },
250                || {
251                    let mut qs = [c1.clone(), c2.clone(), c3.clone()];
252                    qs.par_iter_mut()
253                        .zip(&quotient_block_luts)
254                        .for_each(|(qx, lut)| {
255                            self.key.apply_lookup_table_assign(qx, lut);
256                        });
257                    qs
258                },
259            );
260
261            // Only one of rx and rem is not zero
262            for rx in [&r3, &r2, &r1] {
263                self.unchecked_add_assign(&mut rem, rx);
264            }
265
266            // only one of q1, q2, q3 is not zero
267            let mut q = q1;
268            for qx in [q2, q3] {
269                self.key.unchecked_add_assign(&mut q, &qx);
270            }
271
272            rayon::join(
273                || {
274                    rem.blocks.par_iter_mut().for_each(|block| {
275                        self.key.message_extract_assign(block);
276                    });
277                },
278                || {
279                    self.key.message_extract_assign(&mut q);
280                },
281            );
282
283            remainder.blocks[block_index..].clone_from_slice(&rem.blocks);
284            quotient_blocks.push(q);
285        }
286
287        quotient_blocks.reverse();
288
289        (RadixCiphertext::from(quotient_blocks), remainder)
290    }
291
292    fn unsigned_unchecked_div_rem_parallelized(
293        &self,
294        numerator: &RadixCiphertext,
295        divisor: &RadixCiphertext,
296    ) -> (RadixCiphertext, RadixCiphertext) {
297        assert_eq!(
298            numerator.blocks.len(),
299            divisor.blocks.len(),
300            "numerator and divisor must have same number of blocks"
301        );
302
303        if self.message_modulus().0 == 4 && self.carry_modulus().0 == 4 {
304            return self.unsigned_div_rem_block_by_block_2_2(numerator, divisor);
305        }
306
307        // Pseudocode of the school-book / long-division algorithm:
308        //
309        //
310        // div(N/D):
311        // Q := 0                  -- Initialize quotient and remainder to zero
312        // R := 0
313        // for i := n − 1 .. 0 do  -- Where n is number of bits in N
314        //   R := R << 1           -- Left-shift R by 1 bit
315        //   R(0) := N(i)          -- Set the least-significant bit of R equal to bit i of the
316        //                         -- numerator
317        //   if R ≥ D then
318        //     R := R − D
319        //     Q(i) := 1
320        //   end
321        // end
322        assert_eq!(
323            numerator.blocks.len(),
324            divisor.blocks.len(),
325            "numerator and divisor must have same number of blocks \
326            numerator: {} blocks, divisor: {} blocks",
327            numerator.blocks.len(),
328            divisor.blocks.len(),
329        );
330        assert!(
331            self.key.message_modulus.0.is_power_of_two(),
332            "The message modulus ({}) needs to be a power of two",
333            self.key.message_modulus.0
334        );
335        assert!(
336            numerator.block_carries_are_empty(),
337            "The numerator must have its carries empty"
338        );
339        assert!(
340            divisor.block_carries_are_empty(),
341            "The numerator must have its carries empty"
342        );
343        assert!(numerator
344            .blocks()
345            .iter()
346            .all(|block| block.message_modulus == self.key.message_modulus
347                && block.carry_modulus == self.key.carry_modulus));
348        assert!(divisor
349            .blocks()
350            .iter()
351            .all(|block| block.message_modulus == self.key.message_modulus
352                && block.carry_modulus == self.key.carry_modulus));
353
354        let num_blocks = numerator.blocks.len();
355        let num_bits_in_message = self.key.message_modulus.0.ilog2() as u64;
356        let total_bits = num_bits_in_message * num_blocks as u64;
357
358        let mut quotient: RadixCiphertext = self.create_trivial_zero_radix(num_blocks);
359        let mut remainder1: RadixCiphertext = self.create_trivial_zero_radix(num_blocks);
360        let mut remainder2: RadixCiphertext = self.create_trivial_zero_radix(num_blocks);
361
362        let mut numerator_block_stack = numerator.blocks.clone();
363
364        // The overflow flag is computed by combining 2 separate values,
365        // this vec will contain the lut that merges these two flags.
366        //
367        // Normally only one lut should be needed, and that lut would output a block
368        // encrypting 0 or 1.
369        // However, since the resulting block would then be left shifted and added to
370        // another existing noisy block, we create many LUTs that shift the boolean value
371        // to the correct position, to reduce noise growth
372        let merge_overflow_flags_luts = (0..num_bits_in_message)
373            .map(|bit_position_in_block| {
374                self.key.generate_lookup_table_bivariate(|x, y| {
375                    u64::from(x == 0 && y == 0) << bit_position_in_block
376                })
377            })
378            .collect::<Vec<_>>();
379
380        for i in (0..total_bits as usize).rev() {
381            let block_of_bit = i / num_bits_in_message as usize;
382            let pos_in_block = i % num_bits_in_message as usize;
383
384            // i goes from [total_bits - 1 to 0]
385            // msb_bit_set goes from [0 to total_bits - 1]
386            let msb_bit_set = total_bits as usize - 1 - i;
387
388            let last_non_trivial_block = msb_bit_set / num_bits_in_message as usize;
389            // Index to the first block of the remainder that is fully trivial 0
390            // and all blocks after it are also trivial zeros
391            // This number is in range 1..=num_bocks -1
392            let first_trivial_block = last_non_trivial_block + 1;
393
394            // All blocks starting from the first_trivial_block are known to be trivial
395            // So we can avoid work.
396            // Note that, these are always non-empty (i.e. there is always at least one non trivial
397            // block)
398            let mut interesting_remainder1 =
399                RadixCiphertext::from(remainder1.blocks[..=last_non_trivial_block].to_vec());
400            let mut interesting_remainder2 =
401                RadixCiphertext::from(remainder2.blocks[..=last_non_trivial_block].to_vec());
402            let mut interesting_divisor =
403                RadixCiphertext::from(divisor.blocks[..=last_non_trivial_block].to_vec());
404            let mut divisor_ms_blocks = RadixCiphertext::from(
405                divisor.blocks[((msb_bit_set + 1) / num_bits_in_message as usize)..].to_vec(),
406            );
407
408            // We split the divisor at a block position, when in reality the split should be at a
409            // bit position meaning that potentially (depending on msb_bit_set) the
410            // split versions share some bits they should not. So we do one PBS on the
411            // last block of the interesting_divisor, and first block of divisor_ms_blocks
412            // to trim out bits which should not be there
413
414            let mut trim_last_interesting_divisor_bits = || {
415                if (msb_bit_set + 1).is_multiple_of(num_bits_in_message as usize) {
416                    return;
417                }
418                // The last block of the interesting part of the remainder
419                // can contain bits which we should not account for
420                // we have to zero them out.
421
422                // Where the msb is set in the block
423                let pos_in_block = msb_bit_set as u64 % num_bits_in_message;
424
425                // e.g 2 bits in message:
426                // if pos_in_block is 0, then we want to keep only first bit (right shift mask
427                // by 1) if pos_in_block is 1, then we want to keep the two
428                // bits (right shift mask by 0)
429                let shift_amount = num_bits_in_message - (pos_in_block + 1);
430                // Create mask of 1s on the message part, 0s in the carries
431                let full_message_mask = self.key.message_modulus.0 - 1;
432                // Shift the mask so that we will only keep bits we should
433                let shifted_mask = full_message_mask >> shift_amount;
434
435                let masking_lut = self.key.generate_lookup_table(|x| x & shifted_mask);
436                self.key.apply_lookup_table_assign(
437                    interesting_divisor.blocks.last_mut().unwrap(),
438                    &masking_lut,
439                );
440            };
441
442            let mut trim_first_divisor_ms_bits = || {
443                if divisor_ms_blocks.blocks.is_empty()
444                    || (msb_bit_set + 1).is_multiple_of(num_bits_in_message as usize)
445                {
446                    return;
447                }
448                // As above, we need to zero out some bits, but here it's in the
449                // first block of most significant blocks of the divisor.
450                // The block has the same value as the last block of interesting_divisor.
451                // Here we will zero out the bits that the trim_last_interesting_divisor_bits
452                // above wanted to keep.
453
454                // Where the msb is set in the block
455                let pos_in_block = msb_bit_set as u64 % num_bits_in_message;
456
457                // e.g 2 bits in message:
458                // if pos_in_block is 0, then we want to discard the first bit (left shift mask
459                // by 1) if pos_in_block is 1, then we want to discard the
460                // two bits (left shift mask by 2) let shift_amount =
461                // num_bits_in_message - pos_in_block as u64;
462                let shift_amount = pos_in_block + 1;
463                let full_message_mask = self.key.message_modulus.0 - 1;
464                let shifted_mask = full_message_mask << shift_amount;
465                // Keep the mask within the range of message bits, so that
466                // the estimated degree of the output is < msg_modulus
467                let shifted_mask = shifted_mask & full_message_mask;
468
469                let masking_lut = self.key.generate_lookup_table(|x| x & shifted_mask);
470                self.key.apply_lookup_table_assign(
471                    divisor_ms_blocks.blocks.first_mut().unwrap(),
472                    &masking_lut,
473                );
474            };
475
476            // This does
477            //  R := R << 1; R(0) := N(i)
478            //
479            // We could to that by left shifting, R by one, then unchecked_add the correct numerator
480            // bit.
481            //
482            // However, to keep the remainder clean (noise wise), what we do is that we put the
483            // remainder block from which we need to extract the bit, as the LSB of the
484            // Remainder, so that left shifting will pull the bit we need.
485            let mut left_shift_interesting_remainder1 = || {
486                let numerator_block = numerator_block_stack
487                    .pop()
488                    .expect("internal error: empty numerator block stack in div");
489                // prepend block and then shift
490                interesting_remainder1.blocks.insert(0, numerator_block);
491                self.unchecked_scalar_left_shift_assign_parallelized(
492                    &mut interesting_remainder1,
493                    1,
494                );
495
496                // Extract the block we prepended, and see if it should be dropped
497                // or added back for processing
498                interesting_remainder1.blocks.rotate_left(1);
499                // This unwrap is unreachable, as we are removing the block we added earlier
500                let numerator_block = interesting_remainder1.blocks.pop().unwrap();
501                if pos_in_block != 0 {
502                    // We have not yet extracted all the bits from this numerator
503                    // so, we put it back on the front so that it gets taken next iteration
504                    numerator_block_stack.push(numerator_block);
505                }
506            };
507
508            let mut left_shift_interesting_remainder2 = || {
509                self.unchecked_scalar_left_shift_assign_parallelized(
510                    &mut interesting_remainder2,
511                    1,
512                );
513            };
514
515            let tasks: [&mut (dyn FnMut() + Send + Sync); 4] = [
516                &mut trim_last_interesting_divisor_bits,
517                &mut trim_first_divisor_ms_bits,
518                &mut left_shift_interesting_remainder1,
519                &mut left_shift_interesting_remainder2,
520            ];
521            tasks.into_par_iter().for_each(|task| task());
522
523            // if interesting_remainder1 != 0 -> interesting_remainder2 == 0
524            // if interesting_remainder1 == 0 -> interesting_remainder2 != 0
525            // In practice interesting_remainder1 contains the numerator bit,
526            // but in that position, interesting_remainder2 always has a 0
527            let mut merged_interesting_remainder = interesting_remainder1;
528            self.unchecked_add_assign(&mut merged_interesting_remainder, &interesting_remainder2);
529
530            let do_overflowing_sub = || {
531                self.unchecked_unsigned_overflowing_sub_parallelized(
532                    &merged_interesting_remainder,
533                    &interesting_divisor,
534                )
535            };
536
537            let check_divisor_upper_blocks = || {
538                // Do a comparison (==) with 0 for trivial blocks
539                let trivial_blocks = &divisor_ms_blocks.blocks;
540                if trivial_blocks.is_empty() {
541                    self.key.create_trivial(0)
542                } else {
543                    // We could call unchecked_scalar_ne_parallelized
544                    // But we are in the special case where scalar == 0
545                    // So we can skip some stuff
546                    let tmp = self
547                        .compare_blocks_with_zero(trivial_blocks, ZeroComparisonType::Difference);
548                    self.is_at_least_one_comparisons_block_true(tmp)
549                }
550            };
551
552            // Creates a cleaned version (noise wise) of the merged remainder
553            // so that it can be safely used in bivariate PBSes
554            let create_clean_version_of_merged_remainder = || {
555                RadixCiphertext::from_blocks(
556                    merged_interesting_remainder
557                        .blocks
558                        .par_iter()
559                        .map(|b| self.key.message_extract(b))
560                        .collect(),
561                )
562            };
563
564            // Use nested join as its easier when we need to return values
565            let (
566                (mut new_remainder, subtraction_overflowed),
567                (at_least_one_upper_block_is_non_zero, mut cleaned_merged_interesting_remainder),
568            ) = rayon::join(do_overflowing_sub, || {
569                let (r1, r2) = rayon::join(
570                    check_divisor_upper_blocks,
571                    create_clean_version_of_merged_remainder,
572                );
573
574                (r1, r2)
575            });
576            // explicit drop, so that we do not use it by mistake
577            drop(merged_interesting_remainder);
578
579            let overflow_sum = self.key.unchecked_add(
580                subtraction_overflowed.as_ref(),
581                &at_least_one_upper_block_is_non_zero,
582            );
583            // Give name to closures to improve readability
584            let overflow_happened = |overflow_sum: u64| overflow_sum != 0;
585            let overflow_did_not_happen = |overflow_sum: u64| !overflow_happened(overflow_sum);
586
587            // Here, we will do what zero_out_if does, but to stay within noise constraints,
588            // we do it by hand so that we apply the factor (shift) to the correct block
589            assert!(overflow_sum.degree.get() <= 2); // at_least_one_upper_block_is_non_zero maybe be a trivial 0
590            let factor = MessageModulus(overflow_sum.degree.get() + 1);
591            let mut conditionally_zero_out_merged_interesting_remainder = || {
592                let zero_out_if_overflow_did_not_happen =
593                    self.key.generate_lookup_table_bivariate_with_factor(
594                        |block, overflow_sum| {
595                            if overflow_did_not_happen(overflow_sum) {
596                                0
597                            } else {
598                                block
599                            }
600                        },
601                        factor,
602                    );
603                cleaned_merged_interesting_remainder
604                    .blocks_mut()
605                    .par_iter_mut()
606                    .for_each(|block| {
607                        self.key.unchecked_apply_lookup_table_bivariate_assign(
608                            block,
609                            &overflow_sum,
610                            &zero_out_if_overflow_did_not_happen,
611                        );
612                    });
613            };
614
615            let mut conditionally_zero_out_merged_new_remainder = || {
616                let zero_out_if_overflow_happened =
617                    self.key.generate_lookup_table_bivariate_with_factor(
618                        |block, overflow_sum| {
619                            if overflow_happened(overflow_sum) {
620                                0
621                            } else {
622                                block
623                            }
624                        },
625                        factor,
626                    );
627                new_remainder.blocks_mut().par_iter_mut().for_each(|block| {
628                    self.key.unchecked_apply_lookup_table_bivariate_assign(
629                        block,
630                        &overflow_sum,
631                        &zero_out_if_overflow_happened,
632                    );
633                });
634            };
635
636            let mut set_quotient_bit = || {
637                let did_not_overflow = self.key.unchecked_apply_lookup_table_bivariate(
638                    subtraction_overflowed.as_ref(),
639                    &at_least_one_upper_block_is_non_zero,
640                    &merge_overflow_flags_luts[pos_in_block],
641                );
642
643                self.key
644                    .unchecked_add_assign(&mut quotient.blocks[block_of_bit], &did_not_overflow);
645            };
646
647            let tasks: [&mut (dyn FnMut() + Send + Sync); 3] = [
648                &mut conditionally_zero_out_merged_interesting_remainder,
649                &mut conditionally_zero_out_merged_new_remainder,
650                &mut set_quotient_bit,
651            ];
652            tasks.into_par_iter().for_each(|task| task());
653
654            assert_eq!(
655                remainder1.blocks[..first_trivial_block].len(),
656                cleaned_merged_interesting_remainder.blocks.len()
657            );
658            assert_eq!(
659                remainder2.blocks[..first_trivial_block].len(),
660                new_remainder.blocks.len()
661            );
662            remainder1.blocks[..first_trivial_block]
663                .iter_mut()
664                .zip(cleaned_merged_interesting_remainder.blocks.iter())
665                .for_each(|(remainder_block, new_value)| {
666                    remainder_block.clone_from(new_value);
667                });
668            remainder2.blocks[..first_trivial_block]
669                .iter_mut()
670                .zip(new_remainder.blocks.iter())
671                .for_each(|(remainder_block, new_value)| {
672                    remainder_block.clone_from(new_value);
673                });
674        }
675
676        // Clean the quotient and remainder
677        // as even though they have no carries, they are not at nominal noise level
678        rayon::join(
679            || {
680                remainder1
681                    .blocks_mut()
682                    .par_iter_mut()
683                    .zip(remainder2.blocks.par_iter())
684                    .for_each(|(r1_block, r2_block)| {
685                        self.key.unchecked_add_assign(r1_block, r2_block);
686                        self.key.message_extract_assign(r1_block);
687                    });
688            },
689            || {
690                quotient.blocks_mut().par_iter_mut().for_each(|block| {
691                    self.key.message_extract_assign(block);
692                });
693            },
694        );
695
696        (quotient, remainder1)
697    }
698
699    fn signed_unchecked_div_rem_parallelized(
700        &self,
701        numerator: &SignedRadixCiphertext,
702        divisor: &SignedRadixCiphertext,
703    ) -> (SignedRadixCiphertext, SignedRadixCiphertext) {
704        assert_eq!(
705            numerator.blocks.len(),
706            divisor.blocks.len(),
707            "numerator and divisor must have same length"
708        );
709        let (positive_numerator, positive_divisor) = rayon::join(
710            || {
711                let positive_numerator = self.unchecked_abs_parallelized(numerator);
712                RadixCiphertext::from_blocks(positive_numerator.into_blocks())
713            },
714            || {
715                let positive_divisor = self.unchecked_abs_parallelized(divisor);
716                RadixCiphertext::from_blocks(positive_divisor.into_blocks())
717            },
718        );
719
720        let ((quotient, remainder), sign_bits_are_different) = rayon::join(
721            || self.unsigned_unchecked_div_rem_parallelized(&positive_numerator, &positive_divisor),
722            || {
723                let sign_bit_pos = self.key.message_modulus.0.ilog2() - 1;
724                let compare_sign_bits = |x, y| {
725                    let x_sign_bit = (x >> sign_bit_pos) & 1;
726                    let y_sign_bit = (y >> sign_bit_pos) & 1;
727                    u64::from(x_sign_bit != y_sign_bit)
728                };
729                let lut = self.key.generate_lookup_table_bivariate(compare_sign_bits);
730                self.key.unchecked_apply_lookup_table_bivariate(
731                    numerator.blocks().last().unwrap(),
732                    divisor.blocks().last().unwrap(),
733                    &lut,
734                )
735            },
736        );
737
738        // Rules are
739        // Dividend (numerator) and remainder have the same sign
740        // Quotient is negative if signs of numerator and divisor are different
741        let (quotient, remainder) = rayon::join(
742            || {
743                let negated_quotient = self.neg_parallelized(&quotient);
744
745                let quotient = self.unchecked_programmable_if_then_else_parallelized(
746                    &sign_bits_are_different,
747                    &negated_quotient,
748                    &quotient,
749                    |x| x == 1,
750                    true,
751                );
752                SignedRadixCiphertext::from_blocks(quotient.into_blocks())
753            },
754            || {
755                let negated_remainder = self.neg_parallelized(&remainder);
756
757                let sign_block = numerator.blocks().last().unwrap();
758                let sign_bit_pos = self.key.message_modulus.0.ilog2() - 1;
759
760                let remainder = self.unchecked_programmable_if_then_else_parallelized(
761                    sign_block,
762                    &negated_remainder,
763                    &remainder,
764                    |sign_block| (sign_block >> sign_bit_pos) == 1,
765                    true,
766                );
767                SignedRadixCiphertext::from_blocks(remainder.into_blocks())
768            },
769        );
770
771        (quotient, remainder)
772    }
773
774    /// Computes homomorphically the quotient and remainder of the division between two ciphertexts
775    ///
776    /// # Notes
777    ///
778    /// When the divisor is 0:
779    ///
780    /// - For unsigned operands, the returned quotient will be the max value (i.e. all bits set to
781    ///   1), the remainder will have the value of the numerator.
782    ///
783    /// - For signed operands, remainder will have the same value as the numerator, and, if the
784    ///   numerator is < 0, quotient will be -1 else 1
785    ///
786    /// This behaviour should not be relied on.
787    ///
788    /// # Example
789    ///
790    /// ```rust
791    /// use tfhe::integer::gen_keys_radix;
792    /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
793    ///
794    /// // Generate the client key and the server key:
795    /// let num_blocks = 4;
796    /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks);
797    ///
798    /// let msg1 = 97;
799    /// let msg2 = 14;
800    ///
801    /// let ct1 = cks.encrypt(msg1);
802    /// let ct2 = cks.encrypt(msg2);
803    ///
804    /// // Compute homomorphically the quotient and remainder:
805    /// let (q_res, r_res) = sks.div_rem_parallelized(&ct1, &ct2);
806    ///
807    /// // Decrypt:
808    /// let q: u64 = cks.decrypt(&q_res);
809    /// let r: u64 = cks.decrypt(&r_res);
810    /// assert_eq!(q, msg1 / msg2);
811    /// assert_eq!(r, msg1 % msg2);
812    /// ```
813    pub fn div_rem_parallelized<T>(&self, numerator: &T, divisor: &T) -> (T, T)
814    where
815        T: IntegerRadixCiphertext,
816    {
817        let mut tmp_numerator;
818        let mut tmp_divisor;
819
820        let (numerator, divisor) = match (
821            numerator.block_carries_are_empty(),
822            divisor.block_carries_are_empty(),
823        ) {
824            (true, true) => (numerator, divisor),
825            (true, false) => {
826                tmp_divisor = divisor.clone();
827                self.full_propagate_parallelized(&mut tmp_divisor);
828                (numerator, &tmp_divisor)
829            }
830            (false, true) => {
831                tmp_numerator = numerator.clone();
832                self.full_propagate_parallelized(&mut tmp_numerator);
833                (&tmp_numerator, divisor)
834            }
835            (false, false) => {
836                tmp_divisor = divisor.clone();
837                tmp_numerator = numerator.clone();
838                rayon::join(
839                    || self.full_propagate_parallelized(&mut tmp_numerator),
840                    || self.full_propagate_parallelized(&mut tmp_divisor),
841                );
842                (&tmp_numerator, &tmp_divisor)
843            }
844        };
845
846        self.unchecked_div_rem_parallelized(numerator, divisor)
847    }
848
849    pub fn smart_div_rem_parallelized<T>(&self, numerator: &mut T, divisor: &mut T) -> (T, T)
850    where
851        T: IntegerRadixCiphertext,
852    {
853        rayon::join(
854            || {
855                if !numerator.block_carries_are_empty() {
856                    self.full_propagate_parallelized(numerator);
857                }
858            },
859            || {
860                if !divisor.block_carries_are_empty() {
861                    self.full_propagate_parallelized(divisor);
862                }
863            },
864        );
865        self.unchecked_div_rem_parallelized(numerator, divisor)
866    }
867
868    //======================================================================
869    //                Div
870    //======================================================================
871
872    pub fn unchecked_div_assign_parallelized<T>(&self, numerator: &mut T, divisor: &T)
873    where
874        T: IntegerRadixCiphertext,
875    {
876        let (q, _r) = self.unchecked_div_rem_parallelized(numerator, divisor);
877        *numerator = q;
878    }
879
880    pub fn unchecked_div_parallelized<T>(&self, numerator: &T, divisor: &T) -> T
881    where
882        T: IntegerRadixCiphertext,
883    {
884        let (q, _r) = self.unchecked_div_rem_parallelized(numerator, divisor);
885        q
886    }
887
888    pub fn smart_div_assign_parallelized<T>(&self, numerator: &mut T, divisor: &mut T)
889    where
890        T: IntegerRadixCiphertext,
891    {
892        let (q, _r) = self.smart_div_rem_parallelized(numerator, divisor);
893        *numerator = q;
894    }
895
896    pub fn smart_div_parallelized<T>(&self, numerator: &mut T, divisor: &mut T) -> T
897    where
898        T: IntegerRadixCiphertext,
899    {
900        let (q, _r) = self.smart_div_rem_parallelized(numerator, divisor);
901        q
902    }
903
904    pub fn div_assign_parallelized<T>(&self, numerator: &mut T, divisor: &T)
905    where
906        T: IntegerRadixCiphertext,
907    {
908        let mut tmp_divisor;
909
910        let (numerator, divisor) = match (
911            numerator.block_carries_are_empty(),
912            divisor.block_carries_are_empty(),
913        ) {
914            (true, true) => (numerator, divisor),
915            (true, false) => {
916                tmp_divisor = divisor.clone();
917                self.full_propagate_parallelized(&mut tmp_divisor);
918                (numerator, &tmp_divisor)
919            }
920            (false, true) => {
921                self.full_propagate_parallelized(numerator);
922                (numerator, divisor)
923            }
924            (false, false) => {
925                tmp_divisor = divisor.clone();
926                rayon::join(
927                    || self.full_propagate_parallelized(numerator),
928                    || self.full_propagate_parallelized(&mut tmp_divisor),
929                );
930                (numerator, &tmp_divisor)
931            }
932        };
933
934        let (q, _r) = self.unchecked_div_rem_parallelized(numerator, divisor);
935        *numerator = q;
936    }
937
938    /// Computes homomorphically the quotient of the division between two ciphertexts
939    ///
940    /// # Note
941    ///
942    /// If you need both the quotient and remainder use [Self::div_rem_parallelized].
943    ///
944    /// # Example
945    ///
946    /// ```rust
947    /// use tfhe::integer::gen_keys_radix;
948    /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
949    ///
950    /// // Generate the client key and the server key:
951    /// let num_blocks = 4;
952    /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks);
953    ///
954    /// let msg1 = 97;
955    /// let msg2 = 14;
956    ///
957    /// let ct1 = cks.encrypt(msg1);
958    /// let ct2 = cks.encrypt(msg2);
959    ///
960    /// // Compute homomorphically a division:
961    /// let ct_res = sks.div_parallelized(&ct1, &ct2);
962    ///
963    /// // Decrypt:
964    /// let dec_result: u64 = cks.decrypt(&ct_res);
965    /// assert_eq!(dec_result, msg1 / msg2);
966    /// ```
967    pub fn div_parallelized<T>(&self, numerator: &T, divisor: &T) -> T
968    where
969        T: IntegerRadixCiphertext,
970    {
971        let (q, _r) = self.div_rem_parallelized(numerator, divisor);
972        q
973    }
974
975    //======================================================================
976    //                Rem
977    //======================================================================
978
979    pub fn unchecked_rem_assign_parallelized<T>(&self, numerator: &mut T, divisor: &T)
980    where
981        T: IntegerRadixCiphertext,
982    {
983        let (_q, r) = self.unchecked_div_rem_parallelized(numerator, divisor);
984        *numerator = r;
985    }
986
987    pub fn unchecked_rem_parallelized<T>(&self, numerator: &T, divisor: &T) -> T
988    where
989        T: IntegerRadixCiphertext,
990    {
991        let (_q, r) = self.unchecked_div_rem_parallelized(numerator, divisor);
992        r
993    }
994
995    pub fn smart_rem_assign_parallelized<T>(&self, numerator: &mut T, divisor: &mut T)
996    where
997        T: IntegerRadixCiphertext,
998    {
999        let (_q, r) = self.smart_div_rem_parallelized(numerator, divisor);
1000        *numerator = r;
1001    }
1002
1003    pub fn smart_rem_parallelized<T>(&self, numerator: &mut T, divisor: &mut T) -> T
1004    where
1005        T: IntegerRadixCiphertext,
1006    {
1007        let (_q, r) = self.smart_div_rem_parallelized(numerator, divisor);
1008        r
1009    }
1010
1011    pub fn rem_assign_parallelized<T>(&self, numerator: &mut T, divisor: &T)
1012    where
1013        T: IntegerRadixCiphertext,
1014    {
1015        let mut tmp_divisor;
1016
1017        let (numerator, divisor) = match (
1018            numerator.block_carries_are_empty(),
1019            divisor.block_carries_are_empty(),
1020        ) {
1021            (true, true) => (numerator, divisor),
1022            (true, false) => {
1023                tmp_divisor = divisor.clone();
1024                self.full_propagate_parallelized(&mut tmp_divisor);
1025                (numerator, &tmp_divisor)
1026            }
1027            (false, true) => {
1028                self.full_propagate_parallelized(numerator);
1029                (numerator, divisor)
1030            }
1031            (false, false) => {
1032                tmp_divisor = divisor.clone();
1033                rayon::join(
1034                    || self.full_propagate_parallelized(numerator),
1035                    || self.full_propagate_parallelized(&mut tmp_divisor),
1036                );
1037                (numerator, &tmp_divisor)
1038            }
1039        };
1040
1041        let (_q, r) = self.unchecked_div_rem_parallelized(numerator, divisor);
1042        *numerator = r;
1043    }
1044
1045    /// Computes homomorphically the remainder (rest) of the division between two ciphertexts
1046    ///
1047    /// # Note
1048    ///
1049    /// If you need both the quotient and remainder use [Self::div_rem_parallelized].
1050    ///
1051    /// # Example
1052    ///
1053    /// ```rust
1054    /// use tfhe::integer::gen_keys_radix;
1055    /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
1056    ///
1057    /// // Generate the client key and the server key:
1058    /// let num_blocks = 4;
1059    /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks);
1060    ///
1061    /// let msg1 = 97;
1062    /// let msg2 = 14;
1063    ///
1064    /// let ct1 = cks.encrypt(msg1);
1065    /// let ct2 = cks.encrypt(msg2);
1066    ///
1067    /// // Compute homomorphically the remainder:
1068    /// let ct_res = sks.rem_parallelized(&ct1, &ct2);
1069    ///
1070    /// // Decrypt:
1071    /// let dec_result: u64 = cks.decrypt(&ct_res);
1072    /// assert_eq!(dec_result, msg1 % msg2);
1073    /// ```
1074    pub fn rem_parallelized<T>(&self, numerator: &T, divisor: &T) -> T
1075    where
1076        T: IntegerRadixCiphertext,
1077    {
1078        let (_q, r) = self.div_rem_parallelized(numerator, divisor);
1079        r
1080    }
1081
1082    /// Computes homomorphically the quotient and remainder of the division between two ciphertexts
1083    ///
1084    /// Returns an additional flag indicating if the divisor was 0
1085    ///
1086    /// # Example
1087    ///
1088    /// ```rust
1089    /// use tfhe::integer::gen_keys_radix;
1090    /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
1091    ///
1092    /// // Generate the client key and the server key:
1093    /// let num_blocks = 4;
1094    /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks);
1095    ///
1096    /// let msg = 97u8;
1097    ///
1098    /// let ct1 = cks.encrypt(msg);
1099    /// let ct2 = cks.encrypt(0u8);
1100    ///
1101    /// // Compute homomorphically a division:
1102    /// let (ct_q, ct_r, div_by_0) = sks.checked_div_rem_parallelized(&ct1, &ct2);
1103    ///
1104    /// // Decrypt:
1105    /// let div_by_0 = cks.decrypt_bool(&div_by_0);
1106    /// assert!(div_by_0);
1107    ///
1108    /// let q: u8 = cks.decrypt(&ct_q);
1109    /// assert_eq!(u8::MAX, q);
1110    ///
1111    /// let r: u8 = cks.decrypt(&ct_r);
1112    /// assert_eq!(msg, r);
1113    /// ```
1114    pub fn checked_div_rem_parallelized<T>(
1115        &self,
1116        numerator: &T,
1117        divisor: &T,
1118    ) -> (T, T, BooleanBlock)
1119    where
1120        T: IntegerRadixCiphertext,
1121    {
1122        let ((q, r), div_by_0) = rayon::join(
1123            || self.div_rem_parallelized(numerator, divisor),
1124            || self.are_all_blocks_zero(divisor.blocks()),
1125        );
1126
1127        (q, r, BooleanBlock::new_unchecked(div_by_0))
1128    }
1129
1130    /// Computes homomorphically the quotient of the division between two ciphertexts
1131    ///
1132    /// Returns an additional flag indicating if the divisor was 0
1133    ///
1134    /// # Note
1135    ///
1136    /// If you need both the quotient and remainder use [Self::div_rem_parallelized].
1137    ///
1138    /// # Example
1139    ///
1140    /// ```rust
1141    /// use tfhe::integer::gen_keys_radix;
1142    /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
1143    ///
1144    /// // Generate the client key and the server key:
1145    /// let num_blocks = 4;
1146    /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks);
1147    ///
1148    /// let msg = 97u8;
1149    ///
1150    /// let ct1 = cks.encrypt(msg);
1151    /// let ct2 = cks.encrypt(0u8);
1152    ///
1153    /// // Compute homomorphically a division:
1154    /// let (ct_res, div_by_0) = sks.checked_div_parallelized(&ct1, &ct2);
1155    ///
1156    /// // Decrypt:
1157    /// let div_by_0 = cks.decrypt_bool(&div_by_0);
1158    /// assert!(div_by_0);
1159    ///
1160    /// let dec_result: u8 = cks.decrypt(&ct_res);
1161    /// assert_eq!(u8::MAX, dec_result);
1162    /// ```
1163    pub fn checked_div_parallelized<T>(&self, numerator: &T, divisor: &T) -> (T, BooleanBlock)
1164    where
1165        T: IntegerRadixCiphertext,
1166    {
1167        let (q, div_by_0) = rayon::join(
1168            || self.div_parallelized(numerator, divisor),
1169            || self.are_all_blocks_zero(divisor.blocks()),
1170        );
1171
1172        (q, BooleanBlock::new_unchecked(div_by_0))
1173    }
1174
1175    /// Computes homomorphically the remainder (rest) of the division between two ciphertexts
1176    ///
1177    /// Returns an additional flag indicating if the divisor was 0
1178    ///
1179    /// # Note
1180    ///
1181    /// If you need both the quotient and remainder use [Self::checked_div_rem_parallelized].
1182    ///
1183    /// # Example
1184    ///
1185    /// ```rust
1186    /// use tfhe::integer::gen_keys_radix;
1187    /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
1188    ///
1189    /// // Generate the client key and the server key:
1190    /// let num_blocks = 4;
1191    /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_blocks);
1192    ///
1193    /// let msg = 97u8;
1194    ///
1195    /// let ct1 = cks.encrypt(msg);
1196    /// let ct2 = cks.encrypt(0u8);
1197    ///
1198    /// // Compute homomorphically the remainder:
1199    /// let (ct_res, rem_by_0) = sks.checked_rem_parallelized(&ct1, &ct2);
1200    ///
1201    /// // Decrypt:
1202    /// let rem_by_0 = cks.decrypt_bool(&rem_by_0);
1203    /// assert!(rem_by_0);
1204    ///
1205    /// let dec_result: u8 = cks.decrypt(&ct_res);
1206    /// assert_eq!(dec_result, msg);
1207    /// ```
1208    pub fn checked_rem_parallelized<T>(&self, numerator: &T, divisor: &T) -> (T, BooleanBlock)
1209    where
1210        T: IntegerRadixCiphertext,
1211    {
1212        let (r, rem_by_0) = rayon::join(
1213            || self.rem_parallelized(numerator, divisor),
1214            || self.are_all_blocks_zero(divisor.blocks()),
1215        );
1216
1217        (r, BooleanBlock::new_unchecked(rem_by_0))
1218    }
1219}