fast_paillier/
decryption_key.rs

1use rand_core::{CryptoRng, RngCore};
2
3use crate::backend::Integer;
4use crate::{utils, Ciphertext, EncryptionKey, Nonce, Plaintext};
5use crate::{Error, Reason};
6
7/// Paillier decryption key
8#[derive(Clone)]
9pub struct DecryptionKey {
10    ek: EncryptionKey,
11    /// `lcm(p-1, q-1)`
12    lambda: Integer,
13    /// `lambda^-1 mod N`
14    mu: Integer,
15
16    p: Integer,
17    q: Integer,
18
19    crt_mod_nn: utils::CrtExp,
20    /// Calculates `x ^ N mod N^2`. It's used for faster encryption
21    exp_n: utils::Exponent,
22    /// Calculates `x ^ lambda mod N^2`. It's used for faster decryption
23    exp_lambda: utils::Exponent,
24}
25
26impl DecryptionKey {
27    /// Generates a paillier key
28    ///
29    /// Samples two safe 1536-bits primes that meets 128 bits security level
30    pub fn generate(rng: &mut (impl RngCore + CryptoRng)) -> Result<Self, Error> {
31        let p = Integer::generate_safe_prime(rng, 1536);
32        let q = Integer::generate_safe_prime(rng, 1536);
33        Self::from_primes(p, q)
34    }
35
36    /// Constructs a paillier key from primes `p`, `q`
37    ///
38    /// `p` and `q` need to be safe primes sufficiently large to meet security level requirements.
39    ///
40    /// Returns error if `p` and `q` do not correspond to a valid paillier key.
41    #[allow(clippy::many_single_char_names)]
42    pub fn from_primes(p: Integer, q: Integer) -> Result<Self, Error> {
43        // Paillier doesn't work if p == q
44        if p == q {
45            return Err(Reason::InvalidPQ.into());
46        }
47        let pm1 = &p - 1u8;
48        let qm1 = &q - 1u8;
49        let ek = EncryptionKey::from_n(&p * &q);
50        let lambda = pm1.lcm_ref(&qm1);
51        if lambda.cmp0().is_eq() {
52            return Err(Reason::InvalidPQ.into());
53        }
54
55        // u = lambda^-1 mod N
56        let u = lambda.invert_ref(ek.n()).ok_or(Reason::InvalidPQ)?;
57
58        let crt_mod_nn = utils::CrtExp::build_nn(&p, &q).ok_or(Reason::BuildFastExp)?;
59        let exp_n = crt_mod_nn.prepare_exponent(ek.n());
60        let exp_lambda = crt_mod_nn.prepare_exponent(&lambda);
61
62        Ok(Self {
63            ek,
64            lambda,
65            mu: u,
66            p,
67            q,
68            crt_mod_nn,
69            exp_n,
70            exp_lambda,
71        })
72    }
73
74    /// Decrypts the ciphertext, returns plaintext in `{-N/2, .., N_2}`
75    pub fn decrypt(&self, c: &Ciphertext) -> Result<Plaintext, Error> {
76        if !c.in_mult_group_of(self.ek.nn()) {
77            return Err(Reason::Decrypt.into());
78        }
79
80        // a = c^\lambda mod n^2
81        let a = self
82            .crt_mod_nn
83            .exp(c, &self.exp_lambda)
84            .ok_or(Reason::Decrypt)?;
85
86        // ell = L(a, N)
87        let l = self.ek.l(&a).ok_or(Reason::Decrypt)?;
88
89        // m = lu = L(a)*u = L(c^\lamba*)u mod n
90        let plaintext = (l * &self.mu) % self.ek.n();
91
92        if (&plaintext << 1) >= *self.n() {
93            Ok(plaintext - self.n())
94        } else {
95            Ok(plaintext)
96        }
97    }
98
99    /// Encrypts a plaintext `x` in `{-N/2, .., N/2}` with `nonce` from `Z*_n`
100    ///
101    /// It uses the fact that factorization of `N` is known to speed up encryption.
102    ///
103    /// Returns error if inputs are not in specified range
104    pub fn encrypt_with(&self, x: &Plaintext, nonce: &Nonce) -> Result<Ciphertext, Error> {
105        if !self.ek.in_signed_group(x) || !nonce.in_mult_group_of(self.n()) {
106            return Err(Reason::Encrypt.into());
107        }
108
109        let x = if x.cmp0().is_ge() {
110            x.clone()
111        } else {
112            x + self.n()
113        };
114
115        // a = (1 + N)^x mod N^2 = (1 + xN) mod N^2
116        let a = (Integer::one() + x * self.ek.n()) % self.ek.nn();
117        // b = nonce^N mod N^2
118        let b = self
119            .crt_mod_nn
120            .exp(nonce, &self.exp_n)
121            .ok_or(Reason::Encrypt)?;
122
123        Ok((a * b) % self.ek.nn())
124    }
125
126    /// Encrypts the plaintext `x` in `{-N/2, .., N_2}`
127    ///
128    /// It's uses the fact that factorization of `N` is known to speed up encryption.
129    ///
130    /// Nonce is sampled randomly using `rng`.
131    ///
132    /// Returns error if plaintext is not in specified range
133    pub fn encrypt_with_random(
134        &self,
135        rng: &mut (impl RngCore + CryptoRng),
136        x: &Plaintext,
137    ) -> Result<(Ciphertext, Nonce), Error> {
138        let nonce = Integer::sample_in_mult_group_of(rng, self.ek.n());
139        let ciphertext = self.encrypt_with(x, &nonce)?;
140        Ok((ciphertext, nonce))
141    }
142
143    /// Homomorphic multiplication of scalar at ciphertext
144    ///
145    /// It uses the fact that factorization of `N` is known to speed up an operation.
146    ///
147    /// ```text
148    /// omul(a, Enc(c)) = Enc(a * c)
149    /// ```
150    pub fn omul(&self, scalar: &Integer, ciphertext: &Ciphertext) -> Result<Ciphertext, Error> {
151        if !scalar.abs_in_mult_group_of(self.n()) || !ciphertext.in_mult_group_of(self.ek.nn()) {
152            return Err(Reason::Ops.into());
153        }
154
155        let e = self.crt_mod_nn.prepare_exponent(scalar);
156        Ok(self.crt_mod_nn.exp(ciphertext, &e).ok_or(Reason::Ops)?)
157    }
158
159    /// Returns a (public) encryption key corresponding to the (secret) decryption key
160    pub fn encryption_key(&self) -> &EncryptionKey {
161        &self.ek
162    }
163
164    /// The Paillier modulus
165    pub fn n(&self) -> &Integer {
166        self.ek.n()
167    }
168
169    /// The Paillier `lambda`
170    pub fn lambda(&self) -> &Integer {
171        &self.lambda
172    }
173
174    /// The Paillier `mu`
175    pub fn mu(&self) -> &Integer {
176        &self.mu
177    }
178
179    /// Prime `p`
180    pub fn p(&self) -> &Integer {
181        &self.p
182    }
183    /// Prime `q`
184    pub fn q(&self) -> &Integer {
185        &self.q
186    }
187
188    /// Bits length of smaller prime (`p` or `q`)
189    pub fn bits_length(&self) -> u64 {
190        self.p.significant_bits().min(self.q.significant_bits())
191    }
192}