Skip to main content

sl_paillier/
lib.rs

1// Copyright (c) Silence Laboratories Pte. Ltd. All Rights Reserved.
2// This software is licensed under the Silence Laboratories License Agreement.
3
4use std::ops::Deref;
5
6use crypto_bigint::modular::{MontyForm, MontyParams};
7use crypto_bigint::{
8    BoxedUint, Concat, Encoding, NonZero, RandomMod, Split, Uint,
9};
10use crypto_bigint::{U1024, U2048, U4096};
11
12use crypto_primes::generate_prime_with_rng;
13
14use rand_core::CryptoRngCore;
15
16#[cfg(feature = "serde")]
17use crypto_bigint::Bounded;
18
19// print-type-size type: `SK<64, 32, 16>`: 5400 bytes, alignment: 8 bytes
20// print-type-size     field `.phi`: 256 bytes
21// print-type-size     field `.inv_phi`: 256 bytes
22// print-type-size     field `.p`: 128 bytes
23// print-type-size     field `.hp`: 128 bytes
24// print-type-size     field `.q`: 128 bytes
25// print-type-size     field `.hq`: 128 bytes
26// print-type-size     field `.pk`: 2312 bytes
27// print-type-size     field `.pp_params`: 1032 bytes
28// print-type-size     field `.qq_params`: 1032 bytes
29
30pub type SK2048 = SK<{ U4096::LIMBS }, { U2048::LIMBS }, { U1024::LIMBS }>;
31pub type PK2048 = PK<{ U4096::LIMBS }, { U2048::LIMBS }>;
32
33#[cfg(feature = "serde")]
34pub type MinimalSK2048 = MinimalSK<{ U1024::LIMBS }>;
35#[cfg(feature = "serde")]
36pub type MinimalPK2048 = MinimalPK<{ U4096::LIMBS }, { U2048::LIMBS }>;
37
38#[cfg(feature = "serde")]
39use serde::{Deserialize, Serialize};
40
41#[cfg(feature = "serde")]
42#[derive(Debug, Clone, PartialEq, Eq)]
43#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
44pub struct MinimalSK<const P: usize>
45where
46    Uint<P>: Bounded + Encoding,
47{
48    pub p: Uint<P>,
49    pub q: Uint<P>,
50}
51
52#[cfg(feature = "serde")]
53#[derive(Debug, Clone, Copy)]
54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
55pub struct MinimalPK<const NN: usize, const N: usize>
56where
57    Uint<N>: Bounded + Encoding,
58{
59    pub n: NonZero<Uint<N>>,
60}
61
62#[cfg(feature = "serde")]
63impl<const NN: usize, const N: usize> MinimalPK<NN, N>
64where
65    Uint<N>: Bounded + Encoding,
66    Uint<NN>: From<(Uint<N>, Uint<N>)>,
67{
68    pub fn compute_nn(&self) -> Uint<NN> {
69        self.n.square_wide().into()
70    }
71}
72
73#[cfg(feature = "serde")]
74impl<const NN: usize, const N: usize> From<MinimalPK<NN, N>> for PK<NN, N>
75where
76    Uint<N>: Bounded + Encoding,
77    Uint<NN>: Bounded + Encoding,
78    Uint<NN>: From<(Uint<N>, Uint<N>)>,
79{
80    fn from(value: MinimalPK<NN, N>) -> Self {
81        let nn: Uint<NN> = value.n.square_wide().into();
82        PK {
83            n: value.n,
84            params: MontyParams::<NN>::new_vartime(nn.to_odd().unwrap()),
85        }
86    }
87}
88
89#[cfg(feature = "serde")]
90mod serialize {
91    use super::*;
92    use serde::{Deserialize, Deserializer, Serialize};
93    impl Serialize for SK2048 {
94        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
95        where
96            S: serde::Serializer,
97        {
98            let minimal = self.to_minimal();
99            minimal.serialize(serializer)
100        }
101    }
102
103    impl<'de> Deserialize<'de> for SK2048 {
104        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
105        where
106            D: Deserializer<'de>,
107        {
108            let minimal = MinimalSK2048::deserialize(deserializer)?;
109            Ok(minimal.into())
110        }
111    }
112
113    impl Serialize for PK2048 {
114        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
115        where
116            S: serde::Serializer,
117        {
118            let minimal = self.to_minimal();
119            minimal.serialize(serializer)
120        }
121    }
122
123    impl<'de> Deserialize<'de> for PK2048 {
124        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
125        where
126            D: Deserializer<'de>,
127        {
128            let minimal = MinimalPK2048::deserialize(deserializer)?;
129            Ok(minimal.into())
130        }
131    }
132
133    impl Serialize for RawCiphertext<{ U4096::LIMBS }> {
134        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
135        where
136            S: serde::Serializer,
137        {
138            self.0.serialize(serializer)
139        }
140    }
141
142    impl<'de> Deserialize<'de> for RawCiphertext<{ U4096::LIMBS }> {
143        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
144        where
145            D: Deserializer<'de>,
146        {
147            let c = Uint::<{ U4096::LIMBS }>::deserialize(deserializer)?;
148            Ok(RawCiphertext(c))
149        }
150    }
151}
152
153#[derive(Debug, PartialEq)]
154pub struct RawPlaintext<const L: usize>(Uint<L>);
155
156impl<const L: usize> RawPlaintext<L> {
157    pub fn to_uint(&self) -> Uint<L> {
158        self.0
159    }
160}
161
162#[derive(Debug, PartialEq, Clone)]
163pub struct RawCiphertext<const L: usize>(Uint<L>);
164
165impl<const L: usize> RawCiphertext<L>
166where
167    Uint<L>: Encoding,
168{
169    pub fn from(c: Uint<L>) -> Self {
170        Self(c)
171    }
172
173    pub fn to_be_bytes(&self) -> <Uint<L> as Encoding>::Repr {
174        self.0.to_be_bytes()
175    }
176
177    pub fn to_le_bytes(&self) -> <Uint<L> as Encoding>::Repr {
178        self.0.to_le_bytes()
179    }
180
181    pub fn to_uint(&self) -> Uint<L> {
182        self.0
183    }
184
185    pub fn from_be_bytes(bytes: <Uint<L> as Encoding>::Repr) -> Self {
186        Self(Uint::from_be_bytes(bytes))
187    }
188
189    pub fn from_le_bytes(bytes: <Uint<L> as Encoding>::Repr) -> Self {
190        Self(Uint::from_le_bytes(bytes))
191    }
192
193    pub fn from_uint(c: Uint<L>) -> Self {
194        Self(c)
195    }
196}
197
198impl<const L: usize> Default for RawCiphertext<L> {
199    fn default() -> Self {
200        Self(Default::default())
201    }
202}
203
204impl<const L: usize> Default for RawPlaintext<L> {
205    fn default() -> Self {
206        Self(Default::default())
207    }
208}
209
210pub trait IntoRawPlaintext<const L: usize, T> {
211    fn into_plaintext(self, msg: T) -> Option<RawPlaintext<L>>;
212}
213
214fn inv_mod_uint<const L: usize>(
215    value: &Uint<L>,
216    modulus: &Uint<L>,
217) -> Uint<L> {
218    let value = BoxedUint::from(value);
219    let modulus = BoxedUint::from(modulus);
220    let inverse = value
221        .inv_mod(&modulus)
222        .expect("value should be invertible for the provided modulus");
223    Uint::<L>::from_words(
224        inverse
225            .to_words()
226            .as_ref()
227            .try_into()
228            .expect("boxed inverse should preserve limb width"),
229    )
230}
231
232#[derive(Debug, Clone, Copy)]
233pub struct PK<const C: usize, const M: usize> {
234    n: NonZero<Uint<M>>,
235    params: MontyParams<C>, // mod N^2
236}
237
238#[derive(Debug, Clone)]
239pub struct SK<const C: usize, const M: usize, const P: usize> {
240    pk: PK<C, M>,
241    phi: Uint<M>,
242    inv_phi: Uint<M>,
243    p: Uint<P>,
244    hp: Uint<P>,
245    q: Uint<P>,
246    hq: Uint<P>,
247    pinv_q: Uint<P>,
248    pp_params: MontyParams<M>,
249    qq_params: MontyParams<M>,
250}
251
252impl<const C: usize, const M: usize, const P: usize> SK<C, M, P>
253where
254    Uint<C>: Split<Output = Uint<M>>,
255    Uint<C>: From<(Uint<M>, Uint<M>)>,
256    Uint<M>: Concat<Output = Uint<C>>,
257    Uint<M>: From<(Uint<P>, Uint<P>)>,
258    Uint<P>: Concat<Output = Uint<M>>,
259    Uint<M>: Encoding + Split<Output = Uint<P>>,
260{
261    pub fn gen_pq(rng: &mut impl CryptoRngCore) -> (Uint<P>, Uint<P>) {
262        let q = generate_prime_with_rng(rng, Uint::<P>::BITS);
263        let p = generate_prime_with_rng(rng, Uint::<P>::BITS);
264
265        (p, q)
266    }
267
268    pub fn gen(rng: &mut impl CryptoRngCore) -> Self {
269        let (p, q) = Self::gen_pq(rng);
270        SK::from_pq(&p, &q)
271    }
272
273    pub fn gen_keys(rng: &mut impl CryptoRngCore) -> (SK<C, M, P>, PK<C, M>) {
274        let sk = SK::gen(rng);
275        let pk = sk.public_key();
276        (sk, pk)
277    }
278
279    pub fn get_phi(&self) -> &Uint<M> {
280        &self.phi
281    }
282
283    pub fn from_pq(p: &Uint<P>, q: &Uint<P>) -> Self {
284        // N = pq
285        let n: Uint<M> = q.split_mul(p).into();
286        let pk = PK::from_n(&n);
287
288        // phi = (q-1)(p-1)
289        let phi: Uint<M> = q
290            .wrapping_sub(&Uint::ONE)
291            .split_mul(&p.wrapping_sub(&Uint::ONE))
292            .into();
293
294        // inv_phi = phi^-1 mod N
295        let inv_phi = inv_mod_uint::<M>(&phi, pk.n.as_ref());
296
297        let pinv_q = inv_mod_uint::<P>(p, q);
298
299        let pp: Uint<M> = p.square_wide().into();
300        let pp_params = MontyParams::new(pp.to_odd().unwrap());
301        let hp = Self::h(p, pp_params.modulus().as_ref(), &n);
302
303        let qq: Uint<M> = q.square_wide().into();
304        let qq_params = MontyParams::new(qq.to_odd().unwrap());
305        let hq = Self::h(q, qq_params.modulus().as_ref(), &n);
306
307        SK {
308            phi,
309            inv_phi,
310            pk,
311            p: *p,
312            hp,
313            pp_params,
314            q: *q,
315            pinv_q,
316            hq,
317            qq_params,
318        }
319    }
320
321    pub fn public_key(&self) -> PK<C, M> {
322        PK {
323            n: self.pk.n,
324            params: self.pk.params,
325        }
326    }
327
328    pub fn decrypt(&self, c: &RawCiphertext<C>) -> RawPlaintext<M> {
329        let c = MontyForm::new(&c.0, self.params);
330        let n_wide = NonZero::new(self.n.as_ref().resize::<C>()).unwrap();
331
332        // m = (c^phi mod N^2 - 1) / N
333        let m: Uint<M> = c
334            .pow_bounded_exp(&self.phi, Uint::<M>::BITS)
335            .retrieve()
336            .wrapping_sub(&Uint::ONE)
337            .wrapping_div(&n_wide)
338            .resize(); // drop top half of the value
339
340        // m = (m * phi^-1) mod N
341
342        // m_mod_n = m mod N
343        let m_mod_n = m.rem(&self.n);
344
345        RawPlaintext(m_mod_n.mul_mod::<C>(&self.inv_phi, &self.n))
346    }
347
348    pub(crate) fn h(p: &Uint<P>, pp: &Uint<M>, n: &Uint<M>) -> Uint<P> {
349        // h = L_p(g^{p-1} mod p^2)^-1 mod p
350        //
351        // L_p (x) = (x-1) / p
352        //
353        // n == p*q
354        //
355        //    (1 + n)^{p-1}  mod p^2
356        // =   1 + n(p-1)    mod p^2
357        // =   1 - n + np    mod p^2
358        // =   1 - n + qp^2  mod p^2
359        // =   1 - n         mod p^2
360        //
361
362        let n_mod_pp = n.rem(&NonZero::new(*pp).unwrap()); // should be fast because N and p^2 are close
363        let p_wide = p.resize::<M>();
364        let p_wide_nz = NonZero::new(p_wide).unwrap();
365        let value = Uint::ONE
366            .sub_mod(&n_mod_pp, pp)
367            .wrapping_sub(&Uint::ONE) // L_p(x) = (x-1)/p
368            .wrapping_div(&p_wide_nz);
369
370        inv_mod_uint::<M>(&value, &p_wide).resize() // dropping top half of bits
371    }
372
373    fn mp(
374        &self,
375        cp: Uint<M>,
376        p: &NonZero<Uint<P>>,
377        hp: &Uint<P>,
378        param: &MontyParams<M>,
379    ) -> Uint<P> {
380        // L_p(cp^{p-1} mod p^2) h_p mod p
381        let p_wide_nz = NonZero::new(p.resize::<M>()).unwrap();
382        let mp: Uint<P> = MontyForm::new(&cp, *param)
383            .pow_bounded_exp(&p.wrapping_sub(&Uint::ONE), Uint::<P>::BITS)
384            .retrieve()
385            .wrapping_sub(&Uint::ONE) // Lp(x) = (x-1)/p
386            .wrapping_div(&p_wide_nz)
387            .resize();
388
389        let x: Uint<P> = mp.rem(p);
390
391        x.mul_mod::<M>(hp, p)
392    }
393
394    pub fn decrypt_fast(&self, c: &RawCiphertext<C>) -> RawPlaintext<M> {
395        let pp = self.pp_params.modulus();
396        let qq = self.qq_params.modulus();
397        let p = NonZero::new(self.p).unwrap();
398        let q = NonZero::new(self.q).unwrap();
399
400        let (cp, cq) = decompose(&c.0, pp, qq);
401
402        let mp = self.mp(cp, &p, &self.hp, &self.pp_params);
403        let mq = self.mp(cq, &q, &self.hq, &self.qq_params);
404
405        RawPlaintext(recombine(&self.pinv_q, &mp, &mq, &self.p, &self.q))
406    }
407
408    pub fn extract_n_root(
409        &self,
410        z: &Uint<M>,
411        init_params: &(Uint<P>, Uint<P>, MontyParams<P>, MontyParams<P>),
412    ) -> Uint<M> {
413        let (zp, zq) = decompose(z, &self.p, &self.q);
414        let rp = MontyForm::new(&zp, init_params.2)
415            .pow(&init_params.0)
416            .retrieve();
417        let rq = MontyForm::new(&zq, init_params.3)
418            .pow(&init_params.1)
419            .retrieve();
420
421        recombine(&self.pinv_q, &rp, &rq, &self.p, &self.q)
422    }
423
424    // To reduce recalculation of constant params
425    pub fn extract_n_root_init_params(
426        &self,
427    ) -> (Uint<P>, Uint<P>, MontyParams<P>, MontyParams<P>) {
428        let dk_qminusone = self.q.wrapping_sub(&Uint::ONE);
429        let dk_pminusone = self.p.wrapping_sub(&Uint::ONE);
430        let dk_dn = inv_mod_uint::<M>(self.n.as_ref(), &self.phi);
431
432        let (dk_dp, dk_dq) = decompose(&dk_dn, &dk_pminusone, &dk_qminusone);
433
434        let p_params = MontyParams::new(self.p.to_odd().unwrap());
435        let q_params = MontyParams::new(self.q.to_odd().unwrap());
436
437        (dk_dp, dk_dq, p_params, q_params)
438    }
439}
440
441#[cfg(feature = "serde")]
442impl<const C: usize, const M: usize, const P: usize> SK<C, M, P>
443where
444    Uint<P>: Bounded + Encoding,
445{
446    pub fn to_minimal(&self) -> MinimalSK<P> {
447        MinimalSK {
448            p: self.p,
449            q: self.q,
450        }
451    }
452}
453
454#[cfg(feature = "serde")]
455impl<const C: usize, const M: usize, const P: usize> From<MinimalSK<P>>
456    for SK<C, M, P>
457where
458    Uint<P>: Bounded + Encoding,
459    Uint<C>: Split<Output = Uint<M>>,
460    Uint<C>: From<(Uint<M>, Uint<M>)>,
461    Uint<M>: Concat<Output = Uint<C>>,
462    Uint<M>: From<(Uint<P>, Uint<P>)>,
463    Uint<P>: Concat<Output = Uint<M>>,
464    Uint<M>: Encoding + Split<Output = Uint<P>>,
465{
466    fn from(minimal: MinimalSK<P>) -> Self {
467        SK::from_pq(&minimal.p, &minimal.q)
468    }
469}
470
471impl<const C: usize, const M: usize, const P: usize> Deref for SK<C, M, P> {
472    type Target = PK<C, M>;
473
474    fn deref(&self) -> &Self::Target {
475        &self.pk
476    }
477}
478
479pub fn decompose<const C: usize, const M: usize>(
480    c: &Uint<C>,
481    p: &Uint<M>,
482    q: &Uint<M>,
483) -> (Uint<M>, Uint<M>)
484where
485    Uint<C>: Split<Output = Uint<M>>,
486    Uint<C>: From<(Uint<M>, Uint<M>)>,
487{
488    let (lo, hi) = c.split();
489
490    let cp: Uint<M> = Uint::<C>::from((lo, hi))
491        .rem(&NonZero::new(p.resize::<C>()).unwrap())
492        .resize();
493    let cq: Uint<M> = Uint::<C>::from((lo, hi))
494        .rem(&NonZero::new(q.resize::<C>()).unwrap())
495        .resize();
496
497    (cp, cq)
498}
499
500// Algo 14.71 with Note 14.75 (i)
501pub fn recombine<const M: usize, const P: usize>(
502    p_inv_q: &Uint<P>,
503    v1: &Uint<P>,
504    v2: &Uint<P>,
505    p: &Uint<P>,
506    q: &Uint<P>,
507) -> Uint<M>
508where
509    Uint<P>: Concat<Output = Uint<M>>,
510    Uint<M>: From<(Uint<P>, Uint<P>)>,
511    Uint<M>: Split<Output = Uint<P>>,
512{
513    // C_2 = p^-1 mod q
514    // let c_2 = p.inv_odd_mod(q).0;
515
516    let non_zero_q = NonZero::new(*q).unwrap();
517    // d = (v_2 - v_1) mod q
518    // NOTE: Peforming one mod reduction, as sub_mod assumes
519    // that v2 - v1 is in range [-q, q);
520    let v1_less_q = v1 % non_zero_q;
521    let d = v2.sub_mod(&v1_less_q, q);
522
523    // u = (v_2 - v_1) C_2 mod q
524    let u: Uint<P> = d.mul_mod::<M>(p_inv_q, &non_zero_q);
525
526    // x = v_1 + u p
527    Uint::from(u.split_mul(p)).wrapping_add(&v1.resize())
528}
529
530impl<const C: usize, const M: usize> PK<C, M>
531where
532    Uint<C>: From<(Uint<M>, Uint<M>)>,
533    Uint<M>: Encoding,
534{
535    pub fn from_n(n: &Uint<M>) -> Self {
536        // We generate N as half of L, so hi part of n.square_wide() is zero
537        let nn: Uint<C> = n.square_wide().into();
538        let params = MontyParams::<C>::new_vartime(nn.to_odd().unwrap());
539
540        Self {
541            n: NonZero::new(*n).unwrap(),
542            params,
543        }
544    }
545
546    #[cfg(feature = "serde")]
547    pub fn to_minimal(&self) -> MinimalPK<C, M>
548    where
549        Uint<M>: Bounded + Encoding,
550        Uint<C>: Bounded + Encoding,
551    {
552        MinimalPK { n: self.n }
553    }
554
555    pub fn get_n(&self) -> &NonZero<Uint<M>> {
556        &self.n
557    }
558
559    pub fn get_nn(&self) -> &Uint<C> {
560        self.params.modulus().as_ref()
561    }
562
563    pub fn gen_r(&self, rng: &mut impl CryptoRngCore) -> Uint<M> {
564        loop {
565            let r = Uint::random_mod(rng, &self.n);
566
567            if !r.eq(&Uint::ZERO) {
568                break r;
569            }
570        }
571    }
572
573    pub fn message(&self, bytes: &[u8]) -> Option<RawPlaintext<M>> {
574        let size = std::cmp::min(Uint::<M>::BYTES, bytes.len());
575
576        let mut buf = Uint::<M>::default().to_le_bytes();
577
578        buf.as_mut()[..size].copy_from_slice(&bytes[..size]);
579
580        let m = Uint::<M>::from_le_slice(buf.as_ref());
581
582        self.into_message(&m)
583    }
584
585    pub fn into_message(&self, m: &Uint<M>) -> Option<RawPlaintext<M>> {
586        m.lt(self.n.as_ref()).then_some(RawPlaintext(*m))
587    }
588
589    pub fn encrypt(
590        &self,
591        m: &RawPlaintext<M>,
592        rng: &mut impl CryptoRngCore,
593    ) -> RawCiphertext<C> {
594        let r = self.gen_r(rng);
595        self.encrypt_with_r(m, &r)
596    }
597
598    pub fn encrypt_with_r(
599        &self,
600        m: &RawPlaintext<M>,
601        r: &Uint<M>,
602    ) -> RawCiphertext<C> {
603        let r = MontyForm::new(&r.resize::<C>(), self.params);
604
605        // r^N mod N^2
606        let r_pow_n = r
607            .pow_bounded_exp(self.n.as_ref(), self.n.as_ref().bits_vartime());
608
609        //
610        // g == (1 + N)
611        //
612        // 0 <= m < N
613        //
614        // (1+ N)^m mod N^2 = 1 + m*N mod N^2
615        //
616        // 1 + m*N <= 1 + N^2 - N < N^2
617        //
618        let g_pow_m = MontyForm::new(
619            &Uint::<C>::from(m.0.split_mul(self.n.as_ref()))
620                .wrapping_add(&Uint::ONE),
621            self.params,
622        );
623
624        // c = g^m * r^N mod N^2
625        let c = g_pow_m.mul(&r_pow_n);
626        RawCiphertext(c.retrieve())
627    }
628
629    pub fn add(
630        &self,
631        c_1: &RawCiphertext<C>,
632        c_2: &RawCiphertext<C>,
633    ) -> RawCiphertext<C> {
634        // c_1 * c_2 mod N^2
635        let c_1 = MontyForm::new(&c_1.0, self.params);
636        let c_2 = MontyForm::new(&c_2.0, self.params);
637
638        RawCiphertext(c_1.mul(&c_2).retrieve())
639    }
640
641    pub fn mul(
642        &self,
643        c: &RawCiphertext<C>,
644        m: &RawPlaintext<M>,
645    ) -> RawCiphertext<C> {
646        // c = c^m mod N^2
647        let c = MontyForm::new(&c.0, self.params).pow(&m.0).retrieve();
648
649        RawCiphertext(c)
650    }
651
652    pub fn mul_vartime(
653        &self,
654        c: &RawCiphertext<C>,
655        m: &RawPlaintext<M>,
656    ) -> RawCiphertext<C> {
657        let bits = m.0.bits_vartime();
658
659        // c = c^m mod N^2
660        let c = MontyForm::new(&c.0, self.params)
661            .pow_bounded_exp(&m.0, bits)
662            .retrieve();
663
664        RawCiphertext(c)
665    }
666}
667
668// #[cfg(test)]
669// #[macro_use(quickcheck)]
670// extern crate quickcheck_macros;
671
672#[cfg(test)]
673mod tests {
674    use quickcheck::quickcheck;
675
676    use super::*;
677    use crypto_bigint::{U1024, U2048, U4096};
678    // use rand;
679
680    static P: &str = "95779f0de6b61f3db4c53b1b32aa29e2efb52ebedab7968c37cb10917767547963a121d454c8024dc56f22c523da2dff553ad8a1621ad8f0c093ad09561165fce74fdf977ab1b5f57b4cdcce58f449bcce50cd80359ed0ec4083000c091fbb237e52b8237438ea82932ad0ed7d58fae54ea300461755a0dabc41b5e46af4cee1";
681    static Q: &str = "a80137484b2e0082dbcc520642ea0fcff5652a2367084c052c340b15f0c3ecfeb334024e28e5a982c8971d06f332fc2e91ca985ee37a8e51daa2bae16841b75617a43b52fecea902c5858276ef3ab5282a0635ef34579d5ea2de61bd56f4d7ec26afbcb8ae127c4bc5c0a5799a48d41565a7656fffa056ac3b73ccb3fd0098d1";
682    static R: &str = "1a8b6c80c0cad628e4146e473d49b90b445d09e9a7934431c5cb3e7a43b162018e50b116ed8a0ebaf4b8907a18ad30edfbf573614ededd1bc763265be3a6eeef307d40c2431fa9970590fecd7c8af25d599b513749f998c1ba7a64caeedb2d5dd034f718b9efdf5cf62b129459134b257cf28c61bbe40fc4c20caec7c58b9fa4fa4aea0e2164a398a3c2a21cd012aee7bba3f502b9b10680a36e615d81ef690346d33c05966415c0bff5e6f856ca2bca5786947cca9adfd8300cbf0d2d6f0d4c848b21f46961443fb4519b8ee2dae018c586afe0ee0f430fde643e423cce0cf56f0a59baf6652b250ef6184ffcf09039d34e0a2e0d95c3b24295929e3db4d5f4";
683
684    lazy_static::lazy_static! {
685        static ref SK: SK2048 = {
686            let p: U1024 = from_hex(P);
687            let q: U1024 = from_hex(Q);
688
689            SK2048::from_pq(&p, &q)
690        };
691
692        static ref RU: U2048 = from_hex(R);
693    }
694
695    fn from_hex<const L: usize>(h: &str) -> Uint<L> {
696        let mut r = Uint::<L>::ZERO;
697
698        assert!(!h.is_empty());
699
700        let total = (h.len() + 15) / 16; // round up
701
702        let head = h.len() % 16;
703
704        assert!(total <= L);
705
706        if head == 0 && total == L {
707            return Uint::<L>::from_be_hex(h);
708        }
709
710        let mut h = h.as_bytes();
711        let mut i = total - 1;
712
713        let limbs = r.as_words_mut();
714
715        if head != 0 {
716            let b = &h[..head];
717            let s = std::str::from_utf8(b).unwrap();
718
719            limbs[i] = u64::from_str_radix(s, 16).unwrap();
720            i -= 1;
721            h = &h[head..];
722        }
723
724        for b in h.chunks(16) {
725            let s = std::str::from_utf8(b).unwrap();
726            limbs[i] = u64::from_str_radix(s, 16).unwrap();
727            i -= 1;
728        }
729
730        r
731    }
732
733    #[test]
734    fn add() {
735        fn prop(x: u64, y: u64) -> bool {
736            let mx = SK.into_message(&Uint::from_u64(x)).unwrap();
737            let my = SK.into_message(&Uint::from_u64(y)).unwrap();
738
739            let c1 = SK.encrypt_with_r(&mx, &RU);
740            let c2 = SK.encrypt_with_r(&my, &RU);
741
742            let c3 = SK.add(&c1, &c2);
743
744            let mr = SK.decrypt_fast(&c3);
745
746            mr == RawPlaintext(
747                Uint::from_u64(x).wrapping_add(&Uint::from_u64(y)),
748            )
749        }
750
751        quickcheck(prop as fn(u64, u64) -> bool)
752    }
753
754    #[test]
755    fn mul() {
756        fn prop(x: u64, y: u64) -> bool {
757            let mx = SK.into_message(&Uint::from_u64(x)).unwrap();
758
759            let c1 = SK.encrypt_with_r(&mx, &RU);
760
761            let c3 = SK.mul(&c1, &RawPlaintext(Uint::from_u64(y)));
762
763            let mr = SK.decrypt_fast(&c3);
764
765            mr == RawPlaintext(
766                Uint::from_u64(x).wrapping_mul(&U4096::from_u64(y)),
767            )
768        }
769
770        quickcheck(prop as fn(u64, u64) -> bool)
771    }
772
773    #[test]
774    fn big() {
775        let n = SK.get_n();
776        let m = n.wrapping_sub(&Uint::ONE);
777
778        let m = SK.into_message(&m).unwrap();
779
780        let c = SK.encrypt_with_r(&m, &n.wrapping_sub(&Uint::ONE));
781
782        let d = SK.decrypt(&c);
783
784        assert_eq!(m, d);
785
786        let d = SK.decrypt_fast(&c);
787
788        assert_eq!(m, d);
789    }
790
791    #[test]
792    fn decrypt() {
793        let mut rng = rand::thread_rng();
794
795        let r = SK.gen_r(&mut rng); // random R
796
797        // check that decrypt(encrypt(0)) == 0
798        let m = SK.into_message(&Uint::ZERO).unwrap();
799        let c = SK.encrypt_with_r(&m, &r);
800        let d = SK.decrypt(&c);
801
802        assert_eq!(d, m);
803
804        let d = SK.decrypt_fast(&c);
805
806        assert_eq!(d, m);
807
808        // now test random M in range [1 ..N)
809        for _ in 0..20 {
810            let m = SK.gen_r(&mut rng);
811            let m = SK.into_message(&m).unwrap();
812
813            let c = SK.encrypt_with_r(&m, &r);
814            let d = SK.decrypt(&c);
815
816            assert_eq!(d, m);
817
818            let d = SK.decrypt_fast(&c);
819
820            assert_eq!(d, m);
821        }
822    }
823
824    #[test]
825    fn small() {
826        // numbers from handbook
827        const P: u8 = 11;
828        const Q: u8 = 17;
829        const M: u8 = 175;
830        const R: u8 = 83;
831        const N: u8 = P * Q;
832
833        let pk = PK2048::from_n(&N.into());
834
835        let m = pk.message(&[M]).unwrap();
836
837        let c = pk.encrypt_with_r(&m, &R.into());
838
839        assert_eq!(c, RawCiphertext::from(U4096::from_u64(23911u64)));
840
841        let sk = SK2048::from_pq(&11u8.into(), &17u8.into());
842
843        assert_eq!(sk.decrypt(&c), m);
844        assert_eq!(sk.decrypt_fast(&c), m);
845
846        let c2 = pk.add(&c, &c);
847
848        let m2 = (M as u32 + M as u32) % (N as u32);
849
850        assert_eq!(sk.decrypt(&c2), pk.message(&[m2 as u8]).unwrap());
851        assert_eq!(sk.decrypt_fast(&c2), pk.message(&[m2 as u8]).unwrap());
852
853        let c3 = pk.mul(&c, &pk.message(&[3u8]).unwrap());
854
855        let m3 = ((M as u32) * 3) % (N as u32);
856
857        assert_eq!(sk.decrypt(&c3), pk.message(&[m3 as u8]).unwrap());
858    }
859
860    #[test]
861    fn gen() {
862        let mut rng = rand::thread_rng();
863        let (p, q) = SK2048::gen_pq(&mut rng);
864        let sk = SK2048::from_pq(&p, &q);
865        let _: U2048 = sk.gen_r(&mut rng);
866    }
867
868    #[cfg(feature = "serde")]
869    #[test]
870    fn test_ser_de() {
871        let mut rng = rand::thread_rng();
872        let (p, q) = SK2048::gen_pq(&mut rng);
873
874        let sk = SK2048::from_pq(&p, &q);
875        let pk = sk.public_key();
876        let s1 = serde_json::to_string_pretty(&sk).unwrap();
877        let p1 = serde_json::to_string_pretty(&pk).unwrap();
878        let sk: SK2048 = serde_json::from_str(&s1).unwrap();
879        let pk: PK2048 = serde_json::from_str(&p1).unwrap();
880
881        let s2 = serde_json::to_string_pretty(&sk).unwrap();
882        let p2 = serde_json::to_string_pretty(&pk).unwrap();
883        assert_eq!(s1, s2);
884        assert_eq!(p1, p2);
885
886        let mut bytes = Vec::new();
887        ciborium::into_writer(&sk, &mut bytes).unwrap();
888        let sk1: SK2048 = ciborium::from_reader(bytes.as_slice()).unwrap();
889        assert_eq!(sk.to_minimal(), sk1.to_minimal());
890    }
891}