use anyhow::{anyhow, Result};
use itertools::Itertools;
use plonky2_field::types::Field;
use crate::field::extension::Extendable;
use crate::fri::proof::{FriProof, FriProofTarget};
use crate::hash::hash_types::{HashOut, RichField};
use crate::iop::witness::WitnessWrite;
use crate::plonk::config::AlgebraicHasher;
pub fn set_fri_proof_target<F, W, H, const D: usize>(
witness: &mut W,
fri_proof_target: &FriProofTarget<D>,
fri_proof: &FriProof<F, H, D>,
) -> Result<()>
where
F: RichField + Extendable<D>,
W: WitnessWrite<F> + ?Sized,
H: AlgebraicHasher<F>,
{
witness.set_target(fri_proof_target.pow_witness, fri_proof.pow_witness)?;
match (
&fri_proof_target.batch_mask_proof,
&fri_proof.batch_mask_proof,
) {
(Some(target_mask_proof), Some(mask_proof)) => {
witness.set_cap_target(&target_mask_proof.cap, &mask_proof.cap)?;
for (target_query, query) in target_mask_proof
.query_openings
.iter()
.zip_eq(&mask_proof.query_openings)
{
witness.set_extension_targets(&target_query.values, &query.values)?;
let target_len = target_query.merkle_proof.siblings.len();
let siblings_len = query.merkle_proof.siblings.len();
if target_len < siblings_len {
return Err(anyhow!(
"fri_proof->batch_mask_proof->query_openings->merkle_proof target length is less than the proof length"
));
}
for i in 0..siblings_len {
witness.set_hash_target(
target_query.merkle_proof.siblings[i],
query.merkle_proof.siblings[i],
)?;
}
for i in siblings_len..target_len {
witness
.set_hash_target(target_query.merkle_proof.siblings[i], HashOut::ZERO)?;
}
}
}
(None, None) => {}
(Some(_), None) => {
return Err(anyhow!(
"recursive FRI targets expect a batch-mask proof but the proof is missing it"
))
}
(None, Some(_)) => {
return Err(anyhow!(
"recursive FRI targets are missing batch-mask targets for a masked proof"
))
}
}
if fri_proof_target.final_polys.chunks.len() != fri_proof.final_polys.chunks.len() {
return Err(anyhow!(
"recursive FRI targets expect {} final polynomial chunks but the proof has {}",
fri_proof_target.final_polys.chunks.len(),
fri_proof.final_polys.chunks.len(),
));
}
for (target_chunk, proof_chunk) in fri_proof_target
.final_polys
.chunks
.iter()
.zip_eq(&fri_proof.final_polys.chunks)
{
let target_len = target_chunk.0.len();
let coeffs_len = proof_chunk.coeffs.len();
if target_len < coeffs_len {
return Err(anyhow!(
"fri_proof->final_polys chunk target length is less than the proof length"
));
}
for i in 0..coeffs_len {
witness.set_extension_target(target_chunk.0[i], proof_chunk.coeffs[i])?;
}
for i in coeffs_len..target_len {
witness.set_extension_target(target_chunk.0[i], F::Extension::ZERO)?;
}
}
let target_caps = &fri_proof_target.commit_phase_merkle_caps;
let proof_caps = &fri_proof.commit_phase_merkle_caps;
if target_caps.len() < proof_caps.len() {
return Err(anyhow!(
"fri_proof->commit_phase_merkle_caps's target length is less than the proof length"
));
}
for (target_cap, proof_cap) in target_caps.iter().zip(proof_caps) {
witness.set_cap_target(target_cap, proof_cap)?;
}
for target_cap in target_caps.iter().skip(proof_caps.len()) {
for hash in target_cap.0.iter() {
witness.set_hash_target(*hash, HashOut::ZERO)?;
}
}
for (query_round_index, (qt, q)) in fri_proof_target
.query_round_proofs
.iter()
.zip_eq(&fri_proof.query_round_proofs)
.enumerate()
{
for (oracle_index, (at, a)) in qt
.initial_trees_proof
.evals_proofs
.iter()
.zip_eq(&q.initial_trees_proof.evals_proofs)
.enumerate()
{
let target_len = at.0.len();
let evals_len = a.0.len();
if target_len < evals_len {
return Err(anyhow!(
"fri_proof->query_round_proofs[{query_round_index}]->initial_trees_proof->evals_proofs[{oracle_index}] target length {target_len} is less than the proof length {evals_len}"
));
}
for i in 0..evals_len {
witness.set_target(at.0[i], a.0[i])?;
}
for i in evals_len..target_len {
witness.set_target(at.0[i], F::ZERO)?;
}
let target_len = at.1.siblings.len();
let siblings_len = a.1.siblings.len();
if target_len < siblings_len {
return Err(anyhow!("fri_proof->query_round_proofs->initial_trees_proof->evals_proofs->siblings' target length is less than the proof length"));
}
for i in 0..siblings_len {
witness.set_hash_target(at.1.siblings[i], a.1.siblings[i])?;
}
for i in siblings_len..target_len {
witness.set_hash_target(at.1.siblings[i], HashOut::ZERO)?;
}
}
for (st, s) in qt.steps.iter().zip(&q.steps) {
let target_len = st.evals.len();
let evals_len = s.evals.len();
if target_len < evals_len {
return Err(anyhow!(
"fri_proof->query_round_proofs->steps->evals target length is less than the proof length"
));
}
for i in 0..evals_len {
witness.set_extension_target(st.evals[i], s.evals[i])?;
}
for i in evals_len..target_len {
witness.set_extension_target(st.evals[i], F::Extension::ZERO)?;
}
let target_len = st.merkle_proof.siblings.len();
let siblings_len = s.merkle_proof.siblings.len();
if target_len < siblings_len {
return Err(anyhow!("fri_proof->query_round_proofs->steps->merkle_proof->siblings' target length is less than the proof length"));
}
for i in 0..siblings_len {
witness.set_hash_target(st.merkle_proof.siblings[i], s.merkle_proof.siblings[i])?;
}
for i in siblings_len..target_len {
witness.set_hash_target(st.merkle_proof.siblings[i], HashOut::ZERO)?;
}
}
for st in qt.steps.iter().skip(q.steps.len()) {
for &eval in &st.evals {
witness.set_extension_target(eval, F::Extension::ZERO)?;
}
for &sibling in &st.merkle_proof.siblings {
witness.set_hash_target(sibling, HashOut::ZERO)?;
}
}
}
Ok(())
}