concrete_integer/server_key/radix/
scalar_mul.rs

1use crate::ciphertext::RadixCiphertext;
2use crate::server_key::CheckError;
3use crate::server_key::CheckError::CarryFull;
4use crate::ServerKey;
5use std::collections::BTreeMap;
6
7impl ServerKey {
8    /// Computes homomorphically a multiplication between a scalar and a ciphertext.
9    ///
10    /// This function computes the operation without checking if it exceeds the capacity of the
11    /// ciphertext.
12    ///
13    /// The result is returned as a new ciphertext.
14    ///
15    /// # Example
16    ///
17    /// ```rust
18    /// use concrete_integer::gen_keys_radix;
19    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
20    ///
21    /// // We have 4 * 2 = 8 bits of message
22    /// let size = 4;
23    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
24    ///
25    /// let msg = 30;
26    /// let scalar = 3;
27    ///
28    /// let ct = cks.encrypt(msg);
29    ///
30    /// // Compute homomorphically a scalar multiplication:
31    /// let ct_res = sks.unchecked_small_scalar_mul(&ct, scalar);
32    ///
33    /// let clear = cks.decrypt(&ct_res);
34    /// assert_eq!(scalar * msg, clear);
35    /// ```
36    pub fn unchecked_small_scalar_mul(
37        &self,
38        ctxt: &RadixCiphertext,
39        scalar: u64,
40    ) -> RadixCiphertext {
41        let mut ct_result = ctxt.clone();
42        self.unchecked_small_scalar_mul_assign(&mut ct_result, scalar);
43
44        ct_result
45    }
46
47    pub fn unchecked_small_scalar_mul_assign(&self, ctxt: &mut RadixCiphertext, scalar: u64) {
48        for ct_i in ctxt.blocks.iter_mut() {
49            self.key.unchecked_scalar_mul_assign(ct_i, scalar as u8);
50        }
51    }
52
53    ///Verifies if ct1 can be multiplied by scalar.
54    ///
55    /// # Example
56    ///
57    ///```rust
58    /// use concrete_integer::gen_keys_radix;
59    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
60    ///
61    /// // We have 4 * 2 = 8 bits of message
62    /// let size = 4;
63    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
64    ///
65    /// let msg = 25;
66    /// let scalar1 = 3;
67    ///
68    /// let ct = cks.encrypt(msg);
69    ///
70    /// // Verification if the scalar multiplication can be computed:
71    /// let res = sks.is_small_scalar_mul_possible(&ct, scalar1);
72    ///
73    /// assert_eq!(true, res);
74    ///
75    /// let scalar2 = 7;
76    /// // Verification if the scalar multiplication can be computed:
77    /// let res = sks.is_small_scalar_mul_possible(&ct, scalar2);
78    /// assert_eq!(false, res);
79    /// ```
80    pub fn is_small_scalar_mul_possible(&self, ctxt: &RadixCiphertext, scalar: u64) -> bool {
81        for ct_i in ctxt.blocks.iter() {
82            if !self.key.is_scalar_mul_possible(ct_i, scalar as u8) {
83                return false;
84            }
85        }
86        true
87    }
88
89    /// Computes homomorphically a multiplication between a scalar and a ciphertext.
90    ///
91    /// If the operation can be performed, the result is returned in a new ciphertext.
92    /// Otherwise [CheckError::CarryFull] is returned.
93    ///
94    /// # Example
95    ///
96    /// ```rust
97    /// use concrete_integer::gen_keys_radix;
98    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
99    ///
100    /// // We have 4 * 2 = 8 bits of message
101    /// let size = 4;
102    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
103    ///
104    /// let msg = 33;
105    /// let scalar = 3;
106    ///
107    /// let ct = cks.encrypt(msg);
108    ///
109    /// // Compute homomorphically a scalar multiplication:
110    /// let ct_res = sks.checked_small_scalar_mul(&ct, scalar);
111    ///
112    /// match ct_res {
113    ///     Err(x) => panic!("{:?}", x),
114    ///     Ok(y) => {
115    ///         let clear = cks.decrypt(&y);
116    ///         assert_eq!(msg * scalar, clear);
117    ///     }
118    /// }
119    /// ```
120    pub fn checked_small_scalar_mul(
121        &self,
122        ct: &RadixCiphertext,
123        scalar: u64,
124    ) -> Result<RadixCiphertext, CheckError> {
125        let mut ct_result = ct.clone();
126
127        // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext
128        if self.is_small_scalar_mul_possible(ct, scalar) {
129            ct_result = self.unchecked_small_scalar_mul(&ct_result, scalar);
130
131            Ok(ct_result)
132        } else {
133            Err(CarryFull)
134        }
135    }
136
137    /// Computes homomorphically a multiplication between a scalar and a ciphertext.
138    ///
139    /// If the operation can be performed, the result is assigned to the ciphertext given
140    /// as parameter.
141    /// Otherwise [CheckError::CarryFull] is returned.
142    ///
143    /// # Example
144    ///
145    /// ```rust
146    /// use concrete_integer::gen_keys_radix;
147    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
148    ///
149    /// // We have 4 * 2 = 8 bits of message
150    /// let size = 4;
151    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
152    ///
153    /// let msg = 33;
154    /// let scalar = 3;
155    ///
156    /// let mut ct = cks.encrypt(msg);
157    ///
158    /// // Compute homomorphically a scalar multiplication:
159    /// sks.checked_small_scalar_mul_assign(&mut ct, scalar);
160    ///
161    /// let clear_res = cks.decrypt(&ct);
162    /// assert_eq!(clear_res, msg * scalar);
163    /// ```
164    pub fn checked_small_scalar_mul_assign(
165        &self,
166        ct: &mut RadixCiphertext,
167        scalar: u64,
168    ) -> Result<(), CheckError> {
169        // If the ciphertext cannot be multiplied without exceeding the capacity of a ciphertext
170        if self.is_small_scalar_mul_possible(ct, scalar) {
171            self.unchecked_small_scalar_mul_assign(ct, scalar);
172            Ok(())
173        } else {
174            Err(CarryFull)
175        }
176    }
177
178    /// Computes homomorphically a multiplication between a scalar and a ciphertext.
179    ///
180    /// `small` means the scalar value shall fit in a __shortint block__.
181    /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2,
182    /// the scalar should fit in 2 bits.
183    ///
184    /// The result is returned as a new ciphertext.
185    ///
186    /// # Example
187    ///
188    /// ```rust
189    /// use concrete_integer::gen_keys_radix;
190    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
191    ///
192    /// // We have 4 * 2 = 8 bits of message
193    /// let modulus = 1 << 8;
194    /// let size = 4;
195    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
196    ///
197    /// let msg = 13;
198    /// let scalar = 5;
199    ///
200    /// let mut ct = cks.encrypt(msg);
201    ///
202    /// // Compute homomorphically a scalar multiplication:
203    /// let ct_res = sks.smart_small_scalar_mul(&mut ct, scalar);
204    ///
205    /// // Decrypt:
206    /// let clear = cks.decrypt(&ct_res);
207    /// assert_eq!(msg * scalar % modulus, clear);
208    /// ```
209    pub fn smart_small_scalar_mul(
210        &self,
211        ctxt: &mut RadixCiphertext,
212        scalar: u64,
213    ) -> RadixCiphertext {
214        if !self.is_small_scalar_mul_possible(ctxt, scalar) {
215            self.full_propagate(ctxt);
216        }
217        self.unchecked_small_scalar_mul(ctxt, scalar)
218    }
219
220    /// Computes homomorphically a multiplication between a scalar and a ciphertext.
221    ///
222    /// `small` means the scalar shall value fit in a __shortint block__.
223    /// For example, if the parameters are PARAM_MESSAGE_2_CARRY_2,
224    /// the scalar should fit in 2 bits.
225    ///
226    /// The result is assigned to the input ciphertext
227    ///
228    /// # Example
229    ///
230    /// ```rust
231    /// use concrete_integer::gen_keys_radix;
232    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
233    ///
234    /// // We have 4 * 2 = 8 bits of message
235    /// let modulus = 1 << 8;
236    /// let size = 4;
237    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
238    ///
239    /// let msg = 9;
240    /// let scalar = 3;
241    ///
242    /// let mut ct = cks.encrypt(msg);
243    ///
244    /// // Compute homomorphically a scalar multiplication:
245    /// sks.smart_small_scalar_mul_assign(&mut ct, scalar);
246    ///
247    /// // Decrypt:
248    /// let clear = cks.decrypt(&ct);
249    /// assert_eq!(msg * scalar % modulus, clear);
250    /// ```
251    pub fn smart_small_scalar_mul_assign(&self, ctxt: &mut RadixCiphertext, scalar: u64) {
252        if !self.is_small_scalar_mul_possible(ctxt, scalar) {
253            self.full_propagate(ctxt);
254        }
255        self.unchecked_small_scalar_mul_assign(ctxt, scalar);
256    }
257
258    /// # Example
259    ///
260    /// ```rust
261    /// use concrete_integer::gen_keys_radix;
262    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
263    ///
264    /// // We have 4 * 2 = 8 bits of message
265    /// let size = 4;
266    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
267    ///
268    /// let msg = 1;
269    /// let power = 2;
270    ///
271    /// let ct = cks.encrypt(msg);
272    ///
273    /// // Compute homomorphically a scalar multiplication:
274    /// let ct_res = sks.blockshift(&ct, power);
275    ///
276    /// // Decrypt:
277    /// let clear = cks.decrypt(&ct_res);
278    /// assert_eq!(16, clear);
279    /// ```
280    pub fn blockshift(&self, ctxt: &RadixCiphertext, shift: usize) -> RadixCiphertext {
281        let ctxt_zero = self.key.create_trivial(0_u8);
282        let mut result = ctxt.clone();
283
284        for res_i in result.blocks[..shift].iter_mut() {
285            *res_i = ctxt_zero.clone();
286        }
287
288        for (res_i, c_i) in result.blocks[shift..].iter_mut().zip(ctxt.blocks.iter()) {
289            *res_i = c_i.clone();
290        }
291        result
292    }
293
294    /// Computes homomorphically a multiplication between a scalar and a ciphertext.
295    ///
296    ///
297    /// # Example
298    ///
299    /// ```rust
300    /// use concrete_integer::gen_keys_radix;
301    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
302    ///
303    /// // We have 4 * 2 = 8 bits of message
304    /// let modulus = 1 << 8;
305    /// let size = 4;
306    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
307    ///
308    /// let msg = 230;
309    /// let scalar = 376;
310    ///
311    /// let mut ct = cks.encrypt(msg);
312    ///
313    /// // Compute homomorphically a scalar multiplication:
314    /// let ct_res = sks.smart_scalar_mul(&mut ct, scalar);
315    ///
316    /// // Decrypt:
317    /// let clear = cks.decrypt(&ct_res);
318    /// assert_eq!(msg * scalar % modulus, clear);
319    /// ```
320    pub fn smart_scalar_mul(&self, ctxt: &mut RadixCiphertext, scalar: u64) -> RadixCiphertext {
321        let mask = (self.key.message_modulus.0 - 1) as u64;
322
323        //Propagate the carries before doing the multiplications
324        self.full_propagate(ctxt);
325
326        //Store the computations
327        let mut map: BTreeMap<u64, RadixCiphertext> = BTreeMap::new();
328
329        let mut result = self.create_trivial_zero_radix(ctxt.blocks.len());
330
331        let mut tmp;
332
333        let mut b_i = 1_u64;
334        for i in 0..ctxt.blocks.len() {
335            //lambda = sum u_ib^i
336            let u_ib_i = scalar & (mask * b_i);
337            let u_i = u_ib_i / b_i;
338
339            if u_i == 0 {
340                //update the power b^{i+1}
341                b_i *= self.key.message_modulus.0 as u64;
342                continue;
343            } else if u_i == 1 {
344                // tmp = ctxt * 1 * b^i
345                tmp = self.blockshift(ctxt, i);
346            } else {
347                tmp = map
348                    .entry(u_i)
349                    .or_insert_with(|| self.smart_small_scalar_mul(ctxt, u_i))
350                    .clone();
351
352                //tmp = ctxt* u_i * b^i
353                tmp = self.blockshift(&tmp, i);
354            }
355
356            //update the result
357            result = self.smart_add(&mut result, &mut tmp);
358
359            //update the power b^{i+1}
360            b_i *= self.key.message_modulus.0 as u64;
361        }
362
363        result
364    }
365
366    pub fn smart_scalar_mul_assign(&self, ctxt: &mut RadixCiphertext, scalar: u64) {
367        *ctxt = self.smart_scalar_mul(ctxt, scalar);
368    }
369}