concrete_integer/server_key/radix/scalar_mul.rs
1use crate::ciphertext::RadixCiphertext;
2use crate::server_key::CheckError;
3use crate::server_key::CheckError::CarryFull;
4use crate::ServerKey;
5use std::collections::BTreeMap;
6
7impl ServerKey {
8 /// Computes homomorphically a multiplication between a scalar and a ciphertext.
9 ///
10 /// This function computes the operation without checking if it exceeds the capacity of the
11 /// ciphertext.
12 ///
13 /// The result is returned as a new ciphertext.
14 ///
15 /// # Example
16 ///
17 /// ```rust
18 /// use concrete_integer::gen_keys_radix;
19 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
20 ///
21 /// // We have 4 * 2 = 8 bits of message
22 /// let size = 4;
23 /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
24 ///
25 /// let msg = 30;
26 /// let scalar = 3;
27 ///
28 /// let ct = cks.encrypt(msg);
29 ///
30 /// // Compute homomorphically a scalar multiplication:
31 /// let ct_res = sks.unchecked_small_scalar_mul(&ct, scalar);
32 ///
33 /// let clear = cks.decrypt(&ct_res);
34 /// assert_eq!(scalar * msg, clear);
35 /// ```
36 pub fn unchecked_small_scalar_mul(
37 &self,
38 ctxt: &RadixCiphertext,
39 scalar: u64,
40 ) -> RadixCiphertext {
41 let mut ct_result = ctxt.clone();
42 self.unchecked_small_scalar_mul_assign(&mut ct_result, scalar);
43
44 ct_result
45 }
46
47 pub fn unchecked_small_scalar_mul_assign(&self, ctxt: &mut RadixCiphertext, scalar: u64) {
48 for ct_i in ctxt.blocks.iter_mut() {
49 self.key.unchecked_scalar_mul_assign(ct_i, scalar as u8);
50 }
51 }
52
53 ///Verifies if ct1 can be multiplied by scalar.
54 ///
55 /// # Example
56 ///
57 ///```rust
58 /// use concrete_integer::gen_keys_radix;
59 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
60 ///
61 /// // We have 4 * 2 = 8 bits of message
62 /// let size = 4;
63 /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
64 ///
65 /// let msg = 25;
66 /// let scalar1 = 3;
67 ///
68 /// let ct = cks.encrypt(msg);
69 ///
70 /// // Verification if the scalar multiplication can be computed:
71 /// let res = sks.is_small_scalar_mul_possible(&ct, scalar1);
72 ///
73 /// assert_eq!(true, res);
74 ///
75 /// let scalar2 = 7;
76 /// // Verification if the scalar multiplication can be computed:
77 /// let res = sks.is_small_scalar_mul_possible(&ct, scalar2);
78 /// assert_eq!(false, res);
79 /// ```
80 pub fn is_small_scalar_mul_possible(&self, ctxt: &RadixCiphertext, scalar: u64) -> bool {
81 for ct_i in ctxt.blocks.iter() {
82 if !self.key.is_scalar_mul_possible(ct_i, scalar as u8) {
83 return false;
84 }
85 }
86 true
87 }
88
89 /// Computes homomorphically a multiplication between a scalar and a ciphertext.
90 ///
91 /// If the operation can be performed, the result is returned in a new ciphertext.
92 /// Otherwise [CheckError::CarryFull] is returned.
93 ///
94 /// # Example
95 ///
96 /// ```rust
97 /// use concrete_integer::gen_keys_radix;
98 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
99 ///
100 /// // We have 4 * 2 = 8 bits of message
101 /// let size = 4;
102 /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
103 ///
104 /// let msg = 33;
105 /// let scalar = 3;
106 ///
107 /// let ct = cks.encrypt(msg);
108 ///
109 /// // Compute homomorphically a scalar multiplication:
110 /// let ct_res = sks.checked_small_scalar_mul(&ct, scalar);
111 ///
112 /// match ct_res {
113 /// Err(x) => panic!("{:?}", x),
114 /// Ok(y) => {
115 /// let clear = cks.decrypt(&y);
116 /// assert_eq!(msg * scalar, clear);
117 /// }
118 /// }
119 /// ```
120 pub fn checked_small_scalar_mul(
121 &self,
122 ct: &RadixCiphertext,
123 scalar: u64,
124 ) -> Result<RadixCiphertext, CheckError> {
125 let mut ct_result = ct.clone();
126
127 // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext
128 if self.is_small_scalar_mul_possible(ct, scalar) {
129 ct_result = self.unchecked_small_scalar_mul(&ct_result, scalar);
130
131 Ok(ct_result)
132 } else {
133 Err(CarryFull)
134 }
135 }
136
137 /// Computes homomorphically a multiplication between a scalar and a ciphertext.
138 ///
139 /// If the operation can be performed, the result is assigned to the ciphertext given
140 /// as parameter.
141 /// Otherwise [CheckError::CarryFull] is returned.
142 ///
143 /// # Example
144 ///
145 /// ```rust
146 /// use concrete_integer::gen_keys_radix;
147 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
148 ///
149 /// // We have 4 * 2 = 8 bits of message
150 /// let size = 4;
151 /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
152 ///
153 /// let msg = 33;
154 /// let scalar = 3;
155 ///
156 /// let mut ct = cks.encrypt(msg);
157 ///
158 /// // Compute homomorphically a scalar multiplication:
159 /// sks.checked_small_scalar_mul_assign(&mut ct, scalar);
160 ///
161 /// let clear_res = cks.decrypt(&ct);
162 /// assert_eq!(clear_res, msg * scalar);
163 /// ```
164 pub fn checked_small_scalar_mul_assign(
165 &self,
166 ct: &mut RadixCiphertext,
167 scalar: u64,
168 ) -> Result<(), CheckError> {
169 // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext
170 if self.is_small_scalar_mul_possible(ct, scalar) {
171 self.unchecked_small_scalar_mul_assign(ct, scalar);
172 Ok(())
173 } else {
174 Err(CarryFull)
175 }
176 }
177
178 /// Computes homomorphically a multiplication between a scalar and a ciphertext.
179 ///
180 /// `small` means the scalar value shall fit in a __shortint block__.
181 /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2,
182 /// the scalar should fit in 2 bits.
183 ///
184 /// The result is returned as a new ciphertext.
185 ///
186 /// # Example
187 ///
188 /// ```rust
189 /// use concrete_integer::gen_keys_radix;
190 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
191 ///
192 /// // We have 4 * 2 = 8 bits of message
193 /// let modulus = 1 << 8;
194 /// let size = 4;
195 /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
196 ///
197 /// let msg = 13;
198 /// let scalar = 5;
199 ///
200 /// let mut ct = cks.encrypt(msg);
201 ///
202 /// // Compute homomorphically a scalar multiplication:
203 /// let ct_res = sks.smart_small_scalar_mul(&mut ct, scalar);
204 ///
205 /// // Decrypt:
206 /// let clear = cks.decrypt(&ct_res);
207 /// assert_eq!(msg * scalar % modulus, clear);
208 /// ```
209 pub fn smart_small_scalar_mul(
210 &self,
211 ctxt: &mut RadixCiphertext,
212 scalar: u64,
213 ) -> RadixCiphertext {
214 if !self.is_small_scalar_mul_possible(ctxt, scalar) {
215 self.full_propagate(ctxt);
216 }
217 self.unchecked_small_scalar_mul(ctxt, scalar)
218 }
219
220 /// Computes homomorphically a multiplication between a scalar and a ciphertext.
221 ///
222 /// `small` means the scalar shall value fit in a __shortint block__.
223 /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2,
224 /// the scalar should fit in 2 bits.
225 ///
226 /// The result is assigned to the input ciphertext
227 ///
228 /// # Example
229 ///
230 /// ```rust
231 /// use concrete_integer::gen_keys_radix;
232 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
233 ///
234 /// // We have 4 * 2 = 8 bits of message
235 /// let modulus = 1 << 8;
236 /// let size = 4;
237 /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
238 ///
239 /// let msg = 9;
240 /// let scalar = 3;
241 ///
242 /// let mut ct = cks.encrypt(msg);
243 ///
244 /// // Compute homomorphically a scalar multiplication:
245 /// sks.smart_small_scalar_mul_assign(&mut ct, scalar);
246 ///
247 /// // Decrypt:
248 /// let clear = cks.decrypt(&ct);
249 /// assert_eq!(msg * scalar % modulus, clear);
250 /// ```
251 pub fn smart_small_scalar_mul_assign(&self, ctxt: &mut RadixCiphertext, scalar: u64) {
252 if !self.is_small_scalar_mul_possible(ctxt, scalar) {
253 self.full_propagate(ctxt);
254 }
255 self.unchecked_small_scalar_mul_assign(ctxt, scalar);
256 }
257
258 /// # Example
259 ///
260 /// ```rust
261 /// use concrete_integer::gen_keys_radix;
262 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
263 ///
264 /// // We have 4 * 2 = 8 bits of message
265 /// let size = 4;
266 /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
267 ///
268 /// let msg = 1;
269 /// let power = 2;
270 ///
271 /// let ct = cks.encrypt(msg);
272 ///
273 /// // Compute homomorphically a scalar multiplication:
274 /// let ct_res = sks.blockshift(&ct, power);
275 ///
276 /// // Decrypt:
277 /// let clear = cks.decrypt(&ct_res);
278 /// assert_eq!(16, clear);
279 /// ```
280 pub fn blockshift(&self, ctxt: &RadixCiphertext, shift: usize) -> RadixCiphertext {
281 let ctxt_zero = self.key.create_trivial(0_u8);
282 let mut result = ctxt.clone();
283
284 for res_i in result.blocks[..shift].iter_mut() {
285 *res_i = ctxt_zero.clone();
286 }
287
288 for (res_i, c_i) in result.blocks[shift..].iter_mut().zip(ctxt.blocks.iter()) {
289 *res_i = c_i.clone();
290 }
291 result
292 }
293
294 /// Computes homomorphically a multiplication between a scalar and a ciphertext.
295 ///
296 ///
297 /// # Example
298 ///
299 /// ```rust
300 /// use concrete_integer::gen_keys_radix;
301 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
302 ///
303 /// // We have 4 * 2 = 8 bits of message
304 /// let modulus = 1 << 8;
305 /// let size = 4;
306 /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
307 ///
308 /// let msg = 230;
309 /// let scalar = 376;
310 ///
311 /// let mut ct = cks.encrypt(msg);
312 ///
313 /// // Compute homomorphically a scalar multiplication:
314 /// let ct_res = sks.smart_scalar_mul(&mut ct, scalar);
315 ///
316 /// // Decrypt:
317 /// let clear = cks.decrypt(&ct_res);
318 /// assert_eq!(msg * scalar % modulus, clear);
319 /// ```
320 pub fn smart_scalar_mul(&self, ctxt: &mut RadixCiphertext, scalar: u64) -> RadixCiphertext {
321 let mask = (self.key.message_modulus.0 - 1) as u64;
322
323 //Propagate the carries before doing the multiplications
324 self.full_propagate(ctxt);
325
326 //Store the computations
327 let mut map: BTreeMap<u64, RadixCiphertext> = BTreeMap::new();
328
329 let mut result = self.create_trivial_zero_radix(ctxt.blocks.len());
330
331 let mut tmp;
332
333 let mut b_i = 1_u64;
334 for i in 0..ctxt.blocks.len() {
335 //lambda = sum u_ib^i
336 let u_ib_i = scalar & (mask * b_i);
337 let u_i = u_ib_i / b_i;
338
339 if u_i == 0 {
340 //update the power b^{i+1}
341 b_i *= self.key.message_modulus.0 as u64;
342 continue;
343 } else if u_i == 1 {
344 // tmp = ctxt * 1 * b^i
345 tmp = self.blockshift(ctxt, i);
346 } else {
347 tmp = map
348 .entry(u_i)
349 .or_insert_with(|| self.smart_small_scalar_mul(ctxt, u_i))
350 .clone();
351
352 //tmp = ctxt* u_i * b^i
353 tmp = self.blockshift(&tmp, i);
354 }
355
356 //update the result
357 result = self.smart_add(&mut result, &mut tmp);
358
359 //update the power b^{i+1}
360 b_i *= self.key.message_modulus.0 as u64;
361 }
362
363 result
364 }
365
366 pub fn smart_scalar_mul_assign(&self, ctxt: &mut RadixCiphertext, scalar: u64) {
367 *ctxt = self.smart_scalar_mul(ctxt, scalar);
368 }
369}