#include "risc0/zkp/verify/fri.h"
#include <memory>
#include "risc0/core/log.h"
#include "risc0/core/util.h"
#include "risc0/zkp/core/constants.h"
#include "risc0/zkp/core/ntt.h"
#include "risc0/zkp/core/rou.h"
#include "risc0/zkp/verify/merkle.h"
namespace risc0 {
namespace {
Fp4 foldEval(Fp4* values, Fp4 mix, size_t S, size_t j) {
size_t N = kFriFold;
interpolateNTT(values, kFriFold);
bitReverse(values, kFriFold);
size_t rootPo2 = log2Ceil(N * S);
Fp invWK = pow(kRouRev[rootPo2], j);
Fp mul(1);
Fp4 tot;
Fp4 mixPow(1);
for (size_t i = 0; i < N; i++) {
tot += values[i] * mul * mixPow;
mul *= invWK;
mixPow *= mix;
}
return tot;
}
struct VerifyRoundInfo {
size_t domain;
MerkleTreeVerifier merkle;
Fp4 mix;
VerifyRoundInfo(ReadIOP& iop, size_t inDomain)
: domain(inDomain / kFriFold)
, merkle(iop, domain, kFriFold * 4, kQueries)
, mix(Fp4::random(iop)) {}
void verifyQuery(ReadIOP& iop, size_t* pos, Fp4* goal) const {
size_t quot = *pos / domain;
size_t group = *pos % domain;
auto data = merkle.verify(iop, group);
std::vector<Fp4> data4(kFriFold);
for (size_t i = 0; i < kFriFold; i++) {
data4[i] = Fp4(data[0 * kFriFold + i],
data[1 * kFriFold + i],
data[2 * kFriFold + i],
data[3 * kFriFold + i]);
}
REQUIRE(data4[quot] == *goal);
*goal = foldEval(data4.data(), mix, domain, group);
*pos = group;
}
};
}
void friVerify(ReadIOP& iop, size_t deg, InnerVerify inner) {
size_t domain = deg * kInvRate;
size_t origDomain = domain;
std::vector<VerifyRoundInfo> rounds;
while (deg > kFriMinDegree) {
rounds.emplace_back(iop, domain);
domain /= kFriFold;
deg /= kFriFold;
}
std::vector<Fp> finalCoeffs(deg * 4);
iop.read(finalCoeffs.data(), finalCoeffs.size());
auto digest = shaHash(finalCoeffs.data(), finalCoeffs.size(), 1, false);
iop.commit(digest);
Fp gen = kRouFwd[log2Ceil(domain)];
for (size_t q = 0; q < kQueries; q++) {
uint32_t rng = iop.generate();
size_t pos = rng % origDomain;
Fp4 goal = inner(iop, pos);
for (auto& round : rounds) {
round.verifyQuery(iop, &pos, &goal);
}
Fp x = pow(gen, pos);
Fp4 fx;
Fp cur(1);
for (size_t i = 0; i < deg; i++) {
for (size_t j = 0; j < 4; j++) {
fx.elems[j] += cur * finalCoeffs[j * deg + i];
}
cur *= x;
}
REQUIRE(fx == goal);
}
}
}