Skip to main content

plonky_cat_fri/
lib.rs

1#![forbid(unsafe_code)]
2
3mod error;
4pub mod whir;
5
6pub use self::error::Error;
7pub use self::whir::Whir;
8
9use std::marker::PhantomData;
10
11use plonky_cat_field::Field;
12use plonky_cat_hash::Hasher;
13use plonky_cat_merkle::MerkleTree;
14use plonky_cat_reduce::{
15    ProverContinue, ProverDone, ProverStep, TranscriptSerialize,
16    ReductionFunctor,
17    VerifierContinue, VerifierDone, VerifierStep,
18};
19
20// -- Claims and witnesses --
21
22#[derive(Debug, Clone)]
23pub struct FriClaim<F> {
24    merkle_root: F,
25    codeword_len: usize,
26    target_len: usize,
27}
28
29impl<F: Field> FriClaim<F> {
30    #[must_use]
31    pub fn new(merkle_root: F, codeword_len: usize, target_len: usize) -> Self {
32        Self { merkle_root, codeword_len, target_len }
33    }
34
35    /// Convenience: fold until a single element remains.
36    #[must_use]
37    pub fn until_constant(merkle_root: F, codeword_len: usize) -> Self {
38        Self::new(merkle_root, codeword_len, 1)
39    }
40
41    #[must_use]
42    pub fn merkle_root(&self) -> F {
43        self.merkle_root
44    }
45
46    #[must_use]
47    pub fn codeword_len(&self) -> usize {
48        self.codeword_len
49    }
50
51    #[must_use]
52    pub fn target_len(&self) -> usize {
53        self.target_len
54    }
55}
56
57#[derive(Debug, Clone)]
58pub struct FriWitness<H: Hasher> {
59    codeword: Vec<H::F>,
60    tree: MerkleTree<H>,
61}
62
63impl<H: Hasher> FriWitness<H> {
64    pub fn build(codeword: Vec<H::F>) -> Result<Self, Error> {
65        match () {
66            () if codeword.is_empty() => Err(Error::CodewordEmpty),
67            () if !codeword.len().is_power_of_two() =>
68                Err(Error::CodewordLengthNotPowerOfTwo { len: codeword.len() }),
69            () => {
70                let tree = MerkleTree::<H>::build(codeword.clone())?;
71                Ok(Self { codeword, tree })
72            }
73        }
74    }
75
76    pub fn merkle_root(&self) -> Result<H::F, Error> {
77        self.tree.root().map_err(Error::from)
78    }
79
80    #[must_use]
81    pub fn codeword(&self) -> &[H::F] {
82        &self.codeword
83    }
84}
85
86// -- Round message: the folded codeword's Merkle root --
87
88#[derive(Debug, Clone)]
89pub struct FriRoundMsg<F> {
90    folded_root: F,
91}
92
93impl<F: Field> FriRoundMsg<F> {
94    #[must_use]
95    pub fn new(folded_root: F) -> Self {
96        Self { folded_root }
97    }
98
99    #[must_use]
100    pub fn folded_root(&self) -> F {
101        self.folded_root
102    }
103}
104
105impl<F: Field> TranscriptSerialize<F> for FriRoundMsg<F> {
106    fn to_field_elements(&self) -> Vec<F> {
107        vec![self.folded_root]
108    }
109}
110
111// -- Base opening: the final (constant) codeword value --
112
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub struct FriOpening<F> {
115    constant_value: F,
116}
117
118impl<F: Field> FriOpening<F> {
119    #[must_use]
120    pub fn new(constant_value: F) -> Self {
121        Self { constant_value }
122    }
123
124    pub fn into_constant_value(self) -> F {
125        self.constant_value
126    }
127
128    #[must_use]
129    pub fn constant_value(&self) -> F {
130        self.constant_value
131    }
132}
133
134impl<F: Field> TranscriptSerialize<F> for FriOpening<F> {
135    fn to_field_elements(&self) -> Vec<F> {
136        vec![self.constant_value]
137    }
138}
139
140// -- FRI folding: halve the codeword with challenge r --
141// fold(w, r)[i] = (w[2i] + w[2i+1]) / 2 + r * (w[2i] - w[2i+1]) / 2
142// Simplified: fold(w, r)[i] = w[2i] * (1 + r) / 2 + w[2i+1] * (1 - r) / 2
143// Even simpler for v0.1: fold(w, r)[i] = w[2i] + r * w[2i+1]
144
145fn fold_codeword<F: Field>(codeword: &[F], challenge: F) -> Vec<F> {
146    let half = codeword.len() / 2;
147    (0..half)
148        .map(|i| codeword[2 * i] + challenge * codeword[2 * i + 1])
149        .collect()
150}
151
152// -- FRI as ReductionFunctor --
153
154pub struct Fri<H> {
155    _marker: PhantomData<H>,
156}
157
158impl<H: Hasher> ReductionFunctor for Fri<H>
159where
160    H::F: Field,
161{
162    type Claim = FriClaim<H::F>;
163    type Witness = FriWitness<H>;
164    type RoundMsg = FriRoundMsg<H::F>;
165    type Challenge = H::F;
166    type BaseOpening = FriOpening<H::F>;
167    type Error = Error;
168
169    fn prover_step(
170        claim: Self::Claim,
171        witness: Self::Witness,
172        challenge: Self::Challenge,
173    ) -> Result<
174        ProverStep<Self::Claim, Self::Witness, Self::RoundMsg, Self::BaseOpening>,
175        Self::Error,
176    > {
177        if claim.codeword_len <= claim.target_len {
178            witness.codeword.first()
179                .copied()
180                .ok_or(Error::CodewordEmpty)
181                .map(|val| ProverStep::Done(ProverDone::new(
182                    claim,
183                    witness,
184                    FriOpening::new(val),
185                )))
186        } else {
187            let folded = fold_codeword(&witness.codeword, challenge);
188            let new_witness = FriWitness::<H>::build(folded)?;
189            let new_root = new_witness.tree.root()?;
190            let msg = FriRoundMsg::new(new_root);
191
192            Ok(ProverStep::Continue(ProverContinue::new(
193                FriClaim::new(new_root, claim.codeword_len / 2, claim.target_len),
194                new_witness,
195                msg,
196            )))
197        }
198    }
199
200    fn verifier_step(
201        claim: Self::Claim,
202        message: Self::RoundMsg,
203        _challenge: Self::Challenge,
204    ) -> Result<VerifierStep<Self::Claim, Self::BaseOpening>, Self::Error> {
205        if claim.codeword_len <= claim.target_len {
206            Err(Error::StepOnFinished)
207        } else {
208            let new_len = claim.codeword_len / 2;
209
210            if new_len <= claim.target_len {
211                Ok(VerifierStep::Done(VerifierDone::new(
212                    FriClaim::new(message.folded_root, new_len, claim.target_len),
213                    FriOpening::new(message.folded_root),
214                )))
215            } else {
216                Ok(VerifierStep::Continue(VerifierContinue::new(
217                    FriClaim::new(message.folded_root, new_len, claim.target_len),
218                )))
219            }
220        }
221    }
222}