use std::collections::HashMap;
use std::fmt;
use log::{debug, trace};
use num_bigint::BigInt;
use program_structure::cfg::Cfg;
use program_structure::ir::value_meta::{ValueMeta, ValueReduction};
use program_structure::report_code::ReportCode;
use program_structure::report::{Report, ReportCollection};
use program_structure::ir::*;
pub struct UnconstrainedLessThanWarning {
value: Expression,
bit_sizes: Vec<(Meta, Expression)>,
}
impl UnconstrainedLessThanWarning {
fn primary_meta(&self) -> &Meta {
self.value.meta()
}
pub fn into_report(self) -> Report {
let mut report = Report::warning(
"Inputs to `LessThan` need to be constrained to ensure that they are non-negative"
.to_string(),
ReportCode::UnconstrainedLessThan,
);
if let Some(file_id) = self.primary_meta().file_id {
report.add_primary(
self.primary_meta().file_location(),
file_id,
format!("`{}` needs to be constrained to ensure that it is <= p/2.", self.value),
);
for (meta, size) in self.bit_sizes {
report.add_secondary(
meta.file_location(),
file_id,
Some(format!("`{}` is constrained to `{}` bits here.", self.value, size)),
);
}
}
report
}
}
#[derive(Eq, PartialEq, Hash)]
struct VariableAccess {
pub var: VariableName,
pub access: Vec<AccessType>,
}
impl VariableAccess {
fn new(var: &VariableName, access: &[AccessType]) -> Self {
VariableAccess { var: var.without_version(), access: access.to_vec() }
}
}
enum Component {
LessThan,
Num2Bits { bit_size: Box<Expression> },
}
impl Component {
fn less_than() -> Self {
Self::LessThan
}
fn num_2_bits(bit_size: &Expression) -> Self {
Self::Num2Bits { bit_size: Box::new(bit_size.clone()) }
}
}
enum ComponentInput {
LessThan { value: Box<Expression> },
Num2Bits { value: Box<Expression>, bit_size: Box<Expression> },
}
impl ComponentInput {
fn less_than(value: &Expression) -> Self {
Self::LessThan { value: Box::new(value.clone()) }
}
fn num_2_bits(value: &Expression, bit_size: &Expression) -> Self {
Self::Num2Bits { value: Box::new(value.clone()), bit_size: Box::new(bit_size.clone()) }
}
}
#[derive(Default)]
struct ConstraintData {
pub less_than: Vec<Meta>,
pub num_2_bits: Vec<Meta>,
pub bit_sizes: Vec<Expression>,
}
pub fn find_unconstrained_less_than(cfg: &Cfg) -> ReportCollection {
debug!("running unconstrained less-than analysis pass");
let mut components = HashMap::new();
for basic_block in cfg.iter() {
for stmt in basic_block.iter() {
update_components(stmt, &mut components);
}
}
let mut inputs = Vec::new();
for basic_block in cfg.iter() {
for stmt in basic_block.iter() {
update_inputs(stmt, &components, &mut inputs);
}
}
let mut constraints = HashMap::<Expression, ConstraintData>::new();
for input in inputs {
match input {
ComponentInput::LessThan { value } => {
let entry = constraints.entry(*value.clone()).or_default();
entry.less_than.push(value.meta().clone());
}
ComponentInput::Num2Bits { value, bit_size, .. } => {
let entry = constraints.entry(*value.clone()).or_default();
entry.num_2_bits.push(value.meta().clone());
entry.bit_sizes.push(*bit_size.clone());
}
}
}
let mut reports = ReportCollection::new();
let max_value = BigInt::from(cfg.constants().prime_size() - 1);
for (value, data) in constraints {
if data.less_than.is_empty() {
continue;
}
let mut is_positive = false;
for bit_size in &data.bit_sizes {
if let Some(ValueReduction::FieldElement { value }) = bit_size.value() {
if value < &max_value {
is_positive = true;
break;
}
}
}
if is_positive {
continue;
}
reports.push(build_report(&value, &data));
}
debug!("{} new reports generated", reports.len());
reports
}
fn update_components(stmt: &Statement, components: &mut HashMap<VariableAccess, Component>) {
use AssignOp::*;
use Statement::*;
use Expression::*;
if let Substitution { meta, var, op: AssignLocalOrComponent, rhe, .. } = stmt {
if meta.type_knowledge().is_local() || meta.type_knowledge().is_signal() {
return;
}
let (rhe, access) = if let Update { access, rhe, .. } = rhe {
(rhe.as_ref(), access.clone())
} else {
(rhe, Vec::new())
};
if let Call { name: component_name, args, .. } = rhe {
if component_name == "LessThan" && args.len() == 1 {
trace!(
"`LessThan` template instantiation `{var}{}` found",
vec_to_display(&access, "")
);
let component = VariableAccess::new(var, &access);
components.insert(component, Component::less_than());
} else if component_name == "Num2Bits" && args.len() == 1 {
trace!(
"`LessThan` template instantiation `{var}{}` found",
vec_to_display(&access, "")
);
let component = VariableAccess::new(var, &access);
components.insert(component, Component::num_2_bits(&args[0]));
}
}
}
}
fn update_inputs(
stmt: &Statement,
components: &HashMap<VariableAccess, Component>,
inputs: &mut Vec<ComponentInput>,
) {
use AssignOp::*;
use Statement::*;
use Expression::*;
use AccessType::*;
if let Substitution {
var, op: AssignConstraintSignal, rhe: Update { access, rhe, .. }, ..
} = stmt
{
let mut component_access = access.clone();
let signal_access = component_access.pop();
let component = VariableAccess::new(var, &component_access);
if let Some(Component::Num2Bits { bit_size, .. }) = components.get(&component) {
let Some(ComponentAccess(signal_name)) = signal_access else {
return;
};
if signal_name != "in" {
return;
}
trace!("`Num2Bits` input signal assignment `{rhe}` found");
inputs.push(ComponentInput::num_2_bits(rhe, bit_size));
}
let mut component_access = access.clone();
let index_access = component_access.pop();
let signal_access = component_access.pop();
let component = VariableAccess::new(var, &component_access);
if let Some(Component::LessThan { .. }) = components.get(&component) {
let (Some(ComponentAccess(signal_name)), Some(ArrayAccess(_))) =
(signal_access, index_access)
else {
return;
};
if signal_name != "in" {
return;
}
trace!("`LessThan` input signal assignment `{rhe}` found");
inputs.push(ComponentInput::less_than(rhe));
}
}
}
#[must_use]
fn build_report(value: &Expression, data: &ConstraintData) -> Report {
UnconstrainedLessThanWarning {
value: value.clone(),
bit_sizes: data.num_2_bits.iter().cloned().zip(data.bit_sizes.iter().cloned()).collect(),
}
.into_report()
}
#[must_use]
fn vec_to_display<T: fmt::Display>(elems: &[T], sep: &str) -> String {
elems.iter().map(|elem| format!("{elem}")).collect::<Vec<String>>().join(sep)
}
#[cfg(test)]
mod tests {
use parser::parse_definition;
use program_structure::{cfg::IntoCfg, constants::Curve};
use super::*;
#[test]
fn test_unconstrained_less_than() {
let src = r#"
template Test(n) {
signal input small;
signal input large;
signal output ok;
// Check that small < large.
component lt = LessThan(n);
lt.in[0] <== small;
lt.in[1] <== large;
ok <== lt.out;
}
"#;
validate_reports(src, 2);
let src = r#"
template Test(n) {
signal input small;
signal input large;
signal output ok;
// Constrain inputs to n bits.
component n2b[2];
n2b[0] = Num2Bits(n);
n2b[0].in <== small;
n2b[1] = Num2Bits(n + 1);
n2b[1].in <== large;
// Check that small < large.
component lt = LessThan(n);
lt.in[0] <== small;
lt.in[1] <== large;
ok <== lt.out;
}
"#;
validate_reports(src, 2);
let src = r#"
template Test(n) {
signal input small;
signal input large;
signal output ok;
// Constrain inputs to n bits.
component n2b[2];
n2b[0] = Num2Bits(n);
n2b[0].in <== small;
n2b[1] = Num2Bits(32);
n2b[1].in <== large;
// Check that small < large.
component lt = LessThan(n);
lt.in[0] <== small;
lt.in[1] <== large;
ok <== lt.out;
}
"#;
validate_reports(src, 1);
let src = r#"
template Test(n) {
signal input small;
signal input large;
signal output ok;
// Check that small < large.
component lt = LessThan(n);
lt.in[1] <== large;
lt.in[0] <== small;
// Constrain inputs to n bits.
component n2b[2];
n2b[0] = Num2Bits(32);
n2b[0].in <== small;
n2b[1] = Num2Bits(64);
n2b[1].in <== large;
ok <== lt.out;
}
"#;
validate_reports(src, 0);
}
fn validate_reports(src: &str, expected_len: usize) {
let mut reports = ReportCollection::new();
let cfg = parse_definition(src)
.unwrap()
.into_cfg(&Curve::default(), &mut reports)
.unwrap()
.into_ssa()
.unwrap();
assert!(reports.is_empty());
let reports = find_unconstrained_less_than(&cfg);
assert_eq!(reports.len(), expected_len);
}
}