tfhe_zk_pok/proofs/
pke_v2.rs

1// to follow the notation of the paper
2#![allow(non_snake_case)]
3
4use super::*;
5use crate::backward_compatibility::pke_v2::*;
6use crate::backward_compatibility::BoundVersions;
7use crate::curve_api::{CompressedG1, CompressedG2};
8use crate::four_squares::*;
9use crate::serialization::{
10    try_vec_to_array, InvalidSerializedAffineError, InvalidSerializedPublicParamsError,
11    SerializableGroupElements, SerializablePKEv2PublicParams,
12};
13
14use core::marker::PhantomData;
15use rayon::prelude::*;
16use serde::{Deserialize, Serialize};
17
18fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
19    (0..nbits).map(move |idx| ((x >> idx) & 1) != 0)
20}
21
22/// The CRS of the zk scheme
23#[derive(Clone, Debug, Serialize, Deserialize, Versionize)]
24#[serde(
25    try_from = "SerializablePKEv2PublicParams",
26    into = "SerializablePKEv2PublicParams",
27    bound(
28        deserialize = "PublicParams<G>: TryFrom<SerializablePKEv2PublicParams, Error = InvalidSerializedPublicParamsError>",
29        serialize = "PublicParams<G>: Into<SerializablePKEv2PublicParams>"
30    )
31)]
32#[versionize(try_convert = SerializablePKEv2PublicParams)]
33pub struct PublicParams<G: Curve> {
34    pub(crate) g_lists: GroupElements<G>,
35    pub(crate) D: usize,
36    pub n: usize,
37    pub d: usize,
38    pub k: usize,
39    // We store the square of the bound to avoid rounding on sqrt operations
40    pub B_bound_squared: u128,
41    pub B_inf: u64,
42    pub q: u64,
43    pub t: u64,
44    pub msbs_zero_padding_bit_count: u64,
45    pub bound_type: Bound,
46    pub(crate) hash: [u8; HASH_METADATA_LEN_BYTES],
47    pub(crate) hash_R: [u8; HASH_METADATA_LEN_BYTES],
48    pub(crate) hash_t: [u8; HASH_METADATA_LEN_BYTES],
49    pub(crate) hash_w: [u8; HASH_METADATA_LEN_BYTES],
50    pub(crate) hash_agg: [u8; HASH_METADATA_LEN_BYTES],
51    pub(crate) hash_lmap: [u8; HASH_METADATA_LEN_BYTES],
52    pub(crate) hash_phi: [u8; HASH_METADATA_LEN_BYTES],
53    pub(crate) hash_xi: [u8; HASH_METADATA_LEN_BYTES],
54    pub(crate) hash_z: [u8; HASH_METADATA_LEN_BYTES],
55    pub(crate) hash_chi: [u8; HASH_METADATA_LEN_BYTES],
56}
57
58impl<G: Curve> Compressible for PublicParams<G>
59where
60    GroupElements<G>: Compressible<
61        Compressed = SerializableGroupElements,
62        UncompressError = InvalidSerializedGroupElementsError,
63    >,
64{
65    type Compressed = SerializablePKEv2PublicParams;
66
67    type UncompressError = InvalidSerializedPublicParamsError;
68
69    fn compress(&self) -> Self::Compressed {
70        let PublicParams {
71            g_lists,
72            D,
73            n,
74            d,
75            k,
76            B_bound_squared,
77            B_inf,
78            q,
79            t,
80            msbs_zero_padding_bit_count,
81            bound_type,
82            hash,
83            hash_R,
84            hash_t,
85            hash_w,
86            hash_agg,
87            hash_lmap,
88            hash_phi,
89            hash_xi,
90            hash_z,
91            hash_chi,
92        } = self;
93        SerializablePKEv2PublicParams {
94            g_lists: g_lists.compress(),
95            D: *D,
96            n: *n,
97            d: *d,
98            k: *k,
99            B_inf: *B_inf,
100            B_bound_squared: *B_bound_squared,
101            q: *q,
102            t: *t,
103            msbs_zero_padding_bit_count: *msbs_zero_padding_bit_count,
104            bound_type: *bound_type,
105            hash: hash.to_vec(),
106            hash_R: hash_R.to_vec(),
107            hash_t: hash_t.to_vec(),
108            hash_w: hash_w.to_vec(),
109            hash_agg: hash_agg.to_vec(),
110            hash_lmap: hash_lmap.to_vec(),
111            hash_phi: hash_phi.to_vec(),
112            hash_xi: hash_xi.to_vec(),
113            hash_z: hash_z.to_vec(),
114            hash_chi: hash_chi.to_vec(),
115        }
116    }
117
118    fn uncompress(compressed: Self::Compressed) -> Result<Self, Self::UncompressError> {
119        let SerializablePKEv2PublicParams {
120            g_lists,
121            D,
122            n,
123            d,
124            k,
125            B_bound_squared,
126            B_inf,
127            q,
128            t,
129            msbs_zero_padding_bit_count,
130            bound_type,
131            hash,
132            hash_R,
133            hash_t,
134            hash_w,
135            hash_agg,
136            hash_lmap,
137            hash_phi,
138            hash_xi,
139            hash_z,
140            hash_chi,
141        } = compressed;
142        Ok(Self {
143            g_lists: GroupElements::uncompress(g_lists)?,
144            D,
145            n,
146            d,
147            k,
148            B_bound_squared,
149            B_inf,
150            q,
151            t,
152            msbs_zero_padding_bit_count,
153            bound_type,
154            hash: try_vec_to_array(hash)?,
155            hash_R: try_vec_to_array(hash_R)?,
156            hash_t: try_vec_to_array(hash_t)?,
157            hash_w: try_vec_to_array(hash_w)?,
158            hash_agg: try_vec_to_array(hash_agg)?,
159            hash_lmap: try_vec_to_array(hash_lmap)?,
160            hash_phi: try_vec_to_array(hash_phi)?,
161            hash_xi: try_vec_to_array(hash_xi)?,
162            hash_z: try_vec_to_array(hash_z)?,
163            hash_chi: try_vec_to_array(hash_chi)?,
164        })
165    }
166}
167
168impl<G: Curve> PublicParams<G> {
169    #[allow(clippy::too_many_arguments)]
170    pub fn from_vec(
171        g_list: Vec<Affine<G::Zp, G::G1>>,
172        g_hat_list: Vec<Affine<G::Zp, G::G2>>,
173        d: usize,
174        k: usize,
175        B_inf: u64,
176        q: u64,
177        t: u64,
178        msbs_zero_padding_bit_count: u64,
179        bound_type: Bound,
180        hash: [u8; HASH_METADATA_LEN_BYTES],
181        hash_R: [u8; HASH_METADATA_LEN_BYTES],
182        hash_t: [u8; HASH_METADATA_LEN_BYTES],
183        hash_w: [u8; HASH_METADATA_LEN_BYTES],
184        hash_agg: [u8; HASH_METADATA_LEN_BYTES],
185        hash_lmap: [u8; HASH_METADATA_LEN_BYTES],
186        hash_phi: [u8; HASH_METADATA_LEN_BYTES],
187        hash_xi: [u8; HASH_METADATA_LEN_BYTES],
188        hash_z: [u8; HASH_METADATA_LEN_BYTES],
189        hash_chi: [u8; HASH_METADATA_LEN_BYTES],
190    ) -> Self {
191        let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k);
192        let (n, D, B_bound_squared, _) =
193            compute_crs_params(d, k, B_squared, t, msbs_zero_padding_bit_count, bound_type);
194        Self {
195            g_lists: GroupElements::<G>::from_vec(g_list, g_hat_list),
196            D,
197            n,
198            d,
199            k,
200            B_bound_squared,
201            B_inf,
202            q,
203            t,
204            msbs_zero_padding_bit_count,
205            bound_type,
206            hash,
207            hash_R,
208            hash_t,
209            hash_w,
210            hash_agg,
211            hash_lmap,
212            hash_phi,
213            hash_xi,
214            hash_z,
215            hash_chi,
216        }
217    }
218
219    pub fn exclusive_max_noise(&self) -> u64 {
220        // Here we return the bound without slack because users aren't supposed to generate noise
221        // inside the slack
222        self.B_inf + 1
223    }
224
225    /// Check if the crs can be used to generate or verify a proof
226    ///
227    /// This means checking that the points are:
228    /// - valid points of the curve
229    /// - in the correct subgroup
230    pub fn is_usable(&self) -> bool {
231        self.g_lists.is_valid()
232    }
233}
234
235/// This represents a proof that the given ciphertext is a valid encryptions of the input messages
236/// with the provided public key.
237#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)]
238#[serde(bound(
239    deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>",
240    serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize"
241))]
242#[versionize(ProofVersions)]
243pub struct Proof<G: Curve> {
244    pub(crate) C_hat_e: G::G2,
245    pub(crate) C_e: G::G1,
246    pub(crate) C_r_tilde: G::G1,
247    pub(crate) C_R: G::G1,
248    pub(crate) C_hat_bin: G::G2,
249    pub(crate) C_y: G::G1,
250    pub(crate) C_h1: G::G1,
251    pub(crate) C_h2: G::G1,
252    pub(crate) C_hat_t: G::G2,
253    pub(crate) pi: G::G1,
254    pub(crate) pi_kzg: G::G1,
255
256    pub(crate) compute_load_proof_fields: Option<ComputeLoadProofFields<G>>,
257}
258
259impl<G: Curve> Proof<G> {
260    /// Check if the proof can be used by the Verifier.
261    ///
262    /// This means checking that the points in the proof are:
263    /// - valid points of the curve
264    /// - in the correct subgroup
265    pub fn is_usable(&self) -> bool {
266        let &Proof {
267            C_hat_e,
268            C_e,
269            C_r_tilde,
270            C_R,
271            C_hat_bin,
272            C_y,
273            C_h1,
274            C_h2,
275            C_hat_t,
276            pi,
277            pi_kzg,
278            ref compute_load_proof_fields,
279        } = self;
280
281        C_hat_e.validate_projective()
282            && C_e.validate_projective()
283            && C_r_tilde.validate_projective()
284            && C_R.validate_projective()
285            && C_hat_bin.validate_projective()
286            && C_y.validate_projective()
287            && C_h1.validate_projective()
288            && C_h2.validate_projective()
289            && C_hat_t.validate_projective()
290            && pi.validate_projective()
291            && pi_kzg.validate_projective()
292            && compute_load_proof_fields.as_ref().is_none_or(
293                |&ComputeLoadProofFields { C_hat_h3, C_hat_w }| {
294                    C_hat_h3.validate_projective() && C_hat_w.validate_projective()
295                },
296            )
297    }
298}
299
300/// These fields can be pre-computed on the prover side in the faster Verifier scheme. If that's the
301/// case, they should be included in the proof.
302#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)]
303#[versionize(ComputeLoadProofFieldsVersions)]
304pub(crate) struct ComputeLoadProofFields<G: Curve> {
305    pub(crate) C_hat_h3: G::G2,
306    pub(crate) C_hat_w: G::G2,
307}
308
309#[derive(Serialize, Deserialize, Versionize)]
310#[serde(bound(
311    deserialize = "G: Curve, CompressedG1<G>: serde::Deserialize<'de>, CompressedG2<G>: serde::Deserialize<'de>",
312    serialize = "G: Curve, CompressedG1<G>: serde::Serialize, CompressedG2<G>: serde::Serialize"
313))]
314#[versionize(CompressedProofVersions)]
315pub struct CompressedProof<G: Curve>
316where
317    G::G1: Compressible,
318    G::G2: Compressible,
319{
320    pub(crate) C_hat_e: CompressedG2<G>,
321    pub(crate) C_e: CompressedG1<G>,
322    pub(crate) C_r_tilde: CompressedG1<G>,
323    pub(crate) C_R: CompressedG1<G>,
324    pub(crate) C_hat_bin: CompressedG2<G>,
325    pub(crate) C_y: CompressedG1<G>,
326    pub(crate) C_h1: CompressedG1<G>,
327    pub(crate) C_h2: CompressedG1<G>,
328    pub(crate) C_hat_t: CompressedG2<G>,
329    pub(crate) pi: CompressedG1<G>,
330    pub(crate) pi_kzg: CompressedG1<G>,
331
332    pub(crate) compute_load_proof_fields: Option<CompressedComputeLoadProofFields<G>>,
333}
334
335#[derive(Serialize, Deserialize, Versionize)]
336#[serde(bound(
337    deserialize = "G: Curve, CompressedG1<G>: serde::Deserialize<'de>, CompressedG2<G>: serde::Deserialize<'de>",
338    serialize = "G: Curve, CompressedG1<G>: serde::Serialize, CompressedG2<G>: serde::Serialize"
339))]
340#[versionize(CompressedComputeLoadProofFieldsVersions)]
341pub(crate) struct CompressedComputeLoadProofFields<G: Curve>
342where
343    G::G1: Compressible,
344    G::G2: Compressible,
345{
346    pub(crate) C_hat_h3: CompressedG2<G>,
347    pub(crate) C_hat_w: CompressedG2<G>,
348}
349
350impl<G: Curve> Compressible for Proof<G>
351where
352    G::G1: Compressible<UncompressError = InvalidSerializedAffineError>,
353    G::G2: Compressible<UncompressError = InvalidSerializedAffineError>,
354{
355    type Compressed = CompressedProof<G>;
356
357    type UncompressError = InvalidSerializedAffineError;
358
359    fn compress(&self) -> Self::Compressed {
360        let Proof {
361            C_hat_e,
362            C_e,
363            C_r_tilde,
364            C_R,
365            C_hat_bin,
366            C_y,
367            C_h1,
368            C_h2,
369            C_hat_t,
370            pi,
371            pi_kzg,
372            compute_load_proof_fields,
373        } = self;
374
375        CompressedProof {
376            C_hat_e: C_hat_e.compress(),
377            C_e: C_e.compress(),
378            C_r_tilde: C_r_tilde.compress(),
379            C_R: C_R.compress(),
380            C_hat_bin: C_hat_bin.compress(),
381            C_y: C_y.compress(),
382            C_h1: C_h1.compress(),
383            C_h2: C_h2.compress(),
384            C_hat_t: C_hat_t.compress(),
385            pi: pi.compress(),
386            pi_kzg: pi_kzg.compress(),
387
388            compute_load_proof_fields: compute_load_proof_fields.as_ref().map(
389                |ComputeLoadProofFields { C_hat_h3, C_hat_w }| CompressedComputeLoadProofFields {
390                    C_hat_h3: C_hat_h3.compress(),
391                    C_hat_w: C_hat_w.compress(),
392                },
393            ),
394        }
395    }
396
397    fn uncompress(compressed: Self::Compressed) -> Result<Self, Self::UncompressError> {
398        let CompressedProof {
399            C_hat_e,
400            C_e,
401            C_r_tilde,
402            C_R,
403            C_hat_bin,
404            C_y,
405            C_h1,
406            C_h2,
407            C_hat_t,
408            pi,
409            pi_kzg,
410            compute_load_proof_fields,
411        } = compressed;
412
413        Ok(Proof {
414            C_hat_e: G::G2::uncompress(C_hat_e)?,
415            C_e: G::G1::uncompress(C_e)?,
416            C_r_tilde: G::G1::uncompress(C_r_tilde)?,
417            C_R: G::G1::uncompress(C_R)?,
418            C_hat_bin: G::G2::uncompress(C_hat_bin)?,
419            C_y: G::G1::uncompress(C_y)?,
420            C_h1: G::G1::uncompress(C_h1)?,
421            C_h2: G::G1::uncompress(C_h2)?,
422            C_hat_t: G::G2::uncompress(C_hat_t)?,
423            pi: G::G1::uncompress(pi)?,
424            pi_kzg: G::G1::uncompress(pi_kzg)?,
425
426            compute_load_proof_fields: if let Some(CompressedComputeLoadProofFields {
427                C_hat_h3,
428                C_hat_w,
429            }) = compute_load_proof_fields
430            {
431                Some(ComputeLoadProofFields {
432                    C_hat_h3: G::G2::uncompress(C_hat_h3)?,
433                    C_hat_w: G::G2::uncompress(C_hat_w)?,
434                })
435            } else {
436                None
437            },
438        })
439    }
440}
441
442/// This is the public part of the commitment.
443#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
444pub struct PublicCommit<G: Curve> {
445    /// Mask of the public key
446    a: Vec<i64>,
447    /// Body of the public key
448    b: Vec<i64>,
449    /// Mask of the ciphertexts
450    c1: Vec<i64>,
451    /// Bodies of the ciphertexts
452    c2: Vec<i64>,
453    __marker: PhantomData<G>,
454}
455
456impl<G: Curve> PublicCommit<G> {
457    pub fn new(a: Vec<i64>, b: Vec<i64>, c1: Vec<i64>, c2: Vec<i64>) -> Self {
458        Self {
459            a,
460            b,
461            c1,
462            c2,
463            __marker: PhantomData,
464        }
465    }
466}
467
468#[derive(Clone, Debug)]
469pub struct PrivateCommit<G: Curve> {
470    /// Public key sampling vector
471    r: Vec<i64>,
472    /// Error vector associated with the masks
473    e1: Vec<i64>,
474    /// Input messages
475    m: Vec<i64>,
476    /// Error vector associated with the bodies
477    e2: Vec<i64>,
478    __marker: PhantomData<G>,
479}
480
481#[derive(PartialEq, Copy, Clone, Debug, Serialize, Deserialize, Versionize)]
482#[versionize(BoundVersions)]
483pub enum Bound {
484    GHL,
485    CS,
486}
487
488fn ceil_ilog2(value: u128) -> u64 {
489    value.ilog2() as u64 + if value.is_power_of_two() { 0 } else { 1 }
490}
491
492pub fn compute_crs_params(
493    d: usize,
494    k: usize,
495    B_squared: u128,
496    t: u64,
497    msbs_zero_padding_bit_count: u64,
498    bound_type: Bound,
499) -> (usize, usize, u128, usize) {
500    assert!(
501        k <= d,
502        "Invalid parameters for zk_pok, the maximum number of messages k should be smaller \
503than the lwe dimension d. Please pick a smaller k: k = {k}, d = {d}"
504    );
505
506    let mut B_bound_squared = {
507        (match bound_type {
508            // GHL factor is 9.75, 9.75**2 = 95.0625
509            // Result is multiplied and divided by 10000 to avoid floating point operations.
510            // This could be avoided if one day we need to support bigger params.
511            Bound::GHL => 950625,
512            Bound::CS => 2 * (d as u128 + k as u128) + 4,
513        })
514        .checked_mul(B_squared + (sqr((d + 2) as u64) * (d + k) as u128) / 4)
515        .unwrap_or_else(|| {
516            panic!(
517                "Invalid parameters for zk_pok, B_squared: {B_squared}, d: {d}, k: {k}. \
518Please select a smaller B, d and/or k"
519            )
520        })
521    };
522
523    if bound_type == Bound::GHL {
524        B_bound_squared = B_bound_squared.div_ceil(10000);
525    }
526
527    // Formula is round_up(1 + B_bound.ilog2()).
528    // Since we use B_bound_square, the log is divided by 2
529    let m_bound = 1 + ceil_ilog2(B_bound_squared).div_ceil(2) as usize;
530
531    // m_bound is used to do the bit decomposition of a u64 integer, so we check that it can be
532    // safely used for this
533    assert!(
534        m_bound <= 64,
535        "Invalid parameters for zk_pok, we only support 64 bits integer. \
536The computed m parameter is {m_bound} > 64. Please select a smaller B, d and/or k"
537    );
538
539    // This is also the effective t for encryption
540    let effective_t_for_decomposition = t >> msbs_zero_padding_bit_count;
541
542    // formula in Prove_pp: 2.
543    let D = d + k * (effective_t_for_decomposition.ilog2() as usize);
544    let n = D + 128 * m_bound;
545
546    (n, D, B_bound_squared, m_bound)
547}
548
549/// Convert a bound on the infinite norm  of a vector into a bound on the square of the euclidean
550/// norm.
551///
552/// Use the relationship: `||x||_2 <= sqrt(dim)*||x||_inf`. Since we are only interested in the
553/// squared bound, we avoid the sqrt by returning dim*(||x||_inf)^2.
554fn inf_norm_bound_to_euclidean_squared(B_inf: u64, dim: usize) -> u128 {
555    let norm_squared = sqr(B_inf);
556    norm_squared
557        .checked_mul(dim as u128)
558        .unwrap_or_else(|| panic!("Invalid parameters for zk_pok, B_inf: {B_inf}, d+k: {dim}"))
559}
560
561/// Generates a CRS based on the bound the heuristic provided by the lemma 2 of the paper.
562pub fn crs_gen_ghl<G: Curve>(
563    d: usize,
564    k: usize,
565    B_inf: u64,
566    q: u64,
567    t: u64,
568    msbs_zero_padding_bit_count: u64,
569    rng: &mut dyn RngCore,
570) -> PublicParams<G> {
571    let bound_type = Bound::GHL;
572    let alpha = G::Zp::rand(rng);
573    let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k);
574    let (n, D, B_bound_squared, _) =
575        compute_crs_params(d, k, B_squared, t, msbs_zero_padding_bit_count, bound_type);
576    PublicParams {
577        g_lists: GroupElements::<G>::new(n, alpha),
578        D,
579        n,
580        d,
581        k,
582        B_inf,
583        B_bound_squared,
584        q,
585        t,
586        msbs_zero_padding_bit_count,
587        bound_type,
588        hash: core::array::from_fn(|_| rng.gen()),
589        hash_R: core::array::from_fn(|_| rng.gen()),
590        hash_t: core::array::from_fn(|_| rng.gen()),
591        hash_w: core::array::from_fn(|_| rng.gen()),
592        hash_agg: core::array::from_fn(|_| rng.gen()),
593        hash_lmap: core::array::from_fn(|_| rng.gen()),
594        hash_phi: core::array::from_fn(|_| rng.gen()),
595        hash_xi: core::array::from_fn(|_| rng.gen()),
596        hash_z: core::array::from_fn(|_| rng.gen()),
597        hash_chi: core::array::from_fn(|_| rng.gen()),
598    }
599}
600
601/// Generates a CRS based on the Cauchy-Schwartz inequality. This removes the need of a heuristic
602/// used by GHL (see section 3.5 of the reference paper), but the bound is less strict.
603pub fn crs_gen_cs<G: Curve>(
604    d: usize,
605    k: usize,
606    B_inf: u64,
607    q: u64,
608    t: u64,
609    msbs_zero_padding_bit_count: u64,
610    rng: &mut dyn RngCore,
611) -> PublicParams<G> {
612    let bound_type = Bound::CS;
613    let alpha = G::Zp::rand(rng);
614    let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k);
615    let (n, D, B_bound_squared, _) =
616        compute_crs_params(d, k, B_squared, t, msbs_zero_padding_bit_count, bound_type);
617    PublicParams {
618        g_lists: GroupElements::<G>::new(n, alpha),
619        D,
620        n,
621        d,
622        k,
623        B_bound_squared,
624        B_inf,
625        q,
626        t,
627        msbs_zero_padding_bit_count,
628        bound_type,
629        hash: core::array::from_fn(|_| rng.gen()),
630        hash_R: core::array::from_fn(|_| rng.gen()),
631        hash_t: core::array::from_fn(|_| rng.gen()),
632        hash_w: core::array::from_fn(|_| rng.gen()),
633        hash_agg: core::array::from_fn(|_| rng.gen()),
634        hash_lmap: core::array::from_fn(|_| rng.gen()),
635        hash_phi: core::array::from_fn(|_| rng.gen()),
636        hash_xi: core::array::from_fn(|_| rng.gen()),
637        hash_z: core::array::from_fn(|_| rng.gen()),
638        hash_chi: core::array::from_fn(|_| rng.gen()),
639    }
640}
641
642/// Generates a new CRS. When applied to TFHE, the parameters are mapped like this:
643/// - d: lwe_dimension
644/// - k: max_num_cleartext
645/// - B: noise_bound
646/// - q: ciphertext_modulus
647/// - t: plaintext_modulus
648pub fn crs_gen<G: Curve>(
649    d: usize,
650    k: usize,
651    B: u64,
652    q: u64,
653    t: u64,
654    msbs_zero_padding_bit_count: u64,
655    rng: &mut dyn RngCore,
656) -> PublicParams<G> {
657    crs_gen_cs(d, k, B, q, t, msbs_zero_padding_bit_count, rng)
658}
659
660#[allow(clippy::too_many_arguments)]
661pub fn commit<G: Curve>(
662    a: Vec<i64>,
663    b: Vec<i64>,
664    c1: Vec<i64>,
665    c2: Vec<i64>,
666    r: Vec<i64>,
667    e1: Vec<i64>,
668    m: Vec<i64>,
669    e2: Vec<i64>,
670    public: &PublicParams<G>,
671    rng: &mut dyn RngCore,
672) -> (PublicCommit<G>, PrivateCommit<G>) {
673    let _ = (public, rng);
674    (
675        PublicCommit {
676            a,
677            b,
678            c1,
679            c2,
680            __marker: PhantomData,
681        },
682        PrivateCommit {
683            r,
684            e1,
685            m,
686            e2,
687            __marker: PhantomData,
688        },
689    )
690}
691
692pub fn prove<G: Curve>(
693    public: (&PublicParams<G>, &PublicCommit<G>),
694    private_commit: &PrivateCommit<G>,
695    metadata: &[u8],
696    load: ComputeLoad,
697    rng: &mut dyn RngCore,
698) -> Proof<G> {
699    prove_impl(
700        public,
701        private_commit,
702        metadata,
703        load,
704        rng,
705        ProofSanityCheckMode::Panic,
706    )
707}
708
709fn prove_impl<G: Curve>(
710    public: (&PublicParams<G>, &PublicCommit<G>),
711    private_commit: &PrivateCommit<G>,
712    metadata: &[u8],
713    load: ComputeLoad,
714    rng: &mut dyn RngCore,
715    sanity_check_mode: ProofSanityCheckMode,
716) -> Proof<G> {
717    _ = load;
718    let (
719        &PublicParams {
720            ref g_lists,
721            D: D_max,
722            n,
723            d,
724            k: k_max,
725            B_bound_squared,
726            B_inf,
727            q,
728            t: t_input,
729            msbs_zero_padding_bit_count,
730            bound_type,
731            ref hash,
732            ref hash_R,
733            ref hash_t,
734            ref hash_w,
735            ref hash_agg,
736            ref hash_lmap,
737            ref hash_phi,
738            ref hash_xi,
739            ref hash_z,
740            ref hash_chi,
741        },
742        PublicCommit { a, b, c1, c2, .. },
743    ) = public;
744    let g_list = &*g_lists.g_list.0;
745    let g_hat_list = &*g_lists.g_hat_list.0;
746
747    let PrivateCommit { r, e1, m, e2, .. } = private_commit;
748
749    let k = c2.len();
750
751    let effective_cleartext_t = t_input >> msbs_zero_padding_bit_count;
752
753    let decoded_q = decode_q(q);
754
755    // Recompute some params for our case if k is smaller than the k max
756    let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k);
757    let (_, D, _, m_bound) = compute_crs_params(
758        d,
759        k,
760        B_squared,
761        t_input,
762        msbs_zero_padding_bit_count,
763        bound_type,
764    );
765
766    let e_sqr_norm = e1
767        .iter()
768        .chain(e2)
769        .map(|x| sqr(x.unsigned_abs()))
770        .sum::<u128>();
771
772    if sanity_check_mode == ProofSanityCheckMode::Panic {
773        assert_pke_proof_preconditions(c1, e1, c2, e2, d, k_max, D, D_max);
774        assert!(
775            B_squared >= e_sqr_norm,
776            "squared norm of error ({e_sqr_norm}) exceeds threshold ({B_squared})",
777        );
778    }
779
780    // FIXME: div_round
781    let delta = {
782        // delta takes the encoding with the padding bit
783        (decoded_q / t_input as u128) as u64
784    };
785
786    let g = G::G1::GENERATOR;
787    let g_hat = G::G2::GENERATOR;
788    let gamma_e = G::Zp::rand(rng);
789    let gamma_hat_e = G::Zp::rand(rng);
790    let gamma_r = G::Zp::rand(rng);
791    let gamma_R = G::Zp::rand(rng);
792    let gamma_bin = G::Zp::rand(rng);
793    let gamma_y = G::Zp::rand(rng);
794
795    let r1 = compute_r1(e1, c1, a, r, d, decoded_q);
796    let r2 = compute_r2(e2, c2, m, b, r, d, delta, decoded_q);
797
798    let u64 = |x: i64| x as u64;
799
800    let w_tilde = r
801        .iter()
802        .rev()
803        .map(|&r| r != 0)
804        .chain(
805            m.iter()
806                .flat_map(|&m| bit_iter(u64(m), effective_cleartext_t.ilog2())),
807        )
808        .collect::<Box<[_]>>();
809
810    let v = four_squares(B_squared - e_sqr_norm).map(|v| v as i64);
811
812    let e1_zp = &*e1
813        .iter()
814        .copied()
815        .map(G::Zp::from_i64)
816        .collect::<Box<[_]>>();
817    let e2_zp = &*e2
818        .iter()
819        .copied()
820        .map(G::Zp::from_i64)
821        .collect::<Box<[_]>>();
822    let v_zp = v.map(G::Zp::from_i64);
823
824    let r1_zp = &*r1
825        .iter()
826        .copied()
827        .map(G::Zp::from_i64)
828        .collect::<Box<[_]>>();
829    let r2_zp = &*r2
830        .iter()
831        .copied()
832        .map(G::Zp::from_i64)
833        .collect::<Box<[_]>>();
834
835    let mut scalars = e1_zp
836        .iter()
837        .copied()
838        .chain(e2_zp.iter().copied())
839        .chain(v_zp)
840        .collect::<Box<[_]>>();
841    let C_hat_e =
842        g_hat.mul_scalar(gamma_hat_e) + G::G2::multi_mul_scalar(&g_hat_list[..d + k + 4], &scalars);
843
844    let (C_e, C_r_tilde) = rayon::join(
845        || {
846            scalars.reverse();
847            g.mul_scalar(gamma_e) + G::G1::multi_mul_scalar(&g_list[n - (d + k + 4)..n], &scalars)
848        },
849        || {
850            let scalars = r1_zp
851                .iter()
852                .chain(r2_zp.iter())
853                .copied()
854                .collect::<Box<[_]>>();
855            g.mul_scalar(gamma_r) + G::G1::multi_mul_scalar(&g_list[..d + k], &scalars)
856        },
857    );
858
859    let x_bytes = &*[
860        q.to_le_bytes().as_slice(),
861        (d as u64).to_le_bytes().as_slice(),
862        B_squared.to_le_bytes().as_slice(),
863        t_input.to_le_bytes().as_slice(),
864        msbs_zero_padding_bit_count.to_le_bytes().as_slice(),
865        &*a.iter()
866            .flat_map(|&x| x.to_le_bytes())
867            .collect::<Box<[_]>>(),
868        &*b.iter()
869            .flat_map(|&x| x.to_le_bytes())
870            .collect::<Box<[_]>>(),
871        &*c1.iter()
872            .flat_map(|&x| x.to_le_bytes())
873            .collect::<Box<[_]>>(),
874        &*c2.iter()
875            .flat_map(|&x| x.to_le_bytes())
876            .collect::<Box<[_]>>(),
877    ]
878    .iter()
879    .copied()
880    .flatten()
881    .copied()
882    .collect::<Box<[_]>>();
883
884    // make R_bar a random number generator from the given bytes
885    use sha3::digest::{ExtendableOutput, Update, XofReader};
886
887    let mut hasher = sha3::Shake256::default();
888    for &data in &[
889        hash_R,
890        metadata,
891        x_bytes,
892        C_hat_e.to_le_bytes().as_ref(),
893        C_e.to_le_bytes().as_ref(),
894        C_r_tilde.to_le_bytes().as_ref(),
895    ] {
896        hasher.update(data);
897    }
898    let mut R_bar = hasher.finalize_xof();
899    let R = (0..128 * (2 * (d + k) + 4))
900        .map(|_| {
901            let mut byte = 0u8;
902            R_bar.read(core::slice::from_mut(&mut byte));
903
904            // take two bits
905            match byte & 0b11 {
906                // probability 1/2
907                0 | 1 => 0,
908                // probability 1/4
909                2 => 1,
910                // probability 1/4
911                3 => -1,
912                _ => unreachable!(),
913            }
914        })
915        .collect::<Box<[i8]>>();
916
917    let R = |i: usize, j: usize| R[i + j * 128];
918    let R_bytes = &*(0..128)
919        .flat_map(|i| (0..(2 * (d + k) + 4)).map(move |j| R(i, j) as u8))
920        .collect::<Box<[u8]>>();
921
922    let w_R = (0..128)
923        .map(|i| {
924            let R = |j| R(i, j);
925
926            let mut acc = 0i128;
927            e1.iter()
928                .chain(e2)
929                .chain(&v)
930                .chain(&r1)
931                .chain(&r2)
932                .copied()
933                .enumerate()
934                .for_each(|(j, x)| match R(j) {
935                    0 => {}
936                    1 => acc += x as i128,
937                    -1 => acc -= x as i128,
938                    _ => unreachable!(),
939                });
940            if sanity_check_mode == ProofSanityCheckMode::Panic {
941                assert!(
942                    checked_sqr(acc.unsigned_abs()).unwrap() <= B_bound_squared,
943                    "sqr(acc) ({}) > B_bound_squared ({B_bound_squared})",
944                    checked_sqr(acc.unsigned_abs()).unwrap()
945                );
946            }
947            acc as i64
948        })
949        .collect::<Box<[_]>>();
950
951    let C_R = g.mul_scalar(gamma_R)
952        + G::G1::multi_mul_scalar(
953            &g_list[..128],
954            &w_R.iter()
955                .copied()
956                .map(G::Zp::from_i64)
957                .collect::<Box<[_]>>(),
958        );
959
960    let mut phi = vec![G::Zp::ZERO; 128];
961    G::Zp::hash(
962        &mut phi,
963        &[
964            hash_phi,
965            metadata,
966            x_bytes,
967            R_bytes,
968            C_hat_e.to_le_bytes().as_ref(),
969            C_e.to_le_bytes().as_ref(),
970            C_R.to_le_bytes().as_ref(),
971            C_r_tilde.to_le_bytes().as_ref(),
972        ],
973    );
974    let phi_bytes = &*phi
975        .iter()
976        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
977        .collect::<Box<[_]>>();
978
979    let m = m_bound;
980
981    let w_R_bin = w_R
982        .iter()
983        .flat_map(|&x| bit_iter(x as u64, m as u32))
984        .collect::<Box<[_]>>();
985    let w_bin = w_tilde
986        .iter()
987        .copied()
988        .chain(w_R_bin.iter().copied())
989        .collect::<Box<[_]>>();
990
991    let C_hat_bin = g_hat.mul_scalar(gamma_bin)
992        + g_hat_list
993            .iter()
994            .zip(&*w_bin)
995            .filter(|&(_, &w)| w)
996            .map(|(&x, _)| x)
997            .map(G::G2::projective)
998            .sum::<G::G2>();
999
1000    let mut xi = vec![G::Zp::ZERO; 128];
1001    G::Zp::hash(
1002        &mut xi,
1003        &[
1004            hash_xi,
1005            metadata,
1006            x_bytes,
1007            C_hat_e.to_le_bytes().as_ref(),
1008            C_e.to_le_bytes().as_ref(),
1009            R_bytes,
1010            phi_bytes,
1011            C_R.to_le_bytes().as_ref(),
1012            C_hat_bin.to_le_bytes().as_ref(),
1013            C_r_tilde.to_le_bytes().as_ref(),
1014        ],
1015    );
1016
1017    let xi_bytes = &*xi
1018        .iter()
1019        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
1020        .collect::<Box<[_]>>();
1021
1022    let mut y = vec![G::Zp::ZERO; D + 128 * m];
1023    G::Zp::hash(
1024        &mut y,
1025        &[
1026            hash,
1027            metadata,
1028            x_bytes,
1029            R_bytes,
1030            phi_bytes,
1031            xi_bytes,
1032            C_hat_e.to_le_bytes().as_ref(),
1033            C_e.to_le_bytes().as_ref(),
1034            C_R.to_le_bytes().as_ref(),
1035            C_hat_bin.to_le_bytes().as_ref(),
1036            C_r_tilde.to_le_bytes().as_ref(),
1037        ],
1038    );
1039    let y_bytes = &*y
1040        .iter()
1041        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
1042        .collect::<Box<[_]>>();
1043
1044    if sanity_check_mode == ProofSanityCheckMode::Panic {
1045        assert_eq!(y.len(), w_bin.len());
1046    }
1047    let scalars = y
1048        .iter()
1049        .zip(w_bin.iter())
1050        .rev()
1051        .map(|(&y, &w)| if w { y } else { G::Zp::ZERO })
1052        .collect::<Box<[_]>>();
1053    let C_y =
1054        g.mul_scalar(gamma_y) + G::G1::multi_mul_scalar(&g_list[n - (D + 128 * m)..n], &scalars);
1055
1056    let mut t = vec![G::Zp::ZERO; n];
1057    G::Zp::hash_128bit(
1058        &mut t,
1059        &[
1060            hash_t,
1061            metadata,
1062            x_bytes,
1063            y_bytes,
1064            phi_bytes,
1065            xi_bytes,
1066            C_hat_e.to_le_bytes().as_ref(),
1067            C_e.to_le_bytes().as_ref(),
1068            R_bytes,
1069            C_R.to_le_bytes().as_ref(),
1070            C_hat_bin.to_le_bytes().as_ref(),
1071            C_r_tilde.to_le_bytes().as_ref(),
1072            C_y.to_le_bytes().as_ref(),
1073        ],
1074    );
1075    let t_bytes = &*t
1076        .iter()
1077        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
1078        .collect::<Box<[_]>>();
1079
1080    let mut theta = vec![G::Zp::ZERO; d + k];
1081    G::Zp::hash(
1082        &mut theta,
1083        &[
1084            hash_lmap,
1085            metadata,
1086            x_bytes,
1087            y_bytes,
1088            t_bytes,
1089            phi_bytes,
1090            xi_bytes,
1091            C_hat_e.to_le_bytes().as_ref(),
1092            C_e.to_le_bytes().as_ref(),
1093            R_bytes,
1094            C_R.to_le_bytes().as_ref(),
1095            C_hat_bin.to_le_bytes().as_ref(),
1096            C_r_tilde.to_le_bytes().as_ref(),
1097            C_y.to_le_bytes().as_ref(),
1098        ],
1099    );
1100    let theta_bytes = &*theta
1101        .iter()
1102        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
1103        .collect::<Box<[_]>>();
1104
1105    let mut a_theta = vec![G::Zp::ZERO; D];
1106    compute_a_theta::<G>(&mut a_theta, &theta, a, k, b, effective_cleartext_t, delta);
1107
1108    let t_theta = theta
1109        .iter()
1110        .copied()
1111        .zip(c1.iter().chain(c2.iter()).copied().map(G::Zp::from_i64))
1112        .map(|(x, y)| x * y)
1113        .sum::<G::Zp>();
1114
1115    let mut w = vec![G::Zp::ZERO; n];
1116    G::Zp::hash_128bit(
1117        &mut w,
1118        &[
1119            hash_w,
1120            metadata,
1121            x_bytes,
1122            y_bytes,
1123            t_bytes,
1124            phi_bytes,
1125            xi_bytes,
1126            theta_bytes,
1127            C_hat_e.to_le_bytes().as_ref(),
1128            C_e.to_le_bytes().as_ref(),
1129            R_bytes,
1130            C_R.to_le_bytes().as_ref(),
1131            C_hat_bin.to_le_bytes().as_ref(),
1132            C_r_tilde.to_le_bytes().as_ref(),
1133            C_y.to_le_bytes().as_ref(),
1134        ],
1135    );
1136    let w_bytes = &*w
1137        .iter()
1138        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
1139        .collect::<Box<[_]>>();
1140
1141    let mut delta = [G::Zp::ZERO; 7];
1142    G::Zp::hash(
1143        &mut delta,
1144        &[
1145            hash_agg,
1146            metadata,
1147            x_bytes,
1148            y_bytes,
1149            t_bytes,
1150            phi_bytes,
1151            xi_bytes,
1152            theta_bytes,
1153            w_bytes,
1154            C_hat_e.to_le_bytes().as_ref(),
1155            C_e.to_le_bytes().as_ref(),
1156            R_bytes,
1157            C_R.to_le_bytes().as_ref(),
1158            C_hat_bin.to_le_bytes().as_ref(),
1159            C_r_tilde.to_le_bytes().as_ref(),
1160            C_y.to_le_bytes().as_ref(),
1161        ],
1162    );
1163    let [delta_r, delta_dec, delta_eq, delta_y, delta_theta, delta_e, delta_l] = delta;
1164    let delta_bytes = &*delta
1165        .iter()
1166        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
1167        .collect::<Box<[_]>>();
1168
1169    let mut poly_0_lhs = vec![G::Zp::ZERO; 1 + n];
1170    let mut poly_0_rhs = vec![G::Zp::ZERO; 1 + D + 128 * m];
1171    let mut poly_1_lhs = vec![G::Zp::ZERO; 1 + n];
1172    let mut poly_1_rhs = vec![G::Zp::ZERO; 1 + d + k + 4];
1173    let mut poly_2_lhs = vec![G::Zp::ZERO; 1 + d + k];
1174    let mut poly_2_rhs = vec![G::Zp::ZERO; 1 + n];
1175    let mut poly_3_lhs = vec![G::Zp::ZERO; 1 + 128];
1176    let mut poly_3_rhs = vec![G::Zp::ZERO; 1 + n];
1177    let mut poly_4_lhs = vec![G::Zp::ZERO; 1 + n];
1178    let mut poly_4_rhs = vec![G::Zp::ZERO; 1 + d + k + 4];
1179    let mut poly_5_lhs = vec![G::Zp::ZERO; 1 + n];
1180    let mut poly_5_rhs = vec![G::Zp::ZERO; 1 + n];
1181
1182    let mut xi_scaled = xi.clone();
1183    poly_0_lhs[0] = delta_y * gamma_y;
1184    for j in 0..D + 128 * m {
1185        let p = &mut poly_0_lhs[n - j];
1186
1187        if !w_bin[j] {
1188            *p -= delta_y * y[j];
1189        }
1190
1191        if j < D {
1192            *p += delta_theta * a_theta[j];
1193        }
1194        *p += delta_eq * t[j] * y[j];
1195
1196        if j >= D {
1197            let j = j - D;
1198
1199            let xi = &mut xi_scaled[j / m];
1200            let H_xi = *xi;
1201            *xi = *xi + *xi;
1202
1203            let r = delta_dec * H_xi;
1204
1205            if j % m < m - 1 {
1206                *p += r;
1207            } else {
1208                *p -= r;
1209            }
1210        }
1211    }
1212
1213    poly_0_rhs[0] = gamma_bin;
1214    for j in 0..D + 128 * m {
1215        let p = &mut poly_0_rhs[j + 1];
1216
1217        if w_bin[j] {
1218            *p = G::Zp::ONE;
1219        }
1220    }
1221
1222    poly_1_lhs[0] = delta_l * gamma_e;
1223    for j in 0..d {
1224        let p = &mut poly_1_lhs[n - j];
1225        *p = delta_l * e1_zp[j];
1226    }
1227    for j in 0..k {
1228        let p = &mut poly_1_lhs[n - (d + j)];
1229        *p = delta_l * e2_zp[j];
1230    }
1231    for j in 0..4 {
1232        let p = &mut poly_1_lhs[n - (d + k + j)];
1233        *p = delta_l * v_zp[j];
1234    }
1235
1236    for j in 0..n {
1237        let p = &mut poly_1_lhs[n - j];
1238        let mut acc = delta_e * w[j];
1239        if j < d + k {
1240            acc += delta_theta * theta[j];
1241        }
1242
1243        if j < d + k + 4 {
1244            let mut acc2 = G::Zp::ZERO;
1245            for (i, &phi) in phi.iter().enumerate() {
1246                match R(i, j) {
1247                    0 => {}
1248                    1 => acc2 += phi,
1249                    -1 => acc2 -= phi,
1250                    _ => unreachable!(),
1251                }
1252            }
1253            acc += delta_r * acc2;
1254        }
1255        *p += acc;
1256    }
1257
1258    poly_1_rhs[0] = gamma_hat_e;
1259    for j in 0..d {
1260        let p = &mut poly_1_rhs[1 + j];
1261        *p = e1_zp[j];
1262    }
1263    for j in 0..k {
1264        let p = &mut poly_1_rhs[1 + (d + j)];
1265        *p = e2_zp[j];
1266    }
1267    for j in 0..4 {
1268        let p = &mut poly_1_rhs[1 + (d + k + j)];
1269        *p = v_zp[j];
1270    }
1271
1272    poly_2_lhs[0] = gamma_r;
1273    for j in 0..d {
1274        let p = &mut poly_2_lhs[1 + j];
1275        *p = r1_zp[j];
1276    }
1277    for j in 0..k {
1278        let p = &mut poly_2_lhs[1 + (d + j)];
1279        *p = r2_zp[j];
1280    }
1281
1282    let delta_theta_q = delta_theta * G::Zp::from_u128(decoded_q);
1283    for j in 0..d + k {
1284        let p = &mut poly_2_rhs[n - j];
1285
1286        let mut acc = G::Zp::ZERO;
1287        for (i, &phi) in phi.iter().enumerate() {
1288            match R(i, d + k + 4 + j) {
1289                0 => {}
1290                1 => acc += phi,
1291                -1 => acc -= phi,
1292                _ => unreachable!(),
1293            }
1294        }
1295        *p = delta_r * acc - delta_theta_q * theta[j];
1296    }
1297
1298    poly_3_lhs[0] = gamma_R;
1299    for j in 0..128 {
1300        let p = &mut poly_3_lhs[1 + j];
1301        *p = G::Zp::from_i64(w_R[j]);
1302    }
1303
1304    for j in 0..128 {
1305        let p = &mut poly_3_rhs[n - j];
1306        *p = delta_r * phi[j] + delta_dec * xi[j];
1307    }
1308
1309    poly_4_lhs[0] = delta_e * gamma_e;
1310    for j in 0..d {
1311        let p = &mut poly_4_lhs[n - j];
1312        *p = delta_e * e1_zp[j];
1313    }
1314    for j in 0..k {
1315        let p = &mut poly_4_lhs[n - (d + j)];
1316        *p = delta_e * e2_zp[j];
1317    }
1318    for j in 0..4 {
1319        let p = &mut poly_4_lhs[n - (d + k + j)];
1320        *p = delta_e * v_zp[j];
1321    }
1322
1323    for j in 0..d + k + 4 {
1324        let p = &mut poly_4_rhs[1 + j];
1325        *p = w[j];
1326    }
1327
1328    poly_5_lhs[0] = delta_eq * gamma_y;
1329    for j in 0..D + 128 * m {
1330        let p = &mut poly_5_lhs[n - j];
1331
1332        if w_bin[j] {
1333            *p = delta_eq * y[j];
1334        }
1335    }
1336
1337    for j in 0..n {
1338        let p = &mut poly_5_rhs[1 + j];
1339        *p = t[j];
1340    }
1341
1342    let poly = [
1343        (&poly_0_lhs, &poly_0_rhs),
1344        (&poly_1_lhs, &poly_1_rhs),
1345        (&poly_2_lhs, &poly_2_rhs),
1346        (&poly_3_lhs, &poly_3_rhs),
1347        (&poly_4_lhs, &poly_4_rhs),
1348        (&poly_5_lhs, &poly_5_rhs),
1349    ];
1350
1351    let [mut poly_0, poly_1, poly_2, poly_3, poly_4, poly_5] = {
1352        let tmp: Box<[Vec<G::Zp>; 6]> = poly
1353            .into_par_iter()
1354            .map(|(lhs, rhs)| G::Zp::poly_mul(lhs, rhs))
1355            .collect::<Box<[_]>>()
1356            .try_into()
1357            .unwrap();
1358        *tmp
1359    };
1360
1361    let len = [
1362        poly_0.len(),
1363        poly_1.len(),
1364        poly_2.len(),
1365        poly_3.len(),
1366        poly_4.len(),
1367        poly_5.len(),
1368    ]
1369    .into_iter()
1370    .max()
1371    .unwrap();
1372
1373    poly_0.resize(len, G::Zp::ZERO);
1374
1375    {
1376        let chunk_size = len.div_ceil(rayon::current_num_threads());
1377
1378        poly_0
1379            .par_chunks_mut(chunk_size)
1380            .enumerate()
1381            .for_each(|(j, p0)| {
1382                let offset = j * chunk_size;
1383                let p1 = poly_1.get(offset..).unwrap_or(&[]);
1384                let p2 = poly_2.get(offset..).unwrap_or(&[]);
1385                let p3 = poly_3.get(offset..).unwrap_or(&[]);
1386                let p4 = poly_4.get(offset..).unwrap_or(&[]);
1387                let p5 = poly_5.get(offset..).unwrap_or(&[]);
1388
1389                for (j, p0) in p0.iter_mut().enumerate() {
1390                    if j < p1.len() {
1391                        *p0 += p1[j];
1392                    }
1393                    if j < p2.len() {
1394                        *p0 += p2[j];
1395                    }
1396                    if j < p3.len() {
1397                        *p0 -= p3[j];
1398                    }
1399                    if j < p4.len() {
1400                        *p0 -= p4[j];
1401                    }
1402                    if j < p5.len() {
1403                        *p0 -= p5[j];
1404                    }
1405                }
1406            });
1407    }
1408    let mut P_pi = poly_0;
1409    if P_pi.len() > n + 1 {
1410        P_pi[n + 1] -= delta_theta * t_theta + delta_l * G::Zp::from_u128(B_squared);
1411    }
1412
1413    let pi = if P_pi.is_empty() {
1414        G::G1::ZERO
1415    } else {
1416        g.mul_scalar(P_pi[0]) + G::G1::multi_mul_scalar(&g_list[..P_pi.len() - 1], &P_pi[1..])
1417    };
1418
1419    let mut xi_scaled = xi.clone();
1420    let mut scalars = (0..D + 128 * m)
1421        .map(|j| {
1422            let mut acc = G::Zp::ZERO;
1423            if j < D {
1424                acc += delta_theta * a_theta[j];
1425            }
1426            acc -= delta_y * y[j];
1427            acc += delta_eq * t[j] * y[j];
1428
1429            if j >= D {
1430                let j = j - D;
1431                let xi = &mut xi_scaled[j / m];
1432                let H_xi = *xi;
1433                *xi = *xi + *xi;
1434
1435                let r = delta_dec * H_xi;
1436
1437                if j % m < m - 1 {
1438                    acc += r;
1439                } else {
1440                    acc -= r;
1441                }
1442            }
1443
1444            acc
1445        })
1446        .collect::<Box<[_]>>();
1447    scalars.reverse();
1448    let C_h1 = G::G1::multi_mul_scalar(&g_list[n - (D + 128 * m)..n], &scalars);
1449
1450    let mut scalars = (0..n)
1451        .map(|j| {
1452            let mut acc = G::Zp::ZERO;
1453            if j < d + k {
1454                acc += delta_theta * theta[j];
1455            }
1456
1457            acc += delta_e * w[j];
1458
1459            if j < d + k + 4 {
1460                let mut acc2 = G::Zp::ZERO;
1461                for (i, &phi) in phi.iter().enumerate() {
1462                    match R(i, j) {
1463                        0 => {}
1464                        1 => acc2 += phi,
1465                        -1 => acc2 -= phi,
1466                        _ => unreachable!(),
1467                    }
1468                }
1469                acc += delta_r * acc2;
1470            }
1471            acc
1472        })
1473        .collect::<Box<[_]>>();
1474    scalars.reverse();
1475    let C_h2 = G::G1::multi_mul_scalar(&g_list[..n], &scalars);
1476    let compute_load_proof_fields = match load {
1477        ComputeLoad::Proof => {
1478            let (C_hat_h3, C_hat_w) = rayon::join(
1479                || {
1480                    G::G2::multi_mul_scalar(
1481                        &g_hat_list[n - (d + k)..n],
1482                        &(0..d + k)
1483                            .rev()
1484                            .map(|j| {
1485                                let mut acc = G::Zp::ZERO;
1486                                for (i, &phi) in phi.iter().enumerate() {
1487                                    match R(i, d + k + 4 + j) {
1488                                        0 => {}
1489                                        1 => acc += phi,
1490                                        -1 => acc -= phi,
1491                                        _ => unreachable!(),
1492                                    }
1493                                }
1494                                delta_r * acc - delta_theta_q * theta[j]
1495                            })
1496                            .collect::<Box<[_]>>(),
1497                    )
1498                },
1499                || G::G2::multi_mul_scalar(&g_hat_list[..d + k + 4], &w[..d + k + 4]),
1500            );
1501
1502            Some(ComputeLoadProofFields { C_hat_h3, C_hat_w })
1503        }
1504        ComputeLoad::Verify => None,
1505    };
1506
1507    let byte_generators =
1508        if let Some(ComputeLoadProofFields { C_hat_h3, C_hat_w }) = compute_load_proof_fields {
1509            Some((G::G2::to_le_bytes(C_hat_h3), G::G2::to_le_bytes(C_hat_w)))
1510        } else {
1511            None
1512        };
1513
1514    let (C_hat_h3_bytes, C_hat_w_bytes): (&[u8], &[u8]) =
1515        if let Some((C_hat_h3_bytes_owner, C_hat_w_bytes_owner)) = byte_generators.as_ref() {
1516            (C_hat_h3_bytes_owner.as_ref(), C_hat_w_bytes_owner.as_ref())
1517        } else {
1518            (&[], &[])
1519        };
1520
1521    let C_hat_t = G::G2::multi_mul_scalar(g_hat_list, &t);
1522
1523    let mut z = G::Zp::ZERO;
1524    G::Zp::hash(
1525        core::slice::from_mut(&mut z),
1526        &[
1527            hash_z,
1528            metadata,
1529            x_bytes,
1530            y_bytes,
1531            t_bytes,
1532            phi_bytes,
1533            x_bytes,
1534            theta_bytes,
1535            delta_bytes,
1536            C_hat_e.to_le_bytes().as_ref(),
1537            C_e.to_le_bytes().as_ref(),
1538            R_bytes,
1539            C_R.to_le_bytes().as_ref(),
1540            C_hat_bin.to_le_bytes().as_ref(),
1541            C_r_tilde.to_le_bytes().as_ref(),
1542            C_y.to_le_bytes().as_ref(),
1543            C_h1.to_le_bytes().as_ref(),
1544            C_h2.to_le_bytes().as_ref(),
1545            C_hat_t.to_le_bytes().as_ref(),
1546            C_hat_h3_bytes,
1547            C_hat_w_bytes,
1548        ],
1549    );
1550
1551    let mut P_h1 = vec![G::Zp::ZERO; 1 + n];
1552    let mut P_h2 = vec![G::Zp::ZERO; 1 + n];
1553    let mut P_t = vec![G::Zp::ZERO; 1 + n];
1554    let mut P_h3 = match load {
1555        ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + n],
1556        ComputeLoad::Verify => vec![],
1557    };
1558    let mut P_w = match load {
1559        ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + d + k + 4],
1560        ComputeLoad::Verify => vec![],
1561    };
1562
1563    let mut xi_scaled = xi.clone();
1564    for j in 0..D + 128 * m {
1565        let p = &mut P_h1[n - j];
1566        if j < D {
1567            *p += delta_theta * a_theta[j];
1568        }
1569        *p -= delta_y * y[j];
1570        *p += delta_eq * t[j] * y[j];
1571
1572        if j >= D {
1573            let j = j - D;
1574            let xi = &mut xi_scaled[j / m];
1575            let H_xi = *xi;
1576            *xi = *xi + *xi;
1577
1578            let r = delta_dec * H_xi;
1579
1580            if j % m < m - 1 {
1581                *p += r;
1582            } else {
1583                *p -= r;
1584            }
1585        }
1586    }
1587
1588    for j in 0..n {
1589        let p = &mut P_h2[n - j];
1590
1591        if j < d + k {
1592            *p += delta_theta * theta[j];
1593        }
1594
1595        *p += delta_e * w[j];
1596
1597        if j < d + k + 4 {
1598            let mut acc = G::Zp::ZERO;
1599            for (i, &phi) in phi.iter().enumerate() {
1600                match R(i, j) {
1601                    0 => {}
1602                    1 => acc += phi,
1603                    -1 => acc -= phi,
1604                    _ => unreachable!(),
1605                }
1606            }
1607            *p += delta_r * acc;
1608        }
1609    }
1610
1611    P_t[1..].copy_from_slice(&t);
1612
1613    if !P_h3.is_empty() {
1614        for j in 0..d + k {
1615            let p = &mut P_h3[n - j];
1616
1617            let mut acc = G::Zp::ZERO;
1618            for (i, &phi) in phi.iter().enumerate() {
1619                match R(i, d + k + 4 + j) {
1620                    0 => {}
1621                    1 => acc += phi,
1622                    -1 => acc -= phi,
1623                    _ => unreachable!(),
1624                }
1625            }
1626            *p = delta_r * acc - delta_theta_q * theta[j];
1627        }
1628    }
1629
1630    if !P_w.is_empty() {
1631        P_w[1..].copy_from_slice(&w[..d + k + 4]);
1632    }
1633
1634    let mut p_h1 = G::Zp::ZERO;
1635    let mut p_h2 = G::Zp::ZERO;
1636    let mut p_t = G::Zp::ZERO;
1637    let mut p_h3 = G::Zp::ZERO;
1638    let mut p_w = G::Zp::ZERO;
1639
1640    let mut pow = G::Zp::ONE;
1641    for j in 0..n + 1 {
1642        p_h1 += P_h1[j] * pow;
1643        p_h2 += P_h2[j] * pow;
1644        p_t += P_t[j] * pow;
1645
1646        if j < P_h3.len() {
1647            p_h3 += P_h3[j] * pow;
1648        }
1649        if j < P_w.len() {
1650            p_w += P_w[j] * pow;
1651        }
1652
1653        pow = pow * z;
1654    }
1655
1656    let mut chi = G::Zp::ZERO;
1657    G::Zp::hash(
1658        core::slice::from_mut(&mut chi),
1659        &[
1660            hash_chi,
1661            metadata,
1662            x_bytes,
1663            y_bytes,
1664            t_bytes,
1665            phi_bytes,
1666            xi_bytes,
1667            theta_bytes,
1668            delta_bytes,
1669            C_hat_e.to_le_bytes().as_ref(),
1670            C_e.to_le_bytes().as_ref(),
1671            R_bytes,
1672            C_R.to_le_bytes().as_ref(),
1673            C_hat_bin.to_le_bytes().as_ref(),
1674            C_r_tilde.to_le_bytes().as_ref(),
1675            C_y.to_le_bytes().as_ref(),
1676            C_h1.to_le_bytes().as_ref(),
1677            C_h2.to_le_bytes().as_ref(),
1678            C_hat_t.to_le_bytes().as_ref(),
1679            C_hat_h3_bytes,
1680            C_hat_w_bytes,
1681            z.to_le_bytes().as_ref(),
1682            p_h1.to_le_bytes().as_ref(),
1683            p_h2.to_le_bytes().as_ref(),
1684            p_t.to_le_bytes().as_ref(),
1685        ],
1686    );
1687
1688    let mut Q_kzg = vec![G::Zp::ZERO; 1 + n];
1689    let chi2 = chi * chi;
1690    let chi3 = chi2 * chi;
1691    let chi4 = chi3 * chi;
1692    for j in 1..n + 1 {
1693        Q_kzg[j] = P_h1[j] + chi * P_h2[j] + chi2 * P_t[j];
1694        if j < P_h3.len() {
1695            Q_kzg[j] += chi3 * P_h3[j];
1696        }
1697        if j < P_w.len() {
1698            Q_kzg[j] += chi4 * P_w[j];
1699        }
1700    }
1701    Q_kzg[0] -= p_h1 + chi * p_h2 + chi2 * p_t + chi3 * p_h3 + chi4 * p_w;
1702
1703    // https://en.wikipedia.org/wiki/Polynomial_long_division#Pseudocode
1704    let mut q = vec![G::Zp::ZERO; n];
1705    for j in (0..n).rev() {
1706        Q_kzg[j] = Q_kzg[j] + z * Q_kzg[j + 1];
1707        q[j] = Q_kzg[j + 1];
1708        Q_kzg[j + 1] = G::Zp::ZERO;
1709    }
1710
1711    let pi_kzg = g.mul_scalar(q[0]) + G::G1::multi_mul_scalar(&g_list[..n - 1], &q[1..n]);
1712
1713    Proof {
1714        C_hat_e,
1715        C_e,
1716        C_r_tilde,
1717        C_R,
1718        C_hat_bin,
1719        C_y,
1720        C_h1,
1721        C_h2,
1722        C_hat_t,
1723        pi,
1724        pi_kzg,
1725        compute_load_proof_fields,
1726    }
1727}
1728
1729#[allow(clippy::too_many_arguments)]
1730fn compute_a_theta<G: Curve>(
1731    a_theta: &mut [G::Zp],
1732    theta: &[G::Zp],
1733    a: &[i64],
1734    k: usize,
1735    b: &[i64],
1736    t: u64,
1737    delta: u64,
1738) {
1739    // a_theta = Ã.T theta
1740    //  = [
1741    //    rot(a).T theta1 + phi[d](bar(b)) theta2_1 + ... + phi[d-k+1](bar(b)) theta2_k
1742    //
1743    //    delta g[log t].T theta2_1
1744    //    delta g[log t].T theta2_2
1745    //    ...
1746    //    delta g[log t].T theta2_k
1747    //    ]
1748
1749    let d = a.len();
1750
1751    let theta1 = &theta[..d];
1752    let theta2 = &theta[d..];
1753
1754    {
1755        // rewrite rot(a).T theta1 and rot(b).T theta2.rev() as negacyclic polynomial multiplication
1756        let a_theta = &mut a_theta[..d];
1757
1758        let mut a_rev = vec![G::Zp::ZERO; d].into_boxed_slice();
1759        a_rev[0] = G::Zp::from_i64(a[0]);
1760        for i in 1..d {
1761            a_rev[i] = -G::Zp::from_i64(a[d - i]);
1762        }
1763
1764        let mut b_rev = vec![G::Zp::ZERO; d].into_boxed_slice();
1765        b_rev[0] = G::Zp::from_i64(b[0]);
1766        for i in 1..d {
1767            b_rev[i] = -G::Zp::from_i64(b[d - i]);
1768        }
1769
1770        let theta2_rev = &*(0..d - k)
1771            .map(|_| G::Zp::ZERO)
1772            .chain(theta2.iter().copied().rev())
1773            .collect::<Box<[_]>>();
1774
1775        // compute full poly mul
1776        let (a_rev_theta1, b_rev_theta2_rev) = rayon::join(
1777            || G::Zp::poly_mul(&a_rev, theta1),
1778            || G::Zp::poly_mul(&b_rev, theta2_rev),
1779        );
1780
1781        // make it negacyclic
1782        let min = usize::min(a_theta.len(), a_rev_theta1.len());
1783        a_theta[..min].copy_from_slice(&a_rev_theta1[..min]);
1784
1785        let len = a_theta.len();
1786        let chunk_size = len.div_ceil(rayon::current_num_threads());
1787        a_theta
1788            .par_chunks_mut(chunk_size)
1789            .enumerate()
1790            .for_each(|(j, a_theta)| {
1791                let offset = j * chunk_size;
1792                let a_rev_theta1 = a_rev_theta1.get(offset..).unwrap_or(&[]);
1793                let b_rev_theta2_rev = b_rev_theta2_rev.get(offset..).unwrap_or(&[]);
1794
1795                for (j, a_theta) in a_theta.iter_mut().enumerate() {
1796                    if j + d < a_rev_theta1.len() {
1797                        *a_theta -= a_rev_theta1[j + d];
1798                    }
1799                    if j < b_rev_theta2_rev.len() {
1800                        *a_theta += b_rev_theta2_rev[j];
1801                    }
1802                    if j + d < b_rev_theta2_rev.len() {
1803                        *a_theta -= b_rev_theta2_rev[j + d];
1804                    }
1805                }
1806            });
1807    }
1808
1809    {
1810        let a_theta = &mut a_theta[d..];
1811        let delta = G::Zp::from_u64(delta);
1812        let step = t.ilog2() as usize;
1813
1814        a_theta
1815            .par_chunks_exact_mut(step)
1816            .zip_eq(theta2)
1817            .for_each(|(a_theta, &theta)| {
1818                let mut theta = delta * theta;
1819                let mut first = true;
1820                for a_theta in a_theta {
1821                    if !first {
1822                        theta = theta + theta;
1823                    }
1824                    first = false;
1825                    *a_theta = theta;
1826                }
1827            });
1828    }
1829}
1830
1831#[allow(clippy::result_unit_err)]
1832pub fn verify<G: Curve>(
1833    proof: &Proof<G>,
1834    public: (&PublicParams<G>, &PublicCommit<G>),
1835    metadata: &[u8],
1836) -> Result<(), ()> {
1837    let &Proof {
1838        C_hat_e,
1839        C_e,
1840        C_r_tilde,
1841        C_R,
1842        C_hat_bin,
1843        C_y,
1844        C_h1,
1845        C_h2,
1846        C_hat_t,
1847        pi,
1848        pi_kzg,
1849        ref compute_load_proof_fields,
1850    } = proof;
1851
1852    let pairing = G::Gt::pairing;
1853
1854    let &PublicParams {
1855        ref g_lists,
1856        D: D_max,
1857        n,
1858        d,
1859        k: k_max,
1860        B_bound_squared: _,
1861        B_inf,
1862        q,
1863        t: t_input,
1864        msbs_zero_padding_bit_count,
1865        bound_type,
1866        ref hash,
1867        ref hash_R,
1868        ref hash_t,
1869        ref hash_w,
1870        ref hash_agg,
1871        ref hash_lmap,
1872        ref hash_phi,
1873        ref hash_xi,
1874        ref hash_z,
1875        ref hash_chi,
1876    } = public.0;
1877    let g_list = &*g_lists.g_list.0;
1878    let g_hat_list = &*g_lists.g_hat_list.0;
1879
1880    let decoded_q = decode_q(q);
1881
1882    // FIXME: div_round
1883    let delta = {
1884        // delta takes the encoding with the padding bit
1885        (decoded_q / t_input as u128) as u64
1886    };
1887
1888    let PublicCommit { a, b, c1, c2, .. } = public.1;
1889    let k = c2.len();
1890    if k > k_max {
1891        return Err(());
1892    }
1893
1894    let effective_cleartext_t = t_input >> msbs_zero_padding_bit_count;
1895    let B_squared = inf_norm_bound_to_euclidean_squared(B_inf, d + k);
1896    let (_, D, _, m_bound) = compute_crs_params(
1897        d,
1898        k,
1899        B_squared,
1900        t_input,
1901        msbs_zero_padding_bit_count,
1902        bound_type,
1903    );
1904
1905    let m = m_bound;
1906
1907    if D > D_max {
1908        return Err(());
1909    }
1910
1911    let byte_generators = if let Some(&ComputeLoadProofFields { C_hat_h3, C_hat_w }) =
1912        compute_load_proof_fields.as_ref()
1913    {
1914        Some((G::G2::to_le_bytes(C_hat_h3), G::G2::to_le_bytes(C_hat_w)))
1915    } else {
1916        None
1917    };
1918
1919    let (C_hat_h3_bytes, C_hat_w_bytes): (&[u8], &[u8]) =
1920        if let Some((C_hat_h3_bytes_owner, C_hat_w_bytes_owner)) = byte_generators.as_ref() {
1921            (C_hat_h3_bytes_owner.as_ref(), C_hat_w_bytes_owner.as_ref())
1922        } else {
1923            (&[], &[])
1924        };
1925
1926    let x_bytes = &*[
1927        q.to_le_bytes().as_slice(),
1928        (d as u64).to_le_bytes().as_slice(),
1929        B_squared.to_le_bytes().as_slice(),
1930        t_input.to_le_bytes().as_slice(),
1931        msbs_zero_padding_bit_count.to_le_bytes().as_slice(),
1932        &*a.iter()
1933            .flat_map(|&x| x.to_le_bytes())
1934            .collect::<Box<[_]>>(),
1935        &*b.iter()
1936            .flat_map(|&x| x.to_le_bytes())
1937            .collect::<Box<[_]>>(),
1938        &*c1.iter()
1939            .flat_map(|&x| x.to_le_bytes())
1940            .collect::<Box<[_]>>(),
1941        &*c2.iter()
1942            .flat_map(|&x| x.to_le_bytes())
1943            .collect::<Box<[_]>>(),
1944    ]
1945    .iter()
1946    .copied()
1947    .flatten()
1948    .copied()
1949    .collect::<Box<[_]>>();
1950
1951    // make R_bar a random number generator from the given bytes
1952    use sha3::digest::{ExtendableOutput, Update, XofReader};
1953
1954    let mut hasher = sha3::Shake256::default();
1955    for &data in &[
1956        hash_R,
1957        metadata,
1958        x_bytes,
1959        C_hat_e.to_le_bytes().as_ref(),
1960        C_e.to_le_bytes().as_ref(),
1961        C_r_tilde.to_le_bytes().as_ref(),
1962    ] {
1963        hasher.update(data);
1964    }
1965    let mut R_bar = hasher.finalize_xof();
1966    let R = (0..128 * (2 * (d + k) + 4))
1967        .map(|_| {
1968            let mut byte = 0u8;
1969            R_bar.read(core::slice::from_mut(&mut byte));
1970
1971            // take two bits
1972            match byte & 0b11 {
1973                // probability 1/2
1974                0 | 1 => 0,
1975                // probability 1/4
1976                2 => 1,
1977                // probability 1/4
1978                3 => -1,
1979                _ => unreachable!(),
1980            }
1981        })
1982        .collect::<Box<[i8]>>();
1983
1984    let R = |i: usize, j: usize| R[i + j * 128];
1985    let R_bytes = &*(0..128)
1986        .flat_map(|i| (0..(2 * (d + k) + 4)).map(move |j| R(i, j) as u8))
1987        .collect::<Box<[u8]>>();
1988
1989    let mut phi = vec![G::Zp::ZERO; 128];
1990    G::Zp::hash(
1991        &mut phi,
1992        &[
1993            hash_phi,
1994            metadata,
1995            x_bytes,
1996            R_bytes,
1997            C_hat_e.to_le_bytes().as_ref(),
1998            C_e.to_le_bytes().as_ref(),
1999            C_R.to_le_bytes().as_ref(),
2000            C_r_tilde.to_le_bytes().as_ref(),
2001        ],
2002    );
2003    let phi_bytes = &*phi
2004        .iter()
2005        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
2006        .collect::<Box<[_]>>();
2007
2008    let mut xi = vec![G::Zp::ZERO; 128];
2009    G::Zp::hash(
2010        &mut xi,
2011        &[
2012            hash_xi,
2013            metadata,
2014            x_bytes,
2015            C_hat_e.to_le_bytes().as_ref(),
2016            C_e.to_le_bytes().as_ref(),
2017            R_bytes,
2018            phi_bytes,
2019            C_R.to_le_bytes().as_ref(),
2020            C_hat_bin.to_le_bytes().as_ref(),
2021            C_r_tilde.to_le_bytes().as_ref(),
2022        ],
2023    );
2024    let xi_bytes = &*xi
2025        .iter()
2026        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
2027        .collect::<Box<[_]>>();
2028
2029    let mut y = vec![G::Zp::ZERO; D + 128 * m];
2030    G::Zp::hash(
2031        &mut y,
2032        &[
2033            hash,
2034            metadata,
2035            x_bytes,
2036            R_bytes,
2037            phi_bytes,
2038            xi_bytes,
2039            C_hat_e.to_le_bytes().as_ref(),
2040            C_e.to_le_bytes().as_ref(),
2041            C_R.to_le_bytes().as_ref(),
2042            C_hat_bin.to_le_bytes().as_ref(),
2043            C_r_tilde.to_le_bytes().as_ref(),
2044        ],
2045    );
2046    let y_bytes = &*y
2047        .iter()
2048        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
2049        .collect::<Box<[_]>>();
2050
2051    let mut t = vec![G::Zp::ZERO; n];
2052    G::Zp::hash_128bit(
2053        &mut t,
2054        &[
2055            hash_t,
2056            metadata,
2057            x_bytes,
2058            y_bytes,
2059            phi_bytes,
2060            xi_bytes,
2061            C_hat_e.to_le_bytes().as_ref(),
2062            C_e.to_le_bytes().as_ref(),
2063            R_bytes,
2064            C_R.to_le_bytes().as_ref(),
2065            C_hat_bin.to_le_bytes().as_ref(),
2066            C_r_tilde.to_le_bytes().as_ref(),
2067            C_y.to_le_bytes().as_ref(),
2068        ],
2069    );
2070    let t_bytes = &*t
2071        .iter()
2072        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
2073        .collect::<Box<[_]>>();
2074
2075    let mut theta = vec![G::Zp::ZERO; d + k];
2076    G::Zp::hash(
2077        &mut theta,
2078        &[
2079            hash_lmap,
2080            metadata,
2081            x_bytes,
2082            y_bytes,
2083            t_bytes,
2084            phi_bytes,
2085            xi_bytes,
2086            C_hat_e.to_le_bytes().as_ref(),
2087            C_e.to_le_bytes().as_ref(),
2088            R_bytes,
2089            C_R.to_le_bytes().as_ref(),
2090            C_hat_bin.to_le_bytes().as_ref(),
2091            C_r_tilde.to_le_bytes().as_ref(),
2092            C_y.to_le_bytes().as_ref(),
2093        ],
2094    );
2095    let theta_bytes = &*theta
2096        .iter()
2097        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
2098        .collect::<Box<[_]>>();
2099
2100    let mut w = vec![G::Zp::ZERO; n];
2101    G::Zp::hash_128bit(
2102        &mut w,
2103        &[
2104            hash_w,
2105            metadata,
2106            x_bytes,
2107            y_bytes,
2108            t_bytes,
2109            phi_bytes,
2110            xi_bytes,
2111            theta_bytes,
2112            C_hat_e.to_le_bytes().as_ref(),
2113            C_e.to_le_bytes().as_ref(),
2114            R_bytes,
2115            C_R.to_le_bytes().as_ref(),
2116            C_hat_bin.to_le_bytes().as_ref(),
2117            C_r_tilde.to_le_bytes().as_ref(),
2118            C_y.to_le_bytes().as_ref(),
2119        ],
2120    );
2121    let w_bytes = &*w
2122        .iter()
2123        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
2124        .collect::<Box<[_]>>();
2125
2126    let mut a_theta = vec![G::Zp::ZERO; D];
2127    compute_a_theta::<G>(&mut a_theta, &theta, a, k, b, effective_cleartext_t, delta);
2128
2129    let t_theta = theta
2130        .iter()
2131        .copied()
2132        .zip(c1.iter().chain(c2.iter()).copied().map(G::Zp::from_i64))
2133        .map(|(x, y)| x * y)
2134        .sum::<G::Zp>();
2135
2136    let mut delta = [G::Zp::ZERO; 7];
2137    G::Zp::hash(
2138        &mut delta,
2139        &[
2140            hash_agg,
2141            metadata,
2142            x_bytes,
2143            y_bytes,
2144            t_bytes,
2145            phi_bytes,
2146            xi_bytes,
2147            theta_bytes,
2148            w_bytes,
2149            C_hat_e.to_le_bytes().as_ref(),
2150            C_e.to_le_bytes().as_ref(),
2151            R_bytes,
2152            C_R.to_le_bytes().as_ref(),
2153            C_hat_bin.to_le_bytes().as_ref(),
2154            C_r_tilde.to_le_bytes().as_ref(),
2155            C_y.to_le_bytes().as_ref(),
2156        ],
2157    );
2158    let [delta_r, delta_dec, delta_eq, delta_y, delta_theta, delta_e, delta_l] = delta;
2159    let delta_bytes = &*delta
2160        .iter()
2161        .flat_map(|x| x.to_le_bytes().as_ref().to_vec())
2162        .collect::<Box<[_]>>();
2163
2164    let g = G::G1::GENERATOR;
2165    let g_hat = G::G2::GENERATOR;
2166
2167    let delta_theta_q = delta_theta * G::Zp::from_u128(decoded_q);
2168
2169    let rhs = pairing(pi, g_hat);
2170    let lhs = {
2171        let lhs0 = pairing(C_y.mul_scalar(delta_y) + C_h1, C_hat_bin);
2172        let lhs1 = pairing(C_e.mul_scalar(delta_l) + C_h2, C_hat_e);
2173
2174        let lhs2 = pairing(
2175            C_r_tilde,
2176            match compute_load_proof_fields.as_ref() {
2177                Some(&ComputeLoadProofFields {
2178                    C_hat_h3,
2179                    C_hat_w: _,
2180                }) => C_hat_h3,
2181                None => G::G2::multi_mul_scalar(
2182                    &g_hat_list[n - (d + k)..n],
2183                    &(0..d + k)
2184                        .rev()
2185                        .map(|j| {
2186                            let mut acc = G::Zp::ZERO;
2187                            for (i, &phi) in phi.iter().enumerate() {
2188                                match R(i, d + k + 4 + j) {
2189                                    0 => {}
2190                                    1 => acc += phi,
2191                                    -1 => acc -= phi,
2192                                    _ => unreachable!(),
2193                                }
2194                            }
2195                            delta_r * acc - delta_theta_q * theta[j]
2196                        })
2197                        .collect::<Box<[_]>>(),
2198                ),
2199            },
2200        );
2201        let lhs3 = pairing(
2202            C_R,
2203            G::G2::multi_mul_scalar(
2204                &g_hat_list[n - 128..n],
2205                &(0..128)
2206                    .rev()
2207                    .map(|j| delta_r * phi[j] + delta_dec * xi[j])
2208                    .collect::<Box<[_]>>(),
2209            ),
2210        );
2211        let lhs4 = pairing(
2212            C_e.mul_scalar(delta_e),
2213            match compute_load_proof_fields.as_ref() {
2214                Some(&ComputeLoadProofFields {
2215                    C_hat_h3: _,
2216                    C_hat_w,
2217                }) => C_hat_w,
2218                None => G::G2::multi_mul_scalar(&g_hat_list[..d + k + 4], &w[..d + k + 4]),
2219            },
2220        );
2221        let lhs5 = pairing(C_y.mul_scalar(delta_eq), C_hat_t);
2222        let lhs6 = pairing(
2223            G::G1::projective(g_list[0]),
2224            G::G2::projective(g_hat_list[n - 1]),
2225        )
2226        .mul_scalar(delta_theta * t_theta + delta_l * G::Zp::from_u128(B_squared));
2227
2228        lhs0 + lhs1 + lhs2 - lhs3 - lhs4 - lhs5 - lhs6
2229    };
2230
2231    if lhs != rhs {
2232        return Err(());
2233    }
2234
2235    let mut z = G::Zp::ZERO;
2236    G::Zp::hash(
2237        core::slice::from_mut(&mut z),
2238        &[
2239            hash_z,
2240            metadata,
2241            x_bytes,
2242            y_bytes,
2243            t_bytes,
2244            phi_bytes,
2245            x_bytes,
2246            theta_bytes,
2247            delta_bytes,
2248            C_hat_e.to_le_bytes().as_ref(),
2249            C_e.to_le_bytes().as_ref(),
2250            R_bytes,
2251            C_R.to_le_bytes().as_ref(),
2252            C_hat_bin.to_le_bytes().as_ref(),
2253            C_r_tilde.to_le_bytes().as_ref(),
2254            C_y.to_le_bytes().as_ref(),
2255            C_h1.to_le_bytes().as_ref(),
2256            C_h2.to_le_bytes().as_ref(),
2257            C_hat_t.to_le_bytes().as_ref(),
2258            C_hat_h3_bytes,
2259            C_hat_w_bytes,
2260        ],
2261    );
2262
2263    let load = if compute_load_proof_fields.is_some() {
2264        ComputeLoad::Proof
2265    } else {
2266        ComputeLoad::Verify
2267    };
2268
2269    let mut P_h1 = vec![G::Zp::ZERO; 1 + n];
2270    let mut P_h2 = vec![G::Zp::ZERO; 1 + n];
2271    let mut P_t = vec![G::Zp::ZERO; 1 + n];
2272    let mut P_h3 = match load {
2273        ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + n],
2274        ComputeLoad::Verify => vec![],
2275    };
2276    let mut P_w = match load {
2277        ComputeLoad::Proof => vec![G::Zp::ZERO; 1 + d + k + 4],
2278        ComputeLoad::Verify => vec![],
2279    };
2280
2281    let mut xi_scaled = xi.clone();
2282    for j in 0..D + 128 * m {
2283        let p = &mut P_h1[n - j];
2284        if j < D {
2285            *p += delta_theta * a_theta[j];
2286        }
2287        *p -= delta_y * y[j];
2288        *p += delta_eq * t[j] * y[j];
2289
2290        if j >= D {
2291            let j = j - D;
2292            let xi = &mut xi_scaled[j / m];
2293            let H_xi = *xi;
2294            *xi = *xi + *xi;
2295
2296            let r = delta_dec * H_xi;
2297
2298            if j % m < m - 1 {
2299                *p += r;
2300            } else {
2301                *p -= r;
2302            }
2303        }
2304    }
2305
2306    for j in 0..n {
2307        let p = &mut P_h2[n - j];
2308
2309        if j < d + k {
2310            *p += delta_theta * theta[j];
2311        }
2312
2313        *p += delta_e * w[j];
2314
2315        if j < d + k + 4 {
2316            let mut acc = G::Zp::ZERO;
2317            for (i, &phi) in phi.iter().enumerate() {
2318                match R(i, j) {
2319                    0 => {}
2320                    1 => acc += phi,
2321                    -1 => acc -= phi,
2322                    _ => unreachable!(),
2323                }
2324            }
2325            *p += delta_r * acc;
2326        }
2327    }
2328
2329    P_t[1..].copy_from_slice(&t);
2330
2331    if !P_h3.is_empty() {
2332        for j in 0..d + k {
2333            let p = &mut P_h3[n - j];
2334
2335            let mut acc = G::Zp::ZERO;
2336            for (i, &phi) in phi.iter().enumerate() {
2337                match R(i, d + k + 4 + j) {
2338                    0 => {}
2339                    1 => acc += phi,
2340                    -1 => acc -= phi,
2341                    _ => unreachable!(),
2342                }
2343            }
2344            *p = delta_r * acc - delta_theta_q * theta[j];
2345        }
2346    }
2347
2348    if !P_w.is_empty() {
2349        P_w[1..].copy_from_slice(&w[..d + k + 4]);
2350    }
2351
2352    let mut p_h1 = G::Zp::ZERO;
2353    let mut p_h2 = G::Zp::ZERO;
2354    let mut p_t = G::Zp::ZERO;
2355    let mut p_h3 = G::Zp::ZERO;
2356    let mut p_w = G::Zp::ZERO;
2357
2358    let mut pow = G::Zp::ONE;
2359    for j in 0..n + 1 {
2360        p_h1 += P_h1[j] * pow;
2361        p_h2 += P_h2[j] * pow;
2362        p_t += P_t[j] * pow;
2363
2364        if j < P_h3.len() {
2365            p_h3 += P_h3[j] * pow;
2366        }
2367        if j < P_w.len() {
2368            p_w += P_w[j] * pow;
2369        }
2370
2371        pow = pow * z;
2372    }
2373
2374    let mut chi = G::Zp::ZERO;
2375    G::Zp::hash(
2376        core::slice::from_mut(&mut chi),
2377        &[
2378            hash_chi,
2379            metadata,
2380            x_bytes,
2381            y_bytes,
2382            t_bytes,
2383            phi_bytes,
2384            xi_bytes,
2385            theta_bytes,
2386            delta_bytes,
2387            C_hat_e.to_le_bytes().as_ref(),
2388            C_e.to_le_bytes().as_ref(),
2389            R_bytes,
2390            C_R.to_le_bytes().as_ref(),
2391            C_hat_bin.to_le_bytes().as_ref(),
2392            C_r_tilde.to_le_bytes().as_ref(),
2393            C_y.to_le_bytes().as_ref(),
2394            C_h1.to_le_bytes().as_ref(),
2395            C_h2.to_le_bytes().as_ref(),
2396            C_hat_t.to_le_bytes().as_ref(),
2397            C_hat_h3_bytes,
2398            C_hat_w_bytes,
2399            z.to_le_bytes().as_ref(),
2400            p_h1.to_le_bytes().as_ref(),
2401            p_h2.to_le_bytes().as_ref(),
2402            p_t.to_le_bytes().as_ref(),
2403        ],
2404    );
2405    let chi2 = chi * chi;
2406    let chi3 = chi2 * chi;
2407    let chi4 = chi3 * chi;
2408
2409    let lhs = pairing(
2410        C_h1 + C_h2.mul_scalar(chi) - g.mul_scalar(p_h1 + chi * p_h2),
2411        g_hat,
2412    ) + pairing(
2413        g,
2414        {
2415            let mut C_hat = C_hat_t.mul_scalar(chi2);
2416            if let Some(ComputeLoadProofFields { C_hat_h3, C_hat_w }) = compute_load_proof_fields {
2417                C_hat += C_hat_h3.mul_scalar(chi3);
2418                C_hat += C_hat_w.mul_scalar(chi4);
2419            }
2420            C_hat
2421        } - g_hat.mul_scalar(p_t * chi2 + p_h3 * chi3 + p_w * chi4),
2422    );
2423    let rhs = pairing(
2424        pi_kzg,
2425        G::G2::projective(g_hat_list[0]) - g_hat.mul_scalar(z),
2426    );
2427    if lhs != rhs {
2428        Err(())
2429    } else {
2430        Ok(())
2431    }
2432}
2433
2434#[cfg(test)]
2435mod tests {
2436    use crate::curve_api::{self, bls12_446};
2437
2438    use super::super::test::*;
2439    use super::*;
2440    use rand::rngs::StdRng;
2441    use rand::{Rng, SeedableRng};
2442
2443    type Curve = curve_api::Bls12_446;
2444
2445    /// Compact key params used with pkev2
2446    pub(super) const PKEV2_TEST_PARAMS: PkeTestParameters = PkeTestParameters {
2447        d: 2048,
2448        k: 320,
2449        B: 131072, // 2**17
2450        q: 0,
2451        t: 32, // 2b msg, 2b carry, 1b padding
2452        msbs_zero_padding_bit_count: 1,
2453    };
2454
2455    /// Compact key params used with pkve2 to encrypt a single message
2456    pub(super) const PKEV2_TEST_PARAMS_SINGLE: PkeTestParameters = PkeTestParameters {
2457        d: 2048,
2458        k: 1,
2459        B: 131072, // 2**17
2460        q: 0,
2461        t: 32, // 2b msg, 2b carry, 1b padding
2462        msbs_zero_padding_bit_count: 1,
2463    };
2464
2465    /// Compact key params with limits values to test that there is no overflow, using a GHL bound
2466    pub(super) const BIG_TEST_PARAMS_CS: PkeTestParameters = PkeTestParameters {
2467        d: 2048,
2468        k: 2048,
2469        B: 1125899906842624, // 2**50
2470        q: 0,
2471        t: 4, // 1b message, 1b padding
2472        msbs_zero_padding_bit_count: 1,
2473    };
2474
2475    /// Compact key params with limits values to test that there is no overflow, using a
2476    /// Cauchy-Schwarz bound
2477    pub(super) const BIG_TEST_PARAMS_GHL: PkeTestParameters = PkeTestParameters {
2478        d: 2048,
2479        k: 2048,
2480        B: 281474976710656, // 2**48
2481        q: 0,
2482        t: 4, // 1b message, 1b padding
2483        msbs_zero_padding_bit_count: 1,
2484    };
2485
2486    /// Test that the proof is rejected if we use a different value between encryption and proof
2487    #[test]
2488    fn test_pke() {
2489        let PkeTestParameters {
2490            d,
2491            k,
2492            B,
2493            q,
2494            t,
2495            msbs_zero_padding_bit_count,
2496        } = PKEV2_TEST_PARAMS;
2497
2498        let effective_cleartext_t = t >> msbs_zero_padding_bit_count;
2499
2500        let rng = &mut StdRng::seed_from_u64(0);
2501
2502        let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
2503        let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
2504
2505        let fake_e1 = (0..d)
2506            .map(|_| (rng.gen::<u64>() % (2 * B)) as i64 - B as i64)
2507            .collect::<Vec<_>>();
2508        let fake_e2 = (0..k)
2509            .map(|_| (rng.gen::<u64>() % (2 * B)) as i64 - B as i64)
2510            .collect::<Vec<_>>();
2511
2512        let fake_r = (0..d)
2513            .map(|_| (rng.gen::<u64>() % 2) as i64)
2514            .collect::<Vec<_>>();
2515
2516        let fake_m = (0..k)
2517            .map(|_| (rng.gen::<u64>() % effective_cleartext_t) as i64)
2518            .collect::<Vec<_>>();
2519
2520        let mut fake_metadata = [255u8; METADATA_LEN];
2521        fake_metadata.fill_with(|| rng.gen::<u8>());
2522
2523        // To check management of bigger k_max from CRS during test
2524        let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
2525
2526        let original_public_param =
2527            crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
2528        let public_param_that_was_compressed =
2529            serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap();
2530        let public_param_that_was_not_compressed =
2531            serialize_then_deserialize(&original_public_param, Compress::No).unwrap();
2532
2533        for (
2534            public_param,
2535            use_fake_r,
2536            use_fake_e1,
2537            use_fake_e2,
2538            use_fake_m,
2539            use_fake_metadata_verify,
2540        ) in itertools::iproduct!(
2541            [
2542                original_public_param,
2543                public_param_that_was_compressed,
2544                public_param_that_was_not_compressed,
2545            ],
2546            [false, true],
2547            [false, true],
2548            [false, true],
2549            [false, true],
2550            [false, true]
2551        ) {
2552            let (public_commit, private_commit) = commit(
2553                testcase.a.clone(),
2554                testcase.b.clone(),
2555                ct.c1.clone(),
2556                ct.c2.clone(),
2557                if use_fake_r {
2558                    fake_r.clone()
2559                } else {
2560                    testcase.r.clone()
2561                },
2562                if use_fake_e1 {
2563                    fake_e1.clone()
2564                } else {
2565                    testcase.e1.clone()
2566                },
2567                if use_fake_m {
2568                    fake_m.clone()
2569                } else {
2570                    testcase.m.clone()
2571                },
2572                if use_fake_e2 {
2573                    fake_e2.clone()
2574                } else {
2575                    testcase.e2.clone()
2576                },
2577                &public_param,
2578                rng,
2579            );
2580
2581            for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
2582                let proof = prove(
2583                    (&public_param, &public_commit),
2584                    &private_commit,
2585                    &testcase.metadata,
2586                    load,
2587                    rng,
2588                );
2589
2590                let verify_metadata = if use_fake_metadata_verify {
2591                    &fake_metadata
2592                } else {
2593                    &testcase.metadata
2594                };
2595
2596                assert_eq!(
2597                    verify(&proof, (&public_param, &public_commit), verify_metadata).is_err(),
2598                    use_fake_e1
2599                        || use_fake_e2
2600                        || use_fake_r
2601                        || use_fake_m
2602                        || use_fake_metadata_verify
2603                );
2604            }
2605        }
2606    }
2607
2608    fn prove_and_verify(
2609        testcase: &PkeTestcase,
2610        ct: &PkeTestCiphertext,
2611        crs: &PublicParams<Curve>,
2612        load: ComputeLoad,
2613        sanity_check_mode: ProofSanityCheckMode,
2614        rng: &mut StdRng,
2615    ) -> VerificationResult {
2616        let (public_commit, private_commit) = commit(
2617            testcase.a.clone(),
2618            testcase.b.clone(),
2619            ct.c1.clone(),
2620            ct.c2.clone(),
2621            testcase.r.clone(),
2622            testcase.e1.clone(),
2623            testcase.m.clone(),
2624            testcase.e2.clone(),
2625            crs,
2626            rng,
2627        );
2628
2629        let proof = prove_impl(
2630            (crs, &public_commit),
2631            &private_commit,
2632            &testcase.metadata,
2633            load,
2634            rng,
2635            sanity_check_mode,
2636        );
2637
2638        if verify(&proof, (crs, &public_commit), &testcase.metadata).is_ok() {
2639            VerificationResult::Accept
2640        } else {
2641            VerificationResult::Reject
2642        }
2643    }
2644
2645    fn assert_prove_and_verify(
2646        testcase: &PkeTestcase,
2647        ct: &PkeTestCiphertext,
2648        testcase_name: &str,
2649        crs: &PublicParams<Curve>,
2650        sanity_check_mode: ProofSanityCheckMode,
2651        expected_result: VerificationResult,
2652        rng: &mut StdRng,
2653    ) {
2654        for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
2655            assert_eq!(
2656                prove_and_verify(testcase, ct, crs, load, sanity_check_mode, rng),
2657                expected_result,
2658                "Testcase {testcase_name} with load {load} failed"
2659            )
2660        }
2661    }
2662
2663    #[derive(Clone, Copy)]
2664    enum BoundTestSlackMode {
2665        /// Generate test noise vectors with all coeffs at 0 except one
2666        // Here ||e||inf == ||e||2 so the slack is the biggest, since B is multiplied by
2667        // sqrt(d+k) anyways
2668        Max,
2669        /// Generate test noise vectors with random coeffs and one just around the bound
2670        // Here the slack should be "average"
2671        Avg,
2672        /// Generate test noise vectors with all coeffs equals to B except one at +/-1
2673        // Here the slack should be minimal since ||e||_2 = sqrt(d+k)*||e||_inf, which is exactly
2674        // what we are proving.
2675        Min,
2676    }
2677
2678    impl Display for BoundTestSlackMode {
2679        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2680            match self {
2681                BoundTestSlackMode::Min => write!(f, "min_slack"),
2682                BoundTestSlackMode::Avg => write!(f, "avg_slack"),
2683                BoundTestSlackMode::Max => write!(f, "max_slack"),
2684            }
2685        }
2686    }
2687
2688    #[derive(Clone, Copy)]
2689    enum TestedCoeffOffsetType {
2690        /// Noise term is after the bound, the proof should be refused
2691        After,
2692        /// Noise term is right on the bound, the proof should be accepted
2693        On,
2694        /// Noise term is before the bound, the proof should be accepted
2695        Before,
2696    }
2697
2698    impl Display for TestedCoeffOffsetType {
2699        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2700            match self {
2701                TestedCoeffOffsetType::After => write!(f, "after_bound"),
2702                TestedCoeffOffsetType::On => write!(f, "on_bound"),
2703                TestedCoeffOffsetType::Before => write!(f, "before_bound"),
2704            }
2705        }
2706    }
2707
2708    impl TestedCoeffOffsetType {
2709        fn offset(self) -> i64 {
2710            match self {
2711                TestedCoeffOffsetType::After => 1,
2712                TestedCoeffOffsetType::On => 0,
2713                TestedCoeffOffsetType::Before => -1,
2714            }
2715        }
2716
2717        fn expected_result(self) -> VerificationResult {
2718            match self {
2719                TestedCoeffOffsetType::After => VerificationResult::Reject,
2720                TestedCoeffOffsetType::On => VerificationResult::Accept,
2721                TestedCoeffOffsetType::Before => VerificationResult::Accept,
2722            }
2723        }
2724    }
2725
2726    #[derive(Clone, Copy)]
2727    enum TestedCoeffType {
2728        E1,
2729        E2,
2730    }
2731
2732    impl Display for TestedCoeffType {
2733        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2734            match self {
2735                TestedCoeffType::E1 => write!(f, "e1"),
2736                TestedCoeffType::E2 => write!(f, "e2"),
2737            }
2738        }
2739    }
2740
2741    struct PkeBoundTestcase {
2742        name: String,
2743        testcase: PkeTestcase,
2744        expected_result: VerificationResult,
2745    }
2746
2747    impl PkeBoundTestcase {
2748        fn new(
2749            ref_testcase: &PkeTestcase,
2750            B: u64,
2751            slack_mode: BoundTestSlackMode,
2752            offset_type: TestedCoeffOffsetType,
2753            coeff_type: TestedCoeffType,
2754            rng: &mut StdRng,
2755        ) -> Self {
2756            let mut testcase = ref_testcase.clone();
2757
2758            let d = testcase.e1.len();
2759            let k = testcase.e2.len();
2760
2761            // Select a random index for the tested term
2762            let tested_idx = match coeff_type {
2763                TestedCoeffType::E1 => rng.gen::<usize>() % d,
2764                TestedCoeffType::E2 => rng.gen::<usize>() % k,
2765            };
2766
2767            // Initialize the "good" terms of the error, that are not above the bound
2768            match slack_mode {
2769                BoundTestSlackMode::Max => {
2770                    // In this mode, all the terms are 0 except the tested one
2771                    testcase.e1 = vec![0; d];
2772                    testcase.e2 = vec![0; k];
2773                }
2774                BoundTestSlackMode::Avg => {
2775                    // In this mode we keep the original random vector
2776                }
2777                BoundTestSlackMode::Min => {
2778                    // In this mode all the terms are exactly at the bound
2779                    let good_term = B as i64;
2780                    testcase.e1 = (0..d)
2781                        .map(|_| if rng.gen() { good_term } else { -good_term })
2782                        .collect();
2783                    testcase.e2 = (0..k)
2784                        .map(|_| if rng.gen() { good_term } else { -good_term })
2785                        .collect();
2786                }
2787            };
2788
2789            let B_with_slack_squared = inf_norm_bound_to_euclidean_squared(B, d + k);
2790            let B_with_slack = B_with_slack_squared.isqrt() as u64;
2791
2792            let bound = match slack_mode {
2793                // The slack is maximal, any term above B+slack should be refused
2794                BoundTestSlackMode::Max => B_with_slack as i64,
2795                // The actual accepted bound depends on the content of the test vector
2796                BoundTestSlackMode::Avg => {
2797                    let e_sqr_norm = testcase
2798                        .e1
2799                        .iter()
2800                        .chain(&testcase.e2)
2801                        .map(|x| sqr(x.unsigned_abs()))
2802                        .sum::<u128>();
2803
2804                    let orig_value = match coeff_type {
2805                        TestedCoeffType::E1 => testcase.e1[tested_idx],
2806                        TestedCoeffType::E2 => testcase.e2[tested_idx],
2807                    };
2808
2809                    let bound_squared =
2810                        B_with_slack_squared - (e_sqr_norm - sqr(orig_value as u64));
2811                    bound_squared.isqrt() as i64
2812                }
2813                // There is no slack effect, any term above B should be refused
2814                BoundTestSlackMode::Min => B as i64,
2815            };
2816
2817            let tested_term = bound + offset_type.offset();
2818
2819            match coeff_type {
2820                TestedCoeffType::E1 => testcase.e1[tested_idx] = tested_term,
2821                TestedCoeffType::E2 => testcase.e2[tested_idx] = tested_term,
2822            };
2823
2824            Self {
2825                name: format!("test_{slack_mode}_{offset_type}_{coeff_type}"),
2826                testcase,
2827                expected_result: offset_type.expected_result(),
2828            }
2829        }
2830    }
2831
2832    /// Test that the proof is rejected if we use a noise outside of the bounds, taking the slack
2833    /// into account
2834    #[test]
2835    fn test_pke_bad_noise() {
2836        let PkeTestParameters {
2837            d,
2838            k,
2839            B,
2840            q,
2841            t,
2842            msbs_zero_padding_bit_count,
2843        } = PKEV2_TEST_PARAMS;
2844
2845        let rng = &mut StdRng::seed_from_u64(0);
2846
2847        let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
2848
2849        let crs = crs_gen::<Curve>(d, k, B, q, t, msbs_zero_padding_bit_count, rng);
2850        let crs_max_k = crs_gen::<Curve>(d, d, B, q, t, msbs_zero_padding_bit_count, rng);
2851
2852        let B_with_slack_squared = inf_norm_bound_to_euclidean_squared(B, d + k);
2853        let B_with_slack_upper = B_with_slack_squared.isqrt() as u64 + 1;
2854
2855        // Generate test noise vectors with random coeffs and one completely out of bounds
2856
2857        let mut testcases = Vec::new();
2858        let mut testcase_bad_e1 = testcase.clone();
2859        let bad_idx = rng.gen::<usize>() % d;
2860        let bad_term =
2861            (rng.gen::<u64>() % (i64::MAX as u64 - B_with_slack_upper)) + B_with_slack_upper;
2862        let bad_term = bad_term as i64;
2863
2864        testcase_bad_e1.e1[bad_idx] = if rng.gen() { bad_term } else { -bad_term };
2865
2866        testcases.push(PkeBoundTestcase {
2867            name: "testcase_bad_e1".to_string(),
2868            testcase: testcase_bad_e1,
2869            expected_result: VerificationResult::Reject,
2870        });
2871
2872        let mut testcase_bad_e2 = testcase.clone();
2873        let bad_idx = rng.gen::<usize>() % k;
2874
2875        testcase_bad_e2.e2[bad_idx] = if rng.gen() { bad_term } else { -bad_term };
2876
2877        testcases.push(PkeBoundTestcase {
2878            name: "testcase_bad_e2".to_string(),
2879            testcase: testcase_bad_e2,
2880            expected_result: VerificationResult::Reject,
2881        });
2882
2883        // Generate test vectors with a noise term right around the bound
2884
2885        testcases.extend(
2886            itertools::iproduct!(
2887                [
2888                    BoundTestSlackMode::Min,
2889                    BoundTestSlackMode::Avg,
2890                    BoundTestSlackMode::Max
2891                ],
2892                [
2893                    TestedCoeffOffsetType::Before,
2894                    TestedCoeffOffsetType::On,
2895                    TestedCoeffOffsetType::After
2896                ],
2897                [TestedCoeffType::E1, TestedCoeffType::E2]
2898            )
2899            .map(|(slack_mode, offset_type, coeff_type)| {
2900                PkeBoundTestcase::new(&testcase, B, slack_mode, offset_type, coeff_type, rng)
2901            }),
2902        );
2903
2904        for PkeBoundTestcase {
2905            name,
2906            testcase,
2907            expected_result,
2908        } in testcases
2909        {
2910            let ct = testcase.encrypt_unchecked(PKEV2_TEST_PARAMS);
2911            assert_prove_and_verify(
2912                &testcase,
2913                &ct,
2914                &format!("{name}_crs"),
2915                &crs,
2916                ProofSanityCheckMode::Ignore,
2917                expected_result,
2918                rng,
2919            );
2920            assert_prove_and_verify(
2921                &testcase,
2922                &ct,
2923                &format!("{name}_crs_max_k"),
2924                &crs_max_k,
2925                ProofSanityCheckMode::Ignore,
2926                expected_result,
2927                rng,
2928            );
2929        }
2930    }
2931
2932    /// Compare the computed params with manually calculated ones to check the formula
2933    #[test]
2934    fn test_compute_crs_params() {
2935        let PkeTestParameters {
2936            d,
2937            k,
2938            B,
2939            q: _,
2940            t,
2941            msbs_zero_padding_bit_count,
2942        } = PKEV2_TEST_PARAMS;
2943
2944        let B_squared = inf_norm_bound_to_euclidean_squared(B, d + k);
2945        assert_eq!(B_squared, 40681930227712);
2946
2947        let (n, D, B_bound_squared, m_bound) =
2948            compute_crs_params(d, k, B_squared, t, msbs_zero_padding_bit_count, Bound::GHL);
2949        assert_eq!(n, 6784);
2950        assert_eq!(D, 3328);
2951        assert_eq!(B_bound_squared, 3867562496364372);
2952        assert_eq!(m_bound, 27);
2953
2954        let (n, D, B_bound_squared, m_bound) =
2955            compute_crs_params(d, k, B_squared, t, msbs_zero_padding_bit_count, Bound::CS);
2956        assert_eq!(n, 7168);
2957        assert_eq!(D, 3328);
2958        assert_eq!(B_bound_squared, 192844141830554880);
2959        assert_eq!(m_bound, 30);
2960    }
2961
2962    /// Test that the proof is rejected if we don't have the padding bit set to 0
2963    #[test]
2964    fn test_pke_w_padding_fail_verify() {
2965        let PkeTestParameters {
2966            d,
2967            k,
2968            B,
2969            q,
2970            t,
2971            msbs_zero_padding_bit_count,
2972        } = PKEV2_TEST_PARAMS;
2973
2974        let effective_cleartext_t = t >> msbs_zero_padding_bit_count;
2975
2976        let rng = &mut StdRng::seed_from_u64(0);
2977
2978        let mut testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
2979        // Generate messages with padding set to fail verification
2980        testcase.m = {
2981            let mut tmp = (0..k)
2982                .map(|_| (rng.gen::<u64>() % t) as i64)
2983                .collect::<Vec<_>>();
2984            while tmp.iter().all(|&x| (x as u64) < effective_cleartext_t) {
2985                tmp.fill_with(|| (rng.gen::<u64>() % t) as i64);
2986            }
2987
2988            tmp
2989        };
2990
2991        let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
2992
2993        // To check management of bigger k_max from CRS during test
2994        let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
2995
2996        let original_public_param =
2997            crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
2998        let public_param_that_was_compressed =
2999            serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap();
3000        let public_param_that_was_not_compressed =
3001            serialize_then_deserialize(&original_public_param, Compress::No).unwrap();
3002
3003        for (public_param, test_name) in [
3004            (original_public_param, "original_params"),
3005            (
3006                public_param_that_was_compressed,
3007                "serialized_compressed_params",
3008            ),
3009            (public_param_that_was_not_compressed, "serialize_params"),
3010        ] {
3011            assert_prove_and_verify(
3012                &testcase,
3013                &ct,
3014                test_name,
3015                &public_param,
3016                ProofSanityCheckMode::Panic,
3017                VerificationResult::Reject,
3018                rng,
3019            );
3020        }
3021    }
3022
3023    /// Test verification with modified ciphertexts
3024    #[test]
3025    fn test_bad_ct() {
3026        let PkeTestParameters {
3027            d,
3028            k,
3029            B,
3030            q,
3031            t,
3032            msbs_zero_padding_bit_count,
3033        } = PKEV2_TEST_PARAMS;
3034
3035        let effective_cleartext_t = t >> msbs_zero_padding_bit_count;
3036
3037        let rng = &mut StdRng::seed_from_u64(0);
3038
3039        let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS_SINGLE);
3040        let ct = testcase.encrypt(PKEV2_TEST_PARAMS_SINGLE);
3041
3042        let ct_zero = testcase.sk_encrypt_zero(PKEV2_TEST_PARAMS_SINGLE, rng);
3043
3044        let c1_plus_zero = ct
3045            .c1
3046            .iter()
3047            .zip(ct_zero.iter())
3048            .map(|(a1, az)| a1.wrapping_add(*az))
3049            .collect();
3050        let c2_plus_zero = vec![ct.c2[0].wrapping_add(*ct_zero.last().unwrap())];
3051
3052        let ct_plus_zero = PkeTestCiphertext {
3053            c1: c1_plus_zero,
3054            c2: c2_plus_zero,
3055        };
3056
3057        let m_plus_zero = testcase.decrypt(&ct_plus_zero, PKEV2_TEST_PARAMS_SINGLE);
3058        assert_eq!(testcase.m, m_plus_zero);
3059
3060        let delta = {
3061            let q = decode_q(q) as i128;
3062            // delta takes the encoding with the padding bit
3063            (q / t as i128) as u64
3064        };
3065
3066        let trivial = rng.gen::<u64>() % effective_cleartext_t;
3067        let trivial_pt = trivial * delta;
3068        let c2_plus_trivial = vec![ct.c2[0].wrapping_add(trivial_pt as i64)];
3069
3070        let ct_plus_trivial = PkeTestCiphertext {
3071            c1: ct.c1.clone(),
3072            c2: c2_plus_trivial,
3073        };
3074
3075        let m_plus_trivial = testcase.decrypt(&ct_plus_trivial, PKEV2_TEST_PARAMS_SINGLE);
3076        assert_eq!(testcase.m[0] + trivial as i64, m_plus_trivial[0]);
3077
3078        let crs = crs_gen::<Curve>(d, k, B, q, t, msbs_zero_padding_bit_count, rng);
3079
3080        // Test proving with one ct and verifying another
3081        let (public_commit_proof, private_commit) = commit(
3082            testcase.a.clone(),
3083            testcase.b.clone(),
3084            ct.c1.clone(),
3085            ct.c2.clone(),
3086            testcase.r.clone(),
3087            testcase.e1.clone(),
3088            testcase.m.clone(),
3089            testcase.e2.clone(),
3090            &crs,
3091            rng,
3092        );
3093
3094        let (public_commit_verify_zero, _) = commit(
3095            testcase.a.clone(),
3096            testcase.b.clone(),
3097            ct_plus_zero.c1.clone(),
3098            ct_plus_zero.c2.clone(),
3099            testcase.r.clone(),
3100            testcase.e1.clone(),
3101            testcase.m.clone(),
3102            testcase.e2.clone(),
3103            &crs,
3104            rng,
3105        );
3106
3107        let (public_commit_verify_trivial, _) = commit(
3108            testcase.a.clone(),
3109            testcase.b.clone(),
3110            ct_plus_trivial.c1.clone(),
3111            ct_plus_trivial.c2.clone(),
3112            testcase.r.clone(),
3113            testcase.e1.clone(),
3114            testcase.m.clone(),
3115            testcase.e2.clone(),
3116            &crs,
3117            rng,
3118        );
3119
3120        for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
3121            let proof = prove(
3122                (&crs, &public_commit_proof),
3123                &private_commit,
3124                &testcase.metadata,
3125                load,
3126                rng,
3127            );
3128
3129            assert!(verify(
3130                &proof,
3131                (&crs, &public_commit_verify_zero),
3132                &testcase.metadata
3133            )
3134            .is_err());
3135
3136            assert!(verify(
3137                &proof,
3138                (&crs, &public_commit_verify_trivial),
3139                &testcase.metadata
3140            )
3141            .is_err());
3142        }
3143    }
3144
3145    /// Test encryption of a message where the delta used for encryption is not the one used for
3146    /// proof/verify
3147    #[test]
3148    fn test_bad_delta() {
3149        let PkeTestParameters {
3150            d,
3151            k,
3152            B,
3153            q,
3154            t,
3155            msbs_zero_padding_bit_count,
3156        } = PKEV2_TEST_PARAMS;
3157
3158        let effective_cleartext_t = t >> msbs_zero_padding_bit_count;
3159
3160        let rng = &mut StdRng::seed_from_u64(0);
3161
3162        let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
3163        let mut testcase_bad_delta = testcase.clone();
3164
3165        // Make sure that the messages lower bit is set so the change of delta has an impact on the
3166        // validity of the ct
3167        testcase_bad_delta.m = (0..k)
3168            .map(|_| (rng.gen::<u64>() % effective_cleartext_t) as i64 | 1)
3169            .collect::<Vec<_>>();
3170
3171        let mut params_bad_delta = PKEV2_TEST_PARAMS;
3172        params_bad_delta.t *= 2; // Multiply t by 2 to "spill" 1 bit of message into the noise
3173
3174        // Encrypt using wrong delta
3175        let ct_bad_delta = testcase_bad_delta.encrypt(params_bad_delta);
3176
3177        // Prove using a crs built using the "right" delta
3178        let crs = crs_gen::<Curve>(d, k, B, q, t, msbs_zero_padding_bit_count, rng);
3179
3180        assert_prove_and_verify(
3181            &testcase,
3182            &ct_bad_delta,
3183            "testcase_bad_delta",
3184            &crs,
3185            ProofSanityCheckMode::Panic,
3186            VerificationResult::Reject,
3187            rng,
3188        );
3189    }
3190
3191    /// Test encryption of a message with params that are at the limits of what is supported
3192    #[test]
3193    fn test_big_params() {
3194        let rng = &mut StdRng::seed_from_u64(0);
3195
3196        for bound in [Bound::CS, Bound::GHL] {
3197            let params = match bound {
3198                Bound::GHL => BIG_TEST_PARAMS_GHL,
3199                Bound::CS => BIG_TEST_PARAMS_CS,
3200            };
3201            let PkeTestParameters {
3202                d,
3203                k,
3204                B,
3205                q,
3206                t,
3207                msbs_zero_padding_bit_count,
3208            } = params;
3209
3210            let testcase = PkeTestcase::gen(rng, params);
3211            let ct = testcase.encrypt(params);
3212
3213            // Check that there is no overflow with both bounds
3214            let crs = match bound {
3215                Bound::GHL => crs_gen_ghl::<Curve>(d, k, B, q, t, msbs_zero_padding_bit_count, rng),
3216                Bound::CS => crs_gen_cs::<Curve>(d, k, B, q, t, msbs_zero_padding_bit_count, rng),
3217            };
3218
3219            assert_prove_and_verify(
3220                &testcase,
3221                &ct,
3222                &format!("testcase_big_params_{bound:?}"),
3223                &crs,
3224                ProofSanityCheckMode::Panic,
3225                VerificationResult::Accept,
3226                rng,
3227            );
3228        }
3229    }
3230
3231    /// Test compression of proofs
3232    #[test]
3233    fn test_proof_compression() {
3234        let PkeTestParameters {
3235            d,
3236            k,
3237            B,
3238            q,
3239            t,
3240            msbs_zero_padding_bit_count,
3241        } = PKEV2_TEST_PARAMS;
3242
3243        let rng = &mut StdRng::seed_from_u64(0);
3244
3245        let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
3246        let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
3247
3248        let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
3249
3250        let public_param = crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
3251
3252        let (public_commit, private_commit) = commit(
3253            testcase.a.clone(),
3254            testcase.b.clone(),
3255            ct.c1.clone(),
3256            ct.c2.clone(),
3257            testcase.r.clone(),
3258            testcase.e1.clone(),
3259            testcase.m.clone(),
3260            testcase.e2.clone(),
3261            &public_param,
3262            rng,
3263        );
3264
3265        for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
3266            let proof = prove(
3267                (&public_param, &public_commit),
3268                &private_commit,
3269                &testcase.metadata,
3270                load,
3271                rng,
3272            );
3273
3274            let compressed_proof = bincode::serialize(&proof.compress()).unwrap();
3275            let proof =
3276                Proof::uncompress(bincode::deserialize(&compressed_proof).unwrap()).unwrap();
3277
3278            verify(&proof, (&public_param, &public_commit), &testcase.metadata).unwrap()
3279        }
3280    }
3281
3282    /// Test the `is_usable` method, that checks the correctness of the EC points in the proof
3283    #[test]
3284    fn test_proof_usable() {
3285        let PkeTestParameters {
3286            d,
3287            k,
3288            B,
3289            q,
3290            t,
3291            msbs_zero_padding_bit_count,
3292        } = PKEV2_TEST_PARAMS;
3293
3294        let rng = &mut StdRng::seed_from_u64(0);
3295
3296        let testcase = PkeTestcase::gen(rng, PKEV2_TEST_PARAMS);
3297        let ct = testcase.encrypt(PKEV2_TEST_PARAMS);
3298
3299        let crs_k = k + 1 + (rng.gen::<usize>() % (d - k));
3300
3301        let public_param = crs_gen::<Curve>(d, crs_k, B, q, t, msbs_zero_padding_bit_count, rng);
3302
3303        let (public_commit, private_commit) = commit(
3304            testcase.a.clone(),
3305            testcase.b.clone(),
3306            ct.c1.clone(),
3307            ct.c2.clone(),
3308            testcase.r.clone(),
3309            testcase.e1.clone(),
3310            testcase.m.clone(),
3311            testcase.e2.clone(),
3312            &public_param,
3313            rng,
3314        );
3315
3316        for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
3317            let valid_proof = prove(
3318                (&public_param, &public_commit),
3319                &private_commit,
3320                &testcase.metadata,
3321                load,
3322                rng,
3323            );
3324
3325            let compressed_proof = bincode::serialize(&valid_proof.compress()).unwrap();
3326            let proof_that_was_compressed: Proof<Curve> =
3327                Proof::uncompress(bincode::deserialize(&compressed_proof).unwrap()).unwrap();
3328
3329            assert!(valid_proof.is_usable());
3330            assert!(proof_that_was_compressed.is_usable());
3331
3332            let not_on_curve_g1 = bls12_446::G1::projective(bls12_446::G1Affine {
3333                inner: point_not_on_curve(rng),
3334            });
3335
3336            let not_on_curve_g2 = bls12_446::G2::projective(bls12_446::G2Affine {
3337                inner: point_not_on_curve(rng),
3338            });
3339
3340            let not_in_group_g1 = bls12_446::G1::projective(bls12_446::G1Affine {
3341                inner: point_on_curve_wrong_subgroup(rng),
3342            });
3343
3344            let not_in_group_g2 = bls12_446::G2::projective(bls12_446::G2Affine {
3345                inner: point_on_curve_wrong_subgroup(rng),
3346            });
3347
3348            {
3349                let mut proof = valid_proof.clone();
3350                proof.C_hat_e = not_on_curve_g2;
3351                assert!(!proof.is_usable());
3352                proof.C_hat_e = not_in_group_g2;
3353                assert!(!proof.is_usable());
3354            }
3355
3356            {
3357                let mut proof = valid_proof.clone();
3358                proof.C_e = not_on_curve_g1;
3359                assert!(!proof.is_usable());
3360                proof.C_e = not_in_group_g1;
3361                assert!(!proof.is_usable());
3362            }
3363
3364            {
3365                let mut proof = valid_proof.clone();
3366                proof.C_r_tilde = not_on_curve_g1;
3367                assert!(!proof.is_usable());
3368                proof.C_r_tilde = not_in_group_g1;
3369                assert!(!proof.is_usable());
3370            }
3371
3372            {
3373                let mut proof = valid_proof.clone();
3374                proof.C_R = not_on_curve_g1;
3375                assert!(!proof.is_usable());
3376                proof.C_R = not_in_group_g1;
3377                assert!(!proof.is_usable());
3378            }
3379
3380            {
3381                let mut proof = valid_proof.clone();
3382                proof.C_hat_bin = not_on_curve_g2;
3383                assert!(!proof.is_usable());
3384                proof.C_hat_bin = not_in_group_g2;
3385                assert!(!proof.is_usable());
3386            }
3387
3388            {
3389                let mut proof = valid_proof.clone();
3390                proof.C_y = not_on_curve_g1;
3391                assert!(!proof.is_usable());
3392                proof.C_y = not_in_group_g1;
3393                assert!(!proof.is_usable());
3394            }
3395
3396            {
3397                let mut proof = valid_proof.clone();
3398                proof.C_h1 = not_on_curve_g1;
3399                assert!(!proof.is_usable());
3400                proof.C_h1 = not_in_group_g1;
3401                assert!(!proof.is_usable());
3402            }
3403
3404            {
3405                let mut proof = valid_proof.clone();
3406                proof.C_h2 = not_on_curve_g1;
3407                assert!(!proof.is_usable());
3408                proof.C_h2 = not_in_group_g1;
3409                assert!(!proof.is_usable());
3410            }
3411
3412            {
3413                let mut proof = valid_proof.clone();
3414                proof.C_hat_t = not_on_curve_g2;
3415                assert!(!proof.is_usable());
3416                proof.C_hat_t = not_in_group_g2;
3417                assert!(!proof.is_usable());
3418            }
3419
3420            {
3421                let mut proof = valid_proof.clone();
3422                proof.pi = not_on_curve_g1;
3423                assert!(!proof.is_usable());
3424                proof.pi = not_in_group_g1;
3425                assert!(!proof.is_usable());
3426            }
3427
3428            {
3429                let mut proof = valid_proof.clone();
3430                proof.pi_kzg = not_on_curve_g1;
3431                assert!(!proof.is_usable());
3432                proof.pi_kzg = not_in_group_g1;
3433                assert!(!proof.is_usable());
3434            }
3435
3436            if let Some(ref valid_compute_proof_fields) = valid_proof.compute_load_proof_fields {
3437                {
3438                    let mut proof = valid_proof.clone();
3439                    proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
3440                        C_hat_h3: not_on_curve_g2,
3441                        C_hat_w: valid_compute_proof_fields.C_hat_w,
3442                    });
3443
3444                    assert!(!proof.is_usable());
3445                    proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
3446                        C_hat_h3: not_in_group_g2,
3447                        C_hat_w: valid_compute_proof_fields.C_hat_w,
3448                    });
3449
3450                    assert!(!proof.is_usable());
3451                }
3452
3453                {
3454                    let mut proof = valid_proof.clone();
3455                    proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
3456                        C_hat_h3: valid_compute_proof_fields.C_hat_h3,
3457                        C_hat_w: not_on_curve_g2,
3458                    });
3459
3460                    assert!(!proof.is_usable());
3461
3462                    proof.compute_load_proof_fields = Some(ComputeLoadProofFields {
3463                        C_hat_h3: valid_compute_proof_fields.C_hat_h3,
3464                        C_hat_w: not_in_group_g2,
3465                    });
3466
3467                    assert!(!proof.is_usable());
3468                }
3469            }
3470        }
3471    }
3472}