Skip to main content

rival/eval/
ops.rs

1//! Operation registry generating evaluation and optimization helpers
2//! Defines interval operators along with dispatch, bounds, and path reduction hooks
3use crate::eval::macros::def_ops;
4use crate::eval::adjust::path_reduction;
5use crate::eval::tricks::{AmplBounds, TrickContext, crosses_zero, get_slack};
6use crate::interval::Ival;
7use Expr::*;
8use rug::Float;
9
10def_ops! {
11    constant {
12        Pi: {
13            method: set_pi,
14        },
15
16        E: {
17            method: set_e,
18        },
19    },
20
21    unary {
22        Pow2: {
23            method: pow2_assign,
24            bounds: |ctx, _out, inp| {
25                AmplBounds::new(ctx.logspan(inp) + 1, 0)
26            },
27        },
28        Fabs: {
29            method: fabs_assign,
30            bounds: |_, _, _| AmplBounds::zero(),
31        },
32
33        Neg: {
34            method: neg_assign,
35            bounds: |_, _, _| AmplBounds::zero(),
36        },
37
38        Sqrt: {
39            method: sqrt_assign,
40            bounds: |ctx, _, inp| AmplBounds::new((ctx.logspan(inp) / 2).saturating_sub(1), 0),
41            optimize: |arg| {
42                // sqrt(x^2 + y^2) => hypot(x, y)
43                // sqrt(x^2 + 1) => hypot(x, 1)
44                // sqrt(1 + x^2) => hypot(1, x)
45                match arg {
46                    // TODO: Consider pow(x, 2) pattern in addition to x * x
47                    Add(a, b) => match (&*a, &*b) {
48                        // sqrt(x^2 + y^2)
49                        (Mul(x1, x2), Mul(y1, y2)) if x1 == x2 && y1 == y2 => {
50                            Hypot(x1.clone(), y1.clone())
51                        }
52                        // sqrt(x^2 + 1)
53                        (Mul(x1, x2), Literal(one)) if x1 == x2 && *one == 1.0 => {
54                            Hypot(x1.clone(), Box::new(Literal(one.clone())))
55                        }
56                        // sqrt(1 + x^2)
57                        (Literal(one), Mul(x1, x2)) if x1 == x2 && *one == 1.0 => {
58                            Hypot(Box::new(Literal(one.clone())), x1.clone())
59                        }
60                        _ => Sqrt(Box::new(Add(a, b))),
61                    },
62                    other => Sqrt(Box::new(other)),
63                }
64            },
65        },
66
67        Cbrt: {
68            method: cbrt_assign,
69            bounds: |ctx, _, inp| AmplBounds::new(((2 * ctx.logspan(inp)) / 3).saturating_sub(1), 0),
70        },
71
72        Exp: {
73            method: exp_assign,
74            bounds: |ctx, out, inp| {
75                let upper = ctx.maxlog(inp, false) + ctx.logspan(out);
76                let lower = if ctx.lower_bound_early_stopping {
77                    ctx.minlog(inp, true)
78                } else { 0 };
79                AmplBounds::new(upper, lower)
80            },
81            optimize: |arg| {
82                // exp(log(x)) => x
83                if let Log(x) = arg {
84                    If(
85                        Box::new(Gt(x.clone(), Box::new(Literal(Float::with_val(53, 0.0))))),
86                        x.clone(),
87                        Box::new(Literal(Float::with_val(53, f64::NAN))),
88                    )
89                } else {
90                    Exp(Box::new(arg))
91                }
92            },
93        },
94
95        Exp2: {
96            method: exp2_assign,
97            bounds: |ctx, out, inp| {
98                let upper = ctx.maxlog(inp, false) + ctx.logspan(out);
99                let lower = if ctx.lower_bound_early_stopping {
100                    ctx.minlog(inp, true)
101                } else { 0 };
102                AmplBounds::new(upper, lower)
103            },
104            optimize: |arg| {
105                if let Log2(x) = &arg {
106                    If(
107                        Box::new(Gt(x.clone(), Box::new(Literal(Float::with_val(53, 0.0))))),
108                        x.clone(),
109                        Box::new(Literal(Float::with_val(53, f64::NAN))),
110                    )
111                } else {
112                    Exp2(Box::new(arg))
113                }
114            },
115        },
116
117        Expm1: {
118            method: expm1_assign,
119            bounds: |ctx, out, inp| {
120                let mx = ctx.maxlog(inp, false);
121                let upper = (1 + mx).max(1 + mx - ctx.minlog(out, false));
122                AmplBounds::new(upper, 0)
123            },
124        },
125
126        Log: {
127            method: log_assign,
128            bounds: |ctx, out, inp| {
129                let upper = ctx.logspan(inp) - ctx.minlog(out, false) + 1;
130                let lower = if ctx.lower_bound_early_stopping {
131                    -ctx.maxlog(out, true)
132                } else { 0 };
133                AmplBounds::new(upper, lower)
134            },
135            // TODO: Get rid of these optimize clones
136            optimize: |arg| {
137                // log(exp(x)) => x
138                match arg {
139                    Exp(x) => *x,
140                    // log(1 + x) or log(x + 1) => log1p(x)
141                    Add(a, b) => match (&*a, &*b) {
142                        (Literal(one), x) if *one == 1.0 => Log1p(Box::new(x.clone())),
143                        (x, Literal(one)) if *one == 1.0 => Log1p(Box::new(x.clone())),
144                        _ => Log(Box::new(Add(a, b))),
145                    },
146                    other => Log(Box::new(other)),
147                }
148            },
149        },
150
151        Log2: {
152            method: log2_assign,
153            bounds: |ctx, out, inp| {
154                let upper = ctx.logspan(inp) - ctx.minlog(out, false) + 1;
155                let lower = if ctx.lower_bound_early_stopping {
156                    -ctx.maxlog(out, true)
157                } else { 0 };
158                AmplBounds::new(upper, lower)
159            },
160        },
161
162        Log10: {
163            method: log10_assign,
164            bounds: |ctx, out, inp| {
165                let upper = ctx.logspan(inp) - ctx.minlog(out, false) + 1;
166                let lower = if ctx.lower_bound_early_stopping {
167                    -ctx.maxlog(out, true)
168                } else { 0 };
169                AmplBounds::new(upper, lower)
170            },
171        },
172
173        Log1p: {
174            method: log1p_assign,
175            bounds: |ctx, out, inp| {
176                let upper_base = ctx.maxlog(inp, false) - ctx.minlog(out, false);
177                let lo_neg = inp.lo.as_float().is_sign_negative();
178                let hi_neg = inp.hi.as_float().is_sign_negative();
179                let slack = if lo_neg || hi_neg { get_slack(ctx.iteration, ctx.slack_unit) } else { 0 };
180                AmplBounds::new(upper_base + slack, 0)
181            },
182        },
183
184        Logb: {
185            method: logb_assign,
186            bounds: |ctx, _, _| AmplBounds::new(get_slack(ctx.iteration, ctx.slack_unit), 0),
187        },
188
189        Sin: {
190            method: sin_assign,
191            bounds: |ctx, out, inp| {
192                let upper = ctx.maxlog(inp, false) - ctx.minlog(out, false);
193                let lower = if ctx.lower_bound_early_stopping {
194                    if ctx.maxlog(inp, false) >= 1 { -1 - ctx.maxlog(out, true) } else { 0 }
195                } else { 0 };
196                AmplBounds::new(upper, lower)
197            },
198            optimize: |arg| {
199                // sin(PI * x / n) => sinu(2*n, x)
200                // sin(PI * x) => sinu(2, x)
201                // sin(2 * PI * x) => sinu(1, x)
202                match arg {
203                    // sin(PI * (x / n))
204                    Mul(a, b) if matches!(&*a, Pi) => match &*b {
205                        Div(x, n) => if let Literal(nval) = &**n {
206                            let i = nval.to_f64() as u64;
207                            if i as f64 == nval.to_f64() && i > 0 {
208                                return Sinu(2 * i, x.clone());
209                            }
210                            Sin(Box::new(Mul(a, b)))
211                        } else {
212                            Sin(Box::new(Mul(a, b)))
213                        },
214                        // sin(PI * x)
215                        _ => Sinu(2, b.clone()),
216                    },
217                    // sin((x / n) * PI) or sin(x * PI)
218                    Mul(a, b) if matches!(&*b, Pi) => match &*a {
219                        Div(x, n) => if let Literal(nval) = &**n {
220                            let i = nval.to_f64() as u64;
221                            if i as f64 == nval.to_f64() && i > 0 {
222                                return Sinu(2 * i, x.clone());
223                            }
224                            Sin(Box::new(Mul(a, b)))
225                        } else {
226                            Sin(Box::new(Mul(a, b)))
227                        },
228                        // sin(x * PI)
229                        _ => Sinu(2, a.clone()),
230                    },
231                    _ => Sin(Box::new(arg)),
232                }
233            },
234        },
235
236        Cos: {
237            method: cos_assign,
238            bounds: |ctx, out, inp| {
239                let upper = ctx.maxlog(inp, false) - ctx.minlog(out, false)
240                    + ctx.maxlog(inp, false).min(0);
241                let lower = if ctx.lower_bound_early_stopping {
242                    -ctx.maxlog(out, true) - 2
243                } else { 0 };
244                AmplBounds::new(upper, lower)
245            },
246            optimize: |arg| {
247                // cos(PI * x / n) => cosu(2*n, x)
248                // cos(PI * x) => cosu(2, x)
249                // cos(2 * PI * x) => cosu(1, x)
250                match arg {
251                    // cos(PI * (x / n))
252                    Mul(a, b) if matches!(&*a, Pi) => match &*b {
253                        Div(x, n) => if let Literal(nval) = &**n {
254                            let i = nval.to_f64() as u64;
255                            if i as f64 == nval.to_f64() && i > 0 {
256                                return Cosu(2 * i, x.clone());
257                            }
258                            Cos(Box::new(Mul(a, b)))
259                        } else {
260                            Cos(Box::new(Mul(a, b)))
261                        },
262                        // cos(PI * x)
263                        _ => Cosu(2, b.clone()),
264                    },
265                    // cos((x / n) * PI) or cos(x * PI)
266                    Mul(a, b) if matches!(&*b, Pi) => match &*a {
267                        Div(x, n) => if let Literal(nval) = &**n {
268                            let i = nval.to_f64() as u64;
269                            if i as f64 == nval.to_f64() && i > 0 {
270                                return Cosu(2 * i, x.clone());
271                            }
272                            Cos(Box::new(Mul(a, b)))
273                        } else {
274                            Cos(Box::new(Mul(a, b)))
275                        },
276                        // cos(x * PI)
277                        _ => Cosu(2, a.clone()),
278                    },
279                    _ => Cos(Box::new(arg)),
280                }
281            },
282        },
283
284        Tan: {
285            method: tan_assign,
286            bounds: |ctx, out, inp| {
287                let upper = ctx.maxlog(inp, false)
288                    + ctx.maxlog(out, false).abs().max(ctx.minlog(out, false).abs())
289                    + ctx.logspan(out)
290                    + 1;
291                let lower = if ctx.lower_bound_early_stopping {
292                    ctx.minlog(inp, true)
293                        + ctx
294                            .maxlog(out, true)
295                            .abs()
296                            .min(ctx.minlog(out, true).abs())
297                        - 1
298                } else { 0 };
299                AmplBounds::new(upper, lower)
300            },
301            optimize: |arg| {
302                // tan(PI * x / n) => tanu(2 * n, x)
303                // tan(PI * x) => tanu(2, x)
304                // tan(2 * PI * x) => tanu(1, x)
305                match arg {
306                    // tan(PI * (x / n))
307                    Mul(a, b) if matches!(&*a, Pi) => match &*b {
308                        Div(x, n) => if let Literal(nval) = &**n {
309                            let i = nval.to_f64() as u64;
310                            if i as f64 == nval.to_f64() && i > 0 {
311                                return Tanu(2 * i, x.clone());
312                            }
313                            Tan(Box::new(Mul(a, b)))
314                        } else {
315                            Tan(Box::new(Mul(a, b)))
316                        },
317                        // tan(PI * x)
318                        _ => Tanu(2, b.clone()),
319                    },
320                    // tan((x / n) * PI) or tan(x * PI)
321                    Mul(a, b) if matches!(&*b, Pi) => match &*a {
322                        Div(x, n) => if let Literal(nval) = &**n {
323                            let i = nval.to_f64() as u64;
324                            if i as f64 == nval.to_f64() && i > 0 {
325                                return Tanu(2 * i, x.clone());
326                            }
327                            Tan(Box::new(Mul(a, b)))
328                        } else {
329                            Tan(Box::new(Mul(a, b)))
330                        },
331                        // tan(x * PI)
332                        _ => Tanu(2, a.clone()),
333                    },
334                    _ => Tan(Box::new(arg)),
335                }
336            },
337        },
338
339        Asin: {
340            method: asin_assign,
341            bounds: |ctx, out, _| {
342                let upper = if ctx.maxlog(out, false) >= 1 { get_slack(ctx.iteration, ctx.slack_unit) } else { 1 };
343                AmplBounds::new(upper, 0)
344            },
345        },
346
347        Acos: {
348            method: acos_assign,
349            bounds: |ctx, _out, inp| {
350                let upper = if ctx.maxlog(inp, false) >= 0 { get_slack(ctx.iteration, ctx.slack_unit) } else { 0 };
351                AmplBounds::new(upper, 0)
352            },
353        },
354
355        Atan: {
356            method: atan_assign,
357            bounds: |ctx, out, inp| {
358                let upper = ctx.logspan(inp)
359                    - ctx.minlog(inp, false).abs().min(ctx.maxlog(inp, false).abs())
360                    - ctx.minlog(out, false);
361                let lower = if ctx.lower_bound_early_stopping {
362                    - (ctx.minlog(inp, true).abs().max(ctx.maxlog(inp, true).abs()))
363                        - ctx.maxlog(out, true)
364                        - 2
365                } else { 0 };
366                AmplBounds::new(upper, lower)
367            },
368        },
369
370        Sinh: {
371            method: sinh_assign,
372            bounds: |ctx, out, inp| {
373                let upper = ctx.maxlog(inp, false) + ctx.logspan(out) - ctx.minlog(inp, false).min(0);
374                let lower = if ctx.lower_bound_early_stopping { ctx.minlog(inp, true).max(0) } else { 0 };
375                AmplBounds::new(upper, lower)
376            },
377        },
378
379        Cosh: {
380            method: cosh_assign,
381            bounds: |ctx, out, inp| {
382                let upper = ctx.maxlog(inp, false) + ctx.logspan(out) + ctx.maxlog(inp, false).min(0);
383                let lower = if ctx.lower_bound_early_stopping { (ctx.minlog(inp, true) - 1).max(0) } else { 0 };
384                AmplBounds::new(upper, lower)
385            },
386        },
387
388        Tanh: {
389            method: tanh_assign,
390            bounds: |ctx, out, inp| {
391                let upper = ctx.logspan(out) + ctx.logspan(inp);
392                AmplBounds::new(upper, 0)
393            },
394        },
395
396        Asinh: {
397            method: asinh_assign,
398            bounds: |ctx, _, _| AmplBounds::new(get_slack(ctx.iteration, ctx.slack_unit), 0),
399        },
400
401        Acosh: {
402            method: acosh_assign,
403            bounds: |ctx, out, _| {
404                let z_exp = ctx.minlog(out, false);
405                let upper = if z_exp < 2 { get_slack(ctx.iteration, ctx.slack_unit) - z_exp } else { 0 };
406                AmplBounds::new(upper, 0)
407            },
408        },
409
410        Atanh: {
411            method: atanh_assign,
412            bounds: |ctx, _out, inp| {
413                let upper = if ctx.maxlog(inp, false) >= 1 { get_slack(ctx.iteration, ctx.slack_unit) } else { 1 };
414                AmplBounds::new(upper, 0)
415            },
416        },
417
418        Erf: {
419            method: erf_assign,
420            bounds: |ctx, _, _| AmplBounds::new(get_slack(ctx.iteration, ctx.slack_unit), 0),
421        },
422
423        Erfc: {
424            method: erfc_assign,
425            bounds: |ctx, _, _| AmplBounds::new(get_slack(ctx.iteration, ctx.slack_unit), 0),
426        },
427
428        Rint: {
429            method: rint_assign,
430            bounds: |ctx, _, _| AmplBounds::new(get_slack(ctx.iteration, ctx.slack_unit), 0),
431        },
432
433        Round: {
434            method: round_assign,
435            bounds: |ctx, _, _| AmplBounds::new(get_slack(ctx.iteration, ctx.slack_unit), 0),
436        },
437
438        Ceil: {
439            method: ceil_assign,
440            bounds: |ctx, _, _| AmplBounds::new(get_slack(ctx.iteration, ctx.slack_unit), 0),
441        },
442
443        Floor: {
444            method: floor_assign,
445            bounds: |ctx, _, _| AmplBounds::new(get_slack(ctx.iteration, ctx.slack_unit), 0),
446        },
447
448        Trunc: {
449            method: trunc_assign,
450            bounds: |ctx, _, _| AmplBounds::new(get_slack(ctx.iteration, ctx.slack_unit), 0),
451        },
452
453        Not: {
454            method: not_assign,
455            bounds: |_, _, _| AmplBounds::zero(),
456            path_reduce: path_reduction::not_op_path_reduce,
457        },
458
459        Error: {
460            method: error_assign,
461            bounds: |_, _, _| AmplBounds::zero(),
462        },
463
464        Assert: {
465            method: assert_assign,
466            bounds: |_, _, _| AmplBounds::zero(),
467            path_reduce: path_reduction::assert_op_path_reduce,
468        },
469    },
470
471    unary_param {
472        Cosu: {
473            method: cosu_assign,
474            bounds: |ctx, param, out, inp| {
475                let n_log = param as i64;
476                let upper = ctx.maxlog(inp, false) - n_log - ctx.minlog(out, false) + 2;
477                let lower = 0;
478                AmplBounds::new(upper, lower)
479            },
480        },
481
482        Sinu: {
483            method: sinu_assign,
484            bounds: |ctx, param, out, inp| {
485                let n_log = param as i64;
486                let upper = ctx.maxlog(inp, false) - n_log - ctx.minlog(out, false) + 2;
487                let lower = 0;
488                AmplBounds::new(upper, lower)
489            },
490        },
491
492        Tanu: {
493            method: tanu_assign,
494            bounds: |ctx, param, out, inp| {
495                let n_log = param as i64;
496                let upper = ctx.maxlog(inp, false) - n_log
497                    + ctx.maxlog(out, false).abs().max(ctx.minlog(out, false).abs()) + 3;
498                let lower = 0;
499                AmplBounds::new(upper, lower)
500            },
501        },
502    },
503
504    binary {
505        Pow: {
506            method: pow_assign,
507            bounds: |ctx, out, x, y| {
508                let maxlog_y = ctx.maxlog(y, false);
509                let minlog_y_less = ctx.minlog(y, true);
510                let logspan_x = ctx.logspan(x);
511                let logspan_out = ctx.logspan(out);
512                let maxlog_x = ctx.maxlog(x, false);
513                let minlog_x = ctx.minlog(x, false);
514                // Slack adjustments
515                let y_slack = if crosses_zero(out) && x.lo.as_float().is_sign_negative() { get_slack(ctx.iteration, ctx.slack_unit) } else { 0 };
516                let x_slack = if out.lo.as_float().is_zero() { get_slack(ctx.iteration, ctx.slack_unit) } else { 0 };
517
518                // Upper bounds
519                let upper_x_base = maxlog_y + logspan_x + logspan_out + x_slack;
520                let upper_x = upper_x_base.max(x_slack);
521                let abs_maxlog_x = maxlog_x.abs();
522                let abs_minlog_x = minlog_x.abs();
523                let span_x_mag = abs_maxlog_x.max(abs_minlog_x);
524                let upper_y_base = maxlog_y + span_x_mag + logspan_out + y_slack;
525                let upper_y = upper_y_base.max(y_slack);
526
527                // Lower bounds
528                let lower_x = if ctx.lower_bound_early_stopping { minlog_y_less } else { 0 };
529                let min_abs_span = abs_maxlog_x.min(abs_minlog_x);
530                let lower_y = if ctx.lower_bound_early_stopping {
531                    if min_abs_span == 0 { 0 } else { minlog_y_less }
532                } else { 0 };
533
534                (AmplBounds::new(upper_x, lower_x), AmplBounds::new(upper_y, lower_y))
535            },
536            optimize: |base, exp| {
537                if let Literal(exp_val) = &exp {
538                    // pow(arg, 2) => pow2(arg)
539                    if (exp_val.to_f64() - 2.0).abs() == 0.0 {
540                        return Pow2(Box::new(base));
541                    }
542                    // pow(arg, 0.5) => sqrt(arg)
543                    if (exp_val.to_f64() - 0.5).abs() == 0.0 {
544                        return Sqrt(Box::new(base));
545                    }
546                }
547
548                // pow(x, p/q) optimizations
549                if let Rational(rat) = &exp {
550                    let num = rat.numer();
551                    let den = rat.denom();
552                    if *den == 1 {
553                        return Pow(Box::new(base), Box::new(Rational(rat.clone())));
554                    }
555                    let den_odd = den.is_odd();
556                    let num_odd = num.is_odd();
557                    if den_odd && !num_odd {
558                        return Pow(Box::new(Fabs(Box::new(base))), Box::new(Rational(rat.clone())));
559                    }
560                    if den_odd && num_odd {
561                        return Copysign(
562                            Box::new(Pow(Box::new(Fabs(Box::new(base.clone()))), Box::new(Rational(rat.clone())))),
563                            Box::new(base),
564                        );
565                    }
566                }
567
568                match base {
569                    // pow(10, log10(x)) => x
570                    Literal(base_val) if (base_val.to_f64() - 10.0).abs() == 0.0 => {
571                        match exp {
572                            Log10(x) => If(
573                                Box::new(Gt(x.clone(), Box::new(Literal(Float::with_val(53, 0.0))))),
574                                x,
575                                Box::new(Literal(Float::with_val(53, f64::NAN))),
576                            ),
577                            _ => Pow(Box::new(Literal(base_val.clone())), Box::new(exp)),
578                        }
579                    }
580                    // pow(2, arg) => exp2(arg)
581                    Literal(base_val) if (base_val.to_f64() - 2.0).abs() == 0.0 => {
582                        Exp2(Box::new(exp))
583                    }
584                    // pow(E, arg) => exp(arg)
585                    E => Exp(Box::new(exp)),
586                    // pow(fabs(x), y) stays as is (already optimal for handling negative bases)
587                    Fabs(_) => Pow(Box::new(base), Box::new(exp)),
588                    _ => Pow(Box::new(base), Box::new(exp)),
589                }
590            },
591        },
592
593        Fdim: {
594            method: fdim_assign,
595            bounds: |ctx, out, x, y| {
596                let output_min = ctx.minlog(out, false);
597                let lhs_upper = ctx.maxlog(x, false) - output_min;
598                let rhs_upper = ctx.maxlog(y, false) - output_min;
599                let lhs_lower = if ctx.lower_bound_early_stopping { ctx.minlog(x, true) - ctx.maxlog(out, true) } else { 0 };
600                let rhs_lower = if ctx.lower_bound_early_stopping { ctx.minlog(y, true) - ctx.maxlog(out, true) } else { 0 };
601                (AmplBounds::new(lhs_upper, lhs_lower), AmplBounds::new(rhs_upper, rhs_lower))
602            },
603        },
604
605        Hypot: {
606            method: hypot_assign,
607            bounds: |ctx, _, _, _| {
608                let bounds = AmplBounds::new(get_slack(ctx.iteration, ctx.slack_unit), 0);
609                (bounds, bounds)
610            },
611        },
612        Add: {
613            method: add_assign,
614            bounds: |ctx, out, lhs, rhs| {
615                let output_min = ctx.minlog(out, false);
616                let lhs_upper = ctx.maxlog(lhs, false) - output_min;
617                let rhs_upper = ctx.maxlog(rhs, false) - output_min;
618                let lhs_lower = if ctx.lower_bound_early_stopping {
619                    ctx.minlog(lhs, true) - ctx.maxlog(out, true)
620                } else { 0 };
621                let rhs_lower = if ctx.lower_bound_early_stopping {
622                    ctx.minlog(rhs, true) - ctx.maxlog(out, true)
623                } else { 0 };
624                (AmplBounds::new(lhs_upper, lhs_lower), AmplBounds::new(rhs_upper, rhs_lower))
625            },
626        },
627
628        Sub: {
629            method: sub_assign,
630            bounds: |ctx, out, lhs, rhs| {
631                let output_min = ctx.minlog(out, false);
632                let lhs_upper = ctx.maxlog(lhs, false) - output_min;
633                let rhs_upper = ctx.maxlog(rhs, false) - output_min;
634                let lhs_lower = if ctx.lower_bound_early_stopping {
635                    ctx.minlog(lhs, true) - ctx.maxlog(out, true)
636                } else { 0 };
637                let rhs_lower = if ctx.lower_bound_early_stopping {
638                    ctx.minlog(rhs, true) - ctx.maxlog(out, true)
639                } else { 0 };
640                (AmplBounds::new(lhs_upper, lhs_lower), AmplBounds::new(rhs_upper, rhs_lower))
641            },
642            optimize: |lhs, rhs| {
643                match (&lhs, &rhs) {
644                    // (- (exp x) 1) => expm1(x)
645                    (Exp(x), Literal(one)) if one == &1.0 => {
646                        Expm1(x.clone())
647                    }
648                    // (- 1 (exp x)) => neg(expm1(x))
649                    (Literal(one), Exp(x)) if one == &1.0 => {
650                        Neg(Box::new(Expm1(x.clone())))
651                    }
652                    // (- 1 (erf x)) => erfc(x)
653                    (Literal(one), Erf(x)) if *one == 1.0 => {
654                        Erfc(x.clone())
655                    }
656                    // (- (erf x) 1) => neg(erfc(x))
657                    (Erf(x), Literal(one)) if *one == 1.0 => {
658                        Neg(Box::new(Erfc(x.clone())))
659                    }
660                    _ => Sub(Box::new(lhs), Box::new(rhs))
661                }
662            },
663        },
664
665        Mul: {
666            method: mul_assign,
667            bounds: |ctx, _, lhs, rhs| {
668                (AmplBounds::new(ctx.logspan(rhs), 0),
669                 AmplBounds::new(ctx.logspan(lhs), 0))
670            },
671        },
672
673        Div: {
674            method: div_assign,
675            bounds: |ctx, _, lhs, rhs| {
676                let lhs_bounds = AmplBounds::new(ctx.logspan(rhs), 0);
677                let rhs_bounds = AmplBounds::new(ctx.logspan(lhs) + 2 * ctx.logspan(rhs), 0);
678                (lhs_bounds, rhs_bounds)
679            },
680        },
681
682        And: {
683            method: and_assign,
684            bounds: |_, _, _, _| (AmplBounds::zero(), AmplBounds::zero()),
685            path_reduce: path_reduction::bool_op_path_reduce,
686        },
687
688        Or: {
689            method: or_assign,
690            bounds: |_, _, _, _| (AmplBounds::zero(), AmplBounds::zero()),
691            path_reduce: path_reduction::bool_op_path_reduce,
692        },
693
694        Eq: {
695            method: eq_assign,
696            bounds: |_, _, _, _| (AmplBounds::zero(), AmplBounds::zero()),
697            path_reduce: path_reduction::bool_op_path_reduce,
698        },
699
700        Ne: {
701            method: ne_assign,
702            bounds: |_, _, _, _| (AmplBounds::zero(), AmplBounds::zero()),
703            path_reduce: path_reduction::bool_op_path_reduce,
704        },
705
706        Lt: {
707            method: lt_assign,
708            bounds: |_, _, _, _| (AmplBounds::zero(), AmplBounds::zero()),
709            path_reduce: path_reduction::bool_op_path_reduce,
710        },
711
712        Le: {
713            method: le_assign,
714            bounds: |_, _, _, _| (AmplBounds::zero(), AmplBounds::zero()),
715            path_reduce: path_reduction::bool_op_path_reduce,
716        },
717
718        Gt: {
719            method: gt_assign,
720            bounds: |_, _, _, _| (AmplBounds::zero(), AmplBounds::zero()),
721            path_reduce: path_reduction::bool_op_path_reduce,
722        },
723
724        Ge: {
725            method: ge_assign,
726            bounds: |_, _, _, _| (AmplBounds::zero(), AmplBounds::zero()),
727            path_reduce: path_reduction::bool_op_path_reduce,
728        },
729
730        Fmin: {
731            method: fmin_assign,
732            bounds: |_, _, _, _| (AmplBounds::zero(), AmplBounds::zero()),
733            path_reduce: |machine, idx, mark| {
734                path_reduction::minmax_path_reduce(machine, idx, mark, false)
735            },
736        },
737
738        Fmax: {
739            method: fmax_assign,
740            bounds: |_, _, _, _| (AmplBounds::zero(), AmplBounds::zero()),
741            path_reduce: |machine, idx, mark| {
742                path_reduction::minmax_path_reduce(machine, idx, mark, true)
743            },
744        },
745
746        Copysign: {
747            method: copysign_assign,
748            bounds: |_, _, _, _| (AmplBounds::zero(), AmplBounds::zero()),
749        },
750
751        Atan2: {
752            method: atan2_assign,
753            bounds: |ctx, out, y, x| {
754                let upper = ctx.maxlog(x, false) + ctx.maxlog(y, false)
755                    - 2 * ctx.minlog(x, false).min(ctx.minlog(y, false))
756                    - ctx.minlog(out, false);
757                let lower = if ctx.lower_bound_early_stopping {
758                    ctx.minlog(x, true) + ctx.minlog(y, true)
759                        - 2 * ctx.maxlog(x, true).max(ctx.maxlog(y, true))
760                        - ctx.maxlog(out, true)
761                } else { 0 };
762                (AmplBounds::new(upper, lower), AmplBounds::new(upper, lower))
763            },
764        },
765
766        Fmod: {
767            method: fmod_assign,
768            bounds: |ctx, out, x, y| {
769                let slack = if crosses_zero(y) { get_slack(ctx.iteration, ctx.slack_unit) } else { 0 };
770                let upper_x = ctx.maxlog(x, false) - ctx.minlog(out, false);
771                let upper_y = upper_x + slack;
772                (AmplBounds::new(upper_x, 0), AmplBounds::new(upper_y, 0))
773            },
774        },
775
776        Remainder: {
777            method: remainder_assign,
778            bounds: |ctx, out, x, y| {
779                let slack = if crosses_zero(y) { get_slack(ctx.iteration, ctx.slack_unit) } else { 0 };
780                let upper_x = ctx.maxlog(x, false) - ctx.minlog(out, false);
781                let upper_y = upper_x + slack;
782                (AmplBounds::new(upper_x, 0), AmplBounds::new(upper_y, 0))
783            },
784        },
785    },
786
787    ternary {
788        Fma: {
789            method: fma_assign,
790            bounds: |ctx, out, a, b, _c| {
791                (AmplBounds::new(ctx.logspan(b) + ctx.logspan(out), 0),
792                 AmplBounds::new(ctx.logspan(a) + ctx.logspan(out), 0),
793                 AmplBounds::new(ctx.logspan(out), 0))
794            },
795            optimize: |x, y, z| {
796                // fma(x, y, z) => x * y + z
797                Add(Box::new(Mul(Box::new(x), Box::new(y))), Box::new(z))
798            },
799        },
800
801        If: {
802            method: if_assign,
803            bounds: |_, _, _, _, _| (AmplBounds::zero(), AmplBounds::zero(), AmplBounds::zero()),
804            path_reduce: path_reduction::if_op_path_reduce,
805        },
806    },
807}