use std::collections::HashMap;
use std::rc::Rc;
use pounce_common::types::Number;
use pounce_nlp::expression_provider::{FbbtOp, FbbtTape};
use crate::nl_reader::{BinOp, Expr, UnaryOp};
struct Builder {
ops: Vec<FbbtOp>,
cse_cache: HashMap<*const Expr, usize>,
}
impl Builder {
fn new() -> Self {
Self {
ops: Vec::new(),
cse_cache: HashMap::new(),
}
}
fn emit(&mut self, op: FbbtOp) -> usize {
let idx = self.ops.len();
self.ops.push(op);
idx
}
fn translate(&mut self, expr: &Expr) -> usize {
match expr {
Expr::Const(v) => self.emit(FbbtOp::Const(*v)),
Expr::Var(i) => self.emit(FbbtOp::Var(*i)),
Expr::Cse(rc) => {
let key = Rc::as_ptr(rc);
if let Some(&slot) = self.cse_cache.get(&key) {
return slot;
}
let slot = self.translate(rc.as_ref());
self.cse_cache.insert(key, slot);
slot
}
Expr::Binary(op, lhs, rhs) => {
let a = self.translate(lhs);
let b = self.translate(rhs);
match op {
BinOp::Add => self.emit(FbbtOp::Add(a, b)),
BinOp::Sub => self.emit(FbbtOp::Sub(a, b)),
BinOp::Mul => self.emit(FbbtOp::Mul(a, b)),
BinOp::Div => self.emit(FbbtOp::Div(a, b)),
BinOp::Pow => {
let exp_const = const_value(rhs).and_then(integer_exponent);
if let Some(n) = exp_const {
self.emit(FbbtOp::PowInt(a, n))
} else {
self.emit(FbbtOp::Opaque)
}
}
}
}
Expr::Unary(op, x) => {
let a = self.translate(x);
match op {
UnaryOp::Neg => self.emit(FbbtOp::Neg(a)),
UnaryOp::Sqrt => self.emit(FbbtOp::Sqrt(a)),
UnaryOp::Log => self.emit(FbbtOp::Ln(a)),
UnaryOp::Exp => self.emit(FbbtOp::Exp(a)),
UnaryOp::Abs => self.emit(FbbtOp::Abs(a)),
UnaryOp::Sin => self.emit(FbbtOp::Sin(a)),
UnaryOp::Cos => self.emit(FbbtOp::Cos(a)),
UnaryOp::Log10 => {
let ln = self.emit(FbbtOp::Ln(a));
let denom = self.emit(FbbtOp::Const(std::f64::consts::LN_10));
self.emit(FbbtOp::Div(ln, denom))
}
}
}
Expr::Sum(parts) => {
if parts.is_empty() {
return self.emit(FbbtOp::Const(0.0));
}
let mut acc = self.translate(&parts[0]);
for p in &parts[1..] {
let next = self.translate(p);
acc = self.emit(FbbtOp::Add(acc, next));
}
acc
}
Expr::Funcall { .. } => {
self.emit(FbbtOp::Opaque)
}
}
}
}
fn const_value(expr: &Expr) -> Option<Number> {
match expr {
Expr::Const(v) => Some(*v),
Expr::Cse(rc) => const_value(rc.as_ref()),
_ => None,
}
}
fn integer_exponent(v: Number) -> Option<u32> {
if !v.is_finite() {
return None;
}
if v < 0.0 || v > 64.0 {
return None;
}
let rounded = v.round();
if (v - rounded).abs() > 1e-9 {
return None;
}
Some(rounded as u32)
}
pub fn translate_constraint(nonlinear: &Expr, linear: &[(usize, Number)]) -> Option<FbbtTape> {
let nonlinear_trivial = matches!(nonlinear, Expr::Const(c) if *c == 0.0);
if nonlinear_trivial && linear.is_empty() {
return None;
}
let mut b = Builder::new();
let mut root = if nonlinear_trivial {
None
} else {
Some(b.translate(nonlinear))
};
for &(var_idx, coef) in linear {
let v_slot = b.emit(FbbtOp::Var(var_idx));
let c_slot = b.emit(FbbtOp::Const(coef));
let term = b.emit(FbbtOp::Mul(v_slot, c_slot));
root = Some(match root {
None => term,
Some(prev) => b.emit(FbbtOp::Add(prev, term)),
});
}
debug_assert!(root.is_some());
Some(FbbtTape { ops: b.ops })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pure_linear_translates_to_sum_of_terms() {
let nonlinear = Expr::Const(0.0);
let linear = vec![(0usize, 3.0), (1usize, -2.0)];
let tape = translate_constraint(&nonlinear, &linear).unwrap();
assert!(matches!(tape.ops.last(), Some(FbbtOp::Add(2, 5))));
}
#[test]
fn purely_zero_constraint_returns_none() {
let nonlinear = Expr::Const(0.0);
assert!(translate_constraint(&nonlinear, &[]).is_none());
}
#[test]
fn unary_translations_cover_all_supported_ops() {
let inner = Box::new(Expr::Var(0));
let cases = [
(UnaryOp::Neg, FbbtOp::Neg(0)),
(UnaryOp::Sqrt, FbbtOp::Sqrt(0)),
(UnaryOp::Log, FbbtOp::Ln(0)),
(UnaryOp::Exp, FbbtOp::Exp(0)),
(UnaryOp::Abs, FbbtOp::Abs(0)),
(UnaryOp::Sin, FbbtOp::Sin(0)),
(UnaryOp::Cos, FbbtOp::Cos(0)),
];
for (op, expected) in cases {
let e = Expr::Unary(op, inner.clone());
let tape = translate_constraint(&e, &[]).unwrap();
assert_eq!(tape.ops.last().unwrap(), &expected);
}
}
#[test]
fn log10_decomposes_into_ln_div() {
let e = Expr::Unary(UnaryOp::Log10, Box::new(Expr::Var(0)));
let tape = translate_constraint(&e, &[]).unwrap();
assert!(matches!(tape.ops.last(), Some(FbbtOp::Div(1, 2))));
}
#[test]
fn pow_with_const_int_rhs_uses_powint() {
let e = Expr::Binary(
BinOp::Pow,
Box::new(Expr::Var(0)),
Box::new(Expr::Const(3.0)),
);
let tape = translate_constraint(&e, &[]).unwrap();
assert!(matches!(tape.ops.last(), Some(FbbtOp::PowInt(0, 3))));
}
#[test]
fn pow_with_variable_rhs_is_opaque() {
let e = Expr::Binary(BinOp::Pow, Box::new(Expr::Var(0)), Box::new(Expr::Var(1)));
let tape = translate_constraint(&e, &[]).unwrap();
assert!(matches!(tape.ops.last(), Some(FbbtOp::Opaque)));
}
#[test]
fn pow_with_fractional_const_is_opaque() {
let e = Expr::Binary(
BinOp::Pow,
Box::new(Expr::Var(0)),
Box::new(Expr::Const(1.5)),
);
let tape = translate_constraint(&e, &[]).unwrap();
assert!(matches!(tape.ops.last(), Some(FbbtOp::Opaque)));
}
#[test]
fn cse_shared_body_emitted_once() {
let body = Rc::new(Expr::Binary(
BinOp::Add,
Box::new(Expr::Var(0)),
Box::new(Expr::Const(1.0)),
));
let two_body = Expr::Binary(
BinOp::Mul,
Box::new(Expr::Cse(Rc::clone(&body))),
Box::new(Expr::Const(2.0)),
);
let total = Expr::Binary(BinOp::Add, Box::new(two_body), Box::new(Expr::Cse(body)));
let tape = translate_constraint(&total, &[]).unwrap();
let n_var0 = tape
.ops
.iter()
.filter(|op| matches!(op, FbbtOp::Var(0)))
.count();
assert_eq!(n_var0, 1, "CSE body must be emitted once: {:?}", tape.ops);
}
#[test]
fn sum_node_folds_to_binary_adds() {
let s = Expr::Sum(vec![Expr::Var(0), Expr::Var(1), Expr::Var(2)]);
let tape = translate_constraint(&s, &[]).unwrap();
assert!(matches!(tape.ops.last(), Some(FbbtOp::Add(2, 3))));
}
#[test]
fn empty_sum_folds_to_zero_constant() {
let s = Expr::Sum(vec![]);
let tape = translate_constraint(&s, &[]).unwrap();
assert_eq!(tape.ops.len(), 1);
assert!(matches!(tape.ops[0], FbbtOp::Const(c) if c == 0.0));
}
#[test]
fn funcall_collapses_to_opaque() {
let e = Expr::Funcall {
id: 0,
args: vec![],
};
let tape = translate_constraint(&e, &[]).unwrap();
assert!(matches!(tape.ops.last(), Some(FbbtOp::Opaque)));
}
#[test]
fn nonlinear_plus_linear_combines() {
let nonlinear = Expr::Binary(
BinOp::Pow,
Box::new(Expr::Var(0)),
Box::new(Expr::Const(2.0)),
);
let linear = vec![(1usize, 3.0), (2usize, 5.0)];
let tape = translate_constraint(&nonlinear, &linear).unwrap();
assert!(matches!(tape.ops.last(), Some(FbbtOp::Add(_, _))));
assert!(tape.first_invalid_slot().is_none());
}
#[test]
fn translated_tape_is_well_formed() {
let body = Rc::new(Expr::Unary(UnaryOp::Exp, Box::new(Expr::Var(0))));
let e = Expr::Binary(
BinOp::Add,
Box::new(Expr::Cse(Rc::clone(&body))),
Box::new(Expr::Binary(
BinOp::Mul,
Box::new(Expr::Cse(body)),
Box::new(Expr::Const(3.0)),
)),
);
let tape = translate_constraint(&e, &[(1, 0.5)]).unwrap();
assert!(tape.first_invalid_slot().is_none());
}
}