use halo2_base::{
gates::{
circuit::{
builder::{BaseCircuitBuilder, RangeStatistics},
CircuitBuilderStage,
},
GateInstructions, RangeChip,
},
halo2_proofs::plonk::Circuit,
utils::ScalarField,
virtual_region::copy_constraints::SharedCopyConstraintManager,
AssignedValue, Context,
};
use itertools::Itertools;
use rayon::prelude::*;
use crate::rlc::{
chip::RlcChip,
virtual_region::{manager::RlcManager, RlcThreadBreakPoints},
RLC_PHASE,
};
use super::RlcCircuitParams;
#[derive(Clone, Debug, Default)]
pub struct RlcCircuitBuilder<F: ScalarField> {
pub base: BaseCircuitBuilder<F>,
pub rlc_manager: RlcManager<F>,
pub num_rlc_columns: usize,
pub gamma: Option<F>,
max_cache_bits: usize,
}
impl<F: ScalarField> RlcCircuitBuilder<F> {
pub fn new(witness_gen_only: bool, max_cache_bits: usize) -> Self {
let base = BaseCircuitBuilder::new(witness_gen_only);
let rlc_manager = RlcManager::new(witness_gen_only, base.core().copy_manager.clone());
Self { base, rlc_manager, max_cache_bits, ..Default::default() }
}
pub fn unknown(mut self, use_unknown: bool) -> Self {
self.base = self.base.unknown(use_unknown);
self.rlc_manager = self.rlc_manager.unknown(use_unknown);
self
}
pub fn from_stage(stage: CircuitBuilderStage, max_cache_bits: usize) -> Self {
Self::new(stage.witness_gen_only(), max_cache_bits)
.unknown(stage == CircuitBuilderStage::Keygen)
}
pub fn prover(
config_params: RlcCircuitParams,
break_points: RlcThreadBreakPoints,
max_cache_bits: usize,
) -> Self {
Self::new(true, max_cache_bits).use_params(config_params).use_break_points(break_points)
}
pub fn copy_manager(&self) -> &SharedCopyConstraintManager<F> {
&self.base.core().copy_manager
}
pub fn set_copy_manager(&mut self, copy_manager: SharedCopyConstraintManager<F>) {
self.base.set_copy_manager(copy_manager.clone());
self.rlc_manager.set_copy_manager(copy_manager);
}
pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager<F>) -> Self {
self.set_copy_manager(copy_manager);
self
}
pub fn set_max_cache_bits(&mut self, max_cache_bits: usize) {
self.max_cache_bits = max_cache_bits;
}
pub fn use_max_cache_bits(mut self, max_cache_bits: usize) -> Self {
self.set_max_cache_bits(max_cache_bits);
self
}
pub fn deep_clone(&self) -> Self {
let base = self.base.deep_clone();
let rlc_manager =
self.rlc_manager.clone().use_copy_manager(base.core().copy_manager.clone());
Self {
base,
rlc_manager,
num_rlc_columns: self.num_rlc_columns,
gamma: self.gamma,
max_cache_bits: self.max_cache_bits,
}
}
pub fn clear(&mut self) {
self.base.clear();
self.rlc_manager.clear();
}
pub fn witness_gen_only(&self) -> bool {
assert_eq!(self.base.witness_gen_only(), self.rlc_manager.witness_gen_only());
self.base.witness_gen_only()
}
pub fn params(&self) -> RlcCircuitParams {
RlcCircuitParams { base: self.base.params(), num_rlc_columns: self.num_rlc_columns }
}
pub fn set_params(&mut self, params: RlcCircuitParams) {
self.base.set_params(params.base);
self.num_rlc_columns = params.num_rlc_columns;
}
pub fn use_params(mut self, params: RlcCircuitParams) -> Self {
self.set_params(params);
self
}
pub fn break_points(&self) -> RlcThreadBreakPoints {
let base = self.base.break_points();
let rlc =
self.rlc_manager.break_points.borrow().as_ref().expect("break points not set").clone();
RlcThreadBreakPoints { base, rlc }
}
pub fn set_break_points(&mut self, break_points: RlcThreadBreakPoints) {
self.base.set_break_points(break_points.base);
*self.rlc_manager.break_points.borrow_mut() = Some(break_points.rlc);
}
pub fn use_break_points(mut self, break_points: RlcThreadBreakPoints) -> Self {
self.set_break_points(break_points);
self
}
pub fn set_lookup_bits(&mut self, lookup_bits: usize) {
self.base.config_params.lookup_bits = Some(lookup_bits);
}
pub fn use_lookup_bits(mut self, lookup_bits: usize) -> Self {
self.set_lookup_bits(lookup_bits);
self
}
pub fn set_k(&mut self, k: usize) {
self.base.config_params.k = k;
}
pub fn use_k(mut self, k: usize) -> Self {
self.set_k(k);
self
}
pub fn rlc_ctx_pair(&mut self) -> RlcContextPair<F> {
(self.base.main(RLC_PHASE), self.rlc_manager.main())
}
pub fn statistics(&self) -> RlcStatistics {
let base = self.base.statistics();
let total_rlc_advice = self.rlc_manager.total_advice();
RlcStatistics { base, total_rlc_advice }
}
pub fn public_instances(&mut self) -> &mut [Vec<AssignedValue<F>>] {
&mut self.base.assigned_instances
}
pub fn calculate_params(&mut self, minimum_rows: Option<usize>) -> RlcCircuitParams {
let base = self.base.calculate_params(minimum_rows);
let total_rlc_advice = self.rlc_manager.total_advice();
let max_rows = (1 << base.k) - minimum_rows.unwrap_or(0);
let num_rlc_columns = (total_rlc_advice + max_rows - 1) / max_rows;
self.num_rlc_columns = num_rlc_columns;
let params = RlcCircuitParams { base, num_rlc_columns };
#[cfg(feature = "display")]
{
println!("Total RLC advice cells: {total_rlc_advice}");
log::info!("Auto-calculated config params:\n {params:#?}");
}
params
}
pub fn range_chip(&self) -> RangeChip<F> {
self.base.range_chip()
}
pub fn rlc_chip(&mut self, gate: &impl GateInstructions<F>) -> RlcChip<F> {
#[cfg(feature = "halo2-axiom")]
{
assert!(
!self.witness_gen_only() || self.gamma.is_some(),
"Challenge value not available before SecondPhase"
);
}
let gamma = self.gamma.unwrap_or(F::ZERO);
let rlc_chip = RlcChip::new(gamma);
let cache_bits = self.max_cache_bits;
let (ctx_gate, ctx_rlc) = self.rlc_ctx_pair();
rlc_chip.load_rlc_cache((ctx_gate, ctx_rlc), gate, cache_bits);
rlc_chip
}
pub fn parallelize_phase1<T, R, FR>(&mut self, input: Vec<T>, f: FR) -> Vec<R>
where
F: ScalarField,
T: Send,
R: Send,
FR: Fn(RlcContextPair<F>, T) -> R + Send + Sync,
{
let core_thread_count = self.base.pool(RLC_PHASE).thread_count();
let rlc_thread_count = self.rlc_manager.thread_count();
let mut ctxs_gate = (0..input.len())
.map(|i| self.base.pool(RLC_PHASE).new_context(core_thread_count + i))
.collect_vec();
let mut ctxs_rlc = (0..input.len())
.map(|i| self.rlc_manager.new_context(rlc_thread_count + i))
.collect_vec();
let outputs: Vec<_> = input
.into_par_iter()
.zip((ctxs_gate.par_iter_mut()).zip(ctxs_rlc.par_iter_mut()))
.map(|(input, (ctx_gate, ctx_rlc))| f((ctx_gate, ctx_rlc), input))
.collect();
self.base.pool(RLC_PHASE).threads.append(&mut ctxs_gate);
self.rlc_manager.threads.append(&mut ctxs_rlc);
outputs
}
}
pub type RlcContextPair<'a, F> = (&'a mut Context<F>, &'a mut Context<F>);
pub struct RlcStatistics {
pub base: RangeStatistics,
pub total_rlc_advice: usize,
}