use std::{cell::RefCell, collections::HashSet, fmt::Debug, marker::PhantomData, rc::Rc};
use midnight_proofs::{
circuit::{Chip, Layouter, Region, Value},
plonk::{Advice, Column, ConstraintSystem, Error, Fixed, Selector, TableColumn},
poly::Rotation,
};
use crate::{
field::native::NB_ARITH_COLS, instructions::decomposition::Pow2RangeInstructions,
types::AssignedNative, CircuitField,
};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Pow2RangeConfig {
pub(crate) q_pow2range: Selector,
pub(crate) tag_col: Column<Fixed>,
pub(crate) val_cols: Vec<Column<Advice>>,
t_tag: TableColumn,
t_val: TableColumn,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Pow2RangeChip<F: CircuitField> {
config: Pow2RangeConfig,
max_bit_len: usize,
queried_tags: Rc<RefCell<HashSet<usize>>>,
_marker: PhantomData<F>,
}
impl<F: CircuitField> Chip<F> for Pow2RangeChip<F> {
type Config = Pow2RangeConfig;
type Loaded = ();
fn config(&self) -> &Self::Config {
&self.config
}
fn loaded(&self) -> &Self::Loaded {
&()
}
}
impl<F: CircuitField> Pow2RangeChip<F> {
pub(crate) fn assert_row_lower_than_2_pow_n(
&self,
region: &mut Region<'_, F>,
n: usize,
offset: usize,
) -> Result<(), Error> {
if n > self.max_bit_len {
panic!(
"assert_row_lower_than_2_pow_n: n={} cannot exceed max_bit_len={}",
n, self.max_bit_len
)
}
self.config.q_pow2range.enable(region, offset)?;
region.assign_fixed(
|| "pow2range_tag",
self.config.tag_col,
offset,
|| Value::known(F::from(n as u64)),
)?;
self.queried_tags.borrow_mut().insert(n);
Ok(())
}
}
impl<F: CircuitField> Pow2RangeInstructions<F> for Pow2RangeChip<F> {
fn assert_values_lower_than_2_pow_n(
&self,
layouter: &mut impl Layouter<F>,
values: &[AssignedNative<F>],
n: usize,
) -> Result<(), Error> {
let nr_range_check_cols = self.config.val_cols.len();
for chunk in values.chunks(nr_range_check_cols) {
layouter.assign_region(
|| "Assign values",
|mut region| {
for (assigned, col) in chunk.iter().zip(self.config.val_cols.iter()) {
let x = region.assign_advice(
|| "pow2range val",
*col,
0,
|| assigned.value().copied(),
)?;
region.constrain_equal(x.cell(), assigned.cell())?
}
for i in chunk.len()..nr_range_check_cols {
region.assign_advice(
|| "pow2range zero",
self.config.val_cols[i],
0,
|| Value::known(F::ZERO),
)?;
}
self.assert_row_lower_than_2_pow_n(&mut region, n, 0)
},
)?;
}
Ok(())
}
}
impl<F: CircuitField> Pow2RangeChip<F> {
pub fn new(config: &Pow2RangeConfig, max_bit_len: usize) -> Self {
Self {
config: config.clone(),
max_bit_len,
queried_tags: Rc::new(RefCell::new(HashSet::new())),
_marker: PhantomData,
}
}
pub fn max_bit_len(&self) -> usize {
self.max_bit_len
}
pub fn configure(
meta: &mut ConstraintSystem<F>,
columns: &[Column<Advice>],
) -> Pow2RangeConfig {
let val_cols = columns.to_vec();
assert!(
val_cols.len() < NB_ARITH_COLS,
"Nr of range-check columns should be smaller than NB_ARITHM_COLS."
);
let q_pow2range = meta.complex_selector();
let tag_col = meta.fixed_column();
let t_tag = meta.lookup_table_column();
let t_val = meta.lookup_table_column();
for val_col in &val_cols {
meta.lookup("pow2range column check", |meta| {
let sel = meta.query_selector(q_pow2range);
let tag = meta.query_fixed(tag_col, Rotation::cur());
let val = meta.query_advice(*val_col, Rotation::cur());
vec![(tag, t_tag), (sel * val, t_val)]
});
}
Pow2RangeConfig {
q_pow2range,
tag_col,
val_cols,
t_tag,
t_val,
}
}
pub fn load_table(&self, layouter: &mut impl Layouter<F>) -> Result<(), Error> {
layouter.assign_table(
|| "pow2range table",
|mut table| {
let mut offset = 0;
for bit_len in 0..=self.max_bit_len {
if bit_len > 0 && !self.queried_tags.borrow().contains(&bit_len) {
continue;
}
let tag = Value::known(F::from(bit_len as u64));
for value in 0..(1 << bit_len) {
let val = Value::known(F::from(value));
table.assign_cell(|| "t_tag", self.config.t_tag, offset, || tag)?;
table.assign_cell(|| "t_val", self.config.t_val, offset, || val)?;
offset += 1;
}
}
Ok(())
},
)
}
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use midnight_curves::Fq as Fp;
use midnight_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
use rand::Rng;
use super::*;
struct TestCircuit<F: CircuitField, const NR_COLS: usize> {
inputs: Vec<([u64; NR_COLS], usize)>, max_bit_len: usize,
_marker: PhantomData<F>,
}
impl<F: CircuitField, const NR_COLS: usize> Circuit<F> for TestCircuit<F, NR_COLS> {
type Config = Pow2RangeConfig;
type FloorPlanner = SimpleFloorPlanner;
type Params = ();
fn without_witnesses(&self) -> Self {
unreachable!();
}
fn configure(meta: &mut ConstraintSystem<F>) -> Self::Config {
let columns = (0..NR_COLS).map(|_| meta.advice_column()).collect::<Vec<_>>();
Pow2RangeChip::configure(meta, &columns)
}
fn synthesize(
&self,
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
let pow2range_chip = Pow2RangeChip::<F>::new(&config, self.max_bit_len);
layouter.assign_region(
|| "pow2range test",
|mut region| {
for (offset, input) in self.inputs.iter().enumerate() {
for i in 0..NR_COLS {
let col = pow2range_chip.config.val_cols[i];
let val = Value::known(F::from(input.0[i]));
region.assign_advice(|| "pow2range val", col, offset, || val)?;
}
pow2range_chip.assert_row_lower_than_2_pow_n(&mut region, input.1, 0)?;
}
Ok(())
},
)?;
pow2range_chip.load_table(&mut layouter)
}
}
fn run_pow2range_test<const NR_COLS: usize>() {
const MAX_BIT_LEN: usize = 10;
let mut rng = rand::thread_rng();
let inputs = (0..MAX_BIT_LEN)
.map(|n| {
let mut values = [0u64; NR_COLS];
values[0] = (1 << n) - 1;
for value in values.iter_mut().skip(1) {
*value = rng.gen_range(0..(1 << n));
}
(values, n)
})
.collect();
let circuit = TestCircuit::<Fp, NR_COLS> {
inputs,
max_bit_len: MAX_BIT_LEN,
_marker: PhantomData,
};
let public_inputs = vec![];
let prover = match MockProver::run(&circuit, public_inputs) {
Ok(prover) => prover,
Err(e) => panic!("{e:#?}"),
};
assert_eq!(prover.verify(), Ok(()));
}
#[test]
fn test_pow2range() {
run_pow2range_test::<1>();
run_pow2range_test::<2>();
run_pow2range_test::<3>();
run_pow2range_test::<4>();
}
fn run_pow2range_negative_test<const NR_COLS: usize>() {
const MAX_BIT_LEN: usize = 10;
(0..MAX_BIT_LEN).for_each(|n| {
let mut values = [0u64; NR_COLS];
let i = n % NR_COLS;
values[i] = 1 << n;
let circuit = TestCircuit::<Fp, NR_COLS> {
inputs: vec![(values, n)],
max_bit_len: MAX_BIT_LEN,
_marker: PhantomData,
};
let public_inputs = vec![];
let prover = match MockProver::run(&circuit, public_inputs) {
Ok(prover) => prover,
Err(e) => panic!("{e:#?}"),
};
assert!(prover.verify() != Ok(()));
})
}
#[test]
fn test_pow2range_negative() {
run_pow2range_negative_test::<1>();
run_pow2range_negative_test::<2>();
run_pow2range_negative_test::<3>();
run_pow2range_negative_test::<4>();
}
}