concrete_integer/server_key/radix/
scalar_sub.rs

1use crate::ciphertext::RadixCiphertext;
2use crate::server_key::CheckError;
3use crate::server_key::CheckError::CarryFull;
4use crate::ServerKey;
5
6impl ServerKey {
7    /// Computes homomorphically a subtraction between a ciphertext and a scalar.
8    ///
9    /// This function computes the operation without checking if it exceeds the capacity of the
10    /// ciphertext.
11    ///
12    /// The result is returned as a new ciphertext.
13    ///
14    /// # Example
15    ///
16    /// ```rust
17    /// use concrete_integer::gen_keys_radix;
18    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
19    ///
20    /// // We have 4 * 2 = 8 bits of message
21    /// let num_blocks = 4;
22    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks);
23    ///
24    /// let msg = 40;
25    /// let scalar = 3;
26    ///
27    /// let ct = cks.encrypt(msg);
28    ///
29    /// // Compute homomorphically an addition:
30    /// let ct_res = sks.unchecked_scalar_sub(&ct, scalar);
31    ///
32    /// // Decrypt:
33    /// let dec = cks.decrypt(&ct_res);
34    /// assert_eq!(msg - scalar, dec);
35    /// ```
36    pub fn unchecked_scalar_sub(&self, ct: &RadixCiphertext, scalar: u64) -> RadixCiphertext {
37        let mut result = ct.clone();
38        self.unchecked_scalar_sub_assign(&mut result, scalar);
39        result
40    }
41
42    pub fn unchecked_scalar_sub_assign(&self, ct: &mut RadixCiphertext, scalar: u64) {
43        //Bits of message put to 1
44        let mask = (self.key.message_modulus.0 - 1) as u64;
45
46        let modulus = self.key.message_modulus.0.pow(ct.blocks.len() as u32) as u64;
47
48        let neg_scalar = scalar.wrapping_neg() % modulus;
49
50        let mut power = 1_u64;
51        //Put each decomposition into a new ciphertext
52        for ct_i in ct.blocks.iter_mut() {
53            let mut decomp = neg_scalar & (mask * power);
54            decomp /= power;
55
56            self.key.unchecked_scalar_add_assign(ct_i, decomp as u8);
57
58            //modulus to the power i
59            power *= self.key.message_modulus.0 as u64;
60        }
61    }
62
63    /// Verifies if the subtraction of a ciphertext by scalar can be computed.
64    ///
65    /// # Example
66    ///
67    ///```rust
68    /// use concrete_integer::gen_keys_radix;
69    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
70    ///
71    /// // We have 4 * 2 = 8 bits of message
72    /// let num_blocks = 4;
73    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks);
74    ///
75    /// let msg = 40;
76    /// let scalar = 2;
77    ///
78    /// let ct1 = cks.encrypt(msg);
79    ///
80    /// // Check if we can perform an addition
81    /// let res = sks.is_scalar_sub_possible(&ct1, scalar);
82    ///
83    /// assert_eq!(true, res);
84    /// ```
85    pub fn is_scalar_sub_possible(&self, ct: &RadixCiphertext, scalar: u64) -> bool {
86        //Bits of message put to 1
87        let mask = (self.key.message_modulus.0 - 1) as u64;
88
89        let modulus = self.key.message_modulus.0.pow(ct.blocks.len() as u32) as u64;
90
91        let neg_scalar = scalar.wrapping_neg() % modulus;
92
93        let mut power = 1_u64;
94
95        for ct_i in ct.blocks.iter() {
96            let mut decomp = neg_scalar & (mask * power);
97            decomp /= power;
98
99            if !self.key.is_scalar_add_possible(ct_i, decomp as u8) {
100                return false;
101            }
102
103            //modulus to the power i
104            power *= self.key.message_modulus.0 as u64;
105        }
106        true
107    }
108
109    /// Computes homomorphically a subtraction of a ciphertext by a scalar.
110    ///
111    /// If the operation can be performed, the result is returned in a new ciphertext.
112    /// Otherwise [CheckError::CarryFull] is returned.
113    ///
114    /// # Example
115    ///
116    /// ```rust
117    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
118    /// use concrete_integer::gen_keys_radix;
119    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
120    ///
121    /// // We have 4 * 2 = 8 bits of message
122    /// let num_blocks = 4;
123    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks);
124    ///
125    /// let msg = 40;
126    /// let scalar = 4;
127    ///
128    /// let ct = cks.encrypt(msg);
129    ///
130    /// // Compute tne subtraction:
131    /// let ct_res = sks.checked_scalar_sub(&ct, scalar)?;
132    ///
133    /// // Decrypt:
134    /// let dec = cks.decrypt(&ct_res);
135    /// assert_eq!(msg - scalar, dec);
136    /// # Ok(())
137    /// # }
138    /// ```
139    pub fn checked_scalar_sub(
140        &self,
141        ct: &RadixCiphertext,
142        scalar: u64,
143    ) -> Result<RadixCiphertext, CheckError> {
144        if self.is_scalar_sub_possible(ct, scalar) {
145            Ok(self.unchecked_scalar_sub(ct, scalar))
146        } else {
147            Err(CarryFull)
148        }
149    }
150
151    /// Computes homomorphically a subtraction of a ciphertext by a scalar.
152    ///
153    /// If the operation can be performed, the result is returned in a new ciphertext.
154    /// Otherwise [CheckError::CarryFull] is returned.
155    ///
156    /// # Example
157    ///
158    /// ```rust
159    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
160    /// use concrete_integer::gen_keys_radix;
161    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
162    ///
163    /// // We have 4 * 2 = 8 bits of message
164    /// let num_blocks = 4;
165    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks);
166    ///
167    /// let msg = 232;
168    /// let scalar = 83;
169    ///
170    /// let mut ct = cks.encrypt(msg);
171    ///
172    /// // Compute tne subtraction:
173    /// sks.checked_scalar_sub_assign(&mut ct, scalar)?;
174    ///
175    /// // Decrypt:
176    /// let dec = cks.decrypt(&ct);
177    /// assert_eq!(msg - scalar, dec);
178    /// # Ok(())
179    /// # }
180    /// ```
181    pub fn checked_scalar_sub_assign(
182        &self,
183        ct: &mut RadixCiphertext,
184        scalar: u64,
185    ) -> Result<(), CheckError> {
186        if self.is_scalar_sub_possible(ct, scalar) {
187            self.unchecked_scalar_sub_assign(ct, scalar);
188            Ok(())
189        } else {
190            Err(CarryFull)
191        }
192    }
193
194    /// Computes homomorphically a subtraction of a ciphertext by a scalar.
195    ///
196    /// # Example
197    ///
198    /// ```rust
199    /// use concrete_integer::gen_keys_radix;
200    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
201    ///
202    /// // We have 4 * 2 = 8 bits of message
203    /// let num_blocks = 4;
204    /// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_blocks);
205    ///
206    /// let msg = 165;
207    /// let scalar = 112;
208    ///
209    /// let mut ct = cks.encrypt(msg);
210    ///
211    /// // Compute homomorphically an addition:
212    /// let ct_res = sks.smart_scalar_sub(&mut ct, scalar);
213    ///
214    /// // Decrypt:
215    /// let dec = cks.decrypt(&ct_res);
216    /// assert_eq!(msg - scalar, dec);
217    /// ```
218    pub fn smart_scalar_sub(&self, ct: &mut RadixCiphertext, scalar: u64) -> RadixCiphertext {
219        if !self.is_scalar_sub_possible(ct, scalar) {
220            self.full_propagate(ct);
221        }
222
223        self.unchecked_scalar_sub(ct, scalar)
224    }
225
226    pub fn smart_scalar_sub_assign(&self, ct: &mut RadixCiphertext, scalar: u64) {
227        if !self.is_scalar_sub_possible(ct, scalar) {
228            self.full_propagate(ct);
229        }
230
231        self.unchecked_scalar_sub_assign(ct, scalar);
232    }
233}