Skip to main content

tfhe/shortint/ciphertext/
common.rs

1use super::super::CheckError;
2pub use crate::core_crypto::commons::parameters::PBSOrder;
3use crate::shortint::backward_compatibility::ciphertext::*;
4use crate::shortint::parameters::{CarryModulus, MessageModulus};
5use serde::{Deserialize, Serialize};
6use std::cmp;
7use std::fmt::Debug;
8use tfhe_versionable::Versionize;
9
10/// Error for when a non trivial ciphertext was used when a trivial was expected
11#[derive(Copy, Clone, PartialEq, Eq, Debug)]
12pub struct NotTrivialCiphertextError;
13
14impl std::fmt::Display for NotTrivialCiphertextError {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        write!(f, "The ciphertext is a not a trivial ciphertext")
17    }
18}
19
20impl std::error::Error for NotTrivialCiphertextError {}
21
22/// This tracks the maximal amount of noise of a [super::Ciphertext]
23/// that guarantees the target p-error when doing a PBS on it
24#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
25#[versionize(MaxNoiseLevelVersions)]
26pub struct MaxNoiseLevel(u64);
27
28impl MaxNoiseLevel {
29    pub(crate) const UNKNOWN: Self = Self(u64::MAX);
30
31    pub const fn new(value: u64) -> Self {
32        Self(value)
33    }
34
35    pub const fn get(&self) -> u64 {
36        self.0
37    }
38
39    // This function is valid for current parameters as they guarantee the p-error for a norm2 noise
40    // limit equal to the norm2 limit which guarantees a clean padding bit
41    //
42    // TODO: remove this functions once noise norm2 constraint is decorrelated and stored in
43    // parameter sets
44    pub const fn from_msg_carry_modulus(
45        msg_modulus: MessageModulus,
46        carry_modulus: CarryModulus,
47    ) -> Self {
48        let level = (carry_modulus.0 * msg_modulus.0 - 1) / (msg_modulus.0 - 1);
49        Self(level)
50    }
51
52    pub const fn validate(&self, noise_level: NoiseLevel) -> Result<(), CheckError> {
53        if noise_level.0 > self.0 {
54            return Err(CheckError::NoiseTooBig {
55                noise_level,
56                max_noise_level: *self,
57            });
58        }
59        Ok(())
60    }
61}
62
63/// This tracks the amount of noise in a ciphertext.
64#[derive(
65    Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Serialize, Deserialize, Versionize,
66)]
67#[versionize(NoiseLevelVersions)]
68pub struct NoiseLevel(pub(crate) u64);
69
70impl NoiseLevel {
71    pub const NOMINAL: Self = Self(1);
72    pub const ZERO: Self = Self(0);
73    // As a safety measure the unknown noise level is set to the max value
74    pub const UNKNOWN: Self = Self(u64::MAX);
75}
76
77impl NoiseLevel {
78    pub fn get(&self) -> u64 {
79        self.0
80    }
81}
82
83impl std::ops::AddAssign for NoiseLevel {
84    fn add_assign(&mut self, rhs: Self) {
85        self.0 = self.0.saturating_add(rhs.0);
86    }
87}
88
89impl std::ops::Add for NoiseLevel {
90    type Output = Self;
91
92    fn add(mut self, rhs: Self) -> Self {
93        self += rhs;
94        self
95    }
96}
97
98impl std::ops::MulAssign<u64> for NoiseLevel {
99    fn mul_assign(&mut self, rhs: u64) {
100        self.0 = self.0.saturating_mul(rhs);
101    }
102}
103
104impl std::ops::Mul<u64> for NoiseLevel {
105    type Output = Self;
106
107    fn mul(mut self, rhs: u64) -> Self::Output {
108        self *= rhs;
109
110        self
111    }
112}
113
114/// Maximum value that the degree can reach.
115#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
116#[versionize(MaxDegreeVersions)]
117pub struct MaxDegree(pub(crate) u64);
118
119impl MaxDegree {
120    pub fn new(value: u64) -> Self {
121        Self(value)
122    }
123
124    pub fn get(&self) -> u64 {
125        self.0
126    }
127
128    pub fn from_msg_carry_modulus(
129        msg_modulus: MessageModulus,
130        carry_modulus: CarryModulus,
131    ) -> Self {
132        Self(carry_modulus.0 * msg_modulus.0 - 1)
133    }
134
135    pub fn validate(&self, degree: Degree) -> Result<(), CheckError> {
136        if degree.get() > self.0 {
137            return Err(CheckError::CarryFull {
138                degree,
139                max_degree: *self,
140            });
141        }
142        Ok(())
143    }
144}
145
146/// The maximum value a given ciphertext can have. This helps with optimizations.
147#[derive(
148    Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Serialize, Deserialize, Versionize,
149)]
150#[versionize(DegreeVersions)]
151pub struct Degree(pub(crate) u64);
152
153impl Degree {
154    pub fn new(degree: u64) -> Self {
155        Self(degree)
156    }
157
158    pub fn get(self) -> u64 {
159        self.0
160    }
161}
162
163#[cfg(test)]
164impl AsMut<u64> for Degree {
165    fn as_mut(&mut self) -> &mut u64 {
166        &mut self.0
167    }
168}
169
170impl Degree {
171    pub(crate) fn after_bitxor(self, other: Self) -> Self {
172        let max = cmp::max(self.0, other.0);
173        let min = cmp::min(self.0, other.0);
174        let mut result = max;
175
176        //Try every possibility to find the worst case
177        for i in 0..min + 1 {
178            if max ^ i > result {
179                result = max ^ i;
180            }
181        }
182
183        Self(result)
184    }
185
186    pub(crate) fn after_bitor(self, other: Self) -> Self {
187        let max = cmp::max(self.0, other.0);
188        let min = cmp::min(self.0, other.0);
189        let mut result = max;
190
191        for i in 0..min + 1 {
192            if max | i > result {
193                result = max | i;
194            }
195        }
196
197        Self(result)
198    }
199
200    pub(crate) fn after_bitand(self, other: Self) -> Self {
201        Self(cmp::min(self.0, other.0))
202    }
203
204    pub(crate) fn after_left_shift(self, shift: u8, modulus: u64) -> Self {
205        let mut result = 0;
206
207        for i in 0..self.0 + 1 {
208            let tmp = (i << shift) % modulus;
209            if tmp > result {
210                result = tmp;
211            }
212        }
213
214        Self(result)
215    }
216}
217
218impl std::ops::AddAssign for Degree {
219    fn add_assign(&mut self, rhs: Self) {
220        self.0 = self.0.saturating_add(rhs.0);
221    }
222}
223
224impl std::ops::Add for Degree {
225    type Output = Self;
226
227    fn add(mut self, rhs: Self) -> Self {
228        self += rhs;
229        self
230    }
231}
232
233impl std::ops::MulAssign<u64> for Degree {
234    fn mul_assign(&mut self, rhs: u64) {
235        self.0 = self.0.saturating_mul(rhs);
236    }
237}
238
239impl std::ops::Mul<u64> for Degree {
240    type Output = Self;
241
242    fn mul(mut self, rhs: u64) -> Self::Output {
243        self *= rhs;
244
245        self
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_noise_level_ci_run_filter() {
255        use rand::{thread_rng, Rng};
256
257        let mut rng = thread_rng();
258
259        assert_eq!(NoiseLevel::UNKNOWN.0, u64::MAX);
260
261        let max_noise_level = NoiseLevel::UNKNOWN;
262        let random_addend = rng.gen::<u64>();
263        let add = max_noise_level + NoiseLevel(random_addend);
264        assert_eq!(add, NoiseLevel::UNKNOWN);
265
266        let random_positive_multiplier = rng.gen_range(1u64..=u64::MAX);
267        let mul = max_noise_level * random_positive_multiplier;
268        assert_eq!(mul, NoiseLevel::UNKNOWN);
269
270        let random_noise = NoiseLevel(rng.gen_range(2..=u64::MAX));
271
272        assert!(NoiseLevel::NOMINAL >= NoiseLevel::ZERO);
273        assert!(random_noise > NoiseLevel::NOMINAL);
274        assert!(random_noise <= NoiseLevel::UNKNOWN);
275    }
276
277    #[test]
278    fn test_max_noise_level_from_msg_carry_modulus_ci_run_filter() {
279        let max_noise_level =
280            MaxNoiseLevel::from_msg_carry_modulus(MessageModulus(4), CarryModulus(4));
281
282        assert_eq!(max_noise_level.0, 5);
283    }
284
285    #[test]
286    fn degree_after_bitxor_ci_run_filter() {
287        let data = [
288            (Degree(3), Degree(3), Degree(3)),
289            (Degree(3), Degree(1), Degree(3)),
290            (Degree(1), Degree(3), Degree(3)),
291            (Degree(3), Degree(2), Degree(3)),
292            (Degree(2), Degree(3), Degree(3)),
293            (Degree(2), Degree(2), Degree(3)),
294            (Degree(2), Degree(1), Degree(3)),
295            (Degree(1), Degree(2), Degree(3)),
296            (Degree(1), Degree(1), Degree(1)),
297            (Degree(0), Degree(1), Degree(1)),
298            (Degree(0), Degree(1), Degree(1)),
299        ];
300
301        for (lhs, rhs, expected) in data {
302            let result = lhs.after_bitxor(rhs);
303            assert_eq!(
304                result, expected,
305                "For a bitxor between variables of degree {lhs:?} and {rhs:?},\
306             expected resulting degree: {expected:?}, got {result:?}"
307            );
308        }
309    }
310    #[test]
311    fn degree_after_bitor_ci_run_filter() {
312        let data = [
313            (Degree(3), Degree(3), Degree(3)),
314            (Degree(3), Degree(1), Degree(3)),
315            (Degree(1), Degree(3), Degree(3)),
316            (Degree(3), Degree(2), Degree(3)),
317            (Degree(2), Degree(3), Degree(3)),
318            (Degree(2), Degree(2), Degree(3)),
319            (Degree(2), Degree(1), Degree(3)),
320            (Degree(1), Degree(2), Degree(3)),
321            (Degree(1), Degree(1), Degree(1)),
322            (Degree(0), Degree(1), Degree(1)),
323            (Degree(0), Degree(1), Degree(1)),
324        ];
325
326        for (lhs, rhs, expected) in data {
327            let result = lhs.after_bitor(rhs);
328            assert_eq!(
329                result, expected,
330                "For a bitor between variables of degree {lhs:?} and {rhs:?},\
331             expected resulting degree: {expected:?}, got {result:?}"
332            );
333        }
334    }
335
336    #[test]
337    fn degree_after_bitand_ci_run_filter() {
338        let data = [
339            (Degree(3), Degree(3), Degree(3)),
340            (Degree(3), Degree(1), Degree(1)),
341            (Degree(1), Degree(3), Degree(1)),
342            (Degree(3), Degree(2), Degree(2)),
343            (Degree(2), Degree(3), Degree(2)),
344            (Degree(2), Degree(2), Degree(2)),
345            (Degree(2), Degree(1), Degree(1)),
346            (Degree(1), Degree(2), Degree(1)),
347            (Degree(1), Degree(1), Degree(1)),
348            (Degree(0), Degree(1), Degree(0)),
349            (Degree(0), Degree(1), Degree(0)),
350        ];
351
352        for (lhs, rhs, expected) in data {
353            let result = lhs.after_bitand(rhs);
354            assert_eq!(
355                result, expected,
356                "For a bitand between variables of degree {lhs:?} and {rhs:?},\
357             expected resulting degree: {expected:?}, got {result:?}"
358            );
359        }
360    }
361}