use std::collections::HashMap;
use proc_macro2::{Ident, Span};
use super::config_types::{ArithmeticOp, ArithmeticResult, Bounds, ConstraintDef, Sign};
pub fn bounds_op_is_finite(lhs: &Bounds, rhs: &Bounds, op: impl Fn(f64, f64) -> f64) -> bool {
let l_min = lhs.lower.unwrap_or(f64::MIN);
let l_max = lhs.upper.unwrap_or(f64::MAX);
let r_min = rhs.lower.unwrap_or(f64::MIN);
let r_max = rhs.upper.unwrap_or(f64::MAX);
let results: [f64; 4] = [
op(l_min, r_min),
op(l_min, r_max),
op(l_max, r_min),
op(l_max, r_max),
];
results.iter().all(|r: &f64| r.is_finite())
}
pub fn bounds_div_is_finite(lhs: &Bounds, rhs: &Bounds, rhs_excludes_zero: bool) -> bool {
let l_min = lhs.lower.unwrap_or(f64::MIN);
let l_max = lhs.upper.unwrap_or(f64::MAX);
let r_min = if rhs_excludes_zero && rhs.lower == Some(0.0) {
f64::MIN_POSITIVE
} else {
rhs.lower.unwrap_or(f64::MIN)
};
let r_max = if rhs_excludes_zero && rhs.upper == Some(0.0) {
-f64::MIN_POSITIVE
} else {
rhs.upper.unwrap_or(f64::MAX)
};
let results = [l_min / r_min, l_min / r_max, l_max / r_min, l_max / r_max];
results.iter().all(|r| r.is_finite())
}
pub fn compute_all_arithmetic_results(
constraints: &[ConstraintDef],
) -> HashMap<(ArithmeticOp, String, String), ArithmeticResult> {
let mut results = HashMap::new();
let ops = [
ArithmeticOp::Add,
ArithmeticOp::Sub,
ArithmeticOp::Mul,
ArithmeticOp::Div,
];
for lhs in constraints {
for rhs in constraints {
for &op in &ops {
let result = compute_arithmetic_result(op, lhs, rhs, constraints);
results.insert((op, lhs.name.to_string(), rhs.name.to_string()), result);
}
}
}
results
}
fn compute_arithmetic_result(
op: ArithmeticOp,
lhs: &ConstraintDef,
rhs: &ConstraintDef,
all_constraints: &[ConstraintDef],
) -> ArithmeticResult {
let (output_sign, output_excludes_zero, is_safe) = match op {
ArithmeticOp::Add => compute_add_properties(lhs, rhs),
ArithmeticOp::Sub => compute_sub_properties(lhs, rhs),
ArithmeticOp::Mul => compute_mul_properties(lhs, rhs),
ArithmeticOp::Div => compute_div_properties(lhs, rhs),
};
let output_type = find_matching_constraint(
op,
output_sign,
output_excludes_zero,
lhs,
rhs,
all_constraints,
);
ArithmeticResult {
output_type,
is_safe,
}
}
fn compute_add_properties(lhs: &ConstraintDef, rhs: &ConstraintDef) -> (Sign, bool, bool) {
let signs_differ = matches!(
(lhs.sign, rhs.sign),
(Sign::Positive, Sign::Negative) | (Sign::Negative, Sign::Positive)
);
let is_safe = signs_differ && bounds_op_is_finite(&lhs.bounds, &rhs.bounds, |a, b| a + b);
let output_sign = match (lhs.sign, rhs.sign) {
(Sign::Positive, Sign::Positive) => Sign::Positive,
(Sign::Negative, Sign::Negative) => Sign::Negative,
_ => Sign::Any,
};
let output_excludes_zero =
lhs.excludes_zero && rhs.excludes_zero && lhs.sign == rhs.sign && lhs.sign != Sign::Any;
(output_sign, output_excludes_zero, is_safe)
}
fn compute_sub_properties(lhs: &ConstraintDef, rhs: &ConstraintDef) -> (Sign, bool, bool) {
let signs_same = lhs.sign == rhs.sign && lhs.sign != Sign::Any;
let is_safe = signs_same && bounds_op_is_finite(&lhs.bounds, &rhs.bounds, |a, b| a - b);
let rhs_negated_sign = match rhs.sign {
Sign::Positive => Sign::Negative,
Sign::Negative => Sign::Positive,
Sign::Any => Sign::Any,
};
let output_sign = match (lhs.sign, rhs_negated_sign) {
(Sign::Positive, Sign::Positive) => Sign::Positive,
(Sign::Negative, Sign::Negative) => Sign::Negative,
_ => Sign::Any,
};
let output_excludes_zero = lhs.excludes_zero
&& rhs.excludes_zero
&& lhs.sign == rhs_negated_sign
&& lhs.sign != Sign::Any;
(output_sign, output_excludes_zero, is_safe)
}
const fn max_abs_value(bounds: Bounds) -> f64 {
let lower_abs = if let Some(v) = bounds.lower {
v.abs()
} else {
0.0f64
};
let upper_abs = if let Some(v) = bounds.upper {
v.abs()
} else {
0.0f64
};
lower_abs.max(upper_abs)
}
fn compute_mul_result_bounds(
lhs: &ConstraintDef,
rhs: &ConstraintDef,
) -> (Option<f64>, Option<f64>) {
let max_abs_lhs = max_abs_value(lhs.bounds);
let max_abs_rhs = max_abs_value(rhs.bounds);
let max_abs_result = max_abs_lhs * max_abs_rhs;
match (lhs.sign, rhs.sign) {
(Sign::Positive, Sign::Positive) | (Sign::Negative, Sign::Negative) => {
(Some(0.0), Some(max_abs_result))
}
(Sign::Positive, Sign::Negative) | (Sign::Negative, Sign::Positive) => {
(Some(-max_abs_result), Some(0.0))
}
_ => (Some(-max_abs_result), Some(max_abs_result)),
}
}
fn filter_constraints_by_properties(
constraints: &[ConstraintDef],
sign: Sign,
excludes_zero: bool,
) -> Vec<&ConstraintDef> {
constraints
.iter()
.filter(|c| c.sign == sign && c.excludes_zero == excludes_zero)
.collect()
}
fn filter_constraints_by_sign(constraints: &[ConstraintDef], sign: Sign) -> Vec<&ConstraintDef> {
constraints.iter().filter(|c| c.sign == sign).collect()
}
fn compute_mul_properties(lhs: &ConstraintDef, rhs: &ConstraintDef) -> (Sign, bool, bool) {
let both_bounded = lhs.bounds.is_bounded() && rhs.bounds.is_bounded();
let is_safe = both_bounded && bounds_op_is_finite(&lhs.bounds, &rhs.bounds, |a, b| a * b);
let output_sign = match (lhs.sign, rhs.sign) {
(Sign::Positive, Sign::Positive) | (Sign::Negative, Sign::Negative) => Sign::Positive,
(Sign::Positive, Sign::Negative) | (Sign::Negative, Sign::Positive) => Sign::Negative,
_ => Sign::Any,
};
let output_excludes_zero = lhs.excludes_zero && rhs.excludes_zero;
(output_sign, output_excludes_zero, is_safe)
}
fn compute_div_properties(lhs: &ConstraintDef, rhs: &ConstraintDef) -> (Sign, bool, bool) {
let lhs_in_unit_range = if let (Some(lower), Some(upper)) = (lhs.bounds.lower, lhs.bounds.upper)
{
lower >= -1.0 && upper <= 1.0
} else {
false
};
let is_safe = lhs_in_unit_range
&& rhs.excludes_zero
&& bounds_div_is_finite(&lhs.bounds, &rhs.bounds, rhs.excludes_zero);
let output_sign = match (lhs.sign, rhs.sign) {
(Sign::Positive, Sign::Positive) | (Sign::Negative, Sign::Negative) => Sign::Positive,
(Sign::Positive, Sign::Negative) | (Sign::Negative, Sign::Positive) => Sign::Negative,
_ => Sign::Any,
};
let output_excludes_zero = lhs.excludes_zero;
(output_sign, output_excludes_zero, is_safe)
}
fn find_matching_constraint(
op: ArithmeticOp,
sign: Sign,
excludes_zero: bool,
lhs: &ConstraintDef,
rhs: &ConstraintDef,
constraints: &[ConstraintDef],
) -> Ident {
let operands_have_same_bounds = lhs.bounds.is_bounded()
&& rhs.bounds.is_bounded()
&& lhs.bounds.lower == rhs.bounds.lower
&& lhs.bounds.upper == rhs.bounds.upper;
let matches = filter_constraints_by_properties(constraints, sign, excludes_zero);
if !matches.is_empty() {
if matches!(op, ArithmeticOp::Mul) {
if lhs.bounds.is_bounded() && rhs.bounds.is_bounded() {
let (result_lower, result_upper) = compute_mul_result_bounds(lhs, rhs);
for c in &matches {
if c.bounds.is_bounded()
&& c.bounds.lower == result_lower
&& c.bounds.upper == result_upper
{
return c.name.clone();
}
}
for c in &matches {
if c.bounds.is_bounded() {
return c.name.clone();
}
}
}
}
for c in &matches {
if c.bounds.lower.is_none() && c.bounds.upper.is_none() {
return c.name.clone();
}
}
return matches
.first()
.expect("matches should not be empty at this point")
.name
.clone();
}
let sign_matches = filter_constraints_by_sign(constraints, sign);
if !sign_matches.is_empty() {
if operands_have_same_bounds {
let operand_lower = lhs.bounds.lower;
let operand_upper = lhs.bounds.upper;
for c in &sign_matches {
if c.bounds.is_bounded()
&& c.bounds.lower == operand_lower
&& c.bounds.upper == operand_upper
{
return c.name.clone();
}
}
}
for c in &sign_matches {
if c.bounds.lower.is_none() && c.bounds.upper.is_none() {
return c.name.clone();
}
}
return sign_matches
.first()
.expect("sign_matches should not be empty at this point")
.name
.clone();
}
Ident::new("Fin", Span::call_site())
}