use alloc::{vec, vec::Vec};
use miden_ace_codegen::{
AceCircuit, AceConfig, AceDag, AceError, DagBuilder, InputKey, NodeId, build_ace_dag_for_air,
};
use miden_core::{Felt, field::ExtensionField};
use miden_crypto::{
field::Algebra,
stark::air::{LiftedAir, symbolic::SymbolicExpressionExt},
};
use crate::{PV_PROGRAM_HASH, PV_TRANSCRIPT_STATE};
#[derive(Debug, Clone)]
pub enum MessageElement {
Constant(Felt),
PublicInput(usize),
}
#[derive(Debug, Clone, Copy)]
pub enum Sign {
Plus,
Minus,
}
#[derive(Debug, Clone)]
pub struct BusFraction {
pub sign: Sign,
pub bus: usize,
pub message: Vec<MessageElement>,
}
#[derive(Debug, Clone)]
pub struct LogUpBoundaryConfig {
pub sum_columns: Vec<usize>,
pub zero_columns: Vec<usize>,
pub fractions: Vec<BusFraction>,
pub scalar_corrections: Vec<InputKey>,
}
pub fn batch_logup_boundary<EF>(
constraint_dag: AceDag<EF>,
config: &LogUpBoundaryConfig,
) -> AceDag<EF>
where
EF: ExtensionField<Felt>,
{
let constraint_root = constraint_dag.root;
let mut builder = DagBuilder::from_dag(constraint_dag);
let mut sum_aux = builder.constant(EF::ZERO);
for &col in &config.sum_columns {
let node = builder.input(InputKey::AuxBusBoundary(col));
sum_aux = builder.add(sum_aux, node);
}
for &scalar in &config.scalar_corrections {
let node = builder.input(scalar);
sum_aux = builder.add(sum_aux, node);
}
let mut num = builder.constant(EF::ZERO);
let mut den = builder.constant(EF::ONE);
for fraction in &config.fractions {
let d_i = encode_bus_message(&mut builder, fraction.bus, &fraction.message);
let sign_value = match fraction.sign {
Sign::Plus => EF::ONE,
Sign::Minus => -EF::ONE,
};
let n_i = builder.constant(sign_value);
let num_d = builder.mul(num, d_i);
let den_n = builder.mul(den, n_i);
num = builder.add(num_d, den_n);
den = builder.mul(den, d_i);
}
let sum_times_den = builder.mul(sum_aux, den);
let boundary = builder.add(sum_times_den, num);
let mut zero_sum = builder.constant(EF::ZERO);
for &col in &config.zero_columns {
let node = builder.input(InputKey::AuxBusBoundary(col));
zero_sum = builder.add(zero_sum, node);
}
let gamma = builder.input(InputKey::Gamma);
let gamma_boundary = builder.mul(gamma, boundary);
let constraint_plus_boundary = builder.add(constraint_root, gamma_boundary);
let gamma_sq = builder.mul(gamma, gamma);
let gamma_sq_zero = builder.mul(gamma_sq, zero_sum);
let root = builder.add(constraint_plus_boundary, gamma_sq_zero);
builder.build(root)
}
fn encode_bus_message<EF>(
builder: &mut DagBuilder<EF>,
bus: usize,
elements: &[MessageElement],
) -> NodeId
where
EF: ExtensionField<Felt>,
{
use crate::constraints::lookup::messages::MIDEN_MAX_MESSAGE_WIDTH;
let alpha = builder.input(InputKey::AuxRandAlpha);
let beta = builder.input(InputKey::AuxRandBeta);
let mut gamma_bus = builder.constant(EF::ONE);
for _ in 0..MIDEN_MAX_MESSAGE_WIDTH {
gamma_bus = builder.mul(gamma_bus, beta);
}
let scale = builder.constant(EF::from(Felt::from_u32((bus as u32) + 1)));
let offset = builder.mul(gamma_bus, scale);
let bus_prefix = builder.add(alpha, offset);
let mut acc = bus_prefix;
let mut beta_power = builder.constant(EF::ONE);
for elem in elements {
let node = match elem {
MessageElement::Constant(f) => builder.constant(EF::from(*f)),
MessageElement::PublicInput(idx) => builder.input(InputKey::Public(*idx)),
};
let term = builder.mul(beta_power, node);
acc = builder.add(acc, term);
beta_power = builder.mul(beta_power, beta);
}
acc
}
pub fn logup_boundary_config() -> LogUpBoundaryConfig {
use MessageElement::{Constant, PublicInput};
use crate::constraints::lookup::messages::BusId;
let ph_msg = vec![
PublicInput(PV_PROGRAM_HASH),
PublicInput(PV_PROGRAM_HASH + 1),
PublicInput(PV_PROGRAM_HASH + 2),
PublicInput(PV_PROGRAM_HASH + 3),
Constant(Felt::ZERO), Constant(Felt::ZERO), Constant(Felt::ZERO), ];
let default_lp_msg = vec![
Constant(Felt::ZERO),
Constant(Felt::ZERO),
Constant(Felt::ZERO),
Constant(Felt::ZERO),
];
let final_lp_msg = vec![
PublicInput(PV_TRANSCRIPT_STATE),
PublicInput(PV_TRANSCRIPT_STATE + 1),
PublicInput(PV_TRANSCRIPT_STATE + 2),
PublicInput(PV_TRANSCRIPT_STATE + 3),
];
LogUpBoundaryConfig {
sum_columns: vec![0],
zero_columns: vec![1],
fractions: vec![
BusFraction {
sign: Sign::Plus,
bus: BusId::BlockHashTable as usize,
message: ph_msg,
},
BusFraction {
sign: Sign::Plus,
bus: BusId::LogPrecompileTranscript as usize,
message: default_lp_msg,
},
BusFraction {
sign: Sign::Minus,
bus: BusId::LogPrecompileTranscript as usize,
message: final_lp_msg,
},
],
scalar_corrections: vec![InputKey::VlpiReduction(0)],
}
}
pub fn build_batched_ace_circuit<A, EF>(
air: &A,
config: AceConfig,
boundary_config: &LogUpBoundaryConfig,
) -> Result<AceCircuit<EF>, AceError>
where
A: LiftedAir<Felt, EF>,
EF: ExtensionField<Felt>,
SymbolicExpressionExt<Felt, EF>: Algebra<EF>,
{
let artifacts = build_ace_dag_for_air::<A, Felt, EF>(air, config)?;
let batched_dag = batch_logup_boundary(artifacts.dag, boundary_config);
miden_ace_codegen::emit_circuit(&batched_dag, artifacts.layout)
}