use {
crate::{
memory::{MemoryBlock, MemoryOperation},
noir_to_r1cs::NoirToR1CSCompiler,
},
ark_ff::{One, Zero},
provekit_common::{
witness::{
SpiceMemoryOperation, SpiceWitnesses, SumTerm, WitnessBuilder, WitnessCoefficient,
},
FieldElement,
},
std::ops::Neg,
};
pub trait SpiceWitnessesBuilder {
fn new(
next_witness_idx: usize,
memory_length: usize,
initial_value_witnesses: Vec<usize>,
memory_operations: Vec<MemoryOperation>,
) -> Self;
}
impl SpiceWitnessesBuilder for SpiceWitnesses {
fn new(
mut next_witness_idx: usize,
memory_length: usize,
initial_value_witnesses: Vec<usize>,
memory_operations: Vec<MemoryOperation>,
) -> Self {
let start_witness_idx = next_witness_idx;
let spice_memory_operations = memory_operations
.into_iter()
.map(|op| match op {
MemoryOperation::Load(addr, value) => {
let op = SpiceMemoryOperation::Load(addr, value, next_witness_idx);
next_witness_idx += 1;
op
}
MemoryOperation::Store(addr, new_value) => {
let old_value = next_witness_idx;
next_witness_idx += 1;
let read_timestamp = next_witness_idx;
next_witness_idx += 1;
SpiceMemoryOperation::Store(addr, old_value, new_value, read_timestamp)
}
})
.collect();
let rv_final_start = next_witness_idx;
next_witness_idx += memory_length;
let rt_final_start = next_witness_idx;
next_witness_idx += memory_length;
let num_witnesses = next_witness_idx - start_witness_idx;
Self {
memory_length,
initial_value_witnesses,
memory_operations: spice_memory_operations,
rv_final_start,
rt_final_start,
first_witness_idx: start_witness_idx,
num_witnesses,
}
}
}
pub fn add_ram_checking(
r1cs_compiler: &mut NoirToR1CSCompiler,
block: &MemoryBlock,
) -> (u32, Vec<usize>) {
let rs_challenge =
r1cs_compiler.add_witness_builder(WitnessBuilder::Challenge(r1cs_compiler.num_witnesses()));
let rs_challenge_sqrd = r1cs_compiler.add_product(rs_challenge, rs_challenge);
let sz_challenge =
r1cs_compiler.add_witness_builder(WitnessBuilder::Challenge(r1cs_compiler.num_witnesses()));
let mut rs_hash = r1cs_compiler.witness_one();
let mut ws_hash = r1cs_compiler.witness_one();
let memory_length = block.initial_value_witnesses.len();
let mut all_mem_op_index_and_rt = vec![];
block
.initial_value_witnesses
.iter()
.enumerate()
.for_each(|(addr, mem_value)| {
let factor = add_mem_op_multiset_factor(
r1cs_compiler,
sz_challenge,
rs_challenge,
rs_challenge_sqrd,
(FieldElement::from(addr as u64), r1cs_compiler.witness_one()),
*mem_value,
(FieldElement::zero(), r1cs_compiler.witness_one()),
);
ws_hash = r1cs_compiler.add_product(ws_hash, factor);
});
let spice_witnesses = SpiceWitnesses::new(
r1cs_compiler.num_witnesses(),
memory_length,
block.initial_value_witnesses.clone(),
block.operations.clone(),
);
r1cs_compiler.add_witness_builder(WitnessBuilder::SpiceWitnesses(spice_witnesses.clone()));
spice_witnesses
.memory_operations
.iter()
.enumerate()
.for_each(|(mem_op_index, op)| {
match op {
SpiceMemoryOperation::Load(addr_witness, value_witness, rt_witness) => {
all_mem_op_index_and_rt.push((mem_op_index, *rt_witness));
let factor = add_mem_op_multiset_factor(
r1cs_compiler,
sz_challenge,
rs_challenge,
rs_challenge_sqrd,
(FieldElement::one(), *addr_witness),
*value_witness,
(FieldElement::one(), *rt_witness),
);
rs_hash = r1cs_compiler.add_product(rs_hash, factor);
let factor = add_mem_op_multiset_factor(
r1cs_compiler,
sz_challenge,
rs_challenge,
rs_challenge_sqrd,
(FieldElement::one(), *addr_witness),
*value_witness,
(
FieldElement::from((mem_op_index + 1) as u64),
r1cs_compiler.witness_one(),
),
);
ws_hash = r1cs_compiler.add_product(ws_hash, factor);
}
SpiceMemoryOperation::Store(
addr_witness,
old_value_witness,
new_value_witness,
rt_witness,
) => {
all_mem_op_index_and_rt.push((mem_op_index, *rt_witness));
let factor = add_mem_op_multiset_factor(
r1cs_compiler,
sz_challenge,
rs_challenge,
rs_challenge_sqrd,
(FieldElement::one(), *addr_witness),
*old_value_witness,
(FieldElement::one(), *rt_witness),
);
rs_hash = r1cs_compiler.add_product(rs_hash, factor);
let factor = add_mem_op_multiset_factor(
r1cs_compiler,
sz_challenge,
rs_challenge,
rs_challenge_sqrd,
(FieldElement::one(), *addr_witness),
*new_value_witness,
(
FieldElement::from((mem_op_index + 1) as u64),
r1cs_compiler.witness_one(),
),
);
ws_hash = r1cs_compiler.add_product(ws_hash, factor);
}
}
});
(0..memory_length).for_each(|addr| {
let value_witness = spice_witnesses.rv_final_start + addr;
let rt_witness = spice_witnesses.rt_final_start + addr;
all_mem_op_index_and_rt.push((block.operations.len(), rt_witness));
let factor = add_mem_op_multiset_factor(
r1cs_compiler,
sz_challenge,
rs_challenge,
rs_challenge_sqrd,
(FieldElement::from(addr as u64), r1cs_compiler.witness_one()),
value_witness,
(FieldElement::one(), rt_witness),
);
rs_hash = r1cs_compiler.add_product(rs_hash, factor);
});
r1cs_compiler.r1cs.add_constraint(
&[(FieldElement::one(), r1cs_compiler.witness_one())],
&[(FieldElement::one(), rs_hash)],
&[(FieldElement::one(), ws_hash)],
);
let num_bits = (block.operations.len() + 1).next_power_of_two().ilog2();
let mut range_check = Vec::with_capacity(2 * all_mem_op_index_and_rt.len());
all_mem_op_index_and_rt
.iter()
.for_each(|(mem_op_index, rt_witness)| {
let difference_witness_idx = r1cs_compiler.add_sum(vec![
SumTerm(
Some(FieldElement::from(*mem_op_index as u64)),
r1cs_compiler.witness_one(),
),
SumTerm(Some(FieldElement::one().neg()), *rt_witness),
]);
range_check.push(*rt_witness);
range_check.push(difference_witness_idx);
});
(num_bits, range_check)
}
fn add_mem_op_multiset_factor(
r1cs_compiler: &mut NoirToR1CSCompiler,
sz_challenge: usize,
rs_challenge: usize,
rs_challenge_sqrd: usize,
(addr, addr_witness): (FieldElement, usize),
value_witness: usize,
(timer, timer_witness): (FieldElement, usize),
) -> usize {
let factor = r1cs_compiler.add_witness_builder(WitnessBuilder::SpiceMultisetFactor(
r1cs_compiler.num_witnesses(),
sz_challenge,
rs_challenge,
WitnessCoefficient(addr, addr_witness),
value_witness,
WitnessCoefficient(timer, timer_witness),
));
let intermediate = r1cs_compiler.add_product(rs_challenge_sqrd, timer_witness);
r1cs_compiler.r1cs.add_constraint(
&[(FieldElement::one(), rs_challenge)],
&[(FieldElement::one().neg(), value_witness)],
&[
(FieldElement::one(), factor),
(FieldElement::one().neg(), sz_challenge),
(timer, intermediate),
(addr, addr_witness),
],
);
factor
}