Skip to main content

pounce_cli/
nl_fbbt_translate.rs

1//! Translate parsed `.nl` constraint expressions into an
2//! [`FbbtTape`] for the presolve FBBT pass (issue [#62]).
3//!
4//! The `Expr` tree pounce reads from a `.nl` file uses a richer
5//! operator set than FBBT supports (extern function calls,
6//! variable-exponent powers, AMPL `log10`, n-ary sums) and embeds
7//! common subexpressions via `Rc` sharing. This module flattens
8//! the tree into a tape where:
9//!
10//! * Every `Expr::Cse(rc)` is emitted **once** and re-referenced by
11//!   slot index on every subsequent occurrence — matching the
12//!   per-Rc-pointer caching strategy `nl_tape::Tape::build` already
13//!   uses for AD tapes.
14//! * Operators FBBT can reason about become the corresponding
15//!   [`FbbtOp`] variants.
16//! * Anything else collapses to [`FbbtOp::Opaque`], which forward /
17//!   reverse interval passes treat as "no information here." A single
18//!   unsupported sub-expression doesn't poison the whole constraint —
19//!   intervals just stop tightening through that subtree.
20//!
21//! The full constraint expression on row `i` is `con_nonlinear[i] +
22//! Σ_k coef_k · x_{var_k}`. The linear part is folded in after the
23//! nonlinear translation, so the resulting tape has *one* root
24//! representing the entire constraint.
25//!
26//! [#62]: https://github.com/jkitchin/pounce/issues/62
27
28use std::collections::HashMap;
29use std::rc::Rc;
30
31use pounce_common::types::Number;
32use pounce_nlp::expression_provider::{FbbtOp, FbbtTape};
33
34use crate::nl_reader::{BinOp, Expr, UnaryOp};
35
36/// Result of translating one `Expr` into a tape.
37struct Builder {
38    ops: Vec<FbbtOp>,
39    /// CSE cache: `Rc::as_ptr` → tape slot of the body.
40    cse_cache: HashMap<*const Expr, usize>,
41}
42
43impl Builder {
44    fn new() -> Self {
45        Self {
46            ops: Vec::new(),
47            cse_cache: HashMap::new(),
48        }
49    }
50
51    fn emit(&mut self, op: FbbtOp) -> usize {
52        let idx = self.ops.len();
53        self.ops.push(op);
54        idx
55    }
56
57    /// Recursively translate `expr` and return its slot index in
58    /// `self.ops`.
59    fn translate(&mut self, expr: &Expr) -> usize {
60        match expr {
61            Expr::Const(v) => self.emit(FbbtOp::Const(*v)),
62            Expr::Var(i) => self.emit(FbbtOp::Var(*i)),
63            Expr::Cse(rc) => {
64                let key = Rc::as_ptr(rc);
65                if let Some(&slot) = self.cse_cache.get(&key) {
66                    return slot;
67                }
68                let slot = self.translate(rc.as_ref());
69                self.cse_cache.insert(key, slot);
70                slot
71            }
72            Expr::Binary(op, lhs, rhs) => {
73                let a = self.translate(lhs);
74                let b = self.translate(rhs);
75                match op {
76                    BinOp::Add => self.emit(FbbtOp::Add(a, b)),
77                    BinOp::Sub => self.emit(FbbtOp::Sub(a, b)),
78                    BinOp::Mul => self.emit(FbbtOp::Mul(a, b)),
79                    BinOp::Div => self.emit(FbbtOp::Div(a, b)),
80                    BinOp::Pow => {
81                        // FBBT only handles integer exponent pinned at
82                        // compile time. If the right-hand side is an
83                        // `Expr::Const(c)` with `c` a small non-negative
84                        // integer, emit `PowInt`; otherwise bail.
85                        let exp_const = const_value(rhs).and_then(integer_exponent);
86                        if let Some(n) = exp_const {
87                            self.emit(FbbtOp::PowInt(a, n))
88                        } else {
89                            self.emit(FbbtOp::Opaque)
90                        }
91                    }
92                }
93            }
94            Expr::Unary(op, x) => {
95                let a = self.translate(x);
96                match op {
97                    UnaryOp::Neg => self.emit(FbbtOp::Neg(a)),
98                    UnaryOp::Sqrt => self.emit(FbbtOp::Sqrt(a)),
99                    UnaryOp::Log => self.emit(FbbtOp::Ln(a)),
100                    UnaryOp::Exp => self.emit(FbbtOp::Exp(a)),
101                    UnaryOp::Abs => self.emit(FbbtOp::Abs(a)),
102                    UnaryOp::Sin => self.emit(FbbtOp::Sin(a)),
103                    UnaryOp::Cos => self.emit(FbbtOp::Cos(a)),
104                    // log10 = ln / ln(10) — translate as (Ln(x) /
105                    // Const(ln 10)) so we don't drop info.
106                    UnaryOp::Log10 => {
107                        let ln = self.emit(FbbtOp::Ln(a));
108                        let denom = self.emit(FbbtOp::Const(std::f64::consts::LN_10));
109                        self.emit(FbbtOp::Div(ln, denom))
110                    }
111                }
112            }
113            Expr::Sum(parts) => {
114                // Left-fold the n-ary sum into binary Adds.
115                if parts.is_empty() {
116                    return self.emit(FbbtOp::Const(0.0));
117                }
118                let mut acc = self.translate(&parts[0]);
119                for p in &parts[1..] {
120                    let next = self.translate(p);
121                    acc = self.emit(FbbtOp::Add(acc, next));
122                }
123                acc
124            }
125            Expr::Funcall { .. } => {
126                // External / imported functions are opaque to FBBT.
127                self.emit(FbbtOp::Opaque)
128            }
129        }
130    }
131}
132
133/// Borrow the constant payload of an `Expr::Const`, or follow one
134/// layer of `Cse` to find a constant. Returns `None` for any other
135/// shape — including expressions that are *value-equivalent* to a
136/// constant but not syntactically one.
137fn const_value(expr: &Expr) -> Option<Number> {
138    match expr {
139        Expr::Const(v) => Some(*v),
140        Expr::Cse(rc) => const_value(rc.as_ref()),
141        _ => None,
142    }
143}
144
145/// Coerce a `Number` to a non-negative integer suitable for
146/// [`FbbtOp::PowInt`]. Caps at 64 — beyond that, interval arithmetic
147/// quickly hits the floating-point overflow band and produces
148/// uninformative bounds anyway.
149fn integer_exponent(v: Number) -> Option<u32> {
150    if !v.is_finite() {
151        return None;
152    }
153    if v < 0.0 || v > 64.0 {
154        return None;
155    }
156    let rounded = v.round();
157    if (v - rounded).abs() > 1e-9 {
158        return None;
159    }
160    Some(rounded as u32)
161}
162
163/// Translate the nonlinear part of constraint `i` together with its
164/// linear coefficients into a single tape. Returns `None` if neither
165/// part contributes anything (no nonlinear expression *and* no
166/// linear coefficients) — there's nothing for FBBT to tighten
167/// against.
168///
169/// `nonlinear` is the `Expr` from `NlProblem::con_nonlinear[i]`;
170/// `linear` is the corresponding `con_linear[i]` slice. Variable
171/// indices in `linear` are 0-based and refer to the same `Var(j)`
172/// slots in `nonlinear`.
173pub fn translate_constraint(nonlinear: &Expr, linear: &[(usize, Number)]) -> Option<FbbtTape> {
174    let nonlinear_trivial = matches!(nonlinear, Expr::Const(c) if *c == 0.0);
175    if nonlinear_trivial && linear.is_empty() {
176        return None;
177    }
178
179    let mut b = Builder::new();
180    let mut root = if nonlinear_trivial {
181        // Skip emitting a zero placeholder if we have linear terms;
182        // the linear fold will start from the first linear term.
183        None
184    } else {
185        Some(b.translate(nonlinear))
186    };
187
188    for &(var_idx, coef) in linear {
189        let v_slot = b.emit(FbbtOp::Var(var_idx));
190        let c_slot = b.emit(FbbtOp::Const(coef));
191        let term = b.emit(FbbtOp::Mul(v_slot, c_slot));
192        root = Some(match root {
193            None => term,
194            Some(prev) => b.emit(FbbtOp::Add(prev, term)),
195        });
196    }
197
198    // The builder's last emit is always the root after the linear
199    // fold; if both contributions were trivial we returned above.
200    debug_assert!(root.is_some());
201    Some(FbbtTape { ops: b.ops })
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn pure_linear_translates_to_sum_of_terms() {
210        // 3 * x0 + (-2) * x1
211        let nonlinear = Expr::Const(0.0);
212        let linear = vec![(0usize, 3.0), (1usize, -2.0)];
213        let tape = translate_constraint(&nonlinear, &linear).unwrap();
214        // ops: Var(0), Const(3), Mul(0,1), Var(1), Const(-2), Mul(3,4), Add(2,5).
215        assert!(matches!(tape.ops.last(), Some(FbbtOp::Add(2, 5))));
216        // Forward pass for x0 ∈ [0,1], x1 ∈ [0,1]: result ∈ [-2, 3].
217    }
218
219    #[test]
220    fn purely_zero_constraint_returns_none() {
221        let nonlinear = Expr::Const(0.0);
222        assert!(translate_constraint(&nonlinear, &[]).is_none());
223    }
224
225    #[test]
226    fn unary_translations_cover_all_supported_ops() {
227        let inner = Box::new(Expr::Var(0));
228        let cases = [
229            (UnaryOp::Neg, FbbtOp::Neg(0)),
230            (UnaryOp::Sqrt, FbbtOp::Sqrt(0)),
231            (UnaryOp::Log, FbbtOp::Ln(0)),
232            (UnaryOp::Exp, FbbtOp::Exp(0)),
233            (UnaryOp::Abs, FbbtOp::Abs(0)),
234            (UnaryOp::Sin, FbbtOp::Sin(0)),
235            (UnaryOp::Cos, FbbtOp::Cos(0)),
236        ];
237        for (op, expected) in cases {
238            let e = Expr::Unary(op, inner.clone());
239            let tape = translate_constraint(&e, &[]).unwrap();
240            assert_eq!(tape.ops.last().unwrap(), &expected);
241        }
242    }
243
244    #[test]
245    fn log10_decomposes_into_ln_div() {
246        let e = Expr::Unary(UnaryOp::Log10, Box::new(Expr::Var(0)));
247        let tape = translate_constraint(&e, &[]).unwrap();
248        // ops: Var(0), Ln(0), Const(ln 10), Div(1, 2).
249        assert!(matches!(tape.ops.last(), Some(FbbtOp::Div(1, 2))));
250    }
251
252    #[test]
253    fn pow_with_const_int_rhs_uses_powint() {
254        // x^3
255        let e = Expr::Binary(
256            BinOp::Pow,
257            Box::new(Expr::Var(0)),
258            Box::new(Expr::Const(3.0)),
259        );
260        let tape = translate_constraint(&e, &[]).unwrap();
261        // ops: Var(0), Const(3), PowInt(0, 3).
262        assert!(matches!(tape.ops.last(), Some(FbbtOp::PowInt(0, 3))));
263    }
264
265    #[test]
266    fn pow_with_variable_rhs_is_opaque() {
267        // x^y
268        let e = Expr::Binary(BinOp::Pow, Box::new(Expr::Var(0)), Box::new(Expr::Var(1)));
269        let tape = translate_constraint(&e, &[]).unwrap();
270        assert!(matches!(tape.ops.last(), Some(FbbtOp::Opaque)));
271    }
272
273    #[test]
274    fn pow_with_fractional_const_is_opaque() {
275        // x^1.5
276        let e = Expr::Binary(
277            BinOp::Pow,
278            Box::new(Expr::Var(0)),
279            Box::new(Expr::Const(1.5)),
280        );
281        let tape = translate_constraint(&e, &[]).unwrap();
282        assert!(matches!(tape.ops.last(), Some(FbbtOp::Opaque)));
283    }
284
285    #[test]
286    fn cse_shared_body_emitted_once() {
287        // body = x + 1; (body * 2) + body
288        let body = Rc::new(Expr::Binary(
289            BinOp::Add,
290            Box::new(Expr::Var(0)),
291            Box::new(Expr::Const(1.0)),
292        ));
293        let two_body = Expr::Binary(
294            BinOp::Mul,
295            Box::new(Expr::Cse(Rc::clone(&body))),
296            Box::new(Expr::Const(2.0)),
297        );
298        let total = Expr::Binary(BinOp::Add, Box::new(two_body), Box::new(Expr::Cse(body)));
299        let tape = translate_constraint(&total, &[]).unwrap();
300        // The body should appear only once: count Var(0)s.
301        let n_var0 = tape
302            .ops
303            .iter()
304            .filter(|op| matches!(op, FbbtOp::Var(0)))
305            .count();
306        assert_eq!(n_var0, 1, "CSE body must be emitted once: {:?}", tape.ops);
307    }
308
309    #[test]
310    fn sum_node_folds_to_binary_adds() {
311        let s = Expr::Sum(vec![Expr::Var(0), Expr::Var(1), Expr::Var(2)]);
312        let tape = translate_constraint(&s, &[]).unwrap();
313        // Var(0), Var(1), Add(0,1), Var(2), Add(2,3).
314        assert!(matches!(tape.ops.last(), Some(FbbtOp::Add(2, 3))));
315    }
316
317    #[test]
318    fn empty_sum_folds_to_zero_constant() {
319        let s = Expr::Sum(vec![]);
320        let tape = translate_constraint(&s, &[]).unwrap();
321        // Const(0) — and since linear is empty too, the whole tape
322        // is just that one slot.
323        assert_eq!(tape.ops.len(), 1);
324        assert!(matches!(tape.ops[0], FbbtOp::Const(c) if c == 0.0));
325    }
326
327    #[test]
328    fn funcall_collapses_to_opaque() {
329        let e = Expr::Funcall {
330            id: 0,
331            args: vec![],
332        };
333        let tape = translate_constraint(&e, &[]).unwrap();
334        assert!(matches!(tape.ops.last(), Some(FbbtOp::Opaque)));
335    }
336
337    #[test]
338    fn nonlinear_plus_linear_combines() {
339        // x0^2 + 3*x1 + 5*x2 (where x0^2 is nonlinear and 3*x1, 5*x2 are linear)
340        let nonlinear = Expr::Binary(
341            BinOp::Pow,
342            Box::new(Expr::Var(0)),
343            Box::new(Expr::Const(2.0)),
344        );
345        let linear = vec![(1usize, 3.0), (2usize, 5.0)];
346        let tape = translate_constraint(&nonlinear, &linear).unwrap();
347        // Last op must be Add (folding linear in).
348        assert!(matches!(tape.ops.last(), Some(FbbtOp::Add(_, _))));
349        assert!(tape.first_invalid_slot().is_none());
350    }
351
352    #[test]
353    fn translated_tape_is_well_formed() {
354        // A messy expression mixing CSEs, unary, binary, sums.
355        let body = Rc::new(Expr::Unary(UnaryOp::Exp, Box::new(Expr::Var(0))));
356        let e = Expr::Binary(
357            BinOp::Add,
358            Box::new(Expr::Cse(Rc::clone(&body))),
359            Box::new(Expr::Binary(
360                BinOp::Mul,
361                Box::new(Expr::Cse(body)),
362                Box::new(Expr::Const(3.0)),
363            )),
364        );
365        let tape = translate_constraint(&e, &[(1, 0.5)]).unwrap();
366        assert!(tape.first_invalid_slot().is_none());
367    }
368}