concrete_integer/server_key/radix_parallel/
add.rs

1use std::sync::Mutex;
2
3use crate::ciphertext::RadixCiphertext;
4use crate::ServerKey;
5
6impl ServerKey {
7    /// Computes homomorphically an addition between two ciphertexts encrypting integer values.
8    ///
9    /// # Warning
10    ///
11    /// - Multithreaded
12    ///
13    /// # Example
14    ///
15    /// ```rust
16    /// use concrete_integer::gen_keys_radix;
17    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
18    ///
19    /// // Generate the client key and the server key:
20    /// let num_blocks = 4;
21    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks);
22    ///
23    /// let msg1 = 14;
24    /// let msg2 = 97;
25    ///
26    /// let mut ct1 = cks.encrypt(msg1);
27    /// let mut ct2 = cks.encrypt(msg2);
28    ///
29    /// // Compute homomorphically an addition:
30    /// let ct_res = sks.smart_add_parallelized(&mut ct1, &mut ct2);
31    ///
32    /// // Decrypt:
33    /// let dec_result = cks.decrypt(&ct_res);
34    /// assert_eq!(dec_result, msg1 + msg2);
35    /// ```
36    pub fn smart_add_parallelized(
37        &self,
38        ct_left: &mut RadixCiphertext,
39        ct_right: &mut RadixCiphertext,
40    ) -> RadixCiphertext {
41        if !self.is_add_possible(ct_left, ct_right) {
42            rayon::join(
43                || self.full_propagate_parallelized(ct_left),
44                || self.full_propagate_parallelized(ct_right),
45            );
46        }
47        self.unchecked_add(ct_left, ct_right)
48    }
49
50    pub fn smart_add_assign_parallelized(
51        &self,
52        ct_left: &mut RadixCiphertext,
53        ct_right: &mut RadixCiphertext,
54    ) {
55        if !self.is_add_possible(ct_left, ct_right) {
56            rayon::join(
57                || self.full_propagate_parallelized(ct_left),
58                || self.full_propagate_parallelized(ct_right),
59            );
60        }
61        self.unchecked_add_assign(ct_left, ct_right);
62    }
63
64    /// op must be associative and commutative
65    pub fn smart_binary_op_seq_parallelized<'this, 'item>(
66        &'this self,
67        ct_seq: impl IntoIterator<Item = &'item mut RadixCiphertext>,
68        op: impl for<'a> Fn(
69                &'a ServerKey,
70                &'a mut RadixCiphertext,
71                &'a mut RadixCiphertext,
72            ) -> RadixCiphertext
73            + Sync,
74    ) -> Option<RadixCiphertext> {
75        enum CiphertextCow<'a> {
76            Borrowed(&'a mut RadixCiphertext),
77            Owned(RadixCiphertext),
78        }
79        impl CiphertextCow<'_> {
80            fn as_mut(&mut self) -> &mut RadixCiphertext {
81                match self {
82                    CiphertextCow::Borrowed(b) => b,
83                    CiphertextCow::Owned(o) => o,
84                }
85            }
86        }
87
88        let ct_seq = ct_seq
89            .into_iter()
90            .map(CiphertextCow::Borrowed)
91            .collect::<Vec<_>>();
92        let op = &op;
93
94        // overhead of dynamic dispatch is negligible compared to multithreading, PBS, etc.
95        // we defer all calls to a single implementation to avoid code bloat and long compile
96        // times
97        fn reduce_impl(
98            sks: &ServerKey,
99            mut ct_seq: Vec<CiphertextCow>,
100            op: &(dyn for<'a> Fn(
101                &'a ServerKey,
102                &'a mut RadixCiphertext,
103                &'a mut RadixCiphertext,
104            ) -> RadixCiphertext
105                  + Sync),
106        ) -> Option<RadixCiphertext> {
107            use rayon::prelude::*;
108
109            if ct_seq.is_empty() {
110                None
111            } else {
112                // we repeatedly divide the number of terms by two by iteratively reducing
113                // consecutive terms in the array
114                while ct_seq.len() > 1 {
115                    let results =
116                        Mutex::new(Vec::<RadixCiphertext>::with_capacity(ct_seq.len() / 2));
117
118                    // if the number of elements is odd, we skip the first element
119                    let untouched_prefix = ct_seq.len() % 2;
120                    let ct_seq_slice = &mut ct_seq[untouched_prefix..];
121
122                    ct_seq_slice.par_chunks_mut(2).for_each(|chunk| {
123                        let (first, second) = chunk.split_at_mut(1);
124                        let first = &mut first[0];
125                        let second = &mut second[0];
126                        let result = op(sks, first.as_mut(), second.as_mut());
127                        results.lock().unwrap().push(result);
128                    });
129
130                    let results = results.into_inner().unwrap();
131                    ct_seq.truncate(untouched_prefix);
132                    ct_seq.extend(results.into_iter().map(CiphertextCow::Owned));
133                }
134
135                let sum = ct_seq.pop().unwrap();
136
137                Some(match sum {
138                    CiphertextCow::Borrowed(b) => b.clone(),
139                    CiphertextCow::Owned(o) => o,
140                })
141            }
142        }
143
144        reduce_impl(self, ct_seq, op)
145    }
146}