Skip to main content

sigma_proofs/
composition.rs

1//! # Protocol Composition with AND/OR Logic
2//!
3//! This module defines the [`ComposedRelation`] enum, which generalizes the [`CanonicalLinearRelation`]
4//! by enabling compositional logic between multiple proof instances.
5//!
6//! Specifically, it supports:
7//! - Simple atomic proofs (e.g., discrete logarithm, Pedersen commitments)
8//! - Conjunctions (`And`) of multiple sub-protocols
9//! - Disjunctions (`Or`) of multiple sub-protocols
10//! - Thresholds (`Threshold`) over multiple sub-protocols
11//!
12//! ## Example Composition
13//!
14//! ```ignore
15//! And(
16//!    Or(dleq, pedersen_commitment),
17//!    Simple(discrete_logarithm),
18//!    And(pedersen_commitment_dleq, bbs_blind_commitment_computation)
19//! )
20//! ```
21
22use alloc::{vec, vec::Vec};
23use ff::{Field, PrimeField};
24use group::prime::PrimeGroup;
25use sha3::{Digest, Sha3_256};
26use spongefish::{
27    Decoding, Encoding, NargDeserialize, NargSerialize, VerificationError, VerificationResult,
28};
29use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
30
31use crate::errors::InvalidInstance;
32use crate::traits::ScalarRng;
33use crate::MultiScalarMul;
34use crate::{
35    errors::Error,
36    fiat_shamir::Nizk,
37    linear_relation::{CanonicalLinearRelation, LinearRelation},
38    traits::{SigmaProtocol, SigmaProtocolSimulator},
39};
40
41/// A protocol proving knowledge of a witness for a composition of linear relations.
42///
43/// This implementation generalizes [`CanonicalLinearRelation`] by using AND/OR links.
44///
45/// # Type Parameters
46/// - `G`: A cryptographic group implementing [`group::Group`] and [`group::GroupEncoding`].
47#[derive(Clone)]
48pub enum ComposedRelation<G: PrimeGroup> {
49    Simple(CanonicalLinearRelation<G>),
50    And(Vec<ComposedRelation<G>>),
51    Or(Vec<ComposedRelation<G>>),
52    Threshold(usize, Vec<ComposedRelation<G>>),
53}
54
55impl<G: PrimeGroup + ConstantTimeEq + ConditionallySelectable> ComposedRelation<G> {
56    /// Create a [ComposedRelation] for an AND relation from the given list of relations.
57    pub fn and<T: Into<ComposedRelation<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
58        Self::And(witness.into_iter().map(|x| x.into()).collect())
59    }
60
61    /// Create a [ComposedRelation] for an OR relation from the given list of relations.
62    pub fn or<T: Into<ComposedRelation<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
63        Self::Or(witness.into_iter().map(|x| x.into()).collect())
64    }
65
66    /// Create a [ComposedRelation] for a threshold relation from the given list of relations.
67    pub fn threshold<T: Into<ComposedRelation<G>>>(
68        threshold: usize,
69        witness: impl IntoIterator<Item = T>,
70    ) -> Self {
71        Self::Threshold(threshold, witness.into_iter().map(|x| x.into()).collect())
72    }
73}
74
75impl<G: PrimeGroup> From<CanonicalLinearRelation<G>> for ComposedRelation<G> {
76    fn from(value: CanonicalLinearRelation<G>) -> Self {
77        ComposedRelation::Simple(value)
78    }
79}
80
81impl<G: PrimeGroup + MultiScalarMul> TryFrom<LinearRelation<G>> for ComposedRelation<G> {
82    type Error = InvalidInstance;
83
84    fn try_from(value: LinearRelation<G>) -> Result<Self, Self::Error> {
85        Ok(Self::Simple(CanonicalLinearRelation::try_from(value)?))
86    }
87}
88
89// Structure representing the Commitment type of Protocol as SigmaProtocol
90#[derive(Clone)]
91pub enum ComposedCommitment<G>
92where
93    G: PrimeGroup + ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
94    G::Scalar:
95        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
96{
97    Simple(Vec<G>),
98    And(Vec<ComposedCommitment<G>>),
99    Or(Vec<ComposedCommitment<G>>),
100    Threshold(Vec<ComposedCommitment<G>>),
101}
102
103impl<G: PrimeGroup> ComposedCommitment<G>
104where
105    G: ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
106    G::Scalar:
107        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
108{
109    /// Conditionally select between two ComposedCommitment values.
110    /// This function performs constant-time selection of the commitment values.
111    pub fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
112        match (a, b) {
113            (ComposedCommitment::Simple(a_elements), ComposedCommitment::Simple(b_elements)) => {
114                // Both vectors must have the same length for this to work
115                debug_assert_eq!(a_elements.len(), b_elements.len());
116                let selected: Vec<G> = a_elements
117                    .iter()
118                    .zip(b_elements.iter())
119                    .map(|(a, b)| G::conditional_select(a, b, choice))
120                    .collect();
121                ComposedCommitment::Simple(selected)
122            }
123            (ComposedCommitment::And(a_commitments), ComposedCommitment::And(b_commitments)) => {
124                debug_assert_eq!(a_commitments.len(), b_commitments.len());
125                let selected: Vec<ComposedCommitment<G>> = a_commitments
126                    .iter()
127                    .zip(b_commitments.iter())
128                    .map(|(a, b)| ComposedCommitment::conditional_select(a, b, choice))
129                    .collect();
130                ComposedCommitment::And(selected)
131            }
132            (ComposedCommitment::Or(a_commitments), ComposedCommitment::Or(b_commitments)) => {
133                debug_assert_eq!(a_commitments.len(), b_commitments.len());
134                let selected: Vec<ComposedCommitment<G>> = a_commitments
135                    .iter()
136                    .zip(b_commitments.iter())
137                    .map(|(a, b)| ComposedCommitment::conditional_select(a, b, choice))
138                    .collect();
139                ComposedCommitment::Or(selected)
140            }
141            (
142                ComposedCommitment::Threshold(a_commitments),
143                ComposedCommitment::Threshold(b_commitments),
144            ) => {
145                debug_assert_eq!(a_commitments.len(), b_commitments.len());
146                let selected: Vec<ComposedCommitment<G>> = a_commitments
147                    .iter()
148                    .zip(b_commitments.iter())
149                    .map(|(a, b)| ComposedCommitment::conditional_select(a, b, choice))
150                    .collect();
151                ComposedCommitment::Threshold(selected)
152            }
153            _ => {
154                unreachable!("Mismatched ComposedCommitment variants in conditional_select");
155            }
156        }
157    }
158}
159
160// Structure representing the ProverState type of Protocol as SigmaProtocol
161pub enum ComposedProverState<G>
162where
163    G: PrimeGroup
164        + ConstantTimeEq
165        + ConditionallySelectable
166        + Encoding<[u8]>
167        + NargSerialize
168        + NargDeserialize
169        + MultiScalarMul,
170    G::Scalar:
171        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
172{
173    Simple(<CanonicalLinearRelation<G> as SigmaProtocol>::ProverState),
174    And(Vec<ComposedProverState<G>>),
175    Or(ComposedOrProverState<G>),
176    Threshold(ComposedThresholdProverState<G>),
177}
178
179pub type ComposedOrProverState<G> = Vec<ComposedOrProverStateEntry<G>>;
180pub struct ComposedOrProverStateEntry<G>(
181    Choice,
182    ComposedProverState<G>,
183    ComposedChallenge<G>,
184    ComposedResponse<G>,
185)
186where
187    G: PrimeGroup
188        + ConstantTimeEq
189        + ConditionallySelectable
190        + Encoding<[u8]>
191        + NargSerialize
192        + NargDeserialize
193        + MultiScalarMul,
194    G::Scalar:
195        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable;
196
197pub type ComposedThresholdProverState<G> = Vec<ComposedThresholdProverStateEntry<G>>;
198pub struct ComposedThresholdProverStateEntry<G>
199where
200    G: PrimeGroup
201        + ConstantTimeEq
202        + ConditionallySelectable
203        + Encoding<[u8]>
204        + NargSerialize
205        + NargDeserialize
206        + MultiScalarMul,
207    G::Scalar:
208        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
209{
210    use_simulator: Choice,
211    prover_state: ComposedProverState<G>,
212    simulated_challenge: ComposedChallenge<G>,
213    simulated_response: ComposedResponse<G>,
214}
215
216// Structure representing the Response type of Protocol as SigmaProtocol
217#[derive(Clone)]
218pub enum ComposedResponse<G>
219where
220    G: PrimeGroup
221        + ConditionallySelectable
222        + Encoding<[u8]>
223        + NargSerialize
224        + NargDeserialize
225        + MultiScalarMul,
226    G::Scalar:
227        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
228{
229    Simple(Vec<<CanonicalLinearRelation<G> as SigmaProtocol>::Response>),
230    And(Vec<ComposedResponse<G>>),
231    Or(Vec<ComposedChallenge<G>>, Vec<ComposedResponse<G>>),
232    Threshold(Vec<ComposedChallenge<G>>, Vec<ComposedResponse<G>>),
233}
234
235const TAG_SIMPLE: u8 = 0;
236const TAG_AND: u8 = 1;
237const TAG_OR: u8 = 2;
238const TAG_THRESHOLD: u8 = 3;
239
240fn read_u32(buf: &mut &[u8]) -> VerificationResult<u32> {
241    if buf.len() < 4 {
242        return Err(VerificationError);
243    }
244    let (head, tail) = buf.split_at(4);
245    *buf = tail;
246    Ok(u32::from_le_bytes(head.try_into().unwrap()))
247}
248
249fn write_len(out: &mut Vec<u8>, len: usize) {
250    out.extend_from_slice(&(len as u32).to_le_bytes());
251}
252
253impl<G> Encoding<[u8]> for ComposedCommitment<G>
254where
255    G: PrimeGroup + ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
256    G::Scalar:
257        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
258{
259    fn encode(&self) -> impl AsRef<[u8]> {
260        let mut out = Vec::new();
261        match self {
262            ComposedCommitment::Simple(elems) => {
263                out.push(TAG_SIMPLE);
264                write_len(&mut out, elems.len());
265                for elem in elems {
266                    elem.serialize_into_narg(&mut out);
267                }
268            }
269            ComposedCommitment::And(cs) => {
270                out.push(TAG_AND);
271                write_len(&mut out, cs.len());
272                for c in cs {
273                    c.serialize_into_narg(&mut out);
274                }
275            }
276            ComposedCommitment::Or(cs) => {
277                out.push(TAG_OR);
278                write_len(&mut out, cs.len());
279                for c in cs {
280                    c.serialize_into_narg(&mut out);
281                }
282            }
283            ComposedCommitment::Threshold(cs) => {
284                out.push(TAG_THRESHOLD);
285                write_len(&mut out, cs.len());
286                for c in cs {
287                    c.serialize_into_narg(&mut out);
288                }
289            }
290        }
291        out
292    }
293}
294
295impl<G> NargDeserialize for ComposedCommitment<G>
296where
297    G: PrimeGroup + ConditionallySelectable + Encoding<[u8]> + NargSerialize + NargDeserialize,
298    G::Scalar:
299        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
300{
301    fn deserialize_from_narg(buf: &mut &[u8]) -> VerificationResult<Self> {
302        if buf.is_empty() {
303            return Err(VerificationError);
304        }
305        let (tag_bytes, rest) = buf.split_at(1);
306        *buf = rest;
307        match tag_bytes[0] {
308            TAG_SIMPLE => {
309                let len = read_u32(buf)? as usize;
310                let mut elems = Vec::new();
311                for _ in 0..len {
312                    elems.push(G::deserialize_from_narg(buf)?);
313                }
314                Ok(ComposedCommitment::Simple(elems))
315            }
316            TAG_AND => {
317                let len = read_u32(buf)? as usize;
318                let mut entries = Vec::new();
319                for _ in 0..len {
320                    entries.push(ComposedCommitment::deserialize_from_narg(buf)?);
321                }
322                Ok(ComposedCommitment::And(entries))
323            }
324            TAG_OR => {
325                let len = read_u32(buf)? as usize;
326                let mut entries = Vec::new();
327                for _ in 0..len {
328                    entries.push(ComposedCommitment::deserialize_from_narg(buf)?);
329                }
330                Ok(ComposedCommitment::Or(entries))
331            }
332            TAG_THRESHOLD => {
333                let len = read_u32(buf)? as usize;
334                let mut entries = Vec::new();
335                for _ in 0..len {
336                    entries.push(ComposedCommitment::deserialize_from_narg(buf)?);
337                }
338                Ok(ComposedCommitment::Threshold(entries))
339            }
340            _ => Err(VerificationError),
341        }
342    }
343}
344
345impl<G> Encoding<[u8]> for ComposedResponse<G>
346where
347    G: PrimeGroup
348        + ConditionallySelectable
349        + Encoding<[u8]>
350        + NargSerialize
351        + NargDeserialize
352        + MultiScalarMul,
353    G::Scalar:
354        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
355{
356    fn encode(&self) -> impl AsRef<[u8]> {
357        let mut out = Vec::new();
358        match self {
359            ComposedResponse::Simple(responses) => {
360                out.push(TAG_SIMPLE);
361                write_len(&mut out, responses.len());
362                for r in responses {
363                    r.serialize_into_narg(&mut out);
364                }
365            }
366            ComposedResponse::And(entries) => {
367                out.push(TAG_AND);
368                write_len(&mut out, entries.len());
369                for r in entries {
370                    r.serialize_into_narg(&mut out);
371                }
372            }
373            ComposedResponse::Or(challenges, responses) => {
374                out.push(TAG_OR);
375                write_len(&mut out, challenges.len());
376                for c in challenges {
377                    c.serialize_into_narg(&mut out);
378                }
379                write_len(&mut out, responses.len());
380                for r in responses {
381                    r.serialize_into_narg(&mut out);
382                }
383            }
384            ComposedResponse::Threshold(challenges, responses) => {
385                out.push(TAG_THRESHOLD);
386                write_len(&mut out, challenges.len());
387                for c in challenges {
388                    c.serialize_into_narg(&mut out);
389                }
390                write_len(&mut out, responses.len());
391                for r in responses {
392                    r.serialize_into_narg(&mut out);
393                }
394            }
395        }
396        out
397    }
398}
399
400impl<G> NargDeserialize for ComposedResponse<G>
401where
402    G: PrimeGroup
403        + ConditionallySelectable
404        + Encoding<[u8]>
405        + NargSerialize
406        + NargDeserialize
407        + MultiScalarMul,
408    G::Scalar:
409        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
410{
411    fn deserialize_from_narg(buf: &mut &[u8]) -> VerificationResult<Self> {
412        if buf.is_empty() {
413            return Err(VerificationError);
414        }
415        let (tag_bytes, rest) = buf.split_at(1);
416        *buf = rest;
417        match tag_bytes[0] {
418            TAG_SIMPLE => {
419                let len = read_u32(buf)? as usize;
420                let mut elems = Vec::new();
421                for _ in 0..len {
422                    elems.push(G::Scalar::deserialize_from_narg(buf)?);
423                }
424                Ok(ComposedResponse::Simple(elems))
425            }
426            TAG_AND => {
427                let len = read_u32(buf)? as usize;
428                let mut entries = Vec::new();
429                for _ in 0..len {
430                    entries.push(ComposedResponse::deserialize_from_narg(buf)?);
431                }
432                Ok(ComposedResponse::And(entries))
433            }
434            TAG_OR => {
435                let ch_len = read_u32(buf)? as usize;
436                let mut challenges = Vec::new();
437                for _ in 0..ch_len {
438                    challenges.push(G::Scalar::deserialize_from_narg(buf)?);
439                }
440                let resp_len = read_u32(buf)? as usize;
441                let mut responses = Vec::new();
442                for _ in 0..resp_len {
443                    responses.push(ComposedResponse::deserialize_from_narg(buf)?);
444                }
445                Ok(ComposedResponse::Or(challenges, responses))
446            }
447            TAG_THRESHOLD => {
448                let ch_len = read_u32(buf)? as usize;
449                let mut challenges = Vec::new();
450                for _ in 0..ch_len {
451                    challenges.push(G::Scalar::deserialize_from_narg(buf)?);
452                }
453                let resp_len = read_u32(buf)? as usize;
454                let mut responses = Vec::new();
455                for _ in 0..resp_len {
456                    responses.push(ComposedResponse::deserialize_from_narg(buf)?);
457                }
458                Ok(ComposedResponse::Threshold(challenges, responses))
459            }
460            _ => Err(VerificationError),
461        }
462    }
463}
464
465impl<G> ComposedResponse<G>
466where
467    G: PrimeGroup
468        + ConditionallySelectable
469        + Encoding<[u8]>
470        + NargSerialize
471        + NargDeserialize
472        + MultiScalarMul,
473    G::Scalar:
474        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
475{
476    /// Conditionally select between two ComposedResponse values.
477    /// This function performs constant-time selection of the response values.
478    pub fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
479        match (a, b) {
480            (ComposedResponse::Simple(a_scalars), ComposedResponse::Simple(b_scalars)) => {
481                // Both vectors must have the same length for this to work
482                debug_assert_eq!(a_scalars.len(), b_scalars.len());
483                let selected: Vec<G::Scalar> = a_scalars
484                    .iter()
485                    .zip(b_scalars.iter())
486                    .map(|(a, b)| G::Scalar::conditional_select(a, b, choice))
487                    .collect();
488                ComposedResponse::Simple(selected)
489            }
490            (ComposedResponse::And(a_responses), ComposedResponse::And(b_responses)) => {
491                debug_assert_eq!(a_responses.len(), b_responses.len());
492                let selected: Vec<ComposedResponse<G>> = a_responses
493                    .iter()
494                    .zip(b_responses.iter())
495                    .map(|(a, b)| ComposedResponse::conditional_select(a, b, choice))
496                    .collect();
497                ComposedResponse::And(selected)
498            }
499            (
500                ComposedResponse::Or(a_challenges, a_responses),
501                ComposedResponse::Or(b_challenges, b_responses),
502            ) => {
503                debug_assert_eq!(a_challenges.len(), b_challenges.len());
504                debug_assert_eq!(a_responses.len(), b_responses.len());
505
506                let selected_challenges: Vec<ComposedChallenge<G>> = a_challenges
507                    .iter()
508                    .zip(b_challenges.iter())
509                    .map(|(a, b)| G::Scalar::conditional_select(a, b, choice))
510                    .collect();
511
512                let selected_responses: Vec<ComposedResponse<G>> = a_responses
513                    .iter()
514                    .zip(b_responses.iter())
515                    .map(|(a, b)| ComposedResponse::conditional_select(a, b, choice))
516                    .collect();
517
518                ComposedResponse::Or(selected_challenges, selected_responses)
519            }
520            (
521                ComposedResponse::Threshold(a_challenges, a_responses),
522                ComposedResponse::Threshold(b_challenges, b_responses),
523            ) => {
524                debug_assert_eq!(a_challenges.len(), b_challenges.len());
525                debug_assert_eq!(a_responses.len(), b_responses.len());
526
527                let selected_challenges: Vec<ComposedChallenge<G>> = a_challenges
528                    .iter()
529                    .zip(b_challenges.iter())
530                    .map(|(a, b)| G::Scalar::conditional_select(a, b, choice))
531                    .collect();
532
533                let selected_responses: Vec<ComposedResponse<G>> = a_responses
534                    .iter()
535                    .zip(b_responses.iter())
536                    .map(|(a, b)| ComposedResponse::conditional_select(a, b, choice))
537                    .collect();
538
539                ComposedResponse::Threshold(selected_challenges, selected_responses)
540            }
541            _ => {
542                unreachable!("Mismatched ComposedResponse variants in conditional_select");
543            }
544        }
545    }
546}
547
548// Structure representing the Witness type of Protocol as SigmaProtocol
549#[derive(Clone)]
550pub enum ComposedWitness<G>
551where
552    G: PrimeGroup + Encoding<[u8]> + NargSerialize + NargDeserialize + MultiScalarMul,
553    G::Scalar: Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]>,
554{
555    Simple(<CanonicalLinearRelation<G> as SigmaProtocol>::Witness),
556    And(Vec<ComposedWitness<G>>),
557    Or(Vec<ComposedWitness<G>>),
558    Threshold(Vec<ComposedWitness<G>>),
559}
560
561impl<G> ComposedWitness<G>
562where
563    G: PrimeGroup + Encoding<[u8]> + NargSerialize + NargDeserialize + MultiScalarMul,
564    G::Scalar: Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]>,
565{
566    /// Create a [ComposedWitness] for an AND relation from the given list of witnesses.
567    pub fn and<T: Into<ComposedWitness<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
568        Self::And(witness.into_iter().map(|x| x.into()).collect())
569    }
570
571    /// Create a [ComposedWitness] for an OR relation from the given list of witnesses.
572    pub fn or<T: Into<ComposedWitness<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
573        Self::Or(witness.into_iter().map(|x| x.into()).collect())
574    }
575
576    /// Create a [ComposedWitness] for a threshold relation from the given list of witnesses.
577    pub fn threshold<T: Into<ComposedWitness<G>>>(witness: impl IntoIterator<Item = T>) -> Self {
578        Self::Threshold(witness.into_iter().map(|x| x.into()).collect())
579    }
580}
581
582impl<G> From<<CanonicalLinearRelation<G> as SigmaProtocol>::Witness> for ComposedWitness<G>
583where
584    G: PrimeGroup + Encoding<[u8]> + NargSerialize + NargDeserialize + MultiScalarMul,
585    G::Scalar:
586        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
587{
588    fn from(value: <CanonicalLinearRelation<G> as SigmaProtocol>::Witness) -> Self {
589        Self::Simple(value)
590    }
591}
592
593type ComposedChallenge<G> = <CanonicalLinearRelation<G> as SigmaProtocol>::Challenge;
594fn threshold_x<F: PrimeField>(index: usize) -> F {
595    F::from((index + 1) as u64)
596}
597
598fn poly_mul_linear<F: Field>(coeffs: &[F], constant: F) -> Vec<F> {
599    let mut out = vec![F::ZERO; coeffs.len() + 1];
600    for (i, coeff) in coeffs.iter().enumerate() {
601        out[i] += *coeff * constant;
602        out[i + 1] += *coeff;
603    }
604    out
605}
606
607fn interpolate_polynomial<F: Field>(points: &[Evaluation<F>]) -> Result<Vec<F>, Error> {
608    if points.is_empty() {
609        return Err(Error::InvalidInstanceWitnessPair);
610    }
611
612    let mut coeffs = vec![F::ZERO; points.len()];
613
614    for (i, point_i) in points.iter().enumerate() {
615        let mut basis = vec![F::ONE];
616        let mut denom = F::ONE;
617
618        for (j, point_j) in points.iter().enumerate() {
619            if i == j {
620                continue;
621            }
622            denom *= point_i.x - point_j.x;
623            basis = poly_mul_linear::<F>(&basis, -point_j.x);
624        }
625
626        let denom_inv = denom.invert();
627        if denom_inv.is_none().into() {
628            return Err(Error::InvalidInstanceWitnessPair);
629        }
630        let scale = point_i.y * denom_inv.unwrap_or(F::ZERO);
631        for (coeff, basis_coeff) in coeffs.iter_mut().zip(basis.iter()) {
632            *coeff += *basis_coeff * scale;
633        }
634    }
635
636    Ok(coeffs)
637}
638
639fn evaluate_polynomial<F: Field>(coeffs: &[F], x: F) -> F {
640    coeffs
641        .iter()
642        .rev()
643        .fold(F::ZERO, |acc, coeff| acc * x + coeff)
644}
645
646fn expand_threshold_challenges<F: PrimeField>(
647    threshold: usize,
648    total: usize,
649    challenge: F,
650    compressed_challenges: &[F],
651) -> Result<Vec<F>, Error> {
652    if threshold == 0 || threshold > total {
653        return Err(Error::InvalidInstanceWitnessPair);
654    }
655
656    let degree = total - threshold;
657    if compressed_challenges.len() != degree {
658        return Err(Error::InvalidInstanceWitnessPair);
659    }
660
661    let mut points = Vec::with_capacity(degree + 1);
662    points.push(Evaluation {
663        x: F::ZERO,
664        y: challenge,
665    });
666    for (index, share) in compressed_challenges.iter().enumerate() {
667        points.push(Evaluation {
668            x: threshold_x::<F>(index),
669            y: *share,
670        });
671    }
672
673    let coeffs = interpolate_polynomial::<F>(&points)?;
674    let mut challenges = Vec::with_capacity(total);
675    for index in 0..total {
676        challenges.push(evaluate_polynomial::<F>(&coeffs, threshold_x::<F>(index)));
677    }
678
679    Ok(challenges)
680}
681
682fn count_choices(choices: &[Choice]) -> usize {
683    let mut sum: u32 = 0;
684    for choice in choices {
685        let inc = sum.wrapping_add(1);
686        sum = u32::conditional_select(&sum, &inc, *choice);
687    }
688    sum as usize
689}
690
691#[derive(Clone, Copy)]
692struct Evaluation<T> {
693    x: T,
694    y: T,
695}
696
697impl<T: ConditionallySelectable> ConditionallySelectable for Evaluation<T> {
698    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
699        Evaluation {
700            x: T::conditional_select(&a.x, &b.x, choice),
701            y: T::conditional_select(&a.y, &b.y, choice),
702        }
703    }
704}
705
706impl<T> From<(T, T)> for Evaluation<T> {
707    fn from(value: (T, T)) -> Self {
708        Evaluation {
709            x: value.0,
710            y: value.1,
711        }
712    }
713}
714
715fn conditional_swap_point<T: ConditionallySelectable>(
716    points: &mut [T],
717    left: usize,
718    right: usize,
719    swap: Choice,
720) {
721    if left == right {
722        return;
723    }
724    if left < right {
725        let (head, tail) = points.split_at_mut(right);
726        T::conditional_swap(&mut head[left], &mut tail[0], swap);
727    } else {
728        let (head, tail) = points.split_at_mut(left);
729        T::conditional_swap(&mut tail[0], &mut head[right], swap);
730    }
731}
732
733fn oroffcompact_points<T: ConditionallySelectable>(
734    points: &mut [T],
735    marks: &[Choice],
736    offset: usize,
737) {
738    let n = points.len();
739    if n <= 1 {
740        return;
741    }
742    debug_assert_eq!(n, marks.len());
743    debug_assert!(n.is_power_of_two());
744
745    let half = n / 2;
746    let mut m = 0usize;
747    for mark in &marks[..half] {
748        m += mark.unwrap_u8() as usize;
749    }
750
751    if n == 2 {
752        let z = Choice::from((offset & 1) as u8);
753        let b = ((!marks[0]) & marks[1]) ^ z;
754        conditional_swap_point(points, 0, 1, b);
755        return;
756    }
757
758    let offset_mod = offset % half;
759    oroffcompact_points(&mut points[..half], &marks[..half], offset_mod);
760    let offset_plus_m_mod = (offset + m) % half;
761    oroffcompact_points(&mut points[half..], &marks[half..], offset_plus_m_mod);
762
763    let s = Choice::from(((offset_mod + m) >= half) as u8) ^ Choice::from((offset >= half) as u8);
764    for i in 0..half {
765        let b = s ^ Choice::from((i >= offset_plus_m_mod) as u8);
766        conditional_swap_point(points, i, i + half, b);
767    }
768}
769
770fn oblivious_compact_points<T: ConditionallySelectable>(points: &mut [T], marks: &[Choice]) {
771    let n = points.len();
772    if n == 0 {
773        return;
774    }
775    debug_assert_eq!(n, marks.len());
776
777    let n1 = 1usize << (usize::BITS as usize - 1 - n.leading_zeros() as usize);
778    let n2 = n - n1;
779    let mut m = 0usize;
780    for mark in &marks[..n2] {
781        m += mark.unwrap_u8() as usize;
782    }
783
784    if n2 > 0 {
785        oblivious_compact_points(&mut points[..n2], &marks[..n2]);
786    }
787    oroffcompact_points(&mut points[n2..], &marks[n2..], (n1 - n2 + m) % n1);
788
789    for i in 0..n2 {
790        let b = Choice::from((i >= m) as u8);
791        conditional_swap_point(points, i, i + n1, b);
792    }
793}
794
795impl<G> ComposedRelation<G>
796where
797    G: PrimeGroup
798        + ConstantTimeEq
799        + ConditionallySelectable
800        + Encoding<[u8]>
801        + NargSerialize
802        + NargDeserialize
803        + MultiScalarMul,
804    G::Scalar:
805        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
806{
807    fn is_witness_valid(&self, witness: &ComposedWitness<G>) -> Choice {
808        match (self, witness) {
809            (ComposedRelation::Simple(instance), ComposedWitness::Simple(witness)) => {
810                instance.is_witness_valid(witness)
811            }
812            (ComposedRelation::And(instances), ComposedWitness::And(witnesses)) => instances
813                .iter()
814                .zip(witnesses)
815                .fold(Choice::from(1), |bit, (instance, witness)| {
816                    bit & instance.is_witness_valid(witness)
817                }),
818            (ComposedRelation::Or(instances), ComposedWitness::Or(witnesses)) => instances
819                .iter()
820                .zip(witnesses)
821                .fold(Choice::from(0), |bit, (instance, witness)| {
822                    bit | instance.is_witness_valid(witness)
823                }),
824            (
825                ComposedRelation::Threshold(threshold, instances),
826                ComposedWitness::Threshold(witnesses),
827            ) => {
828                if *threshold == 0 || instances.len() != witnesses.len() {
829                    return Choice::from(0);
830                }
831                let mut count = 0usize;
832                for (instance, witness) in instances.iter().zip(witnesses) {
833                    if instance.is_witness_valid(witness).unwrap_u8() == 1 {
834                        count += 1;
835                    }
836                }
837                Choice::from((count >= *threshold) as u8)
838            }
839            _ => Choice::from(0),
840        }
841    }
842
843    fn prover_commit_simple(
844        protocol: &CanonicalLinearRelation<G>,
845        witness: &<CanonicalLinearRelation<G> as SigmaProtocol>::Witness,
846        rng: &mut impl ScalarRng,
847    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error> {
848        protocol.prover_commit(witness, rng).map(|(c, s)| {
849            (
850                ComposedCommitment::Simple(c),
851                ComposedProverState::Simple(s),
852            )
853        })
854    }
855
856    fn prover_response_simple(
857        instance: &CanonicalLinearRelation<G>,
858        state: <CanonicalLinearRelation<G> as SigmaProtocol>::ProverState,
859        challenge: &<CanonicalLinearRelation<G> as SigmaProtocol>::Challenge,
860    ) -> Result<ComposedResponse<G>, Error> {
861        instance
862            .prover_response(state, challenge)
863            .map(ComposedResponse::Simple)
864    }
865
866    fn prover_commit_and(
867        protocols: &[ComposedRelation<G>],
868        witnesses: &[ComposedWitness<G>],
869        rng: &mut impl ScalarRng,
870    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error> {
871        if protocols.len() != witnesses.len() {
872            return Err(Error::InvalidInstanceWitnessPair);
873        }
874
875        let mut commitments = Vec::with_capacity(protocols.len());
876        let mut prover_states = Vec::with_capacity(protocols.len());
877
878        for (p, w) in protocols.iter().zip(witnesses.iter()) {
879            let (mut c, s) = p.prover_commit(w, rng)?;
880            let commitment = c.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
881            if !c.is_empty() {
882                return Err(Error::InvalidInstanceWitnessPair);
883            }
884            commitments.push(commitment);
885            prover_states.push(s);
886        }
887
888        Ok((
889            ComposedCommitment::And(commitments),
890            ComposedProverState::And(prover_states),
891        ))
892    }
893
894    fn prover_response_and(
895        instances: &[ComposedRelation<G>],
896        prover_state: Vec<ComposedProverState<G>>,
897        challenge: &ComposedChallenge<G>,
898    ) -> Result<ComposedResponse<G>, Error> {
899        if instances.len() != prover_state.len() {
900            return Err(Error::InvalidInstanceWitnessPair);
901        }
902
903        let responses: Result<Vec<_>, _> = instances
904            .iter()
905            .zip(prover_state)
906            .map(|(p, s)| {
907                let mut res = p.prover_response(s, challenge)?;
908                res.pop().ok_or(Error::InvalidInstanceWitnessPair)
909            })
910            .collect();
911
912        Ok(ComposedResponse::And(responses?))
913    }
914
915    fn prover_commit_or(
916        instances: &[ComposedRelation<G>],
917        witnesses: &[ComposedWitness<G>],
918        rng: &mut impl ScalarRng,
919    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error>
920    where
921        G: ConditionallySelectable,
922    {
923        if instances.len() != witnesses.len() {
924            return Err(Error::InvalidInstanceWitnessPair);
925        }
926
927        let mut commitments = Vec::new();
928        let mut prover_states = Vec::new();
929
930        // Selector value set when the first valid witness is found.
931        let mut valid_witness_found = Choice::from(0);
932        for (i, w) in witnesses.iter().enumerate() {
933            let (mut commitment_vec, prover_state) = instances[i].prover_commit(w, rng)?;
934            let commitment = commitment_vec
935                .pop()
936                .ok_or(Error::InvalidInstanceWitnessPair)?;
937            if !commitment_vec.is_empty() {
938                return Err(Error::InvalidInstanceWitnessPair);
939            }
940
941            let (mut simulated_commitment_vec, simulated_challenge, mut simulated_response_vec) =
942                instances[i].simulate_transcript(rng)?;
943            let simulated_commitment = simulated_commitment_vec
944                .pop()
945                .ok_or(Error::InvalidInstanceWitnessPair)?;
946            if !simulated_commitment_vec.is_empty() {
947                return Err(Error::InvalidInstanceWitnessPair);
948            }
949            let simulated_response = simulated_response_vec
950                .pop()
951                .ok_or(Error::InvalidInstanceWitnessPair)?;
952            if !simulated_response_vec.is_empty() {
953                return Err(Error::InvalidInstanceWitnessPair);
954            }
955
956            let valid_witness = instances[i].is_witness_valid(w) & !valid_witness_found;
957            let select_witness = valid_witness;
958
959            let commitment = ComposedCommitment::conditional_select(
960                &simulated_commitment,
961                &commitment,
962                select_witness,
963            );
964
965            commitments.push(commitment);
966            prover_states.push(ComposedOrProverStateEntry(
967                select_witness,
968                prover_state,
969                simulated_challenge,
970                simulated_response,
971            ));
972
973            valid_witness_found |= valid_witness;
974        }
975
976        if valid_witness_found.unwrap_u8() == 0 {
977            Err(Error::InvalidInstanceWitnessPair)
978        } else {
979            Ok((
980                ComposedCommitment::Or(commitments),
981                ComposedProverState::Or(prover_states),
982            ))
983        }
984    }
985
986    fn prover_response_or(
987        instances: &[ComposedRelation<G>],
988        prover_state: ComposedOrProverState<G>,
989        challenge: &ComposedChallenge<G>,
990    ) -> Result<ComposedResponse<G>, Error> {
991        let mut result_challenges = Vec::with_capacity(instances.len());
992        let mut result_responses = Vec::with_capacity(instances.len());
993
994        let mut witness_challenge = *challenge;
995        for ComposedOrProverStateEntry(
996            valid_witness,
997            _prover_state,
998            simulated_challenge,
999            _simulated_response,
1000        ) in &prover_state
1001        {
1002            let c = G::Scalar::conditional_select(
1003                simulated_challenge,
1004                &G::Scalar::ZERO,
1005                *valid_witness,
1006            );
1007            witness_challenge -= c;
1008        }
1009        for (
1010            instance,
1011            ComposedOrProverStateEntry(
1012                valid_witness,
1013                prover_state,
1014                simulated_challenge,
1015                simulated_response,
1016            ),
1017        ) in instances.iter().zip(prover_state)
1018        {
1019            let challenge_i = G::Scalar::conditional_select(
1020                &simulated_challenge,
1021                &witness_challenge,
1022                valid_witness,
1023            );
1024
1025            let mut response_vec = instance.prover_response(prover_state, &challenge_i)?;
1026            let response = response_vec
1027                .pop()
1028                .ok_or(Error::InvalidInstanceWitnessPair)?;
1029            if !response_vec.is_empty() {
1030                return Err(Error::InvalidInstanceWitnessPair);
1031            }
1032            let response =
1033                ComposedResponse::conditional_select(&simulated_response, &response, valid_witness);
1034
1035            result_challenges.push(challenge_i);
1036            result_responses.push(response.clone());
1037        }
1038
1039        result_challenges.pop();
1040        Ok(ComposedResponse::Or(result_challenges, result_responses))
1041    }
1042
1043    fn prover_commit_threshold(
1044        threshold: usize,
1045        instances: &[ComposedRelation<G>],
1046        witnesses: &[ComposedWitness<G>],
1047        rng: &mut impl ScalarRng,
1048    ) -> Result<(ComposedCommitment<G>, ComposedProverState<G>), Error>
1049    where
1050        G: ConditionallySelectable,
1051    {
1052        if instances.len() != witnesses.len() || threshold == 0 || threshold > instances.len() {
1053            return Err(Error::InvalidInstanceWitnessPair);
1054        }
1055        let degree = instances.len() - threshold;
1056
1057        let valid_witnesses = instances
1058            .iter()
1059            .zip(witnesses.iter())
1060            .map(|(x, w)| x.is_witness_valid(w))
1061            .collect::<Vec<Choice>>();
1062
1063        // Degree-(t-1) interpolation can only satisfy t fixed points.
1064        let invalid_count = instances.len() - count_choices(&valid_witnesses);
1065        if invalid_count > degree {
1066            return Err(Error::InvalidInstanceWitnessPair);
1067        }
1068
1069        let mut remaining_seeds = (degree - invalid_count) as u32;
1070        let mut commitments = Vec::with_capacity(instances.len());
1071        let mut prover_states = Vec::with_capacity(instances.len());
1072        for (i, (instance, witness)) in instances.iter().zip(witnesses.iter()).enumerate() {
1073            let (mut commitment_vec, prover_state) = instance.prover_commit(witness, rng)?;
1074            let commitment = commitment_vec
1075                .pop()
1076                .ok_or(Error::InvalidInstanceWitnessPair)?;
1077            if !commitment_vec.is_empty() {
1078                return Err(Error::InvalidInstanceWitnessPair);
1079            }
1080
1081            let (mut simulated_commitments, simulated_challenge, mut simulated_responses) =
1082                instance.simulate_transcript(rng)?;
1083            let simulated_commitment = simulated_commitments
1084                .pop()
1085                .ok_or(Error::InvalidInstanceWitnessPair)?;
1086            if !simulated_commitments.is_empty() {
1087                return Err(Error::InvalidInstanceWitnessPair);
1088            }
1089            let simulated_response = simulated_responses
1090                .pop()
1091                .ok_or(Error::InvalidInstanceWitnessPair)?;
1092            if !simulated_responses.is_empty() {
1093                return Err(Error::InvalidInstanceWitnessPair);
1094            }
1095
1096            let valid_witness = valid_witnesses[i];
1097            let should_seed = valid_witness & Choice::from((remaining_seeds != 0) as u8);
1098            remaining_seeds = remaining_seeds.wrapping_sub(should_seed.unwrap_u8() as u32);
1099            let use_simulator = (!valid_witness) | should_seed;
1100            let commitment = ComposedCommitment::conditional_select(
1101                &commitment,
1102                &simulated_commitment,
1103                use_simulator,
1104            );
1105            commitments.push(commitment);
1106            prover_states.push(ComposedThresholdProverStateEntry {
1107                use_simulator,
1108                prover_state,
1109                simulated_challenge,
1110                simulated_response,
1111            });
1112        }
1113
1114        Ok((
1115            ComposedCommitment::Threshold(commitments),
1116            ComposedProverState::Threshold(prover_states),
1117        ))
1118    }
1119
1120    fn prover_response_threshold(
1121        threshold: usize,
1122        instances: &[ComposedRelation<G>],
1123        prover_states: ComposedThresholdProverState<G>,
1124        challenge: &ComposedChallenge<G>,
1125    ) -> Result<ComposedResponse<G>, Error> {
1126        if threshold == 0 || threshold > instances.len() || instances.len() != prover_states.len() {
1127            return Err(Error::InvalidInstanceWitnessPair);
1128        }
1129        let degree = instances.len() - threshold;
1130
1131        let marks = prover_states
1132            .iter()
1133            .map(|entry| entry.use_simulator)
1134            .collect::<Vec<_>>();
1135        debug_assert_eq!(count_choices(&marks), degree);
1136
1137        let mut points = prover_states
1138            .iter()
1139            .enumerate()
1140            .map(|(i, entry)| Evaluation {
1141                x: threshold_x::<G::Scalar>(i),
1142                y: entry.simulated_challenge,
1143            })
1144            .collect::<Vec<Evaluation<G::Scalar>>>();
1145        oblivious_compact_points(&mut points, &marks);
1146        points.drain(degree..);
1147
1148        let mut full_points = Vec::with_capacity(degree + 1);
1149        full_points.push(Evaluation {
1150            x: G::Scalar::ZERO,
1151            y: *challenge,
1152        });
1153        full_points.extend_from_slice(&points);
1154
1155        let coeffs = interpolate_polynomial::<G::Scalar>(&full_points)?;
1156        let mut compressed_challenges = Vec::with_capacity(degree);
1157        for index in 0..degree {
1158            compressed_challenges.push(evaluate_polynomial::<G::Scalar>(
1159                &coeffs,
1160                threshold_x::<G::Scalar>(index),
1161            ));
1162        }
1163
1164        let expanded_challenges = expand_threshold_challenges::<G::Scalar>(
1165            threshold,
1166            instances.len(),
1167            *challenge,
1168            &compressed_challenges,
1169        )?;
1170
1171        let mut responses = Vec::with_capacity(instances.len());
1172
1173        for (i, (instance, prover_state)) in instances.iter().zip(prover_states).enumerate() {
1174            let poly_challenge = expanded_challenges[i];
1175            let challenge = G::Scalar::conditional_select(
1176                &poly_challenge,
1177                &prover_state.simulated_challenge,
1178                prover_state.use_simulator,
1179            );
1180
1181            let mut response_vec =
1182                instance.prover_response(prover_state.prover_state, &challenge)?;
1183            let response = response_vec
1184                .pop()
1185                .ok_or(Error::InvalidInstanceWitnessPair)?;
1186            if !response_vec.is_empty() {
1187                return Err(Error::InvalidInstanceWitnessPair);
1188            }
1189            let response = ComposedResponse::conditional_select(
1190                &response,
1191                &prover_state.simulated_response,
1192                prover_state.use_simulator,
1193            );
1194
1195            responses.push(response);
1196        }
1197
1198        Ok(ComposedResponse::Threshold(
1199            compressed_challenges,
1200            responses,
1201        ))
1202    }
1203}
1204
1205impl<G> SigmaProtocol for ComposedRelation<G>
1206where
1207    G: PrimeGroup
1208        + ConstantTimeEq
1209        + ConditionallySelectable
1210        + Encoding<[u8]>
1211        + NargSerialize
1212        + NargDeserialize
1213        + MultiScalarMul,
1214    G::Scalar:
1215        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
1216{
1217    type Commitment = ComposedCommitment<G>;
1218    type ProverState = ComposedProverState<G>;
1219    type Response = ComposedResponse<G>;
1220    type Witness = ComposedWitness<G>;
1221    type Challenge = ComposedChallenge<G>;
1222
1223    fn prover_commit(
1224        &self,
1225        witness: &Self::Witness,
1226        rng: &mut impl ScalarRng,
1227    ) -> Result<(Vec<Self::Commitment>, Self::ProverState), Error> {
1228        let (commitment, state) = match (self, witness) {
1229            (ComposedRelation::Simple(p), ComposedWitness::Simple(w)) => {
1230                Self::prover_commit_simple(p, w, rng)
1231            }
1232            (ComposedRelation::And(ps), ComposedWitness::And(ws)) => {
1233                Self::prover_commit_and(ps, ws, rng)
1234            }
1235            (ComposedRelation::Or(ps), ComposedWitness::Or(witnesses)) => {
1236                Self::prover_commit_or(ps, witnesses, rng)
1237            }
1238            (ComposedRelation::Threshold(threshold, ps), ComposedWitness::Threshold(witnesses)) => {
1239                Self::prover_commit_threshold(*threshold, ps, witnesses, rng)
1240            }
1241            _ => Err(Error::InvalidInstanceWitnessPair),
1242        }?;
1243        Ok((vec![commitment], state))
1244    }
1245
1246    fn prover_response(
1247        &self,
1248        state: Self::ProverState,
1249        challenge: &Self::Challenge,
1250    ) -> Result<Vec<Self::Response>, Error> {
1251        let response = match (self, state) {
1252            (ComposedRelation::Simple(instance), ComposedProverState::Simple(state)) => {
1253                Self::prover_response_simple(instance, state, challenge)
1254            }
1255            (ComposedRelation::And(instances), ComposedProverState::And(prover_state)) => {
1256                Self::prover_response_and(instances, prover_state, challenge)
1257            }
1258            (ComposedRelation::Or(instances), ComposedProverState::Or(prover_state)) => {
1259                Self::prover_response_or(instances, prover_state, challenge)
1260            }
1261            (
1262                ComposedRelation::Threshold(threshold, instances),
1263                ComposedProverState::Threshold(prover_state),
1264            ) => Self::prover_response_threshold(*threshold, instances, prover_state, challenge),
1265            _ => Err(Error::InvalidInstanceWitnessPair),
1266        }?;
1267        Ok(vec![response])
1268    }
1269
1270    fn verifier(
1271        &self,
1272        commitment: &[Self::Commitment],
1273        challenge: &Self::Challenge,
1274        response: &[Self::Response],
1275    ) -> Result<(), Error> {
1276        let (commitment, response) = match (commitment.first(), response.first()) {
1277            (Some(c), Some(r)) => (c, r),
1278            _ => return Err(Error::InvalidInstanceWitnessPair),
1279        };
1280
1281        match (self, commitment, response) {
1282            (
1283                ComposedRelation::Simple(p),
1284                ComposedCommitment::Simple(c),
1285                ComposedResponse::Simple(r),
1286            ) => p.verifier(c, challenge, r),
1287            (
1288                ComposedRelation::And(ps),
1289                ComposedCommitment::And(commitments),
1290                ComposedResponse::And(responses),
1291            ) => {
1292                if ps.len() != commitments.len() || commitments.len() != responses.len() {
1293                    return Err(Error::InvalidInstanceWitnessPair);
1294                }
1295                ps.iter()
1296                    .zip(commitments)
1297                    .zip(responses)
1298                    .try_for_each(|((p, c), r)| {
1299                        p.verifier(
1300                            core::slice::from_ref(c),
1301                            challenge,
1302                            core::slice::from_ref(r),
1303                        )
1304                    })
1305            }
1306            (
1307                ComposedRelation::Or(ps),
1308                ComposedCommitment::Or(commitments),
1309                ComposedResponse::Or(challenges, responses),
1310            ) => {
1311                if ps.len() != commitments.len()
1312                    || commitments.len() != responses.len()
1313                    || challenges.len() != ps.len() - 1
1314                {
1315                    return Err(Error::InvalidInstanceWitnessPair);
1316                }
1317                let last_challenge = *challenge - challenges.iter().sum::<G::Scalar>();
1318                ps.iter()
1319                    .zip(commitments)
1320                    .zip(challenges.iter().chain(&Some(last_challenge)))
1321                    .zip(responses)
1322                    .try_for_each(|(((p, commitment), challenge), response)| {
1323                        p.verifier(
1324                            core::slice::from_ref(commitment),
1325                            challenge,
1326                            core::slice::from_ref(response),
1327                        )
1328                    })
1329            }
1330            (
1331                ComposedRelation::Threshold(threshold, ps),
1332                ComposedCommitment::Threshold(commitments),
1333                ComposedResponse::Threshold(challenges, responses),
1334            ) => {
1335                if *threshold == 0
1336                    || *threshold > ps.len()
1337                    || commitments.len() != ps.len()
1338                    || challenges.len() != ps.len() - *threshold
1339                    || responses.len() != ps.len()
1340                {
1341                    return Err(Error::InvalidInstanceWitnessPair);
1342                }
1343
1344                let full_challenges = expand_threshold_challenges::<G::Scalar>(
1345                    *threshold,
1346                    ps.len(),
1347                    *challenge,
1348                    challenges,
1349                )?;
1350
1351                ps.iter()
1352                    .zip(commitments)
1353                    .zip(full_challenges.iter())
1354                    .zip(responses)
1355                    .try_for_each(|(((p, commitment), challenge), response)| {
1356                        p.verifier(
1357                            core::slice::from_ref(commitment),
1358                            challenge,
1359                            core::slice::from_ref(response),
1360                        )
1361                    })
1362            }
1363            _ => Err(Error::InvalidInstanceWitnessPair),
1364        }
1365    }
1366
1367    fn commitment_len(&self) -> usize {
1368        1
1369    }
1370
1371    fn response_len(&self) -> usize {
1372        1
1373    }
1374
1375    fn instance_label(&self) -> impl AsRef<[u8]> {
1376        match self {
1377            ComposedRelation::Simple(p) => {
1378                let label = p.instance_label();
1379                label.as_ref().to_vec()
1380            }
1381            ComposedRelation::And(ps) => {
1382                let mut bytes = Vec::new();
1383                for p in ps {
1384                    bytes.extend(p.instance_label().as_ref());
1385                }
1386                bytes
1387            }
1388            ComposedRelation::Or(ps) => {
1389                let mut bytes = Vec::new();
1390                for p in ps {
1391                    bytes.extend(p.instance_label().as_ref());
1392                }
1393                bytes
1394            }
1395            ComposedRelation::Threshold(threshold, ps) => {
1396                let mut bytes = Vec::new();
1397                bytes.extend_from_slice(&((*threshold as u64).to_le_bytes()));
1398                for p in ps {
1399                    bytes.extend(p.instance_label().as_ref());
1400                }
1401                bytes
1402            }
1403        }
1404    }
1405
1406    fn protocol_identifier(&self) -> [u8; 64] {
1407        let mut hasher = Sha3_256::new();
1408
1409        match self {
1410            ComposedRelation::Simple(p) => {
1411                // take the digest of the simple protocol id
1412                hasher.update([0u8; 32]);
1413                hasher.update(p.protocol_identifier());
1414            }
1415            ComposedRelation::And(protocols) => {
1416                hasher.update([1u8; 32]);
1417                for p in protocols {
1418                    hasher.update(p.protocol_identifier().as_ref());
1419                }
1420            }
1421            ComposedRelation::Or(protocols) => {
1422                hasher.update([2u8; 32]);
1423                for p in protocols {
1424                    hasher.update(p.protocol_identifier().as_ref());
1425                }
1426            }
1427            ComposedRelation::Threshold(threshold, protocols) => {
1428                hasher.update([3u8; 32]);
1429                hasher.update(((*threshold as u64).to_le_bytes()).as_ref());
1430                for p in protocols {
1431                    hasher.update(p.protocol_identifier().as_ref());
1432                }
1433            }
1434        }
1435
1436        let mut protocol_id = [0u8; 64];
1437        protocol_id[..32].clone_from_slice(&hasher.finalize());
1438        protocol_id
1439    }
1440}
1441
1442impl<G> SigmaProtocolSimulator for ComposedRelation<G>
1443where
1444    G: PrimeGroup
1445        + ConstantTimeEq
1446        + ConditionallySelectable
1447        + Encoding<[u8]>
1448        + NargSerialize
1449        + NargDeserialize
1450        + MultiScalarMul,
1451    G::Scalar:
1452        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
1453{
1454    fn simulate_commitment(
1455        &self,
1456        challenge: &Self::Challenge,
1457        response: &[Self::Response],
1458    ) -> Result<Vec<Self::Commitment>, Error> {
1459        let response = response.first().ok_or(Error::InvalidInstanceWitnessPair)?;
1460        let commitment = match (self, response) {
1461            (ComposedRelation::Simple(p), ComposedResponse::Simple(r)) => {
1462                ComposedCommitment::Simple(p.simulate_commitment(challenge, r)?)
1463            }
1464            (ComposedRelation::And(ps), ComposedResponse::And(rs)) => {
1465                let commitments = ps
1466                    .iter()
1467                    .zip(rs)
1468                    .map(|(p, r)| {
1469                        p.simulate_commitment(challenge, core::slice::from_ref(r))
1470                            .and_then(|mut c| c.pop().ok_or(Error::InvalidInstanceWitnessPair))
1471                    })
1472                    .collect::<Result<Vec<_>, _>>()?;
1473                ComposedCommitment::And(commitments)
1474            }
1475            (ComposedRelation::Or(ps), ComposedResponse::Or(challenges, rs)) => {
1476                let last_challenge = *challenge - challenges.iter().sum::<G::Scalar>();
1477                let commitments = ps
1478                    .iter()
1479                    .zip(challenges.iter().chain(&Some(last_challenge)))
1480                    .zip(rs)
1481                    .map(|((p, ch), r)| {
1482                        p.simulate_commitment(ch, core::slice::from_ref(r))
1483                            .and_then(|mut c| c.pop().ok_or(Error::InvalidInstanceWitnessPair))
1484                    })
1485                    .collect::<Result<Vec<_>, _>>()?;
1486                ComposedCommitment::Or(commitments)
1487            }
1488            (
1489                ComposedRelation::Threshold(threshold, ps),
1490                ComposedResponse::Threshold(challenges, rs),
1491            ) => {
1492                if rs.len() != ps.len() || challenges.len() != ps.len() - threshold {
1493                    return Err(Error::InvalidInstanceWitnessPair);
1494                }
1495
1496                let full_challenges = expand_threshold_challenges::<G::Scalar>(
1497                    *threshold,
1498                    ps.len(),
1499                    *challenge,
1500                    challenges,
1501                )?;
1502                let commitments = ps
1503                    .iter()
1504                    .zip(full_challenges.iter())
1505                    .zip(rs)
1506                    .map(|((p, ch), r)| {
1507                        p.simulate_commitment(ch, core::slice::from_ref(r))
1508                            .and_then(|mut c| c.pop().ok_or(Error::InvalidInstanceWitnessPair))
1509                    })
1510                    .collect::<Result<Vec<_>, _>>()?;
1511                ComposedCommitment::Threshold(commitments)
1512            }
1513            _ => unreachable!(),
1514        };
1515
1516        Ok(vec![commitment])
1517    }
1518
1519    fn simulate_response(&self, rng: &mut impl ScalarRng) -> Vec<Self::Response> {
1520        let response = match self {
1521            ComposedRelation::Simple(p) => ComposedResponse::Simple(p.simulate_response(rng)),
1522            ComposedRelation::And(ps) => {
1523                let responses = ps
1524                    .iter()
1525                    .map(|p| {
1526                        let mut r = p.simulate_response(rng);
1527                        r.pop().ok_or(Error::InvalidInstanceWitnessPair)
1528                    })
1529                    .collect::<Result<Vec<_>, _>>()
1530                    .expect("simulate_response invariant");
1531                ComposedResponse::And(responses)
1532            }
1533            ComposedRelation::Or(ps) => {
1534                let challenges = rng.random_scalars_vec::<G>(ps.len()).to_vec();
1535                let mut responses = Vec::with_capacity(ps.len());
1536                for p in ps.iter() {
1537                    let mut r = p.simulate_response(&mut *rng);
1538                    let resp = r
1539                        .pop()
1540                        .expect("simulate_response should return at least one element");
1541                    responses.push(resp);
1542                }
1543                ComposedResponse::Or(challenges, responses)
1544            }
1545            ComposedRelation::Threshold(threshold, ps) => {
1546                if *threshold == 0 || *threshold > ps.len() {
1547                    return vec![ComposedResponse::Threshold(Vec::new(), Vec::new())];
1548                }
1549
1550                let degree = ps.len() - *threshold;
1551                let compressed_challenges = rng.random_scalars_vec::<G>(degree).to_vec();
1552                let mut responses = Vec::with_capacity(ps.len());
1553                for p in ps.iter() {
1554                    let mut r = p.simulate_response(&mut *rng);
1555                    let response = r
1556                        .pop()
1557                        .expect("simulate_response should return at least one element");
1558                    responses.push(response);
1559                }
1560                ComposedResponse::Threshold(compressed_challenges, responses)
1561            }
1562        };
1563        vec![response]
1564    }
1565
1566    fn simulate_transcript(
1567        &self,
1568        rng: &mut impl ScalarRng,
1569    ) -> Result<(Vec<Self::Commitment>, Self::Challenge, Vec<Self::Response>), Error> {
1570        match self {
1571            ComposedRelation::Simple(p) => {
1572                let (c, ch, r) = p.simulate_transcript(rng)?;
1573                Ok((
1574                    vec![ComposedCommitment::Simple(c)],
1575                    ch,
1576                    vec![ComposedResponse::Simple(r)],
1577                ))
1578            }
1579            ComposedRelation::And(ps) => {
1580                let [challenge] = rng.random_scalars::<G, _>();
1581                let mut responses = Vec::with_capacity(ps.len());
1582                for p in ps.iter() {
1583                    let mut resp = p.simulate_response(&mut *rng);
1584                    let response = resp.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1585                    if !resp.is_empty() {
1586                        return Err(Error::InvalidInstanceWitnessPair);
1587                    }
1588                    responses.push(response);
1589                }
1590                let commitments = ps
1591                    .iter()
1592                    .enumerate()
1593                    .map(|(i, p)| {
1594                        p.simulate_commitment(&challenge, &[responses[i].clone()])
1595                            .and_then(|mut c| {
1596                                let first = c.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1597                                if !c.is_empty() {
1598                                    return Err(Error::InvalidInstanceWitnessPair);
1599                                }
1600                                Ok(first)
1601                            })
1602                    })
1603                    .collect::<Result<Vec<_>, Error>>()?;
1604
1605                Ok((
1606                    vec![ComposedCommitment::And(commitments)],
1607                    challenge,
1608                    vec![ComposedResponse::And(responses)],
1609                ))
1610            }
1611            ComposedRelation::Or(ps) => {
1612                let challenges = rng.random_scalars_vec::<G>(ps.len() - 1);
1613                let mut responses = Vec::with_capacity(ps.len());
1614                for p in ps.iter() {
1615                    let mut resp = p.simulate_response(&mut *rng);
1616                    let response = resp.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1617                    if !resp.is_empty() {
1618                        return Err(Error::InvalidInstanceWitnessPair);
1619                    }
1620                    responses.push(response);
1621                }
1622
1623                let mut commitments = Vec::with_capacity(ps.len());
1624                for i in 0..ps.len() {
1625                    let mut commitment = ps[i].simulate_commitment(
1626                        &if i == ps.len() - 1 {
1627                            challenges.iter().fold(G::Scalar::ZERO, |acc, x| acc - x)
1628                        } else {
1629                            challenges[i]
1630                        },
1631                        &[responses[i].clone()],
1632                    )?;
1633                    let commitment = commitment.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1634                    commitments.push(commitment);
1635                }
1636
1637                Ok((
1638                    vec![ComposedCommitment::Or(commitments)],
1639                    challenges.iter().sum::<G::Scalar>(),
1640                    vec![ComposedResponse::Or(challenges, responses)],
1641                ))
1642            }
1643            ComposedRelation::Threshold(threshold, ps) => {
1644                if *threshold == 0 || *threshold > ps.len() {
1645                    return Err(Error::InvalidInstanceWitnessPair);
1646                }
1647
1648                let degree = ps.len() - *threshold;
1649                let compressed_challenges = rng.random_scalars_vec::<G>(degree);
1650                let mut responses = Vec::with_capacity(ps.len());
1651                for p in ps.iter() {
1652                    let mut resp = p.simulate_response(&mut *rng);
1653                    let response = resp.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1654                    if !resp.is_empty() {
1655                        return Err(Error::InvalidInstanceWitnessPair);
1656                    }
1657                    responses.push(response);
1658                }
1659
1660                let [challenge] = rng.random_scalars::<G, _>();
1661                let full_challenges = expand_threshold_challenges(
1662                    *threshold,
1663                    ps.len(),
1664                    challenge,
1665                    &compressed_challenges,
1666                )?;
1667                let commitments = ps
1668                    .iter()
1669                    .zip(full_challenges.iter())
1670                    .zip(responses.iter())
1671                    .map(|((p, ch), r)| {
1672                        p.simulate_commitment(ch, core::slice::from_ref(r))
1673                            .and_then(|mut c| {
1674                                let first = c.pop().ok_or(Error::InvalidInstanceWitnessPair)?;
1675                                if !c.is_empty() {
1676                                    return Err(Error::InvalidInstanceWitnessPair);
1677                                }
1678                                Ok(first)
1679                            })
1680                    })
1681                    .collect::<Result<Vec<_>, Error>>()?;
1682                Ok((
1683                    vec![ComposedCommitment::Threshold(commitments)],
1684                    challenge,
1685                    vec![ComposedResponse::Threshold(
1686                        compressed_challenges,
1687                        responses,
1688                    )],
1689                ))
1690            }
1691        }
1692    }
1693}
1694
1695impl<G> ComposedRelation<G>
1696where
1697    G: PrimeGroup
1698        + ConstantTimeEq
1699        + ConditionallySelectable
1700        + Encoding<[u8]>
1701        + NargSerialize
1702        + NargDeserialize
1703        + MultiScalarMul,
1704    G::Scalar:
1705        Encoding<[u8]> + NargSerialize + NargDeserialize + Decoding<[u8]> + ConditionallySelectable,
1706{
1707    /// Convert this Protocol into a non-interactive zero-knowledge proof
1708    /// using the Shake128DuplexSponge codec and a specified session identifier.
1709    ///
1710    /// This method provides a convenient way to create a NIZK from a Protocol
1711    /// without exposing the specific codec type to the API caller.
1712    ///
1713    /// # Parameters
1714    /// - `session_identifier`: Domain separator bytes for the Fiat-Shamir transform
1715    ///
1716    /// # Returns
1717    /// A `Nizk` instance ready for proving and verification
1718    pub fn into_nizk(self, session_identifier: &[u8]) -> Nizk<ComposedRelation<G>> {
1719        Nizk::new(session_identifier, self)
1720    }
1721}