1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
mod abs;
mod add;
mod bit_extractor;
mod bitwise_op;
mod block_shift;
pub(crate) mod cmux;
mod comparison;
mod count_consecutive_bits;
mod div_mod;
mod modulus_switch_compression;
mod mul;
mod neg;
mod rotate;
mod scalar_add;
mod scalar_bitwise_op;
mod scalar_comparison;
pub(crate) mod scalar_div_mod;
mod scalar_mul;
mod scalar_rotate;
mod scalar_shift;
mod scalar_sub;
mod shift;
pub(crate) mod sub;
mod sum;
mod count_zeros_ones;
pub(crate) mod ilog2;
pub(crate) mod kv_store;
mod reverse_bits;
mod scalar_dot_prod;
mod slice;
#[cfg(test)]
pub(crate) mod tests_cases_unsigned;
#[cfg(test)]
pub(crate) mod tests_long_run;
#[cfg(test)]
pub(crate) mod tests_signed;
#[cfg(test)]
pub(crate) mod tests_unsigned;
mod vector_comparisons;
mod vector_find;
use std::borrow::Cow;
use super::ServerKey;
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::RadixCiphertext;
use crate::shortint::ciphertext::{Ciphertext, NoiseLevel};
pub(crate) use add::OutputFlag;
use rayon::prelude::*;
pub use scalar_div_mod::{MiniUnsignedInteger, Reciprocable};
pub use vector_find::MatchValues;
// parallelized versions
impl ServerKey {
/// Propagate the carry of the 'index' block to the next one.
///
/// # Example
///
///```rust
/// use tfhe::integer::{gen_keys_radix, IntegerCiphertext};
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
///
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, num_blocks);
///
/// let msg = 7u64;
///
/// let ct1 = cks.encrypt(msg);
/// let ct2 = cks.encrypt(msg);
///
/// // Compute homomorphically an addition:
/// let mut ct_res = sks.unchecked_add(&ct1, &ct2);
/// let carry = sks.propagate_parallelized(&mut ct_res, 0);
///
/// // Decrypt one block:
/// let res: u64 = cks.decrypt_one_block(&ct_res.blocks()[1]);
/// assert_eq!(3, res);
/// ```
pub fn propagate_parallelized<T>(&self, ctxt: &mut T, index: usize) -> Ciphertext
where
T: IntegerRadixCiphertext,
{
let (carry, message) = rayon::join(
|| self.key.carry_extract(&ctxt.blocks()[index]),
|| self.key.message_extract(&ctxt.blocks()[index]),
);
ctxt.blocks_mut()[index] = message;
//add the carry to the next block
if index < ctxt.blocks().len() - 1 {
self.key
.unchecked_add_assign(&mut ctxt.blocks_mut()[index + 1], &carry);
}
carry
}
/// Propagates carries starting from start_index.
///
/// Does nothing if start_index >= ctxt.len() or ctxt is empty
pub fn partial_propagate_parallelized<T>(&self, ctxt: &mut T, start_index: usize)
where
T: IntegerRadixCiphertext,
{
if start_index >= ctxt.blocks().len() {
return;
}
self.partial_propagate_blocks_parallelized(&mut ctxt.blocks_mut()[start_index..]);
}
fn partial_propagate_blocks_parallelized(&self, blocks: &mut [Ciphertext]) {
// Extract message blocks and carry blocks from the
// input block slice.
// Carries Vec has one less block than message Vec
let extract_message_and_carry_blocks = |blocks: &[Ciphertext]| {
let num_blocks = blocks.len();
rayon::join(
|| {
blocks
.par_iter()
.map(|block| self.key.message_extract(block))
.collect::<Vec<_>>()
},
|| {
let mut carry_blocks = Vec::with_capacity(num_blocks);
// No need to compute the carry of the last block, we would just throw it away
blocks[..num_blocks - 1]
.par_iter()
.map(|block| self.key.carry_extract(block))
.collect_into_vec(&mut carry_blocks);
carry_blocks
},
)
};
if self.is_eligible_for_parallel_single_carry_propagation(blocks.len()) {
let highest_degree = blocks
.iter()
.max_by(|block_a, block_b| block_a.degree.get().cmp(&block_b.degree.get()))
.map(|block| block.degree.get())
.unwrap(); // We checked for emptiness earlier
if highest_degree >= (self.key.message_modulus.0 - 1) * 2 {
// At least one of the blocks has more than one carry,
// we need to extract message and carries, then add + propagate
let (mut message_blocks, carry_blocks) = extract_message_and_carry_blocks(blocks);
blocks[0] = message_blocks.remove(0);
let mut lhs = RadixCiphertext::from(message_blocks);
let rhs = RadixCiphertext::from(carry_blocks);
self.add_assign_with_carry_parallelized(&mut lhs, &rhs, None);
blocks[1..].clone_from_slice(&lhs.blocks);
} else {
self.propagate_single_carry_parallelized(&mut blocks[..]);
}
} else {
let maybe_highest_degree =
// We do not care about degree of 'first' block as it won't receive any carries
blocks[1..]
.iter()
.max_by(|block_a, block_b| block_a.degree.get().cmp(&block_b.degree.get()))
.map(|block| block.degree.get());
let mut start_index = 0;
if maybe_highest_degree.is_some_and(|degree| degree > self.key.max_degree.get())
|| blocks[1..]
.iter()
// We need to be able to add a carry, which is a fresh ciphertext
.any(|b| b.noise_level().get() >= self.key.max_noise_level.get() - 1)
{
// At least one of the blocks than can receive a carry, won't be able to,
// so we need to do a first 'partial' round
let (mut message_blocks, carry_blocks) = extract_message_and_carry_blocks(blocks);
blocks[0..].swap_with_slice(&mut message_blocks);
for (block, carry) in blocks[1..].iter_mut().zip(carry_blocks.iter()) {
self.key.unchecked_add_assign(block, carry);
}
// We can start propagation one index later as we already did the first block
start_index += 1;
}
let len = blocks.len();
// If start_index >= len, the range is considered empty
for i in start_index..len {
let (carry, message) = rayon::join(
|| self.key.carry_extract(&blocks[i]),
|| self.key.message_extract(&blocks[i]),
);
blocks[i] = message;
//add the carry to the next block
if i < blocks.len() - 1 {
self.key.unchecked_add_assign(&mut blocks[i + 1], &carry);
}
}
}
}
/// Propagate all the carries.
///
/// # Example
///
///```rust
/// use tfhe::integer::gen_keys_radix;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
///
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, num_blocks);
///
/// let msg = 10u64;
///
/// let ct1 = cks.encrypt(msg);
/// let ct2 = cks.encrypt(msg);
///
/// // Compute homomorphically an addition:
/// let mut ct_res = sks.unchecked_add(&ct1, &ct2);
/// sks.full_propagate_parallelized(&mut ct_res);
///
/// // Decrypt:
/// let res: u64 = cks.decrypt(&ct_res);
/// assert_eq!(msg + msg, res);
/// ```
pub fn full_propagate_parallelized<T>(&self, ctxt: &mut T)
where
T: IntegerRadixCiphertext,
{
let num_blocks = ctxt.blocks().len();
// Start the propagation on the first block that has carries
let Some(start_index) = ctxt
.blocks()
.iter()
.position(|block| !block.carry_is_empty())
else {
// No block has any carries. However, we still need to clean the noise
ctxt.blocks_mut()
.par_iter_mut()
.filter(|block| block.noise_level() > NoiseLevel::NOMINAL)
.for_each(|block| self.key.message_extract_assign(block));
return;
};
// End the propagation 2 blocks after the last non-trivial zero block.
// We end it 2 blocks after because, in the worst case:
// 1) the last block (block `n`) 'immediate' carry will be the block n+1
// 2) a carry may be propagated from a block `< n` to block `n + 2` because block n is full
//
// e.g., 2 blocks, 2_2, notation: 0bcarry|msg, block in little endian order
// input: [0b11|11, 0b11|11, 0b00|00, 0b00|00, 0b00|00] = (5 * 15) % (1024)
// msg : [0b00|11, 0b00|11, 0b00|00, 0b00|00, 0b00|00] = 15
// carry: [0b00|00, 0b00|11, 0b00|11, 0b00|00, 0b00|00] = 60
// msg + carry = [0b00|11, 0b00|11+11, 0b00|11, 0b00|00, 0b00|00]
// msg + carry = [0b00|00, 0b00|10, 0b00|11+1, 0b00|00, 0b00|00]
// msg + carry = [0b00|00, 0b00|10, 0b00|00, 0b0|01, 0b00|00] = 75 = 5 * 15
// As we can see, the result is on n+2 = 4 block
// there is no need to consider the last block
let end_index = ctxt
.blocks()
.iter()
.rposition(|block| block.degree.get() != 0)
.map(|pos| (pos + 2).min(num_blocks - 1))
// This can't fail, has if start_index is Some(_)
// => there is a block with carries
// => there is a block non-zero degree
.unwrap();
let (blocks, ignored) = ctxt.blocks_mut().split_at_mut(end_index + 1);
let n = blocks.len();
let (to_be_cleaned, to_be_propagated) = blocks.split_at_mut(start_index.min(n));
rayon::scope(|s| {
if !to_be_propagated.is_empty() {
s.spawn(|_| {
self.partial_propagate_blocks_parallelized(to_be_propagated);
})
}
if !to_be_cleaned.is_empty() {
s.spawn(|_| {
to_be_cleaned
.par_iter_mut()
.filter(|block| block.noise_level() > NoiseLevel::NOMINAL)
.for_each(|block| self.key.message_extract_assign(block));
});
}
if !ignored.is_empty() {
// We may actually still need to clean blocks in the 'ignored' part of the
// propagation
s.spawn(|_| {
ignored
.par_iter_mut()
.filter(|block| block.noise_level() > NoiseLevel::NOMINAL)
.for_each(|block| self.key.message_extract_assign(block));
});
}
})
}
/// Cleans the input ct so that it is ready to be used in a default ops
///
/// Returns a Cow::Owned if a clone was necessary for the cleaning,
/// Cow::Borrowed otherwise
pub(crate) fn clean_for_default_op<'a, T>(&self, ct: &'a T) -> Cow<'a, T>
where
T: IntegerRadixCiphertext,
{
if ct
.blocks()
.iter()
.any(|block| !block.carry_is_empty() || block.noise_level() != NoiseLevel::NOMINAL)
{
let mut cloned = ct.clone();
self.full_propagate_parallelized(&mut cloned);
Cow::Owned(cloned)
} else {
Cow::Borrowed(ct)
}
}
}