use alloc::{vec, vec::Vec};
use miden_ace_codegen::{
AceCircuit, AceConfig, AceDag, AceError, DagBuilder, InputKey, NodeId, NodeKind,
build_ace_dag_for_air,
};
use miden_core::{Felt, field::ExtensionField};
use miden_crypto::{
field::Algebra,
stark::air::{BaseAir, LiftedAir, symbolic::SymbolicExpressionExt},
};
use crate::{MidenAir, PV_PROGRAM_HASH, PV_TRANSCRIPT_STATE};
#[derive(Debug, Clone)]
pub enum MessageElement {
Constant(Felt),
PublicInput(usize),
}
#[derive(Debug, Clone, Copy)]
pub enum Sign {
Plus,
Minus,
}
#[derive(Debug, Clone)]
pub struct BusFraction {
pub sign: Sign,
pub bus: usize,
pub message: Vec<MessageElement>,
}
#[derive(Debug, Clone)]
pub struct LogUpBoundaryConfig {
pub sum_columns: Vec<usize>,
pub fractions: Vec<BusFraction>,
pub scalar_corrections: Vec<InputKey>,
}
pub fn batch_logup_boundary_into_builder<EF>(
builder: &mut DagBuilder<EF>,
constraint_root: NodeId,
config: &LogUpBoundaryConfig,
) -> NodeId
where
EF: ExtensionField<Felt>,
{
let mut sum_aux = builder.constant(EF::ZERO);
for &col in &config.sum_columns {
let node = builder.input(InputKey::AuxBusBoundary(col));
sum_aux = builder.add(sum_aux, node);
}
for &scalar in &config.scalar_corrections {
let node = builder.input(scalar);
sum_aux = builder.add(sum_aux, node);
}
let mut num = builder.constant(EF::ZERO);
let mut den = builder.constant(EF::ONE);
for fraction in &config.fractions {
let d_i = encode_bus_message(builder, fraction.bus, &fraction.message);
let sign_value = match fraction.sign {
Sign::Plus => EF::ONE,
Sign::Minus => -EF::ONE,
};
let n_i = builder.constant(sign_value);
let num_d = builder.mul(num, d_i);
let den_n = builder.mul(den, n_i);
num = builder.add(num_d, den_n);
den = builder.mul(den, d_i);
}
let sum_times_den = builder.mul(sum_aux, den);
let boundary = builder.add(sum_times_den, num);
let gamma = builder.input(InputKey::Gamma);
let gamma_boundary = builder.mul(gamma, boundary);
builder.add(constraint_root, gamma_boundary)
}
fn encode_bus_message<EF>(
builder: &mut DagBuilder<EF>,
bus: usize,
elements: &[MessageElement],
) -> NodeId
where
EF: ExtensionField<Felt>,
{
use crate::constraints::lookup::messages::MIDEN_MAX_MESSAGE_WIDTH;
let alpha = builder.input(InputKey::AuxRandAlpha);
let beta = builder.input(InputKey::AuxRandBeta);
let mut gamma_bus = builder.constant(EF::ONE);
for _ in 0..MIDEN_MAX_MESSAGE_WIDTH {
gamma_bus = builder.mul(gamma_bus, beta);
}
let scale = builder.constant(EF::from(Felt::from_u32((bus as u32) + 1)));
let offset = builder.mul(gamma_bus, scale);
let bus_prefix = builder.add(alpha, offset);
let mut acc = bus_prefix;
let mut beta_power = builder.constant(EF::ONE);
for elem in elements {
let node = match elem {
MessageElement::Constant(f) => builder.constant(EF::from(*f)),
MessageElement::PublicInput(idx) => builder.input(InputKey::Public(*idx)),
};
let term = builder.mul(beta_power, node);
acc = builder.add(acc, term);
beta_power = builder.mul(beta_power, beta);
}
acc
}
pub fn build_multi_air_ace_circuit<EF>(config: AceConfig) -> Result<AceCircuit<EF>, AceError>
where
EF: ExtensionField<Felt>,
SymbolicExpressionExt<Felt, EF>: Algebra<EF>,
{
assert!(
config.is_multi_air,
"build_multi_air_ace_circuit requires AceConfig::is_multi_air = true"
);
use miden_ace_codegen::{InputCounts, InputLayout};
let core_air = MidenAir::CORE;
let chip_air = MidenAir::CHIPLETS;
let sub_config = AceConfig { is_multi_air: false, ..config };
let core_artifacts = build_ace_dag_for_air::<MidenAir, Felt, EF>(&core_air, sub_config)?;
let chip_artifacts = build_ace_dag_for_air::<MidenAir, Felt, EF>(&chip_air, sub_config)?;
let core_main_w = <MidenAir as BaseAir<Felt>>::width(&core_air);
let core_aux_w = <MidenAir as LiftedAir<Felt, EF>>::aux_width(&core_air);
let core_aux_n = <MidenAir as LiftedAir<Felt, EF>>::num_aux_values(&core_air);
let chip_main_w = <MidenAir as BaseAir<Felt>>::width(&chip_air);
let chip_aux_w = <MidenAir as LiftedAir<Felt, EF>>::aux_width(&chip_air);
let chip_aux_n = <MidenAir as LiftedAir<Felt, EF>>::num_aux_values(&chip_air);
const LMCS_ALIGNMENT: usize = 8;
let aligned_core_main = core_main_w.next_multiple_of(LMCS_ALIGNMENT);
let aligned_chip_main = chip_main_w.next_multiple_of(LMCS_ALIGNMENT);
let aligned_core_aux_coord =
(core_aux_w * miden_ace_codegen::EXT_DEGREE).next_multiple_of(LMCS_ALIGNMENT);
let aligned_chip_aux_coord =
(chip_aux_w * miden_ace_codegen::EXT_DEGREE).next_multiple_of(LMCS_ALIGNMENT);
let combined_main_w = aligned_core_main + aligned_chip_main;
let combined_aux_coord_w = aligned_core_aux_coord + aligned_chip_aux_coord;
assert!(
combined_aux_coord_w.is_multiple_of(miden_ace_codegen::EXT_DEGREE),
"combined aux coord width must be even"
);
let combined_aux_w = combined_aux_coord_w / miden_ace_codegen::EXT_DEGREE;
let combined_counts = InputCounts {
width: combined_main_w,
aux_width: combined_aux_w,
num_aux_boundary: core_aux_n + chip_aux_n,
num_public: core_artifacts.layout.counts.num_public,
num_vlpi: core_artifacts.layout.counts.num_vlpi,
num_randomness: 2,
num_periodic: chip_artifacts.layout.counts.num_periodic,
num_quotient_chunks: config.num_quotient_chunks,
};
let combined_layout = match config.layout {
miden_ace_codegen::LayoutKind::Native => InputLayout::new_multi_air(combined_counts),
miden_ace_codegen::LayoutKind::Masm => InputLayout::new_masm_multi_air(combined_counts),
};
let core_dag = core_artifacts.dag;
let core_root_old = core_dag.root();
let mut builder = DagBuilder::<EF>::new();
let core_translation = reemit_dag_with_rewrite(
&mut builder,
&core_dag,
|key| match key {
InputKey::IsFirst => InputKey::IsFirstCore,
InputKey::IsLast => InputKey::IsLastCore,
InputKey::IsTransition => InputKey::IsTransitionCore,
other => other,
},
true, );
let _core_root = core_root_old;
let chip_dag = chip_artifacts.dag;
let chip_root_old = chip_dag.root();
let aligned_core_aux_w = aligned_core_aux_coord / miden_ace_codegen::EXT_DEGREE;
let chip_translation = reemit_dag_with_rewrite(
&mut builder,
&chip_dag,
|key| match key {
InputKey::Main { offset, index } => {
InputKey::Main { offset, index: index + aligned_core_main }
},
InputKey::AuxCoord { offset, index, coord } => InputKey::AuxCoord {
offset,
index: index + aligned_core_aux_w,
coord,
},
InputKey::AuxBusBoundary(slot) => InputKey::AuxBusBoundary(slot + core_aux_n),
InputKey::IsFirst => InputKey::IsFirstChip,
InputKey::IsLast => InputKey::IsLastChip,
InputKey::IsTransition => InputKey::IsTransitionChip,
other => other,
},
true, );
let (core_acc, core_qv) = match core_dag.nodes[core_root_old.index()] {
NodeKind::Sub(acc_id, qv_id) => {
(core_translation[acc_id.index()], core_translation[qv_id.index()])
},
_ => panic!("CoreAir sub-DAG root must be `Sub(acc, q*v)`"),
};
let (chip_acc, chip_qv) = match chip_dag.nodes[chip_root_old.index()] {
NodeKind::Sub(acc_id, qv_id) => {
(chip_translation[acc_id.index()], chip_translation[qv_id.index()])
},
_ => panic!("ChipletsAir sub-DAG root must be `Sub(acc, q*v)`"),
};
if core_qv != chip_qv {
return Err(AceError::InvalidInputLayout {
message: "CoreAir and ChipletsAir quotient bindings must share the same q*v node"
.into(),
});
}
let mab_core = builder.input(InputKey::MultiAirBetaCore);
let mab_chip = builder.input(InputKey::MultiAirBetaChip);
let core_term = builder.mul(mab_core, core_acc);
let chip_term = builder.mul(mab_chip, chip_acc);
let combined_acc = builder.add(core_term, chip_term);
let combined_constraint = builder.sub(combined_acc, chip_qv);
let combined_boundary_config = multi_air_logup_boundary_config(core_aux_n, chip_aux_n);
let final_root = batch_logup_boundary_into_builder(
&mut builder,
combined_constraint,
&combined_boundary_config,
);
let combined_dag = builder.build(final_root);
miden_ace_codegen::emit_circuit(&combined_dag, combined_layout)
}
fn reemit_dag_with_rewrite<EF, F>(
builder: &mut DagBuilder<EF>,
source: &AceDag<EF>,
rewrite: F,
skip_root: bool,
) -> Vec<NodeId>
where
EF: ExtensionField<Felt>,
F: Fn(InputKey) -> InputKey,
{
let nodes = &source.nodes;
let limit = if skip_root && !nodes.is_empty() {
nodes.len() - 1
} else {
nodes.len()
};
let mut translation: Vec<NodeId> = Vec::with_capacity(nodes.len());
for node in nodes.iter().take(limit) {
let new_id = match *node {
NodeKind::Input(key) => builder.input(rewrite(key)),
NodeKind::Constant(v) => builder.constant(v),
NodeKind::Add(a, b) => builder.add(translation[a.index()], translation[b.index()]),
NodeKind::Sub(a, b) => builder.sub(translation[a.index()], translation[b.index()]),
NodeKind::Mul(a, b) => builder.mul(translation[a.index()], translation[b.index()]),
NodeKind::Neg(a) => builder.neg(translation[a.index()]),
};
translation.push(new_id);
}
translation
}
pub fn multi_air_logup_boundary_config(
core_aux_n: usize,
chip_aux_n: usize,
) -> LogUpBoundaryConfig {
use MessageElement::{Constant, PublicInput};
use crate::constraints::lookup::messages::BusId;
let ph_msg = vec![
PublicInput(PV_PROGRAM_HASH),
PublicInput(PV_PROGRAM_HASH + 1),
PublicInput(PV_PROGRAM_HASH + 2),
PublicInput(PV_PROGRAM_HASH + 3),
Constant(Felt::ZERO),
Constant(Felt::ZERO),
Constant(Felt::ZERO),
];
let default_lp_msg = vec![
Constant(Felt::ZERO),
Constant(Felt::ZERO),
Constant(Felt::ZERO),
Constant(Felt::ZERO),
];
let final_lp_msg = vec![
PublicInput(PV_TRANSCRIPT_STATE),
PublicInput(PV_TRANSCRIPT_STATE + 1),
PublicInput(PV_TRANSCRIPT_STATE + 2),
PublicInput(PV_TRANSCRIPT_STATE + 3),
];
let total_slots = core_aux_n + chip_aux_n;
let sum_columns: Vec<usize> = (0..total_slots).collect();
LogUpBoundaryConfig {
sum_columns,
fractions: vec![
BusFraction {
sign: Sign::Plus,
bus: BusId::BlockHashTable as usize,
message: ph_msg,
},
BusFraction {
sign: Sign::Plus,
bus: BusId::LogPrecompileTranscript as usize,
message: default_lp_msg,
},
BusFraction {
sign: Sign::Minus,
bus: BusId::LogPrecompileTranscript as usize,
message: final_lp_msg,
},
],
scalar_corrections: vec![InputKey::VlpiReduction(0)],
}
}