1#![forbid(unsafe_code)]
2
3mod error;
4pub mod whir;
5
6pub use self::error::Error;
7pub use self::whir::Whir;
8
9use std::marker::PhantomData;
10
11use plonky_cat_field::Field;
12use plonky_cat_hash::Hasher;
13use plonky_cat_merkle::MerkleTree;
14use plonky_cat_reduce::{
15 ProverContinue, ProverDone, ProverStep, TranscriptSerialize,
16 ReductionFunctor,
17 VerifierContinue, VerifierDone, VerifierStep,
18};
19
20#[derive(Debug, Clone)]
23pub struct FriClaim<F> {
24 merkle_root: F,
25 codeword_len: usize,
26 target_len: usize,
27}
28
29impl<F: Field> FriClaim<F> {
30 #[must_use]
31 pub fn new(merkle_root: F, codeword_len: usize, target_len: usize) -> Self {
32 Self { merkle_root, codeword_len, target_len }
33 }
34
35 #[must_use]
37 pub fn until_constant(merkle_root: F, codeword_len: usize) -> Self {
38 Self::new(merkle_root, codeword_len, 1)
39 }
40
41 #[must_use]
42 pub fn merkle_root(&self) -> F {
43 self.merkle_root
44 }
45
46 #[must_use]
47 pub fn codeword_len(&self) -> usize {
48 self.codeword_len
49 }
50
51 #[must_use]
52 pub fn target_len(&self) -> usize {
53 self.target_len
54 }
55}
56
57#[derive(Debug, Clone)]
58pub struct FriWitness<H: Hasher> {
59 codeword: Vec<H::F>,
60 tree: MerkleTree<H>,
61}
62
63impl<H: Hasher> FriWitness<H> {
64 pub fn build(codeword: Vec<H::F>) -> Result<Self, Error> {
65 match () {
66 () if codeword.is_empty() => Err(Error::CodewordEmpty),
67 () if !codeword.len().is_power_of_two() =>
68 Err(Error::CodewordLengthNotPowerOfTwo { len: codeword.len() }),
69 () => {
70 let tree = MerkleTree::<H>::build(codeword.clone())?;
71 Ok(Self { codeword, tree })
72 }
73 }
74 }
75
76 pub fn merkle_root(&self) -> Result<H::F, Error> {
77 self.tree.root().map_err(Error::from)
78 }
79
80 #[must_use]
81 pub fn codeword(&self) -> &[H::F] {
82 &self.codeword
83 }
84}
85
86#[derive(Debug, Clone)]
89pub struct FriRoundMsg<F> {
90 folded_root: F,
91}
92
93impl<F: Field> FriRoundMsg<F> {
94 #[must_use]
95 pub fn new(folded_root: F) -> Self {
96 Self { folded_root }
97 }
98
99 #[must_use]
100 pub fn folded_root(&self) -> F {
101 self.folded_root
102 }
103}
104
105impl<F: Field> TranscriptSerialize<F> for FriRoundMsg<F> {
106 fn to_field_elements(&self) -> Vec<F> {
107 vec![self.folded_root]
108 }
109}
110
111#[derive(Debug, Clone, PartialEq, Eq)]
114pub struct FriOpening<F> {
115 constant_value: F,
116}
117
118impl<F: Field> FriOpening<F> {
119 #[must_use]
120 pub fn new(constant_value: F) -> Self {
121 Self { constant_value }
122 }
123
124 pub fn into_constant_value(self) -> F {
125 self.constant_value
126 }
127
128 #[must_use]
129 pub fn constant_value(&self) -> F {
130 self.constant_value
131 }
132}
133
134impl<F: Field> TranscriptSerialize<F> for FriOpening<F> {
135 fn to_field_elements(&self) -> Vec<F> {
136 vec![self.constant_value]
137 }
138}
139
140fn fold_codeword<F: Field>(codeword: &[F], challenge: F) -> Vec<F> {
146 let half = codeword.len() / 2;
147 (0..half)
148 .map(|i| codeword[2 * i] + challenge * codeword[2 * i + 1])
149 .collect()
150}
151
152pub struct Fri<H> {
155 _marker: PhantomData<H>,
156}
157
158impl<H: Hasher> ReductionFunctor for Fri<H>
159where
160 H::F: Field,
161{
162 type Claim = FriClaim<H::F>;
163 type Witness = FriWitness<H>;
164 type RoundMsg = FriRoundMsg<H::F>;
165 type Challenge = H::F;
166 type BaseOpening = FriOpening<H::F>;
167 type Error = Error;
168
169 fn prover_step(
170 claim: Self::Claim,
171 witness: Self::Witness,
172 challenge: Self::Challenge,
173 ) -> Result<
174 ProverStep<Self::Claim, Self::Witness, Self::RoundMsg, Self::BaseOpening>,
175 Self::Error,
176 > {
177 if claim.codeword_len <= claim.target_len {
178 witness.codeword.first()
179 .copied()
180 .ok_or(Error::CodewordEmpty)
181 .map(|val| ProverStep::Done(ProverDone::new(
182 claim,
183 witness,
184 FriOpening::new(val),
185 )))
186 } else {
187 let folded = fold_codeword(&witness.codeword, challenge);
188 let new_witness = FriWitness::<H>::build(folded)?;
189 let new_root = new_witness.tree.root()?;
190 let msg = FriRoundMsg::new(new_root);
191
192 Ok(ProverStep::Continue(ProverContinue::new(
193 FriClaim::new(new_root, claim.codeword_len / 2, claim.target_len),
194 new_witness,
195 msg,
196 )))
197 }
198 }
199
200 fn verifier_step(
201 claim: Self::Claim,
202 message: Self::RoundMsg,
203 _challenge: Self::Challenge,
204 ) -> Result<VerifierStep<Self::Claim, Self::BaseOpening>, Self::Error> {
205 if claim.codeword_len <= claim.target_len {
206 Err(Error::StepOnFinished)
207 } else {
208 let new_len = claim.codeword_len / 2;
209
210 if new_len <= claim.target_len {
211 Ok(VerifierStep::Done(VerifierDone::new(
212 FriClaim::new(message.folded_root, new_len, claim.target_len),
213 FriOpening::new(message.folded_root),
214 )))
215 } else {
216 Ok(VerifierStep::Continue(VerifierContinue::new(
217 FriClaim::new(message.folded_root, new_len, claim.target_len),
218 )))
219 }
220 }
221 }
222}