1use std::collections::HashMap;
29use std::rc::Rc;
30
31use pounce_common::types::Number;
32use pounce_nlp::expression_provider::{FbbtOp, FbbtTape};
33
34use crate::nl_reader::{BinOp, Expr, UnaryOp};
35
36struct Builder {
38 ops: Vec<FbbtOp>,
39 cse_cache: HashMap<*const Expr, usize>,
41}
42
43impl Builder {
44 fn new() -> Self {
45 Self {
46 ops: Vec::new(),
47 cse_cache: HashMap::new(),
48 }
49 }
50
51 fn emit(&mut self, op: FbbtOp) -> usize {
52 let idx = self.ops.len();
53 self.ops.push(op);
54 idx
55 }
56
57 fn translate(&mut self, expr: &Expr) -> usize {
60 match expr {
61 Expr::Const(v) => self.emit(FbbtOp::Const(*v)),
62 Expr::Var(i) => self.emit(FbbtOp::Var(*i)),
63 Expr::Cse(rc) => {
64 let key = Rc::as_ptr(rc);
65 if let Some(&slot) = self.cse_cache.get(&key) {
66 return slot;
67 }
68 let slot = self.translate(rc.as_ref());
69 self.cse_cache.insert(key, slot);
70 slot
71 }
72 Expr::Binary(op, lhs, rhs) => {
73 let a = self.translate(lhs);
74 let b = self.translate(rhs);
75 match op {
76 BinOp::Add => self.emit(FbbtOp::Add(a, b)),
77 BinOp::Sub => self.emit(FbbtOp::Sub(a, b)),
78 BinOp::Mul => self.emit(FbbtOp::Mul(a, b)),
79 BinOp::Div => self.emit(FbbtOp::Div(a, b)),
80 BinOp::Pow => {
81 let exp_const = const_value(rhs).and_then(integer_exponent);
86 if let Some(n) = exp_const {
87 self.emit(FbbtOp::PowInt(a, n))
88 } else {
89 self.emit(FbbtOp::Opaque)
90 }
91 }
92 }
93 }
94 Expr::Unary(op, x) => {
95 let a = self.translate(x);
96 match op {
97 UnaryOp::Neg => self.emit(FbbtOp::Neg(a)),
98 UnaryOp::Sqrt => self.emit(FbbtOp::Sqrt(a)),
99 UnaryOp::Log => self.emit(FbbtOp::Ln(a)),
100 UnaryOp::Exp => self.emit(FbbtOp::Exp(a)),
101 UnaryOp::Abs => self.emit(FbbtOp::Abs(a)),
102 UnaryOp::Sin => self.emit(FbbtOp::Sin(a)),
103 UnaryOp::Cos => self.emit(FbbtOp::Cos(a)),
104 UnaryOp::Log10 => {
107 let ln = self.emit(FbbtOp::Ln(a));
108 let denom = self.emit(FbbtOp::Const(std::f64::consts::LN_10));
109 self.emit(FbbtOp::Div(ln, denom))
110 }
111 }
112 }
113 Expr::Sum(parts) => {
114 if parts.is_empty() {
116 return self.emit(FbbtOp::Const(0.0));
117 }
118 let mut acc = self.translate(&parts[0]);
119 for p in &parts[1..] {
120 let next = self.translate(p);
121 acc = self.emit(FbbtOp::Add(acc, next));
122 }
123 acc
124 }
125 Expr::Funcall { .. } => {
126 self.emit(FbbtOp::Opaque)
128 }
129 }
130 }
131}
132
133fn const_value(expr: &Expr) -> Option<Number> {
138 match expr {
139 Expr::Const(v) => Some(*v),
140 Expr::Cse(rc) => const_value(rc.as_ref()),
141 _ => None,
142 }
143}
144
145fn integer_exponent(v: Number) -> Option<u32> {
150 if !v.is_finite() {
151 return None;
152 }
153 if v < 0.0 || v > 64.0 {
154 return None;
155 }
156 let rounded = v.round();
157 if (v - rounded).abs() > 1e-9 {
158 return None;
159 }
160 Some(rounded as u32)
161}
162
163pub fn translate_constraint(nonlinear: &Expr, linear: &[(usize, Number)]) -> Option<FbbtTape> {
174 let nonlinear_trivial = matches!(nonlinear, Expr::Const(c) if *c == 0.0);
175 if nonlinear_trivial && linear.is_empty() {
176 return None;
177 }
178
179 let mut b = Builder::new();
180 let mut root = if nonlinear_trivial {
181 None
184 } else {
185 Some(b.translate(nonlinear))
186 };
187
188 for &(var_idx, coef) in linear {
189 let v_slot = b.emit(FbbtOp::Var(var_idx));
190 let c_slot = b.emit(FbbtOp::Const(coef));
191 let term = b.emit(FbbtOp::Mul(v_slot, c_slot));
192 root = Some(match root {
193 None => term,
194 Some(prev) => b.emit(FbbtOp::Add(prev, term)),
195 });
196 }
197
198 debug_assert!(root.is_some());
201 Some(FbbtTape { ops: b.ops })
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn pure_linear_translates_to_sum_of_terms() {
210 let nonlinear = Expr::Const(0.0);
212 let linear = vec![(0usize, 3.0), (1usize, -2.0)];
213 let tape = translate_constraint(&nonlinear, &linear).unwrap();
214 assert!(matches!(tape.ops.last(), Some(FbbtOp::Add(2, 5))));
216 }
218
219 #[test]
220 fn purely_zero_constraint_returns_none() {
221 let nonlinear = Expr::Const(0.0);
222 assert!(translate_constraint(&nonlinear, &[]).is_none());
223 }
224
225 #[test]
226 fn unary_translations_cover_all_supported_ops() {
227 let inner = Box::new(Expr::Var(0));
228 let cases = [
229 (UnaryOp::Neg, FbbtOp::Neg(0)),
230 (UnaryOp::Sqrt, FbbtOp::Sqrt(0)),
231 (UnaryOp::Log, FbbtOp::Ln(0)),
232 (UnaryOp::Exp, FbbtOp::Exp(0)),
233 (UnaryOp::Abs, FbbtOp::Abs(0)),
234 (UnaryOp::Sin, FbbtOp::Sin(0)),
235 (UnaryOp::Cos, FbbtOp::Cos(0)),
236 ];
237 for (op, expected) in cases {
238 let e = Expr::Unary(op, inner.clone());
239 let tape = translate_constraint(&e, &[]).unwrap();
240 assert_eq!(tape.ops.last().unwrap(), &expected);
241 }
242 }
243
244 #[test]
245 fn log10_decomposes_into_ln_div() {
246 let e = Expr::Unary(UnaryOp::Log10, Box::new(Expr::Var(0)));
247 let tape = translate_constraint(&e, &[]).unwrap();
248 assert!(matches!(tape.ops.last(), Some(FbbtOp::Div(1, 2))));
250 }
251
252 #[test]
253 fn pow_with_const_int_rhs_uses_powint() {
254 let e = Expr::Binary(
256 BinOp::Pow,
257 Box::new(Expr::Var(0)),
258 Box::new(Expr::Const(3.0)),
259 );
260 let tape = translate_constraint(&e, &[]).unwrap();
261 assert!(matches!(tape.ops.last(), Some(FbbtOp::PowInt(0, 3))));
263 }
264
265 #[test]
266 fn pow_with_variable_rhs_is_opaque() {
267 let e = Expr::Binary(BinOp::Pow, Box::new(Expr::Var(0)), Box::new(Expr::Var(1)));
269 let tape = translate_constraint(&e, &[]).unwrap();
270 assert!(matches!(tape.ops.last(), Some(FbbtOp::Opaque)));
271 }
272
273 #[test]
274 fn pow_with_fractional_const_is_opaque() {
275 let e = Expr::Binary(
277 BinOp::Pow,
278 Box::new(Expr::Var(0)),
279 Box::new(Expr::Const(1.5)),
280 );
281 let tape = translate_constraint(&e, &[]).unwrap();
282 assert!(matches!(tape.ops.last(), Some(FbbtOp::Opaque)));
283 }
284
285 #[test]
286 fn cse_shared_body_emitted_once() {
287 let body = Rc::new(Expr::Binary(
289 BinOp::Add,
290 Box::new(Expr::Var(0)),
291 Box::new(Expr::Const(1.0)),
292 ));
293 let two_body = Expr::Binary(
294 BinOp::Mul,
295 Box::new(Expr::Cse(Rc::clone(&body))),
296 Box::new(Expr::Const(2.0)),
297 );
298 let total = Expr::Binary(BinOp::Add, Box::new(two_body), Box::new(Expr::Cse(body)));
299 let tape = translate_constraint(&total, &[]).unwrap();
300 let n_var0 = tape
302 .ops
303 .iter()
304 .filter(|op| matches!(op, FbbtOp::Var(0)))
305 .count();
306 assert_eq!(n_var0, 1, "CSE body must be emitted once: {:?}", tape.ops);
307 }
308
309 #[test]
310 fn sum_node_folds_to_binary_adds() {
311 let s = Expr::Sum(vec![Expr::Var(0), Expr::Var(1), Expr::Var(2)]);
312 let tape = translate_constraint(&s, &[]).unwrap();
313 assert!(matches!(tape.ops.last(), Some(FbbtOp::Add(2, 3))));
315 }
316
317 #[test]
318 fn empty_sum_folds_to_zero_constant() {
319 let s = Expr::Sum(vec![]);
320 let tape = translate_constraint(&s, &[]).unwrap();
321 assert_eq!(tape.ops.len(), 1);
324 assert!(matches!(tape.ops[0], FbbtOp::Const(c) if c == 0.0));
325 }
326
327 #[test]
328 fn funcall_collapses_to_opaque() {
329 let e = Expr::Funcall {
330 id: 0,
331 args: vec![],
332 };
333 let tape = translate_constraint(&e, &[]).unwrap();
334 assert!(matches!(tape.ops.last(), Some(FbbtOp::Opaque)));
335 }
336
337 #[test]
338 fn nonlinear_plus_linear_combines() {
339 let nonlinear = Expr::Binary(
341 BinOp::Pow,
342 Box::new(Expr::Var(0)),
343 Box::new(Expr::Const(2.0)),
344 );
345 let linear = vec![(1usize, 3.0), (2usize, 5.0)];
346 let tape = translate_constraint(&nonlinear, &linear).unwrap();
347 assert!(matches!(tape.ops.last(), Some(FbbtOp::Add(_, _))));
349 assert!(tape.first_invalid_slot().is_none());
350 }
351
352 #[test]
353 fn translated_tape_is_well_formed() {
354 let body = Rc::new(Expr::Unary(UnaryOp::Exp, Box::new(Expr::Var(0))));
356 let e = Expr::Binary(
357 BinOp::Add,
358 Box::new(Expr::Cse(Rc::clone(&body))),
359 Box::new(Expr::Binary(
360 BinOp::Mul,
361 Box::new(Expr::Cse(body)),
362 Box::new(Expr::Const(3.0)),
363 )),
364 );
365 let tape = translate_constraint(&e, &[(1, 0.5)]).unwrap();
366 assert!(tape.first_invalid_slot().is_none());
367 }
368}