use midnight_proofs::{circuit::Layouter, plonk::Error};
use super::{sha256_chip::IV, Sha256Chip};
use crate::{
field::{
decomposition::chip::P2RDecompositionChip, AssignedBounded, AssignedNative, NativeChip,
NativeGadget,
},
hash::sha256::{
sha256_chip::ROUND_CONSTANTS,
types::{AssignedPlain, CompressionState},
},
instructions::{
ArithInstructions, AssignmentInstructions, BinaryInstructions, ComparisonInstructions,
ControlFlowInstructions, DecompositionInstructions, DivisionInstructions,
EqualityInstructions, ZeroInstructions,
},
types::{AssignedBit, AssignedByte},
vec::AssignedVector,
CircuitField,
};
#[derive(Clone, Debug)]
pub struct VarLenSha256Gadget<F: CircuitField> {
pub(super) sha256chip: Sha256Chip<F>,
}
impl<F: CircuitField> VarLenSha256Gadget<F> {
pub fn new(sha256chip: &Sha256Chip<F>) -> Self {
Self {
sha256chip: sha256chip.clone(),
}
}
}
impl<F: CircuitField> VarLenSha256Gadget<F> {
fn ng(&self) -> &NativeGadget<F, P2RDecompositionChip<F>, NativeChip<F>> {
&self.sha256chip.native_gadget
}
}
impl<F> VarLenSha256Gadget<F>
where
F: CircuitField,
{
fn final_block_len<const M: usize>(
&self,
layouter: &mut impl Layouter<F>,
len: &AssignedNative<F>, ) -> Result<(AssignedBounded<F>, AssignedBit<F>), Error> {
let ng = &self.ng();
let final_block_len = {
let fb_len = ng.rem(layouter, len, 64u64.into(), Some(M.into()))?;
let full_final_block = {
let len_is_zero = ng.is_zero(layouter, len)?;
let fb_is_zero = ng.is_zero(layouter, &fb_len)?;
ng.xor(layouter, &[len_is_zero, fb_is_zero])?
};
let max_block_len = ng.assign_fixed(layouter, F::from(64u64))?;
ng.select(layouter, &full_final_block, &max_block_len, &fb_len)?
};
let len_lim: u64 = 56;
let final_block_len = ng.bounded_of_element(layouter, 7, &final_block_len)?;
let not_extra = ng.lower_than_fixed(layouter, &final_block_len, F::from(len_lim))?;
let extra = ng.not(layouter, ¬_extra)?;
Ok((final_block_len, extra))
}
fn insert_in_array<const L: usize>(
&self,
layouter: &mut impl Layouter<F>,
idx: &AssignedNative<F>,
array: &mut [AssignedByte<F>; L],
elem: AssignedByte<F>,
) -> Result<(), Error> {
let ng = self.ng();
for (i, item) in array.iter_mut().enumerate() {
let at_idx = ng.is_equal_to_fixed(layouter, idx, F::from(i as u64))?;
*item = ng.select(layouter, &at_idx, &elem, item)?;
}
Ok(())
}
fn merge_chunks<const L: usize>(
&self,
layouter: &mut impl Layouter<F>,
chunk_1: &[AssignedByte<F>; L],
chunk_2: &[AssignedByte<F>; L],
len: &AssignedNative<F>,
) -> Result<[AssignedByte<F>; L], Error> {
let ng = &self.ng();
let mut first_chunk: AssignedBit<F> = ng.assign_fixed(layouter, true)?;
let result = chunk_1
.iter()
.zip(chunk_2.iter())
.enumerate()
.map(|(i, (a, b))| {
let switch = ng.is_equal_to_fixed(layouter, len, F::from(i as u64))?;
first_chunk = ng.xor(layouter, &[first_chunk.clone(), switch])?;
ng.select(layouter, &first_chunk, a, b)
})
.collect::<Result<Vec<_>, Error>>()?;
Ok(result.try_into().expect("Chunks of equal length."))
}
fn compute_padding(
&self,
layouter: &mut impl Layouter<F>,
input_len: &AssignedNative<F>, final_chunk_len: &AssignedBounded<F>, final_chunk: &[AssignedByte<F>; 64],
extra_block: &AssignedBit<F>,
) -> Result<[AssignedByte<F>; 2 * 64], Error> {
let ng = self.ng();
let zero: AssignedByte<F> = ng.assign_fixed(layouter, 0u8)?;
let final_chunk_len = &ng.element_of_bounded(layouter, final_chunk_len)?;
let not_extra_block: AssignedNative<F> = ng.not(layouter, extra_block)?.into();
let block_1 = {
let zeros = &vec![zero.clone(); 64].try_into().unwrap();
self.merge_chunks(layouter, final_chunk, zeros, final_chunk_len)?
};
let block_2 = {
let zeros = &vec![zero; 56].try_into().unwrap();
let final_chunk: &[_; 56] = (&final_chunk[..56]).try_into().unwrap();
let cond_len = ng.mul(layouter, final_chunk_len, ¬_extra_block, None)?;
self.merge_chunks(layouter, final_chunk, zeros, &cond_len)?
};
let len_bytes = {
let len_in_bits = ng.mul_by_constant(layouter, input_len, F::from(8u64))?;
ng.assigned_to_be_bytes(layouter, &len_in_bits, Some(8usize))?
};
let mut padding = [block_1.as_slice(), &block_2, &len_bytes].concat();
{
let one: AssignedByte<F> = ng.assign_fixed(layouter, 0x80)?;
let idx = {
ng.linear_combination(
layouter,
&[
(F::ONE, final_chunk_len.clone()),
(F::from(64u64), not_extra_block),
],
-F::from(56u64),
)?
};
self.insert_in_array::<64>(
layouter,
&idx,
(&mut padding[56..120]).try_into().unwrap(),
one,
)?;
}
Ok(padding.try_into().unwrap())
}
}
impl<F: CircuitField> VarLenSha256Gadget<F> {
fn update_state(
&self,
layouter: &mut impl Layouter<F>,
state: &CompressionState<F>,
block: &[AssignedByte<F>; 64],
) -> Result<CompressionState<F>, Error> {
let sha256 = &self.sha256chip;
let block = sha256.block_from_bytes(layouter, block)?;
let message_blocks = sha256.message_schedule(layouter, &block)?;
let mut compression_state = state.clone();
for i in 0..64 {
compression_state = sha256.compression_round(
layouter,
&compression_state,
ROUND_CONSTANTS[i],
&message_blocks[i],
)?;
}
state.add(sha256, layouter, &compression_state)
}
fn conditional_update_state(
&self,
layouter: &mut impl Layouter<F>,
state: &CompressionState<F>,
block: &[AssignedByte<F>; 64],
update: &AssignedBit<F>,
) -> Result<CompressionState<F>, Error> {
let new_state = self.update_state(layouter, state, block)?;
CompressionState::select(layouter, self.ng(), update, &new_state, state)
}
pub(super) fn sha256_varlen<const M: usize>(
&self,
layouter: &mut impl Layouter<F>,
inputs: &AssignedVector<F, AssignedByte<F>, M, 64>,
) -> Result<[AssignedPlain<F, 32>; 8], Error> {
let ng = self.ng();
let (final_block_len, extra_block) = self.final_block_len::<M>(layouter, &inputs.len)?;
let rounded_len = {
let fc_len = ng.element_of_bounded(layouter, &final_block_len)?;
let is_zero = ng.is_zero(layouter, &fc_len)?;
let len_round = ng.sub(layouter, &inputs.len, &fc_len)?;
let len_round_extra = ng.add_constant(layouter, &len_round, F::from(64u64))?;
ng.select(layouter, &is_zero, &len_round, &len_round_extra)
}?;
let mut updating: AssignedBit<F> = ng.assign_fixed(layouter, false)?;
let mut state = CompressionState::<F>::fixed(layouter, ng, IV)?;
let mut block_iter = inputs.buffer.chunks_exact(64);
let mut block = block_iter.next().expect("At least one block.");
for i in 0..(M / 64) - 1 {
let b = ng.is_equal_to_fixed(layouter, &rounded_len, F::from((M - (i * 64)) as u64))?;
updating = ng.xor(layouter, &[b, updating])?;
let block_array = block.try_into().unwrap();
state = self.conditional_update_state(layouter, &state, block_array, &updating)?;
block = block_iter.next().expect("One more block.");
}
assert!(block_iter.next().is_none());
let final_block: &[_; 64] = block.try_into().unwrap();
let padding_data = self.compute_padding(
layouter,
&inputs.len,
&final_block_len,
final_block,
&extra_block,
)?;
let final_block_1 = (&padding_data[..64]).try_into().unwrap();
let final_block_2 = (&padding_data[64..]).try_into().unwrap();
state = self.conditional_update_state(layouter, &state, final_block_1, &extra_block)?;
state = self.update_state(layouter, &state, final_block_2)?;
Ok(state.plain())
}
}
#[cfg(any(test, feature = "testing"))]
use midnight_proofs::plonk::{Advice, Column, ConstraintSystem, Fixed, Instance};
#[cfg(any(test, feature = "testing"))]
use crate::testing_utils::FromScratch;
#[cfg(any(test, feature = "testing"))]
impl<F: CircuitField> FromScratch<F> for VarLenSha256Gadget<F> {
type Config = <Sha256Chip<F> as FromScratch<F>>::Config;
fn new_from_scratch(config: &Self::Config) -> Self {
Self {
sha256chip: Sha256Chip::new_from_scratch(config),
}
}
fn configure_from_scratch(
meta: &mut ConstraintSystem<F>,
advice_columns: &mut Vec<Column<Advice>>,
fixed_columns: &mut Vec<Column<Fixed>>,
instance_columns: &[Column<Instance>; 2],
) -> Self::Config {
Sha256Chip::configure_from_scratch(meta, advice_columns, fixed_columns, instance_columns)
}
fn load_from_scratch(&self, layouter: &mut impl Layouter<F>) -> Result<(), Error> {
self.sha256chip.load_from_scratch(layouter)
}
}